内容安全审核

This commit is contained in:
2025-09-30 17:17:20 +08:00
commit cc6e66bbf8
523 changed files with 4853 additions and 0 deletions

8
.idea/.gitignore generated vendored Normal file
View File

@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

8
.idea/Detect.iml generated Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.10 (2)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@ -0,0 +1,98 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="76">
<item index="0" class="java.lang.String" itemvalue="scipy" />
<item index="1" class="java.lang.String" itemvalue="protobuf" />
<item index="2" class="java.lang.String" itemvalue="thop" />
<item index="3" class="java.lang.String" itemvalue="opencv-python" />
<item index="4" class="java.lang.String" itemvalue="PyYAML" />
<item index="5" class="java.lang.String" itemvalue="ipython" />
<item index="6" class="java.lang.String" itemvalue="torch" />
<item index="7" class="java.lang.String" itemvalue="numpy" />
<item index="8" class="java.lang.String" itemvalue="requests" />
<item index="9" class="java.lang.String" itemvalue="torchvision" />
<item index="10" class="java.lang.String" itemvalue="psutil" />
<item index="11" class="java.lang.String" itemvalue="tqdm" />
<item index="12" class="java.lang.String" itemvalue="pandas" />
<item index="13" class="java.lang.String" itemvalue="tensorboard" />
<item index="14" class="java.lang.String" itemvalue="seaborn" />
<item index="15" class="java.lang.String" itemvalue="matplotlib" />
<item index="16" class="java.lang.String" itemvalue="Pillow" />
<item index="17" class="java.lang.String" itemvalue="fastapi" />
<item index="18" class="java.lang.String" itemvalue="uvicorn" />
<item index="19" class="java.lang.String" itemvalue="python-jose" />
<item index="20" class="java.lang.String" itemvalue="passlib" />
<item index="21" class="java.lang.String" itemvalue="pydantic" />
<item index="22" class="java.lang.String" itemvalue="sqlalchemy" />
<item index="23" class="java.lang.String" itemvalue="imageio_ffmpeg" />
<item index="24" class="java.lang.String" itemvalue="ultralytics" />
<item index="25" class="java.lang.String" itemvalue="future" />
<item index="26" class="java.lang.String" itemvalue="jose" />
<item index="27" class="java.lang.String" itemvalue="ffmpeg-python" />
<item index="28" class="java.lang.String" itemvalue="setuptools" />
<item index="29" class="java.lang.String" itemvalue="opencv_python" />
<item index="30" class="java.lang.String" itemvalue="rsa" />
<item index="31" class="java.lang.String" itemvalue="greenlet" />
<item index="32" class="java.lang.String" itemvalue="networkx" />
<item index="33" class="java.lang.String" itemvalue="python-dateutil" />
<item index="34" class="java.lang.String" itemvalue="SQLAlchemy" />
<item index="35" class="java.lang.String" itemvalue="cffi" />
<item index="36" class="java.lang.String" itemvalue="python-dotenv" />
<item index="37" class="java.lang.String" itemvalue="h11" />
<item index="38" class="java.lang.String" itemvalue="py-cpuinfo" />
<item index="39" class="java.lang.String" itemvalue="cycler" />
<item index="40" class="java.lang.String" itemvalue="MarkupSafe" />
<item index="41" class="java.lang.String" itemvalue="pyasn1" />
<item index="42" class="java.lang.String" itemvalue="pycparser" />
<item index="43" class="java.lang.String" itemvalue="Jinja2" />
<item index="44" class="java.lang.String" itemvalue="sniffio" />
<item index="45" class="java.lang.String" itemvalue="ultralytics-thop" />
<item index="46" class="java.lang.String" itemvalue="fsspec" />
<item index="47" class="java.lang.String" itemvalue="filelock" />
<item index="48" class="java.lang.String" itemvalue="starlette" />
<item index="49" class="java.lang.String" itemvalue="certifi" />
<item index="50" class="java.lang.String" itemvalue="anyio" />
<item index="51" class="java.lang.String" itemvalue="urllib3" />
<item index="52" class="java.lang.String" itemvalue="pyparsing" />
<item index="53" class="java.lang.String" itemvalue="sympy" />
<item index="54" class="java.lang.String" itemvalue="annotated-types" />
<item index="55" class="java.lang.String" itemvalue="pydantic-settings" />
<item index="56" class="java.lang.String" itemvalue="six" />
<item index="57" class="java.lang.String" itemvalue="tzdata" />
<item index="58" class="java.lang.String" itemvalue="ecdsa" />
<item index="59" class="java.lang.String" itemvalue="kiwisolver" />
<item index="60" class="java.lang.String" itemvalue="packaging" />
<item index="61" class="java.lang.String" itemvalue="python-multipart" />
<item index="62" class="java.lang.String" itemvalue="click" />
<item index="63" class="java.lang.String" itemvalue="contourpy" />
<item index="64" class="java.lang.String" itemvalue="fonttools" />
<item index="65" class="java.lang.String" itemvalue="pydantic_core" />
<item index="66" class="java.lang.String" itemvalue="av" />
<item index="67" class="java.lang.String" itemvalue="colorama" />
<item index="68" class="java.lang.String" itemvalue="mpmath" />
<item index="69" class="java.lang.String" itemvalue="argon2-cffi-bindings" />
<item index="70" class="java.lang.String" itemvalue="typing_extensions" />
<item index="71" class="java.lang.String" itemvalue="charset-normalizer" />
<item index="72" class="java.lang.String" itemvalue="pillow" />
<item index="73" class="java.lang.String" itemvalue="argon2-cffi" />
<item index="74" class="java.lang.String" itemvalue="pytz" />
<item index="75" class="java.lang.String" itemvalue="idna" />
</list>
</value>
</option>
</inspection_tool>
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
<option name="ignoredErrors">
<list>
<option value="N802" />
</list>
</option>
</inspection_tool>
<inspection_tool class="Stylelint" enabled="true" level="ERROR" enabled_by_default="true" />
</profile>
</component>

View File

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

7
.idea/misc.xml generated Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="video" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (2)" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml generated Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/Detect.iml" filepath="$PROJECT_DIR$/.idea/Detect.iml" />
</modules>
</component>
</project>

Binary file not shown.

20
config.ini Normal file
View File

@ -0,0 +1,20 @@
[server]
port = 8000
[mysql]
host = 192.168.110.65
port = 6975
user = video_check
password = fsjPfhxCs8NrFGmL
database = video_check
charset = utf8mb4
[jwt]
secret_key = 6tsieyd87wefdw2wgeduwte23rfcsd
algorithm = HS256
access_token_expire_minutes = 30
[business]
ocr_conf = 0.6
face_conf = 0.6
yolo_conf = 0.7

Binary file not shown.

140
core/detect.py Normal file
View File

@ -0,0 +1,140 @@
import json
import re
from mysql.connector import Error as MySQLError
from ds.config import BUSINESS_CONFIG
from ds.db import db
from service.face_service import detect as faceDetect,init_insightface
from service.model_service import load_yolo_model,detect as yoloDetect
from service.ocr_service import detect as ocrDetect,init_ocr_engine
from service.file_service import save_detect_file, save_detect_yolo_file, save_detect_face_file
import asyncio
from concurrent.futures import ThreadPoolExecutor
# 创建线程池执行器
executor = ThreadPoolExecutor(max_workers=10)
def init():
# # 人脸相关
init_insightface()
# # 初始化OCR引擎
init_ocr_engine()
#初始化YOLO模型
load_yolo_model()
def save_db(model_type, client_ip, result):
conn = None
cursor = None
try:
# 连接数据库
conn = db.get_connection()
# 往表插入数据
cursor = conn.cursor(dictionary=True) # 返回字典格式结果
insert_query = """
INSERT INTO device_danger (client_ip, type, result)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_query, (client_ip, model_type, result))
conn.commit()
except MySQLError as e:
raise Exception(f"获取设备列表失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
def detectFrame(client_ip, frame):
# YOLO检测
yolo_flag, yolo_result = yoloDetect(frame, float(BUSINESS_CONFIG["yolo_conf"]))
if yolo_flag:
print(f"❌ 检测到违规内容,保存图片,YOLO")
danger_handler(client_ip)
path = save_detect_yolo_file(client_ip, frame, yolo_result, "yolo")
save_db(model_type="色情", client_ip=client_ip, result=str(path))
# 人脸检测
face_flag, face_result = faceDetect(frame, float(BUSINESS_CONFIG["face_conf"]))
if face_flag:
print(f"❌ 检测到违规内容,保存图片,FACE")
print("人脸识别内容:", face_result)
model_type = extract_face_names(face_result)
danger_handler(client_ip)
path = save_detect_face_file(client_ip, frame, face_result, "face")
save_db(model_type=model_type, client_ip=client_ip, result=str(path))
# OCR检测部分使用修正后的提取函数
ocr_flag, ocr_result = ocrDetect(frame, float(BUSINESS_CONFIG["ocr_conf"]))
if ocr_flag:
print(f"❌ 检测到违规内容,保存图片,OCR")
print("ocr识别内容:", ocr_result)
danger_handler(client_ip)
path = save_detect_file(client_ip, frame, "ocr")
save_db(model_type=str(ocr_result), client_ip=client_ip, result=str(path))
# 仅当所有检测均未发现违规时才提示
if not (face_flag or yolo_flag or ocr_flag):
print(f"所有模型未检测到任何违规内容")
def danger_handler(client_ip):
from ws.ws import send_message_to_client, get_current_time_str
from service.device_service import increment_alarm_count_by_ip
from service.device_service import update_is_need_handler_by_client_ip
danger_msg = {
"type": "danger",
"timestamp": get_current_time_str(),
"client_ip": client_ip,
}
asyncio.run(
send_message_to_client(
client_ip=client_ip,
json_data=json.dumps(danger_msg)
)
)
lock_msg = {
"type": "lock",
"timestamp": get_current_time_str(),
"client_ip": client_ip
}
asyncio.run(
send_message_to_client(
client_ip=client_ip,
json_data=json.dumps(lock_msg)
)
)
# 增加危险记录次数
increment_alarm_count_by_ip(client_ip)
# 更新设备状态为未处理
update_is_need_handler_by_client_ip(client_ip, 1)
def extract_prohibited_words(ocr_result: str) -> str:
"""
从多文本块的ocr_result中提取所有违禁词去重后用逗号拼接
适配格式:多个"文本: ... 包含违禁词: ...;"片段
"""
# 用正则匹配所有"包含违禁词: ...;"的片段(非贪婪匹配到分号)
# 匹配规则:"包含违禁词: "后面的内容,直到遇到";"结束
pattern = r"包含违禁词: (.*?);"
all_prohibited_segments = re.findall(pattern, ocr_result, re.DOTALL)
all_words = []
for segment in all_prohibited_segments:
# 去除每个片段中的置信度信息(如"(置信度: 1.00)"
cleaned = re.sub(r"\s*\([^)]*\)", "", segment.strip())
# 分割词语并过滤空值
words = [word.strip() for word in cleaned.split(",") if word.strip()]
all_words.extend(words)
# 去重后用逗号拼接
unique_words = list(set(all_words))
return ",".join(unique_words)
def extract_face_names(face_result: str) -> str:
pattern = r"匹配: (.*?) \("
all_names = re.findall(pattern, face_result)
unique_names = list(set([name.strip() for name in all_names if name.strip()]))
return ",".join(unique_names)

Binary file not shown.

Binary file not shown.

17
ds/config.py Normal file
View File

@ -0,0 +1,17 @@
import configparser
import os
# 读取配置文件路径
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../config.ini")
# 初始化配置解析器
config = configparser.ConfigParser()
# 读取配置文件
config.read(config_path, encoding="utf-8")
# 暴露配置项(方便其他文件调用)
SERVER_CONFIG = config["server"]
MYSQL_CONFIG = config["mysql"]
JWT_CONFIG = config["jwt"]
BUSINESS_CONFIG = config["business"]

59
ds/db.py Normal file
View File

@ -0,0 +1,59 @@
import mysql.connector
from mysql.connector import Error
from .config import MYSQL_CONFIG
_connection_pool = None
class Database:
"""MySQL 连接池管理类"""
pool_config = {
"host": MYSQL_CONFIG.get("host", "localhost"),
"port": int(MYSQL_CONFIG.get("port", 3306)),
"user": MYSQL_CONFIG.get("user", "root"),
"password": MYSQL_CONFIG.get("password", ""),
"database": MYSQL_CONFIG.get("database", ""),
"charset": MYSQL_CONFIG.get("charset", "utf8mb4"),
"pool_name": "fastapi_pool",
"pool_size": 5,
"pool_reset_session": True
}
@classmethod
def get_connection(cls):
"""获取数据库连接"""
try:
# 从连接池获取连接
conn = mysql.connector.connect(**cls.pool_config)
if conn.is_connected():
return conn
except Error as e:
raise Exception(f"MySQL 连接失败: {str(e)}") from e
@classmethod
def close_connection(cls, conn, cursor=None):
"""关闭连接和游标"""
try:
if cursor:
cursor.close()
if conn and conn.is_connected():
conn.close()
except Error as e:
raise Exception(f"MySQL 连接关闭失败: {str(e)}") from e
@classmethod
def close_all_connections(cls):
"""清理连接池(服务重启前调用)"""
try:
# 先检查属性是否存在、再判断是否有值
if hasattr(cls, "_connection_pool") and cls._connection_pool:
cls._connection_pool = None # 重置连接池
print("[Database] 连接池已重置、旧连接将被自动清理")
else:
print("[Database] 连接池未初始化或已重置、无需操作")
except Exception as e:
print(f"[Database] 重置连接池失败: {str(e)}")
# 暴露数据库操作工具
db = Database()

Binary file not shown.

View File

@ -0,0 +1,40 @@
import json
from datetime import datetime
from functools import wraps
from typing import Any
from encryption.encryption import aes_encrypt
from schema.response_schema import APIResponse
from pydantic import BaseModel
def encrypt_response(field: str = "data"):
"""接口返回值加密装饰器正确序列化自定义对象为JSON"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
original_response: APIResponse = await func(*args, **kwargs)
field_value = getattr(original_response, field)
if not field_value:
return original_response
# 自定义JSON序列化函数处理Pydantic模型和datetime
def json_default(obj: Any) -> Any:
if isinstance(obj, BaseModel):
return obj.model_dump()
if isinstance(obj, datetime):
return obj.isoformat()
return str(obj)
# 使用自定义序列化函数、确保生成标准JSON
field_value_json = json.dumps(field_value, default=json_default)
encrypted_data = aes_encrypt(field_value_json)
setattr(original_response, field, encrypted_data)
return original_response
return wrapper
return decorator

56
encryption/encryption.py Normal file
View File

@ -0,0 +1,56 @@
import os
import base64
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from fastapi import HTTPException
# 硬编码AES密钥32字节、AES-256
AES_SECRET_KEY = b"jr1vA6tfWMHOYi6UXw67UuO6fdak2rMa"
AES_BLOCK_SIZE = 16 # AES固定块大小
# 校验密钥长度
valid_key_lengths = [16, 24, 32]
if len(AES_SECRET_KEY) not in valid_key_lengths:
raise ValueError(
f"AES密钥长度必须为{valid_key_lengths}字节、当前为{len(AES_SECRET_KEY)}字节"
)
def aes_encrypt(plaintext: str) -> dict:
"""AES-CBC模式加密返回密文+IV、均为Base64编码"""
try:
# 生成随机IV16字节
iv = os.urandom(AES_BLOCK_SIZE)
# 创建加密器
cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv)
# 明文填充并加密
padded_plaintext = pad(plaintext.encode("utf-8"), AES_BLOCK_SIZE)
ciphertext = base64.b64encode(cipher.encrypt(padded_plaintext)).decode("utf-8")
iv_base64 = base64.b64encode(iv).decode("utf-8")
return {
"ciphertext": ciphertext,
"iv": iv_base64,
"algorithm": "AES-CBC"
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"AES加密失败{str(e)}") from e
def aes_decrypt(ciphertext: str, iv: str) -> str:
"""AES-CBC模式解密"""
try:
# 解码Base64
ciphertext_bytes = base64.b64decode(ciphertext)
iv_bytes = base64.b64decode(iv)
# 创建解密器
cipher = AES.new(AES_SECRET_KEY, AES.MODE_CBC, iv_bytes)
# 解密并去填充
decrypted_bytes = unpad(cipher.decrypt(ciphertext_bytes), AES_BLOCK_SIZE)
return decrypted_bytes.decode("utf-8")
except Exception as e:
raise HTTPException(status_code=500, detail=f"AES解密失败{str(e)}") from e

66
main.py Normal file
View File

@ -0,0 +1,66 @@
import uvicorn
import os
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
# 原有业务导入
from ds.config import SERVER_CONFIG
from middle.error_handler import global_exception_handler
from router.user_router import router as user_router
from router.sensitive_router import router as sensitive_router
from router.face_router import router as face_router
from router.device_router import router as device_router
from router.model_router import router as model_router
from router.file_router import router as file_router
from router.device_danger_router import router as device_danger_router
from core.detect import init
from ws.ws import ws_router, lifespan
# 初始化 FastAPI 应用
app = FastAPI(
title="内容安全审核后台",
description="含图片访问服务和动态模型管理",
version="1.0.0",
lifespan=lifespan
)
ALLOWED_ORIGINS = [
"*"
]
# 配置 CORS 中间件
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS, # 允许的前端域名
allow_credentials=True, # 允许携带 Cookie
allow_methods=["*"], # 允许所有 HTTP 方法
allow_headers=["*"], # 允许所有请求头
)
# 注册路由
app.include_router(user_router)
app.include_router(device_router)
app.include_router(face_router)
app.include_router(sensitive_router)
app.include_router(model_router)
app.include_router(file_router)
app.include_router(device_danger_router)
app.include_router(ws_router)
# 注册全局异常处理器
app.add_exception_handler(Exception, global_exception_handler)
# 主服务启动入口
if __name__ == "__main__":
# 启动 FastAPI 主服务仅使用8000端口
port = int(SERVER_CONFIG.get("port", 8000))
# 加载所有模型
init()
uvicorn.run(
app="main:app",
host="0.0.0.0",
port=port,
workers=1,
ws="websockets",
reload=False
)

Binary file not shown.

Binary file not shown.

102
middle/auth_middleware.py Normal file
View File

@ -0,0 +1,102 @@
from datetime import datetime, timedelta, timezone
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from passlib.context import CryptContext
from ds.config import JWT_CONFIG
from ds.db import db
# ------------------------------
# 密码加密配置
# ------------------------------
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# ------------------------------
# JWT 配置
# ------------------------------
SECRET_KEY = JWT_CONFIG["secret_key"]
ALGORITHM = JWT_CONFIG["algorithm"]
ACCESS_TOKEN_EXPIRE_MINUTES = int(JWT_CONFIG["access_token_expire_minutes"])
# OAuth2 依赖(从请求头获取 Token、格式: Bearer <token>
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
# ------------------------------
# 密码工具函数
# ------------------------------
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证明文密码与加密密码是否匹配"""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""对明文密码进行 bcrypt 加密"""
return pwd_context.hash(password)
# ------------------------------
# JWT 工具函数
# ------------------------------
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""生成 JWT Token"""
to_encode = data.copy()
# 设置过期时间
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
# 添加过期时间到 Token 数据
to_encode.update({"exp": expire})
# 生成 Token
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# ------------------------------
# 认证依赖(获取当前登录用户)
# ------------------------------
def get_current_user(token: str = Depends(oauth2_scheme)): # 移除返回类型注解
"""从 Token 中解析用户信息、验证通过后返回当前用户"""
# 延迟导入、打破循环依赖
from schema.user_schema import UserResponse # 在这里导入
# 认证失败异常
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token 无效或已过期",
headers={"WWW-Authenticate": "Bearer"},
)
try:
# 解码 Token
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
# 获取 Token 中的用户名
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
# 从数据库查询用户(验证用户是否存在)
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
query = "SELECT id, username, created_at, updated_at FROM users WHERE username = %s"
cursor.execute(query, (username,))
user = cursor.fetchone()
if user is None:
raise credentials_exception
# 转换为 UserResponse 模型(自动校验字段)
return UserResponse(**user)
except Exception as e:
raise credentials_exception from e
finally:
db.close_connection(conn, cursor)

68
middle/error_handler.py Normal file
View File

@ -0,0 +1,68 @@
from fastapi import Request, status
from fastapi.responses import JSONResponse
from fastapi.exceptions import HTTPException, RequestValidationError
from mysql.connector import Error as MySQLError
from jose import JWTError
from schema.response_schema import APIResponse
async def global_exception_handler(request: Request, exc: Exception):
"""全局异常处理器: 所有未捕获的异常都会在这里统一处理"""
# 请求参数验证错误Pydantic 校验失败)
if isinstance(exc, RequestValidationError):
error_details = []
for err in exc.errors():
error_details.append(f"{err['loc'][1]}: {err['msg']}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=APIResponse(
code=400,
message=f"请求参数错误: {'; '.join(error_details)}",
data=None
).model_dump()
)
# HTTP 异常(主动抛出的业务错误、如 401/404
if isinstance(exc, HTTPException):
return JSONResponse(
status_code=exc.status_code,
content=APIResponse(
code=exc.status_code,
message=exc.detail,
data=None
).model_dump()
)
# JWT 相关错误Token 无效/过期)
if isinstance(exc, JWTError):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=APIResponse(
code=401,
message="Token 无效或已过期",
data=None
).model_dump(),
headers={"WWW-Authenticate": "Bearer"},
)
# MySQL 数据库错误
if isinstance(exc, MySQLError):
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=APIResponse(
code=500,
message=f"数据库错误: {str(exc)}",
data=None
).model_dump()
)
# 其他未知错误
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=APIResponse(
code=500,
message=f"服务器内部错误: {str(exc)}",
data=None
).model_dump()
)

21
requirements.txt Normal file
View File

@ -0,0 +1,21 @@
fastapi==0.116.1
insightface==0.7.3
numpy==1.26.1
opencv_contrib_python==4.6.0.66
opencv_python==4.8.1.78
opencv_python_headless==4.11.0.86
paddleocr==2.6.0.1
paddlepaddle-gpu==2.5.2
passlib==1.7.4
Pillow==11.3.0
pycryptodome==3.23.0
pydantic==2.11.7
python_jose==3.5.0
torch==2.3.1+cu118
torchaudio==2.3.1+cu118
torchvision==0.18.1+cu118
ultralytics==8.3.198
uvicorn==0.35.0
insightface==0.7.3
onnxruntime-gpu==1.15.1
uvicorn[standard]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,112 @@
from fastapi import APIRouter, Query, Path
from mysql.connector import Error as MySQLError
from ds.db import db
from encryption.encrypt_decorator import encrypt_response
from schema.device_action_schema import (
DeviceActionResponse,
DeviceActionListResponse
)
from schema.response_schema import APIResponse
# 路由配置
router = APIRouter(
prefix="/api/device/actions",
tags=["设备操作记录"]
)
@router.get("/list", response_model=APIResponse, summary="分页查询设备操作记录")
@encrypt_response()
async def get_device_action_list(
page: int = Query(1, ge=1, description="页码、默认1"),
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100"),
client_ip: str = Query(None, description="按客户端IP筛选"),
action: int = Query(None, ge=0, le=1, description="按状态筛选0=离线、1=上线)")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 构建筛选条件(参数化查询、避免注入)
where_clause = []
params = []
if client_ip:
where_clause.append("client_ip = %s")
params.append(client_ip)
if action is not None:
where_clause.append("action = %s")
params.append(action)
# 查询总记录数(用于返回 total
count_sql = "SELECT COUNT(*) AS total FROM device_action"
if where_clause:
count_sql += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_sql, params)
total = cursor.fetchone()["total"]
# 分页查询记录(按创建时间倒序、确保最新记录在前)
offset = (page - 1) * page_size
list_sql = "SELECT * FROM device_action"
if where_clause:
list_sql += " WHERE " + " AND ".join(where_clause)
list_sql += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_sql, params)
action_list = cursor.fetchall()
# 仅返回 total + device_actions
return APIResponse(
code=200,
message="查询成功",
data=DeviceActionListResponse(
total=total,
device_actions=[DeviceActionResponse(**item) for item in action_list]
)
)
except MySQLError as e:
raise Exception(f"查询记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@router.get("/{client_ip}", response_model=APIResponse, summary="根据IP查询设备操作记录")
@encrypt_response()
async def get_device_actions_by_ip(
client_ip: str = Path(..., description="客户端IP地址")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 查询总记录数
count_sql = "SELECT COUNT(*) AS total FROM device_action WHERE client_ip = %s"
cursor.execute(count_sql, (client_ip,))
total = cursor.fetchone()["total"]
# 2. 查询该IP的所有记录按创建时间倒序
list_sql = """
SELECT * FROM device_action
WHERE client_ip = %s
ORDER BY created_at DESC
"""
cursor.execute(list_sql, (client_ip,))
action_list = cursor.fetchall()
# 3. 返回结果
return APIResponse(
code=200,
message="查询成功",
data=DeviceActionListResponse(
total=total,
device_actions=[DeviceActionResponse(**item) for item in action_list]
)
)
except MySQLError as e:
raise Exception(f"查询记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)

View File

@ -0,0 +1,164 @@
from datetime import date
from fastapi import APIRouter, Query, HTTPException, Path
from mysql.connector import Error as MySQLError
from ds.db import db
from encryption.encrypt_decorator import encrypt_response
from schema.device_danger_schema import (
DeviceDangerResponse, DeviceDangerListResponse
)
from schema.response_schema import APIResponse
router = APIRouter(
prefix="/api/devices/dangers",
tags=["设备管理-危险记录"]
)
# 获取危险记录列表
@router.get("/", response_model=APIResponse, summary="获取设备危险记录列表(多条件筛选)")
@encrypt_response()
async def get_danger_list(
page: int = Query(1, ge=1, description="页码、默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
client_ip: str = Query(None, max_length=100, description="按设备IP筛选"),
danger_type: str = Query(None, max_length=255, alias="type", description="按危险类型筛选"),
start_date: date = Query(None, description="按创建时间筛选开始日期、格式YYYY-MM-DD"),
end_date: date = Query(None, description="按创建时间筛选结束日期、格式YYYY-MM-DD")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 构建筛选条件
where_clause = []
params = []
if client_ip:
where_clause.append("client_ip = %s")
params.append(client_ip)
if danger_type:
where_clause.append("type = %s")
params.append(danger_type)
if start_date:
where_clause.append("DATE(created_at) >= %s")
params.append(start_date.strftime("%Y-%m-%d"))
if end_date:
where_clause.append("DATE(created_at) <= %s")
params.append(end_date.strftime("%Y-%m-%d"))
# 1. 统计符合条件的总记录数
count_query = "SELECT COUNT(*) AS total FROM device_danger"
if where_clause:
count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params)
total = cursor.fetchone()["total"]
# 2. 分页查询记录(按创建时间倒序、最新的在前)
offset = (page - 1) * page_size
list_query = "SELECT * FROM device_danger"
if where_clause:
list_query += " WHERE " + " AND ".join(where_clause)
list_query += " ORDER BY created_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset]) # 追加分页参数
cursor.execute(list_query, params)
danger_list = cursor.fetchall()
# 转换为响应模型
return APIResponse(
code=200,
message="获取危险记录列表成功",
data=DeviceDangerListResponse(
total=total,
dangers=[DeviceDangerResponse(**item) for item in danger_list]
)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询危险记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 获取单个设备的所有危险记录
@router.get("/device/{client_ip}", response_model=APIResponse, summary="获取单个设备的所有危险记录")
# @encrypt_response()
async def get_device_dangers(
client_ip: str = Path(..., max_length=100, description="设备IP地址"),
page: int = Query(1, ge=1, description="页码、默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间")
):
# 先检查设备是否存在
from service.device_danger_service import check_device_exist
if not check_device_exist(client_ip):
raise HTTPException(status_code=404, detail=f"IP为 {client_ip} 的设备不存在")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 统计该设备的危险记录总数
count_query = "SELECT COUNT(*) AS total FROM device_danger WHERE client_ip = %s"
cursor.execute(count_query, (client_ip,))
total = cursor.fetchone()["total"]
# 2. 分页查询该设备的危险记录
offset = (page - 1) * page_size
list_query = """
SELECT * FROM device_danger
WHERE client_ip = %s
ORDER BY created_at DESC
LIMIT %s OFFSET %s
"""
cursor.execute(list_query, (client_ip, page_size, offset))
danger_list = cursor.fetchall()
return APIResponse(
code=200,
message=f"获取设备[{client_ip}]危险记录成功(共{total}条)",
data=DeviceDangerListResponse(
total=total,
dangers=[DeviceDangerResponse(**item) for item in danger_list]
)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询设备[{client_ip}]危险记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 根据ID获取单个危险记录详情
# ------------------------------
@router.get("/{danger_id}", response_model=APIResponse, summary="根据ID获取单个危险记录详情")
@encrypt_response()
async def get_danger_detail(
danger_id: int = Path(..., ge=1, description="危险记录ID")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 查询单个危险记录
query = "SELECT * FROM device_danger WHERE id = %s"
cursor.execute(query, (danger_id,))
danger = cursor.fetchone()
if not danger:
raise HTTPException(status_code=404, detail=f"ID为 {danger_id} 的危险记录不存在")
return APIResponse(
code=200,
message="获取危险记录详情成功",
data=DeviceDangerResponse(**danger)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询危险记录详情失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)

329
router/device_router.py Normal file
View File

@ -0,0 +1,329 @@
import asyncio
import json
from datetime import date
from fastapi import APIRouter, Query, HTTPException, Request, Path
from mysql.connector import Error as MySQLError
from ds.db import db
from encryption.encrypt_decorator import encrypt_response
from schema.device_schema import (
DeviceCreateRequest, DeviceResponse, DeviceListResponse,
DeviceStatusHistoryResponse, DeviceStatusHistoryListResponse
)
from schema.response_schema import APIResponse
from service.device_service import update_online_status_by_ip
from ws.ws import get_current_time_str, aes_encrypt, is_client_connected
router = APIRouter(
prefix="/api/devices",
tags=["设备管理"]
)
# 创建设备信息接口
@router.post("/add", response_model=APIResponse, summary="创建设备信息")
@encrypt_response()
async def create_device(device_data: DeviceCreateRequest, request: Request):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 检查设备是否已存在
cursor.execute("SELECT * FROM devices WHERE client_ip = %s", (device_data.ip,))
existing_device = cursor.fetchone()
if existing_device:
# 更新设备为在线状态
from service.device_service import update_online_status_by_ip
update_online_status_by_ip(client_ip=device_data.ip, online_status=1)
return APIResponse(
code=200,
message=f"设备IP {device_data.ip} 已存在、返回已有设备信息",
data=DeviceResponse(**existing_device)
)
# 通过 User-Agent 判断设备类型
user_agent = request.headers.get("User-Agent", "").lower()
device_type = "unknown"
if user_agent == "default":
device_type = device_data.params.get("os") if (device_data.params and isinstance(device_data.params, dict)) else "unknown"
elif "windows" in user_agent:
device_type = "windows"
elif "android" in user_agent:
device_type = "android"
elif "linux" in user_agent:
device_type = "linux"
device_params_json = json.dumps(device_data.params) if device_data.params else None
# 插入新设备
insert_query = """
INSERT INTO devices
(client_ip, hostname, device_online_status, device_type, alarm_count, params, is_need_handler)
VALUES (%s, %s, %s, %s, %s, %s, %s)
"""
cursor.execute(insert_query, (
device_data.ip,
device_data.hostname,
0,
device_type,
0,
device_params_json,
0
))
conn.commit()
# 获取新设备并返回
device_id = cursor.lastrowid
cursor.execute("SELECT * FROM devices WHERE id = %s", (device_id,))
new_device = cursor.fetchone()
return APIResponse(
code=200,
message="设备创建成功",
data=DeviceResponse(**new_device)
)
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"创建设备失败: {str(e)}") from e
except json.JSONDecodeError as e:
raise Exception(f"设备参数JSON序列化失败: {str(e)}") from e
except Exception as e:
if conn:
conn.rollback()
raise e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 获取设备列表接口
# ------------------------------
@router.get("/", response_model=APIResponse, summary="获取设备列表(支持筛选分页)")
@encrypt_response()
async def get_device_list(
page: int = Query(1, ge=1, description="页码、默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
device_type: str = Query(None, description="按设备类型筛选"),
online_status: int = Query(None, ge=0, le=1, description="按在线状态筛选")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
where_clause = []
params = []
if device_type:
where_clause.append("device_type = %s")
params.append(device_type)
if online_status is not None:
where_clause.append("device_online_status = %s")
params.append(online_status)
# 统计总数
count_query = "SELECT COUNT(*) AS total FROM devices"
if where_clause:
count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params)
total = cursor.fetchone()["total"]
# 分页查询列表
offset = (page - 1) * page_size
list_query = "SELECT * FROM devices"
if where_clause:
list_query += " WHERE " + " AND ".join(where_clause)
list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_query, params)
device_list = cursor.fetchall()
return APIResponse(
code=200,
message="获取设备列表成功",
data=DeviceListResponse(
total=total,
devices=[DeviceResponse(**device) for device in device_list]
)
)
except MySQLError as e:
raise Exception(f"获取设备列表失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 获取设备上下线记录接口
# ------------------------------
@router.get("/status-history", response_model=APIResponse, summary="获取设备上下线记录")
@encrypt_response()
async def get_device_status_history(
client_ip: str = Query(None, description="客户端IP地址非必填为空时返回所有设备记录"),
page: int = Query(1, ge=1, description="页码、默认第1页"),
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
start_date: date = Query(None, description="开始日期、格式YYYY-MM-DD"),
end_date: date = Query(None, description="结束日期、格式YYYY-MM-DD")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 检查设备是否存在仅传IP时强制指定Collation
if client_ip is not None:
# 关键调整1WHERE条件中给d.client_ip指定Collation与da一致或反之
check_query = """
SELECT id FROM devices
WHERE client_ip COLLATE utf8mb4_general_ci = %s COLLATE utf8mb4_general_ci
"""
cursor.execute(check_query, (client_ip,))
device = cursor.fetchone()
if not device:
raise HTTPException(status_code=404, detail=f"客户端IP为 {client_ip} 的设备不存在")
# 2. 构建WHERE条件
where_clause = []
params = []
# 关键调整2传IP时强制指定da.client_ip的Collation
if client_ip is not None:
where_clause.append("da.client_ip COLLATE utf8mb4_general_ci = %s COLLATE utf8mb4_general_ci")
params.append(client_ip)
if start_date:
where_clause.append("DATE(da.created_at) >= %s")
params.append(start_date.strftime("%Y-%m-%d"))
if end_date:
where_clause.append("DATE(da.created_at) <= %s")
params.append(end_date.strftime("%Y-%m-%d"))
# 3. 统计总数JOIN时强制统一Collation
count_query = """
SELECT COUNT(*) AS total
FROM device_action da
LEFT JOIN devices d
ON da.client_ip COLLATE utf8mb4_general_ci = d.client_ip COLLATE utf8mb4_general_ci
"""
if where_clause:
count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params)
total = cursor.fetchone()["total"]
# 4. 分页查询JOIN时强制统一Collation
offset = (page - 1) * page_size
list_query = """
SELECT da.*, d.id AS device_id
FROM device_action da
LEFT JOIN devices d
ON da.client_ip COLLATE utf8mb4_general_ci = d.client_ip COLLATE utf8mb4_general_ci
"""
if where_clause:
list_query += " WHERE " + " AND ".join(where_clause)
list_query += " ORDER BY da.created_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_query, params)
history_list = cursor.fetchall()
# 后续格式化响应逻辑不变...
formatted_history = []
for item in history_list:
formatted_item = {
"id": item["id"],
"device_id": item["device_id"], # 可能为NoneIP无对应设备
"client_ip": item["client_ip"],
"status": item["action"],
"status_time": item["created_at"]
}
formatted_history.append(formatted_item)
return APIResponse(
code=200,
message="获取设备上下线记录成功",
data=DeviceStatusHistoryListResponse(
total=total,
history=[DeviceStatusHistoryResponse(**item) for item in formatted_history]
)
)
except MySQLError as e:
raise Exception(f"获取设备上下线记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 通过客户端IP设置设备is_need_handler为0接口
# ------------------------------
@router.post("/need-handler/reset", response_model=APIResponse, summary="解封客户端")
@encrypt_response()
async def reset_device_need_handler(
client_ip: str = Query(..., description="目标设备的客户端IP地址必填")
):
try:
from service.device_service import update_is_need_handler_by_client_ip
success = update_is_need_handler_by_client_ip(
client_ip=client_ip,
is_need_handler=0 # 固定设置为0不需要处理
)
if success:
online_status = is_client_connected(client_ip)
# 如果设备在线,则发送消息给前端
if online_status:
# 调用 ws 发送一个消息给前端、告诉他已解锁
unlock_msg = {
"type": "unlock",
"timestamp": get_current_time_str(),
"client_ip": client_ip
}
from ws.ws import send_message_to_client
await send_message_to_client(client_ip, json.dumps(unlock_msg))
# 休眠 100 ms
await asyncio.sleep(0.1)
frame_permit_msg = {
"type": "frame",
"timestamp": get_current_time_str(),
"client_ip": client_ip
}
await send_message_to_client(client_ip, json.dumps(frame_permit_msg))
# 更新设备在线状态为1
update_online_status_by_ip(client_ip, 1)
return APIResponse(
code=200,
message=f"设备已解封",
data={
"client_ip": client_ip,
"is_need_handler": 0,
"status_desc": "设备已解封"
}
)
# 捕获工具方法抛出的业务异常如IP为空、设备不存在
except ValueError as e:
# 业务异常返回400/404状态码与现有接口异常规范一致
raise HTTPException(
status_code=404 if "设备不存在" in str(e) else 400,
detail=str(e)
) from e
# 捕获数据库层面异常如连接失败、SQL执行错误
except MySQLError as e:
raise Exception(f"设置is_need_handler失败数据库操作异常 - {str(e)}") from e
# 捕获其他未知异常
except Exception as e:
raise Exception(f"设置is_need_handler失败未知错误 - {str(e)}") from e

326
router/face_router.py Normal file
View File

@ -0,0 +1,326 @@
from io import BytesIO
from pathlib import Path
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query, Request
from mysql.connector import Error as MySQLError
from ds.db import db
from encryption.encrypt_decorator import encrypt_response
from schema.face_schema import (
FaceCreateRequest,
FaceResponse,
FaceListResponse
)
from schema.response_schema import APIResponse
from service.face_service import update_face_data
from util.face_util import add_binary_data
from service.file_service import save_source_file
router = APIRouter(prefix="/api/faces", tags=["人脸管理"])
# 创建人脸记录
@router.post("", response_model=APIResponse, summary="创建人脸记录")
@encrypt_response()
async def create_face(
request: Request,
name: str = Form(None, max_length=255, description="名称(可选)"),
file: UploadFile = File(..., description="人脸文件(必传)")
):
conn = None
cursor = None
try:
face_create = FaceCreateRequest(name=name)
client_ip = request.client.host if request.client else ""
if not client_ip:
raise HTTPException(status_code=400, detail="无法获取客户端IP")
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 先读取文件内容
file_content = await file.read()
# 将文件指针重置到开头
await file.seek(0)
# 再保存文件
path = save_source_file(file, "face")
# 提取人脸特征
detect_success, detect_result = add_binary_data(file_content)
if not detect_success:
raise HTTPException(status_code=400, detail=f"人脸检测失败:{detect_result}")
eigenvalue = detect_result
# 插入数据库
insert_query = """
INSERT INTO face (name, eigenvalue, address)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_query, (face_create.name, str(eigenvalue), path))
conn.commit()
# 查询新记录
cursor.execute("""
SELECT id, name, address, created_at, updated_at
FROM face
WHERE id = LAST_INSERT_ID()
""")
created_face = cursor.fetchone()
if not created_face:
raise HTTPException(status_code=500, detail="创建成功但无法获取记录")
# TODO 重新加载人脸模型
update_face_data()
return APIResponse(
code=200,
message=f"人脸记录创建成功ID: {created_face['id']}",
data=FaceResponse(**created_face)
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"创建失败: {str(e)}") from e
except Exception as e:
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
finally:
await file.close()
db.close_connection(conn, cursor)
# 获取单个人脸记录
@router.get("/{face_id}", response_model=APIResponse, summary="获取单个人脸记录")
@encrypt_response()
async def get_face(face_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
query = """
SELECT id, name, address, created_at, updated_at
FROM face
WHERE id = %s
"""
cursor.execute(query, (face_id,))
face = cursor.fetchone()
if not face:
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
return APIResponse(
code=200,
message="查询成功",
data=FaceResponse(**face)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 获取人脸列表
@router.get("", response_model=APIResponse, summary="获取人脸列表(分页+筛选)")
@encrypt_response()
async def get_face_list(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100),
name: str = Query(None),
has_eigenvalue: bool = Query(None)
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
where_clause = []
params = []
if name:
where_clause.append("name LIKE %s")
params.append(f"%{name}%")
if has_eigenvalue is not None:
where_clause.append("eigenvalue IS NOT NULL" if has_eigenvalue else "eigenvalue IS NULL")
# 总记录数
count_query = "SELECT COUNT(*) AS total FROM face"
if where_clause:
count_query += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_query, params)
total = cursor.fetchone()["total"]
# 列表数据
offset = (page - 1) * page_size
list_query = """
SELECT id, name, address, created_at, updated_at
FROM face
"""
if where_clause:
list_query += " WHERE " + " AND ".join(where_clause)
list_query += " ORDER BY id DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_query, params)
face_list = cursor.fetchall()
return APIResponse(
code=200,
message=f"获取成功(共{total}条)",
data=FaceListResponse(
total=total,
faces=[FaceResponse(**face) for face in face_list]
)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 删除人脸记录
@router.delete("/{face_id}", response_model=APIResponse, summary="删除人脸记录")
@encrypt_response()
async def delete_face(face_id: int):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT id, address FROM face WHERE id = %s", (face_id,))
exist_face = cursor.fetchone()
if not exist_face:
raise HTTPException(status_code=404, detail=f"ID为 {face_id} 的记录不存在")
old_db_path = exist_face["address"]
cursor.execute("DELETE FROM face WHERE id = %s", (face_id,))
conn.commit()
# 删除图片
if old_db_path:
old_abs_path = Path(old_db_path).resolve()
if old_abs_path.exists():
try:
old_abs_path.unlink()
print(f"[FaceRouter] 已删除图片:{old_abs_path}")
extra_msg = "(已同步删除图片)"
except Exception as e:
print(f"[FaceRouter] 删除图片失败:{str(e)}")
extra_msg = "(图片删除失败)"
else:
extra_msg = "(图片不存在)"
else:
extra_msg = "(无关联图片)"
# TODO 重新加载人脸模型
update_face_data()
return APIResponse(
code=200,
message=f"ID为 {face_id} 的记录删除成功 {extra_msg}",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
@router.post("/batch-import", response_model=APIResponse, summary="批量导入文件夹下的人脸图片")
# @encrypt_response()
async def batch_import_faces(
folder_path: str = Form(..., description="人脸图片所在的**服务器本地文件夹路径**")
):
conn = None
cursor = None
success_count = 0 # 成功导入数量
fail_list = [] # 失败记录(包含文件名、错误原因)
try:
# 1. 验证文件夹有效性
folder = Path(folder_path)
if not folder.exists() or not folder.is_dir():
raise HTTPException(status_code=400, detail=f"文件夹 {folder_path} 不存在或不是有效目录")
# 2. 定义支持的图片格式
supported_extensions = {".png", ".jpg", ".jpeg", ".webp"}
# 3. 数据库连接初始化
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 4. 遍历文件夹内所有文件
for file_path in folder.iterdir():
if file_path.is_file() and file_path.suffix.lower() in supported_extensions:
file_name = file_path.stem # 提取文件名(不含后缀)作为 `name`
try:
# 4.1 读取文件二进制内容
with open(file_path, "rb") as f:
file_content = f.read()
# 4.2 构造模拟的 UploadFile 对象(用于兼容 `save_source_file`
mock_file = UploadFile(
filename=file_path.name,
file=BytesIO(file_content)
)
# 4.3 保存文件到指定目录
saved_path = save_source_file(mock_file, "face")
# 4.4 提取人脸特征
detect_success, detect_result = add_binary_data(file_content)
if not detect_success:
fail_list.append({
"name": file_name,
"file_path": str(file_path),
"error": f"人脸检测失败:{detect_result}"
})
continue # 跳过当前文件,处理下一个
eigenvalue = detect_result
# 4.5 插入数据库
insert_sql = """
INSERT INTO face (name, eigenvalue, address)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_sql, (file_name, str(eigenvalue), saved_path))
conn.commit() # 提交当前文件的插入操作
success_count += 1
except Exception as e:
# 捕获单文件处理的异常,记录后继续处理其他文件
fail_list.append({
"name": file_name,
"file_path": str(file_path),
"error": str(e)
})
if conn:
conn.rollback() # 回滚当前失败文件的插入
# 5. 重新加载人脸模型(确保新增数据生效)
update_face_data()
# 6. 构造返回结果
return APIResponse(
code=200,
message=f"批量导入完成,成功 {success_count} 条,失败 {len(fail_list)}",
data={
"success_count": success_count,
"fail_details": fail_list
}
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库操作失败: {str(e)}") from e
except HTTPException:
raise # 直接抛出400等由业务逻辑触发的HTTP异常
except Exception as e:
raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)}") from e
finally:
db.close_connection(conn, cursor)

31
router/file_router.py Normal file
View File

@ -0,0 +1,31 @@
import os
from fastapi import FastAPI, HTTPException, Path, APIRouter
from fastapi.responses import FileResponse
from service.file_service import UPLOAD_ROOT
router = APIRouter(
prefix="/api/file",
tags=["文件管理"]
)
@router.get("/download/{relative_path:path}", summary="下载文件")
async def download_file(
relative_path: str = Path(..., description="文件的相对路径")
):
file_path = os.path.abspath(os.path.join(UPLOAD_ROOT, relative_path))
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail=f"文件不存在: {file_path}")
if not os.path.isfile(file_path):
raise HTTPException(status_code=400, detail="路径指向的不是文件")
if not file_path.startswith(os.path.abspath(UPLOAD_ROOT)):
raise HTTPException(status_code=403, detail="无权访问该文件")
return FileResponse(
path=file_path,
filename=os.path.basename(file_path),
media_type="application/octet-stream"
)

269
router/model_router.py Normal file
View File

@ -0,0 +1,269 @@
import os
from pathlib import Path
from service.file_service import save_source_file
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Query
from mysql.connector import Error as MySQLError
from ds.db import db
from encryption.encrypt_decorator import encrypt_response
from schema.model_schema import (
ModelResponse,
ModelListResponse
)
from schema.response_schema import APIResponse
from service.model_service import ALLOWED_MODEL_EXT, MAX_MODEL_SIZE, load_yolo_model
router = APIRouter(prefix="/api/models", tags=["模型管理"])
# 上传模型
@router.post("", response_model=APIResponse, summary="上传YOLO模型.pt格式")
@encrypt_response()
async def upload_model(
name: str = Form(..., description="模型名称"),
description: str = Form(None, description="模型描述"),
file: UploadFile = File(..., description=f"YOLO模型文件.pt、最大{MAX_MODEL_SIZE // 1024 // 1024}MB")
):
conn = None
cursor = None
try:
# 校验文件格式
file_ext = file.filename.split(".")[-1].lower() if "." in file.filename else ""
if file_ext not in ALLOWED_MODEL_EXT:
raise HTTPException(
status_code=400,
detail=f"仅支持{ALLOWED_MODEL_EXT}格式、当前:{file_ext}"
)
# 校验文件大小
if file.size > MAX_MODEL_SIZE:
raise HTTPException(
status_code=400,
detail=f"文件过大!最大{MAX_MODEL_SIZE // 1024 // 1024}MB、当前{file.size // 1024 // 1024}MB"
)
# 保存文件
file_path = save_source_file(file, "model")
# 数据库操作
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
insert_sql = """
INSERT INTO model (name, path, is_default, description, file_size)
VALUES (%s, %s, 0, %s, %s)
"""
cursor.execute(insert_sql, (name, file_path, description, file.size))
conn.commit()
# 获取新增记录
cursor.execute("SELECT * FROM model WHERE id = LAST_INSERT_ID()")
new_model = cursor.fetchone()
if not new_model:
raise HTTPException(status_code=500, detail="上传成功但无法获取记录")
return APIResponse(
code=200,
message=f"模型上传成功",
data=ModelResponse(**new_model)
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
except Exception as e:
raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}") from e
finally:
await file.close()
db.close_connection(conn, cursor)
# 获取模型列表
@router.get("", response_model=APIResponse, summary="获取模型列表(分页)")
@encrypt_response()
async def get_model_list(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=100),
name: str = Query(None),
is_default: bool = Query(None)
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
where_clause = []
params = []
if name:
where_clause.append("name LIKE %s")
params.append(f"%{name}%")
if is_default is not None:
where_clause.append("is_default = %s")
params.append(1 if is_default else 0)
# 总记录数
count_sql = "SELECT COUNT(*) AS total FROM model"
if where_clause:
count_sql += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_sql, params)
total = cursor.fetchone()["total"]
# 分页数据
offset = (page - 1) * page_size
list_sql = "SELECT * FROM model"
if where_clause:
list_sql += " WHERE " + " AND ".join(where_clause)
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
params.extend([page_size, offset])
cursor.execute(list_sql, params)
model_list = cursor.fetchall()
return APIResponse(
code=200,
message=f"获取成功!",
data=ModelListResponse(
total=total,
models=[ModelResponse(**model) for model in model_list]
)
)
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 更换默认模型
@router.put("/{model_id}/set-default", response_model=APIResponse, summary="更换默认模型")
@encrypt_response()
async def set_default_model(
model_id: int
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
conn.autocommit = False
# 校验目标模型是否存在
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
target_model = cursor.fetchone()
if not target_model:
raise HTTPException(status_code=404, detail=f"目标模型不存在!")
# 检查是否已为默认模型
if target_model["is_default"]:
return APIResponse(
code=200,
message=f"已是默认模型、无需更换",
data=ModelResponse(**target_model)
)
# 数据库事务:更新默认模型状态
try:
cursor.execute("UPDATE model SET is_default = 0, updated_at = CURRENT_TIMESTAMP")
cursor.execute(
"UPDATE model SET is_default = 1, updated_at = CURRENT_TIMESTAMP WHERE id = %s",
(model_id,)
)
conn.commit()
except MySQLError as e:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"更新默认模型状态失败(已回滚):{str(e)}"
) from e
# 更新模型
load_yolo_model()
# 返回成功响应
return APIResponse(
code=200,
message=f"更换成功",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
if conn:
conn.autocommit = True
db.close_connection(conn, cursor)
# 路由文件(如 model_router.py中的删除接口
@router.delete("/{model_id}", response_model=APIResponse, summary="删除模型")
@encrypt_response()
async def delete_model(model_id: int):
# 1. 正确导入 model_service 中的全局变量(关键修复:变量名匹配)
from service.model_service import (
current_yolo_model,
current_model_absolute_path,
load_yolo_model # 用于删除后重新加载模型(可选)
)
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 2. 查询待删除模型信息
cursor.execute("SELECT * FROM model WHERE id = %s", (model_id,))
exist_model = cursor.fetchone()
if not exist_model:
raise HTTPException(status_code=404, detail=f"模型不存在!")
# 3. 关键判断:①默认模型不可删 ②正在使用的模型不可删
if exist_model["is_default"]:
raise HTTPException(status_code=400, detail="默认模型不可删除!")
# 计算待删除模型的绝对路径(与 model_service 逻辑一致)
from service.file_service import get_absolute_path
del_model_abs_path = get_absolute_path(exist_model["path"])
# 判断是否正在使用(对比 current_model_absolute_path
if current_model_absolute_path and del_model_abs_path == current_model_absolute_path:
raise HTTPException(status_code=400, detail="该模型正在使用中,禁止删除!")
# 4. 先删除数据库记录(避免文件删除失败导致数据不一致)
cursor.execute("DELETE FROM model WHERE id = %s", (model_id,))
conn.commit()
# 5. 再删除本地文件(捕获文件删除异常,不影响数据库删除结果)
extra_msg = ""
try:
if os.path.exists(del_model_abs_path):
os.remove(del_model_abs_path) # 或用 Path(del_model_abs_path).unlink()
extra_msg = "(本地文件已同步删除)"
else:
extra_msg = "(本地文件不存在,无需删除)"
except Exception as e:
extra_msg = f"(本地文件删除失败:{str(e)}"
# 6. 若删除后当前模型为空(极端情况),重新加载默认模型(可选优化)
if current_yolo_model is None:
try:
load_yolo_model()
print(f"[模型删除后] 重新加载默认模型成功")
except Exception as e:
print(f"[模型删除后] 重新加载默认模型失败:{str(e)}")
return APIResponse(
code=200,
message=f"模型删除成功!",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(status_code=500, detail=f"数据库错误:{str(e)}") from e
finally:
db.close_connection(conn, cursor)

306
router/sensitive_router.py Normal file
View File

@ -0,0 +1,306 @@
from fastapi import APIRouter, Depends, HTTPException, Query, File, UploadFile
from mysql.connector import Error as MySQLError
from typing import Optional
from ds.db import db
from encryption.encrypt_decorator import encrypt_response
from schema.sensitive_schema import (
SensitiveCreateRequest,
SensitiveResponse,
SensitiveListResponse
)
from schema.response_schema import APIResponse
from middle.auth_middleware import get_current_user
from schema.user_schema import UserResponse
from service.ocr_service import set_forbidden_words
from service.sensitive_service import get_all_sensitive_words
router = APIRouter(
prefix="/api/sensitives",
tags=["敏感信息管理"]
)
# 创建敏感信息记录
@router.post("", response_model=APIResponse, summary="创建敏感信息记录")
@encrypt_response()
async def create_sensitive(
sensitive: SensitiveCreateRequest,
current_user: UserResponse = Depends(get_current_user)
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 插入新敏感信息记录到数据库不包含ID、由数据库自动生成
insert_query = """
INSERT INTO sensitives (name, created_at, updated_at)
VALUES (%s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
"""
cursor.execute(insert_query, (sensitive.name,))
conn.commit()
# 获取刚插入记录的ID使用LAST_INSERT_ID()函数)
new_id = cursor.lastrowid
# 查询刚创建的记录并返回
select_query = "SELECT * FROM sensitives WHERE id = %s"
cursor.execute(select_query, (new_id,))
created_sensitive = cursor.fetchone()
# 重新加载最新的敏感词
set_forbidden_words(get_all_sensitive_words())
return APIResponse(
code=200,
message="敏感信息记录创建成功",
data=SensitiveResponse(**created_sensitive)
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"创建敏感信息记录失败: {str(e)}"
) from e
finally:
db.close_connection(conn, cursor)
# 获取敏感信息分页列表
@router.get("", response_model=APIResponse, summary="获取敏感信息分页列表(支持关键词搜索)")
@encrypt_response()
async def get_sensitive_list(
page: int = Query(1, ge=1, description="页码默认1、最小1"),
page_size: int = Query(10, ge=1, le=100, description="每页条数默认10、1-100"),
name: Optional[str] = Query(None, description="敏感词关键词搜索(模糊匹配)")
):
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 构建查询条件(支持关键词搜索)
where_clause = []
params = []
if name:
where_clause.append("name LIKE %s")
params.append(f"%{name}%") # 模糊匹配关键词
# 2. 查询总记录数(用于分页计算)
count_sql = "SELECT COUNT(*) AS total FROM sensitives"
if where_clause:
count_sql += " WHERE " + " AND ".join(where_clause)
cursor.execute(count_sql, params.copy()) # 复制参数列表、避免后续污染
total = cursor.fetchone()["total"]
# 3. 计算分页偏移量
offset = (page - 1) * page_size
# 4. 分页查询敏感词数据(按更新时间倒序、最新的在前)
list_sql = "SELECT * FROM sensitives"
if where_clause:
list_sql += " WHERE " + " AND ".join(where_clause)
# 排序+分页LIMIT 条数 OFFSET 偏移量)
list_sql += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
# 补充分页参数page_size和offset
params.extend([page_size, offset])
cursor.execute(list_sql, params)
sensitive_list = cursor.fetchall()
# 5. 构造分页响应数据
return APIResponse(
code=200,
message=f"敏感信息列表查询成功(共{total}条记录、当前第{page}页)",
data=SensitiveListResponse(
total=total,
sensitives=[SensitiveResponse(**item) for item in sensitive_list]
)
)
except MySQLError as e:
raise HTTPException(
status_code=500,
detail=f"查询敏感信息列表失败: {str(e)}"
) from e
finally:
db.close_connection(conn, cursor)
# 删除敏感信息记录
@router.delete("/{sensitive_id}", response_model=APIResponse, summary="删除敏感信息记录")
@encrypt_response()
async def delete_sensitive(
sensitive_id: int,
current_user: UserResponse = Depends(get_current_user) # 需登录认证
):
"""
删除敏感信息记录:
- 需登录认证
- 根据ID删除敏感信息记录
- 返回删除成功信息
"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 检查记录是否存在
check_query = "SELECT id FROM sensitives WHERE id = %s"
cursor.execute(check_query, (sensitive_id,))
existing_sensitive = cursor.fetchone()
if not existing_sensitive:
raise HTTPException(
status_code=404,
detail=f"ID为 {sensitive_id} 的敏感信息记录不存在"
)
# 2. 执行删除操作
delete_query = "DELETE FROM sensitives WHERE id = %s"
cursor.execute(delete_query, (sensitive_id,))
conn.commit()
# 重新加载最新的敏感词
set_forbidden_words(get_all_sensitive_words())
return APIResponse(
code=200,
message=f"ID为 {sensitive_id} 的敏感信息记录删除成功",
data=None
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"删除敏感信息记录失败: {str(e)}"
) from e
finally:
db.close_connection(conn, cursor)
# 批量导入敏感信息从txt文件
@router.post("/batch-import", response_model=APIResponse, summary="批量导入敏感信息从txt文件")
@encrypt_response()
async def batch_import_sensitives(
file: UploadFile = File(..., description="包含敏感词的txt文件每行一个敏感词"),
# current_user: UserResponse = Depends(get_current_user) # 添加认证依赖
):
"""
批量导入敏感信息:
- 需登录认证
- 接收txt文件文件中每行一个敏感词
- 批量插入到数据库中(仅插入不存在的敏感词)
- 返回导入结果统计
"""
# 检查文件类型
filename = file.filename or ""
if not filename.lower().endswith(".txt"):
raise HTTPException(
status_code=400,
detail=f"请上传txt格式的文件当前文件格式: {filename.split('.')[-1] if '.' in filename else '未知'}"
)
# 检查文件大小
file_size = await file.read(1) # 读取1字节获取文件信息
await file.seek(0) # 重置文件指针
if not file_size: # 文件为空
raise HTTPException(
status_code=400,
detail="上传的文件为空,请提供有效的敏感词文件"
)
conn = None
cursor = None
try:
# 读取文件内容
contents = await file.read()
# 按行分割内容,处理不同操作系统的换行符
lines = contents.decode("utf-8", errors="replace").splitlines()
# 过滤空行和仅含空白字符的行
sensitive_words = [line.strip() for line in lines if line.strip()]
if not sensitive_words:
return APIResponse(
code=200,
message="文件中没有有效的敏感词",
data={"imported": 0, "total": 0}
)
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 先查询数据库中已存在的敏感词
query = "SELECT name FROM sensitives WHERE name IN (%s)"
# 处理参数,根据敏感词数量生成占位符
placeholders = ', '.join(['%s'] * len(sensitive_words))
cursor.execute(query % placeholders, sensitive_words)
existing_words = {row['name'] for row in cursor.fetchall()}
# 过滤掉已存在的敏感词
new_words = [word for word in sensitive_words if word not in existing_words]
if not new_words:
return APIResponse(
code=200,
message="所有敏感词均已存在于数据库中",
data={
"total": len(sensitive_words),
"imported": 0,
"duplicates": len(sensitive_words),
"message": f"共处理{len(sensitive_words)}个敏感词,全部已存在,未导入任何新敏感词"
}
)
# 批量插入新的敏感词
insert_query = """
INSERT INTO sensitives (name, created_at, updated_at)
VALUES (%s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
"""
# 准备参数列表
params = [(word,) for word in new_words]
# 执行批量插入
cursor.executemany(insert_query, params)
conn.commit()
# 重新加载最新的敏感词
set_forbidden_words(get_all_sensitive_words())
return APIResponse(
code=200,
message=f"敏感词批量导入成功",
data={
"total": len(sensitive_words),
"imported": len(new_words),
"duplicates": len(sensitive_words) - len(new_words),
"message": f"共处理{len(sensitive_words)}个敏感词,成功导入{len(new_words)}个,{len(sensitive_words) - len(new_words)}个已存在"
}
)
except UnicodeDecodeError:
raise HTTPException(
status_code=400,
detail="文件编码格式错误请使用UTF-8编码的txt文件"
)
except MySQLError as e:
if conn:
conn.rollback()
raise HTTPException(
status_code=500,
detail=f"批量导入敏感词失败: {str(e)}"
) from e
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"处理文件时发生错误: {str(e)}"
) from e
finally:
await file.close()
db.close_connection(conn, cursor)

246
router/user_router.py Normal file
View File

@ -0,0 +1,246 @@
from datetime import timedelta
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from mysql.connector import Error as MySQLError
from ds.db import db
from encryption.encrypt_decorator import encrypt_response
from middle.auth_middleware import (
get_password_hash,
verify_password,
create_access_token,
ACCESS_TOKEN_EXPIRE_MINUTES,
get_current_user
)
from schema.response_schema import APIResponse
from schema.user_schema import UserRegisterRequest, UserLoginRequest, UserResponse
router = APIRouter(
prefix="/api/users",
tags=["用户管理"]
)
# 用户注册接口
@router.post("/register", response_model=APIResponse, summary="用户注册")
@encrypt_response()
async def user_register(request: UserRegisterRequest):
"""
用户注册:
- 校验用户名是否已存在
- 加密密码后插入数据库
- 返回注册成功信息
"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 1. 检查用户名是否已存在(唯一索引)
check_query = "SELECT username FROM users WHERE username = %s"
cursor.execute(check_query, (request.username,))
existing_user = cursor.fetchone()
if existing_user:
raise HTTPException(
status_code=400,
detail=f"用户名 '{request.username}' 已存在、请更换其他用户名"
)
# 2. 加密密码
hashed_password = get_password_hash(request.password)
# 3. 插入新用户到数据库
insert_query = """
INSERT INTO users (username, password)
VALUES (%s, %s)
"""
cursor.execute(insert_query, (request.username, hashed_password))
conn.commit() # 提交事务
# 4. 返回注册成功响应
return APIResponse(
code=200, # 200 表示资源创建成功
message=f"用户 '{request.username}' 注册成功",
data=None
)
except MySQLError as e:
conn.rollback() # 数据库错误时回滚事务
raise Exception(f"注册失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 用户登录接口
@router.post("/login", response_model=APIResponse, summary="用户登录(获取 Token")
@encrypt_response()
async def user_login(request: UserLoginRequest):
"""
用户登录:
- 校验用户名是否存在
- 校验密码是否正确
- 生成 JWT Token 并返回
"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 修复: SQL查询添加 created_at 和 updated_at 字段
query = """
SELECT id, username, password, created_at, updated_at
FROM users
WHERE username = %s
"""
cursor.execute(query, (request.username,))
user = cursor.fetchone()
# 2. 校验用户名和密码
if not user or not verify_password(request.password, user["password"]):
raise HTTPException(
status_code=401,
detail="用户名或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
# 3. 生成 Token过期时间从配置读取
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user["username"]},
expires_delta=access_token_expires
)
# 4. 返回 Token 和用户基本信息
return APIResponse(
code=200,
message="登录成功",
data={
"access_token": access_token,
"token_type": "bearer",
"user": UserResponse(
id=user["id"],
username=user["username"],
created_at=user.get("created_at"),
updated_at=user.get("updated_at")
)
}
)
except MySQLError as e:
raise Exception(f"登录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 获取当前登录用户信息(需认证)
@router.get("/me", response_model=APIResponse, summary="获取当前用户信息")
@encrypt_response()
async def get_current_user_info(
current_user: UserResponse = Depends(get_current_user) # 依赖认证中间件
):
"""
获取当前登录用户信息:
- 需在请求头携带 Token格式: Bearer <token>
- 认证通过后返回用户信息
"""
return APIResponse(
code=200,
message="获取用户信息成功",
data=current_user
)
# 获取用户列表(仅需登录权限)
@router.get("/list", response_model=APIResponse, summary="获取用户列表")
@encrypt_response()
async def get_user_list(
page: int = Query(1, ge=1, description="页码、从1开始"),
page_size: int = Query(10, ge=1, le=100, description="每页条数、1-100之间"),
username: Optional[str] = Query(None, description="用户名模糊搜索"),
current_user: UserResponse = Depends(get_current_user) # 仅需登录即可访问(移除管理员校验)
):
"""
获取用户列表:
- 需登录权限(请求头携带 Token: Bearer <token>
- 支持分页查询page=页码、page_size=每页条数)
- 支持用户名模糊搜索(如输入"test"可匹配"test123""admin_test"等)
- 仅返回用户ID、用户名、创建时间、更新时间不包含密码等敏感信息
"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 计算分页偏移量page从1开始、偏移量=(页码-1*每页条数)
offset = (page - 1) * page_size
# 基础查询(仅查非敏感字段)
base_query = """
SELECT id, username, created_at, updated_at
FROM users
"""
# 总条数查询(用于分页计算)
count_query = "SELECT COUNT(*) as total FROM users"
# 条件拼接(支持用户名模糊搜索)
conditions = []
params = []
if username:
conditions.append("username LIKE %s")
params.append(f"%{username}%") # 模糊匹配:%表示任意字符
# 构建最终查询语句
if conditions:
where_clause = " WHERE " + " AND ".join(conditions)
final_query = f"{base_query}{where_clause} LIMIT %s OFFSET %s"
final_count_query = f"{count_query}{where_clause}"
params.extend([page_size, offset]) # 追加分页参数
else:
final_query = f"{base_query} LIMIT %s OFFSET %s"
final_count_query = count_query
params = [page_size, offset]
# 1. 查询用户列表数据
cursor.execute(final_query, params)
users = cursor.fetchall()
# 2. 查询总条数(用于计算总页数)
count_params = [f"%{username}%"] if username else []
cursor.execute(final_count_query, count_params)
total = cursor.fetchone()["total"]
# 3. 转换为UserResponse模型确保字段匹配
user_list = [
UserResponse(
id=user["id"],
username=user["username"],
created_at=user["created_at"],
updated_at=user["updated_at"]
)
for user in users
]
# 4. 计算总页数向上取整、如11条数据每页10条=2页
total_pages = (total + page_size - 1) // page_size
# 返回结果(包含列表和分页信息)
return APIResponse(
code=200,
message="获取用户列表成功",
data={
"users": user_list,
"pagination": {
"page": page, # 当前页码
"page_size": page_size, # 每页条数
"total": total, # 总数据量
"total_pages": total_pages # 总页数
}
}
)
except MySQLError as e:
raise Exception(f"获取用户列表失败: {str(e)}") from e
finally:
# 无论成功失败、都关闭数据库连接
db.close_connection(conn, cursor)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,36 @@
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel, Field
# ------------------------------
# 请求模型
# ------------------------------
class DeviceActionCreate(BaseModel):
"""设备操作记录创建模型0=离线、1=上线)"""
client_ip: str = Field(..., description="客户端IP")
action: int = Field(..., ge=0, le=1, description="操作状态0=离线、1=上线)")
# ------------------------------
# 响应模型(单条记录)
# ------------------------------
class DeviceActionResponse(BaseModel):
"""设备操作记录响应模型(与自增表对齐)"""
id: int = Field(..., description="自增主键ID")
client_ip: Optional[str] = Field(None, description="客户端IP")
action: Optional[int] = Field(None, description="操作状态0=离线、1=上线)")
created_at: datetime = Field(..., description="记录创建时间")
updated_at: datetime = Field(..., description="记录更新时间")
# 支持从数据库结果直接转换
model_config = {"from_attributes": True}
# ------------------------------
# 列表响应模型(仅含 total + device_actions
# ------------------------------
class DeviceActionListResponse(BaseModel):
"""设备操作记录列表(仅核心返回字段)"""
total: int = Field(..., description="总记录数")
device_actions: List[DeviceActionResponse] = Field(..., description="操作记录列表")

View File

@ -0,0 +1,33 @@
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel, Field
# ------------------------------
# 请求模型
# ------------------------------
class DeviceDangerCreateRequest(BaseModel):
"""设备危险记录创建请求模型"""
client_ip: str = Field(..., max_length=100, description="设备IP地址必须与devices表中IP对应")
type: str = Field(..., max_length=255, description="危险类型(如:病毒检测、端口异常、权限泄露等)")
result: str = Field(..., description="危险检测结果/处理结果检测到木马病毒、已隔离端口22异常开放、已关闭")
# ------------------------------
# 响应模型
# ------------------------------
class DeviceDangerResponse(BaseModel):
"""单条设备危险记录响应模型与device_danger表字段对齐、updated_at允许为null"""
id: int = Field(..., description="危险记录主键ID")
client_ip: str = Field(..., max_length=100, description="设备IP地址")
type: str = Field(..., max_length=255, description="危险类型")
result: str = Field(..., description="危险检测结果/处理结果")
created_at: datetime = Field(..., description="记录创建时间(危险发生/检测时间)")
updated_at: Optional[datetime] = Field(None, description="记录更新时间数据库中该字段当前为null")
model_config = {"from_attributes": True}
class DeviceDangerListResponse(BaseModel):
"""设备危险记录列表响应模型(带分页)"""
total: int = Field(..., description="危险记录总数")
dangers: List[DeviceDangerResponse] = Field(..., description="设备危险记录列表")

56
schema/device_schema.py Normal file
View File

@ -0,0 +1,56 @@
from datetime import datetime
from typing import Optional, List, Dict
from pydantic import BaseModel, Field
# ------------------------------
# 请求模型
# ------------------------------
class DeviceCreateRequest(BaseModel):
"""设备创建请求模型"""
ip: Optional[str] = Field(..., max_length=100, description="设备IP地址")
hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
params: Optional[Dict] = Field(None, description="设备扩展参数JSON格式")
# ------------------------------
# 响应模型
# ------------------------------
class DeviceResponse(BaseModel):
"""单设备信息响应模型(与数据库表字段对齐)"""
id: int = Field(..., description="设备主键ID")
device_id: Optional[int] = Field(None, description="关联设备ID若历史记录IP无对应设备则为None")
hostname: Optional[str] = Field(None, max_length=100, description="设备别名")
device_online_status: int = Field(..., description="在线状态1-在线、0-离线)")
device_type: Optional[str] = Field(None, description="设备类型")
alarm_count: int = Field(..., description="报警次数")
params: Optional[str] = Field(None, description="扩展参数JSON字符串")
client_ip: Optional[str] = Field(None, max_length=100, description="设备IP地址")
is_need_handler: int = Field(..., description="是否需要处理1-需要、0-不需要)")
created_at: datetime = Field(..., description="记录创建时间")
updated_at: datetime = Field(..., description="记录更新时间")
model_config = {"from_attributes": True} # 支持从数据库结果直接转换
class DeviceListResponse(BaseModel):
"""设备列表响应模型"""
total: int = Field(..., description="设备总数")
devices: List[DeviceResponse] = Field(..., description="设备列表")
class DeviceStatusHistoryResponse(BaseModel):
"""设备上下线记录响应模型"""
id: int = Field(..., description="记录ID")
device_id: int = Field(..., description="关联设备ID")
client_ip: Optional[str] = Field(None, description="设备IP地址")
status: int = Field(..., description="状态1-在线、0-离线)")
status_time: datetime = Field(..., description="状态变更时间")
model_config = {"from_attributes": True}
class DeviceStatusHistoryListResponse(BaseModel):
"""设备上下线记录列表响应模型"""
total: int = Field(..., description="记录总数")
history: List[DeviceStatusHistoryResponse] = Field(..., description="上下线记录列表")

41
schema/face_schema.py Normal file
View File

@ -0,0 +1,41 @@
from datetime import datetime
from pydantic import BaseModel, Field
from typing import List, Optional
# ------------------------------
# 请求模型(前端传参校验)- 保留update的eigenvalue如需更新特征值
# ------------------------------
class FaceCreateRequest(BaseModel):
"""创建人脸记录请求模型无需ID、由数据库自增"""
name: Optional[str] = Field(None, max_length=255, description="名称可选、最长255字符")
class FaceUpdateRequest(BaseModel):
"""更新人脸记录请求模型 - 保留eigenvalue如需更新特征值、不影响返回"""
name: Optional[str] = Field(None, max_length=255, description="名称(可选)")
eigenvalue: Optional[str] = Field(None, description="特征值(可选、文件处理后可更新)") # 保留更新能力
address: Optional[str] = Field(None, description="图片完整路径(可选、更新图片时使用)")
# ------------------------------
# 响应模型(后端返回数据)- 核心修改删除eigenvalue字段
# ------------------------------
class FaceResponse(BaseModel):
"""人脸记录响应模型仅返回需要的字段、移除eigenvalue"""
id: int = Field(..., description="主键ID数据库自增")
name: Optional[str] = Field(None, description="名称")
address: Optional[str] = Field(None, description="人脸图片完整保存路径(数据库新增字段)") # 仅保留address
created_at: datetime = Field(..., description="记录创建时间(数据库自动生成)")
updated_at: datetime = Field(..., description="记录更新时间(数据库自动生成)")
# 关键配置:支持从数据库查询结果(字典)直接转换
model_config = {"from_attributes": True}
class FaceListResponse(BaseModel):
"""人脸列表分页响应模型结构不变、内部FaceResponse已移除eigenvalue"""
total: int = Field(..., description="筛选后的总记录数")
faces: List[FaceResponse] = Field(..., description="当前页的人脸记录列表")
model_config = {"from_attributes": True}

37
schema/model_schema.py Normal file
View File

@ -0,0 +1,37 @@
from datetime import datetime
from pydantic import BaseModel, Field
from typing import List, Optional
# 请求模型
class ModelCreateRequest(BaseModel):
name: str = Field(..., max_length=255, description="模型名称必填、如yolo-v8s-car")
description: Optional[str] = Field(None, description="模型描述(可选)")
is_default: Optional[bool] = Field(False, description="是否设为默认模型")
class ModelUpdateRequest(BaseModel):
name: Optional[str] = Field(None, max_length=255, description="模型名称(可选修改)")
description: Optional[str] = Field(None, description="模型描述(可选修改)")
is_default: Optional[bool] = Field(None, description="是否设为默认模型(可选切换)")
# 响应模型
class ModelResponse(BaseModel):
id: int = Field(..., description="模型ID")
name: str = Field(..., description="模型名称")
path: str = Field(..., description="模型文件相对路径")
is_default: bool = Field(..., description="是否默认模型")
description: Optional[str] = Field(None, description="模型描述")
file_size: Optional[int] = Field(None, description="文件大小(字节)")
created_at: datetime = Field(..., description="创建时间")
updated_at: datetime = Field(..., description="更新时间")
model_config = {"from_attributes": True}
class ModelListResponse(BaseModel):
total: int = Field(..., description="总记录数")
models: List[ModelResponse] = Field(..., description="当前页模型列表")
model_config = {"from_attributes": True}

13
schema/response_schema.py Normal file
View File

@ -0,0 +1,13 @@
from typing import Optional, Any
from pydantic import BaseModel, Field
class APIResponse(BaseModel):
"""统一 API 响应模型(所有接口必返此格式)"""
code: int = Field(..., description="状态码: 200=成功、4xx=客户端错误、5xx=服务端错误")
message: str = Field(..., description="响应信息: 成功/错误描述")
data: Optional[Any] = Field(None, description="响应数据: 成功时返回、错误时为 None")
# Pydantic V2 配置(支持从 ORM 对象转换)
model_config = {"from_attributes": True}

View File

@ -0,0 +1,38 @@
from datetime import datetime
from pydantic import BaseModel, Field
from typing import List, Optional
# ------------------------------
# 请求模型(前端传参校验)
# ------------------------------
class SensitiveCreateRequest(BaseModel):
"""创建敏感信息记录请求模型"""
name: str = Field(..., max_length=255, description="敏感词内容(必填)")
class SensitiveUpdateRequest(BaseModel):
"""更新敏感信息记录请求模型"""
name: Optional[str] = Field(None, max_length=255, description="敏感词内容(可选修改)")
# ------------------------------
# 响应模型(后端返回数据)
# ------------------------------
class SensitiveResponse(BaseModel):
"""敏感信息单条记录响应模型"""
id: int = Field(..., description="主键ID")
name: str = Field(..., description="敏感词内容")
created_at: datetime = Field(..., description="记录创建时间")
updated_at: datetime = Field(..., description="记录更新时间")
# 支持从数据库查询结果(字典/对象)自动转换
model_config = {"from_attributes": True}
class SensitiveListResponse(BaseModel):
"""敏感信息分页列表响应模型(新增)"""
total: int = Field(..., description="敏感词总记录数")
sensitives: List[SensitiveResponse] = Field(..., description="当前页敏感词列表")
model_config = {"from_attributes": True}

40
schema/user_schema.py Normal file
View File

@ -0,0 +1,40 @@
from datetime import datetime
from pydantic import BaseModel, Field
from typing import List, Optional
# ------------------------------
# 请求模型(前端传参校验)
# ------------------------------
class UserRegisterRequest(BaseModel):
"""用户注册请求模型"""
username: str = Field(..., min_length=3, max_length=50, description="用户名3-50字符")
password: str = Field(..., min_length=6, max_length=100, description="密码6-100字符")
class UserLoginRequest(BaseModel):
"""用户登录请求模型"""
username: str = Field(..., description="用户名")
password: str = Field(..., description="密码")
# ------------------------------
# 响应模型(后端返回用户数据)
# ------------------------------
class UserResponse(BaseModel):
"""用户信息响应模型(隐藏密码等敏感字段)"""
id: int = Field(..., description="用户ID")
username: str = Field(..., description="用户名")
created_at: datetime = Field(..., description="创建时间")
updated_at: datetime = Field(..., description="更新时间")
# Pydantic V2 配置(支持从数据库查询结果转换)
model_config = {"from_attributes": True}
class UserListResponse(BaseModel):
"""用户列表分页响应模型(与设备/人脸列表结构对齐)"""
total: int = Field(..., description="用户总数")
users: List[UserResponse] = Field(..., description="当前页用户列表")
model_config = {"from_attributes": True}

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,43 @@
from ds.db import db
from schema.device_action_schema import DeviceActionCreate, DeviceActionResponse
from mysql.connector import Error as MySQLError
# 新增设备操作记录
def add_device_action(client_ip: str, action: int) -> DeviceActionResponse:
"""
新增设备操作记录(内部方法、非接口)
:param action_data: 含client_ip和action0/1
:return: 新增的完整记录
"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 插入SQLid自增、依赖数据库自动生成
insert_query = """
INSERT INTO device_action
(client_ip, action, created_at, updated_at)
VALUES (%s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
"""
cursor.execute(insert_query, (
client_ip,
action
))
conn.commit()
# 获取新增记录通过自增ID查询
new_id = cursor.lastrowid
cursor.execute("SELECT * FROM device_action WHERE id = %s", (new_id,))
new_action = cursor.fetchone()
return DeviceActionResponse(**new_action)
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"新增记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)

View File

@ -0,0 +1,78 @@
from mysql.connector import Error as MySQLError
from ds.db import db
from schema.device_danger_schema import DeviceDangerCreateRequest, DeviceDangerResponse
# ------------------------------
# 内部工具方法 - 检查设备是否存在(复用设备表逻辑)
# ------------------------------
def check_device_exist(client_ip: str) -> bool:
"""
检查指定IP的设备是否在devices表中存在
:param client_ip: 设备IP地址
:return: 存在返回True、不存在返回False
"""
if not client_ip:
raise ValueError("设备IP不能为空")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
return cursor.fetchone() is not None
except MySQLError as e:
raise Exception(f"检查设备存在性失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# ------------------------------
# 内部工具方法 - 创建设备危险记录(核心插入逻辑)
# ------------------------------
def create_danger_record(danger_data: DeviceDangerCreateRequest) -> DeviceDangerResponse:
"""
内部工具方法向device_danger表插入新的危险记录
:param danger_data: 危险记录创建请求数据
:return: 创建成功的危险记录模型对象
"""
# 先检查设备是否存在
if not check_device_exist(danger_data.client_ip):
raise ValueError(f"IP为 {danger_data.client_ip} 的设备不存在、无法创建危险记录")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 插入危险记录id自增、时间自动填充
insert_query = """
INSERT INTO device_danger
(client_ip, type, result, created_at, updated_at)
VALUES (%s, %s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
"""
cursor.execute(insert_query, (
danger_data.client_ip,
danger_data.type,
danger_data.result
))
conn.commit()
# 获取刚创建的记录用自增ID查询
danger_id = cursor.lastrowid
cursor.execute("SELECT * FROM device_danger WHERE id = %s", (danger_id,))
new_danger = cursor.fetchone()
return DeviceDangerResponse(**new_danger)
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"插入危险记录失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)

250
service/device_service.py Normal file
View File

@ -0,0 +1,250 @@
import threading
import time
from typing import Optional
from mysql.connector import Error as MySQLError
from ds.db import db
from service.device_action_service import add_device_action
_last_alarm_timestamps: dict[str, float] = {}
_timestamp_lock = threading.Lock()
# 获取所有去重的客户端IP列表
def get_unique_client_ips() -> list[str]:
"""获取所有去重的客户端IP列表"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
query = "SELECT DISTINCT client_ip FROM devices WHERE client_ip IS NOT NULL"
cursor.execute(query)
results = cursor.fetchall()
return [item['client_ip'] for item in results]
except MySQLError as e:
raise Exception(f"获取客户端IP列表失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 通过客户端IP更新设备是否需要处理
def update_is_need_handler_by_client_ip(client_ip: str, is_need_handler: int) -> bool:
"""
通过客户端IP更新设备的「是否需要处理」状态is_need_handler字段
"""
# 参数合法性校验
if not client_ip:
raise ValueError("客户端IP不能为空")
# 校验is_need_handler取值需与数据库字段类型匹配、通常为0/1 tinyint
if is_need_handler not in (0, 1):
raise ValueError("是否需要处理is_need_handler必须是0不需要或1需要")
conn = None
cursor = None
try:
# 2. 获取数据库连接与游标
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 3. 先校验设备是否存在通过client_ip定位
cursor.execute(
"SELECT id FROM devices WHERE client_ip = %s",
(client_ip,)
)
device = cursor.fetchone()
if not device:
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在、无法更新「是否需要处理」状态")
# 4. 执行更新操作(同时更新时间戳、保持与其他更新逻辑一致性)
update_query = """
UPDATE devices
SET is_need_handler = %s,
updated_at = CURRENT_TIMESTAMP
WHERE client_ip = %s
"""
cursor.execute(update_query, (is_need_handler, client_ip))
# 5. 确认更新生效(判断影响行数、避免无意义更新)
if cursor.rowcount <= 0:
raise Exception(f"更新失败客户端IP {client_ip} 的设备未发生状态变更(可能已为目标值)")
# 6. 提交事务
conn.commit()
return True
except MySQLError as e:
# 数据库异常时回滚事务
if conn:
conn.rollback()
raise Exception(f"数据库操作失败:更新设备「是否需要处理」状态时出错 - {str(e)}") from e
finally:
# 无论成功失败、都关闭数据库连接(避免连接泄漏)
db.close_connection(conn, cursor)
def increment_alarm_count_by_ip(client_ip: str) -> bool:
"""
通过客户端IP增加设备的报警次数相同IP 200ms内重复调用会被忽略
:param client_ip: 客户端IP地址
:return: 操作是否成功(是否实际执行了数据库更新)
"""
if not client_ip:
raise ValueError("客户端IP不能为空")
current_time = time.time() # 获取当前时间戳(秒,含小数)
with _timestamp_lock: # 确保线程安全的字典操作
last_time: Optional[float] = _last_alarm_timestamps.get(client_ip)
# 如果存在最近记录且间隔小于200ms直接返回False不执行更新
if last_time is not None and (current_time - last_time) < 0.2:
return False
# 更新当前IP的最近调用时间
_last_alarm_timestamps[client_ip] = current_time
# 2. 执行数据库更新操作(只有通过时间校验才会执行)
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 检查设备是否存在
cursor.execute("SELECT id FROM devices WHERE client_ip = %s", (client_ip,))
device = cursor.fetchone()
if not device:
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
# 报警次数加1、并更新时间戳
update_query = """
UPDATE devices
SET alarm_count = alarm_count + 1,
updated_at = CURRENT_TIMESTAMP
WHERE client_ip = %s
"""
cursor.execute(update_query, (client_ip,))
conn.commit()
return True
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"更新报警次数失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 通过客户端IP更新设备在线状态
def update_online_status_by_ip(client_ip: str, online_status: int) -> bool:
"""
通过客户端IP更新设备的在线状态
:param client_ip: 客户端IP地址
:param online_status: 在线状态1-在线、0-离线)
:return: 操作是否成功
"""
if not client_ip:
raise ValueError("客户端IP不能为空")
if online_status not in (0, 1):
raise ValueError("在线状态必须是0离线或1在线")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 检查设备是否存在并获取设备ID
cursor.execute("SELECT id, device_online_status FROM devices WHERE client_ip = %s", (client_ip,))
device = cursor.fetchone()
if not device:
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
# 状态无变化则不操作
if device['device_online_status'] == online_status:
return True
# 更新在线状态和时间戳
update_query = """
UPDATE devices
SET device_online_status = %s,
updated_at = CURRENT_TIMESTAMP
WHERE client_ip = %s
"""
cursor.execute(update_query, (online_status, client_ip))
# 记录状态变更历史
add_device_action(client_ip, online_status)
conn.commit()
return True
except MySQLError as e:
if conn:
conn.rollback()
raise Exception(f"更新设备在线状态失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 通过客户端IP查询设备在数据库中存在
def is_device_exist_by_ip(client_ip: str) -> bool:
"""
通过客户端IP查询设备在数据库中是否存在
"""
if not client_ip:
raise ValueError("客户端IP不能为空")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 查询设备是否存在
cursor.execute(
"SELECT id FROM devices WHERE client_ip = %s",
(client_ip,)
)
device = cursor.fetchone()
# 如果查询到结果则存在,否则不存在
return bool(device)
except MySQLError as e:
raise Exception(f"查询设备是否存在失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 根据客户端IP获取是否需要处理
def get_is_need_handler_by_ip(client_ip: str) -> int:
"""
通过客户端IP查询设备的is_need_handler状态
:param client_ip: 客户端IP地址
:return: 设备的is_need_handler状态0-不需要处理1-需要处理)
"""
if not client_ip:
raise ValueError("客户端IP不能为空")
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 查询设备的is_need_handler状态
cursor.execute(
"SELECT is_need_handler FROM devices WHERE client_ip = %s",
(client_ip,)
)
device = cursor.fetchone()
if not device:
raise ValueError(f"客户端IP为 {client_ip} 的设备不存在")
return device['is_need_handler']
except MySQLError as e:
raise Exception(f"查询设备is_need_handler状态失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)

339
service/face_service.py Normal file
View File

@ -0,0 +1,339 @@
import cv2
from io import BytesIO
from PIL import Image
from mysql.connector import Error as MySQLError
from ds.db import db
import numpy as np
import threading
from insightface.app import FaceAnalysis
# 全局变量定义
_insightface_app = None
_known_faces_embeddings = {} # 存储已知人脸特征 {姓名: 特征向量}
_known_faces_names = [] # 存储已知人脸姓名列表
def init_insightface():
"""初始化InsightFace引擎"""
global _insightface_app
if _insightface_app is not None:
print("InsightFace引擎已初始化无需重复执行")
return _insightface_app
try:
print("正在初始化 InsightFace 引擎模型buffalo_l...")
# 初始化引擎,指定模型路径和计算 providers
app = FaceAnalysis(
name='buffalo_l',
root='~/.insightface',
providers=['CPUExecutionProvider'] # 如需GPU可添加'CUDAExecutionProvider'
)
app.prepare(ctx_id=0, det_size=(640, 640)) # 调整检测尺寸
print("InsightFace 引擎初始化完成")
# 初始化时加载人脸特征库
init_face_data()
_insightface_app = app
return app
except Exception as e:
print(f"InsightFace 初始化失败:{str(e)}")
_insightface_app = None
return None
def init_face_data():
"""初始化或更新人脸特征库(清空旧数据,避免重复)"""
global _known_faces_embeddings, _known_faces_names
# 清空原有数据,防止重复加载
_known_faces_embeddings.clear()
_known_faces_names.clear()
try:
face_data = get_all_face_name_with_eigenvalue() # 假设该函数已定义
print(f"已加载 {len(face_data)} 个人脸数据")
for person_name, eigenvalue_data in face_data.items():
# 解析特征值支持numpy数组或字符串格式
if isinstance(eigenvalue_data, np.ndarray):
eigenvalue = eigenvalue_data.astype(np.float32)
elif isinstance(eigenvalue_data, str):
# 增强字符串解析:支持逗号/空格分隔,清理特殊字符
cleaned = (eigenvalue_data
.replace("[", "").replace("]", "")
.replace("\n", "").replace(",", " ")
.strip())
values = [v for v in cleaned.split() if v] # 过滤空字符串
if not values:
print(f"特征值解析失败(空值),跳过 {person_name}")
continue
eigenvalue = np.array(list(map(float, values)), dtype=np.float32)
else:
print(f"不支持的特征值类型({type(eigenvalue_data)}),跳过 {person_name}")
continue
# 特征值归一化(确保相似度计算一致性)
norm = np.linalg.norm(eigenvalue)
if norm == 0:
print(f"特征值为零向量,跳过 {person_name}")
continue
eigenvalue = eigenvalue / norm
# 更新全局特征库
_known_faces_embeddings[person_name] = eigenvalue
_known_faces_names.append(person_name)
print(f"成功加载 {len(_known_faces_names)} 个人脸的特征库")
except Exception as e:
print(f"加载人脸特征库失败: {e}")
def update_face_data():
"""更新人脸特征库(清空旧数据,加载最新数据)"""
global _known_faces_embeddings, _known_faces_names
print("开始更新人脸特征库...")
# 清空原有数据
_known_faces_embeddings.clear()
_known_faces_names.clear()
try:
# 获取最新人脸数据
face_data = get_all_face_name_with_eigenvalue()
print(f"获取到 {len(face_data)} 条最新人脸数据")
# 处理并加载新数据(复用原有解析逻辑)
for person_name, eigenvalue_data in face_data.items():
# 解析特征值支持numpy数组或字符串格式
if isinstance(eigenvalue_data, np.ndarray):
eigenvalue = eigenvalue_data.astype(np.float32)
elif isinstance(eigenvalue_data, str):
# 增强字符串解析:支持逗号/空格分隔,清理特殊字符
cleaned = (eigenvalue_data
.replace("[", "").replace("]", "")
.replace("\n", "").replace(",", " ")
.strip())
values = [v for v in cleaned.split() if v] # 过滤空字符串
if not values:
print(f"特征值解析失败(空值),跳过 {person_name}")
continue
eigenvalue = np.array(list(map(float, values)), dtype=np.float32)
else:
print(f"不支持的特征值类型({type(eigenvalue_data)}),跳过 {person_name}")
continue
# 特征值归一化(确保相似度计算一致性)
norm = np.linalg.norm(eigenvalue)
if norm == 0:
print(f"特征值为零向量,跳过 {person_name}")
continue
eigenvalue = eigenvalue / norm
# 更新全局特征库
_known_faces_embeddings[person_name] = eigenvalue
_known_faces_names.append(person_name)
print(f"人脸特征库更新完成,共加载 {len(_known_faces_names)} 个人脸数据")
return True # 更新成功
except Exception as e:
print(f"人脸特征库更新失败: {e}")
return False # 更新失败
def detect(frame, similarity_threshold=0.4):
global _insightface_app, _known_faces_embeddings
# 校验输入有效性
if frame is None or frame.size == 0:
return (False, "无效的输入帧数据")
# 校验引擎和特征库状态
if not _insightface_app:
return (False, "人脸引擎未初始化")
if not _known_faces_embeddings:
return (False, "人脸特征库为空")
try:
# 执行人脸检测与特征提取
faces = _insightface_app.get(frame)
except Exception as e:
return (False, f"检测错误: {str(e)}")
result_parts = []
has_matched_known_face = False # 是否有匹配到已知人脸
for face in faces:
# 归一化当前人脸特征
face_embedding = face.embedding.astype(np.float32)
norm = np.linalg.norm(face_embedding)
if norm == 0:
result_parts.append("检测到人脸但特征值为零向量(忽略)")
continue
face_embedding = face_embedding / norm
# 与已知特征库比对
max_similarity, best_match_name = -1.0, "Unknown"
for name, known_emb in _known_faces_embeddings.items():
similarity = np.dot(face_embedding, known_emb) # 余弦相似度
if similarity > max_similarity:
max_similarity = similarity
best_match_name = name
# 判断是否匹配成功
is_matched = max_similarity >= similarity_threshold
if is_matched:
has_matched_known_face = True
# 记录结果(边界框转为整数列表)
bbox = face.bbox.astype(int).tolist()
result_parts.append(
f"{'匹配' if is_matched else '未匹配'}: {best_match_name} "
f"(相似度: {max_similarity:.2f}, 边界框: {bbox})"
)
# 构建最终结果
result_str = "未检测到人脸" if not result_parts else "; ".join(result_parts)
return (has_matched_known_face, result_str)
# 上传图片并提取特征
def add_binary_data(binary_data):
"""
接收单张图片的二进制数据、提取特征并保存
返回:(True, 特征值numpy数组) 或 (False, 错误信息字符串)
"""
global _insightface_app, _feature_list
# 1. 先检查引擎是否初始化成功
if not _insightface_app:
init_result = init_insightface() # 尝试重新初始化
if not init_result:
error_msg = "InsightFace引擎未初始化、无法检测人脸"
print(error_msg)
return False, error_msg
try:
# 2. 验证二进制数据有效性
if len(binary_data) < 1024: # 过滤过小的无效图片小于1KB
error_msg = f"图片过小({len(binary_data)}字节)、可能不是有效图片"
print(error_msg)
return False, error_msg
# 3. 二进制数据转CV2格式关键步骤、避免通道错误
try:
img = Image.open(BytesIO(binary_data)).convert("RGB") # 强制转RGB
frame = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) # InsightFace需要BGR格式
except Exception as e:
error_msg = f"图片格式转换失败:{str(e)}"
print(error_msg)
return False, error_msg
# 4. 检查图片尺寸(避免极端尺寸导致检测失败)
height, width = frame.shape[:2]
if height < 64 or width < 64: # 人脸检测最小建议尺寸
error_msg = f"图片尺寸过小({width}x{height}、需至少64x64像素"
print(error_msg)
return False, error_msg
# 5. 调用InsightFace检测人脸
print(f"开始检测人脸(图片尺寸:{width}x{height}、格式BGR")
faces = _insightface_app.get(frame)
if not faces:
error_msg = "未检测到人脸(请确保图片包含清晰正面人脸、无遮挡、不模糊)"
print(error_msg)
return False, error_msg
# 6. 提取特征并保存
current_feature = faces[0].embedding
_feature_list.append(current_feature)
print(f"人脸检测成功、提取特征值(维度:{current_feature.shape[0]})、累计特征数:{len(_feature_list)}")
return True, current_feature
except Exception as e:
error_msg = f"处理图片时发生异常:{str(e)}"
print(error_msg)
return False, error_msg
# 获取数据库最新的人脸及其特征
def get_all_face_name_with_eigenvalue() -> dict:
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
query = "SELECT name, eigenvalue FROM face WHERE name IS NOT NULL"
cursor.execute(query)
faces = cursor.fetchall()
name_to_eigenvalues = {}
for face in faces:
name = face["name"]
eigenvalue = face["eigenvalue"]
if name in name_to_eigenvalues:
name_to_eigenvalues[name].append(eigenvalue)
else:
name_to_eigenvalues[name] = [eigenvalue]
face_dict = {}
for name, eigenvalues in name_to_eigenvalues.items():
if len(eigenvalues) > 1:
face_dict[name] = get_average_feature(eigenvalues)
else:
face_dict[name] = eigenvalues[0]
return face_dict
except MySQLError as e:
raise Exception(f"获取人脸特征失败: {str(e)}") from e
finally:
db.close_connection(conn, cursor)
# 获取平均特征值
def get_average_feature(features=None):
global _feature_list
try:
if features is None:
features = _feature_list
if not isinstance(features, list) or len(features) == 0:
print("输入必须是包含至少一个特征值的列表")
return None
processed_features = []
for i, embedding in enumerate(features):
try:
if isinstance(embedding, str):
embedding_str = embedding.replace('[', '').replace(']', '').replace(',', ' ').strip()
embedding_list = [float(num) for num in embedding_str.split() if num.strip()]
embedding_np = np.array(embedding_list, dtype=np.float32)
else:
embedding_np = np.array(embedding, dtype=np.float32)
if len(embedding_np.shape) == 1:
processed_features.append(embedding_np)
print(f"已添加第 {i + 1} 个特征值用于计算平均值")
else:
print(f"跳过第 {i + 1} 个特征值:不是一维数组")
except Exception as e:
print(f"处理第 {i + 1} 个特征值时出错:{str(e)}")
if not processed_features:
print("没有有效的特征值用于计算平均值")
return None
dims = {feat.shape[0] for feat in processed_features}
if len(dims) > 1:
print(f"特征值维度不一致:{dims}、无法计算平均值")
return None
avg_feature = np.mean(processed_features, axis=0)
print(f"计算成功:{len(processed_features)} 个特征值的平均向量(维度:{avg_feature.shape[0]}")
return avg_feature
except Exception as e:
print(f"计算平均特征值出错:{str(e)}")
return None

343
service/file_service.py Normal file
View File

@ -0,0 +1,343 @@
import os
import re
import shutil
from datetime import datetime
from PIL import ImageDraw, ImageFont
from fastapi import UploadFile
import cv2
from PIL import Image
import numpy as np
# 上传根目录
UPLOAD_ROOT = "upload"
PRE = "/api/file/download/"
# 确保上传根目录存在
os.makedirs(UPLOAD_ROOT, exist_ok=True)
def save_detect_file(client_ip: str, image_np: np.ndarray, file_type: str) -> str:
"""保存numpy数组格式的PNG图片到detect目录返回下载路径"""
today = datetime.now()
year = today.strftime("%Y")
month = today.strftime("%m")
day = today.strftime("%d")
# 构建目录路径: upload/detect/客户端IP/type/年/月/日包含UPLOAD_ROOT
file_dir = os.path.join(
UPLOAD_ROOT,
"detect",
client_ip,
file_type,
year,
month,
day
)
# 创建目录(确保目录存在)
os.makedirs(file_dir, exist_ok=True)
# 生成当前时间戳作为文件名,确保唯一性
timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
filename = f"{timestamp}.png"
# 1. 完整路径用于实际保存文件包含UPLOAD_ROOT
full_path = os.path.join(file_dir, filename)
# 2. 相对路径用于返回给前端移除UPLOAD_ROOT前缀
relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep)
# 保存numpy数组为PNG图片
try:
# -------- 新增/修改:处理颜色通道和数据类型 --------
# 1. 数据类型转换确保是uint8若为float32且范围0-1需转成0-255的uint8
if image_np.dtype != np.uint8:
image_np = (image_np * 255).astype(np.uint8)
# 2. 通道顺序转换若为OpenCV的BGR格式转成PIL需要的RGB格式
image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
# 3. 转换为PIL Image并保存
img = Image.fromarray(image_rgb)
img.save(full_path, format='PNG')
except Exception as e:
# 处理可能的异常(如数组格式不正确)
raise Exception(f"保存图片失败: {str(e)}")
# 统一路径分隔符为/,拼接前缀返回
return PRE + relative_path.replace(os.sep, "/")
def save_detect_yolo_file(
client_ip: str,
image_np: np.ndarray,
detection_results: list,
file_type: str = "yolo"
) -> str:
print("......................")
"""
保存YOLO检测结果图片在原图上绘制边界框+标签),返回前端可访问的下载路径
"""
# 输入参数验证
if not isinstance(image_np, np.ndarray):
raise ValueError(f"输入image_np必须是numpy数组当前类型{type(image_np)}")
if image_np.ndim != 3 or image_np.shape[-1] != 3:
raise ValueError(f"输入图像必须是 (h, w, 3) 的BGR数组当前shape{image_np.shape}")
if not isinstance(detection_results, list):
raise ValueError(f"detection_results必须是列表当前类型{type(detection_results)}")
for idx, result in enumerate(detection_results):
required_keys = {"class", "confidence", "bbox"}
if not isinstance(result, dict) or not required_keys.issubset(result.keys()):
raise ValueError(
f"detection_results第{idx}个元素格式错误,需包含键:{required_keys}"
f"当前元素:{result}"
)
bbox = result["bbox"]
if not (isinstance(bbox, (tuple, list)) and len(bbox) == 4 and all(isinstance(x, int) for x in bbox)):
raise ValueError(
f"detection_results第{idx}个元素的bbox格式错误需为(x1,y1,x2,y2)整数元组,"
f"当前bbox{bbox}"
)
#图像预处理(数据类型+通道)
draw_image = image_np.copy()
if draw_image.dtype != np.uint8:
draw_image = np.clip(draw_image * 255, 0, 255).astype(np.uint8)
#绘制边界框+标签
# 遍历所有检测结果,逐个绘制
for result in detection_results:
class_name = result["class"]
confidence = result["confidence"]
x1, y1, x2, y2 = result["bbox"]
cv2.rectangle(draw_image, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2)
label = f"{class_name}: {confidence:.2f}"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
font_thickness = 2
(label_width, label_height), baseline = cv2.getTextSize(
label, font, font_scale, font_thickness
)
bg_top_left = (x1, y1 - label_height - 10)
bg_bottom_right = (x1 + label_width, y1)
if bg_top_left[1] < 0:
bg_top_left = (x1, 0)
bg_bottom_right = (x1 + label_width, label_height + 10)
cv2.rectangle(draw_image, bg_top_left, bg_bottom_right, color=(0, 0, 0), thickness=-1)
text_origin = (x1, y1 - 5)
if bg_top_left[1] == 0:
text_origin = (x1, label_height + 5)
cv2.putText(
draw_image, label, text_origin,
font, font_scale, color=(255, 255, 255), thickness=font_thickness
)
#保存图片
try:
today = datetime.now()
year = today.strftime("%Y")
month = today.strftime("%m")
day = today.strftime("%d")
file_dir = os.path.join(
UPLOAD_ROOT, "detect", client_ip, file_type, year, month, day
)
#创建目录(若不存在则创建,支持多级目录)
os.makedirs(file_dir, exist_ok=True)
#生成唯一文件名
timestamp = today.strftime("%Y%m%d%H%M%S%f")
filename = f"{timestamp}.png"
# 4.4 构建完整保存路径和前端访问路径
full_path = os.path.join(file_dir, filename) # 本地完整路径
# 相对路径移除UPLOAD_ROOT前缀统一用"/"作为分隔符兼容Windows/Linux
relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep)
download_path = PRE + relative_path.replace(os.sep, "/")
# 4.5 保存图片CV2绘制的是BGR需转RGB后用PIL保存与原逻辑一致
image_rgb = cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB)
img_pil = Image.fromarray(image_rgb)
img_pil.save(full_path, format="PNG", quality=95) # PNG格式无压缩quality可忽略
print(f"YOLO检测图片保存成功 | 本地路径:{full_path} | 下载路径:{download_path}")
return download_path
except Exception as e:
raise Exception(f"YOLO检测图片保存失败{str(e)}") from e
def save_detect_face_file(
client_ip: str,
image_np: np.ndarray,
face_result: str,
file_type: str = "face",
matched_color: tuple = (0, 255, 0)
) -> str:
"""
保存人脸识别结果图片(仅为「匹配成功」的人脸画框,标签不包含“匹配”二字)
"""
#输入参数验证
if not isinstance(image_np, np.ndarray) or image_np.ndim != 3 or image_np.shape[-1] != 3:
raise ValueError(f"输入图像需为 (h, w, 3) 的BGR数组当前shape{image_np.shape}")
if not isinstance(face_result, str) or face_result.strip() == "":
raise ValueError("face_result必须是非空字符串")
# 解析face_result提取人脸信息
face_info_list = []
if face_result.strip() != "未检测到人脸":
face_pattern = re.compile(
r"(匹配|未匹配):\s*([^\s(]+)\s*\(相似度:\s*(\d+\.\d+),\s*边界框:\s*\[(\d+,\s*\d+,\s*\d+,\s*\d+)\]\)"
)
for part in [p.strip() for p in face_result.split(";") if p.strip()]:
match = face_pattern.match(part)
if match:
status, name, similarity, bbox_str = match.groups()
bbox = list(map(int, bbox_str.replace(" ", "").split(",")))
if len(bbox) == 4:
face_info_list.append({
"status": status,
"name": name,
"similarity": float(similarity),
"bbox": bbox
})
# 图像格式转换OpenCV→PIL
image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(image_rgb)
draw = ImageDraw.Draw(pil_img)
# 绘制边界框和标签
font_size = 12
try:
font = ImageFont.truetype("simhei", font_size)
except:
try:
font = ImageFont.truetype("simsun", font_size)
except:
font = ImageFont.load_default()
print("警告未找到指定中文字体使用PIL默认字体可能影响中文显示")
for face_info in face_info_list:
status = face_info["status"]
if status != "匹配":
print(f"跳过未匹配人脸:{face_info['name']}(相似度:{face_info['similarity']:.2f}")
continue
name = face_info["name"]
similarity = face_info["similarity"]
x1, y1, x2, y2 = face_info["bbox"]
# 4.1 绘制边界框(绿色)
img_cv = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
cv2.rectangle(img_cv, (x1, y1), (x2, y2), color=matched_color, thickness=2)
pil_img = Image.fromarray(cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_img)
label = f"{name} (相似度: {similarity:.2f})"
# 4.3 计算标签尺寸(文本变短后会自动适配,无需额外调整)
label_bbox = draw.textbbox((0, 0), label, font=font)
label_width = label_bbox[2] - label_bbox[0]
label_height = label_bbox[3] - label_bbox[1]
# 4.4 计算标签背景位置(避免超出图像)
bg_x1, bg_y1 = x1, y1 - label_height - 10
bg_x2, bg_y2 = x1 + label_width, y1
if bg_y1 < 0:
bg_y1, bg_y2 = y2 + 5, y2 + label_height + 15
# 4.5 绘制标签背景(黑色)和文本(白色)
draw.rectangle([(bg_x1, bg_y1), (bg_x2, bg_y2)], fill=(0, 0, 0))
text_x = bg_x1
text_y = bg_y1 if bg_y1 >= 0 else bg_y1 + label_height
draw.text((text_x, text_y), label, font=font, fill=(255, 255, 255))
#保存图片
try:
today = datetime.now()
file_dir = os.path.join(
UPLOAD_ROOT, "detect", client_ip, file_type,
today.strftime("%Y"), today.strftime("%m"), today.strftime("%d")
)
os.makedirs(file_dir, exist_ok=True)
timestamp = today.strftime("%Y%m%d%H%M%S%f")
filename = f"{timestamp}.png"
full_path = os.path.join(file_dir, filename)
pil_img.save(full_path, format="PNG", quality=100)
relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep)
download_path = PRE + relative_path.replace(os.sep, "/")
matched_count = sum(1 for info in face_info_list if info["status"] == "匹配")
print(f"人脸检测图片保存成功 | 客户端IP{client_ip} | 匹配人脸数:{matched_count} | 保存路径:{download_path}")
return download_path
except Exception as e:
raise Exception(f"人脸检测图片保存失败客户端IP{client_ip}{str(e)}") from e
def save_source_file(upload_file: UploadFile, file_type: str) -> str:
"""保存上传的文件到source目录返回下载路径"""
today = datetime.now()
year = today.strftime("%Y")
month = today.strftime("%m")
day = today.strftime("%d")
# 生成精确到微秒的时间戳,确保文件名唯一
timestamp = today.strftime("%Y%m%d%H%M%S%f")
# 构建新文件名时间戳_原文件名
unique_filename = f"{timestamp}_{upload_file.filename}"
# 构建目录路径: upload/source/type/年/月/日包含UPLOAD_ROOT
file_dir = os.path.join(
UPLOAD_ROOT,
"source",
file_type,
year,
month,
day
)
# 创建目录(确保目录存在)
os.makedirs(file_dir, exist_ok=True)
# 1. 完整路径:用于实际保存文件(使用带时间戳的唯一文件名)
full_path = os.path.join(file_dir, unique_filename)
# 2. 相对路径:用于返回给前端
relative_path = full_path.replace(UPLOAD_ROOT, "", 1).lstrip(os.sep)
# 保存文件(使用完整路径)
try:
with open(full_path, "wb") as buffer:
shutil.copyfileobj(upload_file.file, buffer)
finally:
upload_file.file.close()
# 统一路径分隔符为/
return PRE + relative_path.replace(os.sep, "/")
def get_absolute_path(relative_path: str) -> str:
"""
根据相对路径获取服务器上的绝对路径
"""
path_without_pre = relative_path.replace(PRE, "", 1)
# 将相对路径转换为系统兼容的格式
normalized_path = os.path.normpath(path_without_pre)
# 拼接基础路径和相对路径,得到绝对路径
absolute_path = os.path.abspath(os.path.join(UPLOAD_ROOT, normalized_path))
# 安全检查确保生成的路径在UPLOAD_ROOT目录下防止路径遍历
if not absolute_path.startswith(os.path.abspath(UPLOAD_ROOT)):
raise ValueError("无效的相对路径,可能试图访问上传目录之外的内容")
return absolute_path

131
service/model_service.py Normal file
View File

@ -0,0 +1,131 @@
from http.client import HTTPException
import numpy as np
import torch
from MySQLdb import MySQLError
from ultralytics import YOLO
import os
from ds.db import db
from service.file_service import get_absolute_path
# 全局变量
current_yolo_model = None
current_model_absolute_path = None # 存储模型绝对路径不依赖model实例
ALLOWED_MODEL_EXT = {"pt"}
MAX_MODEL_SIZE = 100 * 1024 * 1024 # 100MB
def load_yolo_model():
"""加载模型并存储绝对路径"""
global current_yolo_model, current_model_absolute_path
model_rel_path = get_enabled_model_rel_path()
print(f"[模型初始化] 加载模型:{model_rel_path}")
# 计算并存储绝对路径
current_model_absolute_path = get_absolute_path(model_rel_path)
print(f"[模型初始化] 绝对路径:{current_model_absolute_path}")
# 检查模型文件
if not os.path.exists(current_model_absolute_path):
raise FileNotFoundError(f"模型文件不存在: {current_model_absolute_path}")
try:
new_model = YOLO(current_model_absolute_path)
if torch.cuda.is_available():
new_model.to('cuda')
print("模型已移动到GPU")
else:
print("使用CPU进行推理")
current_yolo_model = new_model
print(f"成功加载模型: {current_model_absolute_path}")
return current_yolo_model
except Exception as e:
print(f"模型加载失败:{str(e)}")
raise
def get_current_model():
"""获取当前模型实例"""
if current_yolo_model is None:
raise ValueError("尚未加载任何YOLO模型请先调用load_yolo_model加载模型")
return current_yolo_model
def detect(image_np, conf_threshold=0.8):
# 1. 输入格式验证
if not isinstance(image_np, np.ndarray):
raise ValueError("输入必须是numpy数组BGR图像")
if image_np.ndim != 3 or image_np.shape[-1] != 3:
raise ValueError(f"输入图像格式错误,需为 (h, w, 3) 的BGR数组当前shape: {image_np.shape}")
detection_results = []
try:
model = get_current_model()
if not current_model_absolute_path:
raise RuntimeError("模型未初始化!请先调用 load_yolo_model 加载模型")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"检测设备:{device} | 置信度阈值:{conf_threshold}")
# 图像尺寸信息
img_height, img_width = image_np.shape[:2]
print(f"输入图像尺寸:{img_width}x{img_height}")
# YOLO检测
print("执行YOLO检测")
results = model.predict(
image_np,
conf=conf_threshold,
device=device,
show=False,
)
# 4. 整理检测结果仅保留Chest类别ID=2
for box in results[0].boxes:
class_id = int(box.cls[0]) # 类别ID
class_name = model.names[class_id]
confidence = float(box.conf[0])
bbox = tuple(map(int, box.xyxy[0]))
# 过滤条件:置信度达标 + 类别为Chestclass_id=2
# and class_id == 2
if confidence >= conf_threshold:
detection_results.append({
"class": class_name,
"confidence": confidence,
"bbox": bbox
})
# 判断是否有目标
has_content = len(detection_results) > 0
return has_content, detection_results
except Exception as e:
error_msg = f"检测过程出错:{str(e)}"
print(error_msg)
return False, None
def get_enabled_model_rel_path():
"""获取数据库中启用的模型相对路径"""
conn = None
cursor = None
try:
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
query = "SELECT path FROM model WHERE is_default = 1 LIMIT 1"
cursor.execute(query)
result = cursor.fetchone()
if not result or not result.get('path'):
raise HTTPException(status_code=404, detail="未找到启用的默认模型")
return result['path']
except MySQLError as e:
raise HTTPException(status_code=500, detail=f"查询默认模型时发生数据库错误:{str(e)}") from e
except Exception as e:
if isinstance(e, HTTPException):
raise e
raise HTTPException(status_code=500, detail=f"获取默认模型路径失败:{str(e)}") from e
finally:
db.close_connection(conn, cursor)

131
service/ocr_service.py Normal file
View File

@ -0,0 +1,131 @@
# 首先添加NumPy兼容处理
import numpy as np
# 修复np.int已弃用的问题
if not hasattr(np, 'int'):
np.int = int
from paddleocr import PaddleOCR
from service.sensitive_service import get_all_sensitive_words
_ocr_engine = None
_forbidden_words = set()
_conf_threshold = 0.5
def set_forbidden_words(new_words):
global _forbidden_words
if not isinstance(new_words, (set, list, tuple)):
raise TypeError("新违禁词必须是集合、列表或元组类型")
_forbidden_words = set(new_words) # 确保是集合类型
print(f"已通过函数更新违禁词,当前数量: {len(_forbidden_words)}")
def load_forbidden_words():
global _forbidden_words
try:
_forbidden_words = get_all_sensitive_words()
print(f"加载的违禁词数量: {len(_forbidden_words)}")
except Exception as e:
print(f"Forbidden words load error: {e}")
return False
return True
def init_ocr_engine():
global _ocr_engine
try:
_ocr_engine = PaddleOCR(
use_angle_cls=True,
lang="ch",
show_log=False,
use_gpu=True,
max_text_length=1024
)
load_result = load_forbidden_words()
if not load_result:
print("警告:违禁词加载失败,可能影响检测功能")
print("OCR引擎初始化完成")
return True
except Exception as e:
print(f"OCR引擎初始化错误: {e}")
_ocr_engine = None
return False
def detect(frame, conf_threshold=0.8):
print("开始进行OCR检测...")
try:
ocr_res = _ocr_engine.ocr(frame, cls=True)
if not ocr_res or not isinstance(ocr_res, list):
return (False, "无OCR结果")
texts = []
confs = []
for line in ocr_res:
if line is None:
continue
if isinstance(line, list):
items_to_process = line
else:
items_to_process = [line]
for item in items_to_process:
if isinstance(item, list) and len(item) == 4:
is_coordinate = True
for point in item:
if not (isinstance(point, list) and len(point) == 2 and
all(isinstance(coord, (int, float)) for coord in point)):
is_coordinate = False
break
if is_coordinate:
continue
if isinstance(item, list) and all(isinstance(x, (int, float)) for x in item):
continue
if isinstance(item, tuple) and len(item) == 2:
text, conf = item
if isinstance(text, str) and isinstance(conf, (int, float)):
texts.append(text.strip())
confs.append(float(conf))
continue
if isinstance(item, list) and len(item) >= 2:
text_data = item[1]
if isinstance(text_data, tuple) and len(text_data) == 2:
text, conf = text_data
if isinstance(text, str) and isinstance(conf, (int, float)):
texts.append(text.strip())
confs.append(float(conf))
continue
elif isinstance(text_data, str):
texts.append(text_data.strip())
confs.append(1.0)
continue
print(f"无法解析的OCR结果格式: {item}")
if len(texts) != len(confs):
return (False, "OCR结果格式异常")
# 收集所有识别到的违禁词(去重且保持出现顺序)
vio_words = []
for txt, conf in zip(texts, confs):
if conf < _conf_threshold: # 过滤低置信度结果
continue
# 提取当前文本中包含的违禁词
matched = [w for w in _forbidden_words if w in txt]
# 仅添加未记录过的违禁词(去重)
for word in matched:
if word not in vio_words:
vio_words.append(word)
has_text = len(texts) > 0
has_violation = len(vio_words) > 0
if not has_text:
return (False, "未识别到文本")
elif has_violation:
# 多个违禁词用逗号拼接
return (True, ", ".join(vio_words))
else:
return (False, "未检测到违禁词")
except Exception as e:
print(f"OCR detect error: {e}")
return (False, f"检测错误: {str(e)}")

View File

@ -0,0 +1,36 @@
from mysql.connector import Error as MySQLError
from ds.db import db
def get_all_sensitive_words() -> list[str]:
"""
获取所有敏感词(返回纯字符串列表、用于过滤业务)
返回:
list[str]: 包含所有敏感词的数组
异常:
MySQLError: 数据库操作相关错误
"""
conn = None
cursor = None
try:
# 获取数据库连接
conn = db.get_connection()
cursor = conn.cursor(dictionary=True)
# 执行查询只获取敏感词字段、按ID排序
query = "SELECT name FROM sensitives ORDER BY id"
cursor.execute(query)
sensitive_records = cursor.fetchall()
# 提取敏感词到纯字符串数组
return [record['name'] for record in sensitive_records]
except MySQLError as e:
# 数据库错误向上抛出、由调用方处理
raise MySQLError(f"查询敏感词列表失败: {str(e)}") from e
finally:
# 确保数据库连接正确释放
db.close_connection(conn, cursor)

Binary file not shown.

After

Width:  |  Height:  |  Size: 652 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 550 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 549 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 562 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 546 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 585 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 584 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 561 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 561 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 562 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 546 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 386 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 338 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 338 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 341 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 347 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 337 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 250 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 247 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 250 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 822 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 805 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 791 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 877 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 814 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 749 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 197 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 200 KiB

Some files were not shown because too many files have changed in this diff Show More