Files
box_ocr/pp_onnx/onnx_paddleocr.py
2025-10-16 17:18:10 +08:00

94 lines
2.8 KiB
Python

import time
from pp_onnx.predict_system import TextSystem
from pp_onnx.utils import infer_args as init_args
from pp_onnx.utils import str2bool, draw_ocr
import argparse
import sys
class ONNXPaddleOcr(TextSystem):
def __init__(self, **kwargs):
# 默认参数
parser = init_args()
# import IPython
# IPython.embed(header='L-14')
inference_args_dict = {}
for action in parser._actions:
inference_args_dict[action.dest] = action.default
params = argparse.Namespace(**inference_args_dict)
params.rec_image_shape = "3, 48, 320"
# 根据传入的参数覆盖更新默认参数
params.__dict__.update(**kwargs)
# 初始化模型
super().__init__(params)
def ocr(self, img, det=True, rec=True, cls=True):
if cls == True and self.use_angle_cls == False:
print('Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process')
if det and rec:
ocr_res = []
dt_boxes, rec_res = self.__call__(img, cls)
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
ocr_res.append(tmp_res)
return ocr_res
elif det and not rec:
ocr_res = []
dt_boxes = self.text_detector(img)
tmp_res = [box.tolist() for box in dt_boxes]
ocr_res.append(tmp_res)
return ocr_res
else:
ocr_res = []
cls_res = []
if not isinstance(img, list):
img = [img]
if self.use_angle_cls and cls:
img, cls_res_tmp = self.text_classifier(img)
if not rec:
cls_res.append(cls_res_tmp)
rec_res = self.text_recognizer(img)
ocr_res.append(rec_res)
if not rec:
return cls_res
return ocr_res
def sav2Img(org_img, result, name="draw_ocr.jpg"):
# 显示结果
from PIL import Image
result = result[0]
# image = Image.open(img_path).convert('RGB')
# 图像转BGR2RGB
image = org_img[:, :, ::-1]
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores)
im_show = Image.fromarray(im_show)
im_show.save(name)
if __name__ == '__main__':
import cv2
model = ONNXPaddleOcr(use_angle_cls=True, use_gpu=False)
img = cv2.imread('/data2/liujingsong3/fiber_box/test/img/20230531230052008263304.jpg')
s = time.time()
result = model.ocr(img)
e = time.time()
print("total time: {:.3f}".format(e - s))
print("result:", result)
for box in result[0]:
print(box)
sav2Img(img, result)