路径写入数据库
This commit is contained in:
		
							
								
								
									
										89
									
								
								core/all.py
									
									
									
									
									
								
							
							
						
						
									
										89
									
								
								core/all.py
									
									
									
									
									
								
							| @ -2,6 +2,7 @@ import cv2 | |||||||
| import numpy as np | import numpy as np | ||||||
| from PIL.Image import Image | from PIL.Image import Image | ||||||
|  |  | ||||||
|  | from core.establish import get_image_save_path | ||||||
| from core.ocr import load_model as ocrLoadModel, detect as ocrDetect | from core.ocr import load_model as ocrLoadModel, detect as ocrDetect | ||||||
| from core.face import load_model as faceLoadModel, detect as faceDetect | from core.face import load_model as faceLoadModel, detect as faceDetect | ||||||
| from core.yolo import load_model as yoloLoadModel, detect as yoloDetect | from core.yolo import load_model as yoloLoadModel, detect as yoloDetect | ||||||
| @ -68,19 +69,24 @@ def detect(client_ip, frame): | |||||||
|     yolo_flag, yolo_result = yoloDetect(frame) |     yolo_flag, yolo_result = yoloDetect(frame) | ||||||
|     print(f"YOLO检测结果:{yolo_result}") |     print(f"YOLO检测结果:{yolo_result}") | ||||||
|     if yolo_flag: |     if yolo_flag: | ||||||
|         save_db(model_type="yolo", client_ip=client_ip, result=numpy_array_to_base64(frame)) |         full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip) | ||||||
|         # if full_save_path:  # 只判断完整路径是否有效(用于保存) |         if full_save_path:  # 只判断完整路径是否有效(用于保存) | ||||||
|         #     cv2.imwrite(full_save_path, frame) |             cv2.imwrite(full_save_path, frame) | ||||||
|         #     # 打印时使用「显示用短路径」,符合需求格式 |             # 打印时使用「显示用短路径」,符合需求格式 | ||||||
|         #     print(f"✅ YOLO违规图片已保存:{display_path}") |             print(f"✅ YOLO违规图片已保存:{display_path}") | ||||||
|  |         save_db(model_type="yolo", client_ip=client_ip, result=str(full_save_path)) | ||||||
|         return (True, yolo_result, "yolo") |         return (True, yolo_result, "yolo") | ||||||
|     # |     # | ||||||
|     # # 2. 人脸检测(优先级2) |     # # 2. 人脸检测(优先级2) | ||||||
|     face_flag, face_result = faceDetect(frame) |     face_flag, face_result = faceDetect(frame) | ||||||
|     print(f"人脸检测结果:{face_result}") |     print(f"人脸检测结果:{face_result}") | ||||||
|     if face_flag: |     if face_flag: | ||||||
|         # 将帧转化为 base64 字符串 |         full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip) | ||||||
|         save_db(model_type="face", client_ip=client_ip, result=numpy_array_to_base64(frame)) |         if full_save_path:  # 只判断完整路径是否有效(用于保存) | ||||||
|  |             cv2.imwrite(full_save_path, frame) | ||||||
|  |             # 打印时使用「显示用短路径」,符合需求格式 | ||||||
|  |             print(f"✅ face违规图片已保存:{display_path}") | ||||||
|  |         save_db(model_type="face", client_ip=client_ip, result=str(full_save_path)) | ||||||
|         return (True, face_result, "face") |         return (True, face_result, "face") | ||||||
|  |  | ||||||
|     # 3. OCR检测(优先级3) |     # 3. OCR检测(优先级3) | ||||||
| @ -88,70 +94,13 @@ def detect(client_ip, frame): | |||||||
|     print(f"OCR检测结果:{ocr_result}") |     print(f"OCR检测结果:{ocr_result}") | ||||||
|     if ocr_flag: |     if ocr_flag: | ||||||
|         # 解构元组,保存用完整路径,打印用短路径 |         # 解构元组,保存用完整路径,打印用短路径 | ||||||
|         save_db(model_type="ocr", client_ip=client_ip, result=ocr_result) |         full_save_path, display_path = get_image_save_path(model_type="yolo", client_ip=client_ip) | ||||||
|         # if full_save_path: |         if full_save_path:  # 只判断完整路径是否有效(用于保存) | ||||||
|         #     cv2.imwrite(full_save_path, frame) |             cv2.imwrite(full_save_path, frame) | ||||||
|         #     print(f"✅ OCR违规图片已保存:{display_path}") |             # 打印时使用「显示用短路径」,符合需求格式 | ||||||
|  |             print(f"✅ ocr违规图片已保存:{display_path}") | ||||||
|  |         save_db(model_type="ocr", client_ip=client_ip, result=str(full_save_path)) | ||||||
|         return (True, ocr_result, "ocr") |         return (True, ocr_result, "ocr") | ||||||
|  |  | ||||||
|     # 4. 无违规内容(不保存图片) |     # 4. 无违规内容(不保存图片) | ||||||
|     print(f"❌ 未检测到任何违规内容,不保存图片") |     print(f"❌ 未检测到任何违规内容,不保存图片") | ||||||
|     return (False, "未检测到任何内容", "none") |     return (False, "未检测到任何内容", "none") | ||||||
|  |  | ||||||
|  |  | ||||||
| def numpy_array_to_base64(arr, img_format='PNG'): |  | ||||||
|     """ |  | ||||||
|     将numpy数组转换为base64字符串 |  | ||||||
|  |  | ||||||
|     参数: |  | ||||||
|         arr: numpy数组,通常是图像数据,形状为(height, width, channels) |  | ||||||
|         img_format: 图像格式,默认为'PNG',也可以是'JPEG'等PIL支持的格式 |  | ||||||
|  |  | ||||||
|     返回: |  | ||||||
|         str: 转换后的base64字符串 |  | ||||||
|  |  | ||||||
|     异常: |  | ||||||
|         ValueError: 当输入不是有效的numpy数组或不支持的形状时抛出 |  | ||||||
|         Exception: 处理过程中出现的其他异常 |  | ||||||
|     """ |  | ||||||
|     try: |  | ||||||
|         # 检查输入是否为numpy数组 |  | ||||||
|         if not isinstance(arr, np.ndarray): |  | ||||||
|             raise ValueError("输入必须是numpy数组") |  | ||||||
|  |  | ||||||
|         # 处理单通道图像(灰度图) |  | ||||||
|         if len(arr.shape) == 2: |  | ||||||
|             arr = np.expand_dims(arr, axis=-1) |  | ||||||
|  |  | ||||||
|         # 检查数组形状是否有效 |  | ||||||
|         if len(arr.shape) != 3 or arr.shape[2] not in [1, 3, 4]: |  | ||||||
|             raise ValueError("numpy数组必须是形状为(height, width, channels)的图像数据,通道数应为1、3或4") |  | ||||||
|  |  | ||||||
|         # 处理数据类型,确保是uint8类型 |  | ||||||
|         if arr.dtype != np.uint8: |  | ||||||
|             # 归一化到0-255并转换为uint8 |  | ||||||
|             arr = ((arr - arr.min()) / (arr.max() - arr.min() + 1e-8) * 255).astype(np.uint8) |  | ||||||
|  |  | ||||||
|         # 将单通道图像转换为PIL支持的模式 |  | ||||||
|         if arr.shape[2] == 1: |  | ||||||
|             arr = arr.squeeze(axis=-1) |  | ||||||
|             image = Image.fromarray(arr, mode='L')  # L模式表示灰度图 |  | ||||||
|         elif arr.shape[2] == 3: |  | ||||||
|             image = Image.fromarray(arr, mode='RGB') |  | ||||||
|         else:  # 4通道 |  | ||||||
|             image = Image.fromarray(arr, mode='RGBA') |  | ||||||
|  |  | ||||||
|         # 将图像保存到内存缓冲区 |  | ||||||
|         buffer = BytesIO() |  | ||||||
|         image.save(buffer, format=img_format) |  | ||||||
|  |  | ||||||
|         # 从缓冲区读取数据并编码为base64 |  | ||||||
|         buffer.seek(0) |  | ||||||
|         base64_str = base64.b64encode(buffer.read()).decode('utf-8') |  | ||||||
|  |  | ||||||
|         return base64_str |  | ||||||
|  |  | ||||||
|     except ValueError as ve: |  | ||||||
|         raise ve |  | ||||||
|     except Exception as e: |  | ||||||
|         raise Exception(f"转换过程中发生错误: {str(e)}") |  | ||||||
		Reference in New Issue
	
	Block a user