Files
Seg_Data_Server/Seg_Predict_YoloModel/yolo_Seg_Video-V1-Visible.py
2026-05-20 15:05:35 +08:00

181 lines
7.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
import cv2
import numpy as np
from ultralytics import YOLO
import random
import os
from tqdm import tqdm # 引入进度条库
# --- 配置参数 ---
# 模型文件路径
# MODEL_PATH = os.path.join('YOLOv9e-seg', 'weights', 'best.pt')
MODEL_PATH = os.path.join('YOLO11l-seg', 'best.pt')
# 输入视频文件路径
VIDEO_IN_PATH = 'LC_Video_1.mp4'
# 输出视频文件路径
VIDEO_OUT_PATH = 'output_a_segmented_fps_controlled.mp4'
# --- 新增配置 ---
# 要处理的视频帧率。例如,设置为 5 表示每秒大约只对 5 帧进行模型推理。
# 设置为 None 则处理所有帧。
PROCESS_FPS = None
# 推理参数
CONFIDENCE_THRESHOLD = 0.3 # 置信度阈值,低于此值的检测将被忽略
MASK_ALPHA = 0.4 # 分割掩码的透明度 (0.0 完全透明, 1.0 完全不透明)
# --- 主处理逻辑 ---
def main():
"""
主函数,执行视频分割的完整流程。
"""
# 1. 加载预训练的 YOLOv9 分割模型
# YOLO() 类会自动处理模型的加载、权重初始化以及设备选择(优先使用可用的 GPU
print(f"正在加载模型: {MODEL_PATH}...")
try:
model = YOLO(MODEL_PATH)
except Exception as e:
print(f"错误: 无法加载模型。请检查路径是否正确以及 Ultralytics 是否已正确安装。")
print(f"详细错误: {e}")
return
# 获取模型能够识别的所有类别名称
class_names = model.names
print(f"模型已加载,可识别 {len(class_names)} 个类别。")
# 2. 初始化视频读写对象
# 使用 OpenCV 打开输入视频文件
cap = cv2.VideoCapture(VIDEO_IN_PATH)
if not cap.isOpened():
print(f"错误: 无法打开视频文件: {VIDEO_IN_PATH}")
return
# 获取输入视频的属性(帧率、宽度、高度、总帧数)
fps = cap.get(cv2.CAP_PROP_FPS)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 获取总帧数用于进度条
print(f"输入视频属性: {frame_width}x{frame_height} @ {fps:.2f} FPS, 总共 {total_frames}")
# 定义视频编码器并创建 VideoWriter 对象以保存输出视频
# 'mp4v' 是适用于.mp4 文件的常用编码器
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(VIDEO_OUT_PATH, fourcc, fps, (frame_width, frame_height))
# 3. 为每个类别生成一个稳定的随机颜色
# 这确保了在整个视频中,同一类别的对象掩码颜色保持一致
random.seed(42) # 固定随机种子以保证每次运行颜色相同
class_colors = {
name: [random.randint(0, 255) for _ in range(3)]
for name in class_names.values()
}
# --- 新增:计算帧处理间隔 ---
frame_counter = 0
frame_skip_interval = 1
if PROCESS_FPS is not None and PROCESS_FPS > 0:
frame_skip_interval = round(fps / PROCESS_FPS)
# 确保至少为1避免除零错误
if frame_skip_interval < 1:
frame_skip_interval = 1
print(f"处理帧率设置为 {PROCESS_FPS} FPS. 将每 {frame_skip_interval} 帧运行一次模型推理。")
else:
print("将处理所有帧。")
# 4. 逐帧处理视频,并使用 tqdm 显示进度条
print("开始逐帧处理视频... (在预览窗口激活时按 'q' 键可随时退出)")
for _ in tqdm(range(total_frames), desc="处理进度"):
# 从视频中读取一帧
success, frame = cap.read()
if not success:
# 如果视频读取结束,则跳出循环
print("\n视频读取完毕或发生错误。")
break
frame_counter += 1
# 5. 对当前帧执行 YOLOv9 推理 (有条件地)
# 仅在达到处理间隔时才运行模型,否则将重用上一次的结果
if frame_counter % frame_skip_interval == 0:
results = model.predict(source=frame, conf=CONFIDENCE_THRESHOLD, verbose=False)
# 检查是否有任何有效的 'results' 可供绘制
# 'results' 可能来自当前帧的推理,或来自之前的帧
# `locals()` 检查 `results` 变量是否已在当前作用域中定义
if 'results' in locals() and results is not None:
# ------------------- FIX START -------------------
# `results` 是一个列表,因为我们处理的是单张图片,所以需要取出第一个元素
result = results[0]
# -------------------- FIX END --------------------
# 6. 可视化:绘制半透明掩码和边界框
# 创建一个与原始帧相同的副本,用于绘制掩码。这是实现透明效果的关键步骤。
overlay = frame.copy()
# 检查是否存在检测到的掩码
if result.masks is not None:
# 遍历每个检测到的对象
# zip() 函数将边界框和掩码一一对应起来
for box, mask in zip(result.boxes, result.masks):
# 获取类别 ID 和类别名称
class_id = int(box.cls)
class_name = class_names[class_id]
# 获取该类别的预定义颜色
color = class_colors[class_name]
# --- 绘制分割掩码 ---
# [修正] mask.xy返回的是一个列表的列表需要用np.array处理
points = np.array(mask.xy, dtype=np.int32)
# 在 overlay 层上填充多边形(完全不透明)
cv2.fillPoly(overlay, [points], color)
# --- 绘制边界框和标签 ---
# 获取边界框坐标
x1, y1, x2, y2 = map(int, box.xyxy[0]) # box.xyxy也是列表取第一个
# 在原始帧上绘制矩形边界框
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
# 准备标签文本(类别名 + 置信度)
confidence = float(box.conf)
label = f"{class_name}: {confidence:.2f}"
# 计算文本尺寸以绘制背景
(text_width, text_height), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
# 绘制文本背景框
cv2.rectangle(frame, (x1, y1 - text_height - baseline), (x1 + text_width, y1), color, -1)
# 在背景框上放置文本
cv2.putText(frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
# 7. 融合原始帧和掩码层
# cv2.addWeighted() 将 overlay 和 frame 按照指定的权重(透明度)混合
# 公式为: output = frame * (1 - alpha) + overlay * alpha + 0
cv2.addWeighted(overlay, MASK_ALPHA, frame, 1 - MASK_ALPHA, 0, frame)
# 8. 将处理后的帧写入输出视频文件
out.write(frame)
# # 9. [新增] 实时显示处理后的帧
# cv2.imshow('Real-time Segmentation Preview', frame)
# # 检测 'q' 键是否被按下,如果是则退出循环
# # cv2.waitKey(1) 对于视频处理至关重要,它等待 1ms允许 OpenCV 刷新窗口
# if cv2.waitKey(1) & 0xFF == ord('q'):
# print("\n用户中断处理。")
# break
# 10. 释放资源
cap.release()
out.release()
# cv2.destroyAllWindows()
print(f"处理完成!")
print(f"输出视频已保存至: {VIDEO_OUT_PATH}")
if __name__ == "__main__":
main()