内容安全审核
This commit is contained in:
BIN
service/__pycache__/device_action_service.cpython-310.pyc
Normal file
BIN
service/__pycache__/device_action_service.cpython-310.pyc
Normal file
Binary file not shown.
BIN
service/__pycache__/device_danger_service.cpython-310.pyc
Normal file
BIN
service/__pycache__/device_danger_service.cpython-310.pyc
Normal file
Binary file not shown.
BIN
service/__pycache__/device_service.cpython-310.pyc
Normal file
BIN
service/__pycache__/device_service.cpython-310.pyc
Normal file
Binary file not shown.
BIN
service/__pycache__/face_service.cpython-310.pyc
Normal file
BIN
service/__pycache__/face_service.cpython-310.pyc
Normal file
Binary file not shown.
BIN
service/__pycache__/file_service.cpython-310.pyc
Normal file
BIN
service/__pycache__/file_service.cpython-310.pyc
Normal file
Binary file not shown.
BIN
service/__pycache__/model_service.cpython-310.pyc
Normal file
BIN
service/__pycache__/model_service.cpython-310.pyc
Normal file
Binary file not shown.
BIN
service/__pycache__/ocr_service.cpython-310.pyc
Normal file
BIN
service/__pycache__/ocr_service.cpython-310.pyc
Normal file
Binary file not shown.
BIN
service/__pycache__/sensitive_service.cpython-310.pyc
Normal file
BIN
service/__pycache__/sensitive_service.cpython-310.pyc
Normal file
Binary file not shown.
43
service/device_action_service.py
Normal file
43
service/device_action_service.py
Normal file
@ -0,0 +1,43 @@
|
||||
from ds.db import db
|
||||
from schema.device_action_schema import DeviceActionCreate, DeviceActionResponse
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
|
||||
# 新增设备操作记录
|
||||
def add_device_action(client_ip: str, action: int) -> DeviceActionResponse:
|
||||
"""
|
||||
新增设备操作记录(内部方法、非接口)
|
||||
:param action_data: 含client_ip和action(0/1)
|
||||
:return: 新增的完整记录
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 插入SQL(id自增、依赖数据库自动生成)
|
||||
insert_query = """
|
||||
INSERT INTO device_action
|
||||
(client_ip, action, created_at, updated_at)
|
||||
VALUES (%s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
"""
|
||||
cursor.execute(insert_query, (
|
||||
client_ip,
|
||||
action
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
# 获取新增记录(通过自增ID查询)
|
||||
new_id = cursor.lastrowid
|
||||
cursor.execute("SELECT * FROM device_action WHERE id = %s", (new_id,))
|
||||
new_action = cursor.fetchone()
|
||||
|
||||
return DeviceActionResponse(**new_action)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"新增记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
78
service/device_danger_service.py
Normal file
78
service/device_danger_service.py
Normal file
@ -0,0 +1,78 @@
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
from schema.device_danger_schema import DeviceDangerCreateRequest, DeviceDangerResponse
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 内部工具方法 - 检查设备是否存在(复用设备表逻辑)
|
||||
# ------------------------------
|
||||
def check_device_exist(client_ip: str) -> bool:
|
||||
"""
|
||||
检查指定IP的设备是否在devices表中存在
|
||||
|
||||
:param client_ip: 设备IP地址
|
||||
:return: 存在返回True、不存在返回False
|
||||
"""
|
||||
if not client_ip:
|
||||
raise ValueError("设备IP不能为空")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
|
||||
return cursor.fetchone() is not None
|
||||
except MySQLError as e:
|
||||
raise Exception(f"检查设备存在性失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 内部工具方法 - 创建设备危险记录(核心插入逻辑)
|
||||
# ------------------------------
|
||||
def create_danger_record(danger_data: DeviceDangerCreateRequest) -> DeviceDangerResponse:
|
||||
"""
|
||||
内部工具方法:向device_danger表插入新的危险记录
|
||||
|
||||
:param danger_data: 危险记录创建请求数据
|
||||
:return: 创建成功的危险记录模型对象
|
||||
"""
|
||||
# 先检查设备是否存在
|
||||
if not check_device_exist(danger_data.client_ip):
|
||||
raise ValueError(f"IP为 {danger_data.client_ip} 的设备不存在、无法创建危险记录")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 插入危险记录(id自增、时间自动填充)
|
||||
insert_query = """
|
||||
INSERT INTO device_danger
|
||||
(client_ip, type, result, created_at, updated_at)
|
||||
VALUES (%s, %s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
"""
|
||||
cursor.execute(insert_query, (
|
||||
danger_data.client_ip,
|
||||
danger_data.type,
|
||||
danger_data.result
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
# 获取刚创建的记录(用自增ID查询)
|
||||
danger_id = cursor.lastrowid
|
||||
cursor.execute("SELECT * FROM device_danger WHERE id = %s", (danger_id,))
|
||||
new_danger = cursor.fetchone()
|
||||
|
||||
return DeviceDangerResponse(**new_danger)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"插入危险记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
250
service/device_service.py
Normal file
250
service/device_service.py
Normal file
@ -0,0 +1,250 @@
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
from service.device_action_service import add_device_action
|
||||
_last_alarm_timestamps: dict[str, float] = {}
|
||||
_timestamp_lock = threading.Lock()
|
||||
|
||||
# 获取所有去重的客户端IP列表
|
||||
def get_unique_client_ips() -> list[str]:
|
||||
"""获取所有去重的客户端IP列表"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
query = "SELECT DISTINCT client_ip FROM devices WHERE client_ip IS NOT NULL"
|
||||
cursor.execute(query)
|
||||
results = cursor.fetchall()
|
||||
return [item['client_ip'] for item in results]
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取客户端IP列表失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
# 通过客户端IP更新设备是否需要处理
|
||||
def update_is_need_handler_by_client_ip(client_ip: str, is_need_handler: int) -> bool:
|
||||
"""
|
||||
通过客户端IP更新设备的「是否需要处理」状态(is_need_handler字段)
|
||||
"""
|
||||
# 参数合法性校验
|
||||
if not client_ip:
|
||||
raise ValueError("客户端IP不能为空")
|
||||
|
||||
# 校验is_need_handler取值(需与数据库字段类型匹配、通常为0/1 tinyint)
|
||||
if is_need_handler not in (0, 1):
|
||||
raise ValueError("是否需要处理(is_need_handler)必须是0(不需要)或1(需要)")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
# 2. 获取数据库连接与游标
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 3. 先校验设备是否存在(通过client_ip定位)
|
||||
cursor.execute(
|
||||
"SELECT id FROM devices WHERE client_ip = %s",
|
||||
(client_ip,)
|
||||
)
|
||||
device = cursor.fetchone()
|
||||
if not device:
|
||||
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在、无法更新「是否需要处理」状态")
|
||||
|
||||
# 4. 执行更新操作(同时更新时间戳、保持与其他更新逻辑一致性)
|
||||
update_query = """
|
||||
UPDATE devices
|
||||
SET is_need_handler = %s,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE client_ip = %s
|
||||
"""
|
||||
cursor.execute(update_query, (is_need_handler, client_ip))
|
||||
|
||||
# 5. 确认更新生效(判断影响行数、避免无意义更新)
|
||||
if cursor.rowcount <= 0:
|
||||
raise Exception(f"更新失败:客户端IP {client_ip} 的设备未发生状态变更(可能已为目标值)")
|
||||
|
||||
# 6. 提交事务
|
||||
conn.commit()
|
||||
return True
|
||||
|
||||
except MySQLError as e:
|
||||
# 数据库异常时回滚事务
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"数据库操作失败:更新设备「是否需要处理」状态时出错 - {str(e)}") from e
|
||||
finally:
|
||||
# 无论成功失败、都关闭数据库连接(避免连接泄漏)
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
def increment_alarm_count_by_ip(client_ip: str) -> bool:
|
||||
"""
|
||||
通过客户端IP增加设备的报警次数,相同IP 200ms内重复调用会被忽略
|
||||
|
||||
:param client_ip: 客户端IP地址
|
||||
:return: 操作是否成功(是否实际执行了数据库更新)
|
||||
"""
|
||||
if not client_ip:
|
||||
raise ValueError("客户端IP不能为空")
|
||||
|
||||
current_time = time.time() # 获取当前时间戳(秒,含小数)
|
||||
with _timestamp_lock: # 确保线程安全的字典操作
|
||||
last_time: Optional[float] = _last_alarm_timestamps.get(client_ip)
|
||||
|
||||
# 如果存在最近记录且间隔小于200ms,直接返回False(不执行更新)
|
||||
if last_time is not None and (current_time - last_time) < 0.2:
|
||||
return False
|
||||
|
||||
# 更新当前IP的最近调用时间
|
||||
_last_alarm_timestamps[client_ip] = current_time
|
||||
|
||||
# 2. 执行数据库更新操作(只有通过时间校验才会执行)
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 检查设备是否存在
|
||||
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
|
||||
device = cursor.fetchone()
|
||||
if not device:
|
||||
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
|
||||
|
||||
# 报警次数加1、并更新时间戳
|
||||
update_query = """
|
||||
UPDATE devices
|
||||
SET alarm_count = alarm_count + 1,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE client_ip = %s
|
||||
"""
|
||||
cursor.execute(update_query, (client_ip,))
|
||||
conn.commit()
|
||||
|
||||
return True
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"更新报警次数失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
# 通过客户端IP更新设备在线状态
|
||||
def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
|
||||
"""
|
||||
通过客户端IP更新设备的在线状态
|
||||
|
||||
:param client_ip: 客户端IP地址
|
||||
:param online_status: 在线状态(1-在线、0-离线)
|
||||
:return: 操作是否成功
|
||||
"""
|
||||
if not client_ip:
|
||||
raise ValueError("客户端IP不能为空")
|
||||
|
||||
if online_status not in (0, 1):
|
||||
raise ValueError("在线状态必须是0(离线)或1(在线)")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 检查设备是否存在并获取设备ID
|
||||
cursor.execute("SELECT id, device_online_status FROM devices WHERE client_ip = %s", (client_ip,))
|
||||
device = cursor.fetchone()
|
||||
if not device:
|
||||
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
|
||||
|
||||
# 状态无变化则不操作
|
||||
if device['device_online_status'] == online_status:
|
||||
return True
|
||||
|
||||
# 更新在线状态和时间戳
|
||||
update_query = """
|
||||
UPDATE devices
|
||||
SET device_online_status = %s,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE client_ip = %s
|
||||
"""
|
||||
cursor.execute(update_query, (online_status, client_ip))
|
||||
|
||||
# 记录状态变更历史
|
||||
add_device_action(client_ip, online_status)
|
||||
|
||||
conn.commit()
|
||||
return True
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"更新设备在线状态失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
# 通过客户端IP查询设备在数据库中存在
|
||||
def is_device_exist_by_ip(client_ip: str) -> bool:
|
||||
"""
|
||||
通过客户端IP查询设备在数据库中是否存在
|
||||
"""
|
||||
if not client_ip:
|
||||
raise ValueError("客户端IP不能为空")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 查询设备是否存在
|
||||
cursor.execute(
|
||||
"SELECT id FROM devices WHERE client_ip = %s",
|
||||
(client_ip,)
|
||||
)
|
||||
device = cursor.fetchone()
|
||||
|
||||
# 如果查询到结果则存在,否则不存在
|
||||
return bool(device)
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询设备是否存在失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
# 根据客户端IP获取是否需要处理
|
||||
def get_is_need_handler_by_ip(client_ip: str) -> int:
|
||||
"""
|
||||
通过客户端IP查询设备的is_need_handler状态
|
||||
|
||||
:param client_ip: 客户端IP地址
|
||||
:return: 设备的is_need_handler状态(0-不需要处理,1-需要处理)
|
||||
"""
|
||||
if not client_ip:
|
||||
raise ValueError("客户端IP不能为空")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 查询设备的is_need_handler状态
|
||||
cursor.execute(
|
||||
"SELECT is_need_handler FROM devices WHERE client_ip = %s",
|
||||
(client_ip,)
|
||||
)
|
||||
device = cursor.fetchone()
|
||||
|
||||
if not device:
|
||||
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
|
||||
|
||||
return device['is_need_handler']
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询设备is_need_handler状态失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
339
service/face_service.py
Normal file
339
service/face_service.py
Normal file
@ -0,0 +1,339 @@
|
||||
import cv2
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
|
||||
import numpy as np
|
||||
import threading
|
||||
from insightface.app import FaceAnalysis
|
||||
|
||||
# 全局变量定义
|
||||
_insightface_app = None
|
||||
_known_faces_embeddings = {} # 存储已知人脸特征 {姓名: 特征向量}
|
||||
_known_faces_names = [] # 存储已知人脸姓名列表
|
||||
|
||||
|
||||
def init_insightface():
|
||||
"""初始化InsightFace引擎"""
|
||||
global _insightface_app
|
||||
if _insightface_app is not None:
|
||||
print("InsightFace引擎已初始化,无需重复执行")
|
||||
return _insightface_app
|
||||
|
||||
try:
|
||||
print("正在初始化 InsightFace 引擎(模型:buffalo_l)...")
|
||||
# 初始化引擎,指定模型路径和计算 providers
|
||||
app = FaceAnalysis(
|
||||
name='buffalo_l',
|
||||
root='~/.insightface',
|
||||
providers=['CPUExecutionProvider'] # 如需GPU可添加'CUDAExecutionProvider'
|
||||
)
|
||||
app.prepare(ctx_id=0, det_size=(640, 640)) # 调整检测尺寸
|
||||
print("InsightFace 引擎初始化完成")
|
||||
|
||||
# 初始化时加载人脸特征库
|
||||
init_face_data()
|
||||
|
||||
_insightface_app = app
|
||||
return app
|
||||
except Exception as e:
|
||||
print(f"InsightFace 初始化失败:{str(e)}")
|
||||
_insightface_app = None
|
||||
return None
|
||||
|
||||
|
||||
def init_face_data():
|
||||
"""初始化或更新人脸特征库(清空旧数据,避免重复)"""
|
||||
global _known_faces_embeddings, _known_faces_names
|
||||
# 清空原有数据,防止重复加载
|
||||
_known_faces_embeddings.clear()
|
||||
_known_faces_names.clear()
|
||||
|
||||
try:
|
||||
face_data = get_all_face_name_with_eigenvalue() # 假设该函数已定义
|
||||
print(f"已加载 {len(face_data)} 个人脸数据")
|
||||
for person_name, eigenvalue_data in face_data.items():
|
||||
# 解析特征值(支持numpy数组或字符串格式)
|
||||
if isinstance(eigenvalue_data, np.ndarray):
|
||||
eigenvalue = eigenvalue_data.astype(np.float32)
|
||||
elif isinstance(eigenvalue_data, str):
|
||||
# 增强字符串解析:支持逗号/空格分隔,清理特殊字符
|
||||
cleaned = (eigenvalue_data
|
||||
.replace("[", "").replace("]", "")
|
||||
.replace("\n", "").replace(",", " ")
|
||||
.strip())
|
||||
values = [v for v in cleaned.split() if v] # 过滤空字符串
|
||||
if not values:
|
||||
print(f"特征值解析失败(空值),跳过 {person_name}")
|
||||
continue
|
||||
eigenvalue = np.array(list(map(float, values)), dtype=np.float32)
|
||||
else:
|
||||
print(f"不支持的特征值类型({type(eigenvalue_data)}),跳过 {person_name}")
|
||||
continue
|
||||
|
||||
# 特征值归一化(确保相似度计算一致性)
|
||||
norm = np.linalg.norm(eigenvalue)
|
||||
if norm == 0:
|
||||
print(f"特征值为零向量,跳过 {person_name}")
|
||||
continue
|
||||
eigenvalue = eigenvalue / norm
|
||||
|
||||
# 更新全局特征库
|
||||
_known_faces_embeddings[person_name] = eigenvalue
|
||||
_known_faces_names.append(person_name)
|
||||
|
||||
print(f"成功加载 {len(_known_faces_names)} 个人脸的特征库")
|
||||
except Exception as e:
|
||||
print(f"加载人脸特征库失败: {e}")
|
||||
|
||||
|
||||
def update_face_data():
|
||||
"""更新人脸特征库(清空旧数据,加载最新数据)"""
|
||||
global _known_faces_embeddings, _known_faces_names
|
||||
|
||||
print("开始更新人脸特征库...")
|
||||
|
||||
# 清空原有数据
|
||||
_known_faces_embeddings.clear()
|
||||
_known_faces_names.clear()
|
||||
|
||||
try:
|
||||
# 获取最新人脸数据
|
||||
face_data = get_all_face_name_with_eigenvalue()
|
||||
print(f"获取到 {len(face_data)} 条最新人脸数据")
|
||||
|
||||
# 处理并加载新数据(复用原有解析逻辑)
|
||||
for person_name, eigenvalue_data in face_data.items():
|
||||
# 解析特征值(支持numpy数组或字符串格式)
|
||||
if isinstance(eigenvalue_data, np.ndarray):
|
||||
eigenvalue = eigenvalue_data.astype(np.float32)
|
||||
elif isinstance(eigenvalue_data, str):
|
||||
# 增强字符串解析:支持逗号/空格分隔,清理特殊字符
|
||||
cleaned = (eigenvalue_data
|
||||
.replace("[", "").replace("]", "")
|
||||
.replace("\n", "").replace(",", " ")
|
||||
.strip())
|
||||
values = [v for v in cleaned.split() if v] # 过滤空字符串
|
||||
if not values:
|
||||
print(f"特征值解析失败(空值),跳过 {person_name}")
|
||||
continue
|
||||
eigenvalue = np.array(list(map(float, values)), dtype=np.float32)
|
||||
else:
|
||||
print(f"不支持的特征值类型({type(eigenvalue_data)}),跳过 {person_name}")
|
||||
continue
|
||||
|
||||
# 特征值归一化(确保相似度计算一致性)
|
||||
norm = np.linalg.norm(eigenvalue)
|
||||
if norm == 0:
|
||||
print(f"特征值为零向量,跳过 {person_name}")
|
||||
continue
|
||||
eigenvalue = eigenvalue / norm
|
||||
|
||||
# 更新全局特征库
|
||||
_known_faces_embeddings[person_name] = eigenvalue
|
||||
_known_faces_names.append(person_name)
|
||||
|
||||
print(f"人脸特征库更新完成,共加载 {len(_known_faces_names)} 个人脸数据")
|
||||
return True # 更新成功
|
||||
except Exception as e:
|
||||
print(f"人脸特征库更新失败: {e}")
|
||||
return False # 更新失败
|
||||
|
||||
|
||||
def detect(frame, similarity_threshold=0.4):
|
||||
global _insightface_app, _known_faces_embeddings
|
||||
|
||||
# 校验输入有效性
|
||||
if frame is None or frame.size == 0:
|
||||
return (False, "无效的输入帧数据")
|
||||
|
||||
# 校验引擎和特征库状态
|
||||
if not _insightface_app:
|
||||
return (False, "人脸引擎未初始化")
|
||||
if not _known_faces_embeddings:
|
||||
return (False, "人脸特征库为空")
|
||||
|
||||
try:
|
||||
# 执行人脸检测与特征提取
|
||||
faces = _insightface_app.get(frame)
|
||||
except Exception as e:
|
||||
return (False, f"检测错误: {str(e)}")
|
||||
|
||||
result_parts = []
|
||||
has_matched_known_face = False # 是否有匹配到已知人脸
|
||||
|
||||
for face in faces:
|
||||
# 归一化当前人脸特征
|
||||
face_embedding = face.embedding.astype(np.float32)
|
||||
norm = np.linalg.norm(face_embedding)
|
||||
if norm == 0:
|
||||
result_parts.append("检测到人脸但特征值为零向量(忽略)")
|
||||
continue
|
||||
face_embedding = face_embedding / norm
|
||||
|
||||
# 与已知特征库比对
|
||||
max_similarity, best_match_name = -1.0, "Unknown"
|
||||
for name, known_emb in _known_faces_embeddings.items():
|
||||
similarity = np.dot(face_embedding, known_emb) # 余弦相似度
|
||||
if similarity > max_similarity:
|
||||
max_similarity = similarity
|
||||
best_match_name = name
|
||||
|
||||
# 判断是否匹配成功
|
||||
is_matched = max_similarity >= similarity_threshold
|
||||
|
||||
if is_matched:
|
||||
has_matched_known_face = True
|
||||
|
||||
# 记录结果(边界框转为整数列表)
|
||||
bbox = face.bbox.astype(int).tolist()
|
||||
result_parts.append(
|
||||
f"{'匹配' if is_matched else '未匹配'}: {best_match_name} "
|
||||
f"(相似度: {max_similarity:.2f}, 边界框: {bbox})"
|
||||
)
|
||||
|
||||
# 构建最终结果
|
||||
result_str = "未检测到人脸" if not result_parts else "; ".join(result_parts)
|
||||
return (has_matched_known_face, result_str)
|
||||
|
||||
|
||||
# 上传图片并提取特征
|
||||
def add_binary_data(binary_data):
|
||||
"""
|
||||
接收单张图片的二进制数据、提取特征并保存
|
||||
返回:(True, 特征值numpy数组) 或 (False, 错误信息字符串)
|
||||
"""
|
||||
global _insightface_app, _feature_list
|
||||
|
||||
# 1. 先检查引擎是否初始化成功
|
||||
if not _insightface_app:
|
||||
init_result = init_insightface() # 尝试重新初始化
|
||||
if not init_result:
|
||||
error_msg = "InsightFace引擎未初始化、无法检测人脸"
|
||||
print(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
try:
|
||||
# 2. 验证二进制数据有效性
|
||||
if len(binary_data) < 1024: # 过滤过小的无效图片(小于1KB)
|
||||
error_msg = f"图片过小({len(binary_data)}字节)、可能不是有效图片"
|
||||
print(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
# 3. 二进制数据转CV2格式(关键步骤、避免通道错误)
|
||||
try:
|
||||
img = Image.open(BytesIO(binary_data)).convert("RGB") # 强制转RGB
|
||||
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) # InsightFace需要BGR格式
|
||||
except Exception as e:
|
||||
error_msg = f"图片格式转换失败:{str(e)}"
|
||||
print(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
# 4. 检查图片尺寸(避免极端尺寸导致检测失败)
|
||||
height, width = frame.shape[:2]
|
||||
if height < 64 or width < 64: # 人脸检测最小建议尺寸
|
||||
error_msg = f"图片尺寸过小({width}x{height})、需至少64x64像素"
|
||||
print(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
# 5. 调用InsightFace检测人脸
|
||||
print(f"开始检测人脸(图片尺寸:{width}x{height}、格式:BGR)")
|
||||
faces = _insightface_app.get(frame)
|
||||
|
||||
if not faces:
|
||||
error_msg = "未检测到人脸(请确保图片包含清晰正面人脸、无遮挡、不模糊)"
|
||||
print(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
# 6. 提取特征并保存
|
||||
current_feature = faces[0].embedding
|
||||
_feature_list.append(current_feature)
|
||||
print(f"人脸检测成功、提取特征值(维度:{current_feature.shape[0]})、累计特征数:{len(_feature_list)}")
|
||||
return True, current_feature
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"处理图片时发生异常:{str(e)}"
|
||||
print(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
|
||||
# 获取数据库最新的人脸及其特征
|
||||
def get_all_face_name_with_eigenvalue() -> dict:
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
|
||||
cursor.execute(query)
|
||||
faces = cursor.fetchall()
|
||||
|
||||
name_to_eigenvalues = {}
|
||||
for face in faces:
|
||||
name = face["name"]
|
||||
eigenvalue = face["eigenvalue"]
|
||||
if name in name_to_eigenvalues:
|
||||
name_to_eigenvalues[name].append(eigenvalue)
|
||||
else:
|
||||
name_to_eigenvalues[name] = [eigenvalue]
|
||||
|
||||
face_dict = {}
|
||||
for name, eigenvalues in name_to_eigenvalues.items():
|
||||
if len(eigenvalues) > 1:
|
||||
face_dict[name] = get_average_feature(eigenvalues)
|
||||
else:
|
||||
face_dict[name] = eigenvalues[0]
|
||||
|
||||
return face_dict
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取人脸特征失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 获取平均特征值
|
||||
def get_average_feature(features=None):
|
||||
global _feature_list
|
||||
try:
|
||||
if features is None:
|
||||
features = _feature_list
|
||||
if not isinstance(features, list) or len(features) == 0:
|
||||
print("输入必须是包含至少一个特征值的列表")
|
||||
return None
|
||||
|
||||
processed_features = []
|
||||
for i, embedding in enumerate(features):
|
||||
try:
|
||||
if isinstance(embedding, str):
|
||||
embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip()
|
||||
embedding_list = [float(num) for num in embedding_str.split() if num.strip()]
|
||||
embedding_np = np.array(embedding_list, dtype=np.float32)
|
||||
else:
|
||||
embedding_np = np.array(embedding, dtype=np.float32)
|
||||
|
||||
if len(embedding_np.shape) == 1:
|
||||
processed_features.append(embedding_np)
|
||||
print(f"已添加第 {i + 1} 个特征值用于计算平均值")
|
||||
else:
|
||||
print(f"跳过第 {i + 1} 个特征值:不是一维数组")
|
||||
except Exception as e:
|
||||
print(f"处理第 {i + 1} 个特征值时出错:{str(e)}")
|
||||
|
||||
if not processed_features:
|
||||
print("没有有效的特征值用于计算平均值")
|
||||
return None
|
||||
|
||||
dims = {feat.shape[0] for feat in processed_features}
|
||||
if len(dims) > 1:
|
||||
print(f"特征值维度不一致:{dims}、无法计算平均值")
|
||||
return None
|
||||
|
||||
avg_feature = np.mean(processed_features, axis=0)
|
||||
print(f"计算成功:{len(processed_features)} 个特征值的平均向量(维度:{avg_feature.shape[0]})")
|
||||
return avg_feature
|
||||
except Exception as e:
|
||||
print(f"计算平均特征值出错:{str(e)}")
|
||||
return None
|
||||
343
service/file_service.py
Normal file
343
service/file_service.py
Normal file
@ -0,0 +1,343 @@
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from PIL import ImageDraw, ImageFont
|
||||
from fastapi import UploadFile
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
# 上传根目录
|
||||
UPLOAD_ROOT = "upload"
|
||||
PRE = "/api/file/download/"
|
||||
|
||||
# 确保上传根目录存在
|
||||
os.makedirs(UPLOAD_ROOT, exist_ok=True)
|
||||
|
||||
|
||||
|
||||
def save_detect_file(client_ip: str, image_np: np.ndarray, file_type: str) -> str:
|
||||
"""保存numpy数组格式的PNG图片到detect目录,返回下载路径"""
|
||||
today = datetime.now()
|
||||
year = today.strftime("%Y")
|
||||
month = today.strftime("%m")
|
||||
day = today.strftime("%d")
|
||||
|
||||
# 构建目录路径: upload/detect/客户端IP/type/年/月/日(包含UPLOAD_ROOT)
|
||||
file_dir = os.path.join(
|
||||
UPLOAD_ROOT,
|
||||
"detect",
|
||||
client_ip,
|
||||
file_type,
|
||||
year,
|
||||
month,
|
||||
day
|
||||
)
|
||||
|
||||
# 创建目录(确保目录存在)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
|
||||
# 生成当前时间戳作为文件名,确保唯一性
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
|
||||
filename = f"{timestamp}.png"
|
||||
|
||||
# 1. 完整路径:用于实际保存文件(包含UPLOAD_ROOT)
|
||||
full_path = os.path.join(file_dir, filename)
|
||||
# 2. 相对路径:用于返回给前端(移除UPLOAD_ROOT前缀)
|
||||
relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep)
|
||||
|
||||
# 保存numpy数组为PNG图片
|
||||
try:
|
||||
# -------- 新增/修改:处理颜色通道和数据类型 --------
|
||||
# 1. 数据类型转换:确保是uint8(若为float32且范围0-1,需转成0-255的uint8)
|
||||
if image_np.dtype != np.uint8:
|
||||
image_np = (image_np * 255).astype(np.uint8)
|
||||
|
||||
# 2. 通道顺序转换:若为OpenCV的BGR格式,转成PIL需要的RGB格式
|
||||
image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 3. 转换为PIL Image并保存
|
||||
img = Image.fromarray(image_rgb)
|
||||
img.save(full_path, format='PNG')
|
||||
except Exception as e:
|
||||
# 处理可能的异常(如数组格式不正确)
|
||||
raise Exception(f"保存图片失败: {str(e)}")
|
||||
|
||||
# 统一路径分隔符为/,拼接前缀返回
|
||||
return PRE + relative_path.replace(os.sep, "/")
|
||||
|
||||
|
||||
def save_detect_yolo_file(
|
||||
client_ip: str,
|
||||
image_np: np.ndarray,
|
||||
detection_results: list,
|
||||
file_type: str = "yolo"
|
||||
) -> str:
|
||||
|
||||
|
||||
print("......................")
|
||||
"""
|
||||
保存YOLO检测结果图片(在原图上绘制边界框+标签),返回前端可访问的下载路径
|
||||
"""
|
||||
# 输入参数验证
|
||||
if not isinstance(image_np, np.ndarray):
|
||||
raise ValueError(f"输入image_np必须是numpy数组,当前类型:{type(image_np)}")
|
||||
if image_np.ndim != 3 or image_np.shape[-1] != 3:
|
||||
raise ValueError(f"输入图像必须是 (h, w, 3) 的BGR数组,当前shape:{image_np.shape}")
|
||||
|
||||
if not isinstance(detection_results, list):
|
||||
raise ValueError(f"detection_results必须是列表,当前类型:{type(detection_results)}")
|
||||
for idx, result in enumerate(detection_results):
|
||||
required_keys = {"class", "confidence", "bbox"}
|
||||
if not isinstance(result, dict) or not required_keys.issubset(result.keys()):
|
||||
raise ValueError(
|
||||
f"detection_results第{idx}个元素格式错误,需包含键:{required_keys},"
|
||||
f"当前元素:{result}"
|
||||
)
|
||||
bbox = result["bbox"]
|
||||
if not (isinstance(bbox, (tuple, list)) and len(bbox) == 4 and all(isinstance(x, int) for x in bbox)):
|
||||
raise ValueError(
|
||||
f"detection_results第{idx}个元素的bbox格式错误,需为(x1,y1,x2,y2)整数元组,"
|
||||
f"当前bbox:{bbox}"
|
||||
)
|
||||
|
||||
#图像预处理(数据类型+通道)
|
||||
draw_image = image_np.copy()
|
||||
if draw_image.dtype != np.uint8:
|
||||
draw_image = np.clip(draw_image * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
#绘制边界框+标签
|
||||
# 遍历所有检测结果,逐个绘制
|
||||
for result in detection_results:
|
||||
class_name = result["class"]
|
||||
confidence = result["confidence"]
|
||||
x1, y1, x2, y2 = result["bbox"]
|
||||
cv2.rectangle(draw_image, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2)
|
||||
label = f"{class_name}: {confidence:.2f}"
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
font_scale = 0.5
|
||||
font_thickness = 2
|
||||
(label_width, label_height), baseline = cv2.getTextSize(
|
||||
label, font, font_scale, font_thickness
|
||||
)
|
||||
|
||||
bg_top_left = (x1, y1 - label_height - 10)
|
||||
bg_bottom_right = (x1 + label_width, y1)
|
||||
if bg_top_left[1] < 0:
|
||||
bg_top_left = (x1, 0)
|
||||
bg_bottom_right = (x1 + label_width, label_height + 10)
|
||||
cv2.rectangle(draw_image, bg_top_left, bg_bottom_right, color=(0, 0, 0), thickness=-1)
|
||||
|
||||
text_origin = (x1, y1 - 5)
|
||||
if bg_top_left[1] == 0:
|
||||
text_origin = (x1, label_height + 5)
|
||||
cv2.putText(
|
||||
draw_image, label, text_origin,
|
||||
font, font_scale, color=(255, 255, 255), thickness=font_thickness
|
||||
)
|
||||
|
||||
#保存图片
|
||||
try:
|
||||
today = datetime.now()
|
||||
year = today.strftime("%Y")
|
||||
month = today.strftime("%m")
|
||||
day = today.strftime("%d")
|
||||
file_dir = os.path.join(
|
||||
UPLOAD_ROOT, "detect", client_ip, file_type, year, month, day
|
||||
)
|
||||
|
||||
#创建目录(若不存在则创建,支持多级目录)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
|
||||
#生成唯一文件名
|
||||
timestamp = today.strftime("%Y%m%d%H%M%S%f")
|
||||
filename = f"{timestamp}.png"
|
||||
|
||||
# 4.4 构建完整保存路径和前端访问路径
|
||||
full_path = os.path.join(file_dir, filename) # 本地完整路径
|
||||
# 相对路径:移除UPLOAD_ROOT前缀,统一用"/"作为分隔符(兼容Windows/Linux)
|
||||
relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep)
|
||||
download_path = PRE + relative_path.replace(os.sep, "/")
|
||||
|
||||
# 4.5 保存图片(CV2绘制的是BGR,需转RGB后用PIL保存,与原逻辑一致)
|
||||
image_rgb = cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB)
|
||||
img_pil = Image.fromarray(image_rgb)
|
||||
img_pil.save(full_path, format="PNG", quality=95) # PNG格式无压缩,quality可忽略
|
||||
|
||||
print(f"YOLO检测图片保存成功 | 本地路径:{full_path} | 下载路径:{download_path}")
|
||||
return download_path
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"YOLO检测图片保存失败:{str(e)}") from e
|
||||
|
||||
|
||||
def save_detect_face_file(
|
||||
client_ip: str,
|
||||
image_np: np.ndarray,
|
||||
face_result: str,
|
||||
file_type: str = "face",
|
||||
matched_color: tuple = (0, 255, 0)
|
||||
) -> str:
|
||||
"""
|
||||
保存人脸识别结果图片(仅为「匹配成功」的人脸画框,标签不包含“匹配”二字)
|
||||
"""
|
||||
#输入参数验证
|
||||
if not isinstance(image_np, np.ndarray) or image_np.ndim != 3 or image_np.shape[-1] != 3:
|
||||
raise ValueError(f"输入图像需为 (h, w, 3) 的BGR数组,当前shape:{image_np.shape}")
|
||||
if not isinstance(face_result, str) or face_result.strip() == "":
|
||||
raise ValueError("face_result必须是非空字符串")
|
||||
|
||||
# 解析face_result提取人脸信息
|
||||
face_info_list = []
|
||||
if face_result.strip() != "未检测到人脸":
|
||||
face_pattern = re.compile(
|
||||
r"(匹配|未匹配):\s*([^\s(]+)\s*\(相似度:\s*(\d+\.\d+),\s*边界框:\s*\[(\d+,\s*\d+,\s*\d+,\s*\d+)\]\)"
|
||||
)
|
||||
for part in [p.strip() for p in face_result.split(";") if p.strip()]:
|
||||
match = face_pattern.match(part)
|
||||
if match:
|
||||
status, name, similarity, bbox_str = match.groups()
|
||||
bbox = list(map(int, bbox_str.replace(" ", "").split(",")))
|
||||
if len(bbox) == 4:
|
||||
face_info_list.append({
|
||||
"status": status,
|
||||
"name": name,
|
||||
"similarity": float(similarity),
|
||||
"bbox": bbox
|
||||
})
|
||||
|
||||
# 图像格式转换(OpenCV→PIL)
|
||||
image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
|
||||
pil_img = Image.fromarray(image_rgb)
|
||||
draw = ImageDraw.Draw(pil_img)
|
||||
|
||||
# 绘制边界框和标签
|
||||
font_size = 12
|
||||
try:
|
||||
font = ImageFont.truetype("simhei", font_size)
|
||||
except:
|
||||
try:
|
||||
font = ImageFont.truetype("simsun", font_size)
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
print("警告:未找到指定中文字体,使用PIL默认字体(可能影响中文显示)")
|
||||
|
||||
for face_info in face_info_list:
|
||||
status = face_info["status"]
|
||||
if status != "匹配":
|
||||
print(f"跳过未匹配人脸:{face_info['name']}(相似度:{face_info['similarity']:.2f})")
|
||||
continue
|
||||
|
||||
name = face_info["name"]
|
||||
similarity = face_info["similarity"]
|
||||
x1, y1, x2, y2 = face_info["bbox"]
|
||||
|
||||
# 4.1 绘制边界框(绿色)
|
||||
img_cv = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
||||
cv2.rectangle(img_cv, (x1, y1), (x2, y2), color=matched_color, thickness=2)
|
||||
pil_img = Image.fromarray(cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB))
|
||||
draw = ImageDraw.Draw(pil_img)
|
||||
|
||||
label = f"{name} (相似度: {similarity:.2f})"
|
||||
|
||||
# 4.3 计算标签尺寸(文本变短后会自动适配,无需额外调整)
|
||||
label_bbox = draw.textbbox((0, 0), label, font=font)
|
||||
label_width = label_bbox[2] - label_bbox[0]
|
||||
label_height = label_bbox[3] - label_bbox[1]
|
||||
|
||||
# 4.4 计算标签背景位置(避免超出图像)
|
||||
bg_x1, bg_y1 = x1, y1 - label_height - 10
|
||||
bg_x2, bg_y2 = x1 + label_width, y1
|
||||
if bg_y1 < 0:
|
||||
bg_y1, bg_y2 = y2 + 5, y2 + label_height + 15
|
||||
|
||||
# 4.5 绘制标签背景(黑色)和文本(白色)
|
||||
draw.rectangle([(bg_x1, bg_y1), (bg_x2, bg_y2)], fill=(0, 0, 0))
|
||||
text_x = bg_x1
|
||||
text_y = bg_y1 if bg_y1 >= 0 else bg_y1 + label_height
|
||||
draw.text((text_x, text_y), label, font=font, fill=(255, 255, 255))
|
||||
|
||||
#保存图片
|
||||
try:
|
||||
today = datetime.now()
|
||||
file_dir = os.path.join(
|
||||
UPLOAD_ROOT, "detect", client_ip, file_type,
|
||||
today.strftime("%Y"), today.strftime("%m"), today.strftime("%d")
|
||||
)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
|
||||
timestamp = today.strftime("%Y%m%d%H%M%S%f")
|
||||
filename = f"{timestamp}.png"
|
||||
full_path = os.path.join(file_dir, filename)
|
||||
|
||||
pil_img.save(full_path, format="PNG", quality=100)
|
||||
|
||||
relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep)
|
||||
download_path = PRE + relative_path.replace(os.sep, "/")
|
||||
|
||||
matched_count = sum(1 for info in face_info_list if info["status"] == "匹配")
|
||||
print(f"人脸检测图片保存成功 | 客户端IP:{client_ip} | 匹配人脸数:{matched_count} | 保存路径:{download_path}")
|
||||
return download_path
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"人脸检测图片保存失败(客户端IP:{client_ip}):{str(e)}") from e
|
||||
|
||||
def save_source_file(upload_file: UploadFile, file_type: str) -> str:
|
||||
"""保存上传的文件到source目录,返回下载路径"""
|
||||
today = datetime.now()
|
||||
year = today.strftime("%Y")
|
||||
month = today.strftime("%m")
|
||||
day = today.strftime("%d")
|
||||
|
||||
# 生成精确到微秒的时间戳,确保文件名唯一
|
||||
timestamp = today.strftime("%Y%m%d%H%M%S%f")
|
||||
# 构建新文件名:时间戳_原文件名
|
||||
unique_filename = f"{timestamp}_{upload_file.filename}"
|
||||
|
||||
# 构建目录路径: upload/source/type/年/月/日(包含UPLOAD_ROOT)
|
||||
file_dir = os.path.join(
|
||||
UPLOAD_ROOT,
|
||||
"source",
|
||||
file_type,
|
||||
year,
|
||||
month,
|
||||
day
|
||||
)
|
||||
|
||||
# 创建目录(确保目录存在)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
|
||||
# 1. 完整路径:用于实际保存文件(使用带时间戳的唯一文件名)
|
||||
full_path = os.path.join(file_dir, unique_filename)
|
||||
# 2. 相对路径:用于返回给前端
|
||||
relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep)
|
||||
|
||||
# 保存文件(使用完整路径)
|
||||
try:
|
||||
with open(full_path, "wb") as buffer:
|
||||
shutil.copyfileobj(upload_file.file, buffer)
|
||||
finally:
|
||||
upload_file.file.close()
|
||||
|
||||
# 统一路径分隔符为/
|
||||
return PRE + relative_path.replace(os.sep, "/")
|
||||
|
||||
|
||||
def get_absolute_path(relative_path: str) -> str:
|
||||
"""
|
||||
根据相对路径获取服务器上的绝对路径
|
||||
"""
|
||||
path_without_pre = relative_path.replace(PRE, "", 1)
|
||||
|
||||
# 将相对路径转换为系统兼容的格式
|
||||
normalized_path = os.path.normpath(path_without_pre)
|
||||
|
||||
# 拼接基础路径和相对路径,得到绝对路径
|
||||
absolute_path = os.path.abspath(os.path.join(UPLOAD_ROOT, normalized_path))
|
||||
|
||||
# 安全检查:确保生成的路径在UPLOAD_ROOT目录下,防止路径遍历
|
||||
if not absolute_path.startswith(os.path.abspath(UPLOAD_ROOT)):
|
||||
raise ValueError("无效的相对路径,可能试图访问上传目录之外的内容")
|
||||
|
||||
return absolute_path
|
||||
131
service/model_service.py
Normal file
131
service/model_service.py
Normal file
@ -0,0 +1,131 @@
|
||||
from http.client import HTTPException
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from MySQLdb import MySQLError
|
||||
from ultralytics import YOLO
|
||||
import os
|
||||
|
||||
from ds.db import db
|
||||
from service.file_service import get_absolute_path
|
||||
|
||||
# 全局变量
|
||||
current_yolo_model = None
|
||||
current_model_absolute_path = None # 存储模型绝对路径,不依赖model实例
|
||||
|
||||
ALLOWED_MODEL_EXT = {"pt"}
|
||||
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
|
||||
|
||||
|
||||
def load_yolo_model():
|
||||
"""加载模型并存储绝对路径"""
|
||||
global current_yolo_model, current_model_absolute_path
|
||||
model_rel_path = get_enabled_model_rel_path()
|
||||
print(f"[模型初始化] 加载模型:{model_rel_path}")
|
||||
|
||||
# 计算并存储绝对路径
|
||||
current_model_absolute_path = get_absolute_path(model_rel_path)
|
||||
print(f"[模型初始化] 绝对路径:{current_model_absolute_path}")
|
||||
|
||||
# 检查模型文件
|
||||
if not os.path.exists(current_model_absolute_path):
|
||||
raise FileNotFoundError(f"模型文件不存在: {current_model_absolute_path}")
|
||||
|
||||
try:
|
||||
new_model = YOLO(current_model_absolute_path)
|
||||
if torch.cuda.is_available():
|
||||
new_model.to('cuda')
|
||||
print("模型已移动到GPU")
|
||||
else:
|
||||
print("使用CPU进行推理")
|
||||
current_yolo_model = new_model
|
||||
print(f"成功加载模型: {current_model_absolute_path}")
|
||||
return current_yolo_model
|
||||
except Exception as e:
|
||||
print(f"模型加载失败:{str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_current_model():
|
||||
"""获取当前模型实例"""
|
||||
if current_yolo_model is None:
|
||||
raise ValueError("尚未加载任何YOLO模型,请先调用load_yolo_model加载模型")
|
||||
return current_yolo_model
|
||||
|
||||
|
||||
def detect(image_np, conf_threshold=0.8):
|
||||
# 1. 输入格式验证
|
||||
if not isinstance(image_np, np.ndarray):
|
||||
raise ValueError("输入必须是numpy数组(BGR图像)")
|
||||
if image_np.ndim != 3 or image_np.shape[-1] != 3:
|
||||
raise ValueError(f"输入图像格式错误,需为 (h, w, 3) 的BGR数组,当前shape: {image_np.shape}")
|
||||
detection_results = []
|
||||
try:
|
||||
model = get_current_model()
|
||||
if not current_model_absolute_path:
|
||||
raise RuntimeError("模型未初始化!请先调用 load_yolo_model 加载模型")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"检测设备:{device} | 置信度阈值:{conf_threshold}")
|
||||
|
||||
# 图像尺寸信息
|
||||
img_height, img_width = image_np.shape[:2]
|
||||
print(f"输入图像尺寸:{img_width}x{img_height}")
|
||||
|
||||
# YOLO检测
|
||||
print("执行YOLO检测")
|
||||
results = model.predict(
|
||||
image_np,
|
||||
conf=conf_threshold,
|
||||
device=device,
|
||||
show=False,
|
||||
)
|
||||
|
||||
# 4. 整理检测结果(仅保留Chest类别,ID=2)
|
||||
for box in results[0].boxes:
|
||||
class_id = int(box.cls[0]) # 类别ID
|
||||
class_name = model.names[class_id]
|
||||
confidence = float(box.conf[0])
|
||||
bbox = tuple(map(int, box.xyxy[0]))
|
||||
|
||||
# 过滤条件:置信度达标 + 类别为Chest(class_id=2)
|
||||
# and class_id == 2
|
||||
if confidence >= conf_threshold:
|
||||
detection_results.append({
|
||||
"class": class_name,
|
||||
"confidence": confidence,
|
||||
"bbox": bbox
|
||||
})
|
||||
|
||||
# 判断是否有目标
|
||||
has_content = len(detection_results) > 0
|
||||
return has_content, detection_results
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"检测过程出错:{str(e)}"
|
||||
print(error_msg)
|
||||
return False, None
|
||||
|
||||
|
||||
def get_enabled_model_rel_path():
|
||||
"""获取数据库中启用的模型相对路径"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
query = "SELECT path FROM model WHERE is_default = 1 LIMIT 1"
|
||||
cursor.execute(query)
|
||||
result = cursor.fetchone()
|
||||
|
||||
if not result or not result.get('path'):
|
||||
raise HTTPException(status_code=404, detail="未找到启用的默认模型")
|
||||
|
||||
return result['path']
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询默认模型时发生数据库错误:{str(e)}") from e
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
raise HTTPException(status_code=500, detail=f"获取默认模型路径失败:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
131
service/ocr_service.py
Normal file
131
service/ocr_service.py
Normal file
@ -0,0 +1,131 @@
|
||||
# 首先添加NumPy兼容处理
|
||||
import numpy as np
|
||||
|
||||
# 修复np.int已弃用的问题
|
||||
if not hasattr(np, 'int'):
|
||||
np.int = int
|
||||
|
||||
from paddleocr import PaddleOCR
|
||||
from service.sensitive_service import get_all_sensitive_words
|
||||
|
||||
_ocr_engine = None
|
||||
_forbidden_words = set()
|
||||
_conf_threshold = 0.5
|
||||
|
||||
def set_forbidden_words(new_words):
|
||||
global _forbidden_words
|
||||
if not isinstance(new_words, (set, list, tuple)):
|
||||
raise TypeError("新违禁词必须是集合、列表或元组类型")
|
||||
_forbidden_words = set(new_words) # 确保是集合类型
|
||||
print(f"已通过函数更新违禁词,当前数量: {len(_forbidden_words)}")
|
||||
|
||||
def load_forbidden_words():
|
||||
global _forbidden_words
|
||||
try:
|
||||
_forbidden_words = get_all_sensitive_words()
|
||||
print(f"加载的违禁词数量: {len(_forbidden_words)}")
|
||||
except Exception as e:
|
||||
print(f"Forbidden words load error: {e}")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def init_ocr_engine():
|
||||
global _ocr_engine
|
||||
try:
|
||||
_ocr_engine = PaddleOCR(
|
||||
use_angle_cls=True,
|
||||
lang="ch",
|
||||
show_log=False,
|
||||
use_gpu=True,
|
||||
max_text_length=1024
|
||||
)
|
||||
load_result = load_forbidden_words()
|
||||
if not load_result:
|
||||
print("警告:违禁词加载失败,可能影响检测功能")
|
||||
print("OCR引擎初始化完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"OCR引擎初始化错误: {e}")
|
||||
_ocr_engine = None
|
||||
return False
|
||||
|
||||
|
||||
def detect(frame, conf_threshold=0.8):
|
||||
print("开始进行OCR检测...")
|
||||
try:
|
||||
ocr_res = _ocr_engine.ocr(frame, cls=True)
|
||||
if not ocr_res or not isinstance(ocr_res, list):
|
||||
return (False, "无OCR结果")
|
||||
|
||||
texts = []
|
||||
confs = []
|
||||
for line in ocr_res:
|
||||
if line is None:
|
||||
continue
|
||||
if isinstance(line, list):
|
||||
items_to_process = line
|
||||
else:
|
||||
items_to_process = [line]
|
||||
|
||||
for item in items_to_process:
|
||||
if isinstance(item, list) and len(item) == 4:
|
||||
is_coordinate = True
|
||||
for point in item:
|
||||
if not (isinstance(point, list) and len(point) == 2 and
|
||||
all(isinstance(coord, (int, float)) for coord in point)):
|
||||
is_coordinate = False
|
||||
break
|
||||
if is_coordinate:
|
||||
continue
|
||||
if isinstance(item, list) and all(isinstance(x, (int, float)) for x in item):
|
||||
continue
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
text, conf = item
|
||||
if isinstance(text, str) and isinstance(conf, (int, float)):
|
||||
texts.append(text.strip())
|
||||
confs.append(float(conf))
|
||||
continue
|
||||
if isinstance(item, list) and len(item) >= 2:
|
||||
text_data = item[1]
|
||||
if isinstance(text_data, tuple) and len(text_data) == 2:
|
||||
text, conf = text_data
|
||||
if isinstance(text, str) and isinstance(conf, (int, float)):
|
||||
texts.append(text.strip())
|
||||
confs.append(float(conf))
|
||||
continue
|
||||
elif isinstance(text_data, str):
|
||||
texts.append(text_data.strip())
|
||||
confs.append(1.0)
|
||||
continue
|
||||
print(f"无法解析的OCR结果格式: {item}")
|
||||
|
||||
if len(texts) != len(confs):
|
||||
return (False, "OCR结果格式异常")
|
||||
|
||||
# 收集所有识别到的违禁词(去重且保持出现顺序)
|
||||
vio_words = []
|
||||
for txt, conf in zip(texts, confs):
|
||||
if conf < _conf_threshold: # 过滤低置信度结果
|
||||
continue
|
||||
# 提取当前文本中包含的违禁词
|
||||
matched = [w for w in _forbidden_words if w in txt]
|
||||
# 仅添加未记录过的违禁词(去重)
|
||||
for word in matched:
|
||||
if word not in vio_words:
|
||||
vio_words.append(word)
|
||||
|
||||
has_text = len(texts) > 0
|
||||
has_violation = len(vio_words) > 0
|
||||
|
||||
if not has_text:
|
||||
return (False, "未识别到文本")
|
||||
elif has_violation:
|
||||
# 多个违禁词用逗号拼接
|
||||
return (True, ", ".join(vio_words))
|
||||
else:
|
||||
return (False, "未检测到违禁词")
|
||||
|
||||
except Exception as e:
|
||||
print(f"OCR detect error: {e}")
|
||||
return (False, f"检测错误: {str(e)}")
|
||||
36
service/sensitive_service.py
Normal file
36
service/sensitive_service.py
Normal file
@ -0,0 +1,36 @@
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
|
||||
|
||||
def get_all_sensitive_words() -> list[str]:
|
||||
"""
|
||||
获取所有敏感词(返回纯字符串列表、用于过滤业务)
|
||||
|
||||
返回:
|
||||
list[str]: 包含所有敏感词的数组
|
||||
|
||||
异常:
|
||||
MySQLError: 数据库操作相关错误
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
# 获取数据库连接
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 执行查询(只获取敏感词字段、按ID排序)
|
||||
query = "SELECT name FROM sensitives ORDER BY id"
|
||||
cursor.execute(query)
|
||||
sensitive_records = cursor.fetchall()
|
||||
|
||||
# 提取敏感词到纯字符串数组
|
||||
return [record['name'] for record in sensitive_records]
|
||||
|
||||
except MySQLError as e:
|
||||
# 数据库错误向上抛出、由调用方处理
|
||||
raise MySQLError(f"查询敏感词列表失败: {str(e)}") from e
|
||||
finally:
|
||||
# 确保数据库连接正确释放
|
||||
db.close_connection(conn, cursor)
|
||||
Reference in New Issue
Block a user