盒子ocr检测
This commit is contained in:
		
							
								
								
									
										52
									
								
								pp_onnx/predict_base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								pp_onnx/predict_base.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,52 @@ | ||||
| import onnxruntime | ||||
|  | ||||
| class PredictBase(object): | ||||
|     def __init__(self): | ||||
|         pass | ||||
|  | ||||
|     def get_onnx_session(self, model_dir, use_gpu): | ||||
|         # 使用gpu | ||||
|         if use_gpu: | ||||
|             providers = providers=['CUDAExecutionProvider'] | ||||
|         else: | ||||
|             providers = providers = ['CPUExecutionProvider'] | ||||
|  | ||||
|         onnx_session = onnxruntime.InferenceSession(str(model_dir), None, providers=providers) | ||||
|  | ||||
|         # print("providers:", onnxruntime.get_device()) | ||||
|         return onnx_session | ||||
|  | ||||
|  | ||||
|     def get_output_name(self, onnx_session): | ||||
|         """ | ||||
|         output_name = onnx_session.get_outputs()[0].name | ||||
|         :param onnx_session: | ||||
|         :return: | ||||
|         """ | ||||
|         output_name = [] | ||||
|         for node in onnx_session.get_outputs(): | ||||
|             output_name.append(node.name) | ||||
|         return output_name | ||||
|  | ||||
|     def get_input_name(self, onnx_session): | ||||
|         """ | ||||
|         input_name = onnx_session.get_inputs()[0].name | ||||
|         :param onnx_session: | ||||
|         :return: | ||||
|         """ | ||||
|         input_name = [] | ||||
|         for node in onnx_session.get_inputs(): | ||||
|             input_name.append(node.name) | ||||
|         return input_name | ||||
|  | ||||
|     def get_input_feed(self, input_name, image_numpy): | ||||
|         """ | ||||
|         input_feed={self.input_name: image_numpy} | ||||
|         :param input_name: | ||||
|         :param image_numpy: | ||||
|         :return: | ||||
|         """ | ||||
|         input_feed = {} | ||||
|         for name in input_name: | ||||
|             input_feed[name] = image_numpy | ||||
|         return input_feed | ||||
		Reference in New Issue
	
	Block a user