This commit is contained in:
2025-09-03 14:38:42 +08:00
parent eb5cf715ec
commit b7773f5f00
19 changed files with 546 additions and 168 deletions

187
ocr/face_recognizer.py Normal file
View File

@ -0,0 +1,187 @@
import os
import cv2
import numpy as np
import insightface
from insightface.app import FaceAnalysis
class FaceRecognizer:
"""
封装InsightFace人脸识别功能支持从文件夹加载已知人脸。
"""
def __init__(self, known_faces_dir: str):
self.known_faces_dir = known_faces_dir
self.app = self._initialize_insightface()
self.known_faces_embeddings = {}
self.known_faces_names = []
self._load_known_faces()
def _initialize_insightface(self):
"""
初始化InsightFace FaceAnalysis应用。
默认使用CPU如果检测到CUDA会自动使用GPU。
"""
print("正在初始化InsightFace人脸识别引擎...")
try:
# 默认模型是 'buffalo_l',包含检测、对齐、识别功能
# 如果需要更小的模型,可以尝试 'buffalo_s' 或 'buffalo_m'
# ctx_id=0 表示使用GPUctx_id=-1 表示使用CPU
# InsightFace会自动检测CUDA并选择GPU所以通常不需要手动设置ctx_id
app = FaceAnalysis(name='buffalo_l', root='~/.insightface') # 模型下载到用户目录
app.prepare(ctx_id=0, det_size=(640, 640)) # det_size影响检测性能和精度
print("InsightFace人脸识别引擎初始化成功。")
return app
except Exception as e:
print(f"InsightFace人脸识别引擎初始化失败: {e}")
print("请确保已安装insightface和onnxruntime并且模型文件已下载或可访问。")
return None
def _load_known_faces(self):
"""
扫描已知人脸目录,加载每个人的照片并计算人脸特征。
"""
if not os.path.exists(self.known_faces_dir):
print(f"警告: 已知人脸目录 '{self.known_faces_dir}' 不存在。请创建并放入照片。")
os.makedirs(self.known_faces_dir, exist_ok=True)
return
print(f"正在加载已知人脸特征从: '{self.known_faces_dir}'...")
for person_name in os.listdir(self.known_faces_dir):
person_dir = os.path.join(self.known_faces_dir, person_name)
if os.path.isdir(person_dir):
print(f" 加载人物: {person_name}")
embeddings = []
for filename in os.listdir(person_dir):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
image_path = os.path.join(person_dir, filename)
try:
img = cv2.imread(image_path)
if img is None:
print(f" 警告: 无法读取图片 '{image_path}',已跳过。")
continue
# 查找人脸并提取特征
faces = self.app.get(img)
if faces:
# 通常一张照片只有一个人脸,取第一个
embeddings.append(faces[0].embedding)
print(f" 成功提取 '{filename}' 的人脸特征。")
else:
print(f" 警告: 在图片 '{filename}' 中未检测到人脸,已跳过。")
except Exception as e:
print(f" 处理图片 '{image_path}' 时发生错误: {e}")
if embeddings:
# 将多张照片的特征取平均,作为该人物的最终特征
self.known_faces_embeddings[person_name] = np.array(embeddings).mean(axis=0)
self.known_faces_names.append(person_name)
print(f" 人物 '{person_name}' 加载完成,共 {len(embeddings)} 张照片。")
else:
print(f" 警告: 人物 '{person_name}' 没有有效的人脸特征,已跳过。")
print(f"已知人脸加载完成。共 {len(self.known_faces_names)} 个人物。")
def recognize(self, frame, threshold=0.4):
"""
在视频帧中识别人脸。
Args:
frame: 输入的图像帧 (NumPy数组, BGR格式)。
threshold (float): 识别相似度阈值。0.0到1.0,越高越严格。
Returns:
tuple[bool, str|None, float|None]: 是否检测到已知人脸,检测到的人名,以及相似度。
"""
if not self.app or not self.known_faces_names:
return False, None, None
faces = self.app.get(frame) # 在帧中检测并提取所有人的脸
if not faces:
return False, None, None
for face in faces:
# 遍历已知人脸,进行比对
for known_name in self.known_faces_names:
known_embedding = self.known_faces_embeddings[known_name]
# --- 关键修改:手动计算余弦相似度 ---
# 确保embedding是float32类型避免潜在的类型不匹配问题
embedding1 = face.embedding.astype(np.float32)
embedding2 = known_embedding.astype(np.float32)
# 计算点积
dot_product = np.dot(embedding1, embedding2)
# 计算L2范数向量长度
norm_embedding1 = np.linalg.norm(embedding1)
norm_embedding2 = np.linalg.norm(embedding2)
# 避免除以零
if norm_embedding1 == 0 or norm_embedding2 == 0:
similarity = 0.0
else:
similarity = dot_product / (norm_embedding1 * norm_embedding2)
# -------------------------------------
if similarity >= threshold:
print(f"!!! 人脸识别检测到已知人物: '{known_name}' (相似度: {similarity:.4f}) !!!")
return True, known_name, similarity # 只要检测到一个就立即返回
return False, None, None # 没有检测到已知人脸
# def test_single_image(self, image_path: str, threshold=0.4):
# """
# 测试单张图片的人脸识别效果
#
# Args:
# image_path: 图片路径
# threshold: 识别阈值
#
# Returns:
# tuple[bool, str|None, float|None]: 是否检测到已知人脸,检测到的人名,以及相似度
# """
# if not os.path.exists(image_path):
# print(f"错误: 图片 '{image_path}' 不存在")
# return False, None, None
#
# # 读取图片
# frame = cv2.imread(image_path)
# if frame is None:
# print(f"错误: 无法读取图片 '{image_path}'")
# return False, None, None
#
# # 调用识别方法
# result, name, similarity = self.recognize(frame, threshold)
#
# # 显示结果
# if result:
# print(f"测试结果: 在图片中识别到 {name},相似度: {similarity:.4f}")
#
# # 绘制识别结果并显示图片
# faces = self.app.get(frame)
# for face in faces:
# bbox = face.bbox.astype(int)
# # 绘制 bounding box
# cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
# # 绘制姓名和相似度
# text = f"{name}: {similarity:.2f}"
# cv2.putText(frame, text, (bbox[0], bbox[1] - 10),
# cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
#
# # 显示图片
# cv2.imshow('Recognition Result', frame)
# print("按任意键关闭图片窗口...")
# cv2.waitKey(0)
# cv2.destroyAllWindows()
# else:
# print("测试结果: 未在图片中识别到已知人脸")
#
# return result, name, similarity
# if __name__ == "__main__":
# # 初始化人脸识别器,指定已知人脸目录
# recognizer = FaceRecognizer(known_faces_dir="known_faces")
#
# # 测试单张图片
# test_image_path = r"F:\OCR\RapidOCR-main\known_faces\B\14sino-qiu02-master1050.jpg" # 替换为你的测试图片路径
# recognizer.test_single_image(test_image_path, threshold=0.4)

BIN
ocr/images/img.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 195 KiB

BIN
ocr/images/img_7.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 208 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 657 KiB

BIN
ocr/known_faces/B/104-1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 58 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 155 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 386 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

View File

@ -0,0 +1,133 @@
import cv2
from logger_config import logger
from ocr_violation_detector import OCRViolationDetector
from yolo_violation_detector import ViolationDetector as YoloViolationDetector
from face_recognizer import FaceRecognizer
class MultiModelViolationDetector:
"""
多模型违规检测封装类串行调用OCR、人脸识别和YOLO模型调整为YOLO最后检测任一模型检测到违规即返回结果
"""
def __init__(self,
forbidden_words_path: str,
ocr_config_path: str, # 新增OCR配置文件路径参数
yolo_model_path: str,
known_faces_dir: str,
ocr_confidence_threshold: float = 0.5):
"""
初始化所有检测模型
Args:
forbidden_words_path: 违禁词文件路径
ocr_config_path: OCR配置文件1.yaml路径
yolo_model_path: YOLO模型文件路径
known_faces_dir: 已知人脸目录路径
ocr_confidence_threshold: OCR置信度阈值
"""
# 初始化OCR检测器传入配置文件路径
self.ocr_detector = OCRViolationDetector(
forbidden_words_path=forbidden_words_path,
ocr_config_path=ocr_config_path, # 传递配置文件路径
ocr_confidence_threshold=ocr_confidence_threshold
)
# 初始化人脸识别器
self.face_recognizer = FaceRecognizer(
known_faces_dir=known_faces_dir
)
# 初始化YOLO检测器调整为最后初始化
self.yolo_detector = YoloViolationDetector(
model_path=yolo_model_path
)
logger.info("多模型违规检测器初始化完成")
def detect_violations(self, frame):
"""
串行调用三个检测模型OCR → 人脸识别 → YOLO任一检测到违规即返回结果
Args:
frame: 输入视频帧 (NumPy数组, BGR格式)
Returns:
tuple: (是否有违规, 违规类型, 违规详情)
违规类型: 'ocr' | 'yolo' | 'face' | None
违规详情: 对应模型的检测结果
"""
# 1. 首先进行OCR违禁词检测
try:
ocr_has_violation, ocr_words, ocr_confs = self.ocr_detector.detect(frame)
if ocr_has_violation:
details = {
"words": ocr_words,
"confidences": ocr_confs
}
logger.warning(f"OCR检测到违禁内容: {details}")
return (True, "ocr", details)
except Exception as e:
logger.error(f"OCR检测出错: {str(e)}", exc_info=True)
# 2. 接着进行人脸识别检测
try:
face_has_violation, face_name, face_similarity = self.face_recognizer.recognize(frame)
if face_has_violation:
details = {
"name": face_name,
"similarity": face_similarity
}
logger.warning(f"人脸识别到违规人员: {details}")
return (True, "face", details)
except Exception as e:
logger.error(f"人脸识别出错: {str(e)}", exc_info=True)
# 3. 最后进行YOLO目标检测调整为最后检测
try:
yolo_results = self.yolo_detector.detect(frame)
# 检查是否有检测结果(根据实际业务定义何为违规目标)
if len(yolo_results.boxes) > 0:
# 提取检测到的目标信息
details = {
"classes": yolo_results.names,
"boxes": yolo_results.boxes.xyxy.tolist(), # 边界框坐标
"confidences": yolo_results.boxes.conf.tolist(), # 置信度
"class_ids": yolo_results.boxes.cls.tolist() # 类别ID
}
logger.warning(f"YOLO检测到违规目标: {details}")
return (True, "yolo", details)
except Exception as e:
logger.error(f"YOLO检测出错: {str(e)}", exc_info=True)
# 所有检测均未发现违规
return (False, None, None)
# # 使用示例
# if __name__ == "__main__":
# # 配置文件路径(根据实际情况修改)
# FORBIDDEN_WORDS_PATH = r"D:\Git\bin\video\ocr\forbidden_words.txt"
# OCR_CONFIG_PATH = r"D:\Git\bin\video\ocr\config\1.yaml" # 新增OCR配置文件路径
# YOLO_MODEL_PATH = r"D:\Git\bin\video\ocr\models\best.pt"
# KNOWN_FACES_DIR = r"D:\Git\bin\video\ocr\known_faces"
#
# # 初始化多模型检测器
# detector = MultiModelViolationDetector(
# forbidden_words_path=FORBIDDEN_WORDS_PATH,
# ocr_config_path=OCR_CONFIG_PATH, # 传入OCR配置文件路径
# yolo_model_path=YOLO_MODEL_PATH,
# known_faces_dir=KNOWN_FACES_DIR,
# ocr_confidence_threshold=0.5
# )
#
# # 读取测试图像(可替换为视频帧读取逻辑)
# test_image_path = r"D:\Git\bin\video\ocr\images\img.png"
# frame = cv2.imread(test_image_path)
#
# if frame is not None:
# has_violation, violation_type, details = detector.detect_violations(frame)
# if has_violation:
# print(f"检测到违规 - 类型: {violation_type}, 详情: {details}")
# else:
# print("未检测到任何违规内容")
# else:
# print(f"无法读取测试图像: {test_image_path}")

BIN
ocr/models/best.pt Normal file

Binary file not shown.

View File

@ -7,227 +7,248 @@ from rapidocr import RapidOCR
class OCRViolationDetector: class OCRViolationDetector:
""" """
封装RapidOCR引擎用于检测图像帧中的违禁词。 封装RapidOCR引擎用于检测图像帧中的违禁词。
核心功能加载违禁词、初始化OCR引擎、单帧图像违禁词检测
""" """
def __init__(self, forbidden_words_path: str, ocr_confidence_threshold: float = 0.5, def __init__(self,
log_level: int = logging.INFO, log_file: str = None): forbidden_words_path: str,
ocr_config_path: str,
ocr_confidence_threshold: float = 0.5,
log_level: int = logging.INFO,
log_file: str = None):
""" """
初始化OCR引擎、违禁词列表和日志配置。 初始化OCR引擎、违禁词列表和日志配置。
Args: Args:
forbidden_words_path (str): 违禁词列表 .txt 文件的路径。 forbidden_words_path (str): 违禁词列表 .txt 文件的路径。
ocr_confidence_threshold (float): OCR识别结果的置信度阈值 ocr_config_path (str): OCR配置文件如1.yaml的路径
log_level (int): 日志级别默认为logging.INFO ocr_confidence_threshold (float): OCR识别结果的置信度阈值0~1
log_file (str, optional): 日志文件路径,如不提供则只输出到控制台 log_level (int): 日志级别默认为logging.INFO。
log_file (str, optional): 日志文件路径,如不提供则只输出到控制台。
""" """
# 初始化日志 # 初始化日志(确保先初始化日志,后续操作可正常打日志)
self.logger = self._setup_logger(log_level, log_file) self.logger = self._setup_logger(log_level, log_file)
# 加载违禁词 # 加载违禁词(优先级:先加载配置,再初始化引擎)
self.forbidden_words = self._load_forbidden_words(forbidden_words_path) self.forbidden_words = self._load_forbidden_words(forbidden_words_path)
# 初始化OCR引擎 # 初始化RapidOCR引擎传入配置文件路径
self.ocr_engine = self._initialize_ocr() self.ocr_engine = self._initialize_ocr(ocr_config_path)
# 设置置信度阈值 # 校验核心依赖是否就绪
self.OCR_CONFIDENCE_THRESHOLD = ocr_confidence_threshold self._check_dependencies()
self.logger.info(f"OCR置信度阈值设置为: {ocr_confidence_threshold}")
# 设置置信度阈值限制在0~1范围避免非法值
self.OCR_CONFIDENCE_THRESHOLD = max(0.0, min(ocr_confidence_threshold, 1.0))
self.logger.info(f"OCR置信度阈值已设置范围0~1: {self.OCR_CONFIDENCE_THRESHOLD:.4f}")
def _setup_logger(self, log_level: int, log_file: str = None) -> logging.Logger: def _setup_logger(self, log_level: int, log_file: str = None) -> logging.Logger:
""" """
配置日志系统 配置日志系统(避免重复添加处理器,支持控制台+文件双输出)
Args: Args:
log_level: 日志级别 log_level: 日志级别如logging.DEBUG、logging.INFO
log_file: 日志文件路径,为None则只输出到控制台 log_file: 日志文件路径为None时仅输出到控制台
Returns: Returns:
配置好的logger实例 logging.Logger: 配置好的日志实例
""" """
# 创建logger
logger = logging.getLogger('OCRViolationDetector') logger = logging.getLogger('OCRViolationDetector')
logger.setLevel(log_level) logger.setLevel(log_level)
# 避免重复添加处理器 # 避免重复添加处理器(防止日志重复输出)
if logger.handlers: if logger.handlers:
return logger return logger
# 定义日志格式 # 定义日志格式(包含时间、模块名、级别、内容)
formatter = logging.Formatter( formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s' '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
) )
# 添加控制台处理器 # 1. 添加控制台处理器
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter) console_handler.setFormatter(formatter)
logger.addHandler(console_handler) logger.addHandler(console_handler)
# 如果提供了日志文件路径,则添加文件处理器 # 2. 若指定日志文件,添加文件处理器(自动创建目录)
if log_file: if log_file:
try: try:
# 确保日志目录存在
log_dir = os.path.dirname(log_file) log_dir = os.path.dirname(log_file)
# 若日志目录不存在,自动创建
if log_dir and not os.path.exists(log_dir): if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir) os.makedirs(log_dir, exist_ok=True)
self.logger.debug(f"自动创建日志目录: {log_dir}")
file_handler = logging.FileHandler(log_file, encoding='utf-8') file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
logger.addHandler(file_handler) logger.addHandler(file_handler)
logger.info(f"日志文件将保存至: {log_file}") logger.info(f"日志文件已配置: {log_file}")
except Exception as e: except Exception as e:
logger.warning(f"无法创建日志文件处理器: {str(e)},仅输出至控制台") logger.warning(f"创建日志文件失败(仅控制台输出): {str(e)}")
return logger return logger
def _load_forbidden_words(self, path): def _load_forbidden_words(self, path: str) -> set:
"""从txt文件加载违禁词列表""" """
words = set() 从TXT文件加载违禁词去重、过滤空行支持UTF-8编码
if not os.path.exists(path):
self.logger.warning(f"警告:未找到违禁词文件 {path},将跳过违禁词检测")
return words
Args:
path (str): 违禁词TXT文件路径。
Returns:
set: 去重后的违禁词集合(空集合表示加载失败)。
"""
forbidden_words = set()
# 第一步:检查文件是否存在
if not os.path.exists(path):
self.logger.error(f"违禁词文件不存在: {path}")
return forbidden_words
# 第二步:读取文件并处理内容
try: try:
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
# 去除每行首尾空格和换行符,过滤空行 # 过滤空行、去除首尾空格、去重
words = {line.strip() for line in f if line.strip()} forbidden_words = {
self.logger.info(f"成功加载 {len(words)} 个违禁词。") line.strip() for line in f
if line.strip() # 跳过空行或纯空格行
}
self.logger.info(f"成功加载违禁词: {len(forbidden_words)} 个(已去重)")
self.logger.debug(f"违禁词列表: {forbidden_words}")
except UnicodeDecodeError:
self.logger.error(f"违禁词文件编码错误需UTF-8: {path}")
except PermissionError:
self.logger.error(f"无权限读取违禁词文件: {path}")
except Exception as e: except Exception as e:
self.logger.error(f"加载违禁词文件失败:{str(e)},将跳过违禁词检测") self.logger.error(f"加载违禁词失败: {str(e)}", exc_info=True)
return words
def _initialize_ocr(self): return forbidden_words
"""初始化RapidOCR引擎"""
self.logger.info("正在初始化RapidOCR引擎...")
config_path = r"D:\Git\bin\video\ocr\config\1.yaml" def _initialize_ocr(self, config_path: str) -> RapidOCR | None:
try: """
# 检查配置文件是否存在 初始化RapidOCR引擎校验配置文件、捕获初始化异常
if not os.path.exists(config_path):
self.logger.error(f"RapidOCR配置文件不存在: {config_path}")
return None
engine = RapidOCR( Args:
config_path=config_path config_path (str): RapidOCR配置文件如1.yaml路径。
)
self.logger.info("RapidOCR引擎初始化成功。") Returns:
return engine RapidOCR | None: OCR引擎实例None表示初始化失败
except Exception as e: """
self.logger.error(f"RapidOCR引擎初始化失败: {e}") self.logger.info("开始初始化RapidOCR引擎...")
# 第一步:检查配置文件是否存在
if not os.path.exists(config_path):
self.logger.error(f"OCR配置文件不存在: {config_path}")
return None return None
def detect(self, frame): # 第二步初始化OCR引擎捕获RapidOCR相关异常
""" try:
对单帧图像进行OCR检测所有出现的违禁词并返回列表 ocr_engine = RapidOCR(config_path=config_path)
返回格式:(是否有违禁词, 违禁词列表, 对应的置信度列表) self.logger.info("RapidOCR引擎初始化成功")
""" return ocr_engine
print("收到帧") except ImportError:
if not self.ocr_engine or not self.forbidden_words: self.logger.error("RapidOCR依赖未安装需执行pip install rapidocr-onnxruntime")
return False, [], [] except Exception as e:
self.logger.error(f"RapidOCR初始化失败: {str(e)}", exc_info=True)
all_prohibited = [] # 存储所有检测到的违禁词 return None
all_confidences = [] # 存储对应违禁词的置信度
def _check_dependencies(self) -> None:
"""校验OCR引擎和违禁词列表是否就绪输出警告日志"""
if not self.ocr_engine:
self.logger.warning("⚠️ OCR引擎未就绪违禁词检测功能将禁用")
if not self.forbidden_words:
self.logger.warning("⚠️ 违禁词列表为空,违禁词检测功能将禁用")
def detect(self, frame) -> tuple[bool, list, list]:
"""
对单帧图像进行OCR违禁词检测核心方法
Args:
frame: 输入图像帧NumPy数组BGR格式cv2读取的图像
Returns:
tuple[bool, list, list]:
- 第一个元素是否检测到违禁词True/False
- 第二个元素:检测到的违禁词列表(空列表表示无违禁词);
- 第三个元素:对应违禁词的置信度列表(与违禁词列表一一对应)。
"""
# 初始化返回结果
has_violation = False
violation_words = []
violation_confs = []
# 前置校验1. 图像帧是否有效 2. OCR引擎是否就绪 3. 违禁词是否存在
if frame is None or frame.size == 0:
self.logger.warning("输入图像帧为空或无效跳过OCR检测")
return has_violation, violation_words, violation_confs
if not self.ocr_engine or not self.forbidden_words:
self.logger.debug("OCR引擎未就绪或违禁词为空跳过OCR检测")
return has_violation, violation_words, violation_confs
try: try:
# 执行OCR识别 # 1. 执行OCR识别获取RapidOCR原始结果
result = self.ocr_engine(frame) self.logger.debug("开始执行OCR识别...")
self.logger.debug(f"RapidOCR 原始返回结果: {result}") ocr_result = self.ocr_engine(frame)
self.logger.debug(f"RapidOCR原始结果: {ocr_result}")
if result is None: # 2. 校验OCR结果是否有效避免None或格式异常
return False, [], [] if ocr_result is None:
self.logger.debug("OCR识别未返回任何结果图像无文本或识别失败")
return has_violation, violation_words, violation_confs
# 提取文本和置信度适配RapidOCR的结果格式 # 3. 检查txts和scores是否存在且不为None
texts = result.txts if hasattr(result, 'txts') else [] if not hasattr(ocr_result, 'txts') or ocr_result.txts is None:
confidences = result.scores if hasattr(result, 'scores') else [] self.logger.warning("OCR结果中txts为None或不存在")
return has_violation, violation_words, violation_confs
# 遍历所有识别结果,收集所有违禁词 if not hasattr(ocr_result, 'scores') or ocr_result.scores is None:
self.logger.warning("OCR结果中scores为None或不存在")
return has_violation, violation_words, violation_confs
# 4. 转为列表并去None防止单个元素为None
# 确保txts是可迭代的如果不是则转为空列表
if not isinstance(ocr_result.txts, (list, tuple)):
self.logger.warning(f"OCR txts不是可迭代类型实际类型: {type(ocr_result.txts)}")
texts = []
else:
texts = [txt.strip() for txt in ocr_result.txts if txt and isinstance(txt, str)]
# 确保scores是可迭代的如果不是则转为空列表
if not isinstance(ocr_result.scores, (list, tuple)):
self.logger.warning(f"OCR scores不是可迭代类型实际类型: {type(ocr_result.scores)}")
confidences = []
else:
confidences = [conf for conf in ocr_result.scores if conf and isinstance(conf, (int, float))]
# 5. 校验文本和置信度列表长度是否一致避免zip迭代错误
if len(texts) != len(confidences):
self.logger.warning(
f"OCR文本与置信度数量不匹配文本{len(texts)}个,置信度{len(confidences)}个),跳过检测")
return has_violation, violation_words, violation_confs
if len(texts) == 0:
self.logger.debug("OCR未识别到任何有效文本")
return has_violation, violation_words, violation_confs
# 6. 遍历识别结果,筛选违禁词(按置信度阈值过滤)
self.logger.debug(f"开始筛选违禁词(阈值{self.OCR_CONFIDENCE_THRESHOLD:.4f}")
for text, conf in zip(texts, confidences): for text, conf in zip(texts, confidences):
# 过滤低置信度结果
if conf < self.OCR_CONFIDENCE_THRESHOLD: if conf < self.OCR_CONFIDENCE_THRESHOLD:
self.logger.debug(f"文本 '{text}' 置信度 {conf:.4f} 低于阈值,跳过") self.logger.debug(f"文本 '{text}' 置信度{conf:.4f} < 阈值,跳过")
continue continue
# 检查当前文本是否包含违禁词(支持一个文本含多个违禁词)
# 检查当前文本中是否包含多个违禁词 matched_words = [word for word in self.forbidden_words if word in text]
for word in self.forbidden_words: if matched_words:
if word in text: has_violation = True
self.logger.warning(f"OCR检测到违禁词: '{word}' (来自文本: '{text}') 置信度: {conf:.4f}") # 记录所有匹配的违禁词和对应置信度
all_prohibited.append(word) violation_words.extend(matched_words)
all_confidences.append(conf) violation_confs.extend([conf] * len(matched_words)) # 一个文本对应多个违禁词时,置信度复用
self.logger.warning(f"检测到违禁词: {matched_words}(来源文本: '{text}',置信度: {conf:.4f}")
except Exception as e: except Exception as e:
self.logger.error(f"OCR检测过程中发生错误: {e}", exc_info=True) # 捕获所有异常,确保不中断上层调用
self.logger.error(f"OCR检测过程异常: {str(e)}", exc_info=True)
# 返回检测结果(是否有违禁词、所有违禁词列表、对应置信度列表) return has_violation, violation_words, violation_confs
return len(all_prohibited) > 0, all_prohibited, all_confidences
# def test_image(self, image_path: str, show_image: bool = True) -> tuple:
# """
# 对单张图片进行OCR违禁词检测并展示结果
#
# Args:
# image_path (str): 图片文件路径
# show_image (bool): 是否显示图片默认为True
#
# Returns:
# tuple: (是否有违禁词, 违禁词列表, 对应的置信度列表)
# """
# # 检查图片文件是否存在
# if not os.path.exists(image_path):
# self.logger.error(f"图片文件不存在: {image_path}")
# return False, [], []
#
# try:
# # 读取图片
# frame = cv2.imread(image_path)
# if frame is None:
# self.logger.error(f"无法读取图片: {image_path}")
# return False, [], []
#
# self.logger.info(f"开始处理图片: {image_path}")
#
# # 调用检测方法
# has_violation, violations, confidences = self.detect(frame)
#
# # 输出检测结果
# if has_violation:
# self.logger.info(f"在图片中检测到 {len(violations)} 个违禁词:")
# for word, conf in zip(violations, confidences):
# self.logger.info(f"- {word} (置信度: {conf:.4f})")
# else:
# self.logger.info("图片中未检测到违禁词")
# # 显示图片(如果需要)
# if show_image:
# # 调整图片大小以便于显示(如果太大)
# height, width = frame.shape[:2]
# max_size = 800
# if max(height, width) > max_size:
# scale = max_size / max(height, width)
# frame = cv2.resize(frame, None, fx=scale, fy=scale)
#
# cv2.imshow(f"OCR检测结果: {'发现违禁词' if has_violation else '未发现违禁词'}", frame)
# cv2.waitKey(0) # 等待用户按键
# cv2.destroyAllWindows()
#
# return has_violation, violations, confidences
#
# except Exception as e:
# self.logger.error(f"处理图片时发生错误: {str(e)}", exc_info=True)
# return False, [], []
#
#
# # 使用示例
# if __name__ == "__main__":
# # 配置参数
# forbidden_words_path = "forbidden_words.txt" # 违禁词文件路径
# test_image_path = r"D:\Git\bin\video\ocr\images\img_7.png" # 测试图片路径
# ocr_threshold = 0.6 # OCR置信度阈值
#
# # 创建检测器实例
# detector = OCRViolationDetector(
# forbidden_words_path=forbidden_words_path,
# ocr_confidence_threshold=ocr_threshold,
# log_level=logging.INFO,
# log_file="ocr_detection.log"
# )
#
# # 测试图片
# detector.test_image(test_image_path, show_image=True)

View File

@ -0,0 +1,48 @@
from ultralytics import YOLO
import cv2
from logger_config import logger
class ViolationDetector:
"""
用于加载YOLOv8 .pt模型并进行违规内容检测的类。
"""
def __init__(self, model_path):
"""
初始化检测器。
Args:
model_path (str): YOLO .pt模型的路径。
"""
logger.info(f"正在从 '{model_path}' 加载YOLO模型...")
self.model = YOLO(model_path)
logger.info("YOLO模型加载成功。")
def detect(self, frame):
"""
对单帧图像进行目标检测。
Args:
frame: 输入的图像帧 (NumPy数组, BGR格式)。
Returns:
ultralytics.engine.results.Results: YOLO的检测结果对象。
"""
# conf可以根据您的模型效果进行调整
# --- 为了测试,我们暂时将置信度调低,例如 0.2 ---
results = self.model(frame, conf=0.2)
return results[0]
def draw_boxes(self, frame, result):
"""
在图像帧上绘制检测框。
Args:
frame: 原始图像帧。
result: YOLO的检测结果对象。
Returns:
numpy.ndarray: 绘制了检测框的图像帧。
"""
# 使用YOLO自带的plot功能方便快捷
annotated_frame = result.plot()
return annotated_frame

View File

@ -6,17 +6,6 @@ import time
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration
from aiortc.mediastreams import MediaStreamTrack from aiortc.mediastreams import MediaStreamTrack
from ocr.ocr_violation_detector import OCRViolationDetector
import logging
# 创建检测器实例
detector = OCRViolationDetector(
forbidden_words_path=r"D:\Git\bin\video\ocr\forbidden_words.txt",
ocr_confidence_threshold=0.7,
log_level=logging.INFO,
log_file="ocr_detection.log"
)
# 创建一个长度为1的队列用于生产者和消费者之间的通信 # 创建一个长度为1的队列用于生产者和消费者之间的通信
frame_queue = queue.Queue(maxsize=1) frame_queue = queue.Queue(maxsize=1)