盒子ocr检测
This commit is contained in:
52
pp_onnx/predict_base.py
Normal file
52
pp_onnx/predict_base.py
Normal file
@ -0,0 +1,52 @@
|
||||
import onnxruntime
|
||||
|
||||
class PredictBase(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_onnx_session(self, model_dir, use_gpu):
|
||||
# 使用gpu
|
||||
if use_gpu:
|
||||
providers = providers=['CUDAExecutionProvider']
|
||||
else:
|
||||
providers = providers = ['CPUExecutionProvider']
|
||||
|
||||
onnx_session = onnxruntime.InferenceSession(str(model_dir), None, providers=providers)
|
||||
|
||||
# print("providers:", onnxruntime.get_device())
|
||||
return onnx_session
|
||||
|
||||
|
||||
def get_output_name(self, onnx_session):
|
||||
"""
|
||||
output_name = onnx_session.get_outputs()[0].name
|
||||
:param onnx_session:
|
||||
:return:
|
||||
"""
|
||||
output_name = []
|
||||
for node in onnx_session.get_outputs():
|
||||
output_name.append(node.name)
|
||||
return output_name
|
||||
|
||||
def get_input_name(self, onnx_session):
|
||||
"""
|
||||
input_name = onnx_session.get_inputs()[0].name
|
||||
:param onnx_session:
|
||||
:return:
|
||||
"""
|
||||
input_name = []
|
||||
for node in onnx_session.get_inputs():
|
||||
input_name.append(node.name)
|
||||
return input_name
|
||||
|
||||
def get_input_feed(self, input_name, image_numpy):
|
||||
"""
|
||||
input_feed={self.input_name: image_numpy}
|
||||
:param input_name:
|
||||
:param image_numpy:
|
||||
:return:
|
||||
"""
|
||||
input_feed = {}
|
||||
for name in input_name:
|
||||
input_feed[name] = image_numpy
|
||||
return input_feed
|
||||
Reference in New Issue
Block a user