Files
Face-Verifying-HighPost/app/milvus_helpers.py
2025-07-29 18:15:35 +08:00

69 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from pymilvus import connections, utility, Collection, CollectionSchema, FieldSchema, DataType
from .config import settings
# 定义集合名称和相关参数
COLLECTION_NAME = "face_features_collection"
FACE_VECTOR_DIM = 512 # InsightFace 'buffalo_l' 模型的特征维度
class MilvusHelper:
def __init__(self):
try:
# 连接 Milvus 服务
connections.connect("default", host=settings.MILVUS_HOST, port=settings.MILVUS_PORT)
print("成功连接到 Milvus 服务。")
except Exception as e:
print(f"连接 Milvus 服务失败: {e}")
raise
def has_collection(self):
"""检查集合是否存在"""
return utility.has_collection(COLLECTION_NAME)
def create_collection(self):
"""创建一个新的集合来存储人脸特征"""
if self.has_collection():
print(f"集合 '{COLLECTION_NAME}' 已存在。")
return
# 定义字段
# 主键字段Milvus 会自动生成ID
pk_field = FieldSchema(name="feature_id", dtype=DataType.INT64, is_primary=True, auto_id=True)
# 对应的用户ID字段
user_id_field = FieldSchema(name="user_id", dtype=DataType.INT64)
# 人脸特征向量字段
embedding_field = FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=FACE_VECTOR_DIM)
# 创建集合 Schema
schema = CollectionSchema(
fields=[pk_field, user_id_field, embedding_field],
description="人脸识别特征集合",
enable_dynamic_field=False
)
# 创建集合
self.collection = Collection(name=COLLECTION_NAME, schema=schema)
print(f"集合 '{COLLECTION_NAME}' 创建成功。")
# 为向量字段创建索引以加速搜索
index_params = {
"metric_type": "IP", # IP (Inner Product) 等价于归一化向量的余弦相似度
"index_type": "IVF_FLAT",
"params": {"nlist": 1024} # nlist 的值需要根据数据量调整
}
self.collection.create_index(field_name="embedding", index_params=index_params)
print("向量索引创建成功。")
return self.collection
def get_collection(self):
"""获取集合对象并加载到内存中以便搜索"""
if not self.has_collection():
self.create_collection()
collection = Collection(COLLECTION_NAME)
collection.load()
return collection
# 在模块加载时创建一个全局实例
milvus_client = MilvusHelper()
collection = milvus_client.get_collection()