Files
box_ocr/1
2025-10-16 17:18:10 +08:00

255 lines
7.8 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

文件一test_ocr
import cv2
import time
from pp_onnx.onnx_paddleocr import ONNXPaddleOcr, draw_ocr
model = ONNXPaddleOcr(
use_angle_cls=True,
use_gpu=False,
providers=['RKNNExecutionProvider'],
provider_options=[{'device_id': 0}]
)
try:
# 获取文本检测模型的ONNX会话
onnx_session = model.det_session
# 获取实际使用的执行提供者
used_providers = onnx_session.get_providers()
print(f"当前使用的执行提供者(计算设备):{used_providers}")
if 'RKNNExecutionProvider' in used_providers:
print("✅ 成功使用RK3588 NPU加速推理")
else:
print("❌ 未使用NPU当前设备", used_providers)
except AttributeError as e:
print(f"获取会话失败:{e},请检查 onnx_paddleocr.py 中会话属性名是否正确(如 det_session/rec_session")
def sav2Img(org_img, result, name="./result_img/draw_ocr_996_1.jpg"):
from PIL import Image
result = result[0]
image = org_img[:, :, ::-1]
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores)
im_show = Image.fromarray(im_show)
im_show.save(name)
# 执行OCR推理
img = cv2.imread('./test_img/test1.jpg')
if img is None:
print(f"❌ 未找到图像文件:./test_img/test1.jpg")
else:
s = time.time()
result = model.ocr(img)
e = time.time()
print(f"total time: {e - s:.3f} 秒")
print("result:", result)
for box in result[0]:
print(box)
sav2Img(img, result)
文件二onnx_paddleocr
import time
from pp_onnx.predict_system import TextSystem
from pp_onnx.utils import infer_args as init_args
from pp_onnx.utils import str2bool, draw_ocr
import argparse
import sys
class ONNXPaddleOcr(TextSystem):
def __init__(self, **kwargs):
# 默认参数
parser = init_args()
# import IPython
# IPython.embed(header='L-14')
inference_args_dict = {}
for action in parser._actions:
inference_args_dict[action.dest] = action.default
params = argparse.Namespace(**inference_args_dict)
params.rec_image_shape = "3, 48, 320"
# 根据传入的参数覆盖更新默认参数
params.__dict__.update(**kwargs)
# 初始化模型
super().__init__(params)
def ocr(self, img, det=True, rec=True, cls=True):
if cls == True and self.use_angle_cls == False:
print('Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process')
if det and rec:
ocr_res = []
dt_boxes, rec_res = self.__call__(img, cls)
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
ocr_res.append(tmp_res)
return ocr_res
elif det and not rec:
ocr_res = []
dt_boxes = self.text_detector(img)
tmp_res = [box.tolist() for box in dt_boxes]
ocr_res.append(tmp_res)
return ocr_res
else:
ocr_res = []
cls_res = []
if not isinstance(img, list):
img = [img]
if self.use_angle_cls and cls:
img, cls_res_tmp = self.text_classifier(img)
if not rec:
cls_res.append(cls_res_tmp)
rec_res = self.text_recognizer(img)
ocr_res.append(rec_res)
if not rec:
return cls_res
return ocr_res
def sav2Img(org_img, result, name="draw_ocr.jpg"):
# 显示结果
from PIL import Image
result = result[0]
# image = Image.open(img_path).convert('RGB')
# 图像转BGR2RGB
image = org_img[:, :, ::-1]
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores)
im_show = Image.fromarray(im_show)
im_show.save(name)
if __name__ == '__main__':
import cv2
model = ONNXPaddleOcr(use_angle_cls=True, use_gpu=False)
img = cv2.imread('/data2/liujingsong3/fiber_box/test/img/20230531230052008263304.jpg')
s = time.time()
result = model.ocr(img)
e = time.time()
print("total time: {:.3f}".format(e - s))
print("result:", result)
for box in result[0]:
print(box)
sav2Img(img, result)
文件三predict_system
import os
import cv2
import copy
import pp_onnx.predict_det as predict_det
import pp_onnx.predict_cls as predict_cls
import pp_onnx.predict_rec as predict_rec
from pp_onnx.utils import get_rotate_crop_image, get_minarea_rect_crop
class TextSystem(object):
def __init__(self, args):
self.text_detector = predict_det.TextDetector(args)
self.text_recognizer = predict_rec.TextRecognizer(args)
self.use_angle_cls = args.use_angle_cls
self.drop_score = args.drop_score
if self.use_angle_cls:
self.text_classifier = predict_cls.TextClassifier(args)
self.args = args
self.crop_image_res_index = 0
def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
os.makedirs(output_dir, exist_ok=True)
bbox_num = len(img_crop_list)
for bno in range(bbox_num):
cv2.imwrite(
os.path.join(output_dir,
f"mg_crop_{bno+self.crop_image_res_index}.jpg"),
img_crop_list[bno])
self.crop_image_res_index += bbox_num
def __call__(self, img, cls=True):
ori_im = img.copy()
# 文字检测
dt_boxes = self.text_detector(img)
if dt_boxes is None:
return None, None
img_crop_list = []
dt_boxes = sorted_boxes(dt_boxes)
# 图片裁剪
for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno])
if self.args.det_box_type == "quad":
img_crop = get_rotate_crop_image(ori_im, tmp_box)
else:
img_crop = get_minarea_rect_crop(ori_im, tmp_box)
img_crop_list.append(img_crop)
# 方向分类
if self.use_angle_cls and cls:
img_crop_list, angle_list = self.text_classifier(img_crop_list)
# 图像识别
rec_res = self.text_recognizer(img_crop_list)
if self.args.save_crop_res:
self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,rec_res)
filter_boxes, filter_rec_res = [], []
for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_result
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_result)
# import IPython
# IPython.embed(header='L-70')
return filter_boxes, filter_rec_res
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
for j in range(i, -1, -1):
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
tmp = _boxes[j]
_boxes[j] = _boxes[j + 1]
_boxes[j + 1] = tmp
else:
break
return _boxes
运行test_ocr报错(box_ocr) root@ztl:/result/ocr/pp_onnx-main# python test_ocr.py
获取会话失败:'ONNXPaddleOcr' object has no attribute 'det_session',请检查 onnx_paddleocr.py 中会话属性名是否正确(如 det_session/rec_session
total time: 11.161 秒 且检测一张图片耗时太差