内容安全审核
8
.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
8
.idea/Detect.iml
generated
Normal file
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.10 (2)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
||||
98
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@ -0,0 +1,98 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="76">
|
||||
<item index="0" class="java.lang.String" itemvalue="scipy" />
|
||||
<item index="1" class="java.lang.String" itemvalue="protobuf" />
|
||||
<item index="2" class="java.lang.String" itemvalue="thop" />
|
||||
<item index="3" class="java.lang.String" itemvalue="opencv-python" />
|
||||
<item index="4" class="java.lang.String" itemvalue="PyYAML" />
|
||||
<item index="5" class="java.lang.String" itemvalue="ipython" />
|
||||
<item index="6" class="java.lang.String" itemvalue="torch" />
|
||||
<item index="7" class="java.lang.String" itemvalue="numpy" />
|
||||
<item index="8" class="java.lang.String" itemvalue="requests" />
|
||||
<item index="9" class="java.lang.String" itemvalue="torchvision" />
|
||||
<item index="10" class="java.lang.String" itemvalue="psutil" />
|
||||
<item index="11" class="java.lang.String" itemvalue="tqdm" />
|
||||
<item index="12" class="java.lang.String" itemvalue="pandas" />
|
||||
<item index="13" class="java.lang.String" itemvalue="tensorboard" />
|
||||
<item index="14" class="java.lang.String" itemvalue="seaborn" />
|
||||
<item index="15" class="java.lang.String" itemvalue="matplotlib" />
|
||||
<item index="16" class="java.lang.String" itemvalue="Pillow" />
|
||||
<item index="17" class="java.lang.String" itemvalue="fastapi" />
|
||||
<item index="18" class="java.lang.String" itemvalue="uvicorn" />
|
||||
<item index="19" class="java.lang.String" itemvalue="python-jose" />
|
||||
<item index="20" class="java.lang.String" itemvalue="passlib" />
|
||||
<item index="21" class="java.lang.String" itemvalue="pydantic" />
|
||||
<item index="22" class="java.lang.String" itemvalue="sqlalchemy" />
|
||||
<item index="23" class="java.lang.String" itemvalue="imageio_ffmpeg" />
|
||||
<item index="24" class="java.lang.String" itemvalue="ultralytics" />
|
||||
<item index="25" class="java.lang.String" itemvalue="future" />
|
||||
<item index="26" class="java.lang.String" itemvalue="jose" />
|
||||
<item index="27" class="java.lang.String" itemvalue="ffmpeg-python" />
|
||||
<item index="28" class="java.lang.String" itemvalue="setuptools" />
|
||||
<item index="29" class="java.lang.String" itemvalue="opencv_python" />
|
||||
<item index="30" class="java.lang.String" itemvalue="rsa" />
|
||||
<item index="31" class="java.lang.String" itemvalue="greenlet" />
|
||||
<item index="32" class="java.lang.String" itemvalue="networkx" />
|
||||
<item index="33" class="java.lang.String" itemvalue="python-dateutil" />
|
||||
<item index="34" class="java.lang.String" itemvalue="SQLAlchemy" />
|
||||
<item index="35" class="java.lang.String" itemvalue="cffi" />
|
||||
<item index="36" class="java.lang.String" itemvalue="python-dotenv" />
|
||||
<item index="37" class="java.lang.String" itemvalue="h11" />
|
||||
<item index="38" class="java.lang.String" itemvalue="py-cpuinfo" />
|
||||
<item index="39" class="java.lang.String" itemvalue="cycler" />
|
||||
<item index="40" class="java.lang.String" itemvalue="MarkupSafe" />
|
||||
<item index="41" class="java.lang.String" itemvalue="pyasn1" />
|
||||
<item index="42" class="java.lang.String" itemvalue="pycparser" />
|
||||
<item index="43" class="java.lang.String" itemvalue="Jinja2" />
|
||||
<item index="44" class="java.lang.String" itemvalue="sniffio" />
|
||||
<item index="45" class="java.lang.String" itemvalue="ultralytics-thop" />
|
||||
<item index="46" class="java.lang.String" itemvalue="fsspec" />
|
||||
<item index="47" class="java.lang.String" itemvalue="filelock" />
|
||||
<item index="48" class="java.lang.String" itemvalue="starlette" />
|
||||
<item index="49" class="java.lang.String" itemvalue="certifi" />
|
||||
<item index="50" class="java.lang.String" itemvalue="anyio" />
|
||||
<item index="51" class="java.lang.String" itemvalue="urllib3" />
|
||||
<item index="52" class="java.lang.String" itemvalue="pyparsing" />
|
||||
<item index="53" class="java.lang.String" itemvalue="sympy" />
|
||||
<item index="54" class="java.lang.String" itemvalue="annotated-types" />
|
||||
<item index="55" class="java.lang.String" itemvalue="pydantic-settings" />
|
||||
<item index="56" class="java.lang.String" itemvalue="six" />
|
||||
<item index="57" class="java.lang.String" itemvalue="tzdata" />
|
||||
<item index="58" class="java.lang.String" itemvalue="ecdsa" />
|
||||
<item index="59" class="java.lang.String" itemvalue="kiwisolver" />
|
||||
<item index="60" class="java.lang.String" itemvalue="packaging" />
|
||||
<item index="61" class="java.lang.String" itemvalue="python-multipart" />
|
||||
<item index="62" class="java.lang.String" itemvalue="click" />
|
||||
<item index="63" class="java.lang.String" itemvalue="contourpy" />
|
||||
<item index="64" class="java.lang.String" itemvalue="fonttools" />
|
||||
<item index="65" class="java.lang.String" itemvalue="pydantic_core" />
|
||||
<item index="66" class="java.lang.String" itemvalue="av" />
|
||||
<item index="67" class="java.lang.String" itemvalue="colorama" />
|
||||
<item index="68" class="java.lang.String" itemvalue="mpmath" />
|
||||
<item index="69" class="java.lang.String" itemvalue="argon2-cffi-bindings" />
|
||||
<item index="70" class="java.lang.String" itemvalue="typing_extensions" />
|
||||
<item index="71" class="java.lang.String" itemvalue="charset-normalizer" />
|
||||
<item index="72" class="java.lang.String" itemvalue="pillow" />
|
||||
<item index="73" class="java.lang.String" itemvalue="argon2-cffi" />
|
||||
<item index="74" class="java.lang.String" itemvalue="pytz" />
|
||||
<item index="75" class="java.lang.String" itemvalue="idna" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||
<option name="ignoredErrors">
|
||||
<list>
|
||||
<option value="N802" />
|
||||
</list>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
<inspection_tool class="Stylelint" enabled="true" level="ERROR" enabled_by_default="true" />
|
||||
</profile>
|
||||
</component>
|
||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
7
.idea/misc.xml
generated
Normal file
@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="video" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (2)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
8
.idea/modules.xml
generated
Normal file
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/Detect.iml" filepath="$PROJECT_DIR$/.idea/Detect.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
BIN
__pycache__/main.cpython-310.pyc
Normal file
20
config.ini
Normal file
@ -0,0 +1,20 @@
|
||||
[server]
|
||||
port = 8000
|
||||
|
||||
[mysql]
|
||||
host = 192.168.110.65
|
||||
port = 6975
|
||||
user = video_check
|
||||
password = fsjPfhxCs8NrFGmL
|
||||
database = video_check
|
||||
charset = utf8mb4
|
||||
|
||||
[jwt]
|
||||
secret_key = 6tsieyd87wefdw2wgeduwte23rfcsd
|
||||
algorithm = HS256
|
||||
access_token_expire_minutes = 30
|
||||
|
||||
[business]
|
||||
ocr_conf = 0.6
|
||||
face_conf = 0.6
|
||||
yolo_conf = 0.7
|
||||
BIN
core/__pycache__/detect.cpython-310.pyc
Normal file
140
core/detect.py
Normal file
@ -0,0 +1,140 @@
|
||||
import json
|
||||
import re
|
||||
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.config import BUSINESS_CONFIG
|
||||
from ds.db import db
|
||||
from service.face_service import detect as faceDetect,init_insightface
|
||||
from service.model_service import load_yolo_model,detect as yoloDetect
|
||||
from service.ocr_service import detect as ocrDetect,init_ocr_engine
|
||||
from service.file_service import save_detect_file, save_detect_yolo_file, save_detect_face_file
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
# 创建线程池执行器
|
||||
executor = ThreadPoolExecutor(max_workers=10)
|
||||
def init():
|
||||
# # 人脸相关
|
||||
init_insightface()
|
||||
# # 初始化OCR引擎
|
||||
init_ocr_engine()
|
||||
#初始化YOLO模型
|
||||
load_yolo_model()
|
||||
|
||||
def save_db(model_type, client_ip, result):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
# 连接数据库
|
||||
conn = db.get_connection()
|
||||
# 往表插入数据
|
||||
cursor = conn.cursor(dictionary=True) # 返回字典格式结果
|
||||
insert_query = """
|
||||
INSERT INTO device_danger (client_ip, type, result)
|
||||
VALUES (%s, %s, %s)
|
||||
"""
|
||||
cursor.execute(insert_query, (client_ip, model_type, result))
|
||||
conn.commit()
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取设备列表失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
def detectFrame(client_ip, frame):
|
||||
|
||||
# YOLO检测
|
||||
yolo_flag, yolo_result = yoloDetect(frame, float(BUSINESS_CONFIG["yolo_conf"]))
|
||||
if yolo_flag:
|
||||
print(f"❌ 检测到违规内容,保存图片,YOLO")
|
||||
danger_handler(client_ip)
|
||||
path = save_detect_yolo_file(client_ip, frame, yolo_result, "yolo")
|
||||
save_db(model_type="色情", client_ip=client_ip, result=str(path))
|
||||
|
||||
# 人脸检测
|
||||
face_flag, face_result = faceDetect(frame, float(BUSINESS_CONFIG["face_conf"]))
|
||||
if face_flag:
|
||||
print(f"❌ 检测到违规内容,保存图片,FACE")
|
||||
print("人脸识别内容:", face_result)
|
||||
model_type = extract_face_names(face_result)
|
||||
danger_handler(client_ip)
|
||||
path = save_detect_face_file(client_ip, frame, face_result, "face")
|
||||
save_db(model_type=model_type, client_ip=client_ip, result=str(path))
|
||||
|
||||
# OCR检测部分(使用修正后的提取函数)
|
||||
ocr_flag, ocr_result = ocrDetect(frame, float(BUSINESS_CONFIG["ocr_conf"]))
|
||||
if ocr_flag:
|
||||
print(f"❌ 检测到违规内容,保存图片,OCR")
|
||||
print("ocr识别内容:", ocr_result)
|
||||
danger_handler(client_ip)
|
||||
path = save_detect_file(client_ip, frame, "ocr")
|
||||
save_db(model_type=str(ocr_result), client_ip=client_ip, result=str(path))
|
||||
|
||||
# 仅当所有检测均未发现违规时才提示
|
||||
if not (face_flag or yolo_flag or ocr_flag):
|
||||
print(f"所有模型未检测到任何违规内容")
|
||||
|
||||
def danger_handler(client_ip):
|
||||
from ws.ws import send_message_to_client, get_current_time_str
|
||||
from service.device_service import increment_alarm_count_by_ip
|
||||
from service.device_service import update_is_need_handler_by_client_ip
|
||||
|
||||
danger_msg = {
|
||||
"type": "danger",
|
||||
"timestamp": get_current_time_str(),
|
||||
"client_ip": client_ip,
|
||||
}
|
||||
asyncio.run(
|
||||
send_message_to_client(
|
||||
client_ip=client_ip,
|
||||
json_data=json.dumps(danger_msg)
|
||||
)
|
||||
)
|
||||
lock_msg = {
|
||||
"type": "lock",
|
||||
"timestamp": get_current_time_str(),
|
||||
"client_ip": client_ip
|
||||
}
|
||||
asyncio.run(
|
||||
send_message_to_client(
|
||||
client_ip=client_ip,
|
||||
json_data=json.dumps(lock_msg)
|
||||
)
|
||||
)
|
||||
|
||||
# 增加危险记录次数
|
||||
increment_alarm_count_by_ip(client_ip)
|
||||
|
||||
# 更新设备状态为未处理
|
||||
update_is_need_handler_by_client_ip(client_ip, 1)
|
||||
|
||||
def extract_prohibited_words(ocr_result: str) -> str:
|
||||
"""
|
||||
从多文本块的ocr_result中提取所有违禁词(去重后用逗号拼接)
|
||||
适配格式:多个"文本: ... 包含违禁词: ...;"片段
|
||||
"""
|
||||
# 用正则匹配所有"包含违禁词: ...;"的片段(非贪婪匹配到分号)
|
||||
# 匹配规则:"包含违禁词: "后面的内容,直到遇到";"结束
|
||||
pattern = r"包含违禁词: (.*?);"
|
||||
all_prohibited_segments = re.findall(pattern, ocr_result, re.DOTALL)
|
||||
|
||||
all_words = []
|
||||
for segment in all_prohibited_segments:
|
||||
# 去除每个片段中的置信度信息(如"(置信度: 1.00)")
|
||||
cleaned = re.sub(r"\s*\([^)]*\)", "", segment.strip())
|
||||
# 分割词语并过滤空值
|
||||
words = [word.strip() for word in cleaned.split(",") if word.strip()]
|
||||
all_words.extend(words)
|
||||
|
||||
# 去重后用逗号拼接
|
||||
unique_words = list(set(all_words))
|
||||
return ",".join(unique_words)
|
||||
|
||||
|
||||
def extract_face_names(face_result: str) -> str:
|
||||
pattern = r"匹配: (.*?) \("
|
||||
all_names = re.findall(pattern, face_result)
|
||||
unique_names = list(set([name.strip() for name in all_names if name.strip()]))
|
||||
return ",".join(unique_names)
|
||||
BIN
ds/__pycache__/config.cpython-310.pyc
Normal file
BIN
ds/__pycache__/db.cpython-310.pyc
Normal file
17
ds/config.py
Normal file
@ -0,0 +1,17 @@
|
||||
import configparser
|
||||
import os
|
||||
|
||||
# 读取配置文件路径
|
||||
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../config.ini")
|
||||
|
||||
# 初始化配置解析器
|
||||
config = configparser.ConfigParser()
|
||||
|
||||
# 读取配置文件
|
||||
config.read(config_path, encoding="utf-8")
|
||||
|
||||
# 暴露配置项(方便其他文件调用)
|
||||
SERVER_CONFIG = config["server"]
|
||||
MYSQL_CONFIG = config["mysql"]
|
||||
JWT_CONFIG = config["jwt"]
|
||||
BUSINESS_CONFIG = config["business"]
|
||||
59
ds/db.py
Normal file
@ -0,0 +1,59 @@
|
||||
import mysql.connector
|
||||
from mysql.connector import Error
|
||||
from .config import MYSQL_CONFIG
|
||||
|
||||
_connection_pool = None
|
||||
|
||||
|
||||
class Database:
|
||||
"""MySQL 连接池管理类"""
|
||||
pool_config = {
|
||||
"host": MYSQL_CONFIG.get("host", "localhost"),
|
||||
"port": int(MYSQL_CONFIG.get("port", 3306)),
|
||||
"user": MYSQL_CONFIG.get("user", "root"),
|
||||
"password": MYSQL_CONFIG.get("password", ""),
|
||||
"database": MYSQL_CONFIG.get("database", ""),
|
||||
"charset": MYSQL_CONFIG.get("charset", "utf8mb4"),
|
||||
"pool_name": "fastapi_pool",
|
||||
"pool_size": 5,
|
||||
"pool_reset_session": True
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_connection(cls):
|
||||
"""获取数据库连接"""
|
||||
try:
|
||||
# 从连接池获取连接
|
||||
conn = mysql.connector.connect(**cls.pool_config)
|
||||
if conn.is_connected():
|
||||
return conn
|
||||
except Error as e:
|
||||
raise Exception(f"MySQL 连接失败: {str(e)}") from e
|
||||
|
||||
@classmethod
|
||||
def close_connection(cls, conn, cursor=None):
|
||||
"""关闭连接和游标"""
|
||||
try:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
if conn and conn.is_connected():
|
||||
conn.close()
|
||||
except Error as e:
|
||||
raise Exception(f"MySQL 连接关闭失败: {str(e)}") from e
|
||||
|
||||
@classmethod
|
||||
def close_all_connections(cls):
|
||||
"""清理连接池(服务重启前调用)"""
|
||||
try:
|
||||
# 先检查属性是否存在、再判断是否有值
|
||||
if hasattr(cls, "_connection_pool") and cls._connection_pool:
|
||||
cls._connection_pool = None # 重置连接池
|
||||
print("[Database] 连接池已重置、旧连接将被自动清理")
|
||||
else:
|
||||
print("[Database] 连接池未初始化或已重置、无需操作")
|
||||
except Exception as e:
|
||||
print(f"[Database] 重置连接池失败: {str(e)}")
|
||||
|
||||
|
||||
# 暴露数据库操作工具
|
||||
db = Database()
|
||||
BIN
encryption/__pycache__/encrypt_decorator.cpython-310.pyc
Normal file
BIN
encryption/__pycache__/encryption.cpython-310.pyc
Normal file
40
encryption/encrypt_decorator.py
Normal file
@ -0,0 +1,40 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from encryption.encryption import aes_encrypt
|
||||
from schema.response_schema import APIResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def encrypt_response(field: str = "data"):
|
||||
"""接口返回值加密装饰器:正确序列化自定义对象为JSON"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
original_response: APIResponse = await func(*args, **kwargs)
|
||||
field_value = getattr(original_response, field)
|
||||
|
||||
if not field_value:
|
||||
return original_response
|
||||
|
||||
# 自定义JSON序列化函数:处理Pydantic模型和datetime
|
||||
def json_default(obj: Any) -> Any:
|
||||
if isinstance(obj, BaseModel):
|
||||
return obj.model_dump()
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
return str(obj)
|
||||
|
||||
# 使用自定义序列化函数、确保生成标准JSON
|
||||
field_value_json = json.dumps(field_value, default=json_default)
|
||||
encrypted_data = aes_encrypt(field_value_json)
|
||||
setattr(original_response, field, encrypted_data)
|
||||
|
||||
return original_response
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
56
encryption/encryption.py
Normal file
@ -0,0 +1,56 @@
|
||||
import os
|
||||
import base64
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from fastapi import HTTPException
|
||||
|
||||
# 硬编码AES密钥(32字节、AES-256)
|
||||
AES_SECRET_KEY = b"jr1vA6tfWMHOYi6UXw67UuO6fdak2rMa"
|
||||
AES_BLOCK_SIZE = 16 # AES固定块大小
|
||||
|
||||
# 校验密钥长度
|
||||
valid_key_lengths = [16, 24, 32]
|
||||
if len(AES_SECRET_KEY) not in valid_key_lengths:
|
||||
raise ValueError(
|
||||
f"AES密钥长度必须为{valid_key_lengths}字节、当前为{len(AES_SECRET_KEY)}字节"
|
||||
)
|
||||
|
||||
|
||||
def aes_encrypt(plaintext: str) -> dict:
|
||||
"""AES-CBC模式加密(返回密文+IV、均为Base64编码)"""
|
||||
try:
|
||||
# 生成随机IV(16字节)
|
||||
iv = os.urandom(AES_BLOCK_SIZE)
|
||||
|
||||
# 创建加密器
|
||||
cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv)
|
||||
|
||||
# 明文填充并加密
|
||||
padded_plaintext = pad(plaintext.encode("utf-8"), AES_BLOCK_SIZE)
|
||||
ciphertext = base64.b64encode(cipher.encrypt(padded_plaintext)).decode("utf-8")
|
||||
iv_base64 = base64.b64encode(iv).decode("utf-8")
|
||||
|
||||
return {
|
||||
"ciphertext": ciphertext,
|
||||
"iv": iv_base64,
|
||||
"algorithm": "AES-CBC"
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"AES加密失败:{str(e)}") from e
|
||||
|
||||
|
||||
def aes_decrypt(ciphertext: str, iv: str) -> str:
|
||||
"""AES-CBC模式解密"""
|
||||
try:
|
||||
# 解码Base64
|
||||
ciphertext_bytes = base64.b64decode(ciphertext)
|
||||
iv_bytes = base64.b64decode(iv)
|
||||
|
||||
# 创建解密器
|
||||
cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv_bytes)
|
||||
|
||||
# 解密并去填充
|
||||
decrypted_bytes = unpad(cipher.decrypt(ciphertext_bytes), AES_BLOCK_SIZE)
|
||||
return decrypted_bytes.decode("utf-8")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"AES解密失败:{str(e)}") from e
|
||||
66
main.py
Normal file
@ -0,0 +1,66 @@
|
||||
import uvicorn
|
||||
import os
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
# 原有业务导入
|
||||
from ds.config import SERVER_CONFIG
|
||||
from middle.error_handler import global_exception_handler
|
||||
from router.user_router import router as user_router
|
||||
from router.sensitive_router import router as sensitive_router
|
||||
from router.face_router import router as face_router
|
||||
from router.device_router import router as device_router
|
||||
from router.model_router import router as model_router
|
||||
from router.file_router import router as file_router
|
||||
from router.device_danger_router import router as device_danger_router
|
||||
from core.detect import init
|
||||
from ws.ws import ws_router, lifespan
|
||||
|
||||
# 初始化 FastAPI 应用
|
||||
app = FastAPI(
|
||||
title="内容安全审核后台",
|
||||
description="含图片访问服务和动态模型管理",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
ALLOWED_ORIGINS = [
|
||||
"*"
|
||||
]
|
||||
|
||||
# 配置 CORS 中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS, # 允许的前端域名
|
||||
allow_credentials=True, # 允许携带 Cookie
|
||||
allow_methods=["*"], # 允许所有 HTTP 方法
|
||||
allow_headers=["*"], # 允许所有请求头
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
app.include_router(user_router)
|
||||
app.include_router(device_router)
|
||||
app.include_router(face_router)
|
||||
app.include_router(sensitive_router)
|
||||
app.include_router(model_router)
|
||||
app.include_router(file_router)
|
||||
app.include_router(device_danger_router)
|
||||
app.include_router(ws_router)
|
||||
|
||||
# 注册全局异常处理器
|
||||
app.add_exception_handler(Exception, global_exception_handler)
|
||||
|
||||
# 主服务启动入口
|
||||
if __name__ == "__main__":
|
||||
# 启动 FastAPI 主服务(仅使用8000端口)
|
||||
port = int(SERVER_CONFIG.get("port", 8000))
|
||||
# 加载所有模型
|
||||
init()
|
||||
uvicorn.run(
|
||||
app="main:app",
|
||||
host="0.0.0.0",
|
||||
port=port,
|
||||
workers=1,
|
||||
ws="websockets",
|
||||
reload=False
|
||||
)
|
||||
BIN
middle/__pycache__/auth_middleware.cpython-310.pyc
Normal file
BIN
middle/__pycache__/error_handler.cpython-310.pyc
Normal file
102
middle/auth_middleware.py
Normal file
@ -0,0 +1,102 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from ds.config import JWT_CONFIG
|
||||
from ds.db import db
|
||||
|
||||
# ------------------------------
|
||||
# 密码加密配置
|
||||
# ------------------------------
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# ------------------------------
|
||||
# JWT 配置
|
||||
# ------------------------------
|
||||
SECRET_KEY = JWT_CONFIG["secret_key"]
|
||||
ALGORITHM = JWT_CONFIG["algorithm"]
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"])
|
||||
|
||||
# OAuth2 依赖(从请求头获取 Token、格式: Bearer <token>)
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 密码工具函数
|
||||
# ------------------------------
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证明文密码与加密密码是否匹配"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""对明文密码进行 bcrypt 加密"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# JWT 工具函数
|
||||
# ------------------------------
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""生成 JWT Token"""
|
||||
to_encode = data.copy()
|
||||
# 设置过期时间
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
|
||||
# 添加过期时间到 Token 数据
|
||||
to_encode.update({"exp": expire})
|
||||
# 生成 Token
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 认证依赖(获取当前登录用户)
|
||||
# ------------------------------
|
||||
def get_current_user(token: str = Depends(oauth2_scheme)): # 移除返回类型注解
|
||||
"""从 Token 中解析用户信息、验证通过后返回当前用户"""
|
||||
# 延迟导入、打破循环依赖
|
||||
from schema.user_schema import UserResponse # 在这里导入
|
||||
|
||||
# 认证失败异常
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token 无效或已过期",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
# 解码 Token
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
# 获取 Token 中的用户名
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
|
||||
# 从数据库查询用户(验证用户是否存在)
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
query = "SELECT id, username, created_at, updated_at FROM users WHERE username = %s"
|
||||
cursor.execute(query, (username,))
|
||||
user = cursor.fetchone()
|
||||
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
|
||||
# 转换为 UserResponse 模型(自动校验字段)
|
||||
return UserResponse(**user)
|
||||
except Exception as e:
|
||||
raise credentials_exception from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
68
middle/error_handler.py
Normal file
@ -0,0 +1,68 @@
|
||||
from fastapi import Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import HTTPException, RequestValidationError
|
||||
from mysql.connector import Error as MySQLError
|
||||
from jose import JWTError
|
||||
|
||||
from schema.response_schema import APIResponse
|
||||
|
||||
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
"""全局异常处理器: 所有未捕获的异常都会在这里统一处理"""
|
||||
# 请求参数验证错误(Pydantic 校验失败)
|
||||
if isinstance(exc, RequestValidationError):
|
||||
error_details = []
|
||||
for err in exc.errors():
|
||||
error_details.append(f"{err['loc'][1]}: {err['msg']}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content=APIResponse(
|
||||
code=400,
|
||||
message=f"请求参数错误: {'; '.join(error_details)}",
|
||||
data=None
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
# HTTP 异常(主动抛出的业务错误、如 401/404)
|
||||
if isinstance(exc, HTTPException):
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=APIResponse(
|
||||
code=exc.status_code,
|
||||
message=exc.detail,
|
||||
data=None
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
# JWT 相关错误(Token 无效/过期)
|
||||
if isinstance(exc, JWTError):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content=APIResponse(
|
||||
code=401,
|
||||
message="Token 无效或已过期",
|
||||
data=None
|
||||
).model_dump(),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# MySQL 数据库错误
|
||||
if isinstance(exc, MySQLError):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content=APIResponse(
|
||||
code=500,
|
||||
message=f"数据库错误: {str(exc)}",
|
||||
data=None
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
# 其他未知错误
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content=APIResponse(
|
||||
code=500,
|
||||
message=f"服务器内部错误: {str(exc)}",
|
||||
data=None
|
||||
).model_dump()
|
||||
)
|
||||
21
requirements.txt
Normal file
@ -0,0 +1,21 @@
|
||||
fastapi==0.116.1
|
||||
insightface==0.7.3
|
||||
numpy==1.26.1
|
||||
opencv_contrib_python==4.6.0.66
|
||||
opencv_python==4.8.1.78
|
||||
opencv_python_headless==4.11.0.86
|
||||
paddleocr==2.6.0.1
|
||||
paddlepaddle-gpu==2.5.2
|
||||
passlib==1.7.4
|
||||
Pillow==11.3.0
|
||||
pycryptodome==3.23.0
|
||||
pydantic==2.11.7
|
||||
python_jose==3.5.0
|
||||
torch==2.3.1+cu118
|
||||
torchaudio==2.3.1+cu118
|
||||
torchvision==0.18.1+cu118
|
||||
ultralytics==8.3.198
|
||||
uvicorn==0.35.0
|
||||
insightface==0.7.3
|
||||
onnxruntime-gpu==1.15.1
|
||||
uvicorn[standard]
|
||||
BIN
router/__pycache__/device_danger_router.cpython-310.pyc
Normal file
BIN
router/__pycache__/device_router.cpython-310.pyc
Normal file
BIN
router/__pycache__/face_router.cpython-310.pyc
Normal file
BIN
router/__pycache__/file_router.cpython-310.pyc
Normal file
BIN
router/__pycache__/model_router.cpython-310.pyc
Normal file
BIN
router/__pycache__/sensitive_router.cpython-310.pyc
Normal file
BIN
router/__pycache__/user_router.cpython-310.pyc
Normal file
112
router/device_action_router.py
Normal file
@ -0,0 +1,112 @@
|
||||
from fastapi import APIRouter, Query, Path
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.device_action_schema import (
|
||||
DeviceActionResponse,
|
||||
DeviceActionListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
|
||||
# 路由配置
|
||||
router = APIRouter(
|
||||
prefix="/api/device/actions",
|
||||
tags=["设备操作记录"]
|
||||
)
|
||||
|
||||
@router.get("/list", response_model=APIResponse, summary="分页查询设备操作记录")
|
||||
@encrypt_response()
|
||||
async def get_device_action_list(
|
||||
page: int = Query(1, ge=1, description="页码、默认1"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100"),
|
||||
client_ip: str = Query(None, description="按客户端IP筛选"),
|
||||
action: int = Query(None, ge=0, le=1, description="按状态筛选(0=离线、1=上线)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
# 构建筛选条件(参数化查询、避免注入)
|
||||
where_clause = []
|
||||
params = []
|
||||
if client_ip:
|
||||
where_clause.append("client_ip = %s")
|
||||
params.append(client_ip)
|
||||
if action is not None:
|
||||
where_clause.append("action = %s")
|
||||
params.append(action)
|
||||
|
||||
# 查询总记录数(用于返回 total)
|
||||
count_sql = "SELECT COUNT(*) AS total FROM device_action"
|
||||
if where_clause:
|
||||
count_sql += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_sql, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 分页查询记录(按创建时间倒序、确保最新记录在前)
|
||||
offset = (page - 1) * page_size
|
||||
list_sql = "SELECT * FROM device_action"
|
||||
if where_clause:
|
||||
list_sql += " WHERE " + " AND ".join(where_clause)
|
||||
list_sql += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset])
|
||||
|
||||
cursor.execute(list_sql, params)
|
||||
action_list = cursor.fetchall()
|
||||
|
||||
# 仅返回 total + device_actions
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="查询成功",
|
||||
data=DeviceActionListResponse(
|
||||
total=total,
|
||||
device_actions=[DeviceActionResponse(**item) for item in action_list]
|
||||
)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
@router.get("/{client_ip}", response_model=APIResponse, summary="根据IP查询设备操作记录")
|
||||
@encrypt_response()
|
||||
async def get_device_actions_by_ip(
|
||||
client_ip: str = Path(..., description="客户端IP地址")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 1. 查询总记录数
|
||||
count_sql = "SELECT COUNT(*) AS total FROM device_action WHERE client_ip = %s"
|
||||
cursor.execute(count_sql, (client_ip,))
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 2. 查询该IP的所有记录(按创建时间倒序)
|
||||
list_sql = """
|
||||
SELECT * FROM device_action
|
||||
WHERE client_ip = %s
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
cursor.execute(list_sql, (client_ip,))
|
||||
action_list = cursor.fetchall()
|
||||
|
||||
# 3. 返回结果
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="查询成功",
|
||||
data=DeviceActionListResponse(
|
||||
total=total,
|
||||
device_actions=[DeviceActionResponse(**item) for item in action_list]
|
||||
)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
164
router/device_danger_router.py
Normal file
@ -0,0 +1,164 @@
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, Query, HTTPException, Path
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.device_danger_schema import (
|
||||
DeviceDangerResponse, DeviceDangerListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/devices/dangers",
|
||||
tags=["设备管理-危险记录"]
|
||||
)
|
||||
|
||||
# 获取危险记录列表
|
||||
@router.get("/", response_model=APIResponse, summary="获取设备危险记录列表(多条件筛选)")
|
||||
@encrypt_response()
|
||||
async def get_danger_list(
|
||||
page: int = Query(1, ge=1, description="页码、默认第1页"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
|
||||
client_ip: str = Query(None, max_length=100, description="按设备IP筛选"),
|
||||
danger_type: str = Query(None, max_length=255, alias="type", description="按危险类型筛选"),
|
||||
start_date: date = Query(None, description="按创建时间筛选(开始日期、格式YYYY-MM-DD)"),
|
||||
end_date: date = Query(None, description="按创建时间筛选(结束日期、格式YYYY-MM-DD)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 构建筛选条件
|
||||
where_clause = []
|
||||
params = []
|
||||
|
||||
if client_ip:
|
||||
where_clause.append("client_ip = %s")
|
||||
params.append(client_ip)
|
||||
if danger_type:
|
||||
where_clause.append("type = %s")
|
||||
params.append(danger_type)
|
||||
if start_date:
|
||||
where_clause.append("DATE(created_at) >= %s")
|
||||
params.append(start_date.strftime("%Y-%m-%d"))
|
||||
if end_date:
|
||||
where_clause.append("DATE(created_at) <= %s")
|
||||
params.append(end_date.strftime("%Y-%m-%d"))
|
||||
|
||||
# 1. 统计符合条件的总记录数
|
||||
count_query = "SELECT COUNT(*) AS total FROM device_danger"
|
||||
if where_clause:
|
||||
count_query += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_query, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 2. 分页查询记录(按创建时间倒序、最新的在前)
|
||||
offset = (page - 1) * page_size
|
||||
list_query = "SELECT * FROM device_danger"
|
||||
if where_clause:
|
||||
list_query += " WHERE " + " AND ".join(where_clause)
|
||||
list_query += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset]) # 追加分页参数
|
||||
|
||||
cursor.execute(list_query, params)
|
||||
danger_list = cursor.fetchall()
|
||||
|
||||
# 转换为响应模型
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取危险记录列表成功",
|
||||
data=DeviceDangerListResponse(
|
||||
total=total,
|
||||
dangers=[DeviceDangerResponse(**item) for item in danger_list]
|
||||
)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询危险记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 获取单个设备的所有危险记录
|
||||
@router.get("/device/{client_ip}", response_model=APIResponse, summary="获取单个设备的所有危险记录")
|
||||
# @encrypt_response()
|
||||
async def get_device_dangers(
|
||||
client_ip: str = Path(..., max_length=100, description="设备IP地址"),
|
||||
page: int = Query(1, ge=1, description="页码、默认第1页"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间")
|
||||
):
|
||||
# 先检查设备是否存在
|
||||
from service.device_danger_service import check_device_exist
|
||||
if not check_device_exist(client_ip):
|
||||
raise HTTPException(status_code=404, detail=f"IP为 {client_ip} 的设备不存在")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 1. 统计该设备的危险记录总数
|
||||
count_query = "SELECT COUNT(*) AS total FROM device_danger WHERE client_ip = %s"
|
||||
cursor.execute(count_query, (client_ip,))
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 2. 分页查询该设备的危险记录
|
||||
offset = (page - 1) * page_size
|
||||
list_query = """
|
||||
SELECT * FROM device_danger
|
||||
WHERE client_ip = %s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s OFFSET %s
|
||||
"""
|
||||
cursor.execute(list_query, (client_ip, page_size, offset))
|
||||
danger_list = cursor.fetchall()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"获取设备[{client_ip}]危险记录成功(共{total}条)",
|
||||
data=DeviceDangerListResponse(
|
||||
total=total,
|
||||
dangers=[DeviceDangerResponse(**item) for item in danger_list]
|
||||
)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询设备[{client_ip}]危险记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 根据ID获取单个危险记录详情
|
||||
# ------------------------------
|
||||
@router.get("/{danger_id}", response_model=APIResponse, summary="根据ID获取单个危险记录详情")
|
||||
@encrypt_response()
|
||||
async def get_danger_detail(
|
||||
danger_id: int = Path(..., ge=1, description="危险记录ID")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 查询单个危险记录
|
||||
query = "SELECT * FROM device_danger WHERE id = %s"
|
||||
cursor.execute(query, (danger_id,))
|
||||
danger = cursor.fetchone()
|
||||
|
||||
if not danger:
|
||||
raise HTTPException(status_code=404, detail=f"ID为 {danger_id} 的危险记录不存在")
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取危险记录详情成功",
|
||||
data=DeviceDangerResponse(**danger)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询危险记录详情失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
329
router/device_router.py
Normal file
@ -0,0 +1,329 @@
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, Query, HTTPException, Request, Path
|
||||
from mysql.connector import Error as MySQLError
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.device_schema import (
|
||||
DeviceCreateRequest, DeviceResponse, DeviceListResponse,
|
||||
DeviceStatusHistoryResponse, DeviceStatusHistoryListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
from service.device_service import update_online_status_by_ip
|
||||
from ws.ws import get_current_time_str, aes_encrypt, is_client_connected
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/devices",
|
||||
tags=["设备管理"]
|
||||
)
|
||||
|
||||
|
||||
# 创建设备信息接口
|
||||
@router.post("/add", response_model=APIResponse, summary="创建设备信息")
|
||||
@encrypt_response()
|
||||
async def create_device(device_data: DeviceCreateRequest, request: Request):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 检查设备是否已存在
|
||||
cursor.execute("SELECT * FROM devices WHERE client_ip = %s", (device_data.ip,))
|
||||
existing_device = cursor.fetchone()
|
||||
if existing_device:
|
||||
# 更新设备为在线状态
|
||||
from service.device_service import update_online_status_by_ip
|
||||
update_online_status_by_ip(client_ip=device_data.ip, online_status=1)
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"设备IP {device_data.ip} 已存在、返回已有设备信息",
|
||||
data=DeviceResponse(**existing_device)
|
||||
)
|
||||
|
||||
# 通过 User-Agent 判断设备类型
|
||||
user_agent = request.headers.get("User-Agent", "").lower()
|
||||
device_type = "unknown"
|
||||
if user_agent == "default":
|
||||
device_type = device_data.params.get("os") if (device_data.params and isinstance(device_data.params, dict)) else "unknown"
|
||||
elif "windows" in user_agent:
|
||||
device_type = "windows"
|
||||
elif "android" in user_agent:
|
||||
device_type = "android"
|
||||
elif "linux" in user_agent:
|
||||
device_type = "linux"
|
||||
|
||||
device_params_json = json.dumps(device_data.params) if device_data.params else None
|
||||
|
||||
# 插入新设备
|
||||
insert_query = """
|
||||
INSERT INTO devices
|
||||
(client_ip, hostname, device_online_status, device_type, alarm_count, params, is_need_handler)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
||||
"""
|
||||
cursor.execute(insert_query, (
|
||||
device_data.ip,
|
||||
device_data.hostname,
|
||||
0,
|
||||
device_type,
|
||||
0,
|
||||
device_params_json,
|
||||
0
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
# 获取新设备并返回
|
||||
device_id = cursor.lastrowid
|
||||
cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,))
|
||||
new_device = cursor.fetchone()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="设备创建成功",
|
||||
data=DeviceResponse(**new_device)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"创建设备失败: {str(e)}") from e
|
||||
except json.JSONDecodeError as e:
|
||||
raise Exception(f"设备参数JSON序列化失败: {str(e)}") from e
|
||||
except Exception as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 获取设备列表接口
|
||||
# ------------------------------
|
||||
@router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
|
||||
@encrypt_response()
|
||||
async def get_device_list(
|
||||
page: int = Query(1, ge=1, description="页码、默认第1页"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
|
||||
device_type: str = Query(None, description="按设备类型筛选"),
|
||||
online_status: int = Query(None, ge=0, le=1, description="按在线状态筛选")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
where_clause = []
|
||||
params = []
|
||||
|
||||
if device_type:
|
||||
where_clause.append("device_type = %s")
|
||||
params.append(device_type)
|
||||
if online_status is not None:
|
||||
where_clause.append("device_online_status = %s")
|
||||
params.append(online_status)
|
||||
|
||||
# 统计总数
|
||||
count_query = "SELECT COUNT(*) AS total FROM devices"
|
||||
if where_clause:
|
||||
count_query += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_query, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 分页查询列表
|
||||
offset = (page - 1) * page_size
|
||||
list_query = "SELECT * FROM devices"
|
||||
if where_clause:
|
||||
list_query += " WHERE " + " AND ".join(where_clause)
|
||||
list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset])
|
||||
|
||||
cursor.execute(list_query, params)
|
||||
device_list = cursor.fetchall()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取设备列表成功",
|
||||
data=DeviceListResponse(
|
||||
total=total,
|
||||
devices=[DeviceResponse(**device) for device in device_list]
|
||||
)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取设备列表失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 获取设备上下线记录接口
|
||||
# ------------------------------
|
||||
@router.get("/status-history", response_model=APIResponse, summary="获取设备上下线记录")
|
||||
@encrypt_response()
|
||||
async def get_device_status_history(
|
||||
client_ip: str = Query(None, description="客户端IP地址(非必填,为空时返回所有设备记录)"),
|
||||
page: int = Query(1, ge=1, description="页码、默认第1页"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
|
||||
start_date: date = Query(None, description="开始日期、格式YYYY-MM-DD"),
|
||||
end_date: date = Query(None, description="结束日期、格式YYYY-MM-DD")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 1. 检查设备是否存在(仅传IP时):强制指定Collation
|
||||
if client_ip is not None:
|
||||
# 关键调整1:WHERE条件中给d.client_ip指定Collation(与da一致或反之)
|
||||
check_query = """
|
||||
SELECT id FROM devices
|
||||
WHERE client_ip COLLATE utf8mb4_general_ci = %s COLLATE utf8mb4_general_ci
|
||||
"""
|
||||
cursor.execute(check_query, (client_ip,))
|
||||
device = cursor.fetchone()
|
||||
if not device:
|
||||
raise HTTPException(status_code=404, detail=f"客户端IP为 {client_ip} 的设备不存在")
|
||||
|
||||
# 2. 构建WHERE条件
|
||||
where_clause = []
|
||||
params = []
|
||||
|
||||
# 关键调整2:传IP时,强制指定da.client_ip的Collation
|
||||
if client_ip is not None:
|
||||
where_clause.append("da.client_ip COLLATE utf8mb4_general_ci = %s COLLATE utf8mb4_general_ci")
|
||||
params.append(client_ip)
|
||||
if start_date:
|
||||
where_clause.append("DATE(da.created_at) >= %s")
|
||||
params.append(start_date.strftime("%Y-%m-%d"))
|
||||
if end_date:
|
||||
where_clause.append("DATE(da.created_at) <= %s")
|
||||
params.append(end_date.strftime("%Y-%m-%d"))
|
||||
|
||||
# 3. 统计总数:JOIN时强制统一Collation
|
||||
count_query = """
|
||||
SELECT COUNT(*) AS total
|
||||
FROM device_action da
|
||||
LEFT JOIN devices d
|
||||
ON da.client_ip COLLATE utf8mb4_general_ci = d.client_ip COLLATE utf8mb4_general_ci
|
||||
"""
|
||||
if where_clause:
|
||||
count_query += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_query, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 4. 分页查询:JOIN时强制统一Collation
|
||||
offset = (page - 1) * page_size
|
||||
list_query = """
|
||||
SELECT da.*, d.id AS device_id
|
||||
FROM device_action da
|
||||
LEFT JOIN devices d
|
||||
ON da.client_ip COLLATE utf8mb4_general_ci = d.client_ip COLLATE utf8mb4_general_ci
|
||||
"""
|
||||
if where_clause:
|
||||
list_query += " WHERE " + " AND ".join(where_clause)
|
||||
list_query += " ORDER BY da.created_at DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset])
|
||||
cursor.execute(list_query, params)
|
||||
history_list = cursor.fetchall()
|
||||
|
||||
# 后续格式化响应逻辑不变...
|
||||
formatted_history = []
|
||||
for item in history_list:
|
||||
formatted_item = {
|
||||
"id": item["id"],
|
||||
"device_id": item["device_id"], # 可能为None(IP无对应设备)
|
||||
"client_ip": item["client_ip"],
|
||||
"status": item["action"],
|
||||
"status_time": item["created_at"]
|
||||
}
|
||||
formatted_history.append(formatted_item)
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取设备上下线记录成功",
|
||||
data=DeviceStatusHistoryListResponse(
|
||||
total=total,
|
||||
history=[DeviceStatusHistoryResponse(**item) for item in formatted_history]
|
||||
)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取设备上下线记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
# ------------------------------
|
||||
# 通过客户端IP设置设备is_need_handler为0接口
|
||||
# ------------------------------
|
||||
@router.post("/need-handler/reset", response_model=APIResponse, summary="解封客户端")
|
||||
@encrypt_response()
|
||||
async def reset_device_need_handler(
|
||||
client_ip: str = Query(..., description="目标设备的客户端IP地址(必填)")
|
||||
):
|
||||
try:
|
||||
from service.device_service import update_is_need_handler_by_client_ip
|
||||
success = update_is_need_handler_by_client_ip(
|
||||
client_ip=client_ip,
|
||||
is_need_handler=0 # 固定设置为0(不需要处理)
|
||||
)
|
||||
|
||||
if success:
|
||||
online_status = is_client_connected(client_ip)
|
||||
|
||||
# 如果设备在线,则发送消息给前端
|
||||
if online_status:
|
||||
# 调用 ws 发送一个消息给前端、告诉他已解锁
|
||||
unlock_msg = {
|
||||
"type": "unlock",
|
||||
"timestamp": get_current_time_str(),
|
||||
"client_ip": client_ip
|
||||
}
|
||||
from ws.ws import send_message_to_client
|
||||
await send_message_to_client(client_ip, json.dumps(unlock_msg))
|
||||
|
||||
# 休眠 100 ms
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
frame_permit_msg = {
|
||||
"type": "frame",
|
||||
"timestamp": get_current_time_str(),
|
||||
"client_ip": client_ip
|
||||
}
|
||||
await send_message_to_client(client_ip, json.dumps(frame_permit_msg))
|
||||
|
||||
# 更新设备在线状态为1
|
||||
update_online_status_by_ip(client_ip, 1)
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"设备已解封",
|
||||
data={
|
||||
"client_ip": client_ip,
|
||||
"is_need_handler": 0,
|
||||
"status_desc": "设备已解封"
|
||||
}
|
||||
)
|
||||
|
||||
# 捕获工具方法抛出的业务异常(如IP为空、设备不存在)
|
||||
except ValueError as e:
|
||||
# 业务异常返回400/404状态码(与现有接口异常规范一致)
|
||||
raise HTTPException(
|
||||
status_code=404 if "设备不存在" in str(e) else 400,
|
||||
detail=str(e)
|
||||
) from e
|
||||
|
||||
# 捕获数据库层面异常(如连接失败、SQL执行错误)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"设置is_need_handler失败:数据库操作异常 - {str(e)}") from e
|
||||
|
||||
# 捕获其他未知异常
|
||||
except Exception as e:
|
||||
raise Exception(f"设置is_need_handler失败:未知错误 - {str(e)}") from e
|
||||
|
||||
|
||||
|
||||
326
router/face_router.py
Normal file
@ -0,0 +1,326 @@
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, Request
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.face_schema import (
|
||||
FaceCreateRequest,
|
||||
FaceResponse,
|
||||
FaceListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
from service.face_service import update_face_data
|
||||
from util.face_util import add_binary_data
|
||||
from service.file_service import save_source_file
|
||||
|
||||
router = APIRouter(prefix="/api/faces", tags=["人脸管理"])
|
||||
|
||||
|
||||
# 创建人脸记录
|
||||
@router.post("", response_model=APIResponse, summary="创建人脸记录")
|
||||
@encrypt_response()
|
||||
async def create_face(
|
||||
request: Request,
|
||||
name: str = Form(None, max_length=255, description="名称(可选)"),
|
||||
file: UploadFile = File(..., description="人脸文件(必传)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
face_create = FaceCreateRequest(name=name)
|
||||
client_ip = request.client.host if request.client else ""
|
||||
if not client_ip:
|
||||
raise HTTPException(status_code=400, detail="无法获取客户端IP")
|
||||
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
# 先读取文件内容
|
||||
file_content = await file.read()
|
||||
# 将文件指针重置到开头
|
||||
await file.seek(0)
|
||||
# 再保存文件
|
||||
path = save_source_file(file, "face")
|
||||
# 提取人脸特征
|
||||
detect_success, detect_result = add_binary_data(file_content)
|
||||
if not detect_success:
|
||||
raise HTTPException(status_code=400, detail=f"人脸检测失败:{detect_result}")
|
||||
eigenvalue = detect_result
|
||||
|
||||
# 插入数据库
|
||||
insert_query = """
|
||||
INSERT INTO face (name, eigenvalue, address)
|
||||
VALUES (%s, %s, %s)
|
||||
"""
|
||||
cursor.execute(insert_query, (face_create.name, str(eigenvalue), path))
|
||||
conn.commit()
|
||||
|
||||
# 查询新记录
|
||||
cursor.execute("""
|
||||
SELECT id, name, address, created_at, updated_at
|
||||
FROM face
|
||||
WHERE id = LAST_INSERT_ID()
|
||||
""")
|
||||
created_face = cursor.fetchone()
|
||||
if not created_face:
|
||||
raise HTTPException(status_code=500, detail="创建成功但无法获取记录")
|
||||
|
||||
|
||||
# TODO 重新加载人脸模型
|
||||
update_face_data()
|
||||
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"人脸记录创建成功(ID: {created_face['id']})",
|
||||
data=FaceResponse(**created_face)
|
||||
)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"创建失败: {str(e)}") from e
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
finally:
|
||||
await file.close()
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 获取单个人脸记录
|
||||
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
|
||||
@encrypt_response()
|
||||
async def get_face(face_id: int):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
query = """
|
||||
SELECT id, name, address, created_at, updated_at
|
||||
FROM face
|
||||
WHERE id = %s
|
||||
"""
|
||||
cursor.execute(query, (face_id,))
|
||||
face = cursor.fetchone()
|
||||
|
||||
if not face:
|
||||
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="查询成功",
|
||||
data=FaceResponse(**face)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 获取人脸列表
|
||||
@router.get("", response_model=APIResponse, summary="获取人脸列表(分页+筛选)")
|
||||
@encrypt_response()
|
||||
async def get_face_list(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, le=100),
|
||||
name: str = Query(None),
|
||||
has_eigenvalue: bool = Query(None)
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
where_clause = []
|
||||
params = []
|
||||
if name:
|
||||
where_clause.append("name LIKE %s")
|
||||
params.append(f"%{name}%")
|
||||
if has_eigenvalue is not None:
|
||||
where_clause.append("eigenvalue IS NOT NULL" if has_eigenvalue else "eigenvalue IS NULL")
|
||||
|
||||
# 总记录数
|
||||
count_query = "SELECT COUNT(*) AS total FROM face"
|
||||
if where_clause:
|
||||
count_query += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_query, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 列表数据
|
||||
offset = (page - 1) * page_size
|
||||
list_query = """
|
||||
SELECT id, name, address, created_at, updated_at
|
||||
FROM face
|
||||
"""
|
||||
if where_clause:
|
||||
list_query += " WHERE " + " AND ".join(where_clause)
|
||||
list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset])
|
||||
|
||||
cursor.execute(list_query, params)
|
||||
face_list = cursor.fetchall()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"获取成功(共{total}条)",
|
||||
data=FaceListResponse(
|
||||
total=total,
|
||||
faces=[FaceResponse(**face) for face in face_list]
|
||||
)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 删除人脸记录
|
||||
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
|
||||
@encrypt_response()
|
||||
async def delete_face(face_id: int):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
|
||||
exist_face = cursor.fetchone()
|
||||
if not exist_face:
|
||||
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
|
||||
old_db_path = exist_face["address"]
|
||||
|
||||
cursor.execute("DELETE FROM face WHERE id = %s", (face_id,))
|
||||
conn.commit()
|
||||
|
||||
# 删除图片
|
||||
if old_db_path:
|
||||
old_abs_path = Path(old_db_path).resolve()
|
||||
if old_abs_path.exists():
|
||||
try:
|
||||
old_abs_path.unlink()
|
||||
print(f"[FaceRouter] 已删除图片:{old_abs_path}")
|
||||
extra_msg = "(已同步删除图片)"
|
||||
except Exception as e:
|
||||
print(f"[FaceRouter] 删除图片失败:{str(e)}")
|
||||
extra_msg = "(图片删除失败)"
|
||||
else:
|
||||
extra_msg = "(图片不存在)"
|
||||
else:
|
||||
extra_msg = "(无关联图片)"
|
||||
|
||||
|
||||
# TODO 重新加载人脸模型
|
||||
update_face_data()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"ID为 {face_id} 的记录删除成功 {extra_msg}",
|
||||
data=None
|
||||
)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
@router.post("/batch-import", response_model=APIResponse, summary="批量导入文件夹下的人脸图片")
|
||||
# @encrypt_response()
|
||||
async def batch_import_faces(
|
||||
folder_path: str = Form(..., description="人脸图片所在的**服务器本地文件夹路径**")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
success_count = 0 # 成功导入数量
|
||||
fail_list = [] # 失败记录(包含文件名、错误原因)
|
||||
try:
|
||||
# 1. 验证文件夹有效性
|
||||
folder = Path(folder_path)
|
||||
if not folder.exists() or not folder.is_dir():
|
||||
raise HTTPException(status_code=400, detail=f"文件夹 {folder_path} 不存在或不是有效目录")
|
||||
|
||||
# 2. 定义支持的图片格式
|
||||
supported_extensions = {".png", ".jpg", ".jpeg", ".webp"}
|
||||
|
||||
# 3. 数据库连接初始化
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 4. 遍历文件夹内所有文件
|
||||
for file_path in folder.iterdir():
|
||||
if file_path.is_file() and file_path.suffix.lower() in supported_extensions:
|
||||
file_name = file_path.stem # 提取文件名(不含后缀)作为 `name`
|
||||
try:
|
||||
# 4.1 读取文件二进制内容
|
||||
with open(file_path, "rb") as f:
|
||||
file_content = f.read()
|
||||
|
||||
# 4.2 构造模拟的 UploadFile 对象(用于兼容 `save_source_file`)
|
||||
mock_file = UploadFile(
|
||||
filename=file_path.name,
|
||||
file=BytesIO(file_content)
|
||||
)
|
||||
|
||||
# 4.3 保存文件到指定目录
|
||||
saved_path = save_source_file(mock_file, "face")
|
||||
|
||||
# 4.4 提取人脸特征
|
||||
detect_success, detect_result = add_binary_data(file_content)
|
||||
if not detect_success:
|
||||
fail_list.append({
|
||||
"name": file_name,
|
||||
"file_path": str(file_path),
|
||||
"error": f"人脸检测失败:{detect_result}"
|
||||
})
|
||||
continue # 跳过当前文件,处理下一个
|
||||
eigenvalue = detect_result
|
||||
|
||||
# 4.5 插入数据库
|
||||
insert_sql = """
|
||||
INSERT INTO face (name, eigenvalue, address)
|
||||
VALUES (%s, %s, %s)
|
||||
"""
|
||||
cursor.execute(insert_sql, (file_name, str(eigenvalue), saved_path))
|
||||
conn.commit() # 提交当前文件的插入操作
|
||||
|
||||
success_count += 1
|
||||
|
||||
except Exception as e:
|
||||
# 捕获单文件处理的异常,记录后继续处理其他文件
|
||||
fail_list.append({
|
||||
"name": file_name,
|
||||
"file_path": str(file_path),
|
||||
"error": str(e)
|
||||
})
|
||||
if conn:
|
||||
conn.rollback() # 回滚当前失败文件的插入
|
||||
|
||||
# 5. 重新加载人脸模型(确保新增数据生效)
|
||||
update_face_data()
|
||||
|
||||
# 6. 构造返回结果
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"批量导入完成,成功 {success_count} 条,失败 {len(fail_list)} 条",
|
||||
data={
|
||||
"success_count": success_count,
|
||||
"fail_details": fail_list
|
||||
}
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"数据库操作失败: {str(e)}") from e
|
||||
except HTTPException:
|
||||
raise # 直接抛出400等由业务逻辑触发的HTTP异常
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
31
router/file_router.py
Normal file
@ -0,0 +1,31 @@
|
||||
import os
|
||||
from fastapi import FastAPI, HTTPException, Path, APIRouter
|
||||
from fastapi.responses import FileResponse
|
||||
from service.file_service import UPLOAD_ROOT
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/file",
|
||||
tags=["文件管理"]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/download/{relative_path:path}", summary="下载文件")
|
||||
async def download_file(
|
||||
relative_path: str = Path(..., description="文件的相对路径")
|
||||
):
|
||||
file_path = os.path.abspath(os.path.join(UPLOAD_ROOT, relative_path))
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
raise HTTPException(status_code=404, detail=f"文件不存在: {file_path}")
|
||||
|
||||
if not os.path.isfile(file_path):
|
||||
raise HTTPException(status_code=400, detail="路径指向的不是文件")
|
||||
|
||||
if not file_path.startswith(os.path.abspath(UPLOAD_ROOT)):
|
||||
raise HTTPException(status_code=403, detail="无权访问该文件")
|
||||
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
filename=os.path.basename(file_path),
|
||||
media_type="application/octet-stream"
|
||||
)
|
||||
269
router/model_router.py
Normal file
@ -0,0 +1,269 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from service.file_service import save_source_file
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.model_schema import (
|
||||
ModelResponse,
|
||||
ModelListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
from service.model_service import ALLOWED_MODEL_EXT, MAX_MODEL_SIZE, load_yolo_model
|
||||
|
||||
router = APIRouter(prefix="/api/models", tags=["模型管理"])
|
||||
|
||||
|
||||
# 上传模型
|
||||
@router.post("", response_model=APIResponse, summary="上传YOLO模型(.pt格式)")
|
||||
@encrypt_response()
|
||||
async def upload_model(
|
||||
name: str = Form(..., description="模型名称"),
|
||||
description: str = Form(None, description="模型描述"),
|
||||
file: UploadFile = File(..., description=f"YOLO模型文件(.pt、最大{MAX_MODEL_SIZE // 1024 // 1024}MB)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
# 校验文件格式
|
||||
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else ""
|
||||
if file_ext not in ALLOWED_MODEL_EXT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"仅支持{ALLOWED_MODEL_EXT}格式、当前:{file_ext}"
|
||||
)
|
||||
|
||||
# 校验文件大小
|
||||
if file.size > MAX_MODEL_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"文件过大!最大{MAX_MODEL_SIZE // 1024 // 1024}MB、当前{file.size // 1024 // 1024}MB"
|
||||
)
|
||||
# 保存文件
|
||||
file_path = save_source_file(file, "model")
|
||||
|
||||
# 数据库操作
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
insert_sql = """
|
||||
INSERT INTO model (name, path, is_default, description, file_size)
|
||||
VALUES (%s, %s, 0, %s, %s)
|
||||
"""
|
||||
cursor.execute(insert_sql, (name, file_path, description, file.size))
|
||||
conn.commit()
|
||||
|
||||
# 获取新增记录
|
||||
cursor.execute("SELECT * FROM model WHERE id = LAST_INSERT_ID()")
|
||||
new_model = cursor.fetchone()
|
||||
if not new_model:
|
||||
raise HTTPException(status_code=500, detail="上传成功但无法获取记录")
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"模型上传成功",
|
||||
data=ModelResponse(**new_model)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") from e
|
||||
finally:
|
||||
await file.close()
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 获取模型列表
|
||||
@router.get("", response_model=APIResponse, summary="获取模型列表(分页)")
|
||||
@encrypt_response()
|
||||
async def get_model_list(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, le=100),
|
||||
name: str = Query(None),
|
||||
is_default: bool = Query(None)
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
where_clause = []
|
||||
params = []
|
||||
if name:
|
||||
where_clause.append("name LIKE %s")
|
||||
params.append(f"%{name}%")
|
||||
if is_default is not None:
|
||||
where_clause.append("is_default = %s")
|
||||
params.append(1 if is_default else 0)
|
||||
|
||||
# 总记录数
|
||||
count_sql = "SELECT COUNT(*) AS total FROM model"
|
||||
if where_clause:
|
||||
count_sql += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_sql, params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 分页数据
|
||||
offset = (page - 1) * page_size
|
||||
list_sql = "SELECT * FROM model"
|
||||
if where_clause:
|
||||
list_sql += " WHERE " + " AND ".join(where_clause)
|
||||
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
|
||||
params.extend([page_size, offset])
|
||||
|
||||
cursor.execute(list_sql, params)
|
||||
model_list = cursor.fetchall()
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"获取成功!",
|
||||
data=ModelListResponse(
|
||||
total=total,
|
||||
models=[ModelResponse(**model) for model in model_list]
|
||||
)
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 更换默认模型
|
||||
@router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型")
|
||||
@encrypt_response()
|
||||
async def set_default_model(
|
||||
model_id: int
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
conn.autocommit = False
|
||||
|
||||
# 校验目标模型是否存在
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
target_model = cursor.fetchone()
|
||||
if not target_model:
|
||||
raise HTTPException(status_code=404, detail=f"目标模型不存在!")
|
||||
|
||||
# 检查是否已为默认模型
|
||||
if target_model["is_default"]:
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"已是默认模型、无需更换",
|
||||
data=ModelResponse(**target_model)
|
||||
)
|
||||
|
||||
# 数据库事务:更新默认模型状态
|
||||
try:
|
||||
cursor.execute("UPDATE model SET is_default = 0, updated_at = CURRENT_TIMESTAMP")
|
||||
cursor.execute(
|
||||
"UPDATE model SET is_default = 1, updated_at = CURRENT_TIMESTAMP WHERE id = %s",
|
||||
(model_id,)
|
||||
)
|
||||
conn.commit()
|
||||
except MySQLError as e:
|
||||
conn.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"更新默认模型状态失败(已回滚):{str(e)}"
|
||||
) from e
|
||||
|
||||
# 更新模型
|
||||
load_yolo_model()
|
||||
# 返回成功响应
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"更换成功",
|
||||
data=None
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
if conn:
|
||||
conn.autocommit = True
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 路由文件(如 model_router.py)中的删除接口
|
||||
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
|
||||
@encrypt_response()
|
||||
async def delete_model(model_id: int):
|
||||
# 1. 正确导入 model_service 中的全局变量(关键修复:变量名匹配)
|
||||
from service.model_service import (
|
||||
current_yolo_model,
|
||||
current_model_absolute_path,
|
||||
load_yolo_model # 用于删除后重新加载模型(可选)
|
||||
)
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 2. 查询待删除模型信息
|
||||
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
|
||||
exist_model = cursor.fetchone()
|
||||
if not exist_model:
|
||||
raise HTTPException(status_code=404, detail=f"模型不存在!")
|
||||
|
||||
# 3. 关键判断:①默认模型不可删 ②正在使用的模型不可删
|
||||
if exist_model["is_default"]:
|
||||
raise HTTPException(status_code=400, detail="默认模型不可删除!")
|
||||
|
||||
# 计算待删除模型的绝对路径(与 model_service 逻辑一致)
|
||||
from service.file_service import get_absolute_path
|
||||
del_model_abs_path = get_absolute_path(exist_model["path"])
|
||||
|
||||
# 判断是否正在使用(对比 current_model_absolute_path)
|
||||
if current_model_absolute_path and del_model_abs_path == current_model_absolute_path:
|
||||
raise HTTPException(status_code=400, detail="该模型正在使用中,禁止删除!")
|
||||
|
||||
# 4. 先删除数据库记录(避免文件删除失败导致数据不一致)
|
||||
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
|
||||
conn.commit()
|
||||
|
||||
# 5. 再删除本地文件(捕获文件删除异常,不影响数据库删除结果)
|
||||
extra_msg = ""
|
||||
try:
|
||||
if os.path.exists(del_model_abs_path):
|
||||
os.remove(del_model_abs_path) # 或用 Path(del_model_abs_path).unlink()
|
||||
extra_msg = "(本地文件已同步删除)"
|
||||
else:
|
||||
extra_msg = "(本地文件不存在,无需删除)"
|
||||
except Exception as e:
|
||||
extra_msg = f"(本地文件删除失败:{str(e)})"
|
||||
|
||||
# 6. 若删除后当前模型为空(极端情况),重新加载默认模型(可选优化)
|
||||
if current_yolo_model is None:
|
||||
try:
|
||||
load_yolo_model()
|
||||
print(f"[模型删除后] 重新加载默认模型成功")
|
||||
except Exception as e:
|
||||
print(f"[模型删除后] 重新加载默认模型失败:{str(e)}")
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"模型删除成功!",
|
||||
data=None
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
306
router/sensitive_router.py
Normal file
@ -0,0 +1,306 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, File, UploadFile
|
||||
from mysql.connector import Error as MySQLError
|
||||
from typing import Optional
|
||||
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from schema.sensitive_schema import (
|
||||
SensitiveCreateRequest,
|
||||
SensitiveResponse,
|
||||
SensitiveListResponse
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
from middle.auth_middleware import get_current_user
|
||||
from schema.user_schema import UserResponse
|
||||
from service.ocr_service import set_forbidden_words
|
||||
from service.sensitive_service import get_all_sensitive_words
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/sensitives",
|
||||
tags=["敏感信息管理"]
|
||||
)
|
||||
|
||||
|
||||
# 创建敏感信息记录
|
||||
@router.post("", response_model=APIResponse, summary="创建敏感信息记录")
|
||||
@encrypt_response()
|
||||
async def create_sensitive(
|
||||
sensitive: SensitiveCreateRequest,
|
||||
current_user: UserResponse = Depends(get_current_user)
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 插入新敏感信息记录到数据库(不包含ID、由数据库自动生成)
|
||||
insert_query = """
|
||||
INSERT INTO sensitives (name, created_at, updated_at)
|
||||
VALUES (%s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
"""
|
||||
cursor.execute(insert_query, (sensitive.name,))
|
||||
conn.commit()
|
||||
|
||||
# 获取刚插入记录的ID(使用LAST_INSERT_ID()函数)
|
||||
new_id = cursor.lastrowid
|
||||
|
||||
# 查询刚创建的记录并返回
|
||||
select_query = "SELECT * FROM sensitives WHERE id = %s"
|
||||
cursor.execute(select_query, (new_id,))
|
||||
created_sensitive = cursor.fetchone()
|
||||
|
||||
# 重新加载最新的敏感词
|
||||
set_forbidden_words(get_all_sensitive_words())
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="敏感信息记录创建成功",
|
||||
data=SensitiveResponse(**created_sensitive)
|
||||
)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"创建敏感信息记录失败: {str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 获取敏感信息分页列表
|
||||
@router.get("", response_model=APIResponse, summary="获取敏感信息分页列表(支持关键词搜索)")
|
||||
@encrypt_response()
|
||||
async def get_sensitive_list(
|
||||
page: int = Query(1, ge=1, description="页码(默认1、最小1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数(默认10、1-100)"),
|
||||
name: Optional[str] = Query(None, description="敏感词关键词搜索(模糊匹配)")
|
||||
):
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 1. 构建查询条件(支持关键词搜索)
|
||||
where_clause = []
|
||||
params = []
|
||||
if name:
|
||||
where_clause.append("name LIKE %s")
|
||||
params.append(f"%{name}%") # 模糊匹配关键词
|
||||
|
||||
# 2. 查询总记录数(用于分页计算)
|
||||
count_sql = "SELECT COUNT(*) AS total FROM sensitives"
|
||||
if where_clause:
|
||||
count_sql += " WHERE " + " AND ".join(where_clause)
|
||||
cursor.execute(count_sql, params.copy()) # 复制参数列表、避免后续污染
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 3. 计算分页偏移量
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 4. 分页查询敏感词数据(按更新时间倒序、最新的在前)
|
||||
list_sql = "SELECT * FROM sensitives"
|
||||
if where_clause:
|
||||
list_sql += " WHERE " + " AND ".join(where_clause)
|
||||
# 排序+分页(LIMIT 条数 OFFSET 偏移量)
|
||||
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
|
||||
# 补充分页参数(page_size和offset)
|
||||
params.extend([page_size, offset])
|
||||
|
||||
cursor.execute(list_sql, params)
|
||||
sensitive_list = cursor.fetchall()
|
||||
|
||||
# 5. 构造分页响应数据
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"敏感信息列表查询成功(共{total}条记录、当前第{page}页)",
|
||||
data=SensitiveListResponse(
|
||||
total=total,
|
||||
sensitives=[SensitiveResponse(**item) for item in sensitive_list]
|
||||
)
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"查询敏感信息列表失败: {str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 删除敏感信息记录
|
||||
@router.delete("/{sensitive_id}", response_model=APIResponse, summary="删除敏感信息记录")
|
||||
@encrypt_response()
|
||||
async def delete_sensitive(
|
||||
sensitive_id: int,
|
||||
current_user: UserResponse = Depends(get_current_user) # 需登录认证
|
||||
):
|
||||
"""
|
||||
删除敏感信息记录:
|
||||
- 需登录认证
|
||||
- 根据ID删除敏感信息记录
|
||||
- 返回删除成功信息
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 1. 检查记录是否存在
|
||||
check_query = "SELECT id FROM sensitives WHERE id = %s"
|
||||
cursor.execute(check_query, (sensitive_id,))
|
||||
existing_sensitive = cursor.fetchone()
|
||||
if not existing_sensitive:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"ID为 {sensitive_id} 的敏感信息记录不存在"
|
||||
)
|
||||
|
||||
# 2. 执行删除操作
|
||||
delete_query = "DELETE FROM sensitives WHERE id = %s"
|
||||
cursor.execute(delete_query, (sensitive_id,))
|
||||
conn.commit()
|
||||
|
||||
# 重新加载最新的敏感词
|
||||
set_forbidden_words(get_all_sensitive_words())
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"ID为 {sensitive_id} 的敏感信息记录删除成功",
|
||||
data=None
|
||||
)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"删除敏感信息记录失败: {str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 批量导入敏感信息(从txt文件)
|
||||
@router.post("/batch-import", response_model=APIResponse, summary="批量导入敏感信息(从txt文件)")
|
||||
@encrypt_response()
|
||||
async def batch_import_sensitives(
|
||||
file: UploadFile = File(..., description="包含敏感词的txt文件,每行一个敏感词"),
|
||||
# current_user: UserResponse = Depends(get_current_user) # 添加认证依赖
|
||||
):
|
||||
"""
|
||||
批量导入敏感信息:
|
||||
- 需登录认证
|
||||
- 接收txt文件,文件中每行一个敏感词
|
||||
- 批量插入到数据库中(仅插入不存在的敏感词)
|
||||
- 返回导入结果统计
|
||||
"""
|
||||
# 检查文件类型
|
||||
filename = file.filename or ""
|
||||
if not filename.lower().endswith(".txt"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"请上传txt格式的文件,当前文件格式: {filename.split('.')[-1] if '.' in filename else '未知'}"
|
||||
)
|
||||
|
||||
# 检查文件大小
|
||||
file_size = await file.read(1) # 读取1字节获取文件信息
|
||||
await file.seek(0) # 重置文件指针
|
||||
if not file_size: # 文件为空
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="上传的文件为空,请提供有效的敏感词文件"
|
||||
)
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
# 读取文件内容
|
||||
contents = await file.read()
|
||||
# 按行分割内容,处理不同操作系统的换行符
|
||||
lines = contents.decode("utf-8", errors="replace").splitlines()
|
||||
|
||||
# 过滤空行和仅含空白字符的行
|
||||
sensitive_words = [line.strip() for line in lines if line.strip()]
|
||||
|
||||
if not sensitive_words:
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="文件中没有有效的敏感词",
|
||||
data={"imported": 0, "total": 0}
|
||||
)
|
||||
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 先查询数据库中已存在的敏感词
|
||||
query = "SELECT name FROM sensitives WHERE name IN (%s)"
|
||||
# 处理参数,根据敏感词数量生成占位符
|
||||
placeholders = ', '.join(['%s'] * len(sensitive_words))
|
||||
cursor.execute(query % placeholders, sensitive_words)
|
||||
existing_words = {row['name'] for row in cursor.fetchall()}
|
||||
|
||||
# 过滤掉已存在的敏感词
|
||||
new_words = [word for word in sensitive_words if word not in existing_words]
|
||||
|
||||
if not new_words:
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="所有敏感词均已存在于数据库中",
|
||||
data={
|
||||
"total": len(sensitive_words),
|
||||
"imported": 0,
|
||||
"duplicates": len(sensitive_words),
|
||||
"message": f"共处理{len(sensitive_words)}个敏感词,全部已存在,未导入任何新敏感词"
|
||||
}
|
||||
)
|
||||
|
||||
# 批量插入新的敏感词
|
||||
insert_query = """
|
||||
INSERT INTO sensitives (name, created_at, updated_at)
|
||||
VALUES (%s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
"""
|
||||
|
||||
# 准备参数列表
|
||||
params = [(word,) for word in new_words]
|
||||
|
||||
# 执行批量插入
|
||||
cursor.executemany(insert_query, params)
|
||||
conn.commit()
|
||||
|
||||
# 重新加载最新的敏感词
|
||||
set_forbidden_words(get_all_sensitive_words())
|
||||
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message=f"敏感词批量导入成功",
|
||||
data={
|
||||
"total": len(sensitive_words),
|
||||
"imported": len(new_words),
|
||||
"duplicates": len(sensitive_words) - len(new_words),
|
||||
"message": f"共处理{len(sensitive_words)}个敏感词,成功导入{len(new_words)}个,{len(sensitive_words) - len(new_words)}个已存在"
|
||||
}
|
||||
)
|
||||
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="文件编码格式错误,请使用UTF-8编码的txt文件"
|
||||
)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"批量导入敏感词失败: {str(e)}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"处理文件时发生错误: {str(e)}"
|
||||
) from e
|
||||
finally:
|
||||
await file.close()
|
||||
db.close_connection(conn, cursor)
|
||||
246
router/user_router.py
Normal file
@ -0,0 +1,246 @@
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
from encryption.encrypt_decorator import encrypt_response
|
||||
from middle.auth_middleware import (
|
||||
get_password_hash,
|
||||
verify_password,
|
||||
create_access_token,
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
get_current_user
|
||||
)
|
||||
from schema.response_schema import APIResponse
|
||||
from schema.user_schema import UserRegisterRequest, UserLoginRequest, UserResponse
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/users",
|
||||
tags=["用户管理"]
|
||||
)
|
||||
|
||||
|
||||
# 用户注册接口
|
||||
@router.post("/register", response_model=APIResponse, summary="用户注册")
|
||||
@encrypt_response()
|
||||
async def user_register(request: UserRegisterRequest):
|
||||
"""
|
||||
用户注册:
|
||||
- 校验用户名是否已存在
|
||||
- 加密密码后插入数据库
|
||||
- 返回注册成功信息
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 1. 检查用户名是否已存在(唯一索引)
|
||||
check_query = "SELECT username FROM users WHERE username = %s"
|
||||
cursor.execute(check_query, (request.username,))
|
||||
existing_user = cursor.fetchone()
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"用户名 '{request.username}' 已存在、请更换其他用户名"
|
||||
)
|
||||
|
||||
# 2. 加密密码
|
||||
hashed_password = get_password_hash(request.password)
|
||||
|
||||
# 3. 插入新用户到数据库
|
||||
insert_query = """
|
||||
INSERT INTO users (username, password)
|
||||
VALUES (%s, %s)
|
||||
"""
|
||||
cursor.execute(insert_query, (request.username, hashed_password))
|
||||
conn.commit() # 提交事务
|
||||
|
||||
# 4. 返回注册成功响应
|
||||
return APIResponse(
|
||||
code=200, # 200 表示资源创建成功
|
||||
message=f"用户 '{request.username}' 注册成功",
|
||||
data=None
|
||||
)
|
||||
except MySQLError as e:
|
||||
conn.rollback() # 数据库错误时回滚事务
|
||||
raise Exception(f"注册失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 用户登录接口
|
||||
@router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token)")
|
||||
@encrypt_response()
|
||||
async def user_login(request: UserLoginRequest):
|
||||
"""
|
||||
用户登录:
|
||||
- 校验用户名是否存在
|
||||
- 校验密码是否正确
|
||||
- 生成 JWT Token 并返回
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 修复: SQL查询添加 created_at 和 updated_at 字段
|
||||
query = """
|
||||
SELECT id, username, password, created_at, updated_at
|
||||
FROM users
|
||||
WHERE username = %s
|
||||
"""
|
||||
cursor.execute(query, (request.username,))
|
||||
user = cursor.fetchone()
|
||||
|
||||
# 2. 校验用户名和密码
|
||||
if not user or not verify_password(request.password, user["password"]):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="用户名或密码错误",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# 3. 生成 Token(过期时间从配置读取)
|
||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = create_access_token(
|
||||
data={"sub": user["username"]},
|
||||
expires_delta=access_token_expires
|
||||
)
|
||||
|
||||
# 4. 返回 Token 和用户基本信息
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="登录成功",
|
||||
data={
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"user": UserResponse(
|
||||
id=user["id"],
|
||||
username=user["username"],
|
||||
created_at=user.get("created_at"),
|
||||
updated_at=user.get("updated_at")
|
||||
)
|
||||
}
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"登录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 获取当前登录用户信息(需认证)
|
||||
@router.get("/me", response_model=APIResponse, summary="获取当前用户信息")
|
||||
@encrypt_response()
|
||||
async def get_current_user_info(
|
||||
current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件
|
||||
):
|
||||
"""
|
||||
获取当前登录用户信息:
|
||||
- 需在请求头携带 Token(格式: Bearer <token>)
|
||||
- 认证通过后返回用户信息
|
||||
"""
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取用户信息成功",
|
||||
data=current_user
|
||||
)
|
||||
|
||||
|
||||
# 获取用户列表(仅需登录权限)
|
||||
@router.get("/list", response_model=APIResponse, summary="获取用户列表")
|
||||
@encrypt_response()
|
||||
async def get_user_list(
|
||||
page: int = Query(1, ge=1, description="页码、从1开始"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
|
||||
username: Optional[str] = Query(None, description="用户名模糊搜索"),
|
||||
current_user: UserResponse = Depends(get_current_user) # 仅需登录即可访问(移除管理员校验)
|
||||
):
|
||||
"""
|
||||
获取用户列表:
|
||||
- 需登录权限(请求头携带 Token: Bearer <token>)
|
||||
- 支持分页查询(page=页码、page_size=每页条数)
|
||||
- 支持用户名模糊搜索(如输入"test"可匹配"test123"、"admin_test"等)
|
||||
- 仅返回用户ID、用户名、创建时间、更新时间(不包含密码等敏感信息)
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 计算分页偏移量(page从1开始、偏移量=(页码-1)*每页条数)
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 基础查询(仅查非敏感字段)
|
||||
base_query = """
|
||||
SELECT id, username, created_at, updated_at
|
||||
FROM users
|
||||
"""
|
||||
# 总条数查询(用于分页计算)
|
||||
count_query = "SELECT COUNT(*) as total FROM users"
|
||||
|
||||
# 条件拼接(支持用户名模糊搜索)
|
||||
conditions = []
|
||||
params = []
|
||||
if username:
|
||||
conditions.append("username LIKE %s")
|
||||
params.append(f"%{username}%") # 模糊匹配:%表示任意字符
|
||||
|
||||
# 构建最终查询语句
|
||||
if conditions:
|
||||
where_clause = " WHERE " + " AND ".join(conditions)
|
||||
final_query = f"{base_query}{where_clause} LIMIT %s OFFSET %s"
|
||||
final_count_query = f"{count_query}{where_clause}"
|
||||
params.extend([page_size, offset]) # 追加分页参数
|
||||
else:
|
||||
final_query = f"{base_query} LIMIT %s OFFSET %s"
|
||||
final_count_query = count_query
|
||||
params = [page_size, offset]
|
||||
|
||||
# 1. 查询用户列表数据
|
||||
cursor.execute(final_query, params)
|
||||
users = cursor.fetchall()
|
||||
|
||||
# 2. 查询总条数(用于计算总页数)
|
||||
count_params = [f"%{username}%"] if username else []
|
||||
cursor.execute(final_count_query, count_params)
|
||||
total = cursor.fetchone()["total"]
|
||||
|
||||
# 3. 转换为UserResponse模型(确保字段匹配)
|
||||
user_list = [
|
||||
UserResponse(
|
||||
id=user["id"],
|
||||
username=user["username"],
|
||||
created_at=user["created_at"],
|
||||
updated_at=user["updated_at"]
|
||||
)
|
||||
for user in users
|
||||
]
|
||||
|
||||
# 4. 计算总页数(向上取整、如11条数据每页10条=2页)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# 返回结果(包含列表和分页信息)
|
||||
return APIResponse(
|
||||
code=200,
|
||||
message="获取用户列表成功",
|
||||
data={
|
||||
"users": user_list,
|
||||
"pagination": {
|
||||
"page": page, # 当前页码
|
||||
"page_size": page_size, # 每页条数
|
||||
"total": total, # 总数据量
|
||||
"total_pages": total_pages # 总页数
|
||||
}
|
||||
}
|
||||
)
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取用户列表失败: {str(e)}") from e
|
||||
finally:
|
||||
# 无论成功失败、都关闭数据库连接
|
||||
db.close_connection(conn, cursor)
|
||||
BIN
schema/__pycache__/device_action_schema.cpython-310.pyc
Normal file
BIN
schema/__pycache__/device_danger_schema.cpython-310.pyc
Normal file
BIN
schema/__pycache__/device_schema.cpython-310.pyc
Normal file
BIN
schema/__pycache__/face_schema.cpython-310.pyc
Normal file
BIN
schema/__pycache__/model_schema.cpython-310.pyc
Normal file
BIN
schema/__pycache__/response_schema.cpython-310.pyc
Normal file
BIN
schema/__pycache__/sensitive_schema.cpython-310.pyc
Normal file
BIN
schema/__pycache__/user_schema.cpython-310.pyc
Normal file
36
schema/device_action_schema.py
Normal file
@ -0,0 +1,36 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 请求模型
|
||||
# ------------------------------
|
||||
class DeviceActionCreate(BaseModel):
|
||||
"""设备操作记录创建模型(0=离线、1=上线)"""
|
||||
client_ip: str = Field(..., description="客户端IP")
|
||||
action: int = Field(..., ge=0, le=1, description="操作状态(0=离线、1=上线)")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 响应模型(单条记录)
|
||||
# ------------------------------
|
||||
class DeviceActionResponse(BaseModel):
|
||||
"""设备操作记录响应模型(与自增表对齐)"""
|
||||
id: int = Field(..., description="自增主键ID")
|
||||
client_ip: Optional[str] = Field(None, description="客户端IP")
|
||||
action: Optional[int] = Field(None, description="操作状态(0=离线、1=上线)")
|
||||
created_at: datetime = Field(..., description="记录创建时间")
|
||||
updated_at: datetime = Field(..., description="记录更新时间")
|
||||
|
||||
# 支持从数据库结果直接转换
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 列表响应模型(仅含 total + device_actions)
|
||||
# ------------------------------
|
||||
class DeviceActionListResponse(BaseModel):
|
||||
"""设备操作记录列表(仅核心返回字段)"""
|
||||
total: int = Field(..., description="总记录数")
|
||||
device_actions: List[DeviceActionResponse] = Field(..., description="操作记录列表")
|
||||
33
schema/device_danger_schema.py
Normal file
@ -0,0 +1,33 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 请求模型
|
||||
# ------------------------------
|
||||
class DeviceDangerCreateRequest(BaseModel):
|
||||
"""设备危险记录创建请求模型"""
|
||||
client_ip: str = Field(..., max_length=100, description="设备IP地址(必须与devices表中IP对应)")
|
||||
type: str = Field(..., max_length=255, description="危险类型(如:病毒检测、端口异常、权限泄露等)")
|
||||
result: str = Field(..., description="危险检测结果/处理结果(如:检测到木马病毒、已隔离;端口22异常开放、已关闭)")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 响应模型
|
||||
# ------------------------------
|
||||
class DeviceDangerResponse(BaseModel):
|
||||
"""单条设备危险记录响应模型(与device_danger表字段对齐、updated_at允许为null)"""
|
||||
id: int = Field(..., description="危险记录主键ID")
|
||||
client_ip: str = Field(..., max_length=100, description="设备IP地址")
|
||||
type: str = Field(..., max_length=255, description="危险类型")
|
||||
result: str = Field(..., description="危险检测结果/处理结果")
|
||||
created_at: datetime = Field(..., description="记录创建时间(危险发生/检测时间)")
|
||||
updated_at: Optional[datetime] = Field(None, description="记录更新时间(数据库中该字段当前为null)")
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class DeviceDangerListResponse(BaseModel):
|
||||
"""设备危险记录列表响应模型(带分页)"""
|
||||
total: int = Field(..., description="危险记录总数")
|
||||
dangers: List[DeviceDangerResponse] = Field(..., description="设备危险记录列表")
|
||||
56
schema/device_schema.py
Normal file
@ -0,0 +1,56 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 请求模型
|
||||
# ------------------------------
|
||||
class DeviceCreateRequest(BaseModel):
|
||||
"""设备创建请求模型"""
|
||||
ip: Optional[str] = Field(..., max_length=100, description="设备IP地址")
|
||||
hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
|
||||
params: Optional[Dict] = Field(None, description="设备扩展参数(JSON格式)")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 响应模型
|
||||
# ------------------------------
|
||||
class DeviceResponse(BaseModel):
|
||||
"""单设备信息响应模型(与数据库表字段对齐)"""
|
||||
id: int = Field(..., description="设备主键ID")
|
||||
device_id: Optional[int] = Field(None, description="关联设备ID(若历史记录IP无对应设备则为None)")
|
||||
hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
|
||||
device_online_status: int = Field(..., description="在线状态(1-在线、0-离线)")
|
||||
device_type: Optional[str] = Field(None, description="设备类型")
|
||||
alarm_count: int = Field(..., description="报警次数")
|
||||
params: Optional[str] = Field(None, description="扩展参数(JSON字符串)")
|
||||
client_ip: Optional[str] = Field(None, max_length=100, description="设备IP地址")
|
||||
is_need_handler: int = Field(..., description="是否需要处理(1-需要、0-不需要)")
|
||||
created_at: datetime = Field(..., description="记录创建时间")
|
||||
updated_at: datetime = Field(..., description="记录更新时间")
|
||||
model_config = {"from_attributes": True} # 支持从数据库结果直接转换
|
||||
|
||||
|
||||
class DeviceListResponse(BaseModel):
|
||||
"""设备列表响应模型"""
|
||||
total: int = Field(..., description="设备总数")
|
||||
devices: List[DeviceResponse] = Field(..., description="设备列表")
|
||||
|
||||
|
||||
class DeviceStatusHistoryResponse(BaseModel):
|
||||
"""设备上下线记录响应模型"""
|
||||
id: int = Field(..., description="记录ID")
|
||||
device_id: int = Field(..., description="关联设备ID")
|
||||
client_ip: Optional[str] = Field(None, description="设备IP地址")
|
||||
status: int = Field(..., description="状态(1-在线、0-离线)")
|
||||
status_time: datetime = Field(..., description="状态变更时间")
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class DeviceStatusHistoryListResponse(BaseModel):
|
||||
"""设备上下线记录列表响应模型"""
|
||||
total: int = Field(..., description="记录总数")
|
||||
history: List[DeviceStatusHistoryResponse] = Field(..., description="上下线记录列表")
|
||||
41
schema/face_schema.py
Normal file
@ -0,0 +1,41 @@
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 请求模型(前端传参校验)- 保留update的eigenvalue(如需更新特征值)
|
||||
# ------------------------------
|
||||
class FaceCreateRequest(BaseModel):
|
||||
"""创建人脸记录请求模型(无需ID、由数据库自增)"""
|
||||
name: Optional[str] = Field(None, max_length=255, description="名称(可选、最长255字符)")
|
||||
|
||||
|
||||
class FaceUpdateRequest(BaseModel):
|
||||
"""更新人脸记录请求模型 - 保留eigenvalue(如需更新特征值、不影响返回)"""
|
||||
name: Optional[str] = Field(None, max_length=255, description="名称(可选)")
|
||||
eigenvalue: Optional[str] = Field(None, description="特征值(可选、文件处理后可更新)") # 保留更新能力
|
||||
address: Optional[str] = Field(None, description="图片完整路径(可选、更新图片时使用)")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 响应模型(后端返回数据)- 核心修改:删除eigenvalue字段
|
||||
# ------------------------------
|
||||
class FaceResponse(BaseModel):
|
||||
"""人脸记录响应模型(仅返回需要的字段、移除eigenvalue)"""
|
||||
id: int = Field(..., description="主键ID(数据库自增)")
|
||||
name: Optional[str] = Field(None, description="名称")
|
||||
address: Optional[str] = Field(None, description="人脸图片完整保存路径(数据库新增字段)") # 仅保留address
|
||||
created_at: datetime = Field(..., description="记录创建时间(数据库自动生成)")
|
||||
updated_at: datetime = Field(..., description="记录更新时间(数据库自动生成)")
|
||||
|
||||
# 关键配置:支持从数据库查询结果(字典)直接转换
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class FaceListResponse(BaseModel):
|
||||
"""人脸列表分页响应模型(结构不变、内部FaceResponse已移除eigenvalue)"""
|
||||
total: int = Field(..., description="筛选后的总记录数")
|
||||
faces: List[FaceResponse] = Field(..., description="当前页的人脸记录列表")
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
37
schema/model_schema.py
Normal file
@ -0,0 +1,37 @@
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
# 请求模型
|
||||
class ModelCreateRequest(BaseModel):
|
||||
name: str = Field(..., max_length=255, description="模型名称(必填、如:yolo-v8s-car)")
|
||||
description: Optional[str] = Field(None, description="模型描述(可选)")
|
||||
is_default: Optional[bool] = Field(False, description="是否设为默认模型")
|
||||
|
||||
|
||||
class ModelUpdateRequest(BaseModel):
|
||||
name: Optional[str] = Field(None, max_length=255, description="模型名称(可选修改)")
|
||||
description: Optional[str] = Field(None, description="模型描述(可选修改)")
|
||||
is_default: Optional[bool] = Field(None, description="是否设为默认模型(可选切换)")
|
||||
|
||||
|
||||
# 响应模型
|
||||
class ModelResponse(BaseModel):
|
||||
id: int = Field(..., description="模型ID")
|
||||
name: str = Field(..., description="模型名称")
|
||||
path: str = Field(..., description="模型文件相对路径")
|
||||
is_default: bool = Field(..., description="是否默认模型")
|
||||
description: Optional[str] = Field(None, description="模型描述")
|
||||
file_size: Optional[int] = Field(None, description="文件大小(字节)")
|
||||
created_at: datetime = Field(..., description="创建时间")
|
||||
updated_at: datetime = Field(..., description="更新时间")
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ModelListResponse(BaseModel):
|
||||
total: int = Field(..., description="总记录数")
|
||||
models: List[ModelResponse] = Field(..., description="当前页模型列表")
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
13
schema/response_schema.py
Normal file
@ -0,0 +1,13 @@
|
||||
from typing import Optional, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class APIResponse(BaseModel):
|
||||
"""统一 API 响应模型(所有接口必返此格式)"""
|
||||
code: int = Field(..., description="状态码: 200=成功、4xx=客户端错误、5xx=服务端错误")
|
||||
message: str = Field(..., description="响应信息: 成功/错误描述")
|
||||
data: Optional[Any] = Field(None, description="响应数据: 成功时返回、错误时为 None")
|
||||
|
||||
# Pydantic V2 配置(支持从 ORM 对象转换)
|
||||
model_config = {"from_attributes": True}
|
||||
38
schema/sensitive_schema.py
Normal file
@ -0,0 +1,38 @@
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 请求模型(前端传参校验)
|
||||
# ------------------------------
|
||||
class SensitiveCreateRequest(BaseModel):
|
||||
"""创建敏感信息记录请求模型"""
|
||||
name: str = Field(..., max_length=255, description="敏感词内容(必填)")
|
||||
|
||||
|
||||
class SensitiveUpdateRequest(BaseModel):
|
||||
"""更新敏感信息记录请求模型"""
|
||||
name: Optional[str] = Field(None, max_length=255, description="敏感词内容(可选修改)")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 响应模型(后端返回数据)
|
||||
# ------------------------------
|
||||
class SensitiveResponse(BaseModel):
|
||||
"""敏感信息单条记录响应模型"""
|
||||
id: int = Field(..., description="主键ID")
|
||||
name: str = Field(..., description="敏感词内容")
|
||||
created_at: datetime = Field(..., description="记录创建时间")
|
||||
updated_at: datetime = Field(..., description="记录更新时间")
|
||||
|
||||
# 支持从数据库查询结果(字典/对象)自动转换
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class SensitiveListResponse(BaseModel):
|
||||
"""敏感信息分页列表响应模型(新增)"""
|
||||
total: int = Field(..., description="敏感词总记录数")
|
||||
sensitives: List[SensitiveResponse] = Field(..., description="当前页敏感词列表")
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
40
schema/user_schema.py
Normal file
@ -0,0 +1,40 @@
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 请求模型(前端传参校验)
|
||||
# ------------------------------
|
||||
class UserRegisterRequest(BaseModel):
|
||||
"""用户注册请求模型"""
|
||||
username: str = Field(..., min_length=3, max_length=50, description="用户名(3-50字符)")
|
||||
password: str = Field(..., min_length=6, max_length=100, description="密码(6-100字符)")
|
||||
|
||||
|
||||
class UserLoginRequest(BaseModel):
|
||||
"""用户登录请求模型"""
|
||||
username: str = Field(..., description="用户名")
|
||||
password: str = Field(..., description="密码")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 响应模型(后端返回用户数据)
|
||||
# ------------------------------
|
||||
class UserResponse(BaseModel):
|
||||
"""用户信息响应模型(隐藏密码等敏感字段)"""
|
||||
id: int = Field(..., description="用户ID")
|
||||
username: str = Field(..., description="用户名")
|
||||
created_at: datetime = Field(..., description="创建时间")
|
||||
updated_at: datetime = Field(..., description="更新时间")
|
||||
|
||||
# Pydantic V2 配置(支持从数据库查询结果转换)
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class UserListResponse(BaseModel):
|
||||
"""用户列表分页响应模型(与设备/人脸列表结构对齐)"""
|
||||
total: int = Field(..., description="用户总数")
|
||||
users: List[UserResponse] = Field(..., description="当前页用户列表")
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
BIN
service/__pycache__/device_action_service.cpython-310.pyc
Normal file
BIN
service/__pycache__/device_danger_service.cpython-310.pyc
Normal file
BIN
service/__pycache__/device_service.cpython-310.pyc
Normal file
BIN
service/__pycache__/face_service.cpython-310.pyc
Normal file
BIN
service/__pycache__/file_service.cpython-310.pyc
Normal file
BIN
service/__pycache__/model_service.cpython-310.pyc
Normal file
BIN
service/__pycache__/ocr_service.cpython-310.pyc
Normal file
BIN
service/__pycache__/sensitive_service.cpython-310.pyc
Normal file
43
service/device_action_service.py
Normal file
@ -0,0 +1,43 @@
|
||||
from ds.db import db
|
||||
from schema.device_action_schema import DeviceActionCreate, DeviceActionResponse
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
|
||||
# 新增设备操作记录
|
||||
def add_device_action(client_ip: str, action: int) -> DeviceActionResponse:
|
||||
"""
|
||||
新增设备操作记录(内部方法、非接口)
|
||||
:param action_data: 含client_ip和action(0/1)
|
||||
:return: 新增的完整记录
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 插入SQL(id自增、依赖数据库自动生成)
|
||||
insert_query = """
|
||||
INSERT INTO device_action
|
||||
(client_ip, action, created_at, updated_at)
|
||||
VALUES (%s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
"""
|
||||
cursor.execute(insert_query, (
|
||||
client_ip,
|
||||
action
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
# 获取新增记录(通过自增ID查询)
|
||||
new_id = cursor.lastrowid
|
||||
cursor.execute("SELECT * FROM device_action WHERE id = %s", (new_id,))
|
||||
new_action = cursor.fetchone()
|
||||
|
||||
return DeviceActionResponse(**new_action)
|
||||
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"新增记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
78
service/device_danger_service.py
Normal file
@ -0,0 +1,78 @@
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
from schema.device_danger_schema import DeviceDangerCreateRequest, DeviceDangerResponse
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 内部工具方法 - 检查设备是否存在(复用设备表逻辑)
|
||||
# ------------------------------
|
||||
def check_device_exist(client_ip: str) -> bool:
|
||||
"""
|
||||
检查指定IP的设备是否在devices表中存在
|
||||
|
||||
:param client_ip: 设备IP地址
|
||||
:return: 存在返回True、不存在返回False
|
||||
"""
|
||||
if not client_ip:
|
||||
raise ValueError("设备IP不能为空")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
|
||||
return cursor.fetchone() is not None
|
||||
except MySQLError as e:
|
||||
raise Exception(f"检查设备存在性失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# 内部工具方法 - 创建设备危险记录(核心插入逻辑)
|
||||
# ------------------------------
|
||||
def create_danger_record(danger_data: DeviceDangerCreateRequest) -> DeviceDangerResponse:
|
||||
"""
|
||||
内部工具方法:向device_danger表插入新的危险记录
|
||||
|
||||
:param danger_data: 危险记录创建请求数据
|
||||
:return: 创建成功的危险记录模型对象
|
||||
"""
|
||||
# 先检查设备是否存在
|
||||
if not check_device_exist(danger_data.client_ip):
|
||||
raise ValueError(f"IP为 {danger_data.client_ip} 的设备不存在、无法创建危险记录")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 插入危险记录(id自增、时间自动填充)
|
||||
insert_query = """
|
||||
INSERT INTO device_danger
|
||||
(client_ip, type, result, created_at, updated_at)
|
||||
VALUES (%s, %s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
"""
|
||||
cursor.execute(insert_query, (
|
||||
danger_data.client_ip,
|
||||
danger_data.type,
|
||||
danger_data.result
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
# 获取刚创建的记录(用自增ID查询)
|
||||
danger_id = cursor.lastrowid
|
||||
cursor.execute("SELECT * FROM device_danger WHERE id = %s", (danger_id,))
|
||||
new_danger = cursor.fetchone()
|
||||
|
||||
return DeviceDangerResponse(**new_danger)
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"插入危险记录失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
250
service/device_service.py
Normal file
@ -0,0 +1,250 @@
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
from service.device_action_service import add_device_action
|
||||
_last_alarm_timestamps: dict[str, float] = {}
|
||||
_timestamp_lock = threading.Lock()
|
||||
|
||||
# 获取所有去重的客户端IP列表
|
||||
def get_unique_client_ips() -> list[str]:
|
||||
"""获取所有去重的客户端IP列表"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
query = "SELECT DISTINCT client_ip FROM devices WHERE client_ip IS NOT NULL"
|
||||
cursor.execute(query)
|
||||
results = cursor.fetchall()
|
||||
return [item['client_ip'] for item in results]
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取客户端IP列表失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
# 通过客户端IP更新设备是否需要处理
|
||||
def update_is_need_handler_by_client_ip(client_ip: str, is_need_handler: int) -> bool:
|
||||
"""
|
||||
通过客户端IP更新设备的「是否需要处理」状态(is_need_handler字段)
|
||||
"""
|
||||
# 参数合法性校验
|
||||
if not client_ip:
|
||||
raise ValueError("客户端IP不能为空")
|
||||
|
||||
# 校验is_need_handler取值(需与数据库字段类型匹配、通常为0/1 tinyint)
|
||||
if is_need_handler not in (0, 1):
|
||||
raise ValueError("是否需要处理(is_need_handler)必须是0(不需要)或1(需要)")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
# 2. 获取数据库连接与游标
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 3. 先校验设备是否存在(通过client_ip定位)
|
||||
cursor.execute(
|
||||
"SELECT id FROM devices WHERE client_ip = %s",
|
||||
(client_ip,)
|
||||
)
|
||||
device = cursor.fetchone()
|
||||
if not device:
|
||||
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在、无法更新「是否需要处理」状态")
|
||||
|
||||
# 4. 执行更新操作(同时更新时间戳、保持与其他更新逻辑一致性)
|
||||
update_query = """
|
||||
UPDATE devices
|
||||
SET is_need_handler = %s,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE client_ip = %s
|
||||
"""
|
||||
cursor.execute(update_query, (is_need_handler, client_ip))
|
||||
|
||||
# 5. 确认更新生效(判断影响行数、避免无意义更新)
|
||||
if cursor.rowcount <= 0:
|
||||
raise Exception(f"更新失败:客户端IP {client_ip} 的设备未发生状态变更(可能已为目标值)")
|
||||
|
||||
# 6. 提交事务
|
||||
conn.commit()
|
||||
return True
|
||||
|
||||
except MySQLError as e:
|
||||
# 数据库异常时回滚事务
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"数据库操作失败:更新设备「是否需要处理」状态时出错 - {str(e)}") from e
|
||||
finally:
|
||||
# 无论成功失败、都关闭数据库连接(避免连接泄漏)
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
def increment_alarm_count_by_ip(client_ip: str) -> bool:
|
||||
"""
|
||||
通过客户端IP增加设备的报警次数,相同IP 200ms内重复调用会被忽略
|
||||
|
||||
:param client_ip: 客户端IP地址
|
||||
:return: 操作是否成功(是否实际执行了数据库更新)
|
||||
"""
|
||||
if not client_ip:
|
||||
raise ValueError("客户端IP不能为空")
|
||||
|
||||
current_time = time.time() # 获取当前时间戳(秒,含小数)
|
||||
with _timestamp_lock: # 确保线程安全的字典操作
|
||||
last_time: Optional[float] = _last_alarm_timestamps.get(client_ip)
|
||||
|
||||
# 如果存在最近记录且间隔小于200ms,直接返回False(不执行更新)
|
||||
if last_time is not None and (current_time - last_time) < 0.2:
|
||||
return False
|
||||
|
||||
# 更新当前IP的最近调用时间
|
||||
_last_alarm_timestamps[client_ip] = current_time
|
||||
|
||||
# 2. 执行数据库更新操作(只有通过时间校验才会执行)
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 检查设备是否存在
|
||||
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
|
||||
device = cursor.fetchone()
|
||||
if not device:
|
||||
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
|
||||
|
||||
# 报警次数加1、并更新时间戳
|
||||
update_query = """
|
||||
UPDATE devices
|
||||
SET alarm_count = alarm_count + 1,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE client_ip = %s
|
||||
"""
|
||||
cursor.execute(update_query, (client_ip,))
|
||||
conn.commit()
|
||||
|
||||
return True
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"更新报警次数失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
# 通过客户端IP更新设备在线状态
|
||||
def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
|
||||
"""
|
||||
通过客户端IP更新设备的在线状态
|
||||
|
||||
:param client_ip: 客户端IP地址
|
||||
:param online_status: 在线状态(1-在线、0-离线)
|
||||
:return: 操作是否成功
|
||||
"""
|
||||
if not client_ip:
|
||||
raise ValueError("客户端IP不能为空")
|
||||
|
||||
if online_status not in (0, 1):
|
||||
raise ValueError("在线状态必须是0(离线)或1(在线)")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 检查设备是否存在并获取设备ID
|
||||
cursor.execute("SELECT id, device_online_status FROM devices WHERE client_ip = %s", (client_ip,))
|
||||
device = cursor.fetchone()
|
||||
if not device:
|
||||
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
|
||||
|
||||
# 状态无变化则不操作
|
||||
if device['device_online_status'] == online_status:
|
||||
return True
|
||||
|
||||
# 更新在线状态和时间戳
|
||||
update_query = """
|
||||
UPDATE devices
|
||||
SET device_online_status = %s,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE client_ip = %s
|
||||
"""
|
||||
cursor.execute(update_query, (online_status, client_ip))
|
||||
|
||||
# 记录状态变更历史
|
||||
add_device_action(client_ip, online_status)
|
||||
|
||||
conn.commit()
|
||||
return True
|
||||
except MySQLError as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise Exception(f"更新设备在线状态失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
# 通过客户端IP查询设备在数据库中存在
|
||||
def is_device_exist_by_ip(client_ip: str) -> bool:
|
||||
"""
|
||||
通过客户端IP查询设备在数据库中是否存在
|
||||
"""
|
||||
if not client_ip:
|
||||
raise ValueError("客户端IP不能为空")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 查询设备是否存在
|
||||
cursor.execute(
|
||||
"SELECT id FROM devices WHERE client_ip = %s",
|
||||
(client_ip,)
|
||||
)
|
||||
device = cursor.fetchone()
|
||||
|
||||
# 如果查询到结果则存在,否则不存在
|
||||
return bool(device)
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询设备是否存在失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
# 根据客户端IP获取是否需要处理
|
||||
def get_is_need_handler_by_ip(client_ip: str) -> int:
|
||||
"""
|
||||
通过客户端IP查询设备的is_need_handler状态
|
||||
|
||||
:param client_ip: 客户端IP地址
|
||||
:return: 设备的is_need_handler状态(0-不需要处理,1-需要处理)
|
||||
"""
|
||||
if not client_ip:
|
||||
raise ValueError("客户端IP不能为空")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 查询设备的is_need_handler状态
|
||||
cursor.execute(
|
||||
"SELECT is_need_handler FROM devices WHERE client_ip = %s",
|
||||
(client_ip,)
|
||||
)
|
||||
device = cursor.fetchone()
|
||||
|
||||
if not device:
|
||||
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
|
||||
|
||||
return device['is_need_handler']
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"查询设备is_need_handler状态失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
339
service/face_service.py
Normal file
@ -0,0 +1,339 @@
|
||||
import cv2
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
|
||||
import numpy as np
|
||||
import threading
|
||||
from insightface.app import FaceAnalysis
|
||||
|
||||
# 全局变量定义
|
||||
_insightface_app = None
|
||||
_known_faces_embeddings = {} # 存储已知人脸特征 {姓名: 特征向量}
|
||||
_known_faces_names = [] # 存储已知人脸姓名列表
|
||||
|
||||
|
||||
def init_insightface():
|
||||
"""初始化InsightFace引擎"""
|
||||
global _insightface_app
|
||||
if _insightface_app is not None:
|
||||
print("InsightFace引擎已初始化,无需重复执行")
|
||||
return _insightface_app
|
||||
|
||||
try:
|
||||
print("正在初始化 InsightFace 引擎(模型:buffalo_l)...")
|
||||
# 初始化引擎,指定模型路径和计算 providers
|
||||
app = FaceAnalysis(
|
||||
name='buffalo_l',
|
||||
root='~/.insightface',
|
||||
providers=['CPUExecutionProvider'] # 如需GPU可添加'CUDAExecutionProvider'
|
||||
)
|
||||
app.prepare(ctx_id=0, det_size=(640, 640)) # 调整检测尺寸
|
||||
print("InsightFace 引擎初始化完成")
|
||||
|
||||
# 初始化时加载人脸特征库
|
||||
init_face_data()
|
||||
|
||||
_insightface_app = app
|
||||
return app
|
||||
except Exception as e:
|
||||
print(f"InsightFace 初始化失败:{str(e)}")
|
||||
_insightface_app = None
|
||||
return None
|
||||
|
||||
|
||||
def init_face_data():
|
||||
"""初始化或更新人脸特征库(清空旧数据,避免重复)"""
|
||||
global _known_faces_embeddings, _known_faces_names
|
||||
# 清空原有数据,防止重复加载
|
||||
_known_faces_embeddings.clear()
|
||||
_known_faces_names.clear()
|
||||
|
||||
try:
|
||||
face_data = get_all_face_name_with_eigenvalue() # 假设该函数已定义
|
||||
print(f"已加载 {len(face_data)} 个人脸数据")
|
||||
for person_name, eigenvalue_data in face_data.items():
|
||||
# 解析特征值(支持numpy数组或字符串格式)
|
||||
if isinstance(eigenvalue_data, np.ndarray):
|
||||
eigenvalue = eigenvalue_data.astype(np.float32)
|
||||
elif isinstance(eigenvalue_data, str):
|
||||
# 增强字符串解析:支持逗号/空格分隔,清理特殊字符
|
||||
cleaned = (eigenvalue_data
|
||||
.replace("[", "").replace("]", "")
|
||||
.replace("\n", "").replace(",", " ")
|
||||
.strip())
|
||||
values = [v for v in cleaned.split() if v] # 过滤空字符串
|
||||
if not values:
|
||||
print(f"特征值解析失败(空值),跳过 {person_name}")
|
||||
continue
|
||||
eigenvalue = np.array(list(map(float, values)), dtype=np.float32)
|
||||
else:
|
||||
print(f"不支持的特征值类型({type(eigenvalue_data)}),跳过 {person_name}")
|
||||
continue
|
||||
|
||||
# 特征值归一化(确保相似度计算一致性)
|
||||
norm = np.linalg.norm(eigenvalue)
|
||||
if norm == 0:
|
||||
print(f"特征值为零向量,跳过 {person_name}")
|
||||
continue
|
||||
eigenvalue = eigenvalue / norm
|
||||
|
||||
# 更新全局特征库
|
||||
_known_faces_embeddings[person_name] = eigenvalue
|
||||
_known_faces_names.append(person_name)
|
||||
|
||||
print(f"成功加载 {len(_known_faces_names)} 个人脸的特征库")
|
||||
except Exception as e:
|
||||
print(f"加载人脸特征库失败: {e}")
|
||||
|
||||
|
||||
def update_face_data():
|
||||
"""更新人脸特征库(清空旧数据,加载最新数据)"""
|
||||
global _known_faces_embeddings, _known_faces_names
|
||||
|
||||
print("开始更新人脸特征库...")
|
||||
|
||||
# 清空原有数据
|
||||
_known_faces_embeddings.clear()
|
||||
_known_faces_names.clear()
|
||||
|
||||
try:
|
||||
# 获取最新人脸数据
|
||||
face_data = get_all_face_name_with_eigenvalue()
|
||||
print(f"获取到 {len(face_data)} 条最新人脸数据")
|
||||
|
||||
# 处理并加载新数据(复用原有解析逻辑)
|
||||
for person_name, eigenvalue_data in face_data.items():
|
||||
# 解析特征值(支持numpy数组或字符串格式)
|
||||
if isinstance(eigenvalue_data, np.ndarray):
|
||||
eigenvalue = eigenvalue_data.astype(np.float32)
|
||||
elif isinstance(eigenvalue_data, str):
|
||||
# 增强字符串解析:支持逗号/空格分隔,清理特殊字符
|
||||
cleaned = (eigenvalue_data
|
||||
.replace("[", "").replace("]", "")
|
||||
.replace("\n", "").replace(",", " ")
|
||||
.strip())
|
||||
values = [v for v in cleaned.split() if v] # 过滤空字符串
|
||||
if not values:
|
||||
print(f"特征值解析失败(空值),跳过 {person_name}")
|
||||
continue
|
||||
eigenvalue = np.array(list(map(float, values)), dtype=np.float32)
|
||||
else:
|
||||
print(f"不支持的特征值类型({type(eigenvalue_data)}),跳过 {person_name}")
|
||||
continue
|
||||
|
||||
# 特征值归一化(确保相似度计算一致性)
|
||||
norm = np.linalg.norm(eigenvalue)
|
||||
if norm == 0:
|
||||
print(f"特征值为零向量,跳过 {person_name}")
|
||||
continue
|
||||
eigenvalue = eigenvalue / norm
|
||||
|
||||
# 更新全局特征库
|
||||
_known_faces_embeddings[person_name] = eigenvalue
|
||||
_known_faces_names.append(person_name)
|
||||
|
||||
print(f"人脸特征库更新完成,共加载 {len(_known_faces_names)} 个人脸数据")
|
||||
return True # 更新成功
|
||||
except Exception as e:
|
||||
print(f"人脸特征库更新失败: {e}")
|
||||
return False # 更新失败
|
||||
|
||||
|
||||
def detect(frame, similarity_threshold=0.4):
|
||||
global _insightface_app, _known_faces_embeddings
|
||||
|
||||
# 校验输入有效性
|
||||
if frame is None or frame.size == 0:
|
||||
return (False, "无效的输入帧数据")
|
||||
|
||||
# 校验引擎和特征库状态
|
||||
if not _insightface_app:
|
||||
return (False, "人脸引擎未初始化")
|
||||
if not _known_faces_embeddings:
|
||||
return (False, "人脸特征库为空")
|
||||
|
||||
try:
|
||||
# 执行人脸检测与特征提取
|
||||
faces = _insightface_app.get(frame)
|
||||
except Exception as e:
|
||||
return (False, f"检测错误: {str(e)}")
|
||||
|
||||
result_parts = []
|
||||
has_matched_known_face = False # 是否有匹配到已知人脸
|
||||
|
||||
for face in faces:
|
||||
# 归一化当前人脸特征
|
||||
face_embedding = face.embedding.astype(np.float32)
|
||||
norm = np.linalg.norm(face_embedding)
|
||||
if norm == 0:
|
||||
result_parts.append("检测到人脸但特征值为零向量(忽略)")
|
||||
continue
|
||||
face_embedding = face_embedding / norm
|
||||
|
||||
# 与已知特征库比对
|
||||
max_similarity, best_match_name = -1.0, "Unknown"
|
||||
for name, known_emb in _known_faces_embeddings.items():
|
||||
similarity = np.dot(face_embedding, known_emb) # 余弦相似度
|
||||
if similarity > max_similarity:
|
||||
max_similarity = similarity
|
||||
best_match_name = name
|
||||
|
||||
# 判断是否匹配成功
|
||||
is_matched = max_similarity >= similarity_threshold
|
||||
|
||||
if is_matched:
|
||||
has_matched_known_face = True
|
||||
|
||||
# 记录结果(边界框转为整数列表)
|
||||
bbox = face.bbox.astype(int).tolist()
|
||||
result_parts.append(
|
||||
f"{'匹配' if is_matched else '未匹配'}: {best_match_name} "
|
||||
f"(相似度: {max_similarity:.2f}, 边界框: {bbox})"
|
||||
)
|
||||
|
||||
# 构建最终结果
|
||||
result_str = "未检测到人脸" if not result_parts else "; ".join(result_parts)
|
||||
return (has_matched_known_face, result_str)
|
||||
|
||||
|
||||
# 上传图片并提取特征
|
||||
def add_binary_data(binary_data):
|
||||
"""
|
||||
接收单张图片的二进制数据、提取特征并保存
|
||||
返回:(True, 特征值numpy数组) 或 (False, 错误信息字符串)
|
||||
"""
|
||||
global _insightface_app, _feature_list
|
||||
|
||||
# 1. 先检查引擎是否初始化成功
|
||||
if not _insightface_app:
|
||||
init_result = init_insightface() # 尝试重新初始化
|
||||
if not init_result:
|
||||
error_msg = "InsightFace引擎未初始化、无法检测人脸"
|
||||
print(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
try:
|
||||
# 2. 验证二进制数据有效性
|
||||
if len(binary_data) < 1024: # 过滤过小的无效图片(小于1KB)
|
||||
error_msg = f"图片过小({len(binary_data)}字节)、可能不是有效图片"
|
||||
print(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
# 3. 二进制数据转CV2格式(关键步骤、避免通道错误)
|
||||
try:
|
||||
img = Image.open(BytesIO(binary_data)).convert("RGB") # 强制转RGB
|
||||
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) # InsightFace需要BGR格式
|
||||
except Exception as e:
|
||||
error_msg = f"图片格式转换失败:{str(e)}"
|
||||
print(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
# 4. 检查图片尺寸(避免极端尺寸导致检测失败)
|
||||
height, width = frame.shape[:2]
|
||||
if height < 64 or width < 64: # 人脸检测最小建议尺寸
|
||||
error_msg = f"图片尺寸过小({width}x{height})、需至少64x64像素"
|
||||
print(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
# 5. 调用InsightFace检测人脸
|
||||
print(f"开始检测人脸(图片尺寸:{width}x{height}、格式:BGR)")
|
||||
faces = _insightface_app.get(frame)
|
||||
|
||||
if not faces:
|
||||
error_msg = "未检测到人脸(请确保图片包含清晰正面人脸、无遮挡、不模糊)"
|
||||
print(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
# 6. 提取特征并保存
|
||||
current_feature = faces[0].embedding
|
||||
_feature_list.append(current_feature)
|
||||
print(f"人脸检测成功、提取特征值(维度:{current_feature.shape[0]})、累计特征数:{len(_feature_list)}")
|
||||
return True, current_feature
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"处理图片时发生异常:{str(e)}"
|
||||
print(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
|
||||
# 获取数据库最新的人脸及其特征
|
||||
def get_all_face_name_with_eigenvalue() -> dict:
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
|
||||
cursor.execute(query)
|
||||
faces = cursor.fetchall()
|
||||
|
||||
name_to_eigenvalues = {}
|
||||
for face in faces:
|
||||
name = face["name"]
|
||||
eigenvalue = face["eigenvalue"]
|
||||
if name in name_to_eigenvalues:
|
||||
name_to_eigenvalues[name].append(eigenvalue)
|
||||
else:
|
||||
name_to_eigenvalues[name] = [eigenvalue]
|
||||
|
||||
face_dict = {}
|
||||
for name, eigenvalues in name_to_eigenvalues.items():
|
||||
if len(eigenvalues) > 1:
|
||||
face_dict[name] = get_average_feature(eigenvalues)
|
||||
else:
|
||||
face_dict[name] = eigenvalues[0]
|
||||
|
||||
return face_dict
|
||||
|
||||
except MySQLError as e:
|
||||
raise Exception(f"获取人脸特征失败: {str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
|
||||
|
||||
# 获取平均特征值
|
||||
def get_average_feature(features=None):
|
||||
global _feature_list
|
||||
try:
|
||||
if features is None:
|
||||
features = _feature_list
|
||||
if not isinstance(features, list) or len(features) == 0:
|
||||
print("输入必须是包含至少一个特征值的列表")
|
||||
return None
|
||||
|
||||
processed_features = []
|
||||
for i, embedding in enumerate(features):
|
||||
try:
|
||||
if isinstance(embedding, str):
|
||||
embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip()
|
||||
embedding_list = [float(num) for num in embedding_str.split() if num.strip()]
|
||||
embedding_np = np.array(embedding_list, dtype=np.float32)
|
||||
else:
|
||||
embedding_np = np.array(embedding, dtype=np.float32)
|
||||
|
||||
if len(embedding_np.shape) == 1:
|
||||
processed_features.append(embedding_np)
|
||||
print(f"已添加第 {i + 1} 个特征值用于计算平均值")
|
||||
else:
|
||||
print(f"跳过第 {i + 1} 个特征值:不是一维数组")
|
||||
except Exception as e:
|
||||
print(f"处理第 {i + 1} 个特征值时出错:{str(e)}")
|
||||
|
||||
if not processed_features:
|
||||
print("没有有效的特征值用于计算平均值")
|
||||
return None
|
||||
|
||||
dims = {feat.shape[0] for feat in processed_features}
|
||||
if len(dims) > 1:
|
||||
print(f"特征值维度不一致:{dims}、无法计算平均值")
|
||||
return None
|
||||
|
||||
avg_feature = np.mean(processed_features, axis=0)
|
||||
print(f"计算成功:{len(processed_features)} 个特征值的平均向量(维度:{avg_feature.shape[0]})")
|
||||
return avg_feature
|
||||
except Exception as e:
|
||||
print(f"计算平均特征值出错:{str(e)}")
|
||||
return None
|
||||
343
service/file_service.py
Normal file
@ -0,0 +1,343 @@
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from PIL import ImageDraw, ImageFont
|
||||
from fastapi import UploadFile
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
# 上传根目录
|
||||
UPLOAD_ROOT = "upload"
|
||||
PRE = "/api/file/download/"
|
||||
|
||||
# 确保上传根目录存在
|
||||
os.makedirs(UPLOAD_ROOT, exist_ok=True)
|
||||
|
||||
|
||||
|
||||
def save_detect_file(client_ip: str, image_np: np.ndarray, file_type: str) -> str:
|
||||
"""保存numpy数组格式的PNG图片到detect目录,返回下载路径"""
|
||||
today = datetime.now()
|
||||
year = today.strftime("%Y")
|
||||
month = today.strftime("%m")
|
||||
day = today.strftime("%d")
|
||||
|
||||
# 构建目录路径: upload/detect/客户端IP/type/年/月/日(包含UPLOAD_ROOT)
|
||||
file_dir = os.path.join(
|
||||
UPLOAD_ROOT,
|
||||
"detect",
|
||||
client_ip,
|
||||
file_type,
|
||||
year,
|
||||
month,
|
||||
day
|
||||
)
|
||||
|
||||
# 创建目录(确保目录存在)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
|
||||
# 生成当前时间戳作为文件名,确保唯一性
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
|
||||
filename = f"{timestamp}.png"
|
||||
|
||||
# 1. 完整路径:用于实际保存文件(包含UPLOAD_ROOT)
|
||||
full_path = os.path.join(file_dir, filename)
|
||||
# 2. 相对路径:用于返回给前端(移除UPLOAD_ROOT前缀)
|
||||
relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep)
|
||||
|
||||
# 保存numpy数组为PNG图片
|
||||
try:
|
||||
# -------- 新增/修改:处理颜色通道和数据类型 --------
|
||||
# 1. 数据类型转换:确保是uint8(若为float32且范围0-1,需转成0-255的uint8)
|
||||
if image_np.dtype != np.uint8:
|
||||
image_np = (image_np * 255).astype(np.uint8)
|
||||
|
||||
# 2. 通道顺序转换:若为OpenCV的BGR格式,转成PIL需要的RGB格式
|
||||
image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 3. 转换为PIL Image并保存
|
||||
img = Image.fromarray(image_rgb)
|
||||
img.save(full_path, format='PNG')
|
||||
except Exception as e:
|
||||
# 处理可能的异常(如数组格式不正确)
|
||||
raise Exception(f"保存图片失败: {str(e)}")
|
||||
|
||||
# 统一路径分隔符为/,拼接前缀返回
|
||||
return PRE + relative_path.replace(os.sep, "/")
|
||||
|
||||
|
||||
def save_detect_yolo_file(
|
||||
client_ip: str,
|
||||
image_np: np.ndarray,
|
||||
detection_results: list,
|
||||
file_type: str = "yolo"
|
||||
) -> str:
|
||||
|
||||
|
||||
print("......................")
|
||||
"""
|
||||
保存YOLO检测结果图片(在原图上绘制边界框+标签),返回前端可访问的下载路径
|
||||
"""
|
||||
# 输入参数验证
|
||||
if not isinstance(image_np, np.ndarray):
|
||||
raise ValueError(f"输入image_np必须是numpy数组,当前类型:{type(image_np)}")
|
||||
if image_np.ndim != 3 or image_np.shape[-1] != 3:
|
||||
raise ValueError(f"输入图像必须是 (h, w, 3) 的BGR数组,当前shape:{image_np.shape}")
|
||||
|
||||
if not isinstance(detection_results, list):
|
||||
raise ValueError(f"detection_results必须是列表,当前类型:{type(detection_results)}")
|
||||
for idx, result in enumerate(detection_results):
|
||||
required_keys = {"class", "confidence", "bbox"}
|
||||
if not isinstance(result, dict) or not required_keys.issubset(result.keys()):
|
||||
raise ValueError(
|
||||
f"detection_results第{idx}个元素格式错误,需包含键:{required_keys},"
|
||||
f"当前元素:{result}"
|
||||
)
|
||||
bbox = result["bbox"]
|
||||
if not (isinstance(bbox, (tuple, list)) and len(bbox) == 4 and all(isinstance(x, int) for x in bbox)):
|
||||
raise ValueError(
|
||||
f"detection_results第{idx}个元素的bbox格式错误,需为(x1,y1,x2,y2)整数元组,"
|
||||
f"当前bbox:{bbox}"
|
||||
)
|
||||
|
||||
#图像预处理(数据类型+通道)
|
||||
draw_image = image_np.copy()
|
||||
if draw_image.dtype != np.uint8:
|
||||
draw_image = np.clip(draw_image * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
#绘制边界框+标签
|
||||
# 遍历所有检测结果,逐个绘制
|
||||
for result in detection_results:
|
||||
class_name = result["class"]
|
||||
confidence = result["confidence"]
|
||||
x1, y1, x2, y2 = result["bbox"]
|
||||
cv2.rectangle(draw_image, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2)
|
||||
label = f"{class_name}: {confidence:.2f}"
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
font_scale = 0.5
|
||||
font_thickness = 2
|
||||
(label_width, label_height), baseline = cv2.getTextSize(
|
||||
label, font, font_scale, font_thickness
|
||||
)
|
||||
|
||||
bg_top_left = (x1, y1 - label_height - 10)
|
||||
bg_bottom_right = (x1 + label_width, y1)
|
||||
if bg_top_left[1] < 0:
|
||||
bg_top_left = (x1, 0)
|
||||
bg_bottom_right = (x1 + label_width, label_height + 10)
|
||||
cv2.rectangle(draw_image, bg_top_left, bg_bottom_right, color=(0, 0, 0), thickness=-1)
|
||||
|
||||
text_origin = (x1, y1 - 5)
|
||||
if bg_top_left[1] == 0:
|
||||
text_origin = (x1, label_height + 5)
|
||||
cv2.putText(
|
||||
draw_image, label, text_origin,
|
||||
font, font_scale, color=(255, 255, 255), thickness=font_thickness
|
||||
)
|
||||
|
||||
#保存图片
|
||||
try:
|
||||
today = datetime.now()
|
||||
year = today.strftime("%Y")
|
||||
month = today.strftime("%m")
|
||||
day = today.strftime("%d")
|
||||
file_dir = os.path.join(
|
||||
UPLOAD_ROOT, "detect", client_ip, file_type, year, month, day
|
||||
)
|
||||
|
||||
#创建目录(若不存在则创建,支持多级目录)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
|
||||
#生成唯一文件名
|
||||
timestamp = today.strftime("%Y%m%d%H%M%S%f")
|
||||
filename = f"{timestamp}.png"
|
||||
|
||||
# 4.4 构建完整保存路径和前端访问路径
|
||||
full_path = os.path.join(file_dir, filename) # 本地完整路径
|
||||
# 相对路径:移除UPLOAD_ROOT前缀,统一用"/"作为分隔符(兼容Windows/Linux)
|
||||
relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep)
|
||||
download_path = PRE + relative_path.replace(os.sep, "/")
|
||||
|
||||
# 4.5 保存图片(CV2绘制的是BGR,需转RGB后用PIL保存,与原逻辑一致)
|
||||
image_rgb = cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB)
|
||||
img_pil = Image.fromarray(image_rgb)
|
||||
img_pil.save(full_path, format="PNG", quality=95) # PNG格式无压缩,quality可忽略
|
||||
|
||||
print(f"YOLO检测图片保存成功 | 本地路径:{full_path} | 下载路径:{download_path}")
|
||||
return download_path
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"YOLO检测图片保存失败:{str(e)}") from e
|
||||
|
||||
|
||||
def save_detect_face_file(
|
||||
client_ip: str,
|
||||
image_np: np.ndarray,
|
||||
face_result: str,
|
||||
file_type: str = "face",
|
||||
matched_color: tuple = (0, 255, 0)
|
||||
) -> str:
|
||||
"""
|
||||
保存人脸识别结果图片(仅为「匹配成功」的人脸画框,标签不包含“匹配”二字)
|
||||
"""
|
||||
#输入参数验证
|
||||
if not isinstance(image_np, np.ndarray) or image_np.ndim != 3 or image_np.shape[-1] != 3:
|
||||
raise ValueError(f"输入图像需为 (h, w, 3) 的BGR数组,当前shape:{image_np.shape}")
|
||||
if not isinstance(face_result, str) or face_result.strip() == "":
|
||||
raise ValueError("face_result必须是非空字符串")
|
||||
|
||||
# 解析face_result提取人脸信息
|
||||
face_info_list = []
|
||||
if face_result.strip() != "未检测到人脸":
|
||||
face_pattern = re.compile(
|
||||
r"(匹配|未匹配):\s*([^\s(]+)\s*\(相似度:\s*(\d+\.\d+),\s*边界框:\s*\[(\d+,\s*\d+,\s*\d+,\s*\d+)\]\)"
|
||||
)
|
||||
for part in [p.strip() for p in face_result.split(";") if p.strip()]:
|
||||
match = face_pattern.match(part)
|
||||
if match:
|
||||
status, name, similarity, bbox_str = match.groups()
|
||||
bbox = list(map(int, bbox_str.replace(" ", "").split(",")))
|
||||
if len(bbox) == 4:
|
||||
face_info_list.append({
|
||||
"status": status,
|
||||
"name": name,
|
||||
"similarity": float(similarity),
|
||||
"bbox": bbox
|
||||
})
|
||||
|
||||
# 图像格式转换(OpenCV→PIL)
|
||||
image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
|
||||
pil_img = Image.fromarray(image_rgb)
|
||||
draw = ImageDraw.Draw(pil_img)
|
||||
|
||||
# 绘制边界框和标签
|
||||
font_size = 12
|
||||
try:
|
||||
font = ImageFont.truetype("simhei", font_size)
|
||||
except:
|
||||
try:
|
||||
font = ImageFont.truetype("simsun", font_size)
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
print("警告:未找到指定中文字体,使用PIL默认字体(可能影响中文显示)")
|
||||
|
||||
for face_info in face_info_list:
|
||||
status = face_info["status"]
|
||||
if status != "匹配":
|
||||
print(f"跳过未匹配人脸:{face_info['name']}(相似度:{face_info['similarity']:.2f})")
|
||||
continue
|
||||
|
||||
name = face_info["name"]
|
||||
similarity = face_info["similarity"]
|
||||
x1, y1, x2, y2 = face_info["bbox"]
|
||||
|
||||
# 4.1 绘制边界框(绿色)
|
||||
img_cv = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
||||
cv2.rectangle(img_cv, (x1, y1), (x2, y2), color=matched_color, thickness=2)
|
||||
pil_img = Image.fromarray(cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB))
|
||||
draw = ImageDraw.Draw(pil_img)
|
||||
|
||||
label = f"{name} (相似度: {similarity:.2f})"
|
||||
|
||||
# 4.3 计算标签尺寸(文本变短后会自动适配,无需额外调整)
|
||||
label_bbox = draw.textbbox((0, 0), label, font=font)
|
||||
label_width = label_bbox[2] - label_bbox[0]
|
||||
label_height = label_bbox[3] - label_bbox[1]
|
||||
|
||||
# 4.4 计算标签背景位置(避免超出图像)
|
||||
bg_x1, bg_y1 = x1, y1 - label_height - 10
|
||||
bg_x2, bg_y2 = x1 + label_width, y1
|
||||
if bg_y1 < 0:
|
||||
bg_y1, bg_y2 = y2 + 5, y2 + label_height + 15
|
||||
|
||||
# 4.5 绘制标签背景(黑色)和文本(白色)
|
||||
draw.rectangle([(bg_x1, bg_y1), (bg_x2, bg_y2)], fill=(0, 0, 0))
|
||||
text_x = bg_x1
|
||||
text_y = bg_y1 if bg_y1 >= 0 else bg_y1 + label_height
|
||||
draw.text((text_x, text_y), label, font=font, fill=(255, 255, 255))
|
||||
|
||||
#保存图片
|
||||
try:
|
||||
today = datetime.now()
|
||||
file_dir = os.path.join(
|
||||
UPLOAD_ROOT, "detect", client_ip, file_type,
|
||||
today.strftime("%Y"), today.strftime("%m"), today.strftime("%d")
|
||||
)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
|
||||
timestamp = today.strftime("%Y%m%d%H%M%S%f")
|
||||
filename = f"{timestamp}.png"
|
||||
full_path = os.path.join(file_dir, filename)
|
||||
|
||||
pil_img.save(full_path, format="PNG", quality=100)
|
||||
|
||||
relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep)
|
||||
download_path = PRE + relative_path.replace(os.sep, "/")
|
||||
|
||||
matched_count = sum(1 for info in face_info_list if info["status"] == "匹配")
|
||||
print(f"人脸检测图片保存成功 | 客户端IP:{client_ip} | 匹配人脸数:{matched_count} | 保存路径:{download_path}")
|
||||
return download_path
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"人脸检测图片保存失败(客户端IP:{client_ip}):{str(e)}") from e
|
||||
|
||||
def save_source_file(upload_file: UploadFile, file_type: str) -> str:
|
||||
"""保存上传的文件到source目录,返回下载路径"""
|
||||
today = datetime.now()
|
||||
year = today.strftime("%Y")
|
||||
month = today.strftime("%m")
|
||||
day = today.strftime("%d")
|
||||
|
||||
# 生成精确到微秒的时间戳,确保文件名唯一
|
||||
timestamp = today.strftime("%Y%m%d%H%M%S%f")
|
||||
# 构建新文件名:时间戳_原文件名
|
||||
unique_filename = f"{timestamp}_{upload_file.filename}"
|
||||
|
||||
# 构建目录路径: upload/source/type/年/月/日(包含UPLOAD_ROOT)
|
||||
file_dir = os.path.join(
|
||||
UPLOAD_ROOT,
|
||||
"source",
|
||||
file_type,
|
||||
year,
|
||||
month,
|
||||
day
|
||||
)
|
||||
|
||||
# 创建目录(确保目录存在)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
|
||||
# 1. 完整路径:用于实际保存文件(使用带时间戳的唯一文件名)
|
||||
full_path = os.path.join(file_dir, unique_filename)
|
||||
# 2. 相对路径:用于返回给前端
|
||||
relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep)
|
||||
|
||||
# 保存文件(使用完整路径)
|
||||
try:
|
||||
with open(full_path, "wb") as buffer:
|
||||
shutil.copyfileobj(upload_file.file, buffer)
|
||||
finally:
|
||||
upload_file.file.close()
|
||||
|
||||
# 统一路径分隔符为/
|
||||
return PRE + relative_path.replace(os.sep, "/")
|
||||
|
||||
|
||||
def get_absolute_path(relative_path: str) -> str:
|
||||
"""
|
||||
根据相对路径获取服务器上的绝对路径
|
||||
"""
|
||||
path_without_pre = relative_path.replace(PRE, "", 1)
|
||||
|
||||
# 将相对路径转换为系统兼容的格式
|
||||
normalized_path = os.path.normpath(path_without_pre)
|
||||
|
||||
# 拼接基础路径和相对路径,得到绝对路径
|
||||
absolute_path = os.path.abspath(os.path.join(UPLOAD_ROOT, normalized_path))
|
||||
|
||||
# 安全检查:确保生成的路径在UPLOAD_ROOT目录下,防止路径遍历
|
||||
if not absolute_path.startswith(os.path.abspath(UPLOAD_ROOT)):
|
||||
raise ValueError("无效的相对路径,可能试图访问上传目录之外的内容")
|
||||
|
||||
return absolute_path
|
||||
131
service/model_service.py
Normal file
@ -0,0 +1,131 @@
|
||||
from http.client import HTTPException
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from MySQLdb import MySQLError
|
||||
from ultralytics import YOLO
|
||||
import os
|
||||
|
||||
from ds.db import db
|
||||
from service.file_service import get_absolute_path
|
||||
|
||||
# 全局变量
|
||||
current_yolo_model = None
|
||||
current_model_absolute_path = None # 存储模型绝对路径,不依赖model实例
|
||||
|
||||
ALLOWED_MODEL_EXT = {"pt"}
|
||||
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
|
||||
|
||||
|
||||
def load_yolo_model():
|
||||
"""加载模型并存储绝对路径"""
|
||||
global current_yolo_model, current_model_absolute_path
|
||||
model_rel_path = get_enabled_model_rel_path()
|
||||
print(f"[模型初始化] 加载模型:{model_rel_path}")
|
||||
|
||||
# 计算并存储绝对路径
|
||||
current_model_absolute_path = get_absolute_path(model_rel_path)
|
||||
print(f"[模型初始化] 绝对路径:{current_model_absolute_path}")
|
||||
|
||||
# 检查模型文件
|
||||
if not os.path.exists(current_model_absolute_path):
|
||||
raise FileNotFoundError(f"模型文件不存在: {current_model_absolute_path}")
|
||||
|
||||
try:
|
||||
new_model = YOLO(current_model_absolute_path)
|
||||
if torch.cuda.is_available():
|
||||
new_model.to('cuda')
|
||||
print("模型已移动到GPU")
|
||||
else:
|
||||
print("使用CPU进行推理")
|
||||
current_yolo_model = new_model
|
||||
print(f"成功加载模型: {current_model_absolute_path}")
|
||||
return current_yolo_model
|
||||
except Exception as e:
|
||||
print(f"模型加载失败:{str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_current_model():
|
||||
"""获取当前模型实例"""
|
||||
if current_yolo_model is None:
|
||||
raise ValueError("尚未加载任何YOLO模型,请先调用load_yolo_model加载模型")
|
||||
return current_yolo_model
|
||||
|
||||
|
||||
def detect(image_np, conf_threshold=0.8):
|
||||
# 1. 输入格式验证
|
||||
if not isinstance(image_np, np.ndarray):
|
||||
raise ValueError("输入必须是numpy数组(BGR图像)")
|
||||
if image_np.ndim != 3 or image_np.shape[-1] != 3:
|
||||
raise ValueError(f"输入图像格式错误,需为 (h, w, 3) 的BGR数组,当前shape: {image_np.shape}")
|
||||
detection_results = []
|
||||
try:
|
||||
model = get_current_model()
|
||||
if not current_model_absolute_path:
|
||||
raise RuntimeError("模型未初始化!请先调用 load_yolo_model 加载模型")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"检测设备:{device} | 置信度阈值:{conf_threshold}")
|
||||
|
||||
# 图像尺寸信息
|
||||
img_height, img_width = image_np.shape[:2]
|
||||
print(f"输入图像尺寸:{img_width}x{img_height}")
|
||||
|
||||
# YOLO检测
|
||||
print("执行YOLO检测")
|
||||
results = model.predict(
|
||||
image_np,
|
||||
conf=conf_threshold,
|
||||
device=device,
|
||||
show=False,
|
||||
)
|
||||
|
||||
# 4. 整理检测结果(仅保留Chest类别,ID=2)
|
||||
for box in results[0].boxes:
|
||||
class_id = int(box.cls[0]) # 类别ID
|
||||
class_name = model.names[class_id]
|
||||
confidence = float(box.conf[0])
|
||||
bbox = tuple(map(int, box.xyxy[0]))
|
||||
|
||||
# 过滤条件:置信度达标 + 类别为Chest(class_id=2)
|
||||
# and class_id == 2
|
||||
if confidence >= conf_threshold:
|
||||
detection_results.append({
|
||||
"class": class_name,
|
||||
"confidence": confidence,
|
||||
"bbox": bbox
|
||||
})
|
||||
|
||||
# 判断是否有目标
|
||||
has_content = len(detection_results) > 0
|
||||
return has_content, detection_results
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"检测过程出错:{str(e)}"
|
||||
print(error_msg)
|
||||
return False, None
|
||||
|
||||
|
||||
def get_enabled_model_rel_path():
|
||||
"""获取数据库中启用的模型相对路径"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
query = "SELECT path FROM model WHERE is_default = 1 LIMIT 1"
|
||||
cursor.execute(query)
|
||||
result = cursor.fetchone()
|
||||
|
||||
if not result or not result.get('path'):
|
||||
raise HTTPException(status_code=404, detail="未找到启用的默认模型")
|
||||
|
||||
return result['path']
|
||||
except MySQLError as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询默认模型时发生数据库错误:{str(e)}") from e
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
raise HTTPException(status_code=500, detail=f"获取默认模型路径失败:{str(e)}") from e
|
||||
finally:
|
||||
db.close_connection(conn, cursor)
|
||||
131
service/ocr_service.py
Normal file
@ -0,0 +1,131 @@
|
||||
# 首先添加NumPy兼容处理
|
||||
import numpy as np
|
||||
|
||||
# 修复np.int已弃用的问题
|
||||
if not hasattr(np, 'int'):
|
||||
np.int = int
|
||||
|
||||
from paddleocr import PaddleOCR
|
||||
from service.sensitive_service import get_all_sensitive_words
|
||||
|
||||
_ocr_engine = None
|
||||
_forbidden_words = set()
|
||||
_conf_threshold = 0.5
|
||||
|
||||
def set_forbidden_words(new_words):
|
||||
global _forbidden_words
|
||||
if not isinstance(new_words, (set, list, tuple)):
|
||||
raise TypeError("新违禁词必须是集合、列表或元组类型")
|
||||
_forbidden_words = set(new_words) # 确保是集合类型
|
||||
print(f"已通过函数更新违禁词,当前数量: {len(_forbidden_words)}")
|
||||
|
||||
def load_forbidden_words():
|
||||
global _forbidden_words
|
||||
try:
|
||||
_forbidden_words = get_all_sensitive_words()
|
||||
print(f"加载的违禁词数量: {len(_forbidden_words)}")
|
||||
except Exception as e:
|
||||
print(f"Forbidden words load error: {e}")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def init_ocr_engine():
|
||||
global _ocr_engine
|
||||
try:
|
||||
_ocr_engine = PaddleOCR(
|
||||
use_angle_cls=True,
|
||||
lang="ch",
|
||||
show_log=False,
|
||||
use_gpu=True,
|
||||
max_text_length=1024
|
||||
)
|
||||
load_result = load_forbidden_words()
|
||||
if not load_result:
|
||||
print("警告:违禁词加载失败,可能影响检测功能")
|
||||
print("OCR引擎初始化完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"OCR引擎初始化错误: {e}")
|
||||
_ocr_engine = None
|
||||
return False
|
||||
|
||||
|
||||
def detect(frame, conf_threshold=0.8):
|
||||
print("开始进行OCR检测...")
|
||||
try:
|
||||
ocr_res = _ocr_engine.ocr(frame, cls=True)
|
||||
if not ocr_res or not isinstance(ocr_res, list):
|
||||
return (False, "无OCR结果")
|
||||
|
||||
texts = []
|
||||
confs = []
|
||||
for line in ocr_res:
|
||||
if line is None:
|
||||
continue
|
||||
if isinstance(line, list):
|
||||
items_to_process = line
|
||||
else:
|
||||
items_to_process = [line]
|
||||
|
||||
for item in items_to_process:
|
||||
if isinstance(item, list) and len(item) == 4:
|
||||
is_coordinate = True
|
||||
for point in item:
|
||||
if not (isinstance(point, list) and len(point) == 2 and
|
||||
all(isinstance(coord, (int, float)) for coord in point)):
|
||||
is_coordinate = False
|
||||
break
|
||||
if is_coordinate:
|
||||
continue
|
||||
if isinstance(item, list) and all(isinstance(x, (int, float)) for x in item):
|
||||
continue
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
text, conf = item
|
||||
if isinstance(text, str) and isinstance(conf, (int, float)):
|
||||
texts.append(text.strip())
|
||||
confs.append(float(conf))
|
||||
continue
|
||||
if isinstance(item, list) and len(item) >= 2:
|
||||
text_data = item[1]
|
||||
if isinstance(text_data, tuple) and len(text_data) == 2:
|
||||
text, conf = text_data
|
||||
if isinstance(text, str) and isinstance(conf, (int, float)):
|
||||
texts.append(text.strip())
|
||||
confs.append(float(conf))
|
||||
continue
|
||||
elif isinstance(text_data, str):
|
||||
texts.append(text_data.strip())
|
||||
confs.append(1.0)
|
||||
continue
|
||||
print(f"无法解析的OCR结果格式: {item}")
|
||||
|
||||
if len(texts) != len(confs):
|
||||
return (False, "OCR结果格式异常")
|
||||
|
||||
# 收集所有识别到的违禁词(去重且保持出现顺序)
|
||||
vio_words = []
|
||||
for txt, conf in zip(texts, confs):
|
||||
if conf < _conf_threshold: # 过滤低置信度结果
|
||||
continue
|
||||
# 提取当前文本中包含的违禁词
|
||||
matched = [w for w in _forbidden_words if w in txt]
|
||||
# 仅添加未记录过的违禁词(去重)
|
||||
for word in matched:
|
||||
if word not in vio_words:
|
||||
vio_words.append(word)
|
||||
|
||||
has_text = len(texts) > 0
|
||||
has_violation = len(vio_words) > 0
|
||||
|
||||
if not has_text:
|
||||
return (False, "未识别到文本")
|
||||
elif has_violation:
|
||||
# 多个违禁词用逗号拼接
|
||||
return (True, ", ".join(vio_words))
|
||||
else:
|
||||
return (False, "未检测到违禁词")
|
||||
|
||||
except Exception as e:
|
||||
print(f"OCR detect error: {e}")
|
||||
return (False, f"检测错误: {str(e)}")
|
||||
36
service/sensitive_service.py
Normal file
@ -0,0 +1,36 @@
|
||||
from mysql.connector import Error as MySQLError
|
||||
|
||||
from ds.db import db
|
||||
|
||||
|
||||
def get_all_sensitive_words() -> list[str]:
|
||||
"""
|
||||
获取所有敏感词(返回纯字符串列表、用于过滤业务)
|
||||
|
||||
返回:
|
||||
list[str]: 包含所有敏感词的数组
|
||||
|
||||
异常:
|
||||
MySQLError: 数据库操作相关错误
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
# 获取数据库连接
|
||||
conn = db.get_connection()
|
||||
cursor = conn.cursor(dictionary=True)
|
||||
|
||||
# 执行查询(只获取敏感词字段、按ID排序)
|
||||
query = "SELECT name FROM sensitives ORDER BY id"
|
||||
cursor.execute(query)
|
||||
sensitive_records = cursor.fetchall()
|
||||
|
||||
# 提取敏感词到纯字符串数组
|
||||
return [record['name'] for record in sensitive_records]
|
||||
|
||||
except MySQLError as e:
|
||||
# 数据库错误向上抛出、由调用方处理
|
||||
raise MySQLError(f"查询敏感词列表失败: {str(e)}") from e
|
||||
finally:
|
||||
# 确保数据库连接正确释放
|
||||
db.close_connection(conn, cursor)
|
||||
|
After Width: | Height: | Size: 652 KiB |
|
After Width: | Height: | Size: 550 KiB |
|
After Width: | Height: | Size: 549 KiB |
|
After Width: | Height: | Size: 562 KiB |
|
After Width: | Height: | Size: 546 KiB |
|
After Width: | Height: | Size: 585 KiB |
|
After Width: | Height: | Size: 584 KiB |
|
After Width: | Height: | Size: 561 KiB |
|
After Width: | Height: | Size: 561 KiB |
|
After Width: | Height: | Size: 562 KiB |
|
After Width: | Height: | Size: 546 KiB |
|
After Width: | Height: | Size: 386 KiB |
|
After Width: | Height: | Size: 338 KiB |
|
After Width: | Height: | Size: 338 KiB |
|
After Width: | Height: | Size: 341 KiB |
|
After Width: | Height: | Size: 347 KiB |
|
After Width: | Height: | Size: 337 KiB |
|
After Width: | Height: | Size: 250 KiB |
|
After Width: | Height: | Size: 247 KiB |
|
After Width: | Height: | Size: 250 KiB |
|
After Width: | Height: | Size: 822 KiB |
|
After Width: | Height: | Size: 805 KiB |
|
After Width: | Height: | Size: 791 KiB |
|
After Width: | Height: | Size: 1.1 MiB |
|
After Width: | Height: | Size: 877 KiB |
|
After Width: | Height: | Size: 814 KiB |
|
After Width: | Height: | Size: 749 KiB |
|
After Width: | Height: | Size: 197 KiB |
|
After Width: | Height: | Size: 200 KiB |