盒子ocr检测
162
.gitignore
vendored
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||||
|
.pdm.toml
|
||||||
|
.pdm-python
|
||||||
|
.pdm-build/
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
8
.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# 默认忽略的文件
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# 基于编辑器的 HTTP 客户端请求
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
||||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
||||||
7
.idea/misc.xml
generated
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="box_ocr" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="box_ocr" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
||||||
8
.idea/modules.xml
generated
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/pp_onnx-main.iml" filepath="$PROJECT_DIR$/.idea/pp_onnx-main.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
12
.idea/pp_onnx-main.iml
generated
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="jdk" jdkName="box_ocr" jdkType="Python SDK" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
<component name="PyDocumentationSettings">
|
||||||
|
<option name="format" value="PLAIN" />
|
||||||
|
<option name="myDocStringFormat" value="Plain" />
|
||||||
|
</component>
|
||||||
|
</module>
|
||||||
255
1
Normal file
@ -0,0 +1,255 @@
|
|||||||
|
文件一:test_ocr
|
||||||
|
import cv2
|
||||||
|
import time
|
||||||
|
from pp_onnx.onnx_paddleocr import ONNXPaddleOcr, draw_ocr
|
||||||
|
|
||||||
|
model = ONNXPaddleOcr(
|
||||||
|
use_angle_cls=True,
|
||||||
|
use_gpu=False,
|
||||||
|
providers=['RKNNExecutionProvider'],
|
||||||
|
provider_options=[{'device_id': 0}]
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取文本检测模型的ONNX会话
|
||||||
|
onnx_session = model.det_session
|
||||||
|
# 获取实际使用的执行提供者
|
||||||
|
used_providers = onnx_session.get_providers()
|
||||||
|
print(f"当前使用的执行提供者(计算设备):{used_providers}")
|
||||||
|
|
||||||
|
if 'RKNNExecutionProvider' in used_providers:
|
||||||
|
print("✅ 成功使用RK3588 NPU加速推理")
|
||||||
|
else:
|
||||||
|
print("❌ 未使用NPU,当前设备:", used_providers)
|
||||||
|
except AttributeError as e:
|
||||||
|
print(f"获取会话失败:{e},请检查 onnx_paddleocr.py 中会话属性名是否正确(如 det_session/rec_session)")
|
||||||
|
|
||||||
|
|
||||||
|
def sav2Img(org_img, result, name="./result_img/draw_ocr_996_1.jpg"):
|
||||||
|
from PIL import Image
|
||||||
|
result = result[0]
|
||||||
|
image = org_img[:, :, ::-1]
|
||||||
|
boxes = [line[0] for line in result]
|
||||||
|
txts = [line[1][0] for line in result]
|
||||||
|
scores = [line[1][1] for line in result]
|
||||||
|
im_show = draw_ocr(image, boxes, txts, scores)
|
||||||
|
im_show = Image.fromarray(im_show)
|
||||||
|
im_show.save(name)
|
||||||
|
|
||||||
|
|
||||||
|
# 执行OCR推理
|
||||||
|
img = cv2.imread('./test_img/test1.jpg')
|
||||||
|
if img is None:
|
||||||
|
print(f"❌ 未找到图像文件:./test_img/test1.jpg")
|
||||||
|
else:
|
||||||
|
s = time.time()
|
||||||
|
result = model.ocr(img)
|
||||||
|
e = time.time()
|
||||||
|
print(f"total time: {e - s:.3f} 秒")
|
||||||
|
print("result:", result)
|
||||||
|
for box in result[0]:
|
||||||
|
print(box)
|
||||||
|
sav2Img(img, result)
|
||||||
|
|
||||||
|
|
||||||
|
文件二:onnx_paddleocr
|
||||||
|
import time
|
||||||
|
|
||||||
|
from pp_onnx.predict_system import TextSystem
|
||||||
|
from pp_onnx.utils import infer_args as init_args
|
||||||
|
from pp_onnx.utils import str2bool, draw_ocr
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXPaddleOcr(TextSystem):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
# 默认参数
|
||||||
|
parser = init_args()
|
||||||
|
# import IPython
|
||||||
|
# IPython.embed(header='L-14')
|
||||||
|
|
||||||
|
inference_args_dict = {}
|
||||||
|
for action in parser._actions:
|
||||||
|
inference_args_dict[action.dest] = action.default
|
||||||
|
params = argparse.Namespace(**inference_args_dict)
|
||||||
|
|
||||||
|
params.rec_image_shape = "3, 48, 320"
|
||||||
|
|
||||||
|
# 根据传入的参数覆盖更新默认参数
|
||||||
|
params.__dict__.update(**kwargs)
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
|
super().__init__(params)
|
||||||
|
|
||||||
|
def ocr(self, img, det=True, rec=True, cls=True):
|
||||||
|
if cls == True and self.use_angle_cls == False:
|
||||||
|
print('Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process')
|
||||||
|
|
||||||
|
if det and rec:
|
||||||
|
ocr_res = []
|
||||||
|
dt_boxes, rec_res = self.__call__(img, cls)
|
||||||
|
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
|
||||||
|
ocr_res.append(tmp_res)
|
||||||
|
return ocr_res
|
||||||
|
elif det and not rec:
|
||||||
|
ocr_res = []
|
||||||
|
dt_boxes = self.text_detector(img)
|
||||||
|
tmp_res = [box.tolist() for box in dt_boxes]
|
||||||
|
ocr_res.append(tmp_res)
|
||||||
|
return ocr_res
|
||||||
|
else:
|
||||||
|
ocr_res = []
|
||||||
|
cls_res = []
|
||||||
|
|
||||||
|
if not isinstance(img, list):
|
||||||
|
img = [img]
|
||||||
|
if self.use_angle_cls and cls:
|
||||||
|
img, cls_res_tmp = self.text_classifier(img)
|
||||||
|
if not rec:
|
||||||
|
cls_res.append(cls_res_tmp)
|
||||||
|
rec_res = self.text_recognizer(img)
|
||||||
|
ocr_res.append(rec_res)
|
||||||
|
|
||||||
|
if not rec:
|
||||||
|
return cls_res
|
||||||
|
return ocr_res
|
||||||
|
|
||||||
|
|
||||||
|
def sav2Img(org_img, result, name="draw_ocr.jpg"):
|
||||||
|
# 显示结果
|
||||||
|
from PIL import Image
|
||||||
|
result = result[0]
|
||||||
|
# image = Image.open(img_path).convert('RGB')
|
||||||
|
# 图像转BGR2RGB
|
||||||
|
image = org_img[:, :, ::-1]
|
||||||
|
boxes = [line[0] for line in result]
|
||||||
|
txts = [line[1][0] for line in result]
|
||||||
|
scores = [line[1][1] for line in result]
|
||||||
|
im_show = draw_ocr(image, boxes, txts, scores)
|
||||||
|
im_show = Image.fromarray(im_show)
|
||||||
|
im_show.save(name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
model = ONNXPaddleOcr(use_angle_cls=True, use_gpu=False)
|
||||||
|
|
||||||
|
|
||||||
|
img = cv2.imread('/data2/liujingsong3/fiber_box/test/img/20230531230052008263304.jpg')
|
||||||
|
s = time.time()
|
||||||
|
result = model.ocr(img)
|
||||||
|
e = time.time()
|
||||||
|
print("total time: {:.3f}".format(e - s))
|
||||||
|
print("result:", result)
|
||||||
|
for box in result[0]:
|
||||||
|
print(box)
|
||||||
|
|
||||||
|
sav2Img(img, result)
|
||||||
|
|
||||||
|
|
||||||
|
文件三:predict_system
|
||||||
|
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import copy
|
||||||
|
import pp_onnx.predict_det as predict_det
|
||||||
|
import pp_onnx.predict_cls as predict_cls
|
||||||
|
import pp_onnx.predict_rec as predict_rec
|
||||||
|
from pp_onnx.utils import get_rotate_crop_image, get_minarea_rect_crop
|
||||||
|
|
||||||
|
|
||||||
|
class TextSystem(object):
|
||||||
|
def __init__(self, args):
|
||||||
|
self.text_detector = predict_det.TextDetector(args)
|
||||||
|
self.text_recognizer = predict_rec.TextRecognizer(args)
|
||||||
|
self.use_angle_cls = args.use_angle_cls
|
||||||
|
self.drop_score = args.drop_score
|
||||||
|
if self.use_angle_cls:
|
||||||
|
self.text_classifier = predict_cls.TextClassifier(args)
|
||||||
|
|
||||||
|
self.args = args
|
||||||
|
self.crop_image_res_index = 0
|
||||||
|
|
||||||
|
|
||||||
|
def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
bbox_num = len(img_crop_list)
|
||||||
|
for bno in range(bbox_num):
|
||||||
|
cv2.imwrite(
|
||||||
|
os.path.join(output_dir,
|
||||||
|
f"mg_crop_{bno+self.crop_image_res_index}.jpg"),
|
||||||
|
img_crop_list[bno])
|
||||||
|
|
||||||
|
self.crop_image_res_index += bbox_num
|
||||||
|
|
||||||
|
def __call__(self, img, cls=True):
|
||||||
|
ori_im = img.copy()
|
||||||
|
# 文字检测
|
||||||
|
dt_boxes = self.text_detector(img)
|
||||||
|
|
||||||
|
if dt_boxes is None:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
img_crop_list = []
|
||||||
|
|
||||||
|
dt_boxes = sorted_boxes(dt_boxes)
|
||||||
|
|
||||||
|
# 图片裁剪
|
||||||
|
for bno in range(len(dt_boxes)):
|
||||||
|
tmp_box = copy.deepcopy(dt_boxes[bno])
|
||||||
|
if self.args.det_box_type == "quad":
|
||||||
|
img_crop = get_rotate_crop_image(ori_im, tmp_box)
|
||||||
|
else:
|
||||||
|
img_crop = get_minarea_rect_crop(ori_im, tmp_box)
|
||||||
|
img_crop_list.append(img_crop)
|
||||||
|
|
||||||
|
# 方向分类
|
||||||
|
if self.use_angle_cls and cls:
|
||||||
|
img_crop_list, angle_list = self.text_classifier(img_crop_list)
|
||||||
|
|
||||||
|
# 图像识别
|
||||||
|
rec_res = self.text_recognizer(img_crop_list)
|
||||||
|
|
||||||
|
if self.args.save_crop_res:
|
||||||
|
self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,rec_res)
|
||||||
|
filter_boxes, filter_rec_res = [], []
|
||||||
|
for box, rec_result in zip(dt_boxes, rec_res):
|
||||||
|
text, score = rec_result
|
||||||
|
if score >= self.drop_score:
|
||||||
|
filter_boxes.append(box)
|
||||||
|
filter_rec_res.append(rec_result)
|
||||||
|
|
||||||
|
# import IPython
|
||||||
|
# IPython.embed(header='L-70')
|
||||||
|
|
||||||
|
return filter_boxes, filter_rec_res
|
||||||
|
|
||||||
|
|
||||||
|
def sorted_boxes(dt_boxes):
|
||||||
|
"""
|
||||||
|
Sort text boxes in order from top to bottom, left to right
|
||||||
|
args:
|
||||||
|
dt_boxes(array):detected text boxes with shape [4, 2]
|
||||||
|
return:
|
||||||
|
sorted boxes(array) with shape [4, 2]
|
||||||
|
"""
|
||||||
|
num_boxes = dt_boxes.shape[0]
|
||||||
|
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
||||||
|
_boxes = list(sorted_boxes)
|
||||||
|
|
||||||
|
for i in range(num_boxes - 1):
|
||||||
|
for j in range(i, -1, -1):
|
||||||
|
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
|
||||||
|
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
|
||||||
|
tmp = _boxes[j]
|
||||||
|
_boxes[j] = _boxes[j + 1]
|
||||||
|
_boxes[j + 1] = tmp
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return _boxes
|
||||||
|
|
||||||
|
运行test_ocr报错:(box_ocr) root@ztl:/result/ocr/pp_onnx-main# python test_ocr.py
|
||||||
|
获取会话失败:'ONNXPaddleOcr' object has no attribute 'det_session',请检查 onnx_paddleocr.py 中会话属性名是否正确(如 det_session/rec_session)
|
||||||
|
total time: 11.161 秒 且检测一张图片耗时太差
|
||||||
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
49
Readme.md
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# onnxOCR
|
||||||
|
#### 一.优势:
|
||||||
|
1.脱离深度学习训练框架,可直接用于部署的通用OCR。
|
||||||
|
2.在算力有限,精度不变的情况下使用paddleOCR转成ONNX模型,进行重新构建的一款可部署在arm架构和x86架构计算机上的OCR模型。
|
||||||
|
3.在同样性能的计算机上推理速度加速了4-5倍。
|
||||||
|
|
||||||
|
#### 二.环境安装
|
||||||
|
python>=3.6
|
||||||
|
|
||||||
|
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt
|
||||||
|
|
||||||
|
由于rec模型超过了100M,github有限制,所以我上传到
|
||||||
|
[百度网盘,提取码: 125c](https://pan.baidu.com/s/1O1b30CMwsDjD7Ti9EnxYKQ )
|
||||||
|
|
||||||
|
下载后放到./models/ch_ppocr_server_v2.0/rec/rec.onnx下
|
||||||
|
|
||||||
|
#### 三.一键运行
|
||||||
|
|
||||||
|
python test_ocr.py
|
||||||
|
|
||||||
|
#### 效果展示
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
#### 感谢PaddleOcr
|
||||||
|
|
||||||
|
https://github.com/PaddlePaddle/PaddleOCR
|
||||||
|
|
||||||
|
#### 从该项目Fork而来
|
||||||
|
https://github.com/jingsongliujing/OnnxOCR
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
CHANGELOG
|
||||||
|
|
||||||
|
1. 加入最新的`pp_ocr_v4`的检测与识别模型
|
||||||
|
2. 修改包名为`pp_onnx`防止与onnx冲突
|
||||||
|
3. 修改部分写死的参数
|
||||||
|
|
||||||
6
__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from pp_onnx.onnx_paddleocr import ONNXPaddleOcr
|
||||||
BIN
det_result.jpg
Normal file
|
After Width: | Height: | Size: 92 KiB |
0
pp_onnx/__init__.py
Normal file
30
pp_onnx/cls_postprocess.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
|
||||||
|
# import paddle
|
||||||
|
|
||||||
|
|
||||||
|
class ClsPostProcess(object):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self, label_list=None, key=None, **kwargs):
|
||||||
|
super(ClsPostProcess, self).__init__()
|
||||||
|
self.label_list = label_list
|
||||||
|
self.key = key
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
if self.key is not None:
|
||||||
|
preds = preds[self.key]
|
||||||
|
|
||||||
|
label_list = self.label_list
|
||||||
|
if label_list is None:
|
||||||
|
label_list = {idx: idx for idx in range(preds.shape[-1])}
|
||||||
|
|
||||||
|
# if isinstance(preds, paddle.Tensor):
|
||||||
|
# preds = preds.numpy()
|
||||||
|
|
||||||
|
pred_idxs = preds.argmax(axis=1)
|
||||||
|
decode_out = [(label_list[idx], preds[i, idx])
|
||||||
|
for i, idx in enumerate(pred_idxs)]
|
||||||
|
if label is None:
|
||||||
|
return decode_out
|
||||||
|
label = [(label_list[idx], 1.0) for idx in label]
|
||||||
|
return decode_out, label
|
||||||
276
pp_onnx/db_postprocess.py
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This code is refered from:
|
||||||
|
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
# import paddle
|
||||||
|
from shapely.geometry import Polygon
|
||||||
|
import pyclipper
|
||||||
|
|
||||||
|
|
||||||
|
class DBPostProcess(object):
|
||||||
|
"""
|
||||||
|
The post process for Differentiable Binarization (DB).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
thresh=0.3,
|
||||||
|
box_thresh=0.7,
|
||||||
|
max_candidates=1000,
|
||||||
|
unclip_ratio=2.0,
|
||||||
|
use_dilation=False,
|
||||||
|
score_mode="fast",
|
||||||
|
box_type='quad',
|
||||||
|
**kwargs):
|
||||||
|
self.thresh = thresh
|
||||||
|
self.box_thresh = box_thresh
|
||||||
|
self.max_candidates = max_candidates
|
||||||
|
self.unclip_ratio = unclip_ratio
|
||||||
|
self.min_size = 3
|
||||||
|
self.score_mode = score_mode
|
||||||
|
self.box_type = box_type
|
||||||
|
assert score_mode in [
|
||||||
|
"slow", "fast"
|
||||||
|
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
|
||||||
|
|
||||||
|
self.dilation_kernel = None if not use_dilation else np.array(
|
||||||
|
[[1, 1], [1, 1]])
|
||||||
|
|
||||||
|
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||||
|
'''
|
||||||
|
_bitmap: single map with shape (1, H, W),
|
||||||
|
whose values are binarized as {0, 1}
|
||||||
|
'''
|
||||||
|
|
||||||
|
bitmap = _bitmap
|
||||||
|
height, width = bitmap.shape
|
||||||
|
|
||||||
|
boxes = []
|
||||||
|
scores = []
|
||||||
|
|
||||||
|
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
|
||||||
|
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
|
||||||
|
for contour in contours[:self.max_candidates]:
|
||||||
|
epsilon = 0.002 * cv2.arcLength(contour, True)
|
||||||
|
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||||
|
points = approx.reshape((-1, 2))
|
||||||
|
if points.shape[0] < 4:
|
||||||
|
continue
|
||||||
|
|
||||||
|
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
||||||
|
if self.box_thresh > score:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if points.shape[0] > 2:
|
||||||
|
box = self.unclip(points, self.unclip_ratio)
|
||||||
|
if len(box) > 1:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
box = box.reshape(-1, 2)
|
||||||
|
|
||||||
|
_, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
|
||||||
|
if sside < self.min_size + 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
box = np.array(box)
|
||||||
|
box[:, 0] = np.clip(
|
||||||
|
np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
||||||
|
box[:, 1] = np.clip(
|
||||||
|
np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
||||||
|
boxes.append(box.tolist())
|
||||||
|
scores.append(score)
|
||||||
|
return boxes, scores
|
||||||
|
|
||||||
|
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||||
|
'''
|
||||||
|
_bitmap: single map with shape (1, H, W),
|
||||||
|
whose values are binarized as {0, 1}
|
||||||
|
'''
|
||||||
|
|
||||||
|
bitmap = _bitmap
|
||||||
|
height, width = bitmap.shape
|
||||||
|
|
||||||
|
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
|
||||||
|
cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
if len(outs) == 3:
|
||||||
|
img, contours, _ = outs[0], outs[1], outs[2]
|
||||||
|
elif len(outs) == 2:
|
||||||
|
contours, _ = outs[0], outs[1]
|
||||||
|
|
||||||
|
num_contours = min(len(contours), self.max_candidates)
|
||||||
|
|
||||||
|
boxes = []
|
||||||
|
scores = []
|
||||||
|
for index in range(num_contours):
|
||||||
|
contour = contours[index]
|
||||||
|
points, sside = self.get_mini_boxes(contour)
|
||||||
|
if sside < self.min_size:
|
||||||
|
continue
|
||||||
|
points = np.array(points)
|
||||||
|
if self.score_mode == "fast":
|
||||||
|
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
||||||
|
else:
|
||||||
|
score = self.box_score_slow(pred, contour)
|
||||||
|
if self.box_thresh > score:
|
||||||
|
continue
|
||||||
|
|
||||||
|
box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
|
||||||
|
box, sside = self.get_mini_boxes(box)
|
||||||
|
if sside < self.min_size + 2:
|
||||||
|
continue
|
||||||
|
box = np.array(box)
|
||||||
|
|
||||||
|
box[:, 0] = np.clip(
|
||||||
|
np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
||||||
|
box[:, 1] = np.clip(
|
||||||
|
np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
||||||
|
boxes.append(box.astype("int32"))
|
||||||
|
scores.append(score)
|
||||||
|
return np.array(boxes, dtype="int32"), scores
|
||||||
|
|
||||||
|
def unclip(self, box, unclip_ratio):
|
||||||
|
poly = Polygon(box)
|
||||||
|
distance = poly.area * unclip_ratio / poly.length
|
||||||
|
offset = pyclipper.PyclipperOffset()
|
||||||
|
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
||||||
|
expanded = np.array(offset.Execute(distance))
|
||||||
|
return expanded
|
||||||
|
|
||||||
|
def get_mini_boxes(self, contour):
|
||||||
|
bounding_box = cv2.minAreaRect(contour)
|
||||||
|
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
||||||
|
|
||||||
|
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
|
||||||
|
if points[1][1] > points[0][1]:
|
||||||
|
index_1 = 0
|
||||||
|
index_4 = 1
|
||||||
|
else:
|
||||||
|
index_1 = 1
|
||||||
|
index_4 = 0
|
||||||
|
if points[3][1] > points[2][1]:
|
||||||
|
index_2 = 2
|
||||||
|
index_3 = 3
|
||||||
|
else:
|
||||||
|
index_2 = 3
|
||||||
|
index_3 = 2
|
||||||
|
|
||||||
|
box = [
|
||||||
|
points[index_1], points[index_2], points[index_3], points[index_4]
|
||||||
|
]
|
||||||
|
return box, min(bounding_box[1])
|
||||||
|
|
||||||
|
def box_score_fast(self, bitmap, _box):
|
||||||
|
'''
|
||||||
|
box_score_fast: use bbox mean score as the mean score
|
||||||
|
'''
|
||||||
|
h, w = bitmap.shape[:2]
|
||||||
|
box = _box.copy()
|
||||||
|
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
|
||||||
|
xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
|
||||||
|
ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
|
||||||
|
ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
|
||||||
|
|
||||||
|
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
||||||
|
box[:, 0] = box[:, 0] - xmin
|
||||||
|
box[:, 1] = box[:, 1] - ymin
|
||||||
|
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
|
||||||
|
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||||
|
|
||||||
|
def box_score_slow(self, bitmap, contour):
|
||||||
|
'''
|
||||||
|
box_score_slow: use polyon mean score as the mean score
|
||||||
|
'''
|
||||||
|
h, w = bitmap.shape[:2]
|
||||||
|
contour = contour.copy()
|
||||||
|
contour = np.reshape(contour, (-1, 2))
|
||||||
|
|
||||||
|
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
|
||||||
|
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
|
||||||
|
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
|
||||||
|
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
|
||||||
|
|
||||||
|
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
||||||
|
|
||||||
|
contour[:, 0] = contour[:, 0] - xmin
|
||||||
|
contour[:, 1] = contour[:, 1] - ymin
|
||||||
|
|
||||||
|
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)
|
||||||
|
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||||
|
|
||||||
|
def __call__(self, outs_dict, shape_list):
|
||||||
|
pred = outs_dict['maps']
|
||||||
|
# if isinstance(pred, paddle.Tensor):
|
||||||
|
# pred = pred.numpy()
|
||||||
|
pred = pred[:, 0, :, :]
|
||||||
|
segmentation = pred > self.thresh
|
||||||
|
|
||||||
|
boxes_batch = []
|
||||||
|
for batch_index in range(pred.shape[0]):
|
||||||
|
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
|
||||||
|
if self.dilation_kernel is not None:
|
||||||
|
mask = cv2.dilate(
|
||||||
|
np.array(segmentation[batch_index]).astype(np.uint8),
|
||||||
|
self.dilation_kernel)
|
||||||
|
else:
|
||||||
|
mask = segmentation[batch_index]
|
||||||
|
if self.box_type == 'poly':
|
||||||
|
boxes, scores = self.polygons_from_bitmap(pred[batch_index],
|
||||||
|
mask, src_w, src_h)
|
||||||
|
elif self.box_type == 'quad':
|
||||||
|
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
|
||||||
|
src_w, src_h)
|
||||||
|
else:
|
||||||
|
raise ValueError("box_type can only be one of ['quad', 'poly']")
|
||||||
|
|
||||||
|
boxes_batch.append({'points': boxes})
|
||||||
|
return boxes_batch
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationDBPostProcess(object):
|
||||||
|
def __init__(self,
|
||||||
|
model_name=["student"],
|
||||||
|
key=None,
|
||||||
|
thresh=0.3,
|
||||||
|
box_thresh=0.6,
|
||||||
|
max_candidates=1000,
|
||||||
|
unclip_ratio=1.5,
|
||||||
|
use_dilation=False,
|
||||||
|
score_mode="fast",
|
||||||
|
box_type='quad',
|
||||||
|
**kwargs):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.key = key
|
||||||
|
self.post_process = DBPostProcess(
|
||||||
|
thresh=thresh,
|
||||||
|
box_thresh=box_thresh,
|
||||||
|
max_candidates=max_candidates,
|
||||||
|
unclip_ratio=unclip_ratio,
|
||||||
|
use_dilation=use_dilation,
|
||||||
|
score_mode=score_mode,
|
||||||
|
box_type=box_type)
|
||||||
|
|
||||||
|
def __call__(self, predicts, shape_list):
|
||||||
|
results = {}
|
||||||
|
for k in self.model_name:
|
||||||
|
results[k] = self.post_process(predicts[k], shape_list=shape_list)
|
||||||
|
return results
|
||||||
BIN
pp_onnx/fonts/simfang.ttf
Normal file
32
pp_onnx/imaug.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from pp_onnx.operators import *
|
||||||
|
|
||||||
|
def transform(data, ops=None):
|
||||||
|
""" transform """
|
||||||
|
if ops is None:
|
||||||
|
ops = []
|
||||||
|
for op in ops:
|
||||||
|
data = op(data)
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def create_operators(op_param_list, global_config=None):
|
||||||
|
"""
|
||||||
|
create operators based on the config
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params(list): a dict list, used to create some operators
|
||||||
|
"""
|
||||||
|
assert isinstance(op_param_list, list), ('operator config should be a list')
|
||||||
|
ops = []
|
||||||
|
for operator in op_param_list:
|
||||||
|
assert isinstance(operator,
|
||||||
|
dict) and len(operator) == 1, "yaml format error"
|
||||||
|
op_name = list(operator)[0]
|
||||||
|
param = {} if operator[op_name] is None else operator[op_name]
|
||||||
|
if global_config is not None:
|
||||||
|
param.update(global_config)
|
||||||
|
op = eval(op_name)(**param)
|
||||||
|
ops.append(op)
|
||||||
|
return ops
|
||||||
341
pp_onnx/legancy/utils copy.py
Normal file
@ -0,0 +1,341 @@
|
|||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
def get_rotate_crop_image(img, points):
|
||||||
|
'''
|
||||||
|
img_height, img_width = img.shape[0:2]
|
||||||
|
left = int(np.min(points[:, 0]))
|
||||||
|
right = int(np.max(points[:, 0]))
|
||||||
|
top = int(np.min(points[:, 1]))
|
||||||
|
bottom = int(np.max(points[:, 1]))
|
||||||
|
img_crop = img[top:bottom, left:right, :].copy()
|
||||||
|
points[:, 0] = points[:, 0] - left
|
||||||
|
points[:, 1] = points[:, 1] - top
|
||||||
|
'''
|
||||||
|
assert len(points) == 4, "shape of points must be 4*2"
|
||||||
|
img_crop_width = int(
|
||||||
|
max(
|
||||||
|
np.linalg.norm(points[0] - points[1]),
|
||||||
|
np.linalg.norm(points[2] - points[3])))
|
||||||
|
img_crop_height = int(
|
||||||
|
max(
|
||||||
|
np.linalg.norm(points[0] - points[3]),
|
||||||
|
np.linalg.norm(points[1] - points[2])))
|
||||||
|
pts_std = np.float32([[0, 0], [img_crop_width, 0],
|
||||||
|
[img_crop_width, img_crop_height],
|
||||||
|
[0, img_crop_height]])
|
||||||
|
M = cv2.getPerspectiveTransform(points, pts_std)
|
||||||
|
dst_img = cv2.warpPerspective(
|
||||||
|
img,
|
||||||
|
M, (img_crop_width, img_crop_height),
|
||||||
|
borderMode=cv2.BORDER_REPLICATE,
|
||||||
|
flags=cv2.INTER_CUBIC)
|
||||||
|
dst_img_height, dst_img_width = dst_img.shape[0:2]
|
||||||
|
if dst_img_height * 1.0 / dst_img_width >= 1.5:
|
||||||
|
dst_img = np.rot90(dst_img)
|
||||||
|
return dst_img
|
||||||
|
|
||||||
|
def get_minarea_rect_crop(img, points):
|
||||||
|
bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
|
||||||
|
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
||||||
|
|
||||||
|
index_a, index_b, index_c, index_d = 0, 1, 2, 3
|
||||||
|
if points[1][1] > points[0][1]:
|
||||||
|
index_a = 0
|
||||||
|
index_d = 1
|
||||||
|
else:
|
||||||
|
index_a = 1
|
||||||
|
index_d = 0
|
||||||
|
if points[3][1] > points[2][1]:
|
||||||
|
index_b = 2
|
||||||
|
index_c = 3
|
||||||
|
else:
|
||||||
|
index_b = 3
|
||||||
|
index_c = 2
|
||||||
|
|
||||||
|
box = [points[index_a], points[index_b], points[index_c], points[index_d]]
|
||||||
|
crop_img = get_rotate_crop_image(img, np.array(box))
|
||||||
|
return crop_img
|
||||||
|
|
||||||
|
|
||||||
|
def resize_img(img, input_size=600):
|
||||||
|
"""
|
||||||
|
resize img and limit the longest side of the image to input_size
|
||||||
|
"""
|
||||||
|
img = np.array(img)
|
||||||
|
im_shape = img.shape
|
||||||
|
im_size_max = np.max(im_shape[0:2])
|
||||||
|
im_scale = float(input_size) / float(im_size_max)
|
||||||
|
img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def str_count(s):
|
||||||
|
"""
|
||||||
|
Count the number of Chinese characters,
|
||||||
|
a single English character and a single number
|
||||||
|
equal to half the length of Chinese characters.
|
||||||
|
args:
|
||||||
|
s(string): the input of string
|
||||||
|
return(int):
|
||||||
|
the number of Chinese characters
|
||||||
|
"""
|
||||||
|
import string
|
||||||
|
count_zh = count_pu = 0
|
||||||
|
s_len = len(str(s))
|
||||||
|
en_dg_count = 0
|
||||||
|
for c in str(s):
|
||||||
|
if c in string.ascii_letters or c.isdigit() or c.isspace():
|
||||||
|
en_dg_count += 1
|
||||||
|
elif c.isalpha():
|
||||||
|
count_zh += 1
|
||||||
|
else:
|
||||||
|
count_pu += 1
|
||||||
|
return s_len - math.ceil(en_dg_count / 2)
|
||||||
|
|
||||||
|
def text_visual(texts,
|
||||||
|
scores,
|
||||||
|
img_h=400,
|
||||||
|
img_w=600,
|
||||||
|
threshold=0.,
|
||||||
|
font_path="./fonts/simfang.ttf"):
|
||||||
|
"""
|
||||||
|
create new blank img and draw txt on it
|
||||||
|
args:
|
||||||
|
texts(list): the text will be draw
|
||||||
|
scores(list|None): corresponding score of each txt
|
||||||
|
img_h(int): the height of blank img
|
||||||
|
img_w(int): the width of blank img
|
||||||
|
font_path: the path of font which is used to draw text
|
||||||
|
return(array):
|
||||||
|
"""
|
||||||
|
if scores is not None:
|
||||||
|
assert len(texts) == len(
|
||||||
|
scores), "The number of txts and corresponding scores must match"
|
||||||
|
|
||||||
|
def create_blank_img():
|
||||||
|
blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
|
||||||
|
blank_img[:, img_w - 1:] = 0
|
||||||
|
blank_img = Image.fromarray(blank_img).convert("RGB")
|
||||||
|
draw_txt = ImageDraw.Draw(blank_img)
|
||||||
|
return blank_img, draw_txt
|
||||||
|
|
||||||
|
blank_img, draw_txt = create_blank_img()
|
||||||
|
|
||||||
|
font_size = 20
|
||||||
|
txt_color = (0, 0, 0)
|
||||||
|
# import IPython; IPython.embed(header='L-129')
|
||||||
|
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
|
||||||
|
|
||||||
|
gap = font_size + 5
|
||||||
|
txt_img_list = []
|
||||||
|
count, index = 1, 0
|
||||||
|
for idx, txt in enumerate(texts):
|
||||||
|
index += 1
|
||||||
|
if scores[idx] < threshold or math.isnan(scores[idx]):
|
||||||
|
index -= 1
|
||||||
|
continue
|
||||||
|
first_line = True
|
||||||
|
while str_count(txt) >= img_w // font_size - 4:
|
||||||
|
tmp = txt
|
||||||
|
txt = tmp[:img_w // font_size - 4]
|
||||||
|
if first_line:
|
||||||
|
new_txt = str(index) + ': ' + txt
|
||||||
|
first_line = False
|
||||||
|
else:
|
||||||
|
new_txt = ' ' + txt
|
||||||
|
draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
|
||||||
|
txt = tmp[img_w // font_size - 4:]
|
||||||
|
if count >= img_h // gap - 1:
|
||||||
|
txt_img_list.append(np.array(blank_img))
|
||||||
|
blank_img, draw_txt = create_blank_img()
|
||||||
|
count = 0
|
||||||
|
count += 1
|
||||||
|
if first_line:
|
||||||
|
new_txt = str(index) + ': ' + txt + ' ' + '%.3f' % (scores[idx])
|
||||||
|
else:
|
||||||
|
new_txt = " " + txt + " " + '%.3f' % (scores[idx])
|
||||||
|
draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
|
||||||
|
# whether add new blank img or not
|
||||||
|
if count >= img_h // gap - 1 and idx + 1 < len(texts):
|
||||||
|
txt_img_list.append(np.array(blank_img))
|
||||||
|
blank_img, draw_txt = create_blank_img()
|
||||||
|
count = 0
|
||||||
|
count += 1
|
||||||
|
txt_img_list.append(np.array(blank_img))
|
||||||
|
if len(txt_img_list) == 1:
|
||||||
|
blank_img = np.array(txt_img_list[0])
|
||||||
|
else:
|
||||||
|
blank_img = np.concatenate(txt_img_list, axis=1)
|
||||||
|
return np.array(blank_img)
|
||||||
|
|
||||||
|
def draw_ocr(image,
|
||||||
|
boxes,
|
||||||
|
txts=None,
|
||||||
|
scores=None,
|
||||||
|
drop_score=0.5,
|
||||||
|
font_path="./pp_onnx/fonts/simfang.ttf"):
|
||||||
|
"""
|
||||||
|
Visualize the results of OCR detection and recognition
|
||||||
|
args:
|
||||||
|
image(Image|array): RGB image
|
||||||
|
boxes(list): boxes with shape(N, 4, 2)
|
||||||
|
txts(list): the texts
|
||||||
|
scores(list): txxs corresponding scores
|
||||||
|
drop_score(float): only scores greater than drop_threshold will be visualized
|
||||||
|
font_path: the path of font which is used to draw text
|
||||||
|
return(array):
|
||||||
|
the visualized img
|
||||||
|
"""
|
||||||
|
if scores is None:
|
||||||
|
scores = [1] * len(boxes)
|
||||||
|
box_num = len(boxes)
|
||||||
|
for i in range(box_num):
|
||||||
|
if scores is not None and (scores[i] < drop_score or
|
||||||
|
math.isnan(scores[i])):
|
||||||
|
continue
|
||||||
|
box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
|
||||||
|
image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
|
||||||
|
if txts is not None:
|
||||||
|
img = np.array(resize_img(image, input_size=600))
|
||||||
|
txt_img = text_visual(
|
||||||
|
txts,
|
||||||
|
scores,
|
||||||
|
img_h=img.shape[0],
|
||||||
|
img_w=600,
|
||||||
|
threshold=drop_score,
|
||||||
|
font_path=font_path)
|
||||||
|
img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
|
||||||
|
return img
|
||||||
|
return image
|
||||||
|
|
||||||
|
def base64_to_cv2(b64str):
|
||||||
|
import base64
|
||||||
|
data = base64.b64decode(b64str.encode('utf8'))
|
||||||
|
data = np.frombuffer(data, np.uint8)
|
||||||
|
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def str2bool(v):
|
||||||
|
return v.lower() in ("true", "t", "1")
|
||||||
|
|
||||||
|
def infer_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# params for prediction engine
|
||||||
|
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||||
|
parser.add_argument("--use_xpu", type=str2bool, default=False)
|
||||||
|
parser.add_argument("--use_npu", type=str2bool, default=False)
|
||||||
|
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
||||||
|
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
||||||
|
parser.add_argument("--min_subgraph_size", type=int, default=15)
|
||||||
|
parser.add_argument("--precision", type=str, default="fp32")
|
||||||
|
parser.add_argument("--gpu_mem", type=int, default=500)
|
||||||
|
parser.add_argument("--gpu_id", type=int, default=0)
|
||||||
|
|
||||||
|
# params for text detector
|
||||||
|
parser.add_argument("--image_dir", type=str)
|
||||||
|
parser.add_argument("--page_num", type=int, default=0)
|
||||||
|
parser.add_argument("--det_algorithm", type=str, default='DB')
|
||||||
|
# parser.add_argument("--det_model_dir", type=str, default='./onnx/models/ch_ppocr_server_v2.0/det/det.onnx')
|
||||||
|
parser.add_argument("--det_model_dir", type=str, default='./pp_onnx/models/ch_PP-OCRv4/ch_PP-OCRv4_det_infer.onnx')
|
||||||
|
parser.add_argument("--det_limit_side_len", type=float, default=960)
|
||||||
|
parser.add_argument("--det_limit_type", type=str, default='max')
|
||||||
|
parser.add_argument("--det_box_type", type=str, default='quad')
|
||||||
|
|
||||||
|
# DB parmas
|
||||||
|
parser.add_argument("--det_db_thresh", type=float, default=0.3)
|
||||||
|
parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
|
||||||
|
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
|
||||||
|
parser.add_argument("--max_batch_size", type=int, default=10)
|
||||||
|
parser.add_argument("--use_dilation", type=str2bool, default=False)
|
||||||
|
parser.add_argument("--det_db_score_mode", type=str, default="fast")
|
||||||
|
|
||||||
|
# # EAST parmas
|
||||||
|
# parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
|
||||||
|
# parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
|
||||||
|
# parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
|
||||||
|
|
||||||
|
# # SAST parmas
|
||||||
|
# parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
|
||||||
|
# parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
|
||||||
|
|
||||||
|
# # PSE parmas
|
||||||
|
# parser.add_argument("--det_pse_thresh", type=float, default=0)
|
||||||
|
# parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
|
||||||
|
# parser.add_argument("--det_pse_min_area", type=float, default=16)
|
||||||
|
# parser.add_argument("--det_pse_scale", type=int, default=1)
|
||||||
|
|
||||||
|
# # FCE parmas
|
||||||
|
# parser.add_argument("--scales", type=list, default=[8, 16, 32])
|
||||||
|
# parser.add_argument("--alpha", type=float, default=1.0)
|
||||||
|
# parser.add_argument("--beta", type=float, default=1.0)
|
||||||
|
# parser.add_argument("--fourier_degree", type=int, default=5)
|
||||||
|
|
||||||
|
# params for text recognizer
|
||||||
|
parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')
|
||||||
|
# parser.add_argument("--rec_model_dir", type=str, default='./onnx/models/ch_ppocr_server_v2.0/rec/rec.onnx')
|
||||||
|
parser.add_argument("--rec_model_dir", type=str, default='./pp_onnx/models/ch_PP-OCRv4/ch_PP-OCRv4_rec_infer.onnx')
|
||||||
|
parser.add_argument("--rec_image_inverse", type=str2bool, default=True)
|
||||||
|
# parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
|
||||||
|
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
|
||||||
|
parser.add_argument("--rec_batch_num", type=int, default=6)
|
||||||
|
parser.add_argument("--max_text_length", type=int, default=25)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rec_char_dict_path",
|
||||||
|
type=str,
|
||||||
|
default='./pp_onnx/models/ch_ppocr_server_v2.0/ppocr_keys_v1.txt')
|
||||||
|
parser.add_argument("--use_space_char", type=str2bool, default=True)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vis_font_path", type=str, default="./pp_onnx/fonts/simfang.ttf")
|
||||||
|
parser.add_argument("--drop_score", type=float, default=0.5)
|
||||||
|
|
||||||
|
# params for e2e
|
||||||
|
parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
|
||||||
|
parser.add_argument("--e2e_model_dir", type=str)
|
||||||
|
parser.add_argument("--e2e_limit_side_len", type=float, default=768)
|
||||||
|
parser.add_argument("--e2e_limit_type", type=str, default='max')
|
||||||
|
|
||||||
|
# PGNet parmas
|
||||||
|
parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
|
||||||
|
parser.add_argument(
|
||||||
|
"--e2e_char_dict_path", type=str, default="./onnx/ppocr/utils/ic15_dict.txt")
|
||||||
|
parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
|
||||||
|
parser.add_argument("--e2e_pgnet_mode", type=str, default='fast')
|
||||||
|
|
||||||
|
# params for text classifier
|
||||||
|
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
|
||||||
|
parser.add_argument("--cls_model_dir", type=str, default='./pp_onnx/models/ch_ppocr_server_v2.0/cls/cls.onnx')
|
||||||
|
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
|
||||||
|
parser.add_argument("--label_list", type=list, default=['0', '180'])
|
||||||
|
parser.add_argument("--cls_batch_num", type=int, default=6)
|
||||||
|
parser.add_argument("--cls_thresh", type=float, default=0.9)
|
||||||
|
|
||||||
|
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
|
||||||
|
parser.add_argument("--cpu_threads", type=int, default=10)
|
||||||
|
parser.add_argument("--use_pdserving", type=str2bool, default=False)
|
||||||
|
parser.add_argument("--warmup", type=str2bool, default=False)
|
||||||
|
|
||||||
|
# SR parmas
|
||||||
|
parser.add_argument("--sr_model_dir", type=str)
|
||||||
|
parser.add_argument("--sr_image_shape", type=str, default="3, 32, 128")
|
||||||
|
parser.add_argument("--sr_batch_num", type=int, default=1)
|
||||||
|
|
||||||
|
#
|
||||||
|
parser.add_argument(
|
||||||
|
"--draw_img_save_dir", type=str, default="./onnx/inference_results")
|
||||||
|
parser.add_argument("--save_crop_res", type=str2bool, default=False)
|
||||||
|
parser.add_argument("--crop_res_save_dir", type=str, default="./onnx/output")
|
||||||
|
|
||||||
|
# multi-process
|
||||||
|
parser.add_argument("--use_mp", type=str2bool, default=False)
|
||||||
|
parser.add_argument("--total_process_num", type=int, default=1)
|
||||||
|
parser.add_argument("--process_id", type=int, default=0)
|
||||||
|
|
||||||
|
parser.add_argument("--benchmark", type=str2bool, default=False)
|
||||||
|
parser.add_argument("--save_log_path", type=str, default="./onnx/log_output/")
|
||||||
|
|
||||||
|
parser.add_argument("--show_log", type=str2bool, default=True)
|
||||||
|
parser.add_argument("--use_onnx", type=str2bool, default=False)
|
||||||
|
return parser
|
||||||
45
pp_onnx/logger.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
LogName = 'Umi-OCR_log'
|
||||||
|
LogFileName = 'Umi-OCR_debug.log'
|
||||||
|
|
||||||
|
|
||||||
|
class Logger:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.initLogger()
|
||||||
|
|
||||||
|
def initLogger(self):
|
||||||
|
'''初始化日志'''
|
||||||
|
|
||||||
|
# 日志
|
||||||
|
self.logger = logging.getLogger(LogName)
|
||||||
|
self.logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
# 控制台
|
||||||
|
streamHandler = logging.StreamHandler()
|
||||||
|
streamHandler.setLevel(logging.DEBUG)
|
||||||
|
formatPrint = logging.Formatter(
|
||||||
|
'【%(levelname)s】 %(message)s')
|
||||||
|
streamHandler.setFormatter(formatPrint)
|
||||||
|
# self.logger.addHandler(streamHandler)
|
||||||
|
|
||||||
|
return
|
||||||
|
# 日志文件
|
||||||
|
fileHandler = logging.FileHandler(LogFileName)
|
||||||
|
fileHandler.setLevel(logging.ERROR)
|
||||||
|
formatFile = logging.Formatter(
|
||||||
|
'''
|
||||||
|
【%(levelname)s】 %(asctime)s
|
||||||
|
%(message)s
|
||||||
|
文件:%(module)s | 函数:%(funcName)s | 行号:%(lineno)d
|
||||||
|
线程id:%(thread)d | 线程名:%(thread)s''')
|
||||||
|
fileHandler.setFormatter(formatFile)
|
||||||
|
self.logger.addHandler(fileHandler)
|
||||||
|
|
||||||
|
|
||||||
|
LOG = Logger()
|
||||||
|
|
||||||
|
|
||||||
|
def GetLog():
|
||||||
|
return LOG.logger
|
||||||
BIN
pp_onnx/models/ch_PP-OCRv4/ch_PP-OCRv4_det_infer.onnx
Normal file
BIN
pp_onnx/models/ch_PP-OCRv4/ch_PP-OCRv4_rec_infer.onnx
Normal file
BIN
pp_onnx/models/ch_PP-OCRv4/ppocrv4_det_rk3588_i8.rknn
Normal file
BIN
pp_onnx/models/ch_PP-OCRv4/ppocrv4_rec_rk3588_fp.rknn
Normal file
BIN
pp_onnx/models/ch_ppocr_server_v2.0/cls/cls.onnx
Normal file
BIN
pp_onnx/models/ch_ppocr_server_v2.0/det/det.onnx
Normal file
6623
pp_onnx/models/ch_ppocr_server_v2.0/ppocr_keys_v1.txt
Normal file
94
pp_onnx/onnx_paddleocr.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
from pp_onnx.predict_system import TextSystem
|
||||||
|
from pp_onnx.utils import infer_args as init_args
|
||||||
|
from pp_onnx.utils import str2bool, draw_ocr
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXPaddleOcr(TextSystem):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
# 默认参数
|
||||||
|
parser = init_args()
|
||||||
|
# import IPython
|
||||||
|
# IPython.embed(header='L-14')
|
||||||
|
|
||||||
|
inference_args_dict = {}
|
||||||
|
for action in parser._actions:
|
||||||
|
inference_args_dict[action.dest] = action.default
|
||||||
|
params = argparse.Namespace(**inference_args_dict)
|
||||||
|
|
||||||
|
params.rec_image_shape = "3, 48, 320"
|
||||||
|
|
||||||
|
# 根据传入的参数覆盖更新默认参数
|
||||||
|
params.__dict__.update(**kwargs)
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
|
super().__init__(params)
|
||||||
|
|
||||||
|
def ocr(self, img, det=True, rec=True, cls=True):
|
||||||
|
if cls == True and self.use_angle_cls == False:
|
||||||
|
print('Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process')
|
||||||
|
|
||||||
|
if det and rec:
|
||||||
|
ocr_res = []
|
||||||
|
dt_boxes, rec_res = self.__call__(img, cls)
|
||||||
|
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
|
||||||
|
ocr_res.append(tmp_res)
|
||||||
|
return ocr_res
|
||||||
|
elif det and not rec:
|
||||||
|
ocr_res = []
|
||||||
|
dt_boxes = self.text_detector(img)
|
||||||
|
tmp_res = [box.tolist() for box in dt_boxes]
|
||||||
|
ocr_res.append(tmp_res)
|
||||||
|
return ocr_res
|
||||||
|
else:
|
||||||
|
ocr_res = []
|
||||||
|
cls_res = []
|
||||||
|
|
||||||
|
if not isinstance(img, list):
|
||||||
|
img = [img]
|
||||||
|
if self.use_angle_cls and cls:
|
||||||
|
img, cls_res_tmp = self.text_classifier(img)
|
||||||
|
if not rec:
|
||||||
|
cls_res.append(cls_res_tmp)
|
||||||
|
rec_res = self.text_recognizer(img)
|
||||||
|
ocr_res.append(rec_res)
|
||||||
|
|
||||||
|
if not rec:
|
||||||
|
return cls_res
|
||||||
|
return ocr_res
|
||||||
|
|
||||||
|
|
||||||
|
def sav2Img(org_img, result, name="draw_ocr.jpg"):
|
||||||
|
# 显示结果
|
||||||
|
from PIL import Image
|
||||||
|
result = result[0]
|
||||||
|
# image = Image.open(img_path).convert('RGB')
|
||||||
|
# 图像转BGR2RGB
|
||||||
|
image = org_img[:, :, ::-1]
|
||||||
|
boxes = [line[0] for line in result]
|
||||||
|
txts = [line[1][0] for line in result]
|
||||||
|
scores = [line[1][1] for line in result]
|
||||||
|
im_show = draw_ocr(image, boxes, txts, scores)
|
||||||
|
im_show = Image.fromarray(im_show)
|
||||||
|
im_show.save(name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
model = ONNXPaddleOcr(use_angle_cls=True, use_gpu=False)
|
||||||
|
|
||||||
|
|
||||||
|
img = cv2.imread('/data2/liujingsong3/fiber_box/test/img/20230531230052008263304.jpg')
|
||||||
|
s = time.time()
|
||||||
|
result = model.ocr(img)
|
||||||
|
e = time.time()
|
||||||
|
print("total time: {:.3f}".format(e - s))
|
||||||
|
print("result:", result)
|
||||||
|
for box in result[0]:
|
||||||
|
print(box)
|
||||||
|
|
||||||
|
sav2Img(img, result)
|
||||||
187
pp_onnx/operators.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import sys
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeImage(object):
|
||||||
|
""" normalize image such as substract mean, divide std
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
|
||||||
|
if isinstance(scale, str):
|
||||||
|
scale = eval(scale)
|
||||||
|
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
||||||
|
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
||||||
|
std = std if std is not None else [0.229, 0.224, 0.225]
|
||||||
|
|
||||||
|
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
|
||||||
|
self.mean = np.array(mean).reshape(shape).astype('float32')
|
||||||
|
self.std = np.array(std).reshape(shape).astype('float32')
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
img = data['image']
|
||||||
|
from PIL import Image
|
||||||
|
if isinstance(img, Image.Image):
|
||||||
|
img = np.array(img)
|
||||||
|
assert isinstance(img,
|
||||||
|
np.ndarray), "invalid input 'img' in NormalizeImage"
|
||||||
|
data['image'] = (
|
||||||
|
img.astype('float32') * self.scale - self.mean) / self.std
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class DetResizeForTest(object):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(DetResizeForTest, self).__init__()
|
||||||
|
self.resize_type = 0
|
||||||
|
self.keep_ratio = False
|
||||||
|
if 'image_shape' in kwargs:
|
||||||
|
self.image_shape = kwargs['image_shape']
|
||||||
|
self.resize_type = 1
|
||||||
|
if 'keep_ratio' in kwargs:
|
||||||
|
self.keep_ratio = kwargs['keep_ratio']
|
||||||
|
elif 'limit_side_len' in kwargs:
|
||||||
|
self.limit_side_len = kwargs['limit_side_len']
|
||||||
|
self.limit_type = kwargs.get('limit_type', 'min')
|
||||||
|
elif 'resize_long' in kwargs:
|
||||||
|
self.resize_type = 2
|
||||||
|
self.resize_long = kwargs.get('resize_long', 960)
|
||||||
|
else:
|
||||||
|
self.limit_side_len = 736
|
||||||
|
self.limit_type = 'min'
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
img = data['image']
|
||||||
|
src_h, src_w, _ = img.shape
|
||||||
|
if sum([src_h, src_w]) < 64:
|
||||||
|
img = self.image_padding(img)
|
||||||
|
|
||||||
|
if self.resize_type == 0:
|
||||||
|
# img, shape = self.resize_image_type0(img)
|
||||||
|
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
|
||||||
|
elif self.resize_type == 2:
|
||||||
|
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
|
||||||
|
else:
|
||||||
|
# img, shape = self.resize_image_type1(img)
|
||||||
|
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
|
||||||
|
data['image'] = img
|
||||||
|
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
|
||||||
|
return data
|
||||||
|
|
||||||
|
def image_padding(self, im, value=0):
|
||||||
|
h, w, c = im.shape
|
||||||
|
im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
|
||||||
|
im_pad[:h, :w, :] = im
|
||||||
|
return im_pad
|
||||||
|
|
||||||
|
def resize_image_type1(self, img):
|
||||||
|
resize_h, resize_w = self.image_shape
|
||||||
|
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
||||||
|
if self.keep_ratio is True:
|
||||||
|
resize_w = ori_w * resize_h / ori_h
|
||||||
|
N = math.ceil(resize_w / 32)
|
||||||
|
resize_w = N * 32
|
||||||
|
ratio_h = float(resize_h) / ori_h
|
||||||
|
ratio_w = float(resize_w) / ori_w
|
||||||
|
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||||
|
# return img, np.array([ori_h, ori_w])
|
||||||
|
return img, [ratio_h, ratio_w]
|
||||||
|
|
||||||
|
def resize_image_type0(self, img):
|
||||||
|
"""
|
||||||
|
resize image to a size multiple of 32 which is required by the network
|
||||||
|
args:
|
||||||
|
img(array): array with shape [h, w, c]
|
||||||
|
return(tuple):
|
||||||
|
img, (ratio_h, ratio_w)
|
||||||
|
"""
|
||||||
|
limit_side_len = self.limit_side_len
|
||||||
|
h, w, c = img.shape
|
||||||
|
|
||||||
|
# limit the max side
|
||||||
|
if self.limit_type == 'max':
|
||||||
|
if max(h, w) > limit_side_len:
|
||||||
|
if h > w:
|
||||||
|
ratio = float(limit_side_len) / h
|
||||||
|
else:
|
||||||
|
ratio = float(limit_side_len) / w
|
||||||
|
else:
|
||||||
|
ratio = 1.
|
||||||
|
elif self.limit_type == 'min':
|
||||||
|
if min(h, w) < limit_side_len:
|
||||||
|
if h < w:
|
||||||
|
ratio = float(limit_side_len) / h
|
||||||
|
else:
|
||||||
|
ratio = float(limit_side_len) / w
|
||||||
|
else:
|
||||||
|
ratio = 1.
|
||||||
|
elif self.limit_type == 'resize_long':
|
||||||
|
ratio = float(limit_side_len) / max(h, w)
|
||||||
|
else:
|
||||||
|
raise Exception('not support limit type, image ')
|
||||||
|
resize_h = int(h * ratio)
|
||||||
|
resize_w = int(w * ratio)
|
||||||
|
|
||||||
|
resize_h = max(int(round(resize_h / 32) * 32), 32)
|
||||||
|
resize_w = max(int(round(resize_w / 32) * 32), 32)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if int(resize_w) <= 0 or int(resize_h) <= 0:
|
||||||
|
return None, (None, None)
|
||||||
|
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||||
|
except:
|
||||||
|
print(img.shape, resize_w, resize_h)
|
||||||
|
sys.exit(0)
|
||||||
|
ratio_h = resize_h / float(h)
|
||||||
|
ratio_w = resize_w / float(w)
|
||||||
|
return img, [ratio_h, ratio_w]
|
||||||
|
|
||||||
|
def resize_image_type2(self, img):
|
||||||
|
h, w, _ = img.shape
|
||||||
|
|
||||||
|
resize_w = w
|
||||||
|
resize_h = h
|
||||||
|
|
||||||
|
if resize_h > resize_w:
|
||||||
|
ratio = float(self.resize_long) / resize_h
|
||||||
|
else:
|
||||||
|
ratio = float(self.resize_long) / resize_w
|
||||||
|
|
||||||
|
resize_h = int(resize_h * ratio)
|
||||||
|
resize_w = int(resize_w * ratio)
|
||||||
|
|
||||||
|
max_stride = 128
|
||||||
|
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||||
|
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||||
|
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||||
|
ratio_h = resize_h / float(h)
|
||||||
|
ratio_w = resize_w / float(w)
|
||||||
|
|
||||||
|
return img, [ratio_h, ratio_w]
|
||||||
|
|
||||||
|
class ToCHWImage(object):
|
||||||
|
""" convert hwc image to chw image
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
img = data['image']
|
||||||
|
from PIL import Image
|
||||||
|
if isinstance(img, Image.Image):
|
||||||
|
img = np.array(img)
|
||||||
|
data['image'] = img.transpose((2, 0, 1))
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class KeepKeys(object):
|
||||||
|
def __init__(self, keep_keys, **kwargs):
|
||||||
|
self.keep_keys = keep_keys
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
data_list = []
|
||||||
|
for key in self.keep_keys:
|
||||||
|
data_list.append(data[key])
|
||||||
|
return data_list
|
||||||
52
pp_onnx/predict_base.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import onnxruntime
|
||||||
|
|
||||||
|
class PredictBase(object):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_onnx_session(self, model_dir, use_gpu):
|
||||||
|
# 使用gpu
|
||||||
|
if use_gpu:
|
||||||
|
providers = providers=['CUDAExecutionProvider']
|
||||||
|
else:
|
||||||
|
providers = providers = ['CPUExecutionProvider']
|
||||||
|
|
||||||
|
onnx_session = onnxruntime.InferenceSession(str(model_dir), None, providers=providers)
|
||||||
|
|
||||||
|
# print("providers:", onnxruntime.get_device())
|
||||||
|
return onnx_session
|
||||||
|
|
||||||
|
|
||||||
|
def get_output_name(self, onnx_session):
|
||||||
|
"""
|
||||||
|
output_name = onnx_session.get_outputs()[0].name
|
||||||
|
:param onnx_session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
output_name = []
|
||||||
|
for node in onnx_session.get_outputs():
|
||||||
|
output_name.append(node.name)
|
||||||
|
return output_name
|
||||||
|
|
||||||
|
def get_input_name(self, onnx_session):
|
||||||
|
"""
|
||||||
|
input_name = onnx_session.get_inputs()[0].name
|
||||||
|
:param onnx_session:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
input_name = []
|
||||||
|
for node in onnx_session.get_inputs():
|
||||||
|
input_name.append(node.name)
|
||||||
|
return input_name
|
||||||
|
|
||||||
|
def get_input_feed(self, input_name, image_numpy):
|
||||||
|
"""
|
||||||
|
input_feed={self.input_name: image_numpy}
|
||||||
|
:param input_name:
|
||||||
|
:param image_numpy:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
input_feed = {}
|
||||||
|
for name in input_name:
|
||||||
|
input_feed[name] = image_numpy
|
||||||
|
return input_feed
|
||||||
86
pp_onnx/predict_cls.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import cv2
|
||||||
|
import copy
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
|
||||||
|
from pp_onnx.cls_postprocess import ClsPostProcess
|
||||||
|
from pp_onnx.predict_base import PredictBase
|
||||||
|
|
||||||
|
class TextClassifier(PredictBase):
|
||||||
|
def __init__(self, args):
|
||||||
|
self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
|
||||||
|
self.cls_batch_num = args.cls_batch_num
|
||||||
|
self.cls_thresh = args.cls_thresh
|
||||||
|
self.postprocess_op = ClsPostProcess(label_list=args.label_list)
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
|
self.cls_onnx_session = self.get_onnx_session(args.cls_model_dir, args.use_gpu)
|
||||||
|
self.cls_input_name = self.get_input_name(self.cls_onnx_session)
|
||||||
|
self.cls_output_name = self.get_output_name(self.cls_onnx_session)
|
||||||
|
|
||||||
|
def resize_norm_img(self, img):
|
||||||
|
imgC, imgH, imgW = self.cls_image_shape
|
||||||
|
h = img.shape[0]
|
||||||
|
w = img.shape[1]
|
||||||
|
ratio = w / float(h)
|
||||||
|
if math.ceil(imgH * ratio) > imgW:
|
||||||
|
resized_w = imgW
|
||||||
|
else:
|
||||||
|
resized_w = int(math.ceil(imgH * ratio))
|
||||||
|
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||||
|
resized_image = resized_image.astype('float32')
|
||||||
|
if self.cls_image_shape[0] == 1:
|
||||||
|
resized_image = resized_image / 255
|
||||||
|
resized_image = resized_image[np.newaxis, :]
|
||||||
|
else:
|
||||||
|
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||||
|
resized_image -= 0.5
|
||||||
|
resized_image /= 0.5
|
||||||
|
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
||||||
|
padding_im[:, :, 0:resized_w] = resized_image
|
||||||
|
return padding_im
|
||||||
|
|
||||||
|
def __call__(self, img_list):
|
||||||
|
img_list = copy.deepcopy(img_list)
|
||||||
|
img_num = len(img_list)
|
||||||
|
# Calculate the aspect ratio of all text bars
|
||||||
|
width_list = []
|
||||||
|
for img in img_list:
|
||||||
|
width_list.append(img.shape[1] / float(img.shape[0]))
|
||||||
|
# Sorting can speed up the cls process
|
||||||
|
indices = np.argsort(np.array(width_list))
|
||||||
|
|
||||||
|
cls_res = [['', 0.0]] * img_num
|
||||||
|
batch_num = self.cls_batch_num
|
||||||
|
|
||||||
|
for beg_img_no in range(0, img_num, batch_num):
|
||||||
|
|
||||||
|
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||||
|
norm_img_batch = []
|
||||||
|
max_wh_ratio = 0
|
||||||
|
|
||||||
|
for ino in range(beg_img_no, end_img_no):
|
||||||
|
h, w = img_list[indices[ino]].shape[0:2]
|
||||||
|
wh_ratio = w * 1.0 / h
|
||||||
|
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||||
|
for ino in range(beg_img_no, end_img_no):
|
||||||
|
norm_img = self.resize_norm_img(img_list[indices[ino]])
|
||||||
|
norm_img = norm_img[np.newaxis, :]
|
||||||
|
norm_img_batch.append(norm_img)
|
||||||
|
norm_img_batch = np.concatenate(norm_img_batch)
|
||||||
|
norm_img_batch = norm_img_batch.copy()
|
||||||
|
|
||||||
|
input_feed = self.get_input_feed(self.cls_input_name, norm_img_batch)
|
||||||
|
outputs = self.cls_onnx_session.run(self.cls_output_name, input_feed=input_feed)
|
||||||
|
|
||||||
|
prob_out = outputs[0]
|
||||||
|
|
||||||
|
cls_result = self.postprocess_op(prob_out)
|
||||||
|
for rno in range(len(cls_result)):
|
||||||
|
label, score = cls_result[rno]
|
||||||
|
cls_res[indices[beg_img_no + rno]] = [label, score]
|
||||||
|
if '180' in label and score > self.cls_thresh:
|
||||||
|
img_list[indices[beg_img_no + rno]] = cv2.rotate(
|
||||||
|
img_list[indices[beg_img_no + rno]], 1)
|
||||||
|
return img_list, cls_res
|
||||||
|
|
||||||
126
pp_onnx/predict_det.py
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
import numpy as np
|
||||||
|
from pp_onnx.imaug import transform, create_operators
|
||||||
|
from pp_onnx.db_postprocess import DBPostProcess
|
||||||
|
from pp_onnx.predict_base import PredictBase
|
||||||
|
|
||||||
|
|
||||||
|
class TextDetector(PredictBase):
|
||||||
|
def __init__(self, args):
|
||||||
|
self.args = args
|
||||||
|
self.det_algorithm = args.det_algorithm
|
||||||
|
pre_process_list = [{
|
||||||
|
'DetResizeForTest': {
|
||||||
|
'limit_side_len': args.det_limit_side_len,
|
||||||
|
'limit_type': args.det_limit_type,
|
||||||
|
}
|
||||||
|
}, {
|
||||||
|
'NormalizeImage': {
|
||||||
|
'std': [0.229, 0.224, 0.225],
|
||||||
|
'mean': [0.485, 0.456, 0.406],
|
||||||
|
'scale': '1./255.',
|
||||||
|
'order': 'hwc'
|
||||||
|
}
|
||||||
|
}, {
|
||||||
|
'ToCHWImage': None
|
||||||
|
}, {
|
||||||
|
'KeepKeys': {
|
||||||
|
'keep_keys': ['image', 'shape']
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
postprocess_params = {}
|
||||||
|
postprocess_params['name'] = 'DBPostProcess'
|
||||||
|
postprocess_params["thresh"] = args.det_db_thresh
|
||||||
|
postprocess_params["box_thresh"] = args.det_db_box_thresh
|
||||||
|
postprocess_params["max_candidates"] = 1000
|
||||||
|
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
|
||||||
|
postprocess_params["use_dilation"] = args.use_dilation
|
||||||
|
postprocess_params["score_mode"] = args.det_db_score_mode
|
||||||
|
postprocess_params["box_type"] = args.det_box_type
|
||||||
|
|
||||||
|
# 实例化预处理操作类
|
||||||
|
self.preprocess_op = create_operators(pre_process_list)
|
||||||
|
# self.postprocess_op = build_post_process(postprocess_params)
|
||||||
|
# 实例化后处理操作类
|
||||||
|
self.postprocess_op = DBPostProcess(**postprocess_params)
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
|
self.det_onnx_session = self.get_onnx_session(args.det_model_dir, args.use_gpu)
|
||||||
|
self.det_input_name = self.get_input_name(self.det_onnx_session)
|
||||||
|
self.det_output_name = self.get_output_name(self.det_onnx_session)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def order_points_clockwise(self, pts):
|
||||||
|
rect = np.zeros((4, 2), dtype="float32")
|
||||||
|
s = pts.sum(axis=1)
|
||||||
|
rect[0] = pts[np.argmin(s)]
|
||||||
|
rect[2] = pts[np.argmax(s)]
|
||||||
|
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
|
||||||
|
diff = np.diff(np.array(tmp), axis=1)
|
||||||
|
rect[1] = tmp[np.argmin(diff)]
|
||||||
|
rect[3] = tmp[np.argmax(diff)]
|
||||||
|
return rect
|
||||||
|
|
||||||
|
def clip_det_res(self, points, img_height, img_width):
|
||||||
|
for pno in range(points.shape[0]):
|
||||||
|
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
|
||||||
|
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
|
||||||
|
return points
|
||||||
|
|
||||||
|
def filter_tag_det_res(self, dt_boxes, image_shape):
|
||||||
|
img_height, img_width = image_shape[0:2]
|
||||||
|
dt_boxes_new = []
|
||||||
|
for box in dt_boxes:
|
||||||
|
if type(box) is list:
|
||||||
|
box = np.array(box)
|
||||||
|
box = self.order_points_clockwise(box)
|
||||||
|
box = self.clip_det_res(box, img_height, img_width)
|
||||||
|
rect_width = int(np.linalg.norm(box[0] - box[1]))
|
||||||
|
rect_height = int(np.linalg.norm(box[0] - box[3]))
|
||||||
|
if rect_width <= 3 or rect_height <= 3:
|
||||||
|
continue
|
||||||
|
dt_boxes_new.append(box)
|
||||||
|
dt_boxes = np.array(dt_boxes_new)
|
||||||
|
return dt_boxes
|
||||||
|
|
||||||
|
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
|
||||||
|
img_height, img_width = image_shape[0:2]
|
||||||
|
dt_boxes_new = []
|
||||||
|
for box in dt_boxes:
|
||||||
|
if type(box) is list:
|
||||||
|
box = np.array(box)
|
||||||
|
box = self.clip_det_res(box, img_height, img_width)
|
||||||
|
dt_boxes_new.append(box)
|
||||||
|
dt_boxes = np.array(dt_boxes_new)
|
||||||
|
return dt_boxes
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
ori_im = img.copy()
|
||||||
|
data = {'image': img}
|
||||||
|
|
||||||
|
data = transform(data, self.preprocess_op)
|
||||||
|
img, shape_list = data
|
||||||
|
if img is None:
|
||||||
|
return None, 0
|
||||||
|
img = np.expand_dims(img, axis=0)
|
||||||
|
shape_list = np.expand_dims(shape_list, axis=0)
|
||||||
|
img = img.copy()
|
||||||
|
|
||||||
|
|
||||||
|
input_feed = self.get_input_feed(self.det_input_name, img)
|
||||||
|
outputs = self.det_onnx_session.run(self.det_output_name, input_feed=input_feed)
|
||||||
|
|
||||||
|
preds = {}
|
||||||
|
preds['maps'] = outputs[0]
|
||||||
|
|
||||||
|
post_result = self.postprocess_op(preds, shape_list)
|
||||||
|
dt_boxes = post_result[0]['points']
|
||||||
|
|
||||||
|
if self.args.det_box_type == 'poly':
|
||||||
|
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
|
||||||
|
else:
|
||||||
|
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
||||||
|
|
||||||
|
return dt_boxes
|
||||||
|
|
||||||
321
pp_onnx/predict_rec.py
Normal file
@ -0,0 +1,321 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
from pp_onnx.rec_postprocess import CTCLabelDecode
|
||||||
|
from pp_onnx.predict_base import PredictBase
|
||||||
|
|
||||||
|
class TextRecognizer(PredictBase):
|
||||||
|
def __init__(self, args):
|
||||||
|
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
|
||||||
|
self.rec_batch_num = args.rec_batch_num
|
||||||
|
self.rec_algorithm = args.rec_algorithm
|
||||||
|
self.postprocess_op = CTCLabelDecode(character_dict_path=args.rec_char_dict_path, use_space_char=args.use_space_char)
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
|
self.rec_onnx_session = self.get_onnx_session(args.rec_model_dir, args.use_gpu)
|
||||||
|
self.rec_input_name = self.get_input_name(self.rec_onnx_session)
|
||||||
|
self.rec_output_name = self.get_output_name(self.rec_onnx_session)
|
||||||
|
|
||||||
|
|
||||||
|
def resize_norm_img(self, img, max_wh_ratio):
|
||||||
|
imgC, imgH, imgW = self.rec_image_shape
|
||||||
|
if self.rec_algorithm == 'NRTR' or self.rec_algorithm == 'ViTSTR':
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||||
|
# return padding_im
|
||||||
|
image_pil = Image.fromarray(np.uint8(img))
|
||||||
|
if self.rec_algorithm == 'ViTSTR':
|
||||||
|
img = image_pil.resize([imgW, imgH], Image.BICUBIC)
|
||||||
|
else:
|
||||||
|
img = image_pil.resize([imgW, imgH], Image.ANTIALIAS)
|
||||||
|
img = np.array(img)
|
||||||
|
norm_img = np.expand_dims(img, -1)
|
||||||
|
norm_img = norm_img.transpose((2, 0, 1))
|
||||||
|
if self.rec_algorithm == 'ViTSTR':
|
||||||
|
norm_img = norm_img.astype(np.float32) / 255.
|
||||||
|
else:
|
||||||
|
norm_img = norm_img.astype(np.float32) / 128. - 1.
|
||||||
|
return norm_img
|
||||||
|
elif self.rec_algorithm == 'RFL':
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||||
|
resized_image = cv2.resize(
|
||||||
|
img, (imgW, imgH), interpolation=cv2.INTER_CUBIC)
|
||||||
|
resized_image = resized_image.astype('float32')
|
||||||
|
resized_image = resized_image / 255
|
||||||
|
resized_image = resized_image[np.newaxis, :]
|
||||||
|
resized_image -= 0.5
|
||||||
|
resized_image /= 0.5
|
||||||
|
return resized_image
|
||||||
|
|
||||||
|
assert imgC == img.shape[2]
|
||||||
|
imgW = int((imgH * max_wh_ratio))
|
||||||
|
|
||||||
|
# import IPython
|
||||||
|
# IPython.embed(header="predict_rec.py L-56")
|
||||||
|
|
||||||
|
w = self.rec_onnx_session.get_inputs()[0].shape[3:][0]
|
||||||
|
if isinstance(w, int) and w>0:
|
||||||
|
imgW = w
|
||||||
|
# if w is not None and w > 0:
|
||||||
|
# imgW = w
|
||||||
|
|
||||||
|
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
ratio = w / float(h)
|
||||||
|
if math.ceil(imgH * ratio) > imgW:
|
||||||
|
resized_w = imgW
|
||||||
|
else:
|
||||||
|
resized_w = int(math.ceil(imgH * ratio))
|
||||||
|
if self.rec_algorithm == 'RARE':
|
||||||
|
if resized_w > self.rec_image_shape[2]:
|
||||||
|
resized_w = self.rec_image_shape[2]
|
||||||
|
imgW = self.rec_image_shape[2]
|
||||||
|
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||||
|
resized_image = resized_image.astype('float32')
|
||||||
|
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||||
|
resized_image -= 0.5
|
||||||
|
resized_image /= 0.5
|
||||||
|
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
||||||
|
padding_im[:, :, 0:resized_w] = resized_image
|
||||||
|
return padding_im
|
||||||
|
|
||||||
|
def resize_norm_img_vl(self, img, image_shape):
|
||||||
|
|
||||||
|
imgC, imgH, imgW = image_shape
|
||||||
|
img = img[:, :, ::-1] # bgr2rgb
|
||||||
|
resized_image = cv2.resize(
|
||||||
|
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||||
|
resized_image = resized_image.astype('float32')
|
||||||
|
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||||
|
return resized_image
|
||||||
|
|
||||||
|
def resize_norm_img_srn(self, img, image_shape):
|
||||||
|
imgC, imgH, imgW = image_shape
|
||||||
|
|
||||||
|
img_black = np.zeros((imgH, imgW))
|
||||||
|
im_hei = img.shape[0]
|
||||||
|
im_wid = img.shape[1]
|
||||||
|
|
||||||
|
if im_wid <= im_hei * 1:
|
||||||
|
img_new = cv2.resize(img, (imgH * 1, imgH))
|
||||||
|
elif im_wid <= im_hei * 2:
|
||||||
|
img_new = cv2.resize(img, (imgH * 2, imgH))
|
||||||
|
elif im_wid <= im_hei * 3:
|
||||||
|
img_new = cv2.resize(img, (imgH * 3, imgH))
|
||||||
|
else:
|
||||||
|
img_new = cv2.resize(img, (imgW, imgH))
|
||||||
|
|
||||||
|
img_np = np.asarray(img_new)
|
||||||
|
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
|
||||||
|
img_black[:, 0:img_np.shape[1]] = img_np
|
||||||
|
img_black = img_black[:, :, np.newaxis]
|
||||||
|
|
||||||
|
row, col, c = img_black.shape
|
||||||
|
c = 1
|
||||||
|
|
||||||
|
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
||||||
|
|
||||||
|
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
|
||||||
|
|
||||||
|
imgC, imgH, imgW = image_shape
|
||||||
|
feature_dim = int((imgH / 8) * (imgW / 8))
|
||||||
|
|
||||||
|
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
|
||||||
|
(feature_dim, 1)).astype('int64')
|
||||||
|
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
|
||||||
|
(max_text_length, 1)).astype('int64')
|
||||||
|
|
||||||
|
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
||||||
|
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
|
||||||
|
[-1, 1, max_text_length, max_text_length])
|
||||||
|
gsrm_slf_attn_bias1 = np.tile(
|
||||||
|
gsrm_slf_attn_bias1,
|
||||||
|
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
||||||
|
|
||||||
|
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
|
||||||
|
[-1, 1, max_text_length, max_text_length])
|
||||||
|
gsrm_slf_attn_bias2 = np.tile(
|
||||||
|
gsrm_slf_attn_bias2,
|
||||||
|
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
||||||
|
|
||||||
|
encoder_word_pos = encoder_word_pos[np.newaxis, :]
|
||||||
|
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
|
||||||
|
|
||||||
|
return [
|
||||||
|
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||||
|
gsrm_slf_attn_bias2
|
||||||
|
]
|
||||||
|
|
||||||
|
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
|
||||||
|
norm_img = self.resize_norm_img_srn(img, image_shape)
|
||||||
|
norm_img = norm_img[np.newaxis, :]
|
||||||
|
|
||||||
|
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
||||||
|
self.srn_other_inputs(image_shape, num_heads, max_text_length)
|
||||||
|
|
||||||
|
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
|
||||||
|
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
|
||||||
|
encoder_word_pos = encoder_word_pos.astype(np.int64)
|
||||||
|
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
|
||||||
|
|
||||||
|
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||||
|
gsrm_slf_attn_bias2)
|
||||||
|
|
||||||
|
def resize_norm_img_sar(self, img, image_shape,
|
||||||
|
width_downsample_ratio=0.25):
|
||||||
|
imgC, imgH, imgW_min, imgW_max = image_shape
|
||||||
|
h = img.shape[0]
|
||||||
|
w = img.shape[1]
|
||||||
|
valid_ratio = 1.0
|
||||||
|
# make sure new_width is an integral multiple of width_divisor.
|
||||||
|
width_divisor = int(1 / width_downsample_ratio)
|
||||||
|
# resize
|
||||||
|
ratio = w / float(h)
|
||||||
|
resize_w = math.ceil(imgH * ratio)
|
||||||
|
if resize_w % width_divisor != 0:
|
||||||
|
resize_w = round(resize_w / width_divisor) * width_divisor
|
||||||
|
if imgW_min is not None:
|
||||||
|
resize_w = max(imgW_min, resize_w)
|
||||||
|
if imgW_max is not None:
|
||||||
|
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
|
||||||
|
resize_w = min(imgW_max, resize_w)
|
||||||
|
resized_image = cv2.resize(img, (resize_w, imgH))
|
||||||
|
resized_image = resized_image.astype('float32')
|
||||||
|
# norm
|
||||||
|
if image_shape[0] == 1:
|
||||||
|
resized_image = resized_image / 255
|
||||||
|
resized_image = resized_image[np.newaxis, :]
|
||||||
|
else:
|
||||||
|
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||||
|
resized_image -= 0.5
|
||||||
|
resized_image /= 0.5
|
||||||
|
resize_shape = resized_image.shape
|
||||||
|
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
|
||||||
|
padding_im[:, :, 0:resize_w] = resized_image
|
||||||
|
pad_shape = padding_im.shape
|
||||||
|
|
||||||
|
return padding_im, resize_shape, pad_shape, valid_ratio
|
||||||
|
|
||||||
|
def resize_norm_img_spin(self, img):
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||||
|
# return padding_im
|
||||||
|
img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
|
||||||
|
img = np.array(img, np.float32)
|
||||||
|
img = np.expand_dims(img, -1)
|
||||||
|
img = img.transpose((2, 0, 1))
|
||||||
|
mean = [127.5]
|
||||||
|
std = [127.5]
|
||||||
|
mean = np.array(mean, dtype=np.float32)
|
||||||
|
std = np.array(std, dtype=np.float32)
|
||||||
|
mean = np.float32(mean.reshape(1, -1))
|
||||||
|
stdinv = 1 / np.float32(std.reshape(1, -1))
|
||||||
|
img -= mean
|
||||||
|
img *= stdinv
|
||||||
|
return img
|
||||||
|
|
||||||
|
def resize_norm_img_svtr(self, img, image_shape):
|
||||||
|
|
||||||
|
imgC, imgH, imgW = image_shape
|
||||||
|
resized_image = cv2.resize(
|
||||||
|
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||||
|
resized_image = resized_image.astype('float32')
|
||||||
|
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||||
|
resized_image -= 0.5
|
||||||
|
resized_image /= 0.5
|
||||||
|
return resized_image
|
||||||
|
|
||||||
|
def resize_norm_img_abinet(self, img, image_shape):
|
||||||
|
|
||||||
|
imgC, imgH, imgW = image_shape
|
||||||
|
|
||||||
|
resized_image = cv2.resize(
|
||||||
|
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||||
|
resized_image = resized_image.astype('float32')
|
||||||
|
resized_image = resized_image / 255.
|
||||||
|
|
||||||
|
mean = np.array([0.485, 0.456, 0.406])
|
||||||
|
std = np.array([0.229, 0.224, 0.225])
|
||||||
|
resized_image = (
|
||||||
|
resized_image - mean[None, None, ...]) / std[None, None, ...]
|
||||||
|
resized_image = resized_image.transpose((2, 0, 1))
|
||||||
|
resized_image = resized_image.astype('float32')
|
||||||
|
|
||||||
|
return resized_image
|
||||||
|
|
||||||
|
def norm_img_can(self, img, image_shape):
|
||||||
|
|
||||||
|
img = cv2.cvtColor(
|
||||||
|
img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
|
||||||
|
|
||||||
|
if self.inverse:
|
||||||
|
img = 255 - img
|
||||||
|
|
||||||
|
if self.rec_image_shape[0] == 1:
|
||||||
|
h, w = img.shape
|
||||||
|
_, imgH, imgW = self.rec_image_shape
|
||||||
|
if h < imgH or w < imgW:
|
||||||
|
padding_h = max(imgH - h, 0)
|
||||||
|
padding_w = max(imgW - w, 0)
|
||||||
|
img_padded = np.pad(img, ((0, padding_h), (0, padding_w)),
|
||||||
|
'constant',
|
||||||
|
constant_values=(255))
|
||||||
|
img = img_padded
|
||||||
|
|
||||||
|
img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
|
||||||
|
img = img.astype('float32')
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
def __call__(self, img_list):
|
||||||
|
img_num = len(img_list)
|
||||||
|
# Calculate the aspect ratio of all text bars
|
||||||
|
width_list = []
|
||||||
|
for img in img_list:
|
||||||
|
width_list.append(img.shape[1] / float(img.shape[0]))
|
||||||
|
# Sorting can speed up the recognition process
|
||||||
|
indices = np.argsort(np.array(width_list))
|
||||||
|
rec_res = [['', 0.0]] * img_num
|
||||||
|
batch_num = self.rec_batch_num
|
||||||
|
|
||||||
|
for beg_img_no in range(0, img_num, batch_num):
|
||||||
|
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||||
|
norm_img_batch = []
|
||||||
|
imgC, imgH, imgW = self.rec_image_shape[:3]
|
||||||
|
max_wh_ratio = imgW / imgH
|
||||||
|
# max_wh_ratio = 0
|
||||||
|
for ino in range(beg_img_no, end_img_no):
|
||||||
|
h, w = img_list[indices[ino]].shape[0:2]
|
||||||
|
wh_ratio = w * 1.0 / h
|
||||||
|
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||||
|
for ino in range(beg_img_no, end_img_no):
|
||||||
|
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
||||||
|
max_wh_ratio)
|
||||||
|
norm_img = norm_img[np.newaxis, :]
|
||||||
|
norm_img_batch.append(norm_img)
|
||||||
|
|
||||||
|
norm_img_batch = np.concatenate(norm_img_batch)
|
||||||
|
norm_img_batch = norm_img_batch.copy()
|
||||||
|
|
||||||
|
# img = img[:, :, ::-1].transpose(2, 0, 1)
|
||||||
|
# img = img[:, :, ::-1]
|
||||||
|
# img = img.transpose(2, 0, 1)
|
||||||
|
# img = img.astype(np.float32)
|
||||||
|
# img = np.expand_dims(img, axis=0)
|
||||||
|
# print(img.shape)
|
||||||
|
|
||||||
|
input_feed = self.get_input_feed(self.rec_input_name, norm_img_batch)
|
||||||
|
|
||||||
|
# import IPython
|
||||||
|
# IPython.embed(header='L-303')
|
||||||
|
|
||||||
|
outputs = self.rec_onnx_session.run(self.rec_output_name, input_feed=input_feed)
|
||||||
|
|
||||||
|
preds = outputs[0]
|
||||||
|
|
||||||
|
rec_result = self.postprocess_op(preds)
|
||||||
|
for rno in range(len(rec_result)):
|
||||||
|
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
||||||
|
|
||||||
|
return rec_res
|
||||||
99
pp_onnx/predict_system.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import copy
|
||||||
|
import pp_onnx.predict_det as predict_det
|
||||||
|
import pp_onnx.predict_cls as predict_cls
|
||||||
|
import pp_onnx.predict_rec as predict_rec
|
||||||
|
from pp_onnx.utils import get_rotate_crop_image, get_minarea_rect_crop
|
||||||
|
|
||||||
|
|
||||||
|
class TextSystem(object):
|
||||||
|
def __init__(self, args):
|
||||||
|
self.text_detector = predict_det.TextDetector(args)
|
||||||
|
self.text_recognizer = predict_rec.TextRecognizer(args)
|
||||||
|
self.use_angle_cls = args.use_angle_cls
|
||||||
|
self.drop_score = args.drop_score
|
||||||
|
if self.use_angle_cls:
|
||||||
|
self.text_classifier = predict_cls.TextClassifier(args)
|
||||||
|
|
||||||
|
self.args = args
|
||||||
|
self.crop_image_res_index = 0
|
||||||
|
|
||||||
|
|
||||||
|
def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
bbox_num = len(img_crop_list)
|
||||||
|
for bno in range(bbox_num):
|
||||||
|
cv2.imwrite(
|
||||||
|
os.path.join(output_dir,
|
||||||
|
f"mg_crop_{bno+self.crop_image_res_index}.jpg"),
|
||||||
|
img_crop_list[bno])
|
||||||
|
|
||||||
|
self.crop_image_res_index += bbox_num
|
||||||
|
|
||||||
|
def __call__(self, img, cls=True):
|
||||||
|
ori_im = img.copy()
|
||||||
|
# 文字检测
|
||||||
|
dt_boxes = self.text_detector(img)
|
||||||
|
|
||||||
|
if dt_boxes is None:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
img_crop_list = []
|
||||||
|
|
||||||
|
dt_boxes = sorted_boxes(dt_boxes)
|
||||||
|
|
||||||
|
# 图片裁剪
|
||||||
|
for bno in range(len(dt_boxes)):
|
||||||
|
tmp_box = copy.deepcopy(dt_boxes[bno])
|
||||||
|
if self.args.det_box_type == "quad":
|
||||||
|
img_crop = get_rotate_crop_image(ori_im, tmp_box)
|
||||||
|
else:
|
||||||
|
img_crop = get_minarea_rect_crop(ori_im, tmp_box)
|
||||||
|
img_crop_list.append(img_crop)
|
||||||
|
|
||||||
|
# 方向分类
|
||||||
|
if self.use_angle_cls and cls:
|
||||||
|
img_crop_list, angle_list = self.text_classifier(img_crop_list)
|
||||||
|
|
||||||
|
# 图像识别
|
||||||
|
rec_res = self.text_recognizer(img_crop_list)
|
||||||
|
|
||||||
|
if self.args.save_crop_res:
|
||||||
|
self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,rec_res)
|
||||||
|
filter_boxes, filter_rec_res = [], []
|
||||||
|
for box, rec_result in zip(dt_boxes, rec_res):
|
||||||
|
text, score = rec_result
|
||||||
|
if score >= self.drop_score:
|
||||||
|
filter_boxes.append(box)
|
||||||
|
filter_rec_res.append(rec_result)
|
||||||
|
|
||||||
|
# import IPython
|
||||||
|
# IPython.embed(header='L-70')
|
||||||
|
|
||||||
|
return filter_boxes, filter_rec_res
|
||||||
|
|
||||||
|
|
||||||
|
def sorted_boxes(dt_boxes):
|
||||||
|
"""
|
||||||
|
Sort text boxes in order from top to bottom, left to right
|
||||||
|
args:
|
||||||
|
dt_boxes(array):detected text boxes with shape [4, 2]
|
||||||
|
return:
|
||||||
|
sorted boxes(array) with shape [4, 2]
|
||||||
|
"""
|
||||||
|
num_boxes = dt_boxes.shape[0]
|
||||||
|
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
||||||
|
_boxes = list(sorted_boxes)
|
||||||
|
|
||||||
|
for i in range(num_boxes - 1):
|
||||||
|
for j in range(i, -1, -1):
|
||||||
|
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
|
||||||
|
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
|
||||||
|
tmp = _boxes[j]
|
||||||
|
_boxes[j] = _boxes[j + 1]
|
||||||
|
_boxes[j + 1] = tmp
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return _boxes
|
||||||
|
|
||||||
65
pp_onnx/readme.md
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# paddleocr模型转换成onnx模型后,利用ONNX模型进行推理
|
||||||
|
## 1、安装paddle2onnx
|
||||||
|
```angular2html
|
||||||
|
pip install paddle2onnx
|
||||||
|
```
|
||||||
|
|
||||||
|
## 2、下载paddleocr模型文件
|
||||||
|
```angular2html
|
||||||
|
!wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
|
||||||
|
!wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar
|
||||||
|
!wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar
|
||||||
|
```
|
||||||
|
## 3、解压模型文件
|
||||||
|
```angular2html
|
||||||
|
!tar -xvf /home/aistudio/onnx_pred/models/ch_ppocr_mobile_v2.0_cls_infer.tar
|
||||||
|
!tar -xvf /home/aistudio/onnx_pred/models/ch_ppocr_server_v2.0_det_infer.tar
|
||||||
|
!tar -xvf /home/aistudio/onnx_pred/models/ch_ppocr_server_v2.0_rec_infer.tar
|
||||||
|
```
|
||||||
|
|
||||||
|
## 4、将paddleocr模型转成onxx模型
|
||||||
|
```angular2html
|
||||||
|
paddle2onnx --model_dir ./ch_ppocr_server_v2.0_rec_infer \
|
||||||
|
--model_filename inference.pdmodel \
|
||||||
|
--params_filename inference.pdiparams \
|
||||||
|
--save_file ./ch_ppocr_server_v2.0_rec.onnx \
|
||||||
|
--opset_version 11 \
|
||||||
|
--enable_onnx_checker True
|
||||||
|
|
||||||
|
|
||||||
|
paddle2onnx --model_dir ./ch_ppocr_server_v2.0_det_infer \
|
||||||
|
--model_filename inference.pdmodel \
|
||||||
|
--params_filename inference.pdiparams \
|
||||||
|
--save_file ./ch_ppocr_server_v2.0_det.onnx \
|
||||||
|
--opset_version 11 \
|
||||||
|
--enable_onnx_checker True
|
||||||
|
|
||||||
|
|
||||||
|
paddle2onnx --model_dir ./ch_ppocr_mobile_v2.0_cls_infer \
|
||||||
|
--model_filename inference.pdmodel \
|
||||||
|
--params_filename inference.pdiparams \
|
||||||
|
--save_file ./ch_ppocr_mobile_v2.0_cls.onnx \
|
||||||
|
--opset_version 11 \
|
||||||
|
--enable_onnx_checker True
|
||||||
|
```
|
||||||
|
|
||||||
|
## 5、安装onnx
|
||||||
|
```angular2html
|
||||||
|
pip install onnx==1.14.0
|
||||||
|
pip install onnxruntime-gpu==1.14.1
|
||||||
|
```
|
||||||
|
|
||||||
|
## 6、模型推理
|
||||||
|
```angular2html
|
||||||
|
import cv2
|
||||||
|
model = ONNXPaddleOcr()
|
||||||
|
|
||||||
|
img = cv2.imread('./1.jpg')
|
||||||
|
|
||||||
|
# ocr识别结果
|
||||||
|
result = model.ocr(img)
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
# 画box框
|
||||||
|
sav2Img(img, result)
|
||||||
|
```
|
||||||
920
pp_onnx/rec_postprocess.py
Normal file
@ -0,0 +1,920 @@
|
|||||||
|
|
||||||
|
import numpy as np
|
||||||
|
# import paddle
|
||||||
|
# from paddle.nn import functional as F
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRecLabelDecode(object):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self, character_dict_path=None, use_space_char=False):
|
||||||
|
self.beg_str = "sos"
|
||||||
|
self.end_str = "eos"
|
||||||
|
self.reverse = False
|
||||||
|
self.character_str = []
|
||||||
|
|
||||||
|
if character_dict_path is None:
|
||||||
|
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||||
|
dict_character = list(self.character_str)
|
||||||
|
else:
|
||||||
|
with open(character_dict_path, "rb") as fin:
|
||||||
|
lines = fin.readlines()
|
||||||
|
for line in lines:
|
||||||
|
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||||
|
self.character_str.append(line)
|
||||||
|
if use_space_char:
|
||||||
|
self.character_str.append(" ")
|
||||||
|
dict_character = list(self.character_str)
|
||||||
|
# import IPython
|
||||||
|
# IPython.embed(header='L-19')
|
||||||
|
if 'arabic' in str(character_dict_path):
|
||||||
|
self.reverse = True
|
||||||
|
|
||||||
|
dict_character = self.add_special_char(dict_character)
|
||||||
|
self.dict = {}
|
||||||
|
for i, char in enumerate(dict_character):
|
||||||
|
self.dict[char] = i
|
||||||
|
self.character = dict_character
|
||||||
|
|
||||||
|
def pred_reverse(self, pred):
|
||||||
|
pred_re = []
|
||||||
|
c_current = ''
|
||||||
|
for c in pred:
|
||||||
|
if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
|
||||||
|
if c_current != '':
|
||||||
|
pred_re.append(c_current)
|
||||||
|
pred_re.append(c)
|
||||||
|
c_current = ''
|
||||||
|
else:
|
||||||
|
c_current += c
|
||||||
|
if c_current != '':
|
||||||
|
pred_re.append(c_current)
|
||||||
|
|
||||||
|
return ''.join(pred_re[::-1])
|
||||||
|
|
||||||
|
def add_special_char(self, dict_character):
|
||||||
|
return dict_character
|
||||||
|
|
||||||
|
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||||
|
""" convert text-index into text-label. """
|
||||||
|
result_list = []
|
||||||
|
ignored_tokens = self.get_ignored_tokens()
|
||||||
|
batch_size = len(text_index)
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
|
||||||
|
if is_remove_duplicate:
|
||||||
|
selection[1:] = text_index[batch_idx][1:] != text_index[
|
||||||
|
batch_idx][:-1]
|
||||||
|
for ignored_token in ignored_tokens:
|
||||||
|
selection &= text_index[batch_idx] != ignored_token
|
||||||
|
|
||||||
|
char_list = [
|
||||||
|
self.character[text_id]
|
||||||
|
for text_id in text_index[batch_idx][selection]
|
||||||
|
]
|
||||||
|
if text_prob is not None:
|
||||||
|
conf_list = text_prob[batch_idx][selection]
|
||||||
|
else:
|
||||||
|
conf_list = [1] * len(selection)
|
||||||
|
if len(conf_list) == 0:
|
||||||
|
conf_list = [0]
|
||||||
|
|
||||||
|
text = ''.join(char_list)
|
||||||
|
|
||||||
|
if self.reverse: # for arabic rec
|
||||||
|
text = self.pred_reverse(text)
|
||||||
|
|
||||||
|
result_list.append((text, np.mean(conf_list).tolist()))
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
def get_ignored_tokens(self):
|
||||||
|
return [0] # for ctc blank
|
||||||
|
|
||||||
|
|
||||||
|
class CTCLabelDecode(BaseRecLabelDecode):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||||
|
**kwargs):
|
||||||
|
super(CTCLabelDecode, self).__init__(character_dict_path,
|
||||||
|
use_space_char)
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
if isinstance(preds, tuple) or isinstance(preds, list):
|
||||||
|
preds = preds[-1]
|
||||||
|
# if isinstance(preds, paddle.Tensor):
|
||||||
|
# preds = preds.numpy()
|
||||||
|
preds_idx = preds.argmax(axis=2)
|
||||||
|
preds_prob = preds.max(axis=2)
|
||||||
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
||||||
|
if label is None:
|
||||||
|
return text
|
||||||
|
label = self.decode(label)
|
||||||
|
return text, label
|
||||||
|
|
||||||
|
def add_special_char(self, dict_character):
|
||||||
|
dict_character = ['blank'] + dict_character
|
||||||
|
return dict_character
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationCTCLabelDecode(CTCLabelDecode):
|
||||||
|
"""
|
||||||
|
Convert
|
||||||
|
Convert between text-label and text-index
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
character_dict_path=None,
|
||||||
|
use_space_char=False,
|
||||||
|
model_name=["student"],
|
||||||
|
key=None,
|
||||||
|
multi_head=False,
|
||||||
|
**kwargs):
|
||||||
|
super(DistillationCTCLabelDecode, self).__init__(character_dict_path,
|
||||||
|
use_space_char)
|
||||||
|
if not isinstance(model_name, list):
|
||||||
|
model_name = [model_name]
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
self.key = key
|
||||||
|
self.multi_head = multi_head
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
output = dict()
|
||||||
|
for name in self.model_name:
|
||||||
|
pred = preds[name]
|
||||||
|
if self.key is not None:
|
||||||
|
pred = pred[self.key]
|
||||||
|
if self.multi_head and isinstance(pred, dict):
|
||||||
|
pred = pred['ctc']
|
||||||
|
output[name] = super().__call__(pred, label=label, *args, **kwargs)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class AttnLabelDecode(BaseRecLabelDecode):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||||
|
**kwargs):
|
||||||
|
super(AttnLabelDecode, self).__init__(character_dict_path,
|
||||||
|
use_space_char)
|
||||||
|
|
||||||
|
def add_special_char(self, dict_character):
|
||||||
|
self.beg_str = "sos"
|
||||||
|
self.end_str = "eos"
|
||||||
|
dict_character = dict_character
|
||||||
|
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
||||||
|
return dict_character
|
||||||
|
|
||||||
|
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||||
|
""" convert text-index into text-label. """
|
||||||
|
result_list = []
|
||||||
|
ignored_tokens = self.get_ignored_tokens()
|
||||||
|
[beg_idx, end_idx] = self.get_ignored_tokens()
|
||||||
|
batch_size = len(text_index)
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
char_list = []
|
||||||
|
conf_list = []
|
||||||
|
for idx in range(len(text_index[batch_idx])):
|
||||||
|
if text_index[batch_idx][idx] in ignored_tokens:
|
||||||
|
continue
|
||||||
|
if int(text_index[batch_idx][idx]) == int(end_idx):
|
||||||
|
break
|
||||||
|
if is_remove_duplicate:
|
||||||
|
# only for predict
|
||||||
|
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||||
|
batch_idx][idx]:
|
||||||
|
continue
|
||||||
|
char_list.append(self.character[int(text_index[batch_idx][
|
||||||
|
idx])])
|
||||||
|
if text_prob is not None:
|
||||||
|
conf_list.append(text_prob[batch_idx][idx])
|
||||||
|
else:
|
||||||
|
conf_list.append(1)
|
||||||
|
text = ''.join(char_list)
|
||||||
|
result_list.append((text, np.mean(conf_list).tolist()))
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
text = self.decode(text)
|
||||||
|
if label is None:
|
||||||
|
return text
|
||||||
|
else:
|
||||||
|
label = self.decode(label, is_remove_duplicate=False)
|
||||||
|
return text, label
|
||||||
|
"""
|
||||||
|
# if isinstance(preds, paddle.Tensor):
|
||||||
|
# preds = preds.numpy()
|
||||||
|
|
||||||
|
preds_idx = preds.argmax(axis=2)
|
||||||
|
preds_prob = preds.max(axis=2)
|
||||||
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||||
|
if label is None:
|
||||||
|
return text
|
||||||
|
label = self.decode(label, is_remove_duplicate=False)
|
||||||
|
return text, label
|
||||||
|
|
||||||
|
def get_ignored_tokens(self):
|
||||||
|
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||||
|
end_idx = self.get_beg_end_flag_idx("end")
|
||||||
|
return [beg_idx, end_idx]
|
||||||
|
|
||||||
|
def get_beg_end_flag_idx(self, beg_or_end):
|
||||||
|
if beg_or_end == "beg":
|
||||||
|
idx = np.array(self.dict[self.beg_str])
|
||||||
|
elif beg_or_end == "end":
|
||||||
|
idx = np.array(self.dict[self.end_str])
|
||||||
|
else:
|
||||||
|
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
||||||
|
% beg_or_end
|
||||||
|
return idx
|
||||||
|
|
||||||
|
|
||||||
|
class RFLLabelDecode(BaseRecLabelDecode):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||||
|
**kwargs):
|
||||||
|
super(RFLLabelDecode, self).__init__(character_dict_path,
|
||||||
|
use_space_char)
|
||||||
|
|
||||||
|
def add_special_char(self, dict_character):
|
||||||
|
self.beg_str = "sos"
|
||||||
|
self.end_str = "eos"
|
||||||
|
dict_character = dict_character
|
||||||
|
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
||||||
|
return dict_character
|
||||||
|
|
||||||
|
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||||
|
""" convert text-index into text-label. """
|
||||||
|
result_list = []
|
||||||
|
ignored_tokens = self.get_ignored_tokens()
|
||||||
|
[beg_idx, end_idx] = self.get_ignored_tokens()
|
||||||
|
batch_size = len(text_index)
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
char_list = []
|
||||||
|
conf_list = []
|
||||||
|
for idx in range(len(text_index[batch_idx])):
|
||||||
|
if text_index[batch_idx][idx] in ignored_tokens:
|
||||||
|
continue
|
||||||
|
if int(text_index[batch_idx][idx]) == int(end_idx):
|
||||||
|
break
|
||||||
|
if is_remove_duplicate:
|
||||||
|
# only for predict
|
||||||
|
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||||
|
batch_idx][idx]:
|
||||||
|
continue
|
||||||
|
char_list.append(self.character[int(text_index[batch_idx][
|
||||||
|
idx])])
|
||||||
|
if text_prob is not None:
|
||||||
|
conf_list.append(text_prob[batch_idx][idx])
|
||||||
|
else:
|
||||||
|
conf_list.append(1)
|
||||||
|
text = ''.join(char_list)
|
||||||
|
result_list.append((text, np.mean(conf_list).tolist()))
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
# if seq_outputs is not None:
|
||||||
|
if isinstance(preds, tuple) or isinstance(preds, list):
|
||||||
|
cnt_outputs, seq_outputs = preds
|
||||||
|
# if isinstance(seq_outputs, paddle.Tensor):
|
||||||
|
# seq_outputs = seq_outputs.numpy()
|
||||||
|
preds_idx = seq_outputs.argmax(axis=2)
|
||||||
|
preds_prob = seq_outputs.max(axis=2)
|
||||||
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||||
|
|
||||||
|
if label is None:
|
||||||
|
return text
|
||||||
|
label = self.decode(label, is_remove_duplicate=False)
|
||||||
|
return text, label
|
||||||
|
|
||||||
|
else:
|
||||||
|
cnt_outputs = preds
|
||||||
|
# if isinstance(cnt_outputs, paddle.Tensor):
|
||||||
|
# cnt_outputs = cnt_outputs.numpy()
|
||||||
|
cnt_length = []
|
||||||
|
for lens in cnt_outputs:
|
||||||
|
length = round(np.sum(lens))
|
||||||
|
cnt_length.append(length)
|
||||||
|
if label is None:
|
||||||
|
return cnt_length
|
||||||
|
label = self.decode(label, is_remove_duplicate=False)
|
||||||
|
length = [len(res[0]) for res in label]
|
||||||
|
return cnt_length, length
|
||||||
|
|
||||||
|
def get_ignored_tokens(self):
|
||||||
|
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||||
|
end_idx = self.get_beg_end_flag_idx("end")
|
||||||
|
return [beg_idx, end_idx]
|
||||||
|
|
||||||
|
def get_beg_end_flag_idx(self, beg_or_end):
|
||||||
|
if beg_or_end == "beg":
|
||||||
|
idx = np.array(self.dict[self.beg_str])
|
||||||
|
elif beg_or_end == "end":
|
||||||
|
idx = np.array(self.dict[self.end_str])
|
||||||
|
else:
|
||||||
|
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
||||||
|
% beg_or_end
|
||||||
|
return idx
|
||||||
|
|
||||||
|
|
||||||
|
class SEEDLabelDecode(BaseRecLabelDecode):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||||
|
**kwargs):
|
||||||
|
super(SEEDLabelDecode, self).__init__(character_dict_path,
|
||||||
|
use_space_char)
|
||||||
|
|
||||||
|
def add_special_char(self, dict_character):
|
||||||
|
self.padding_str = "padding"
|
||||||
|
self.end_str = "eos"
|
||||||
|
self.unknown = "unknown"
|
||||||
|
dict_character = dict_character + [
|
||||||
|
self.end_str, self.padding_str, self.unknown
|
||||||
|
]
|
||||||
|
return dict_character
|
||||||
|
|
||||||
|
def get_ignored_tokens(self):
|
||||||
|
end_idx = self.get_beg_end_flag_idx("eos")
|
||||||
|
return [end_idx]
|
||||||
|
|
||||||
|
def get_beg_end_flag_idx(self, beg_or_end):
|
||||||
|
if beg_or_end == "sos":
|
||||||
|
idx = np.array(self.dict[self.beg_str])
|
||||||
|
elif beg_or_end == "eos":
|
||||||
|
idx = np.array(self.dict[self.end_str])
|
||||||
|
else:
|
||||||
|
assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
|
||||||
|
return idx
|
||||||
|
|
||||||
|
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||||
|
""" convert text-index into text-label. """
|
||||||
|
result_list = []
|
||||||
|
[end_idx] = self.get_ignored_tokens()
|
||||||
|
batch_size = len(text_index)
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
char_list = []
|
||||||
|
conf_list = []
|
||||||
|
for idx in range(len(text_index[batch_idx])):
|
||||||
|
if int(text_index[batch_idx][idx]) == int(end_idx):
|
||||||
|
break
|
||||||
|
if is_remove_duplicate:
|
||||||
|
# only for predict
|
||||||
|
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||||
|
batch_idx][idx]:
|
||||||
|
continue
|
||||||
|
char_list.append(self.character[int(text_index[batch_idx][
|
||||||
|
idx])])
|
||||||
|
if text_prob is not None:
|
||||||
|
conf_list.append(text_prob[batch_idx][idx])
|
||||||
|
else:
|
||||||
|
conf_list.append(1)
|
||||||
|
text = ''.join(char_list)
|
||||||
|
result_list.append((text, np.mean(conf_list).tolist()))
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
text = self.decode(text)
|
||||||
|
if label is None:
|
||||||
|
return text
|
||||||
|
else:
|
||||||
|
label = self.decode(label, is_remove_duplicate=False)
|
||||||
|
return text, label
|
||||||
|
"""
|
||||||
|
preds_idx = preds["rec_pred"]
|
||||||
|
# if isinstance(preds_idx, paddle.Tensor):
|
||||||
|
# preds_idx = preds_idx.numpy()
|
||||||
|
if "rec_pred_scores" in preds:
|
||||||
|
preds_idx = preds["rec_pred"]
|
||||||
|
preds_prob = preds["rec_pred_scores"]
|
||||||
|
else:
|
||||||
|
preds_idx = preds["rec_pred"].argmax(axis=2)
|
||||||
|
preds_prob = preds["rec_pred"].max(axis=2)
|
||||||
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||||
|
if label is None:
|
||||||
|
return text
|
||||||
|
label = self.decode(label, is_remove_duplicate=False)
|
||||||
|
return text, label
|
||||||
|
|
||||||
|
|
||||||
|
class SRNLabelDecode(BaseRecLabelDecode):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||||
|
**kwargs):
|
||||||
|
super(SRNLabelDecode, self).__init__(character_dict_path,
|
||||||
|
use_space_char)
|
||||||
|
self.max_text_length = kwargs.get('max_text_length', 25)
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
pred = preds['predict']
|
||||||
|
char_num = len(self.character_str) + 2
|
||||||
|
# if isinstance(pred, paddle.Tensor):
|
||||||
|
# pred = pred.numpy()
|
||||||
|
pred = np.reshape(pred, [-1, char_num])
|
||||||
|
|
||||||
|
preds_idx = np.argmax(pred, axis=1)
|
||||||
|
preds_prob = np.max(pred, axis=1)
|
||||||
|
|
||||||
|
preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
|
||||||
|
|
||||||
|
preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
|
||||||
|
|
||||||
|
text = self.decode(preds_idx, preds_prob)
|
||||||
|
|
||||||
|
if label is None:
|
||||||
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||||
|
return text
|
||||||
|
label = self.decode(label)
|
||||||
|
return text, label
|
||||||
|
|
||||||
|
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||||
|
""" convert text-index into text-label. """
|
||||||
|
result_list = []
|
||||||
|
ignored_tokens = self.get_ignored_tokens()
|
||||||
|
batch_size = len(text_index)
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
char_list = []
|
||||||
|
conf_list = []
|
||||||
|
for idx in range(len(text_index[batch_idx])):
|
||||||
|
if text_index[batch_idx][idx] in ignored_tokens:
|
||||||
|
continue
|
||||||
|
if is_remove_duplicate:
|
||||||
|
# only for predict
|
||||||
|
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||||
|
batch_idx][idx]:
|
||||||
|
continue
|
||||||
|
char_list.append(self.character[int(text_index[batch_idx][
|
||||||
|
idx])])
|
||||||
|
if text_prob is not None:
|
||||||
|
conf_list.append(text_prob[batch_idx][idx])
|
||||||
|
else:
|
||||||
|
conf_list.append(1)
|
||||||
|
|
||||||
|
text = ''.join(char_list)
|
||||||
|
result_list.append((text, np.mean(conf_list).tolist()))
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
def add_special_char(self, dict_character):
|
||||||
|
dict_character = dict_character + [self.beg_str, self.end_str]
|
||||||
|
return dict_character
|
||||||
|
|
||||||
|
def get_ignored_tokens(self):
|
||||||
|
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||||
|
end_idx = self.get_beg_end_flag_idx("end")
|
||||||
|
return [beg_idx, end_idx]
|
||||||
|
|
||||||
|
def get_beg_end_flag_idx(self, beg_or_end):
|
||||||
|
if beg_or_end == "beg":
|
||||||
|
idx = np.array(self.dict[self.beg_str])
|
||||||
|
elif beg_or_end == "end":
|
||||||
|
idx = np.array(self.dict[self.end_str])
|
||||||
|
else:
|
||||||
|
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
||||||
|
% beg_or_end
|
||||||
|
return idx
|
||||||
|
|
||||||
|
|
||||||
|
class SARLabelDecode(BaseRecLabelDecode):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||||
|
**kwargs):
|
||||||
|
super(SARLabelDecode, self).__init__(character_dict_path,
|
||||||
|
use_space_char)
|
||||||
|
|
||||||
|
self.rm_symbol = kwargs.get('rm_symbol', False)
|
||||||
|
|
||||||
|
def add_special_char(self, dict_character):
|
||||||
|
beg_end_str = "<BOS/EOS>"
|
||||||
|
unknown_str = "<UKN>"
|
||||||
|
padding_str = "<PAD>"
|
||||||
|
dict_character = dict_character + [unknown_str]
|
||||||
|
self.unknown_idx = len(dict_character) - 1
|
||||||
|
dict_character = dict_character + [beg_end_str]
|
||||||
|
self.start_idx = len(dict_character) - 1
|
||||||
|
self.end_idx = len(dict_character) - 1
|
||||||
|
dict_character = dict_character + [padding_str]
|
||||||
|
self.padding_idx = len(dict_character) - 1
|
||||||
|
return dict_character
|
||||||
|
|
||||||
|
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||||
|
""" convert text-index into text-label. """
|
||||||
|
result_list = []
|
||||||
|
ignored_tokens = self.get_ignored_tokens()
|
||||||
|
|
||||||
|
batch_size = len(text_index)
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
char_list = []
|
||||||
|
conf_list = []
|
||||||
|
for idx in range(len(text_index[batch_idx])):
|
||||||
|
if text_index[batch_idx][idx] in ignored_tokens:
|
||||||
|
continue
|
||||||
|
if int(text_index[batch_idx][idx]) == int(self.end_idx):
|
||||||
|
if text_prob is None and idx == 0:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
if is_remove_duplicate:
|
||||||
|
# only for predict
|
||||||
|
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||||
|
batch_idx][idx]:
|
||||||
|
continue
|
||||||
|
char_list.append(self.character[int(text_index[batch_idx][
|
||||||
|
idx])])
|
||||||
|
if text_prob is not None:
|
||||||
|
conf_list.append(text_prob[batch_idx][idx])
|
||||||
|
else:
|
||||||
|
conf_list.append(1)
|
||||||
|
text = ''.join(char_list)
|
||||||
|
if self.rm_symbol:
|
||||||
|
comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
|
||||||
|
text = text.lower()
|
||||||
|
text = comp.sub('', text)
|
||||||
|
result_list.append((text, np.mean(conf_list).tolist()))
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
# if isinstance(preds, paddle.Tensor):
|
||||||
|
# preds = preds.numpy()
|
||||||
|
preds_idx = preds.argmax(axis=2)
|
||||||
|
preds_prob = preds.max(axis=2)
|
||||||
|
|
||||||
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||||
|
|
||||||
|
if label is None:
|
||||||
|
return text
|
||||||
|
label = self.decode(label, is_remove_duplicate=False)
|
||||||
|
return text, label
|
||||||
|
|
||||||
|
def get_ignored_tokens(self):
|
||||||
|
return [self.padding_idx]
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationSARLabelDecode(SARLabelDecode):
|
||||||
|
"""
|
||||||
|
Convert
|
||||||
|
Convert between text-label and text-index
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
character_dict_path=None,
|
||||||
|
use_space_char=False,
|
||||||
|
model_name=["student"],
|
||||||
|
key=None,
|
||||||
|
multi_head=False,
|
||||||
|
**kwargs):
|
||||||
|
super(DistillationSARLabelDecode, self).__init__(character_dict_path,
|
||||||
|
use_space_char)
|
||||||
|
if not isinstance(model_name, list):
|
||||||
|
model_name = [model_name]
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
self.key = key
|
||||||
|
self.multi_head = multi_head
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
output = dict()
|
||||||
|
for name in self.model_name:
|
||||||
|
pred = preds[name]
|
||||||
|
if self.key is not None:
|
||||||
|
pred = pred[self.key]
|
||||||
|
if self.multi_head and isinstance(pred, dict):
|
||||||
|
pred = pred['sar']
|
||||||
|
output[name] = super().__call__(pred, label=label, *args, **kwargs)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class PRENLabelDecode(BaseRecLabelDecode):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||||
|
**kwargs):
|
||||||
|
super(PRENLabelDecode, self).__init__(character_dict_path,
|
||||||
|
use_space_char)
|
||||||
|
|
||||||
|
def add_special_char(self, dict_character):
|
||||||
|
padding_str = '<PAD>' # 0
|
||||||
|
end_str = '<EOS>' # 1
|
||||||
|
unknown_str = '<UNK>' # 2
|
||||||
|
|
||||||
|
dict_character = [padding_str, end_str, unknown_str] + dict_character
|
||||||
|
self.padding_idx = 0
|
||||||
|
self.end_idx = 1
|
||||||
|
self.unknown_idx = 2
|
||||||
|
|
||||||
|
return dict_character
|
||||||
|
|
||||||
|
def decode(self, text_index, text_prob=None):
|
||||||
|
""" convert text-index into text-label. """
|
||||||
|
result_list = []
|
||||||
|
batch_size = len(text_index)
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
char_list = []
|
||||||
|
conf_list = []
|
||||||
|
for idx in range(len(text_index[batch_idx])):
|
||||||
|
if text_index[batch_idx][idx] == self.end_idx:
|
||||||
|
break
|
||||||
|
if text_index[batch_idx][idx] in \
|
||||||
|
[self.padding_idx, self.unknown_idx]:
|
||||||
|
continue
|
||||||
|
char_list.append(self.character[int(text_index[batch_idx][
|
||||||
|
idx])])
|
||||||
|
if text_prob is not None:
|
||||||
|
conf_list.append(text_prob[batch_idx][idx])
|
||||||
|
else:
|
||||||
|
conf_list.append(1)
|
||||||
|
|
||||||
|
text = ''.join(char_list)
|
||||||
|
if len(text) > 0:
|
||||||
|
result_list.append((text, np.mean(conf_list).tolist()))
|
||||||
|
else:
|
||||||
|
# here confidence of empty recog result is 1
|
||||||
|
result_list.append(('', 1))
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
# if isinstance(preds, paddle.Tensor):
|
||||||
|
# preds = preds.numpy()
|
||||||
|
preds_idx = preds.argmax(axis=2)
|
||||||
|
preds_prob = preds.max(axis=2)
|
||||||
|
text = self.decode(preds_idx, preds_prob)
|
||||||
|
if label is None:
|
||||||
|
return text
|
||||||
|
label = self.decode(label)
|
||||||
|
return text, label
|
||||||
|
|
||||||
|
|
||||||
|
class NRTRLabelDecode(BaseRecLabelDecode):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
|
||||||
|
super(NRTRLabelDecode, self).__init__(character_dict_path,
|
||||||
|
use_space_char)
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
|
||||||
|
if len(preds) == 2:
|
||||||
|
preds_id = preds[0]
|
||||||
|
preds_prob = preds[1]
|
||||||
|
# if isinstance(preds_id, paddle.Tensor):
|
||||||
|
# preds_id = preds_id.numpy()
|
||||||
|
# if isinstance(preds_prob, paddle.Tensor):
|
||||||
|
# preds_prob = preds_prob.numpy()
|
||||||
|
if preds_id[0][0] == 2:
|
||||||
|
preds_idx = preds_id[:, 1:]
|
||||||
|
preds_prob = preds_prob[:, 1:]
|
||||||
|
else:
|
||||||
|
preds_idx = preds_id
|
||||||
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||||
|
if label is None:
|
||||||
|
return text
|
||||||
|
label = self.decode(label[:, 1:])
|
||||||
|
else:
|
||||||
|
# if isinstance(preds, paddle.Tensor):
|
||||||
|
# preds = preds.numpy()
|
||||||
|
preds_idx = preds.argmax(axis=2)
|
||||||
|
preds_prob = preds.max(axis=2)
|
||||||
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||||
|
if label is None:
|
||||||
|
return text
|
||||||
|
label = self.decode(label[:, 1:])
|
||||||
|
return text, label
|
||||||
|
|
||||||
|
def add_special_char(self, dict_character):
|
||||||
|
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
|
||||||
|
return dict_character
|
||||||
|
|
||||||
|
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||||
|
""" convert text-index into text-label. """
|
||||||
|
result_list = []
|
||||||
|
batch_size = len(text_index)
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
char_list = []
|
||||||
|
conf_list = []
|
||||||
|
for idx in range(len(text_index[batch_idx])):
|
||||||
|
try:
|
||||||
|
char_idx = self.character[int(text_index[batch_idx][idx])]
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
if char_idx == '</s>': # end
|
||||||
|
break
|
||||||
|
char_list.append(char_idx)
|
||||||
|
if text_prob is not None:
|
||||||
|
conf_list.append(text_prob[batch_idx][idx])
|
||||||
|
else:
|
||||||
|
conf_list.append(1)
|
||||||
|
text = ''.join(char_list)
|
||||||
|
result_list.append((text.lower(), np.mean(conf_list).tolist()))
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
|
||||||
|
class ViTSTRLabelDecode(NRTRLabelDecode):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||||
|
**kwargs):
|
||||||
|
super(ViTSTRLabelDecode, self).__init__(character_dict_path,
|
||||||
|
use_space_char)
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
# if isinstance(preds, paddle.Tensor):
|
||||||
|
# preds = preds[:, 1:].numpy()
|
||||||
|
# else:
|
||||||
|
preds = preds[:, 1:]
|
||||||
|
preds_idx = preds.argmax(axis=2)
|
||||||
|
preds_prob = preds.max(axis=2)
|
||||||
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||||
|
if label is None:
|
||||||
|
return text
|
||||||
|
label = self.decode(label[:, 1:])
|
||||||
|
return text, label
|
||||||
|
|
||||||
|
def add_special_char(self, dict_character):
|
||||||
|
dict_character = ['<s>', '</s>'] + dict_character
|
||||||
|
return dict_character
|
||||||
|
|
||||||
|
|
||||||
|
class ABINetLabelDecode(NRTRLabelDecode):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||||
|
**kwargs):
|
||||||
|
super(ABINetLabelDecode, self).__init__(character_dict_path,
|
||||||
|
use_space_char)
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
if isinstance(preds, dict):
|
||||||
|
preds = preds['align'][-1].numpy()
|
||||||
|
# elif isinstance(preds, paddle.Tensor):
|
||||||
|
# preds = preds.numpy()
|
||||||
|
else:
|
||||||
|
preds = preds
|
||||||
|
|
||||||
|
preds_idx = preds.argmax(axis=2)
|
||||||
|
preds_prob = preds.max(axis=2)
|
||||||
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||||
|
if label is None:
|
||||||
|
return text
|
||||||
|
label = self.decode(label)
|
||||||
|
return text, label
|
||||||
|
|
||||||
|
def add_special_char(self, dict_character):
|
||||||
|
dict_character = ['</s>'] + dict_character
|
||||||
|
return dict_character
|
||||||
|
|
||||||
|
|
||||||
|
class SPINLabelDecode(AttnLabelDecode):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||||
|
**kwargs):
|
||||||
|
super(SPINLabelDecode, self).__init__(character_dict_path,
|
||||||
|
use_space_char)
|
||||||
|
|
||||||
|
def add_special_char(self, dict_character):
|
||||||
|
self.beg_str = "sos"
|
||||||
|
self.end_str = "eos"
|
||||||
|
dict_character = dict_character
|
||||||
|
dict_character = [self.beg_str] + [self.end_str] + dict_character
|
||||||
|
return dict_character
|
||||||
|
|
||||||
|
|
||||||
|
# class VLLabelDecode(BaseRecLabelDecode):
|
||||||
|
# """ Convert between text-label and text-index """
|
||||||
|
#
|
||||||
|
# def __init__(self, character_dict_path=None, use_space_char=False,
|
||||||
|
# **kwargs):
|
||||||
|
# super(VLLabelDecode, self).__init__(character_dict_path, use_space_char)
|
||||||
|
# self.max_text_length = kwargs.get('max_text_length', 25)
|
||||||
|
# self.nclass = len(self.character) + 1
|
||||||
|
#
|
||||||
|
# def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||||
|
# """ convert text-index into text-label. """
|
||||||
|
# result_list = []
|
||||||
|
# ignored_tokens = self.get_ignored_tokens()
|
||||||
|
# batch_size = len(text_index)
|
||||||
|
# for batch_idx in range(batch_size):
|
||||||
|
# selection = np.ones(len(text_index[batch_idx]), dtype=bool)
|
||||||
|
# if is_remove_duplicate:
|
||||||
|
# selection[1:] = text_index[batch_idx][1:] != text_index[
|
||||||
|
# batch_idx][:-1]
|
||||||
|
# for ignored_token in ignored_tokens:
|
||||||
|
# selection &= text_index[batch_idx] != ignored_token
|
||||||
|
#
|
||||||
|
# char_list = [
|
||||||
|
# self.character[text_id - 1]
|
||||||
|
# for text_id in text_index[batch_idx][selection]
|
||||||
|
# ]
|
||||||
|
# if text_prob is not None:
|
||||||
|
# conf_list = text_prob[batch_idx][selection]
|
||||||
|
# else:
|
||||||
|
# conf_list = [1] * len(selection)
|
||||||
|
# if len(conf_list) == 0:
|
||||||
|
# conf_list = [0]
|
||||||
|
#
|
||||||
|
# text = ''.join(char_list)
|
||||||
|
# result_list.append((text, np.mean(conf_list).tolist()))
|
||||||
|
# return result_list
|
||||||
|
#
|
||||||
|
# def __call__(self, preds, label=None, length=None, *args, **kwargs):
|
||||||
|
# if len(preds) == 2: # eval mode
|
||||||
|
# text_pre, x = preds
|
||||||
|
# b = text_pre.shape[1]
|
||||||
|
# lenText = self.max_text_length
|
||||||
|
# nsteps = self.max_text_length
|
||||||
|
#
|
||||||
|
# if not isinstance(text_pre, paddle.Tensor):
|
||||||
|
# text_pre = paddle.to_tensor(text_pre, dtype='float32')
|
||||||
|
#
|
||||||
|
# out_res = paddle.zeros(
|
||||||
|
# shape=[lenText, b, self.nclass], dtype=x.dtype)
|
||||||
|
# out_length = paddle.zeros(shape=[b], dtype=x.dtype)
|
||||||
|
# now_step = 0
|
||||||
|
# for _ in range(nsteps):
|
||||||
|
# if 0 in out_length and now_step < nsteps:
|
||||||
|
# tmp_result = text_pre[now_step, :, :]
|
||||||
|
# out_res[now_step] = tmp_result
|
||||||
|
# tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
|
||||||
|
# for j in range(b):
|
||||||
|
# if out_length[j] == 0 and tmp_result[j] == 0:
|
||||||
|
# out_length[j] = now_step + 1
|
||||||
|
# now_step += 1
|
||||||
|
# for j in range(0, b):
|
||||||
|
# if int(out_length[j]) == 0:
|
||||||
|
# out_length[j] = nsteps
|
||||||
|
# start = 0
|
||||||
|
# output = paddle.zeros(
|
||||||
|
# shape=[int(out_length.sum()), self.nclass], dtype=x.dtype)
|
||||||
|
# for i in range(0, b):
|
||||||
|
# cur_length = int(out_length[i])
|
||||||
|
# output[start:start + cur_length] = out_res[0:cur_length, i, :]
|
||||||
|
# start += cur_length
|
||||||
|
# net_out = output
|
||||||
|
# length = out_length
|
||||||
|
#
|
||||||
|
# else: # train mode
|
||||||
|
# net_out = preds[0]
|
||||||
|
# length = length
|
||||||
|
# net_out = paddle.concat([t[:l] for t, l in zip(net_out, length)])
|
||||||
|
# text = []
|
||||||
|
# if not isinstance(net_out, paddle.Tensor):
|
||||||
|
# net_out = paddle.to_tensor(net_out, dtype='float32')
|
||||||
|
# net_out = F.softmax(net_out, axis=1)
|
||||||
|
# for i in range(0, length.shape[0]):
|
||||||
|
# preds_idx = net_out[int(length[:i].sum()):int(length[:i].sum(
|
||||||
|
# ) + length[i])].topk(1)[1][:, 0].tolist()
|
||||||
|
# preds_text = ''.join([
|
||||||
|
# self.character[idx - 1]
|
||||||
|
# if idx > 0 and idx <= len(self.character) else ''
|
||||||
|
# for idx in preds_idx
|
||||||
|
# ])
|
||||||
|
# preds_prob = net_out[int(length[:i].sum()):int(length[:i].sum(
|
||||||
|
# ) + length[i])].topk(1)[0][:, 0]
|
||||||
|
# preds_prob = paddle.exp(
|
||||||
|
# paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6))
|
||||||
|
# text.append((preds_text, preds_prob.numpy()[0]))
|
||||||
|
# if label is None:
|
||||||
|
# return text
|
||||||
|
# label = self.decode(label)
|
||||||
|
# return text, label
|
||||||
|
|
||||||
|
|
||||||
|
class CANLabelDecode(BaseRecLabelDecode):
|
||||||
|
""" Convert between latex-symbol and symbol-index """
|
||||||
|
|
||||||
|
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||||
|
**kwargs):
|
||||||
|
super(CANLabelDecode, self).__init__(character_dict_path,
|
||||||
|
use_space_char)
|
||||||
|
|
||||||
|
def decode(self, text_index, preds_prob=None):
|
||||||
|
result_list = []
|
||||||
|
batch_size = len(text_index)
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
seq_end = text_index[batch_idx].argmin(0)
|
||||||
|
idx_list = text_index[batch_idx][:seq_end].tolist()
|
||||||
|
symbol_list = [self.character[idx] for idx in idx_list]
|
||||||
|
probs = []
|
||||||
|
if preds_prob is not None:
|
||||||
|
probs = preds_prob[batch_idx][:len(symbol_list)].tolist()
|
||||||
|
|
||||||
|
result_list.append([' '.join(symbol_list), probs])
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
pred_prob, _, _, _ = preds
|
||||||
|
preds_idx = pred_prob.argmax(axis=2)
|
||||||
|
|
||||||
|
text = self.decode(preds_idx)
|
||||||
|
if label is None:
|
||||||
|
return text
|
||||||
|
label = self.decode(label)
|
||||||
|
return text, label
|
||||||
285
pp_onnx/utils.py
Normal file
@ -0,0 +1,285 @@
|
|||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
# pathlib
|
||||||
|
from logzero import logger
|
||||||
|
from importlib.resources import files
|
||||||
|
|
||||||
|
def get_rotate_crop_image(img, points):
|
||||||
|
'''
|
||||||
|
img_height, img_width = img.shape[0:2]
|
||||||
|
left = int(np.min(points[:, 0]))
|
||||||
|
right = int(np.max(points[:, 0]))
|
||||||
|
top = int(np.min(points[:, 1]))
|
||||||
|
bottom = int(np.max(points[:, 1]))
|
||||||
|
img_crop = img[top:bottom, left:right, :].copy()
|
||||||
|
points[:, 0] = points[:, 0] - left
|
||||||
|
points[:, 1] = points[:, 1] - top
|
||||||
|
'''
|
||||||
|
assert len(points) == 4, "shape of points must be 4*2"
|
||||||
|
img_crop_width = int(
|
||||||
|
max(
|
||||||
|
np.linalg.norm(points[0] - points[1]),
|
||||||
|
np.linalg.norm(points[2] - points[3])))
|
||||||
|
img_crop_height = int(
|
||||||
|
max(
|
||||||
|
np.linalg.norm(points[0] - points[3]),
|
||||||
|
np.linalg.norm(points[1] - points[2])))
|
||||||
|
pts_std = np.float32([[0, 0], [img_crop_width, 0],
|
||||||
|
[img_crop_width, img_crop_height],
|
||||||
|
[0, img_crop_height]])
|
||||||
|
M = cv2.getPerspectiveTransform(points, pts_std)
|
||||||
|
dst_img = cv2.warpPerspective(
|
||||||
|
img,
|
||||||
|
M, (img_crop_width, img_crop_height),
|
||||||
|
borderMode=cv2.BORDER_REPLICATE,
|
||||||
|
flags=cv2.INTER_CUBIC)
|
||||||
|
dst_img_height, dst_img_width = dst_img.shape[0:2]
|
||||||
|
if dst_img_height * 1.0 / dst_img_width >= 1.5:
|
||||||
|
dst_img = np.rot90(dst_img)
|
||||||
|
return dst_img
|
||||||
|
|
||||||
|
def get_minarea_rect_crop(img, points):
|
||||||
|
bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
|
||||||
|
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
||||||
|
|
||||||
|
index_a, index_b, index_c, index_d = 0, 1, 2, 3
|
||||||
|
if points[1][1] > points[0][1]:
|
||||||
|
index_a = 0
|
||||||
|
index_d = 1
|
||||||
|
else:
|
||||||
|
index_a = 1
|
||||||
|
index_d = 0
|
||||||
|
if points[3][1] > points[2][1]:
|
||||||
|
index_b = 2
|
||||||
|
index_c = 3
|
||||||
|
else:
|
||||||
|
index_b = 3
|
||||||
|
index_c = 2
|
||||||
|
|
||||||
|
box = [points[index_a], points[index_b], points[index_c], points[index_d]]
|
||||||
|
crop_img = get_rotate_crop_image(img, np.array(box))
|
||||||
|
return crop_img
|
||||||
|
|
||||||
|
|
||||||
|
def resize_img(img, input_size=600):
|
||||||
|
"""
|
||||||
|
resize img and limit the longest side of the image to input_size
|
||||||
|
"""
|
||||||
|
img = np.array(img)
|
||||||
|
im_shape = img.shape
|
||||||
|
im_size_max = np.max(im_shape[0:2])
|
||||||
|
im_scale = float(input_size) / float(im_size_max)
|
||||||
|
img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def str_count(s):
|
||||||
|
"""
|
||||||
|
Count the number of Chinese characters,
|
||||||
|
a single English character and a single number
|
||||||
|
equal to half the length of Chinese characters.
|
||||||
|
args:
|
||||||
|
s(string): the input of string
|
||||||
|
return(int):
|
||||||
|
the number of Chinese characters
|
||||||
|
"""
|
||||||
|
import string
|
||||||
|
count_zh = count_pu = 0
|
||||||
|
s_len = len(str(s))
|
||||||
|
en_dg_count = 0
|
||||||
|
for c in str(s):
|
||||||
|
if c in string.ascii_letters or c.isdigit() or c.isspace():
|
||||||
|
en_dg_count += 1
|
||||||
|
elif c.isalpha():
|
||||||
|
count_zh += 1
|
||||||
|
else:
|
||||||
|
count_pu += 1
|
||||||
|
return s_len - math.ceil(en_dg_count / 2)
|
||||||
|
|
||||||
|
def text_visual(texts,
|
||||||
|
scores,
|
||||||
|
img_h=400,
|
||||||
|
img_w=600,
|
||||||
|
threshold=0.,
|
||||||
|
font_path="./fonts/simfang.ttf"):
|
||||||
|
"""
|
||||||
|
create new blank img and draw txt on it
|
||||||
|
args:
|
||||||
|
texts(list): the text will be draw
|
||||||
|
scores(list|None): corresponding score of each txt
|
||||||
|
img_h(int): the height of blank img
|
||||||
|
img_w(int): the width of blank img
|
||||||
|
font_path: the path of font which is used to draw text
|
||||||
|
return(array):
|
||||||
|
"""
|
||||||
|
if scores is not None:
|
||||||
|
assert len(texts) == len(
|
||||||
|
scores), "The number of txts and corresponding scores must match"
|
||||||
|
|
||||||
|
def create_blank_img():
|
||||||
|
blank_img = np.ones(shape=[img_h, img_w], dtype=np.uint8) * 255
|
||||||
|
blank_img[:, img_w - 1:] = 0
|
||||||
|
blank_img = Image.fromarray(blank_img).convert("RGB")
|
||||||
|
draw_txt = ImageDraw.Draw(blank_img)
|
||||||
|
return blank_img, draw_txt
|
||||||
|
|
||||||
|
blank_img, draw_txt = create_blank_img()
|
||||||
|
|
||||||
|
font_size = 20
|
||||||
|
txt_color = (0, 0, 0)
|
||||||
|
# import IPython; IPython.embed(header='L-129')
|
||||||
|
font = ImageFont.truetype(str(font_path), font_size, encoding="utf-8")
|
||||||
|
|
||||||
|
gap = font_size + 5
|
||||||
|
txt_img_list = []
|
||||||
|
count, index = 1, 0
|
||||||
|
for idx, txt in enumerate(texts):
|
||||||
|
index += 1
|
||||||
|
if scores[idx] < threshold or math.isnan(scores[idx]):
|
||||||
|
index -= 1
|
||||||
|
continue
|
||||||
|
first_line = True
|
||||||
|
while str_count(txt) >= img_w // font_size - 4:
|
||||||
|
tmp = txt
|
||||||
|
txt = tmp[:img_w // font_size - 4]
|
||||||
|
if first_line:
|
||||||
|
new_txt = str(index) + ': ' + txt
|
||||||
|
first_line = False
|
||||||
|
else:
|
||||||
|
new_txt = ' ' + txt
|
||||||
|
draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
|
||||||
|
txt = tmp[img_w // font_size - 4:]
|
||||||
|
if count >= img_h // gap - 1:
|
||||||
|
txt_img_list.append(np.array(blank_img))
|
||||||
|
blank_img, draw_txt = create_blank_img()
|
||||||
|
count = 0
|
||||||
|
count += 1
|
||||||
|
if first_line:
|
||||||
|
new_txt = str(index) + ': ' + txt + ' ' + '%.3f' % (scores[idx])
|
||||||
|
else:
|
||||||
|
new_txt = " " + txt + " " + '%.3f' % (scores[idx])
|
||||||
|
draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
|
||||||
|
# whether add new blank img or not
|
||||||
|
if count >= img_h // gap - 1 and idx + 1 < len(texts):
|
||||||
|
txt_img_list.append(np.array(blank_img))
|
||||||
|
blank_img, draw_txt = create_blank_img()
|
||||||
|
count = 0
|
||||||
|
count += 1
|
||||||
|
txt_img_list.append(np.array(blank_img))
|
||||||
|
if len(txt_img_list) == 1:
|
||||||
|
blank_img = np.array(txt_img_list[0])
|
||||||
|
else:
|
||||||
|
blank_img = np.concatenate(txt_img_list, axis=1)
|
||||||
|
return np.array(blank_img)
|
||||||
|
|
||||||
|
def draw_ocr(image,
|
||||||
|
boxes,
|
||||||
|
txts=None,
|
||||||
|
scores=None,
|
||||||
|
drop_score=0.5,
|
||||||
|
font_path=None):
|
||||||
|
"""
|
||||||
|
Visualize the results of OCR detection and recognition
|
||||||
|
args:
|
||||||
|
image(Image|array): RGB image
|
||||||
|
boxes(list): boxes with shape(N, 4, 2)
|
||||||
|
txts(list): the texts
|
||||||
|
scores(list): txxs corresponding scores
|
||||||
|
drop_score(float): only scores greater than drop_threshold will be visualized
|
||||||
|
font_path: the path of font which is used to draw text
|
||||||
|
return(array):
|
||||||
|
the visualized img
|
||||||
|
"""
|
||||||
|
if font_path is None:
|
||||||
|
SIMFANG_TTF = files('pp_onnx').joinpath('fonts/simfang.ttf')
|
||||||
|
font_path = SIMFANG_TTF
|
||||||
|
|
||||||
|
if scores is None:
|
||||||
|
scores = [1] * len(boxes)
|
||||||
|
box_num = len(boxes)
|
||||||
|
for i in range(box_num):
|
||||||
|
if scores is not None and (scores[i] < drop_score or
|
||||||
|
math.isnan(scores[i])):
|
||||||
|
continue
|
||||||
|
box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
|
||||||
|
image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
|
||||||
|
if txts is not None:
|
||||||
|
img = np.array(resize_img(image, input_size=600))
|
||||||
|
txt_img = text_visual(
|
||||||
|
txts,
|
||||||
|
scores,
|
||||||
|
img_h=img.shape[0],
|
||||||
|
img_w=600,
|
||||||
|
threshold=drop_score,
|
||||||
|
font_path=font_path)
|
||||||
|
img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
|
||||||
|
return img
|
||||||
|
return image
|
||||||
|
|
||||||
|
def base64_to_cv2(b64str):
|
||||||
|
import base64
|
||||||
|
data = base64.b64decode(b64str.encode('utf8'))
|
||||||
|
data = np.frombuffer(data, np.uint8)
|
||||||
|
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def str2bool(v):
|
||||||
|
return v.lower() in ("true", "t", "1")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def infer_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
DET_MODEL_DIR = files('pp_onnx').joinpath('models/ch_PP-OCRv4/ch_PP-OCRv4_det_infer.onnx')
|
||||||
|
REC_MODEL_DIR = files('pp_onnx').joinpath('models/ch_PP-OCRv4/ch_PP-OCRv4_rec_infer.onnx')
|
||||||
|
PPOCR_KEYS_V1 = files('pp_onnx').joinpath('models/ch_ppocr_server_v2.0/ppocr_keys_v1.txt')
|
||||||
|
SIMFANG_TTF = files('pp_onnx').joinpath('fonts/simfang.ttf')
|
||||||
|
CLS_MODEL_DIR = files('pp_onnx').joinpath('models/ch_ppocr_server_v2.0/cls/cls.onnx')
|
||||||
|
|
||||||
|
# params for text detector
|
||||||
|
parser.add_argument("--image_dir", type=str)
|
||||||
|
parser.add_argument("--page_num", type=int, default=0)
|
||||||
|
parser.add_argument("--det_algorithm", type=str, default='DB')
|
||||||
|
parser.add_argument("--det_model_dir", type=str, default=DET_MODEL_DIR)
|
||||||
|
parser.add_argument("--det_limit_side_len", type=float, default=960)
|
||||||
|
parser.add_argument("--det_limit_type", type=str, default='max')
|
||||||
|
parser.add_argument("--det_box_type", type=str, default='quad')
|
||||||
|
|
||||||
|
# DB parmas
|
||||||
|
parser.add_argument("--det_db_thresh", type=float, default=0.3)
|
||||||
|
parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
|
||||||
|
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
|
||||||
|
parser.add_argument("--max_batch_size", type=int, default=10)
|
||||||
|
parser.add_argument("--use_dilation", type=str2bool, default=False)
|
||||||
|
parser.add_argument("--det_db_score_mode", type=str, default="fast")
|
||||||
|
|
||||||
|
# params for text recognizer
|
||||||
|
parser.add_argument("--rec_algorithm", type=str, default='SVTR_LCNet')
|
||||||
|
parser.add_argument("--rec_model_dir", type=str, default=REC_MODEL_DIR)
|
||||||
|
parser.add_argument("--rec_image_inverse", type=str2bool, default=True)
|
||||||
|
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
|
||||||
|
parser.add_argument("--rec_batch_num", type=int, default=6)
|
||||||
|
parser.add_argument("--max_text_length", type=int, default=25)
|
||||||
|
parser.add_argument( "--rec_char_dict_path", type=str, default=PPOCR_KEYS_V1)
|
||||||
|
parser.add_argument("--use_space_char", type=str2bool, default=True)
|
||||||
|
parser.add_argument( "--vis_font_path", type=str, default=SIMFANG_TTF)
|
||||||
|
parser.add_argument("--drop_score", type=float, default=0.5)
|
||||||
|
|
||||||
|
# params for text classifier
|
||||||
|
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
|
||||||
|
parser.add_argument("--cls_model_dir", type=str, default=CLS_MODEL_DIR)
|
||||||
|
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
|
||||||
|
parser.add_argument("--label_list", type=list, default=['0', '180'])
|
||||||
|
parser.add_argument("--cls_batch_num", type=int, default=6)
|
||||||
|
parser.add_argument("--cls_thresh", type=float, default=0.9)
|
||||||
|
|
||||||
|
# others
|
||||||
|
parser.add_argument("--save_crop_res", type=str2bool, default=False)
|
||||||
|
# parser.add_argument( "--draw_img_save_dir", type=str, default="./onnx/inference_results")
|
||||||
|
# parser.add_argument("--crop_res_save_dir", type=str, default="./onnx/output")
|
||||||
|
|
||||||
|
return parser
|
||||||
32
requirements.txt
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
colorama==0.4.6
|
||||||
|
coloredlogs==15.0.1
|
||||||
|
cycler==0.11.0
|
||||||
|
flatbuffers==23.5.26
|
||||||
|
fonttools==4.38.0
|
||||||
|
humanfriendly==10.0
|
||||||
|
imageio==2.31.1
|
||||||
|
imgaug==0.4.0
|
||||||
|
kiwisolver==1.4.4
|
||||||
|
lmdb==1.4.1
|
||||||
|
matplotlib==3.5.3
|
||||||
|
mpmath==1.3.0
|
||||||
|
networkx==2.6.3
|
||||||
|
numpy==1.21.6
|
||||||
|
onnxruntime==1.14.1
|
||||||
|
opencv-python==3.4.18.65
|
||||||
|
packaging==23.1
|
||||||
|
Pillow==9.5.0
|
||||||
|
protobuf==4.23.4
|
||||||
|
pyclipper==1.3.0.post4
|
||||||
|
pyparsing==3.1.0
|
||||||
|
pyreadline==2.1
|
||||||
|
python-dateutil==2.8.2
|
||||||
|
PyWavelets==1.3.0
|
||||||
|
scikit-image==0.19.3
|
||||||
|
scipy==1.7.3
|
||||||
|
shapely==2.0.1
|
||||||
|
six==1.16.0
|
||||||
|
sympy==1.10.1
|
||||||
|
tifffile==2021.11.2
|
||||||
|
tqdm==4.65.0
|
||||||
|
typing_extensions==4.7.1
|
||||||
BIN
result_img/1.jpg
Normal file
|
After Width: | Height: | Size: 97 KiB |
BIN
result_img/3.jpg
Normal file
|
After Width: | Height: | Size: 228 KiB |
BIN
result_img/draw_ocr.jpg
Normal file
|
After Width: | Height: | Size: 74 KiB |
BIN
result_img/draw_ocr2.jpg
Normal file
|
After Width: | Height: | Size: 110 KiB |
BIN
result_img/draw_ocr3.jpg
Normal file
|
After Width: | Height: | Size: 86 KiB |
BIN
result_img/draw_ocr4.jpg
Normal file
|
After Width: | Height: | Size: 108 KiB |
BIN
result_img/draw_ocr5.jpg
Normal file
|
After Width: | Height: | Size: 74 KiB |
BIN
result_img/draw_ocr_1.jpg
Normal file
|
After Width: | Height: | Size: 60 KiB |
BIN
result_img/draw_ocr_996.jpg
Normal file
|
After Width: | Height: | Size: 76 KiB |
BIN
result_img/draw_ocr_996_1.jpg
Normal file
|
After Width: | Height: | Size: 228 KiB |
BIN
result_img/rec.onnx
Normal file
BIN
test_img/1.jpg
Normal file
|
After Width: | Height: | Size: 305 KiB |
BIN
test_img/10.jpg
Normal file
|
After Width: | Height: | Size: 100 KiB |
BIN
test_img/11.jpg
Normal file
|
After Width: | Height: | Size: 150 KiB |
BIN
test_img/12.jpg
Normal file
|
After Width: | Height: | Size: 54 KiB |
BIN
test_img/13.jpg
Normal file
|
After Width: | Height: | Size: 122 KiB |
BIN
test_img/14.jpg
Normal file
|
After Width: | Height: | Size: 98 KiB |
BIN
test_img/15.jpg
Normal file
|
After Width: | Height: | Size: 49 KiB |
BIN
test_img/16.jpg
Normal file
|
After Width: | Height: | Size: 164 KiB |
BIN
test_img/17.jpg
Normal file
|
After Width: | Height: | Size: 34 KiB |
BIN
test_img/18.jpg
Normal file
|
After Width: | Height: | Size: 47 KiB |
BIN
test_img/19.jpg
Normal file
|
After Width: | Height: | Size: 6.1 KiB |
BIN
test_img/2.jpg
Normal file
|
After Width: | Height: | Size: 617 KiB |
BIN
test_img/20.jpg
Normal file
|
After Width: | Height: | Size: 96 KiB |
BIN
test_img/21.jpg
Normal file
|
After Width: | Height: | Size: 39 KiB |
BIN
test_img/22.png
Normal file
|
After Width: | Height: | Size: 65 KiB |
BIN
test_img/3.jpg
Normal file
|
After Width: | Height: | Size: 396 KiB |
BIN
test_img/4.jpg
Normal file
|
After Width: | Height: | Size: 48 KiB |
BIN
test_img/5.jpg
Normal file
|
After Width: | Height: | Size: 389 KiB |
BIN
test_img/6.jpg
Normal file
|
After Width: | Height: | Size: 126 KiB |
BIN
test_img/7.jpg
Normal file
|
After Width: | Height: | Size: 42 KiB |
BIN
test_img/8.jpg
Normal file
|
After Width: | Height: | Size: 89 KiB |
BIN
test_img/9.jpg
Normal file
|
After Width: | Height: | Size: 67 KiB |
45
test_ocr.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import cv2
|
||||||
|
import time
|
||||||
|
from pp_onnx.onnx_paddleocr import ONNXPaddleOcr, draw_ocr
|
||||||
|
|
||||||
|
# 优化参数配置,提高速度
|
||||||
|
model = ONNXPaddleOcr(
|
||||||
|
use_angle_cls=True,
|
||||||
|
use_gpu=False,
|
||||||
|
providers=['AzureExecutionProvider'],
|
||||||
|
provider_options=[{'device_id': 0}],
|
||||||
|
det_limit_type='max',
|
||||||
|
)
|
||||||
|
|
||||||
|
def sav2Img(org_img, result, name="./result_img/3.jpg"):
|
||||||
|
from PIL import Image
|
||||||
|
result = result[0]
|
||||||
|
image = org_img[:, :, ::-1]
|
||||||
|
boxes = [line[0] for line in result]
|
||||||
|
txts = [line[1][0] for line in result]
|
||||||
|
scores = [line[1][1] for line in result]
|
||||||
|
im_show = draw_ocr(image, boxes, txts, scores)
|
||||||
|
im_show = Image.fromarray(im_show)
|
||||||
|
im_show.save(name)
|
||||||
|
|
||||||
|
|
||||||
|
# 执行OCR推理
|
||||||
|
img = cv2.imread('./test_img/3.jpg')
|
||||||
|
if img is None:
|
||||||
|
print(f"❌ 未找到图像文件")
|
||||||
|
else:
|
||||||
|
# 图像预处理 - 缩小图像尺寸加速处理
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
max_size = 1080
|
||||||
|
if max(h, w) > max_size:
|
||||||
|
scale = max_size / max(h, w)
|
||||||
|
img = cv2.resize(img, (int(w * scale), int(h * scale)))
|
||||||
|
|
||||||
|
s = time.time()
|
||||||
|
result = model.ocr(img)
|
||||||
|
e = time.time()
|
||||||
|
print(f"total time: {e - s:.3f} 秒")
|
||||||
|
print("result:", result)
|
||||||
|
for box in result[0]:
|
||||||
|
print(box)
|
||||||
|
sav2Img(img, result)
|
||||||