52 lines
1.4 KiB
Python
52 lines
1.4 KiB
Python
import onnxruntime
|
|
|
|
class PredictBase(object):
|
|
def __init__(self):
|
|
pass
|
|
|
|
def get_onnx_session(self, model_dir, use_gpu):
|
|
# 使用gpu
|
|
if use_gpu:
|
|
providers = providers=['CUDAExecutionProvider']
|
|
else:
|
|
providers = providers = ['CPUExecutionProvider']
|
|
|
|
onnx_session = onnxruntime.InferenceSession(str(model_dir), None, providers=providers)
|
|
|
|
# print("providers:", onnxruntime.get_device())
|
|
return onnx_session
|
|
|
|
|
|
def get_output_name(self, onnx_session):
|
|
"""
|
|
output_name = onnx_session.get_outputs()[0].name
|
|
:param onnx_session:
|
|
:return:
|
|
"""
|
|
output_name = []
|
|
for node in onnx_session.get_outputs():
|
|
output_name.append(node.name)
|
|
return output_name
|
|
|
|
def get_input_name(self, onnx_session):
|
|
"""
|
|
input_name = onnx_session.get_inputs()[0].name
|
|
:param onnx_session:
|
|
:return:
|
|
"""
|
|
input_name = []
|
|
for node in onnx_session.get_inputs():
|
|
input_name.append(node.name)
|
|
return input_name
|
|
|
|
def get_input_feed(self, input_name, image_numpy):
|
|
"""
|
|
input_feed={self.input_name: image_numpy}
|
|
:param input_name:
|
|
:param image_numpy:
|
|
:return:
|
|
"""
|
|
input_feed = {}
|
|
for name in input_name:
|
|
input_feed[name] = image_numpy
|
|
return input_feed |