343 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			343 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|  | import os | |||
|  | import re | |||
|  | import shutil | |||
|  | from datetime import datetime | |||
|  | from PIL import ImageDraw, ImageFont | |||
|  | from fastapi import UploadFile | |||
|  | import cv2 | |||
|  | from PIL import Image | |||
|  | import numpy as np | |||
|  | 
 | |||
|  | # 上传根目录 | |||
|  | UPLOAD_ROOT = "upload" | |||
|  | PRE = "/api/file/download/" | |||
|  | 
 | |||
|  | # 确保上传根目录存在 | |||
|  | os.makedirs(UPLOAD_ROOT, exist_ok=True) | |||
|  | 
 | |||
|  | 
 | |||
|  | 
 | |||
|  | def save_detect_file(client_ip: str, image_np: np.ndarray, file_type: str) -> str: | |||
|  |     """保存numpy数组格式的PNG图片到detect目录,返回下载路径""" | |||
|  |     today = datetime.now() | |||
|  |     year = today.strftime("%Y") | |||
|  |     month = today.strftime("%m") | |||
|  |     day = today.strftime("%d") | |||
|  | 
 | |||
|  |     # 构建目录路径: upload/detect/客户端IP/type/年/月/日(包含UPLOAD_ROOT) | |||
|  |     file_dir = os.path.join( | |||
|  |         UPLOAD_ROOT, | |||
|  |         "detect", | |||
|  |         client_ip, | |||
|  |         file_type, | |||
|  |         year, | |||
|  |         month, | |||
|  |         day | |||
|  |     ) | |||
|  | 
 | |||
|  |     # 创建目录(确保目录存在) | |||
|  |     os.makedirs(file_dir, exist_ok=True) | |||
|  | 
 | |||
|  |     # 生成当前时间戳作为文件名,确保唯一性 | |||
|  |     timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f") | |||
|  |     filename = f"{timestamp}.png" | |||
|  | 
 | |||
|  |     # 1. 完整路径:用于实际保存文件(包含UPLOAD_ROOT) | |||
|  |     full_path = os.path.join(file_dir, filename) | |||
|  |     # 2. 相对路径:用于返回给前端(移除UPLOAD_ROOT前缀) | |||
|  |     relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep) | |||
|  | 
 | |||
|  |     # 保存numpy数组为PNG图片 | |||
|  |     try: | |||
|  |         # -------- 新增/修改:处理颜色通道和数据类型 -------- | |||
|  |         # 1. 数据类型转换:确保是uint8(若为float32且范围0-1,需转成0-255的uint8) | |||
|  |         if image_np.dtype != np.uint8: | |||
|  |             image_np = (image_np * 255).astype(np.uint8) | |||
|  | 
 | |||
|  |         # 2. 通道顺序转换:若为OpenCV的BGR格式,转成PIL需要的RGB格式 | |||
|  |         image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) | |||
|  | 
 | |||
|  |         # 3. 转换为PIL Image并保存 | |||
|  |         img = Image.fromarray(image_rgb) | |||
|  |         img.save(full_path, format='PNG') | |||
|  |     except Exception as e: | |||
|  |         # 处理可能的异常(如数组格式不正确) | |||
|  |         raise Exception(f"保存图片失败: {str(e)}") | |||
|  | 
 | |||
|  |     # 统一路径分隔符为/,拼接前缀返回 | |||
|  |     return PRE + relative_path.replace(os.sep, "/") | |||
|  | 
 | |||
|  | 
 | |||
|  | def save_detect_yolo_file( | |||
|  |         client_ip: str, | |||
|  |         image_np: np.ndarray, | |||
|  |         detection_results: list, | |||
|  |         file_type: str = "yolo" | |||
|  | ) -> str: | |||
|  | 
 | |||
|  | 
 | |||
|  |     print("......................") | |||
|  |     """
 | |||
|  |     保存YOLO检测结果图片(在原图上绘制边界框+标签),返回前端可访问的下载路径 | |||
|  |     """
 | |||
|  |     # 输入参数验证 | |||
|  |     if not isinstance(image_np, np.ndarray): | |||
|  |         raise ValueError(f"输入image_np必须是numpy数组,当前类型:{type(image_np)}") | |||
|  |     if image_np.ndim != 3 or image_np.shape[-1] != 3: | |||
|  |         raise ValueError(f"输入图像必须是 (h, w, 3) 的BGR数组,当前shape:{image_np.shape}") | |||
|  | 
 | |||
|  |     if not isinstance(detection_results, list): | |||
|  |         raise ValueError(f"detection_results必须是列表,当前类型:{type(detection_results)}") | |||
|  |     for idx, result in enumerate(detection_results): | |||
|  |         required_keys = {"class", "confidence", "bbox"} | |||
|  |         if not isinstance(result, dict) or not required_keys.issubset(result.keys()): | |||
|  |             raise ValueError( | |||
|  |                 f"detection_results第{idx}个元素格式错误,需包含键:{required_keys}," | |||
|  |                 f"当前元素:{result}" | |||
|  |             ) | |||
|  |         bbox = result["bbox"] | |||
|  |         if not (isinstance(bbox, (tuple, list)) and len(bbox) == 4 and all(isinstance(x, int) for x in bbox)): | |||
|  |             raise ValueError( | |||
|  |                 f"detection_results第{idx}个元素的bbox格式错误,需为(x1,y1,x2,y2)整数元组," | |||
|  |                 f"当前bbox:{bbox}" | |||
|  |             ) | |||
|  | 
 | |||
|  |     #图像预处理(数据类型+通道) | |||
|  |     draw_image = image_np.copy() | |||
|  |     if draw_image.dtype != np.uint8: | |||
|  |         draw_image = np.clip(draw_image * 255, 0, 255).astype(np.uint8) | |||
|  | 
 | |||
|  |     #绘制边界框+标签 | |||
|  |     # 遍历所有检测结果,逐个绘制 | |||
|  |     for result in detection_results: | |||
|  |         class_name = result["class"] | |||
|  |         confidence = result["confidence"] | |||
|  |         x1, y1, x2, y2 = result["bbox"] | |||
|  |         cv2.rectangle(draw_image, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2) | |||
|  |         label = f"{class_name}: {confidence:.2f}" | |||
|  |         font = cv2.FONT_HERSHEY_SIMPLEX | |||
|  |         font_scale = 0.5 | |||
|  |         font_thickness = 2 | |||
|  |         (label_width, label_height), baseline = cv2.getTextSize( | |||
|  |             label, font, font_scale, font_thickness | |||
|  |         ) | |||
|  | 
 | |||
|  |         bg_top_left = (x1, y1 - label_height - 10) | |||
|  |         bg_bottom_right = (x1 + label_width, y1) | |||
|  |         if bg_top_left[1] < 0: | |||
|  |             bg_top_left = (x1, 0) | |||
|  |             bg_bottom_right = (x1 + label_width, label_height + 10) | |||
|  |         cv2.rectangle(draw_image, bg_top_left, bg_bottom_right, color=(0, 0, 0), thickness=-1) | |||
|  | 
 | |||
|  |         text_origin = (x1, y1 - 5) | |||
|  |         if bg_top_left[1] == 0: | |||
|  |             text_origin = (x1, label_height + 5) | |||
|  |         cv2.putText( | |||
|  |             draw_image, label, text_origin, | |||
|  |             font, font_scale, color=(255, 255, 255), thickness=font_thickness | |||
|  |         ) | |||
|  | 
 | |||
|  |     #保存图片 | |||
|  |     try: | |||
|  |         today = datetime.now() | |||
|  |         year = today.strftime("%Y") | |||
|  |         month = today.strftime("%m") | |||
|  |         day = today.strftime("%d") | |||
|  |         file_dir = os.path.join( | |||
|  |             UPLOAD_ROOT, "detect", client_ip, file_type, year, month, day | |||
|  |         ) | |||
|  | 
 | |||
|  |         #创建目录(若不存在则创建,支持多级目录) | |||
|  |         os.makedirs(file_dir, exist_ok=True) | |||
|  | 
 | |||
|  |         #生成唯一文件名 | |||
|  |         timestamp = today.strftime("%Y%m%d%H%M%S%f") | |||
|  |         filename = f"{timestamp}.png" | |||
|  | 
 | |||
|  |         # 4.4 构建完整保存路径和前端访问路径 | |||
|  |         full_path = os.path.join(file_dir, filename)  # 本地完整路径 | |||
|  |         # 相对路径:移除UPLOAD_ROOT前缀,统一用"/"作为分隔符(兼容Windows/Linux) | |||
|  |         relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep) | |||
|  |         download_path = PRE + relative_path.replace(os.sep, "/") | |||
|  | 
 | |||
|  |         # 4.5 保存图片(CV2绘制的是BGR,需转RGB后用PIL保存,与原逻辑一致) | |||
|  |         image_rgb = cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB) | |||
|  |         img_pil = Image.fromarray(image_rgb) | |||
|  |         img_pil.save(full_path, format="PNG", quality=95)  # PNG格式无压缩,quality可忽略 | |||
|  | 
 | |||
|  |         print(f"YOLO检测图片保存成功 | 本地路径:{full_path} | 下载路径:{download_path}") | |||
|  |         return download_path | |||
|  | 
 | |||
|  |     except Exception as e: | |||
|  |         raise Exception(f"YOLO检测图片保存失败:{str(e)}") from e | |||
|  | 
 | |||
|  | 
 | |||
|  | def save_detect_face_file( | |||
|  |     client_ip: str, | |||
|  |     image_np: np.ndarray, | |||
|  |     face_result: str, | |||
|  |     file_type: str = "face", | |||
|  |     matched_color: tuple = (0, 255, 0) | |||
|  | ) -> str: | |||
|  |     """
 | |||
|  |     保存人脸识别结果图片(仅为「匹配成功」的人脸画框,标签不包含“匹配”二字) | |||
|  |     """
 | |||
|  |     #输入参数验证 | |||
|  |     if not isinstance(image_np, np.ndarray) or image_np.ndim != 3 or image_np.shape[-1] != 3: | |||
|  |         raise ValueError(f"输入图像需为 (h, w, 3) 的BGR数组,当前shape:{image_np.shape}") | |||
|  |     if not isinstance(face_result, str) or face_result.strip() == "": | |||
|  |         raise ValueError("face_result必须是非空字符串") | |||
|  | 
 | |||
|  |     # 解析face_result提取人脸信息 | |||
|  |     face_info_list = [] | |||
|  |     if face_result.strip() != "未检测到人脸": | |||
|  |         face_pattern = re.compile( | |||
|  |             r"(匹配|未匹配):\s*([^\s(]+)\s*\(相似度:\s*(\d+\.\d+),\s*边界框:\s*\[(\d+,\s*\d+,\s*\d+,\s*\d+)\]\)" | |||
|  |         ) | |||
|  |         for part in [p.strip() for p in face_result.split(";") if p.strip()]: | |||
|  |             match = face_pattern.match(part) | |||
|  |             if match: | |||
|  |                 status, name, similarity, bbox_str = match.groups() | |||
|  |                 bbox = list(map(int, bbox_str.replace(" ", "").split(","))) | |||
|  |                 if len(bbox) == 4: | |||
|  |                     face_info_list.append({ | |||
|  |                         "status": status, | |||
|  |                         "name": name, | |||
|  |                         "similarity": float(similarity), | |||
|  |                         "bbox": bbox | |||
|  |                     }) | |||
|  | 
 | |||
|  |     # 图像格式转换(OpenCV→PIL) | |||
|  |     image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) | |||
|  |     pil_img = Image.fromarray(image_rgb) | |||
|  |     draw = ImageDraw.Draw(pil_img) | |||
|  | 
 | |||
|  |     # 绘制边界框和标签 | |||
|  |     font_size = 12 | |||
|  |     try: | |||
|  |         font = ImageFont.truetype("simhei", font_size) | |||
|  |     except: | |||
|  |         try: | |||
|  |             font = ImageFont.truetype("simsun", font_size) | |||
|  |         except: | |||
|  |             font = ImageFont.load_default() | |||
|  |             print("警告:未找到指定中文字体,使用PIL默认字体(可能影响中文显示)") | |||
|  | 
 | |||
|  |     for face_info in face_info_list: | |||
|  |         status = face_info["status"] | |||
|  |         if status != "匹配": | |||
|  |             print(f"跳过未匹配人脸:{face_info['name']}(相似度:{face_info['similarity']:.2f})") | |||
|  |             continue | |||
|  | 
 | |||
|  |         name = face_info["name"] | |||
|  |         similarity = face_info["similarity"] | |||
|  |         x1, y1, x2, y2 = face_info["bbox"] | |||
|  | 
 | |||
|  |         # 4.1 绘制边界框(绿色) | |||
|  |         img_cv = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) | |||
|  |         cv2.rectangle(img_cv, (x1, y1), (x2, y2), color=matched_color, thickness=2) | |||
|  |         pil_img = Image.fromarray(cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)) | |||
|  |         draw = ImageDraw.Draw(pil_img) | |||
|  | 
 | |||
|  |         label = f"{name} (相似度: {similarity:.2f})" | |||
|  | 
 | |||
|  |         # 4.3 计算标签尺寸(文本变短后会自动适配,无需额外调整) | |||
|  |         label_bbox = draw.textbbox((0, 0), label, font=font) | |||
|  |         label_width = label_bbox[2] - label_bbox[0] | |||
|  |         label_height = label_bbox[3] - label_bbox[1] | |||
|  | 
 | |||
|  |         # 4.4 计算标签背景位置(避免超出图像) | |||
|  |         bg_x1, bg_y1 = x1, y1 - label_height - 10 | |||
|  |         bg_x2, bg_y2 = x1 + label_width, y1 | |||
|  |         if bg_y1 < 0: | |||
|  |             bg_y1, bg_y2 = y2 + 5, y2 + label_height + 15 | |||
|  | 
 | |||
|  |         # 4.5 绘制标签背景(黑色)和文本(白色) | |||
|  |         draw.rectangle([(bg_x1, bg_y1), (bg_x2, bg_y2)], fill=(0, 0, 0)) | |||
|  |         text_x = bg_x1 | |||
|  |         text_y = bg_y1 if bg_y1 >= 0 else bg_y1 + label_height | |||
|  |         draw.text((text_x, text_y), label, font=font, fill=(255, 255, 255)) | |||
|  | 
 | |||
|  |     #保存图片 | |||
|  |     try: | |||
|  |         today = datetime.now() | |||
|  |         file_dir = os.path.join( | |||
|  |             UPLOAD_ROOT, "detect", client_ip, file_type, | |||
|  |             today.strftime("%Y"), today.strftime("%m"), today.strftime("%d") | |||
|  |         ) | |||
|  |         os.makedirs(file_dir, exist_ok=True) | |||
|  | 
 | |||
|  |         timestamp = today.strftime("%Y%m%d%H%M%S%f") | |||
|  |         filename = f"{timestamp}.png" | |||
|  |         full_path = os.path.join(file_dir, filename) | |||
|  | 
 | |||
|  |         pil_img.save(full_path, format="PNG", quality=100) | |||
|  | 
 | |||
|  |         relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep) | |||
|  |         download_path = PRE + relative_path.replace(os.sep, "/") | |||
|  | 
 | |||
|  |         matched_count = sum(1 for info in face_info_list if info["status"] == "匹配") | |||
|  |         print(f"人脸检测图片保存成功 | 客户端IP:{client_ip} | 匹配人脸数:{matched_count} | 保存路径:{download_path}") | |||
|  |         return download_path | |||
|  | 
 | |||
|  |     except Exception as e: | |||
|  |         raise Exception(f"人脸检测图片保存失败(客户端IP:{client_ip}):{str(e)}") from e | |||
|  | 
 | |||
|  | def save_source_file(upload_file: UploadFile, file_type: str) -> str: | |||
|  |     """保存上传的文件到source目录,返回下载路径""" | |||
|  |     today = datetime.now() | |||
|  |     year = today.strftime("%Y") | |||
|  |     month = today.strftime("%m") | |||
|  |     day = today.strftime("%d") | |||
|  | 
 | |||
|  |     # 生成精确到微秒的时间戳,确保文件名唯一 | |||
|  |     timestamp = today.strftime("%Y%m%d%H%M%S%f") | |||
|  |     # 构建新文件名:时间戳_原文件名 | |||
|  |     unique_filename = f"{timestamp}_{upload_file.filename}" | |||
|  | 
 | |||
|  |     # 构建目录路径: upload/source/type/年/月/日(包含UPLOAD_ROOT) | |||
|  |     file_dir = os.path.join( | |||
|  |         UPLOAD_ROOT, | |||
|  |         "source", | |||
|  |         file_type, | |||
|  |         year, | |||
|  |         month, | |||
|  |         day | |||
|  |     ) | |||
|  | 
 | |||
|  |     # 创建目录(确保目录存在) | |||
|  |     os.makedirs(file_dir, exist_ok=True) | |||
|  | 
 | |||
|  |     # 1. 完整路径:用于实际保存文件(使用带时间戳的唯一文件名) | |||
|  |     full_path = os.path.join(file_dir, unique_filename) | |||
|  |     # 2. 相对路径:用于返回给前端 | |||
|  |     relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep) | |||
|  | 
 | |||
|  |     # 保存文件(使用完整路径) | |||
|  |     try: | |||
|  |         with open(full_path, "wb") as buffer: | |||
|  |             shutil.copyfileobj(upload_file.file, buffer) | |||
|  |     finally: | |||
|  |         upload_file.file.close() | |||
|  | 
 | |||
|  |     # 统一路径分隔符为/ | |||
|  |     return PRE + relative_path.replace(os.sep, "/") | |||
|  | 
 | |||
|  | 
 | |||
|  | def get_absolute_path(relative_path: str) -> str: | |||
|  |     """
 | |||
|  |     根据相对路径获取服务器上的绝对路径 | |||
|  |     """
 | |||
|  |     path_without_pre = relative_path.replace(PRE, "", 1) | |||
|  | 
 | |||
|  |     # 将相对路径转换为系统兼容的格式 | |||
|  |     normalized_path = os.path.normpath(path_without_pre) | |||
|  | 
 | |||
|  |     # 拼接基础路径和相对路径,得到绝对路径 | |||
|  |     absolute_path = os.path.abspath(os.path.join(UPLOAD_ROOT, normalized_path)) | |||
|  | 
 | |||
|  |     # 安全检查:确保生成的路径在UPLOAD_ROOT目录下,防止路径遍历 | |||
|  |     if not absolute_path.startswith(os.path.abspath(UPLOAD_ROOT)): | |||
|  |         raise ValueError("无效的相对路径,可能试图访问上传目录之外的内容") | |||
|  | 
 | |||
|  |     return absolute_path |