盒子ocr检测
This commit is contained in:
99
pp_onnx/predict_system.py
Normal file
99
pp_onnx/predict_system.py
Normal file
@ -0,0 +1,99 @@
|
||||
import os
|
||||
import cv2
|
||||
import copy
|
||||
import pp_onnx.predict_det as predict_det
|
||||
import pp_onnx.predict_cls as predict_cls
|
||||
import pp_onnx.predict_rec as predict_rec
|
||||
from pp_onnx.utils import get_rotate_crop_image, get_minarea_rect_crop
|
||||
|
||||
|
||||
class TextSystem(object):
|
||||
def __init__(self, args):
|
||||
self.text_detector = predict_det.TextDetector(args)
|
||||
self.text_recognizer = predict_rec.TextRecognizer(args)
|
||||
self.use_angle_cls = args.use_angle_cls
|
||||
self.drop_score = args.drop_score
|
||||
if self.use_angle_cls:
|
||||
self.text_classifier = predict_cls.TextClassifier(args)
|
||||
|
||||
self.args = args
|
||||
self.crop_image_res_index = 0
|
||||
|
||||
|
||||
def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
bbox_num = len(img_crop_list)
|
||||
for bno in range(bbox_num):
|
||||
cv2.imwrite(
|
||||
os.path.join(output_dir,
|
||||
f"mg_crop_{bno+self.crop_image_res_index}.jpg"),
|
||||
img_crop_list[bno])
|
||||
|
||||
self.crop_image_res_index += bbox_num
|
||||
|
||||
def __call__(self, img, cls=True):
|
||||
ori_im = img.copy()
|
||||
# 文字检测
|
||||
dt_boxes = self.text_detector(img)
|
||||
|
||||
if dt_boxes is None:
|
||||
return None, None
|
||||
|
||||
img_crop_list = []
|
||||
|
||||
dt_boxes = sorted_boxes(dt_boxes)
|
||||
|
||||
# 图片裁剪
|
||||
for bno in range(len(dt_boxes)):
|
||||
tmp_box = copy.deepcopy(dt_boxes[bno])
|
||||
if self.args.det_box_type == "quad":
|
||||
img_crop = get_rotate_crop_image(ori_im, tmp_box)
|
||||
else:
|
||||
img_crop = get_minarea_rect_crop(ori_im, tmp_box)
|
||||
img_crop_list.append(img_crop)
|
||||
|
||||
# 方向分类
|
||||
if self.use_angle_cls and cls:
|
||||
img_crop_list, angle_list = self.text_classifier(img_crop_list)
|
||||
|
||||
# 图像识别
|
||||
rec_res = self.text_recognizer(img_crop_list)
|
||||
|
||||
if self.args.save_crop_res:
|
||||
self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,rec_res)
|
||||
filter_boxes, filter_rec_res = [], []
|
||||
for box, rec_result in zip(dt_boxes, rec_res):
|
||||
text, score = rec_result
|
||||
if score >= self.drop_score:
|
||||
filter_boxes.append(box)
|
||||
filter_rec_res.append(rec_result)
|
||||
|
||||
# import IPython
|
||||
# IPython.embed(header='L-70')
|
||||
|
||||
return filter_boxes, filter_rec_res
|
||||
|
||||
|
||||
def sorted_boxes(dt_boxes):
|
||||
"""
|
||||
Sort text boxes in order from top to bottom, left to right
|
||||
args:
|
||||
dt_boxes(array):detected text boxes with shape [4, 2]
|
||||
return:
|
||||
sorted boxes(array) with shape [4, 2]
|
||||
"""
|
||||
num_boxes = dt_boxes.shape[0]
|
||||
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
||||
_boxes = list(sorted_boxes)
|
||||
|
||||
for i in range(num_boxes - 1):
|
||||
for j in range(i, -1, -1):
|
||||
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
|
||||
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
|
||||
tmp = _boxes[j]
|
||||
_boxes[j] = _boxes[j + 1]
|
||||
_boxes[j + 1] = tmp
|
||||
else:
|
||||
break
|
||||
return _boxes
|
||||
|
||||
Reference in New Issue
Block a user