Files
Seg_Data_Server/Seg_Predict_YoloModel/yolo_Seg_Video-V2-UnVisible.py
2026-05-20 15:05:35 +08:00

193 lines
7.3 KiB
Python

# -*- coding: utf-8 -*-
import cv2
import numpy as np
from ultralytics import YOLO
import random
import os
from tqdm import tqdm
import multiprocessing as mp
import time
# --- Configuration ---
MODEL_PATH = os.path.join('YOLOv9e-seg', 'weights', 'best.pt')
VIDEO_IN_PATH = 'LC_Video_1.mp4'
VIDEO_OUT_PATH = 'output_parallel_segmented.mp4'
CONFIDENCE_THRESHOLD = 0.3
MASK_ALPHA = 0.4
# --- Parallel Processing Configuration ---
# Set the number of parallel processes. Using cpu_count() - 1 is a safe choice.
# You can adjust this based on your system's resources.
NUM_WORKERS = min(5, max(1, mp.cpu_count() - 1))
def process_frame_worker(task_queue, results_queue, model_path, confidence_threshold):
"""
Worker function for parallel processing.
Each worker loads its own model instance and processes frames from the task queue.
"""
try:
model = YOLO(model_path)
except Exception as e:
print(f"[Worker PID: {os.getpid()}] Error loading model: {e}")
return
while True:
task = task_queue.get()
# A None task is a signal to terminate the worker
if task is None:
break
frame_index, frame = task
try:
results = model.predict(source=frame, conf=confidence_threshold, verbose=False)
# We only need the first result as we process one frame at a time
results_queue.put((frame_index, results[0]))
except Exception as e:
print(f"[Worker PID: {os.getpid()}] Error processing frame {frame_index}: {e}")
# Put a placeholder to avoid deadlocks
results_queue.put((frame_index, None))
def main():
"""
Main function to execute video segmentation using parallel processing.
"""
# 1. Initialize video capture and get properties
cap = cv2.VideoCapture(VIDEO_IN_PATH)
if not cap.isOpened():
print(f"Error: Could not open video file: {VIDEO_IN_PATH}")
return
fps = cap.get(cv2.CAP_PROP_FPS)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"Input video: {frame_width}x{frame_height} @ {fps:.2f} FPS, {total_frames} total frames")
print(f"Using {NUM_WORKERS} parallel workers for processing.")
# 2. Initialize video writer
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(VIDEO_OUT_PATH, fourcc, fps, (frame_width, frame_height))
# 3. Generate stable random colors for each class
# We load a temporary model instance just to get the class names
try:
temp_model = YOLO(MODEL_PATH)
class_names = temp_model.names
del temp_model
except Exception as e:
print(f"Fatal Error: Could not load model to get class names. Aborting. Details: {e}")
cap.release()
out.release()
return
random.seed(42)
class_colors = {name: [random.randint(0, 255) for _ in range(3)] for name in class_names.values()}
print(f"Model recognizes {len(class_names)} classes.")
# 4. Set up multiprocessing queues and processes
manager = mp.Manager()
# Queue for frames to be processed. Maxsize helps prevent memory overload.
task_queue = manager.Queue(maxsize=NUM_WORKERS * 2)
# Queue for processed results.
results_queue = manager.Queue(maxsize=NUM_WORKERS * 2)
processes = []
for _ in range(NUM_WORKERS):
p = mp.Process(target=process_frame_worker, args=(task_queue, results_queue, MODEL_PATH, CONFIDENCE_THRESHOLD))
processes.append(p)
p.start()
# 5. Main process: Read frames and dispatch to workers
frame_index = 0
pbar_read = tqdm(total=total_frames, desc="Reading frames")
while True:
success, frame = cap.read()
if not success:
break
task_queue.put((frame_index, frame))
frame_index += 1
pbar_read.update(1)
pbar_read.close()
# Signal workers to terminate by sending None tasks
for _ in range(NUM_WORKERS):
task_queue.put(None)
# 6. Main process: Collect results, draw segmentation, and write to video
# This part must be sequential to ensure correct video order.
frames_to_write = {}
next_frame_to_write = 0
pbar_write = tqdm(total=total_frames, desc="Processing & Writing")
for _ in range(total_frames):
# Get a result from the queue (this might be out of order)
res_index, result = results_queue.get()
# If a worker failed, the result might be None
if result is None:
# We need a placeholder frame. We'll re-read it.
# This is inefficient but robust against worker failure.
cap.set(cv2.CAP_PROP_POS_FRAMES, res_index)
_, frame = cap.read()
frames_to_write[res_index] = (frame, None)
else:
# Re-read the original frame to draw on.
# This avoids passing bulky frames through the results queue.
cap.set(cv2.CAP_PROP_POS_FRAMES, res_index)
_, frame = cap.read()
frames_to_write[res_index] = (frame, result)
# Write all consecutive frames that are now available
while next_frame_to_write in frames_to_write:
frame, result = frames_to_write.pop(next_frame_to_write)
if result is not None:
overlay = frame.copy()
if result.masks is not None:
for box, mask in zip(result.boxes, result.masks):
class_id = int(box.cls)
class_name = class_names[class_id]
color = class_colors[class_name]
# Draw segmentation mask
points = np.array(mask.xy, dtype=np.int32)
cv2.fillPoly(overlay, [points], color)
# Draw bounding box and label
x1, y1, x2, y2 = map(int, box.xyxy[0])
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
confidence = float(box.conf)
label = f"{class_name}: {confidence:.2f}"
(w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(frame, (x1, y1 - h - 5), (x1 + w, y1), color, -1)
cv2.putText(frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
# Blend the overlay with the original frame
cv2.addWeighted(overlay, MASK_ALPHA, frame, 1 - MASK_ALPHA, 0, frame)
out.write(frame)
pbar_write.update(1)
next_frame_to_write += 1
pbar_write.close()
# 7. Clean up all resources
for p in processes:
p.join()
cap.release()
out.release()
print("Processing complete!")
print(f"Output video saved to: {VIDEO_OUT_PATH}")
if __name__ == "__main__":
# On Windows and macOS, 'fork' is not the default start method.
# 'spawn' or 'forkserver' are safer and required on these platforms.
# It's good practice to set it explicitly.
mp.set_start_method("spawn", force=True)
main()