192 lines
7.6 KiB
Python
192 lines
7.6 KiB
Python
|
|
|||
|
import cv2
|
|||
|
import numpy as np
|
|||
|
from insightface.app import FaceAnalysis
|
|||
|
import base64
|
|||
|
import requests
|
|||
|
from .database import database, users
|
|||
|
from .milvus_helpers import collection as milvus_collection
|
|||
|
|
|||
|
class FaceRecognitionService:
|
|||
|
"""封装所有核心人脸识别和数据库操作的业务逻辑"""
|
|||
|
def __init__(self):
|
|||
|
self.app = FaceAnalysis(name="buffalo_l", providers=['CPUExecutionProvider'])
|
|||
|
self.app.prepare(ctx_id=0, det_size=(640, 640))
|
|||
|
|
|||
|
def _decode_image(self, image_source):
|
|||
|
"""私有方法,从URL或Base64解码图片"""
|
|||
|
img = None
|
|||
|
if image_source.face_data:
|
|||
|
try:
|
|||
|
img_data = base64.b64decode(image_source.face_data)
|
|||
|
nparr = np.frombuffer(img_data, np.uint8)
|
|||
|
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
|||
|
except Exception as e:
|
|||
|
raise ValueError(f"Base64解码失败: {e}")
|
|||
|
elif image_source.url:
|
|||
|
try:
|
|||
|
response = requests.get(image_source.url, timeout=10)
|
|||
|
response.raise_for_status()
|
|||
|
img_array = np.asarray(bytearray(response.content), dtype=np.uint8)
|
|||
|
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
|
|||
|
except Exception as e:
|
|||
|
raise ValueError(f"从URL获取图片失败: {e}")
|
|||
|
|
|||
|
if img is None:
|
|||
|
raise ValueError("无法加载图片")
|
|||
|
return img
|
|||
|
|
|||
|
async def register_new_face(self, user_id: int, name: str, image_source):
|
|||
|
"""注册新的人脸,元数据存入PostgreSQL,向量存入Milvus"""
|
|||
|
query = users.select().where(users.c.id == user_id)
|
|||
|
user = await database.fetch_one(query)
|
|||
|
|
|||
|
if user and user.name != name:
|
|||
|
raise ValueError(f"ID {user_id} 已被 '{user.name}' 注册,无法更改为 '{name}'。")
|
|||
|
|
|||
|
image = self._decode_image(image_source)
|
|||
|
faces = self.app.get(image)
|
|||
|
|
|||
|
if not faces:
|
|||
|
raise ValueError("图片中未检测到人脸。")
|
|||
|
if len(faces) > 1:
|
|||
|
raise ValueError("注册图片中只能包含一张人脸。")
|
|||
|
|
|||
|
embedding = faces[0].embedding
|
|||
|
|
|||
|
async with database.transaction():
|
|||
|
if not user:
|
|||
|
query = users.insert().values(id=user_id, name=name)
|
|||
|
await database.execute(query)
|
|||
|
|
|||
|
data_to_insert = [[user_id], [embedding]]
|
|||
|
milvus_collection.insert(data_to_insert)
|
|||
|
milvus_collection.flush()
|
|||
|
|
|||
|
count_expr = f"user_id == {user_id}"
|
|||
|
count = milvus_collection.query(expr=count_expr, output_fields=["user_id"])
|
|||
|
|
|||
|
return {"id": user_id, "name": name, "registered_faces_count": len(count)}
|
|||
|
|
|||
|
async def detect_faces(self, image_source):
|
|||
|
"""在Milvus中进行1:N人脸识别"""
|
|||
|
image = self._decode_image(image_source)
|
|||
|
detected_faces = self.app.get(image)
|
|||
|
|
|||
|
if not detected_faces:
|
|||
|
return []
|
|||
|
|
|||
|
search_vectors = [face.embedding for face in detected_faces]
|
|||
|
search_params = {"metric_type": "IP", "params": {"nprobe": 10}}
|
|||
|
|
|||
|
results = milvus_collection.search(
|
|||
|
data=search_vectors, anns_field="embedding", param=search_params,
|
|||
|
limit=1, output_fields=["user_id"]
|
|||
|
)
|
|||
|
|
|||
|
final_results = []
|
|||
|
user_ids_to_fetch = {hits[0].entity.get('user_id') for hits in results if hits}
|
|||
|
|
|||
|
if not user_ids_to_fetch:
|
|||
|
return []
|
|||
|
|
|||
|
# 批量查询用户信息,提高效率
|
|||
|
user_query = users.select().where(users.c.id.in_(user_ids_to_fetch))
|
|||
|
user_records = await database.fetch_all(user_query)
|
|||
|
user_map = {user.id: user.name for user in user_records}
|
|||
|
|
|||
|
for i, hits in enumerate(results):
|
|||
|
if not hits:
|
|||
|
continue
|
|||
|
best_hit = hits[0]
|
|||
|
similarity = best_hit.distance
|
|||
|
|
|||
|
RECOGNITION_THRESHOLD = 0.5
|
|||
|
if similarity > RECOGNITION_THRESHOLD:
|
|||
|
matched_user_id = best_hit.entity.get('user_id')
|
|||
|
user_name = user_map.get(matched_user_id)
|
|||
|
|
|||
|
if user_name:
|
|||
|
count_expr = f"user_id == {matched_user_id}"
|
|||
|
count_res = milvus_collection.query(expr=count_expr, output_fields=["user_id"])
|
|||
|
x1, y1, x2, y2 = detected_faces[i].bbox.astype(int)
|
|||
|
final_results.append({
|
|||
|
"id": matched_user_id,
|
|||
|
"name": user_name,
|
|||
|
"registered_faces_count": len(count_res),
|
|||
|
"confidence": float(similarity),
|
|||
|
"location": {"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1}
|
|||
|
})
|
|||
|
return final_results
|
|||
|
|
|||
|
async def verify_face(self, user_id: int, image_source):
|
|||
|
"""在Milvus中进行1:1人脸认证"""
|
|||
|
image = self._decode_image(image_source)
|
|||
|
detected_faces = self.app.get(image)
|
|||
|
|
|||
|
if not detected_faces:
|
|||
|
raise ValueError("图片中未检测到人脸。")
|
|||
|
if len(detected_faces) > 1:
|
|||
|
raise ValueError("用于认证的图片中只能包含一张人脸。")
|
|||
|
|
|||
|
embedding = detected_faces[0].embedding
|
|||
|
|
|||
|
search_params = {"metric_type": "IP", "params": {"nprobe": 10}}
|
|||
|
# 在指定user_id的范围内进行搜索
|
|||
|
expr = f"user_id == {user_id}"
|
|||
|
results = milvus_collection.search(
|
|||
|
data=[embedding], anns_field="embedding", param=search_params,
|
|||
|
limit=1, expr=expr, output_fields=["user_id"]
|
|||
|
)
|
|||
|
|
|||
|
best_similarity = 0.0
|
|||
|
if results and results[0]:
|
|||
|
best_similarity = results[0][0].distance
|
|||
|
|
|||
|
VERIFICATION_THRESHOLD = 0.6
|
|||
|
match = bool(best_similarity > VERIFICATION_THRESHOLD)
|
|||
|
|
|||
|
return {"match": match, "confidence": float(best_similarity)}
|
|||
|
|
|||
|
async def delete_user(self, user_id: int):
|
|||
|
"""从PostgreSQL和Milvus中删除用户"""
|
|||
|
query = users.select().where(users.c.id == user_id)
|
|||
|
user = await database.fetch_one(query)
|
|||
|
if not user:
|
|||
|
raise ValueError(f"ID为 {user_id} 的用户不存在。")
|
|||
|
|
|||
|
async with database.transaction():
|
|||
|
delete_query = users.delete().where(users.c.id == user_id)
|
|||
|
await database.execute(delete_query)
|
|||
|
|
|||
|
# 从 Milvus 删除
|
|||
|
expr = f"user_id == {user_id}"
|
|||
|
milvus_collection.delete(expr)
|
|||
|
|
|||
|
return {"id": user.id, "name": user.name}
|
|||
|
|
|||
|
async def get_user(self, user_id: int):
|
|||
|
"""获取单个用户的元数据和在Milvus中的人脸数量"""
|
|||
|
query = users.select().where(users.c.id == user_id)
|
|||
|
user = await database.fetch_one(query)
|
|||
|
if not user:
|
|||
|
return None
|
|||
|
|
|||
|
count_expr = f"user_id == {user_id}"
|
|||
|
count_res = milvus_collection.query(expr=count_expr, output_fields=["user_id"])
|
|||
|
return {"id": user.id, "name": user.name, "registered_faces_count": len(count_res)}
|
|||
|
|
|||
|
async def list_all_users(self, skip: int = 0, limit: int = 100):
|
|||
|
"""列出所有用户,并从Milvus统计人脸数量"""
|
|||
|
query = users.select().offset(skip).limit(limit)
|
|||
|
user_records = await database.fetch_all(query)
|
|||
|
|
|||
|
results = []
|
|||
|
for user in user_records:
|
|||
|
count_expr = f"user_id == {user.id}"
|
|||
|
count_res = milvus_collection.query(expr=count_expr, output_fields=["user_id"])
|
|||
|
results.append({"id": user.id, "name": user.name, "registered_faces_count": len(count_res)})
|
|||
|
return results
|
|||
|
|
|||
|
face_service = FaceRecognitionService()
|