"""Background SAM video propagation runner used by Celery workers.""" import hashlib import json import logging import tempfile from datetime import datetime, timezone from pathlib import Path from typing import Any import cv2 import numpy as np from sqlalchemy.orm import Session from minio_client import download_file from models import Annotation, Frame, ProcessingTask, Project from progress_events import publish_task_progress_event from services.sam_registry import ModelUnavailableError, sam_registry from statuses import ( TASK_STATUS_CANCELLED, TASK_STATUS_FAILED, TASK_STATUS_RUNNING, TASK_STATUS_SUCCESS, ) logger = logging.getLogger(__name__) class PropagationTaskCancelled(RuntimeError): """Raised internally when a persisted propagation task has been cancelled.""" def _now() -> datetime: return datetime.now(timezone.utc) def _set_task_state( db: Session, task: ProcessingTask, *, status: str | None = None, progress: int | None = None, message: str | None = None, result: dict[str, Any] | None = None, error: str | None = None, started: bool = False, finished: bool = False, ) -> None: if status is not None: task.status = status if progress is not None: task.progress = max(0, min(100, progress)) if message is not None: task.message = message if result is not None: task.result = result if error is not None: task.error = error if started: task.started_at = _now() if finished: task.finished_at = _now() db.commit() db.refresh(task) publish_task_progress_event(task) def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None: db.refresh(task) if task.status == TASK_STATUS_CANCELLED: raise PropagationTaskCancelled("Task was cancelled") def _clamp01(value: float) -> float: return min(max(float(value), 0.0), 1.0) def _polygon_bbox(polygon: list[list[float]]) -> list[float]: xs = [_clamp01(point[0]) for point in polygon] ys = [_clamp01(point[1]) for point in polygon] left, right = min(xs), max(xs) top, bottom = min(ys), max(ys) return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)] def _normalize_polygon(polygon: list[list[float]]) -> list[list[float]]: return [[_clamp01(point[0]), _clamp01(point[1])] for point in polygon if len(point) >= 2] def _normalize_smoothing_options(value: Any) -> dict[str, Any] | None: if not isinstance(value, dict): return None try: strength = max(0.0, min(float(value.get("strength") or 0.0), 100.0)) except (TypeError, ValueError): strength = 0.0 if strength <= 0: return None method = str(value.get("method") or "chaikin").lower() if method != "chaikin": method = "chaikin" return {"strength": round(strength, 2), "method": method} def _smoothing_ratio(strength: float, curve: float = 1.65) -> float: normalized = max(0.0, min(float(strength or 0.0), 100.0)) / 100.0 return normalized ** curve def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int, corner_cut: float = 0.25) -> list[list[float]]: points = _normalize_polygon(polygon) q = max(0.02, min(float(corner_cut), 0.25)) for _ in range(max(0, iterations)): if len(points) < 3: break next_points: list[list[float]] = [] for index, current in enumerate(points): following = points[(index + 1) % len(points)] next_points.append([ _clamp01((1.0 - q) * current[0] + q * following[0]), _clamp01((1.0 - q) * current[1] + q * following[1]), ]) next_points.append([ _clamp01(q * current[0] + (1.0 - q) * following[0]), _clamp01(q * current[1] + (1.0 - q) * following[1]), ]) points = next_points return points def _simplify_polygon(polygon: list[list[float]], strength: float) -> list[list[float]]: if len(polygon) < 3: return polygon contour = np.array([[[point[0], point[1]]] for point in polygon], dtype=np.float32) arc_length = cv2.arcLength(contour, True) epsilon = arc_length * (0.00015 + _smoothing_ratio(strength) * 0.00735) approx = cv2.approxPolyDP(contour, epsilon, True).reshape(-1, 2) if len(approx) < 3: return polygon return [[_clamp01(float(x)), _clamp01(float(y))] for x, y in approx] def _smooth_polygon(polygon: list[list[float]], smoothing: dict[str, Any] | None) -> list[list[float]]: if not smoothing: return _normalize_polygon(polygon) strength = float(smoothing.get("strength") or 0.0) if strength <= 0: return _normalize_polygon(polygon) effective_strength = _smoothing_ratio(strength, curve=1.45) * 100.0 if effective_strength >= 85: iterations = 4 elif effective_strength >= 55: iterations = 3 elif effective_strength >= 25: iterations = 2 else: iterations = 1 corner_cut = 0.03 + _smoothing_ratio(strength, curve=1.35) * 0.22 normalized = _normalize_polygon(polygon) pre_simplified = _simplify_polygon(normalized, effective_strength * 0.25) smoothed = _chaikin_smooth_polygon(pre_simplified, iterations, corner_cut) simplified = _simplify_polygon(smoothed, effective_strength) if len(simplified) > len(normalized): for fallback_strength in (25.0, 35.0, 50.0, 70.0, 90.0, 100.0): simplified = _simplify_polygon(simplified, max(effective_strength, fallback_strength)) if len(simplified) <= len(normalized): break return simplified if len(simplified) >= 3 else _normalize_polygon(polygon) def _bbox_area(bbox: list[float]) -> float: return max(float(bbox[2]), 0.0) * max(float(bbox[3]), 0.0) def _bbox_overlap_ratio(a: list[float], b: list[float]) -> float: ax1, ay1, aw, ah = a bx1, by1, bw, bh = b ax2 = ax1 + aw ay2 = ay1 + ah bx2 = bx1 + bw by2 = by1 + bh overlap_width = max(0.0, min(ax2, bx2) - max(ax1, bx1)) overlap_height = max(0.0, min(ay2, by2) - max(ay1, by1)) overlap_area = overlap_width * overlap_height smallest_area = min(_bbox_area(a), _bbox_area(b)) return overlap_area / smallest_area if smallest_area > 0 else 0.0 def _stable_json(value: Any) -> str: return json.dumps(value, ensure_ascii=False, sort_keys=True, separators=(",", ":")) def _canonicalize_signature_value(value: Any) -> Any: if isinstance(value, float): return round(value, 6) if isinstance(value, list): return [_canonicalize_signature_value(item) for item in value] if isinstance(value, dict): return {key: _canonicalize_signature_value(value[key]) for key in sorted(value)} return value def _seed_signature(seed: dict[str, Any]) -> str: """Return a stable signature for seed geometry and semantic attrs.""" inherited_signature = seed.get("propagation_seed_signature") if inherited_signature: return str(inherited_signature) signature_payload = { "polygons": seed.get("polygons") or [], "bbox": seed.get("bbox") or [], "points": seed.get("points") or [], "labels": seed.get("labels") or [], "label": seed.get("label"), "color": seed.get("color"), "class_metadata": seed.get("class_metadata") or {}, "template_id": seed.get("template_id"), "smoothing": _normalize_smoothing_options(seed.get("smoothing")), } return hashlib.sha256(_stable_json(_canonicalize_signature_value(signature_payload)).encode("utf-8")).hexdigest() def _seed_key(seed: dict[str, Any]) -> str: """Prefer stable persisted ids; fall back to semantic attrs for legacy callers.""" source_annotation_id = seed.get("source_annotation_id") if source_annotation_id is not None: return f"annotation:{source_annotation_id}" source_mask_id = seed.get("source_mask_id") if source_mask_id: return f"mask:{source_mask_id}" class_metadata = seed.get("class_metadata") or {} class_id = class_metadata.get("id") or class_metadata.get("name") return _stable_json({ "template_id": seed.get("template_id"), "class_id": class_id, "label": seed.get("label"), "color": seed.get("color"), }) def _semantic_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool: """Best-effort match when a manually edited replacement lacks old lineage ids.""" class_metadata = seed.get("class_metadata") or {} previous_class = mask_data.get("class") or {} previous_class_id = previous_class.get("id") or previous_class.get("name") class_id = class_metadata.get("id") or class_metadata.get("name") if previous_class_id and class_id and str(previous_class_id) != str(class_id): return False return ( mask_data.get("label") == seed.get("label") and mask_data.get("color") == seed.get("color") ) def _legacy_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool: """Best-effort match for propagation annotations created before seed keys.""" class_metadata = seed.get("class_metadata") or {} previous_class = mask_data.get("class") or {} previous_class_id = previous_class.get("id") or previous_class.get("name") class_id = class_metadata.get("id") or class_metadata.get("name") return ( mask_data.get("label") == seed.get("label") and mask_data.get("color") == seed.get("color") and previous_class_id == class_id ) def _source_model_matches(mask_data: dict[str, Any], model_id: str) -> bool: return str(mask_data.get("source") or "") == f"{model_id}_propagation" def _seed_identity_matches(mask_data: dict[str, Any], seed_key: str, seed: dict[str, Any]) -> bool: previous_seed_key = mask_data.get("propagation_seed_key") if previous_seed_key == seed_key: return True source_annotation_id = seed.get("source_annotation_id") if source_annotation_id is not None and str(mask_data.get("source_annotation_id") or "") == str(source_annotation_id): return True source_mask_id = seed.get("source_mask_id") if source_mask_id and mask_data.get("source_mask_id") == source_mask_id: return True return _legacy_seed_matches(mask_data, seed) def _is_propagation_annotation(annotation: Annotation, seed_key: str, seed: dict[str, Any]) -> bool: mask_data = annotation.mask_data or {} source = str(mask_data.get("source") or "") if not source.endswith("_propagation"): return False return _seed_identity_matches(mask_data, seed_key, seed) def _direction_matches(mask_data: dict[str, Any], direction: str) -> bool: previous_direction = mask_data.get("propagation_direction") return previous_direction in {None, direction} def _annotation_spatially_matches(annotation: Annotation, polygon: list[list[float]]) -> bool: """Use target-frame overlap as a final guard before replacing same-object propagation.""" candidate_bbox = _polygon_bbox(polygon) for previous_polygon in (annotation.mask_data or {}).get("polygons") or []: if len(previous_polygon) < 3: continue if _bbox_overlap_ratio(_polygon_bbox(previous_polygon), candidate_bbox) >= 0.15: return True return False def _delete_replaced_frame_annotations( db: Session, *, payload: dict[str, Any], frame_id: int, seed_key: str, seed: dict[str, Any], polygon: list[list[float]], ) -> int: """Delete old propagated masks for the same object immediately before writing a new result.""" previous_annotations = ( db.query(Annotation) .filter(Annotation.project_id == int(payload["project_id"])) .filter(Annotation.frame_id == frame_id) .all() ) deleted_count = 0 for annotation in previous_annotations: mask_data = annotation.mask_data or {} source = str(mask_data.get("source") or "") if not source.endswith("_propagation"): continue same_lineage = _seed_identity_matches(mask_data, seed_key, seed) same_manual_replacement = ( _semantic_seed_matches(mask_data, seed) and _annotation_spatially_matches(annotation, polygon) ) if same_lineage or same_manual_replacement: db.delete(annotation) deleted_count += 1 if deleted_count: db.commit() return deleted_count def _prepare_seed_propagation( db: Session, *, payload: dict[str, Any], model_id: str, seed: dict[str, Any], direction: str, target_frame_ids: set[int], ) -> dict[str, Any]: seed_key = _seed_key(seed) seed_signature = _seed_signature(seed) if not target_frame_ids: return { "skip": True, "seed_key": seed_key, "seed_signature": seed_signature, "deleted_annotation_count": 0, } previous_annotations = ( db.query(Annotation) .filter(Annotation.project_id == int(payload["project_id"])) .filter(Annotation.frame_id.in_(target_frame_ids)) .all() ) matching = [ annotation for annotation in previous_annotations if _is_propagation_annotation(annotation, seed_key, seed) and _direction_matches(annotation.mask_data or {}, direction) ] covered_frame_ids = {int(annotation.frame_id) for annotation in matching} if matching and all( (annotation.mask_data or {}).get("propagation_seed_signature") == seed_signature and _source_model_matches(annotation.mask_data or {}, model_id) for annotation in matching ) and target_frame_ids.issubset(covered_frame_ids): return { "skip": True, "seed_key": seed_key, "seed_signature": seed_signature, "deleted_annotation_count": 0, } deleted_count = 0 if matching: for annotation in matching: db.delete(annotation) deleted_count += 1 db.commit() return { "skip": False, "seed_key": seed_key, "seed_signature": seed_signature, "deleted_annotation_count": deleted_count, } def _frame_window( frames: list[Frame], source_position: int, direction: str, max_frames: int, ) -> tuple[list[Frame], int]: count = max(1, min(max_frames, len(frames))) if direction == "backward": start = max(0, source_position - count + 1) return frames[start:source_position + 1], source_position - start end = min(len(frames), source_position + count) return frames[source_position:end], 0 def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]: paths = [] for index, frame in enumerate(frames): data = download_file(frame.image_url) # SAM2VideoPredictor sorts frames by converting the filename stem to int. path = directory / f"{index:06d}.jpg" path.write_bytes(data) paths.append(str(path)) return paths def _save_propagated_annotations( db: Session, *, payload: dict[str, Any], selected_frames: list[Frame], source_frame: Frame, propagated: list[dict[str, Any]], seed: dict[str, Any], ) -> tuple[list[Annotation], int]: created: list[Annotation] = [] if payload.get("save_annotations", True) is False: return created, 0 class_metadata = seed.get("class_metadata") template_id = seed.get("template_id") label = seed.get("label") or "Propagated Mask" color = seed.get("color") or "#06b6d4" model_id = sam_registry.normalize_model_id(payload.get("model")) include_source = bool(payload.get("include_source", False)) seed_key = _seed_key(seed) seed_signature = _seed_signature(seed) source_annotation_id = seed.get("source_annotation_id") source_mask_id = seed.get("source_mask_id") smoothing = _normalize_smoothing_options(seed.get("smoothing")) direction = str(payload.get("current_direction") or "") deleted_count = 0 cleaned_frame_ids: set[int] = set() for frame_result in propagated: relative_index = int(frame_result.get("frame_index", -1)) if relative_index < 0 or relative_index >= len(selected_frames): continue frame = selected_frames[relative_index] if not include_source and frame.id == source_frame.id: continue result_polygons = frame_result.get("polygons") or [] scores = frame_result.get("scores") or [] smoothed_polygons = [ _smooth_polygon(polygon, smoothing) for polygon in result_polygons if len(polygon) >= 3 ] cleanup_polygon = next((polygon for polygon in smoothed_polygons if len(polygon) >= 3), None) if cleanup_polygon is not None and frame.id not in cleaned_frame_ids: deleted_count += _delete_replaced_frame_annotations( db, payload=payload, frame_id=int(frame.id), seed_key=seed_key, seed=seed, polygon=cleanup_polygon, ) cleaned_frame_ids.add(int(frame.id)) for polygon_index, polygon in enumerate(smoothed_polygons): if len(polygon) < 3: continue annotation = Annotation( project_id=int(payload["project_id"]), frame_id=frame.id, template_id=template_id, mask_data={ "polygons": [polygon], "label": label, "color": color, "source": f"{model_id}_propagation", "propagated_from_frame_id": source_frame.id, "propagated_from_frame_index": source_frame.frame_index, "propagation_seed_key": seed_key, "propagation_seed_signature": seed_signature, "propagation_direction": direction, "source_annotation_id": source_annotation_id, "source_mask_id": source_mask_id, "score": scores[polygon_index] if polygon_index < len(scores) else None, **({"geometry_smoothing": smoothing} if smoothing else {}), **({"class": class_metadata} if class_metadata else {}), }, points=None, bbox=_polygon_bbox(polygon), ) db.add(annotation) created.append(annotation) db.commit() for annotation in created: db.refresh(annotation) return created, deleted_count def _run_one_step( db: Session, *, payload: dict[str, Any], frames: list[Frame], source_frame: Frame, source_position: int, step: dict[str, Any], ) -> dict[str, Any]: direction = str(step.get("direction") or "forward").lower() if direction not in {"forward", "backward"}: raise ValueError("direction must be forward or backward") max_frames = max(1, min(int(step.get("max_frames") or payload.get("max_frames") or 30), 500)) seed = step.get("seed") or {} if not (seed.get("polygons") or seed.get("bbox") or seed.get("points")): raise ValueError("Propagation requires seed polygons, bbox, or points") model_id = sam_registry.normalize_model_id(payload.get("model")) selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames) include_source = bool(payload.get("include_source", False)) target_frame_ids = { int(frame.id) for frame in selected_frames if include_source or frame.id != source_frame.id } seed_state = _prepare_seed_propagation( db, payload=payload, model_id=model_id, seed=seed, direction=direction, target_frame_ids=target_frame_ids, ) if seed_state["skip"]: return { "model": model_id, "direction": direction, "processed_frame_count": 0, "created_annotation_count": 0, "deleted_annotation_count": 0, "skipped_seed_count": 1, "seed_label": seed.get("label"), "seed_key": seed_state["seed_key"], } with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload['project_id']}_") as tmpdir: frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir)) propagated = sam_registry.propagate_video( model_id, frame_paths, source_relative_index, seed, direction, len(selected_frames), ) save_payload = {**payload, "current_direction": direction} created, write_cleanup_count = _save_propagated_annotations( db, payload=save_payload, selected_frames=selected_frames, source_frame=source_frame, propagated=propagated, seed=seed, ) return { "model": model_id, "direction": direction, "processed_frame_count": len(selected_frames), "created_annotation_count": len(created), "deleted_annotation_count": int(seed_state["deleted_annotation_count"]) + write_cleanup_count, "skipped_seed_count": 0, "seed_label": seed.get("label"), "seed_key": seed_state["seed_key"], } def run_propagate_project_task(db: Session, task_id: int) -> dict[str, Any]: """Run one queued SAM propagation task and update persisted progress.""" task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first() if not task: raise ValueError(f"Task not found: {task_id}") if task.status == TASK_STATUS_CANCELLED: return {"task_id": task.id, "status": TASK_STATUS_CANCELLED, "message": task.message or "任务已取消"} payload = task.payload or {} project_id = int(payload.get("project_id") or task.project_id or 0) source_frame_id = int(payload.get("frame_id") or 0) try: model_id = sam_registry.normalize_model_id(payload.get("model")) except ValueError as exc: _set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True) raise project = db.query(Project).filter(Project.id == project_id).first() if not project: _set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="项目不存在", error="Project not found", finished=True) raise ValueError(f"Project not found: {project_id}") source_frame = db.query(Frame).filter(Frame.id == source_frame_id, Frame.project_id == project_id).first() if not source_frame: _set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不存在", error="Frame not found", finished=True) raise ValueError(f"Frame not found: {source_frame_id}") frames = db.query(Frame).filter(Frame.project_id == project_id).order_by(Frame.frame_index).all() source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None) if source_position is None: _set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不在项目帧序列中", error="Source frame is not in project frame sequence", finished=True) raise ValueError("Source frame is not in project frame sequence") steps = payload.get("steps") or [] if not steps: _set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="传播任务缺少步骤", error="Propagation task has no steps", finished=True) raise ValueError("Propagation task has no steps") _ensure_not_cancelled(db, task) _set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="自动传播任务已启动", started=True) step_results: list[dict[str, Any]] = [] created_count = 0 processed_count = 0 deleted_count = 0 skipped_count = 0 total_steps = len(steps) try: for index, step in enumerate(steps, start=1): _ensure_not_cancelled(db, task) seed_label = (step.get("seed") or {}).get("label") or "mask" direction_label = "向前传播" if step.get("direction") == "backward" else "向后传播" progress_before = 5 + int(((index - 1) / total_steps) * 90) _set_task_state( db, task, progress=progress_before, message=f"{direction_label} {seed_label} ({index}/{total_steps})", result={ "project_id": project_id, "source_frame_id": source_frame_id, "model": model_id, "total_steps": total_steps, "completed_steps": index - 1, "processed_frame_count": processed_count, "created_annotation_count": created_count, "deleted_annotation_count": deleted_count, "skipped_seed_count": skipped_count, "steps": step_results, }, ) result = _run_one_step( db, payload=payload, frames=frames, source_frame=source_frame, source_position=source_position, step=step, ) step_results.append(result) created_count += int(result["created_annotation_count"]) processed_count += int(result["processed_frame_count"]) deleted_count += int(result.get("deleted_annotation_count") or 0) skipped_count += int(result.get("skipped_seed_count") or 0) _set_task_state( db, task, progress=5 + int((index / total_steps) * 90), message=f"{direction_label} {seed_label} 完成 ({index}/{total_steps})", result={ "project_id": project_id, "source_frame_id": source_frame_id, "model": model_id, "total_steps": total_steps, "completed_steps": index, "processed_frame_count": processed_count, "created_annotation_count": created_count, "deleted_annotation_count": deleted_count, "skipped_seed_count": skipped_count, "steps": step_results, }, ) result = { "project_id": project_id, "source_frame_id": source_frame_id, "model": model_id, "total_steps": total_steps, "completed_steps": total_steps, "processed_frame_count": processed_count, "created_annotation_count": created_count, "deleted_annotation_count": deleted_count, "skipped_seed_count": skipped_count, "steps": step_results, } _set_task_state( db, task, status=TASK_STATUS_SUCCESS, progress=100, message="自动传播完成" if created_count > 0 else ( "自动传播完成,未改变的 mask 已跳过" if skipped_count > 0 else "自动传播完成,但没有生成新的 mask" ), result=result, finished=True, ) return result except PropagationTaskCancelled: task.status = TASK_STATUS_CANCELLED task.progress = 100 task.message = task.message or "任务已取消" task.error = task.error or "Cancelled by user" task.finished_at = task.finished_at or _now() db.commit() db.refresh(task) publish_task_progress_event(task) return {"task_id": task.id, "project_id": project_id, "status": TASK_STATUS_CANCELLED, "message": task.message} except (ModelUnavailableError, NotImplementedError, ValueError) as exc: _set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True) raise except Exception as exc: # noqa: BLE001 logger.exception("Propagation task failed: task_id=%s", task.id) _set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True) raise