Files
video_detect/service/file_service.py
2025-09-30 17:17:20 +08:00

343 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import re
import shutil
from datetime import datetime
from PIL import ImageDraw, ImageFont
from fastapi import UploadFile
import cv2
from PIL import Image
import numpy as np
# 上传根目录
UPLOAD_ROOT = "upload"
PRE = "/api/file/download/"
# 确保上传根目录存在
os.makedirs(UPLOAD_ROOT, exist_ok=True)
def save_detect_file(client_ip: str, image_np: np.ndarray, file_type: str) -> str:
"""保存numpy数组格式的PNG图片到detect目录返回下载路径"""
today = datetime.now()
year = today.strftime("%Y")
month = today.strftime("%m")
day = today.strftime("%d")
# 构建目录路径: upload/detect/客户端IP/type/年/月/日包含UPLOAD_ROOT
file_dir = os.path.join(
UPLOAD_ROOT,
"detect",
client_ip,
file_type,
year,
month,
day
)
# 创建目录(确保目录存在)
os.makedirs(file_dir, exist_ok=True)
# 生成当前时间戳作为文件名,确保唯一性
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
filename = f"{timestamp}.png"
# 1. 完整路径用于实际保存文件包含UPLOAD_ROOT
full_path = os.path.join(file_dir, filename)
# 2. 相对路径用于返回给前端移除UPLOAD_ROOT前缀
relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep)
# 保存numpy数组为PNG图片
try:
# -------- 新增/修改:处理颜色通道和数据类型 --------
# 1. 数据类型转换确保是uint8若为float32且范围0-1需转成0-255的uint8
if image_np.dtype != np.uint8:
image_np = (image_np * 255).astype(np.uint8)
# 2. 通道顺序转换若为OpenCV的BGR格式转成PIL需要的RGB格式
image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
# 3. 转换为PIL Image并保存
img = Image.fromarray(image_rgb)
img.save(full_path, format='PNG')
except Exception as e:
# 处理可能的异常(如数组格式不正确)
raise Exception(f"保存图片失败: {str(e)}")
# 统一路径分隔符为/,拼接前缀返回
return PRE + relative_path.replace(os.sep, "/")
def save_detect_yolo_file(
client_ip: str,
image_np: np.ndarray,
detection_results: list,
file_type: str = "yolo"
) -> str:
print("......................")
"""
保存YOLO检测结果图片在原图上绘制边界框+标签),返回前端可访问的下载路径
"""
# 输入参数验证
if not isinstance(image_np, np.ndarray):
raise ValueError(f"输入image_np必须是numpy数组当前类型{type(image_np)}")
if image_np.ndim != 3 or image_np.shape[-1] != 3:
raise ValueError(f"输入图像必须是 (h, w, 3) 的BGR数组当前shape{image_np.shape}")
if not isinstance(detection_results, list):
raise ValueError(f"detection_results必须是列表当前类型{type(detection_results)}")
for idx, result in enumerate(detection_results):
required_keys = {"class", "confidence", "bbox"}
if not isinstance(result, dict) or not required_keys.issubset(result.keys()):
raise ValueError(
f"detection_results第{idx}个元素格式错误,需包含键:{required_keys}"
f"当前元素:{result}"
)
bbox = result["bbox"]
if not (isinstance(bbox, (tuple, list)) and len(bbox) == 4 and all(isinstance(x, int) for x in bbox)):
raise ValueError(
f"detection_results第{idx}个元素的bbox格式错误需为(x1,y1,x2,y2)整数元组,"
f"当前bbox{bbox}"
)
#图像预处理(数据类型+通道)
draw_image = image_np.copy()
if draw_image.dtype != np.uint8:
draw_image = np.clip(draw_image * 255, 0, 255).astype(np.uint8)
#绘制边界框+标签
# 遍历所有检测结果,逐个绘制
for result in detection_results:
class_name = result["class"]
confidence = result["confidence"]
x1, y1, x2, y2 = result["bbox"]
cv2.rectangle(draw_image, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2)
label = f"{class_name}: {confidence:.2f}"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
font_thickness = 2
(label_width, label_height), baseline = cv2.getTextSize(
label, font, font_scale, font_thickness
)
bg_top_left = (x1, y1 - label_height - 10)
bg_bottom_right = (x1 + label_width, y1)
if bg_top_left[1] < 0:
bg_top_left = (x1, 0)
bg_bottom_right = (x1 + label_width, label_height + 10)
cv2.rectangle(draw_image, bg_top_left, bg_bottom_right, color=(0, 0, 0), thickness=-1)
text_origin = (x1, y1 - 5)
if bg_top_left[1] == 0:
text_origin = (x1, label_height + 5)
cv2.putText(
draw_image, label, text_origin,
font, font_scale, color=(255, 255, 255), thickness=font_thickness
)
#保存图片
try:
today = datetime.now()
year = today.strftime("%Y")
month = today.strftime("%m")
day = today.strftime("%d")
file_dir = os.path.join(
UPLOAD_ROOT, "detect", client_ip, file_type, year, month, day
)
#创建目录(若不存在则创建,支持多级目录)
os.makedirs(file_dir, exist_ok=True)
#生成唯一文件名
timestamp = today.strftime("%Y%m%d%H%M%S%f")
filename = f"{timestamp}.png"
# 4.4 构建完整保存路径和前端访问路径
full_path = os.path.join(file_dir, filename) # 本地完整路径
# 相对路径移除UPLOAD_ROOT前缀统一用"/"作为分隔符兼容Windows/Linux
relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep)
download_path = PRE + relative_path.replace(os.sep, "/")
# 4.5 保存图片CV2绘制的是BGR需转RGB后用PIL保存与原逻辑一致
image_rgb = cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB)
img_pil = Image.fromarray(image_rgb)
img_pil.save(full_path, format="PNG", quality=95) # PNG格式无压缩quality可忽略
print(f"YOLO检测图片保存成功 | 本地路径:{full_path} | 下载路径:{download_path}")
return download_path
except Exception as e:
raise Exception(f"YOLO检测图片保存失败{str(e)}") from e
def save_detect_face_file(
client_ip: str,
image_np: np.ndarray,
face_result: str,
file_type: str = "face",
matched_color: tuple = (0, 255, 0)
) -> str:
"""
保存人脸识别结果图片(仅为「匹配成功」的人脸画框,标签不包含“匹配”二字)
"""
#输入参数验证
if not isinstance(image_np, np.ndarray) or image_np.ndim != 3 or image_np.shape[-1] != 3:
raise ValueError(f"输入图像需为 (h, w, 3) 的BGR数组当前shape{image_np.shape}")
if not isinstance(face_result, str) or face_result.strip() == "":
raise ValueError("face_result必须是非空字符串")
# 解析face_result提取人脸信息
face_info_list = []
if face_result.strip() != "未检测到人脸":
face_pattern = re.compile(
r"(匹配|未匹配):\s*([^\s(]+)\s*\(相似度:\s*(\d+\.\d+),\s*边界框:\s*\[(\d+,\s*\d+,\s*\d+,\s*\d+)\]\)"
)
for part in [p.strip() for p in face_result.split(";") if p.strip()]:
match = face_pattern.match(part)
if match:
status, name, similarity, bbox_str = match.groups()
bbox = list(map(int, bbox_str.replace(" ", "").split(",")))
if len(bbox) == 4:
face_info_list.append({
"status": status,
"name": name,
"similarity": float(similarity),
"bbox": bbox
})
# 图像格式转换OpenCV→PIL
image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(image_rgb)
draw = ImageDraw.Draw(pil_img)
# 绘制边界框和标签
font_size = 12
try:
font = ImageFont.truetype("simhei", font_size)
except:
try:
font = ImageFont.truetype("simsun", font_size)
except:
font = ImageFont.load_default()
print("警告未找到指定中文字体使用PIL默认字体可能影响中文显示")
for face_info in face_info_list:
status = face_info["status"]
if status != "匹配":
print(f"跳过未匹配人脸:{face_info['name']}(相似度:{face_info['similarity']:.2f}")
continue
name = face_info["name"]
similarity = face_info["similarity"]
x1, y1, x2, y2 = face_info["bbox"]
# 4.1 绘制边界框(绿色)
img_cv = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
cv2.rectangle(img_cv, (x1, y1), (x2, y2), color=matched_color, thickness=2)
pil_img = Image.fromarray(cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_img)
label = f"{name} (相似度: {similarity:.2f})"
# 4.3 计算标签尺寸(文本变短后会自动适配,无需额外调整)
label_bbox = draw.textbbox((0, 0), label, font=font)
label_width = label_bbox[2] - label_bbox[0]
label_height = label_bbox[3] - label_bbox[1]
# 4.4 计算标签背景位置(避免超出图像)
bg_x1, bg_y1 = x1, y1 - label_height - 10
bg_x2, bg_y2 = x1 + label_width, y1
if bg_y1 < 0:
bg_y1, bg_y2 = y2 + 5, y2 + label_height + 15
# 4.5 绘制标签背景(黑色)和文本(白色)
draw.rectangle([(bg_x1, bg_y1), (bg_x2, bg_y2)], fill=(0, 0, 0))
text_x = bg_x1
text_y = bg_y1 if bg_y1 >= 0 else bg_y1 + label_height
draw.text((text_x, text_y), label, font=font, fill=(255, 255, 255))
#保存图片
try:
today = datetime.now()
file_dir = os.path.join(
UPLOAD_ROOT, "detect", client_ip, file_type,
today.strftime("%Y"), today.strftime("%m"), today.strftime("%d")
)
os.makedirs(file_dir, exist_ok=True)
timestamp = today.strftime("%Y%m%d%H%M%S%f")
filename = f"{timestamp}.png"
full_path = os.path.join(file_dir, filename)
pil_img.save(full_path, format="PNG", quality=100)
relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep)
download_path = PRE + relative_path.replace(os.sep, "/")
matched_count = sum(1 for info in face_info_list if info["status"] == "匹配")
print(f"人脸检测图片保存成功 | 客户端IP{client_ip} | 匹配人脸数:{matched_count} | 保存路径:{download_path}")
return download_path
except Exception as e:
raise Exception(f"人脸检测图片保存失败客户端IP{client_ip}{str(e)}") from e
def save_source_file(upload_file: UploadFile, file_type: str) -> str:
"""保存上传的文件到source目录返回下载路径"""
today = datetime.now()
year = today.strftime("%Y")
month = today.strftime("%m")
day = today.strftime("%d")
# 生成精确到微秒的时间戳,确保文件名唯一
timestamp = today.strftime("%Y%m%d%H%M%S%f")
# 构建新文件名时间戳_原文件名
unique_filename = f"{timestamp}_{upload_file.filename}"
# 构建目录路径: upload/source/type/年/月/日包含UPLOAD_ROOT
file_dir = os.path.join(
UPLOAD_ROOT,
"source",
file_type,
year,
month,
day
)
# 创建目录(确保目录存在)
os.makedirs(file_dir, exist_ok=True)
# 1. 完整路径:用于实际保存文件(使用带时间戳的唯一文件名)
full_path = os.path.join(file_dir, unique_filename)
# 2. 相对路径:用于返回给前端
relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep)
# 保存文件(使用完整路径)
try:
with open(full_path, "wb") as buffer:
shutil.copyfileobj(upload_file.file, buffer)
finally:
upload_file.file.close()
# 统一路径分隔符为/
return PRE + relative_path.replace(os.sep, "/")
def get_absolute_path(relative_path: str) -> str:
"""
根据相对路径获取服务器上的绝对路径
"""
path_without_pre = relative_path.replace(PRE, "", 1)
# 将相对路径转换为系统兼容的格式
normalized_path = os.path.normpath(path_without_pre)
# 拼接基础路径和相对路径,得到绝对路径
absolute_path = os.path.abspath(os.path.join(UPLOAD_ROOT, normalized_path))
# 安全检查确保生成的路径在UPLOAD_ROOT目录下防止路径遍历
if not absolute_path.startswith(os.path.abspath(UPLOAD_ROOT)):
raise ValueError("无效的相对路径,可能试图访问上传目录之外的内容")
return absolute_path