69 lines
2.6 KiB
Python
69 lines
2.6 KiB
Python
|
||
from pymilvus import connections, utility, Collection, CollectionSchema, FieldSchema, DataType
|
||
from .config import settings
|
||
|
||
# 定义集合名称和相关参数
|
||
COLLECTION_NAME = "face_features_collection"
|
||
FACE_VECTOR_DIM = 512 # InsightFace 'buffalo_l' 模型的特征维度
|
||
|
||
class MilvusHelper:
|
||
def __init__(self):
|
||
try:
|
||
# 连接 Milvus 服务
|
||
connections.connect("default", host=settings.MILVUS_HOST, port=settings.MILVUS_PORT)
|
||
print("成功连接到 Milvus 服务。")
|
||
except Exception as e:
|
||
print(f"连接 Milvus 服务失败: {e}")
|
||
raise
|
||
|
||
def has_collection(self):
|
||
"""检查集合是否存在"""
|
||
return utility.has_collection(COLLECTION_NAME)
|
||
|
||
def create_collection(self):
|
||
"""创建一个新的集合来存储人脸特征"""
|
||
if self.has_collection():
|
||
print(f"集合 '{COLLECTION_NAME}' 已存在。")
|
||
return
|
||
|
||
# 定义字段
|
||
# 主键字段,Milvus 会自动生成ID
|
||
pk_field = FieldSchema(name="feature_id", dtype=DataType.INT64, is_primary=True, auto_id=True)
|
||
# 对应的用户ID字段
|
||
user_id_field = FieldSchema(name="user_id", dtype=DataType.INT64)
|
||
# 人脸特征向量字段
|
||
embedding_field = FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=FACE_VECTOR_DIM)
|
||
|
||
# 创建集合 Schema
|
||
schema = CollectionSchema(
|
||
fields=[pk_field, user_id_field, embedding_field],
|
||
description="人脸识别特征集合",
|
||
enable_dynamic_field=False
|
||
)
|
||
|
||
# 创建集合
|
||
self.collection = Collection(name=COLLECTION_NAME, schema=schema)
|
||
print(f"集合 '{COLLECTION_NAME}' 创建成功。")
|
||
|
||
# 为向量字段创建索引以加速搜索
|
||
index_params = {
|
||
"metric_type": "IP", # IP (Inner Product) 等价于归一化向量的余弦相似度
|
||
"index_type": "IVF_FLAT",
|
||
"params": {"nlist": 1024} # nlist 的值需要根据数据量调整
|
||
}
|
||
self.collection.create_index(field_name="embedding", index_params=index_params)
|
||
print("向量索引创建成功。")
|
||
return self.collection
|
||
|
||
def get_collection(self):
|
||
"""获取集合对象并加载到内存中以便搜索"""
|
||
if not self.has_collection():
|
||
self.create_collection()
|
||
collection = Collection(COLLECTION_NAME)
|
||
collection.load()
|
||
return collection
|
||
|
||
# 在模块加载时创建一个全局实例
|
||
milvus_client = MilvusHelper()
|
||
collection = milvus_client.get_collection()
|