255 lines
7.8 KiB
Plaintext
255 lines
7.8 KiB
Plaintext
文件一:test_ocr
|
||
import cv2
|
||
import time
|
||
from pp_onnx.onnx_paddleocr import ONNXPaddleOcr, draw_ocr
|
||
|
||
model = ONNXPaddleOcr(
|
||
use_angle_cls=True,
|
||
use_gpu=False,
|
||
providers=['RKNNExecutionProvider'],
|
||
provider_options=[{'device_id': 0}]
|
||
)
|
||
|
||
try:
|
||
# 获取文本检测模型的ONNX会话
|
||
onnx_session = model.det_session
|
||
# 获取实际使用的执行提供者
|
||
used_providers = onnx_session.get_providers()
|
||
print(f"当前使用的执行提供者(计算设备):{used_providers}")
|
||
|
||
if 'RKNNExecutionProvider' in used_providers:
|
||
print("✅ 成功使用RK3588 NPU加速推理")
|
||
else:
|
||
print("❌ 未使用NPU,当前设备:", used_providers)
|
||
except AttributeError as e:
|
||
print(f"获取会话失败:{e},请检查 onnx_paddleocr.py 中会话属性名是否正确(如 det_session/rec_session)")
|
||
|
||
|
||
def sav2Img(org_img, result, name="./result_img/draw_ocr_996_1.jpg"):
|
||
from PIL import Image
|
||
result = result[0]
|
||
image = org_img[:, :, ::-1]
|
||
boxes = [line[0] for line in result]
|
||
txts = [line[1][0] for line in result]
|
||
scores = [line[1][1] for line in result]
|
||
im_show = draw_ocr(image, boxes, txts, scores)
|
||
im_show = Image.fromarray(im_show)
|
||
im_show.save(name)
|
||
|
||
|
||
# 执行OCR推理
|
||
img = cv2.imread('./test_img/test1.jpg')
|
||
if img is None:
|
||
print(f"❌ 未找到图像文件:./test_img/test1.jpg")
|
||
else:
|
||
s = time.time()
|
||
result = model.ocr(img)
|
||
e = time.time()
|
||
print(f"total time: {e - s:.3f} 秒")
|
||
print("result:", result)
|
||
for box in result[0]:
|
||
print(box)
|
||
sav2Img(img, result)
|
||
|
||
|
||
文件二:onnx_paddleocr
|
||
import time
|
||
|
||
from pp_onnx.predict_system import TextSystem
|
||
from pp_onnx.utils import infer_args as init_args
|
||
from pp_onnx.utils import str2bool, draw_ocr
|
||
import argparse
|
||
import sys
|
||
|
||
|
||
class ONNXPaddleOcr(TextSystem):
|
||
def __init__(self, **kwargs):
|
||
# 默认参数
|
||
parser = init_args()
|
||
# import IPython
|
||
# IPython.embed(header='L-14')
|
||
|
||
inference_args_dict = {}
|
||
for action in parser._actions:
|
||
inference_args_dict[action.dest] = action.default
|
||
params = argparse.Namespace(**inference_args_dict)
|
||
|
||
params.rec_image_shape = "3, 48, 320"
|
||
|
||
# 根据传入的参数覆盖更新默认参数
|
||
params.__dict__.update(**kwargs)
|
||
|
||
# 初始化模型
|
||
super().__init__(params)
|
||
|
||
def ocr(self, img, det=True, rec=True, cls=True):
|
||
if cls == True and self.use_angle_cls == False:
|
||
print('Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process')
|
||
|
||
if det and rec:
|
||
ocr_res = []
|
||
dt_boxes, rec_res = self.__call__(img, cls)
|
||
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
|
||
ocr_res.append(tmp_res)
|
||
return ocr_res
|
||
elif det and not rec:
|
||
ocr_res = []
|
||
dt_boxes = self.text_detector(img)
|
||
tmp_res = [box.tolist() for box in dt_boxes]
|
||
ocr_res.append(tmp_res)
|
||
return ocr_res
|
||
else:
|
||
ocr_res = []
|
||
cls_res = []
|
||
|
||
if not isinstance(img, list):
|
||
img = [img]
|
||
if self.use_angle_cls and cls:
|
||
img, cls_res_tmp = self.text_classifier(img)
|
||
if not rec:
|
||
cls_res.append(cls_res_tmp)
|
||
rec_res = self.text_recognizer(img)
|
||
ocr_res.append(rec_res)
|
||
|
||
if not rec:
|
||
return cls_res
|
||
return ocr_res
|
||
|
||
|
||
def sav2Img(org_img, result, name="draw_ocr.jpg"):
|
||
# 显示结果
|
||
from PIL import Image
|
||
result = result[0]
|
||
# image = Image.open(img_path).convert('RGB')
|
||
# 图像转BGR2RGB
|
||
image = org_img[:, :, ::-1]
|
||
boxes = [line[0] for line in result]
|
||
txts = [line[1][0] for line in result]
|
||
scores = [line[1][1] for line in result]
|
||
im_show = draw_ocr(image, boxes, txts, scores)
|
||
im_show = Image.fromarray(im_show)
|
||
im_show.save(name)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
import cv2
|
||
|
||
model = ONNXPaddleOcr(use_angle_cls=True, use_gpu=False)
|
||
|
||
|
||
img = cv2.imread('/data2/liujingsong3/fiber_box/test/img/20230531230052008263304.jpg')
|
||
s = time.time()
|
||
result = model.ocr(img)
|
||
e = time.time()
|
||
print("total time: {:.3f}".format(e - s))
|
||
print("result:", result)
|
||
for box in result[0]:
|
||
print(box)
|
||
|
||
sav2Img(img, result)
|
||
|
||
|
||
文件三:predict_system
|
||
|
||
import os
|
||
import cv2
|
||
import copy
|
||
import pp_onnx.predict_det as predict_det
|
||
import pp_onnx.predict_cls as predict_cls
|
||
import pp_onnx.predict_rec as predict_rec
|
||
from pp_onnx.utils import get_rotate_crop_image, get_minarea_rect_crop
|
||
|
||
|
||
class TextSystem(object):
|
||
def __init__(self, args):
|
||
self.text_detector = predict_det.TextDetector(args)
|
||
self.text_recognizer = predict_rec.TextRecognizer(args)
|
||
self.use_angle_cls = args.use_angle_cls
|
||
self.drop_score = args.drop_score
|
||
if self.use_angle_cls:
|
||
self.text_classifier = predict_cls.TextClassifier(args)
|
||
|
||
self.args = args
|
||
self.crop_image_res_index = 0
|
||
|
||
|
||
def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
bbox_num = len(img_crop_list)
|
||
for bno in range(bbox_num):
|
||
cv2.imwrite(
|
||
os.path.join(output_dir,
|
||
f"mg_crop_{bno+self.crop_image_res_index}.jpg"),
|
||
img_crop_list[bno])
|
||
|
||
self.crop_image_res_index += bbox_num
|
||
|
||
def __call__(self, img, cls=True):
|
||
ori_im = img.copy()
|
||
# 文字检测
|
||
dt_boxes = self.text_detector(img)
|
||
|
||
if dt_boxes is None:
|
||
return None, None
|
||
|
||
img_crop_list = []
|
||
|
||
dt_boxes = sorted_boxes(dt_boxes)
|
||
|
||
# 图片裁剪
|
||
for bno in range(len(dt_boxes)):
|
||
tmp_box = copy.deepcopy(dt_boxes[bno])
|
||
if self.args.det_box_type == "quad":
|
||
img_crop = get_rotate_crop_image(ori_im, tmp_box)
|
||
else:
|
||
img_crop = get_minarea_rect_crop(ori_im, tmp_box)
|
||
img_crop_list.append(img_crop)
|
||
|
||
# 方向分类
|
||
if self.use_angle_cls and cls:
|
||
img_crop_list, angle_list = self.text_classifier(img_crop_list)
|
||
|
||
# 图像识别
|
||
rec_res = self.text_recognizer(img_crop_list)
|
||
|
||
if self.args.save_crop_res:
|
||
self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,rec_res)
|
||
filter_boxes, filter_rec_res = [], []
|
||
for box, rec_result in zip(dt_boxes, rec_res):
|
||
text, score = rec_result
|
||
if score >= self.drop_score:
|
||
filter_boxes.append(box)
|
||
filter_rec_res.append(rec_result)
|
||
|
||
# import IPython
|
||
# IPython.embed(header='L-70')
|
||
|
||
return filter_boxes, filter_rec_res
|
||
|
||
|
||
def sorted_boxes(dt_boxes):
|
||
"""
|
||
Sort text boxes in order from top to bottom, left to right
|
||
args:
|
||
dt_boxes(array):detected text boxes with shape [4, 2]
|
||
return:
|
||
sorted boxes(array) with shape [4, 2]
|
||
"""
|
||
num_boxes = dt_boxes.shape[0]
|
||
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
||
_boxes = list(sorted_boxes)
|
||
|
||
for i in range(num_boxes - 1):
|
||
for j in range(i, -1, -1):
|
||
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
|
||
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
|
||
tmp = _boxes[j]
|
||
_boxes[j] = _boxes[j + 1]
|
||
_boxes[j + 1] = tmp
|
||
else:
|
||
break
|
||
return _boxes
|
||
|
||
运行test_ocr报错:(box_ocr) root@ztl:/result/ocr/pp_onnx-main# python test_ocr.py
|
||
获取会话失败:'ONNXPaddleOcr' object has no attribute 'det_session',请检查 onnx_paddleocr.py 中会话属性名是否正确(如 det_session/rec_session)
|
||
total time: 11.161 秒 且检测一张图片耗时太差 |