193 lines
7.3 KiB
Python
193 lines
7.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from ultralytics import YOLO
|
|
import random
|
|
import os
|
|
from tqdm import tqdm
|
|
import multiprocessing as mp
|
|
import time
|
|
|
|
# --- Configuration ---
|
|
MODEL_PATH = os.path.join('YOLOv9e-seg', 'weights', 'best.pt')
|
|
VIDEO_IN_PATH = 'LC_Video_1.mp4'
|
|
VIDEO_OUT_PATH = 'output_parallel_segmented.mp4'
|
|
CONFIDENCE_THRESHOLD = 0.3
|
|
MASK_ALPHA = 0.4
|
|
|
|
# --- Parallel Processing Configuration ---
|
|
# Set the number of parallel processes. Using cpu_count() - 1 is a safe choice.
|
|
# You can adjust this based on your system's resources.
|
|
NUM_WORKERS = min(5, max(1, mp.cpu_count() - 1))
|
|
|
|
|
|
def process_frame_worker(task_queue, results_queue, model_path, confidence_threshold):
|
|
"""
|
|
Worker function for parallel processing.
|
|
Each worker loads its own model instance and processes frames from the task queue.
|
|
"""
|
|
try:
|
|
model = YOLO(model_path)
|
|
except Exception as e:
|
|
print(f"[Worker PID: {os.getpid()}] Error loading model: {e}")
|
|
return
|
|
|
|
while True:
|
|
task = task_queue.get()
|
|
# A None task is a signal to terminate the worker
|
|
if task is None:
|
|
break
|
|
|
|
frame_index, frame = task
|
|
try:
|
|
results = model.predict(source=frame, conf=confidence_threshold, verbose=False)
|
|
# We only need the first result as we process one frame at a time
|
|
results_queue.put((frame_index, results[0]))
|
|
except Exception as e:
|
|
print(f"[Worker PID: {os.getpid()}] Error processing frame {frame_index}: {e}")
|
|
# Put a placeholder to avoid deadlocks
|
|
results_queue.put((frame_index, None))
|
|
|
|
|
|
def main():
|
|
"""
|
|
Main function to execute video segmentation using parallel processing.
|
|
"""
|
|
# 1. Initialize video capture and get properties
|
|
cap = cv2.VideoCapture(VIDEO_IN_PATH)
|
|
if not cap.isOpened():
|
|
print(f"Error: Could not open video file: {VIDEO_IN_PATH}")
|
|
return
|
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
print(f"Input video: {frame_width}x{frame_height} @ {fps:.2f} FPS, {total_frames} total frames")
|
|
print(f"Using {NUM_WORKERS} parallel workers for processing.")
|
|
|
|
# 2. Initialize video writer
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
out = cv2.VideoWriter(VIDEO_OUT_PATH, fourcc, fps, (frame_width, frame_height))
|
|
|
|
# 3. Generate stable random colors for each class
|
|
# We load a temporary model instance just to get the class names
|
|
try:
|
|
temp_model = YOLO(MODEL_PATH)
|
|
class_names = temp_model.names
|
|
del temp_model
|
|
except Exception as e:
|
|
print(f"Fatal Error: Could not load model to get class names. Aborting. Details: {e}")
|
|
cap.release()
|
|
out.release()
|
|
return
|
|
|
|
random.seed(42)
|
|
class_colors = {name: [random.randint(0, 255) for _ in range(3)] for name in class_names.values()}
|
|
print(f"Model recognizes {len(class_names)} classes.")
|
|
|
|
# 4. Set up multiprocessing queues and processes
|
|
manager = mp.Manager()
|
|
# Queue for frames to be processed. Maxsize helps prevent memory overload.
|
|
task_queue = manager.Queue(maxsize=NUM_WORKERS * 2)
|
|
# Queue for processed results.
|
|
results_queue = manager.Queue(maxsize=NUM_WORKERS * 2)
|
|
|
|
processes = []
|
|
for _ in range(NUM_WORKERS):
|
|
p = mp.Process(target=process_frame_worker, args=(task_queue, results_queue, MODEL_PATH, CONFIDENCE_THRESHOLD))
|
|
processes.append(p)
|
|
p.start()
|
|
|
|
# 5. Main process: Read frames and dispatch to workers
|
|
frame_index = 0
|
|
pbar_read = tqdm(total=total_frames, desc="Reading frames")
|
|
while True:
|
|
success, frame = cap.read()
|
|
if not success:
|
|
break
|
|
task_queue.put((frame_index, frame))
|
|
frame_index += 1
|
|
pbar_read.update(1)
|
|
pbar_read.close()
|
|
|
|
# Signal workers to terminate by sending None tasks
|
|
for _ in range(NUM_WORKERS):
|
|
task_queue.put(None)
|
|
|
|
# 6. Main process: Collect results, draw segmentation, and write to video
|
|
# This part must be sequential to ensure correct video order.
|
|
frames_to_write = {}
|
|
next_frame_to_write = 0
|
|
pbar_write = tqdm(total=total_frames, desc="Processing & Writing")
|
|
|
|
for _ in range(total_frames):
|
|
# Get a result from the queue (this might be out of order)
|
|
res_index, result = results_queue.get()
|
|
|
|
# If a worker failed, the result might be None
|
|
if result is None:
|
|
# We need a placeholder frame. We'll re-read it.
|
|
# This is inefficient but robust against worker failure.
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, res_index)
|
|
_, frame = cap.read()
|
|
frames_to_write[res_index] = (frame, None)
|
|
else:
|
|
# Re-read the original frame to draw on.
|
|
# This avoids passing bulky frames through the results queue.
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, res_index)
|
|
_, frame = cap.read()
|
|
frames_to_write[res_index] = (frame, result)
|
|
|
|
# Write all consecutive frames that are now available
|
|
while next_frame_to_write in frames_to_write:
|
|
frame, result = frames_to_write.pop(next_frame_to_write)
|
|
|
|
if result is not None:
|
|
overlay = frame.copy()
|
|
if result.masks is not None:
|
|
for box, mask in zip(result.boxes, result.masks):
|
|
class_id = int(box.cls)
|
|
class_name = class_names[class_id]
|
|
color = class_colors[class_name]
|
|
|
|
# Draw segmentation mask
|
|
points = np.array(mask.xy, dtype=np.int32)
|
|
cv2.fillPoly(overlay, [points], color)
|
|
|
|
# Draw bounding box and label
|
|
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
|
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
|
|
|
|
confidence = float(box.conf)
|
|
label = f"{class_name}: {confidence:.2f}"
|
|
(w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
|
cv2.rectangle(frame, (x1, y1 - h - 5), (x1 + w, y1), color, -1)
|
|
cv2.putText(frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
|
|
|
# Blend the overlay with the original frame
|
|
cv2.addWeighted(overlay, MASK_ALPHA, frame, 1 - MASK_ALPHA, 0, frame)
|
|
|
|
out.write(frame)
|
|
pbar_write.update(1)
|
|
next_frame_to_write += 1
|
|
|
|
pbar_write.close()
|
|
|
|
# 7. Clean up all resources
|
|
for p in processes:
|
|
p.join()
|
|
|
|
cap.release()
|
|
out.release()
|
|
print("Processing complete!")
|
|
print(f"Output video saved to: {VIDEO_OUT_PATH}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# On Windows and macOS, 'fork' is not the default start method.
|
|
# 'spawn' or 'forkserver' are safer and required on these platforms.
|
|
# It's good practice to set it explicitly.
|
|
mp.set_start_method("spawn", force=True)
|
|
main() |