52 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			52 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|  | import onnxruntime | ||
|  | 
 | ||
|  | class PredictBase(object): | ||
|  |     def __init__(self): | ||
|  |         pass | ||
|  | 
 | ||
|  |     def get_onnx_session(self, model_dir, use_gpu): | ||
|  |         # 使用gpu | ||
|  |         if use_gpu: | ||
|  |             providers = providers=['CUDAExecutionProvider'] | ||
|  |         else: | ||
|  |             providers = providers = ['CPUExecutionProvider'] | ||
|  | 
 | ||
|  |         onnx_session = onnxruntime.InferenceSession(str(model_dir), None, providers=providers) | ||
|  | 
 | ||
|  |         # print("providers:", onnxruntime.get_device()) | ||
|  |         return onnx_session | ||
|  | 
 | ||
|  | 
 | ||
|  |     def get_output_name(self, onnx_session): | ||
|  |         """
 | ||
|  |         output_name = onnx_session.get_outputs()[0].name | ||
|  |         :param onnx_session: | ||
|  |         :return: | ||
|  |         """
 | ||
|  |         output_name = [] | ||
|  |         for node in onnx_session.get_outputs(): | ||
|  |             output_name.append(node.name) | ||
|  |         return output_name | ||
|  | 
 | ||
|  |     def get_input_name(self, onnx_session): | ||
|  |         """
 | ||
|  |         input_name = onnx_session.get_inputs()[0].name | ||
|  |         :param onnx_session: | ||
|  |         :return: | ||
|  |         """
 | ||
|  |         input_name = [] | ||
|  |         for node in onnx_session.get_inputs(): | ||
|  |             input_name.append(node.name) | ||
|  |         return input_name | ||
|  | 
 | ||
|  |     def get_input_feed(self, input_name, image_numpy): | ||
|  |         """
 | ||
|  |         input_feed={self.input_name: image_numpy} | ||
|  |         :param input_name: | ||
|  |         :param image_numpy: | ||
|  |         :return: | ||
|  |         """
 | ||
|  |         input_feed = {} | ||
|  |         for name in input_name: | ||
|  |             input_feed[name] = image_numpy | ||
|  |         return input_feed |