Skip to content

Commit

Permalink
add pre-commit workflow (#11973)
Browse files Browse the repository at this point in the history
* add pre-commit workflow

* run 'pre-commit run --all-files'

* setup python version
  • Loading branch information
GreatV authored Apr 21, 2024
1 parent 66b731b commit 90cbb95
Show file tree
Hide file tree
Showing 22 changed files with 1,676 additions and 951 deletions.
1,774 changes: 1,185 additions & 589 deletions PPOCRLabel.py

Large diffs are not rendered by default.

72 changes: 52 additions & 20 deletions gen_ocr_train_val_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,16 @@ def isCreateOrDeleteFolder(path, flag):
return flagAbsPath


def splitTrainVal(root, abs_train_root_path, abs_val_root_path, abs_test_root_path, train_txt, val_txt, test_txt, flag):

def splitTrainVal(
root,
abs_train_root_path,
abs_val_root_path,
abs_test_root_path,
train_txt,
val_txt,
test_txt,
flag,
):
data_abs_path = os.path.abspath(root)
label_file_name = args.detLabelFileName if flag == "det" else args.recLabelFileName
label_file_path = os.path.join(data_abs_path, label_file_name)
Expand All @@ -29,13 +37,15 @@ def splitTrainVal(root, abs_train_root_path, abs_val_root_path, abs_test_root_pa
label_record_len = len(label_file_content)

for index, label_record_info in enumerate(label_file_content):
image_relative_path, image_label = label_record_info.split('\t')
image_relative_path, image_label = label_record_info.split("\t")
image_name = os.path.basename(image_relative_path)

if flag == "det":
image_path = os.path.join(data_abs_path, image_name)
elif flag == "rec":
image_path = os.path.join(data_abs_path, args.recImageDirName, image_name)
image_path = os.path.join(
data_abs_path, args.recImageDirName, image_name
)

train_val_test_ratio = args.trainValTestRatio.split(":")
train_ratio = eval(train_val_test_ratio[0]) / 10
Expand Down Expand Up @@ -77,27 +87,46 @@ def genDetRecTrainVal(args):
removeFile(os.path.join(args.recRootPath, "val.txt"))
removeFile(os.path.join(args.recRootPath, "test.txt"))

detTrainTxt = open(os.path.join(args.detRootPath, "train.txt"), "a", encoding="UTF-8")
detTrainTxt = open(
os.path.join(args.detRootPath, "train.txt"), "a", encoding="UTF-8"
)
detValTxt = open(os.path.join(args.detRootPath, "val.txt"), "a", encoding="UTF-8")
detTestTxt = open(os.path.join(args.detRootPath, "test.txt"), "a", encoding="UTF-8")
recTrainTxt = open(os.path.join(args.recRootPath, "train.txt"), "a", encoding="UTF-8")
recTrainTxt = open(
os.path.join(args.recRootPath, "train.txt"), "a", encoding="UTF-8"
)
recValTxt = open(os.path.join(args.recRootPath, "val.txt"), "a", encoding="UTF-8")
recTestTxt = open(os.path.join(args.recRootPath, "test.txt"), "a", encoding="UTF-8")

splitTrainVal(args.datasetRootPath, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt,
detTestTxt, "det")
splitTrainVal(
args.datasetRootPath,
detAbsTrainRootPath,
detAbsValRootPath,
detAbsTestRootPath,
detTrainTxt,
detValTxt,
detTestTxt,
"det",
)

for root, dirs, files in os.walk(args.datasetRootPath):
for dir in dirs:
if dir == 'crop_img':
splitTrainVal(root, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt,
recTestTxt, "rec")
if dir == "crop_img":
splitTrainVal(
root,
recAbsTrainRootPath,
recAbsValRootPath,
recAbsTestRootPath,
recTrainTxt,
recValTxt,
recTestTxt,
"rec",
)
else:
continue
break



if __name__ == "__main__":
# 功能描述:分别划分检测和识别的训练集、验证集、测试集
# 说明:可以根据自己的路径和需求调整参数,图像数据往往多人合作分批标注,每一批图像数据放在一个文件夹内用PPOCRLabel进行标注,
Expand All @@ -107,40 +136,43 @@ def genDetRecTrainVal(args):
"--trainValTestRatio",
type=str,
default="6:2:2",
help="ratio of trainset:valset:testset")
help="ratio of trainset:valset:testset",
)
parser.add_argument(
"--datasetRootPath",
type=str,
default="../train_data/",
help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3..."
help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3...",
)
parser.add_argument(
"--detRootPath",
type=str,
default="../train_data/det",
help="the path where the divided detection dataset is placed")
help="the path where the divided detection dataset is placed",
)
parser.add_argument(
"--recRootPath",
type=str,
default="../train_data/rec",
help="the path where the divided recognition dataset is placed"
help="the path where the divided recognition dataset is placed",
)
parser.add_argument(
"--detLabelFileName",
type=str,
default="Label.txt",
help="the name of the detection annotation file")
help="the name of the detection annotation file",
)
parser.add_argument(
"--recLabelFileName",
type=str,
default="rec_gt.txt",
help="the name of the recognition annotation file"
help="the name of the recognition annotation file",
)
parser.add_argument(
"--recImageDirName",
type=str,
default="crop_img",
help="the name of the folder where the cropped recognition dataset is located"
help="the name of the folder where the cropped recognition dataset is located",
)
args = parser.parse_args()
genDetRecTrainVal(args)
genDetRecTrainVal(args)
4 changes: 2 additions & 2 deletions libs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version_info__ = ('1', '0', '0')
__version__ = '.'.join(__version_info__)
__version_info__ = ("1", "0", "0")
__version__ = ".".join(__version_info__)
56 changes: 38 additions & 18 deletions libs/autoDialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,40 +29,53 @@ def __init__(self, ocr, mImgList, mainThread, model):
self.mImgList = mImgList
self.mainThread = mainThread
self.model = model
self.setStackSize(1024*1024)
self.setStackSize(1024 * 1024)

def run(self):
try:
findex = 0
for Imgpath in self.mImgList:
if self.handle == 0:
self.listValue.emit(Imgpath)
if self.model == 'paddle':
h, w, _ = cv2.imdecode(np.fromfile(Imgpath, dtype=np.uint8), 1).shape
if self.model == "paddle":
h, w, _ = cv2.imdecode(
np.fromfile(Imgpath, dtype=np.uint8), 1
).shape
if h > 32 and w > 32:
self.result_dic = self.ocr.ocr(Imgpath, cls=True, det=True)[0]
self.result_dic = self.ocr.ocr(Imgpath, cls=True, det=True)[
0
]
else:
print('The size of', Imgpath, 'is too small to be recognised')
print(
"The size of", Imgpath, "is too small to be recognised"
)
self.result_dic = None

# 结果保存
if self.result_dic is None or len(self.result_dic) == 0:
print('Can not recognise file', Imgpath)
print("Can not recognise file", Imgpath)
pass
else:
strs = ''
strs = ""
for res in self.result_dic:
chars = res[1][0]
cond = res[1][1]
posi = res[0]
strs += "Transcription: " + chars + " Probability: " + str(cond) + \
" Location: " + json.dumps(posi) +'\n'
strs += (
"Transcription: "
+ chars
+ " Probability: "
+ str(cond)
+ " Location: "
+ json.dumps(posi)
+ "\n"
)
# Sending large amounts of data repeatedly through pyqtSignal may affect the program efficiency
self.listValue.emit(strs)
self.mainThread.result_dic = self.result_dic
self.mainThread.filePath = Imgpath
# 保存
self.mainThread.saveFile(mode='Auto')
self.mainThread.saveFile(mode="Auto")
findex += 1
self.progressBarValue.emit(findex)
else:
Expand All @@ -75,8 +88,9 @@ def run(self):


class AutoDialog(QDialog):

def __init__(self, text="Enter object label", parent=None, ocr=None, mImgList=None, lenbar=0):
def __init__(
self, text="Enter object label", parent=None, ocr=None, mImgList=None, lenbar=0
):
super(AutoDialog, self).__init__(parent)
self.setFixedWidth(1000)
self.parent = parent
Expand All @@ -89,13 +103,13 @@ def __init__(self, text="Enter object label", parent=None, ocr=None, mImgList=No

layout = QVBoxLayout()
layout.addWidget(self.pb)
self.model = 'paddle'
self.model = "paddle"
self.listWidget = QListWidget(self)
layout.addWidget(self.listWidget)

self.buttonBox = bb = BB(BB.Ok | BB.Cancel, Qt.Horizontal, self)
bb.button(BB.Ok).setIcon(newIcon('done'))
bb.button(BB.Cancel).setIcon(newIcon('undo'))
bb.button(BB.Ok).setIcon(newIcon("done"))
bb.button(BB.Cancel).setIcon(newIcon("undo"))
bb.accepted.connect(self.validate)
bb.rejected.connect(self.reject)
layout.addWidget(bb)
Expand All @@ -107,7 +121,7 @@ def __init__(self, text="Enter object label", parent=None, ocr=None, mImgList=No

# self.setWindowFlags(Qt.WindowCloseButtonHint)

self.thread_1 = Worker(self.ocr, self.mImgList, self.parent, 'paddle')
self.thread_1 = Worker(self.ocr, self.mImgList, self.parent, "paddle")
self.thread_1.progressBarValue.connect(self.handleProgressBarSingal)
self.thread_1.listValue.connect(self.handleListWidgetSingal)
self.thread_1.endsignal.connect(self.handleEndsignalSignal)
Expand All @@ -117,8 +131,14 @@ def handleProgressBarSingal(self, i):
self.pb.setValue(i)

# calculate time left of auto labeling
avg_time = (time.time() - self.time_start) / i # Use average time to prevent time fluctuations
time_left = str(datetime.timedelta(seconds=avg_time * (self.lender - i))).split(".")[0] # Remove microseconds
avg_time = (
time.time() - self.time_start
) / i # Use average time to prevent time fluctuations
time_left = str(datetime.timedelta(seconds=avg_time * (self.lender - i))).split(
"."
)[
0
] # Remove microseconds
self.setWindowTitle("PPOCRLabel -- " + f"Time Left: {time_left}") # show

def handleListWidgetSingal(self, i):
Expand Down
Loading

0 comments on commit 90cbb95

Please sign in to comment.