216 lines
8.9 KiB
Python
216 lines
8.9 KiB
Python
|
||
import cv2
|
||
import numpy as np
|
||
from insightface.app import FaceAnalysis
|
||
from . import database
|
||
import base64
|
||
import requests
|
||
|
||
class FaceRecognitionService:
|
||
def __init__(self):
|
||
# 初始化InsightFace分析器
|
||
self.app = FaceAnalysis(name="buffalo_l", providers=['CPUExecutionProvider'])
|
||
self.app.prepare(ctx_id=0, det_size=(640, 640))
|
||
|
||
def _decode_image(self, image_source):
|
||
"""从URL或Base64解码图片"""
|
||
img = None
|
||
if image_source.face_data:
|
||
try:
|
||
img_data = base64.b64decode(image_source.face_data)
|
||
nparr = np.frombuffer(img_data, np.uint8)
|
||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||
except Exception as e:
|
||
raise ValueError(f"Base64解码失败: {e}")
|
||
elif image_source.url:
|
||
try:
|
||
response = requests.get(image_source.url, timeout=10)
|
||
response.raise_for_status()
|
||
img_array = np.asarray(bytearray(response.content), dtype=np.uint8)
|
||
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
|
||
except Exception as e:
|
||
raise ValueError(f"从URL获取图片失败: {e}")
|
||
|
||
if img is None:
|
||
raise ValueError("无法加载图片")
|
||
return img
|
||
|
||
def register_new_face(self, user_id: int, name: str, image_source):
|
||
"""注册新的人脸"""
|
||
with database.get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
|
||
# 1. 检查ID是否已存在
|
||
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
|
||
user = cursor.fetchone()
|
||
|
||
if user and user['name'] != name:
|
||
raise ValueError(f"ID {user_id} 已被 '{user['name']}' 注册,无法更改为 '{name}'。")
|
||
|
||
# 2. 解码和处理图片
|
||
image = self._decode_image(image_source)
|
||
faces = self.app.get(image)
|
||
|
||
if not faces:
|
||
raise ValueError("图片中未检测到人脸。")
|
||
if len(faces) > 1:
|
||
raise ValueError("注册图片中只能包含一张人脸。")
|
||
|
||
embedding = faces[0].embedding
|
||
|
||
# 3. 存储到数据库
|
||
if not user:
|
||
# 如果是新用户,先在users表创建记录
|
||
cursor.execute("INSERT INTO users (id, name) VALUES (?, ?)", (user_id, name))
|
||
|
||
# 插入新的人脸特征
|
||
cursor.execute("INSERT INTO face_features (user_id, embedding) VALUES (?, ?)", (user_id, embedding.tobytes()))
|
||
|
||
conn.commit()
|
||
|
||
# 4. 返回用户信息
|
||
cursor.execute("SELECT COUNT(*) FROM face_features WHERE user_id = ?", (user_id,))
|
||
count = cursor.fetchone()[0]
|
||
|
||
return {"id": user_id, "name": name, "registered_faces_count": count}
|
||
|
||
|
||
def detect_faces(self, image_source):
|
||
"""1:N 人脸识别"""
|
||
image = self._decode_image(image_source)
|
||
detected_faces = self.app.get(image)
|
||
|
||
if not detected_faces:
|
||
return []
|
||
|
||
with database.get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT user_id, embedding FROM face_features")
|
||
db_features = cursor.fetchall()
|
||
|
||
if not db_features:
|
||
# 如果数据库为空,所有检测到的人脸都是未知的
|
||
# (此处简化处理,实际可返回带位置的未知人脸列表)
|
||
return []
|
||
|
||
# 将数据库特征加载到Numpy数组中以便快速计算
|
||
db_user_ids = np.array([f['user_id'] for f in db_features])
|
||
db_embeddings = np.array([np.frombuffer(f['embedding'], dtype=np.float32) for f in db_features])
|
||
|
||
results = []
|
||
for face in detected_faces:
|
||
embedding = face.embedding
|
||
|
||
# 计算与数据库中所有特征的余弦相似度
|
||
# (insightface的特征是归一化的,点积等价于余弦相似度)
|
||
similarities = np.dot(db_embeddings, embedding)
|
||
|
||
best_match_index = np.argmax(similarities)
|
||
best_similarity = similarities[best_match_index]
|
||
|
||
# 设置一个阈值来判断是否为已知人脸
|
||
# ArcFace 官方建议的阈值通常在 0.4 到 0.6 之间,这里我们用0.5
|
||
RECOGNITION_THRESHOLD = 0.5
|
||
|
||
if best_similarity > RECOGNITION_THRESHOLD:
|
||
matched_user_id = db_user_ids[best_match_index]
|
||
|
||
# 查询用户信息
|
||
cursor.execute("SELECT name, (SELECT COUNT(*) FROM face_features WHERE user_id=?) FROM users WHERE id=?", (matched_user_id, matched_user_id))
|
||
user_info = cursor.fetchone()
|
||
|
||
x1, y1, x2, y2 = face.bbox.astype(int)
|
||
results.append({
|
||
"id": matched_user_id,
|
||
"name": user_info['name'],
|
||
"registered_faces_count": user_info[2],
|
||
"confidence": float(best_similarity),
|
||
"location": {"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1}
|
||
})
|
||
|
||
return results
|
||
|
||
|
||
def verify_face(self, user_id: int, image_source):
|
||
"""1:1 人脸认证"""
|
||
# 1. 解码图片并检测人脸
|
||
image = self._decode_image(image_source)
|
||
detected_faces = self.app.get(image)
|
||
|
||
if not detected_faces:
|
||
raise ValueError("图片中未检测到人脸。")
|
||
if len(detected_faces) > 1:
|
||
raise ValueError("用于认证的图片中只能包含一张人脸。")
|
||
|
||
embedding = detected_faces[0].embedding
|
||
|
||
# 2. 从数据库获取该ID对应的所有人脸特征
|
||
with database.get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT embedding FROM face_features WHERE user_id = ?", (user_id,))
|
||
db_features = cursor.fetchall()
|
||
|
||
if not db_features:
|
||
raise ValueError(f"数据库中不存在ID为 {user_id} 的用户,或该用户未注册任何人脸。")
|
||
|
||
db_embeddings = np.array([np.frombuffer(f['embedding'], dtype=np.float32) for f in db_features])
|
||
|
||
# 3. 计算与该ID所有特征的相似度,取最高值
|
||
similarities = np.dot(db_embeddings, embedding)
|
||
best_similarity = np.max(similarities)
|
||
|
||
# 1:1 认证通常使用比 1:N 更严格的阈值
|
||
VERIFICATION_THRESHOLD = 0.6
|
||
|
||
match = bool(best_similarity > VERIFICATION_THRESHOLD)
|
||
|
||
return {"match": match, "confidence": float(best_similarity)}
|
||
|
||
|
||
def delete_user(self, user_id: int):
|
||
"""根据ID删除用户及其所有的人脸数据"""
|
||
with database.get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT name FROM users WHERE id = ?", (user_id,))
|
||
user = cursor.fetchone()
|
||
if not user:
|
||
raise ValueError(f"ID为 {user_id} 的用户不存在。")
|
||
|
||
# 使用了外键的 ON DELETE CASCADE,删除users表中的记录会自动删除face_features中的相关记录
|
||
cursor.execute("DELETE FROM users WHERE id = ?", (user_id,))
|
||
conn.commit()
|
||
|
||
# cursor.rowcount 在 SQLite 中可能不总是可靠,我们确认用户存在即认为删除成功
|
||
return {"id": user_id, "name": user['name']}
|
||
|
||
def get_user(self, user_id: int):
|
||
"""根据ID获取用户信息"""
|
||
with database.get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("SELECT id, name, (SELECT COUNT(*) FROM face_features WHERE user_id=?) FROM users WHERE id=?", (user_id, user_id))
|
||
user = cursor.fetchone()
|
||
if not user:
|
||
return None
|
||
return {"id": user['id'], "name": user['name'], "registered_faces_count": user[2]}
|
||
|
||
def list_all_users(self, skip: int = 0, limit: int = 100):
|
||
"""列出所有已注册的用户,支持分页"""
|
||
with database.get_db_connection() as conn:
|
||
cursor = conn.cursor()
|
||
cursor.execute("""
|
||
SELECT u.id, u.name, COUNT(f.feature_id) as face_count
|
||
FROM users u
|
||
LEFT JOIN face_features f ON u.id = f.user_id
|
||
GROUP BY u.id, u.name
|
||
ORDER BY u.id
|
||
LIMIT ? OFFSET ?
|
||
""", (limit, skip))
|
||
users = cursor.fetchall()
|
||
return [{"id": u['id'], "name": u['name'], "registered_faces_count": u['face_count']} for u in users]
|
||
|
||
# 在文件末尾实例化服务,方便API层调用
|
||
face_service = FaceRecognitionService()
|
||
|
||
|
||
|