192 lines
7.6 KiB
Python
192 lines
7.6 KiB
Python
|
||
import cv2
|
||
import numpy as np
|
||
from insightface.app import FaceAnalysis
|
||
import base64
|
||
import requests
|
||
from .database import database, users
|
||
from .milvus_helpers import collection as milvus_collection
|
||
|
||
class FaceRecognitionService:
|
||
"""封装所有核心人脸识别和数据库操作的业务逻辑"""
|
||
def __init__(self):
|
||
self.app = FaceAnalysis(name="buffalo_l", providers=['CPUExecutionProvider'])
|
||
self.app.prepare(ctx_id=0, det_size=(640, 640))
|
||
|
||
def _decode_image(self, image_source):
|
||
"""私有方法,从URL或Base64解码图片"""
|
||
img = None
|
||
if image_source.face_data:
|
||
try:
|
||
img_data = base64.b64decode(image_source.face_data)
|
||
nparr = np.frombuffer(img_data, np.uint8)
|
||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||
except Exception as e:
|
||
raise ValueError(f"Base64解码失败: {e}")
|
||
elif image_source.url:
|
||
try:
|
||
response = requests.get(image_source.url, timeout=10)
|
||
response.raise_for_status()
|
||
img_array = np.asarray(bytearray(response.content), dtype=np.uint8)
|
||
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
|
||
except Exception as e:
|
||
raise ValueError(f"从URL获取图片失败: {e}")
|
||
|
||
if img is None:
|
||
raise ValueError("无法加载图片")
|
||
return img
|
||
|
||
async def register_new_face(self, user_id: int, name: str, image_source):
|
||
"""注册新的人脸,元数据存入PostgreSQL,向量存入Milvus"""
|
||
query = users.select().where(users.c.id == user_id)
|
||
user = await database.fetch_one(query)
|
||
|
||
if user and user.name != name:
|
||
raise ValueError(f"ID {user_id} 已被 '{user.name}' 注册,无法更改为 '{name}'。")
|
||
|
||
image = self._decode_image(image_source)
|
||
faces = self.app.get(image)
|
||
|
||
if not faces:
|
||
raise ValueError("图片中未检测到人脸。")
|
||
if len(faces) > 1:
|
||
raise ValueError("注册图片中只能包含一张人脸。")
|
||
|
||
embedding = faces[0].embedding
|
||
|
||
async with database.transaction():
|
||
if not user:
|
||
query = users.insert().values(id=user_id, name=name)
|
||
await database.execute(query)
|
||
|
||
data_to_insert = [[user_id], [embedding]]
|
||
milvus_collection.insert(data_to_insert)
|
||
milvus_collection.flush()
|
||
|
||
count_expr = f"user_id == {user_id}"
|
||
count = milvus_collection.query(expr=count_expr, output_fields=["user_id"])
|
||
|
||
return {"id": user_id, "name": name, "registered_faces_count": len(count)}
|
||
|
||
async def detect_faces(self, image_source):
|
||
"""在Milvus中进行1:N人脸识别"""
|
||
image = self._decode_image(image_source)
|
||
detected_faces = self.app.get(image)
|
||
|
||
if not detected_faces:
|
||
return []
|
||
|
||
search_vectors = [face.embedding for face in detected_faces]
|
||
search_params = {"metric_type": "IP", "params": {"nprobe": 10}}
|
||
|
||
results = milvus_collection.search(
|
||
data=search_vectors, anns_field="embedding", param=search_params,
|
||
limit=1, output_fields=["user_id"]
|
||
)
|
||
|
||
final_results = []
|
||
user_ids_to_fetch = {hits[0].entity.get('user_id') for hits in results if hits}
|
||
|
||
if not user_ids_to_fetch:
|
||
return []
|
||
|
||
# 批量查询用户信息,提高效率
|
||
user_query = users.select().where(users.c.id.in_(user_ids_to_fetch))
|
||
user_records = await database.fetch_all(user_query)
|
||
user_map = {user.id: user.name for user in user_records}
|
||
|
||
for i, hits in enumerate(results):
|
||
if not hits:
|
||
continue
|
||
best_hit = hits[0]
|
||
similarity = best_hit.distance
|
||
|
||
RECOGNITION_THRESHOLD = 0.5
|
||
if similarity > RECOGNITION_THRESHOLD:
|
||
matched_user_id = best_hit.entity.get('user_id')
|
||
user_name = user_map.get(matched_user_id)
|
||
|
||
if user_name:
|
||
count_expr = f"user_id == {matched_user_id}"
|
||
count_res = milvus_collection.query(expr=count_expr, output_fields=["user_id"])
|
||
x1, y1, x2, y2 = detected_faces[i].bbox.astype(int)
|
||
final_results.append({
|
||
"id": matched_user_id,
|
||
"name": user_name,
|
||
"registered_faces_count": len(count_res),
|
||
"confidence": float(similarity),
|
||
"location": {"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1}
|
||
})
|
||
return final_results
|
||
|
||
async def verify_face(self, user_id: int, image_source):
|
||
"""在Milvus中进行1:1人脸认证"""
|
||
image = self._decode_image(image_source)
|
||
detected_faces = self.app.get(image)
|
||
|
||
if not detected_faces:
|
||
raise ValueError("图片中未检测到人脸。")
|
||
if len(detected_faces) > 1:
|
||
raise ValueError("用于认证的图片中只能包含一张人脸。")
|
||
|
||
embedding = detected_faces[0].embedding
|
||
|
||
search_params = {"metric_type": "IP", "params": {"nprobe": 10}}
|
||
# 在指定user_id的范围内进行搜索
|
||
expr = f"user_id == {user_id}"
|
||
results = milvus_collection.search(
|
||
data=[embedding], anns_field="embedding", param=search_params,
|
||
limit=1, expr=expr, output_fields=["user_id"]
|
||
)
|
||
|
||
best_similarity = 0.0
|
||
if results and results[0]:
|
||
best_similarity = results[0][0].distance
|
||
|
||
VERIFICATION_THRESHOLD = 0.6
|
||
match = bool(best_similarity > VERIFICATION_THRESHOLD)
|
||
|
||
return {"match": match, "confidence": float(best_similarity)}
|
||
|
||
async def delete_user(self, user_id: int):
|
||
"""从PostgreSQL和Milvus中删除用户"""
|
||
query = users.select().where(users.c.id == user_id)
|
||
user = await database.fetch_one(query)
|
||
if not user:
|
||
raise ValueError(f"ID为 {user_id} 的用户不存在。")
|
||
|
||
async with database.transaction():
|
||
delete_query = users.delete().where(users.c.id == user_id)
|
||
await database.execute(delete_query)
|
||
|
||
# 从 Milvus 删除
|
||
expr = f"user_id == {user_id}"
|
||
milvus_collection.delete(expr)
|
||
|
||
return {"id": user.id, "name": user.name}
|
||
|
||
async def get_user(self, user_id: int):
|
||
"""获取单个用户的元数据和在Milvus中的人脸数量"""
|
||
query = users.select().where(users.c.id == user_id)
|
||
user = await database.fetch_one(query)
|
||
if not user:
|
||
return None
|
||
|
||
count_expr = f"user_id == {user_id}"
|
||
count_res = milvus_collection.query(expr=count_expr, output_fields=["user_id"])
|
||
return {"id": user.id, "name": user.name, "registered_faces_count": len(count_res)}
|
||
|
||
async def list_all_users(self, skip: int = 0, limit: int = 100):
|
||
"""列出所有用户,并从Milvus统计人脸数量"""
|
||
query = users.select().offset(skip).limit(limit)
|
||
user_records = await database.fetch_all(query)
|
||
|
||
results = []
|
||
for user in user_records:
|
||
count_expr = f"user_id == {user.id}"
|
||
count_res = milvus_collection.query(expr=count_expr, output_fields=["user_id"])
|
||
results.append({"id": user.id, "name": user.name, "registered_faces_count": len(count_res)})
|
||
return results
|
||
|
||
face_service = FaceRecognitionService()
|