Files
Pre_Seg_Server/backend/services/propagation_task_runner.py
admin 0485ce4d92 保持传播多区域结果为单个遮罩
- 后端传播落库时将同一 seed 在同一目标帧的多个不连通 polygon 保存到同一 annotation
- 同步任务传播和兼容同步传播接口的多 polygon 保存逻辑
- 传播结果 bbox 改为覆盖全部不连通 polygon,并保留多 polygon scores 与 holes
- 前端回显单条多 polygon annotation 时使用组合 bbox 和真实 polygon 面积
- 补充后端传播 worker 回归测试,验证不连通结果只生成一个 annotation
- 补充前端 API 回归测试,验证多 polygon annotation 回显为一个 mask
- 更新项目指南和设计冻结文档
2026-05-04 02:32:31 +08:00

808 lines
31 KiB
Python

"""Background SAM video propagation runner used by Celery workers."""
import hashlib
import json
import logging
import tempfile
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import cv2
import numpy as np
from sqlalchemy.orm import Session
from minio_client import download_file
from models import Annotation, Frame, ProcessingTask, Project
from progress_events import publish_task_progress_event
from services.sam_registry import ModelUnavailableError, sam_registry
from statuses import (
TASK_STATUS_CANCELLED,
TASK_STATUS_FAILED,
TASK_STATUS_RUNNING,
TASK_STATUS_SUCCESS,
)
logger = logging.getLogger(__name__)
class PropagationTaskCancelled(RuntimeError):
"""Raised internally when a persisted propagation task has been cancelled."""
def _now() -> datetime:
return datetime.now(timezone.utc)
def _set_task_state(
db: Session,
task: ProcessingTask,
*,
status: str | None = None,
progress: int | None = None,
message: str | None = None,
result: dict[str, Any] | None = None,
error: str | None = None,
started: bool = False,
finished: bool = False,
) -> None:
if status is not None:
task.status = status
if progress is not None:
task.progress = max(0, min(100, progress))
if message is not None:
task.message = message
if result is not None:
task.result = result
if error is not None:
task.error = error
if started:
task.started_at = _now()
if finished:
task.finished_at = _now()
db.commit()
db.refresh(task)
publish_task_progress_event(task)
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
db.refresh(task)
if task.status == TASK_STATUS_CANCELLED:
raise PropagationTaskCancelled("Task was cancelled")
def _clamp01(value: float) -> float:
return min(max(float(value), 0.0), 1.0)
def _polygon_bbox(polygon: list[list[float]]) -> list[float]:
xs = [_clamp01(point[0]) for point in polygon]
ys = [_clamp01(point[1]) for point in polygon]
left, right = min(xs), max(xs)
top, bottom = min(ys), max(ys)
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
def _polygons_bbox(polygons: list[list[list[float]]]) -> list[float]:
points = [point for polygon in polygons for point in polygon if len(point) >= 2]
if not points:
return [0.0, 0.0, 0.0, 0.0]
xs = [_clamp01(point[0]) for point in points]
ys = [_clamp01(point[1]) for point in points]
left, right = min(xs), max(xs)
top, bottom = min(ys), max(ys)
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
def _normalize_polygon(polygon: list[list[float]]) -> list[list[float]]:
return [[_clamp01(point[0]), _clamp01(point[1])] for point in polygon if len(point) >= 2]
def _normalize_smoothing_options(value: Any) -> dict[str, Any] | None:
if not isinstance(value, dict):
return None
try:
strength = max(0.0, min(float(value.get("strength") or 0.0), 100.0))
except (TypeError, ValueError):
strength = 0.0
if strength <= 0:
return None
method = str(value.get("method") or "chaikin").lower()
if method != "chaikin":
method = "chaikin"
return {"strength": round(strength, 2), "method": method}
def _smoothing_ratio(strength: float, curve: float = 1.65) -> float:
normalized = max(0.0, min(float(strength or 0.0), 100.0)) / 100.0
return normalized ** curve
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int, corner_cut: float = 0.25) -> list[list[float]]:
points = _normalize_polygon(polygon)
q = max(0.02, min(float(corner_cut), 0.25))
for _ in range(max(0, iterations)):
if len(points) < 3:
break
next_points: list[list[float]] = []
for index, current in enumerate(points):
following = points[(index + 1) % len(points)]
next_points.append([
_clamp01((1.0 - q) * current[0] + q * following[0]),
_clamp01((1.0 - q) * current[1] + q * following[1]),
])
next_points.append([
_clamp01(q * current[0] + (1.0 - q) * following[0]),
_clamp01(q * current[1] + (1.0 - q) * following[1]),
])
points = next_points
return points
def _simplify_polygon(polygon: list[list[float]], strength: float) -> list[list[float]]:
if len(polygon) < 3:
return polygon
contour = np.array([[[point[0], point[1]]] for point in polygon], dtype=np.float32)
arc_length = cv2.arcLength(contour, True)
epsilon = arc_length * (0.00015 + _smoothing_ratio(strength) * 0.00735)
approx = cv2.approxPolyDP(contour, epsilon, True).reshape(-1, 2)
if len(approx) < 3:
return polygon
return [[_clamp01(float(x)), _clamp01(float(y))] for x, y in approx]
def _smooth_polygon(polygon: list[list[float]], smoothing: dict[str, Any] | None) -> list[list[float]]:
if not smoothing:
return _normalize_polygon(polygon)
strength = float(smoothing.get("strength") or 0.0)
if strength <= 0:
return _normalize_polygon(polygon)
effective_strength = _smoothing_ratio(strength, curve=1.45) * 100.0
if effective_strength >= 85:
iterations = 4
elif effective_strength >= 55:
iterations = 3
elif effective_strength >= 25:
iterations = 2
else:
iterations = 1
corner_cut = 0.03 + _smoothing_ratio(strength, curve=1.35) * 0.22
normalized = _normalize_polygon(polygon)
pre_simplified = _simplify_polygon(normalized, effective_strength * 0.25)
smoothed = _chaikin_smooth_polygon(pre_simplified, iterations, corner_cut)
simplified = _simplify_polygon(smoothed, effective_strength)
if len(simplified) > len(normalized):
for fallback_strength in (25.0, 35.0, 50.0, 70.0, 90.0, 100.0):
simplified = _simplify_polygon(simplified, max(effective_strength, fallback_strength))
if len(simplified) <= len(normalized):
break
return simplified if len(simplified) >= 3 else _normalize_polygon(polygon)
def _bbox_area(bbox: list[float]) -> float:
return max(float(bbox[2]), 0.0) * max(float(bbox[3]), 0.0)
def _bbox_overlap_ratio(a: list[float], b: list[float]) -> float:
ax1, ay1, aw, ah = a
bx1, by1, bw, bh = b
ax2 = ax1 + aw
ay2 = ay1 + ah
bx2 = bx1 + bw
by2 = by1 + bh
overlap_width = max(0.0, min(ax2, bx2) - max(ax1, bx1))
overlap_height = max(0.0, min(ay2, by2) - max(ay1, by1))
overlap_area = overlap_width * overlap_height
smallest_area = min(_bbox_area(a), _bbox_area(b))
return overlap_area / smallest_area if smallest_area > 0 else 0.0
def _stable_json(value: Any) -> str:
return json.dumps(value, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
def _canonicalize_signature_value(value: Any) -> Any:
if isinstance(value, float):
return round(value, 6)
if isinstance(value, list):
return [_canonicalize_signature_value(item) for item in value]
if isinstance(value, dict):
return {key: _canonicalize_signature_value(value[key]) for key in sorted(value)}
return value
def _seed_signature(seed: dict[str, Any]) -> str:
"""Return a stable signature for seed geometry and semantic attrs."""
inherited_signature = seed.get("propagation_seed_signature")
if inherited_signature:
return str(inherited_signature)
signature_payload = {
"polygons": seed.get("polygons") or [],
"holes": seed.get("holes") or [],
"bbox": seed.get("bbox") or [],
"points": seed.get("points") or [],
"labels": seed.get("labels") or [],
"label": seed.get("label"),
"color": seed.get("color"),
"class_metadata": seed.get("class_metadata") or {},
"template_id": seed.get("template_id"),
"smoothing": _normalize_smoothing_options(seed.get("smoothing")),
}
return hashlib.sha256(_stable_json(_canonicalize_signature_value(signature_payload)).encode("utf-8")).hexdigest()
def _seed_key(seed: dict[str, Any]) -> str:
"""Prefer stable persisted ids; fall back to semantic attrs for legacy callers."""
source_annotation_id = seed.get("source_annotation_id")
if source_annotation_id is not None:
return f"annotation:{source_annotation_id}"
source_mask_id = seed.get("source_mask_id")
if source_mask_id:
return f"mask:{source_mask_id}"
class_metadata = seed.get("class_metadata") or {}
class_id = class_metadata.get("id") or class_metadata.get("name")
return _stable_json({
"template_id": seed.get("template_id"),
"class_id": class_id,
"label": seed.get("label"),
"color": seed.get("color"),
})
def _semantic_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool:
"""Best-effort match when a manually edited replacement lacks old lineage ids."""
class_metadata = seed.get("class_metadata") or {}
previous_class = mask_data.get("class") or {}
previous_class_id = previous_class.get("id") or previous_class.get("name")
class_id = class_metadata.get("id") or class_metadata.get("name")
if previous_class_id and class_id and str(previous_class_id) != str(class_id):
return False
return (
mask_data.get("label") == seed.get("label")
and mask_data.get("color") == seed.get("color")
)
def _legacy_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool:
"""Best-effort match for propagation annotations created before seed keys."""
class_metadata = seed.get("class_metadata") or {}
previous_class = mask_data.get("class") or {}
previous_class_id = previous_class.get("id") or previous_class.get("name")
class_id = class_metadata.get("id") or class_metadata.get("name")
return (
mask_data.get("label") == seed.get("label")
and mask_data.get("color") == seed.get("color")
and previous_class_id == class_id
)
def _source_model_matches(mask_data: dict[str, Any], model_id: str) -> bool:
return str(mask_data.get("source") or "") == f"{model_id}_propagation"
def _seed_identity_matches(mask_data: dict[str, Any], seed_key: str, seed: dict[str, Any]) -> bool:
previous_seed_key = mask_data.get("propagation_seed_key")
if previous_seed_key == seed_key:
return True
source_annotation_id = seed.get("source_annotation_id")
if source_annotation_id is not None and str(mask_data.get("source_annotation_id") or "") == str(source_annotation_id):
return True
source_mask_id = seed.get("source_mask_id")
if source_mask_id and mask_data.get("source_mask_id") == source_mask_id:
return True
has_persisted_seed_identity = source_annotation_id is not None or bool(source_mask_id)
has_previous_identity = bool(previous_seed_key) or mask_data.get("source_annotation_id") is not None or bool(mask_data.get("source_mask_id"))
if has_persisted_seed_identity or has_previous_identity:
return False
return _legacy_seed_matches(mask_data, seed)
def _seed_identity_markers(seed: dict[str, Any]) -> set[str]:
markers = {f"seed:{_seed_key(seed)}"}
source_annotation_id = seed.get("source_annotation_id")
if source_annotation_id is not None:
markers.add(f"annotation:{source_annotation_id}")
source_mask_id = seed.get("source_mask_id")
if source_mask_id:
markers.add(f"mask:{source_mask_id}")
return markers
def _mask_identity_markers(mask_data: dict[str, Any]) -> set[str]:
markers: set[str] = set()
previous_seed_key = mask_data.get("propagation_seed_key")
if previous_seed_key:
markers.add(f"seed:{previous_seed_key}")
source_annotation_id = mask_data.get("source_annotation_id")
if source_annotation_id is not None:
markers.add(f"annotation:{source_annotation_id}")
source_mask_id = mask_data.get("source_mask_id")
if source_mask_id:
markers.add(f"mask:{source_mask_id}")
return markers
def _payload_seed_identity_markers(payload: dict[str, Any]) -> set[str]:
markers: set[str] = set()
for step in payload.get("steps") or []:
seed = step.get("seed") or {}
markers.update(_seed_identity_markers(seed))
return markers
def _is_propagation_annotation(annotation: Annotation, seed_key: str, seed: dict[str, Any]) -> bool:
mask_data = annotation.mask_data or {}
source = str(mask_data.get("source") or "")
if not source.endswith("_propagation"):
return False
return _seed_identity_matches(mask_data, seed_key, seed)
def _direction_matches(mask_data: dict[str, Any], direction: str) -> bool:
previous_direction = mask_data.get("propagation_direction")
return previous_direction in {None, direction}
def _annotation_spatially_matches(annotation: Annotation, polygon: list[list[float]]) -> bool:
"""Use target-frame overlap as a final guard before replacing same-object propagation."""
candidate_bbox = _polygon_bbox(polygon)
for previous_polygon in (annotation.mask_data or {}).get("polygons") or []:
if len(previous_polygon) < 3:
continue
if _bbox_overlap_ratio(_polygon_bbox(previous_polygon), candidate_bbox) >= 0.15:
return True
return False
def _delete_replaced_frame_annotations(
db: Session,
*,
payload: dict[str, Any],
frame_id: int,
seed_key: str,
seed: dict[str, Any],
polygon: list[list[float]],
) -> int:
"""Delete old propagated masks for the same object immediately before writing a new result."""
previous_annotations = (
db.query(Annotation)
.filter(Annotation.project_id == int(payload["project_id"]))
.filter(Annotation.frame_id == frame_id)
.all()
)
deleted_count = 0
current_seed_markers = _seed_identity_markers(seed)
task_seed_markers = _payload_seed_identity_markers(payload)
for annotation in previous_annotations:
mask_data = annotation.mask_data or {}
source = str(mask_data.get("source") or "")
if not source.endswith("_propagation"):
continue
mask_markers = _mask_identity_markers(mask_data)
# Keep sibling seeds in the same propagation task from deleting each other.
if mask_markers and mask_markers.isdisjoint(current_seed_markers) and not mask_markers.isdisjoint(task_seed_markers):
continue
same_lineage = _seed_identity_matches(mask_data, seed_key, seed)
same_manual_replacement = (
_semantic_seed_matches(mask_data, seed)
and _annotation_spatially_matches(annotation, polygon)
)
if same_lineage or same_manual_replacement:
db.delete(annotation)
deleted_count += 1
if deleted_count:
db.commit()
return deleted_count
def _prepare_seed_propagation(
db: Session,
*,
payload: dict[str, Any],
model_id: str,
seed: dict[str, Any],
direction: str,
target_frame_ids: set[int],
) -> dict[str, Any]:
seed_key = _seed_key(seed)
seed_signature = _seed_signature(seed)
if not target_frame_ids:
return {
"skip": True,
"seed_key": seed_key,
"seed_signature": seed_signature,
"deleted_annotation_count": 0,
}
previous_annotations = (
db.query(Annotation)
.filter(Annotation.project_id == int(payload["project_id"]))
.filter(Annotation.frame_id.in_(target_frame_ids))
.all()
)
matching = [
annotation for annotation in previous_annotations
if _is_propagation_annotation(annotation, seed_key, seed)
and _direction_matches(annotation.mask_data or {}, direction)
]
covered_frame_ids = {int(annotation.frame_id) for annotation in matching}
if matching and all(
(annotation.mask_data or {}).get("propagation_seed_signature") == seed_signature
and _source_model_matches(annotation.mask_data or {}, model_id)
for annotation in matching
) and target_frame_ids.issubset(covered_frame_ids):
return {
"skip": True,
"seed_key": seed_key,
"seed_signature": seed_signature,
"deleted_annotation_count": 0,
}
deleted_count = 0
if matching:
for annotation in matching:
db.delete(annotation)
deleted_count += 1
db.commit()
return {
"skip": False,
"seed_key": seed_key,
"seed_signature": seed_signature,
"deleted_annotation_count": deleted_count,
}
def _frame_window(
frames: list[Frame],
source_position: int,
direction: str,
max_frames: int,
) -> tuple[list[Frame], int]:
count = max(1, min(max_frames, len(frames)))
if direction == "backward":
start = max(0, source_position - count + 1)
return frames[start:source_position + 1], source_position - start
end = min(len(frames), source_position + count)
return frames[source_position:end], 0
def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]:
paths = []
for index, frame in enumerate(frames):
data = download_file(frame.image_url)
# SAM2VideoPredictor sorts frames by converting the filename stem to int.
path = directory / f"{index:06d}.jpg"
path.write_bytes(data)
paths.append(str(path))
return paths
def _save_propagated_annotations(
db: Session,
*,
payload: dict[str, Any],
selected_frames: list[Frame],
source_frame: Frame,
propagated: list[dict[str, Any]],
seed: dict[str, Any],
) -> tuple[list[Annotation], int]:
created: list[Annotation] = []
if payload.get("save_annotations", True) is False:
return created, 0
class_metadata = seed.get("class_metadata")
template_id = seed.get("template_id")
label = seed.get("label") or "Propagated Mask"
color = seed.get("color") or "#06b6d4"
model_id = sam_registry.normalize_model_id(payload.get("model"))
include_source = bool(payload.get("include_source", False))
seed_key = _seed_key(seed)
seed_signature = _seed_signature(seed)
source_annotation_id = seed.get("source_annotation_id")
source_mask_id = seed.get("source_mask_id")
smoothing = _normalize_smoothing_options(seed.get("smoothing"))
direction = str(payload.get("current_direction") or "")
deleted_count = 0
cleaned_frame_ids: set[int] = set()
for frame_result in propagated:
relative_index = int(frame_result.get("frame_index", -1))
if relative_index < 0 or relative_index >= len(selected_frames):
continue
frame = selected_frames[relative_index]
if not include_source and frame.id == source_frame.id:
continue
result_polygons = frame_result.get("polygons") or []
result_holes = frame_result.get("holes") or []
scores = frame_result.get("scores") or []
prepared_polygons = [
(polygon_index, _smooth_polygon(polygon, smoothing))
for polygon_index, polygon in enumerate(result_polygons)
if len(polygon) >= 3
]
cleanup_polygon = next((polygon for _polygon_index, polygon in prepared_polygons if len(polygon) >= 3), None)
if cleanup_polygon is not None and frame.id not in cleaned_frame_ids:
deleted_count += _delete_replaced_frame_annotations(
db,
payload=payload,
frame_id=int(frame.id),
seed_key=seed_key,
seed=seed,
polygon=cleanup_polygon,
)
cleaned_frame_ids.add(int(frame.id))
polygons_to_save: list[list[list[float]]] = []
holes_to_save: list[list[list[list[float]]]] = []
score_values: list[float] = []
for polygon_index, polygon in prepared_polygons:
if len(polygon) < 3:
continue
polygons_to_save.append(polygon)
hole_group = result_holes[polygon_index] if polygon_index < len(result_holes) and isinstance(result_holes[polygon_index], list) else []
holes_to_save.append(hole_group if isinstance(hole_group, list) else [])
if polygon_index < len(scores):
try:
score_values.append(float(scores[polygon_index]))
except (TypeError, ValueError):
pass
if not polygons_to_save:
continue
annotation = Annotation(
project_id=int(payload["project_id"]),
frame_id=frame.id,
template_id=template_id,
mask_data={
"polygons": polygons_to_save,
**({"holes": holes_to_save, "hasHoles": True} if any(holes_to_save) else {}),
"label": label,
"color": color,
"source": f"{model_id}_propagation",
"propagated_from_frame_id": source_frame.id,
"propagated_from_frame_index": source_frame.frame_index,
"propagation_seed_key": seed_key,
"propagation_seed_signature": seed_signature,
"propagation_direction": direction,
"source_annotation_id": source_annotation_id,
"source_mask_id": source_mask_id,
"score": max(score_values) if score_values else None,
**({"scores": score_values} if len(score_values) > 1 else {}),
**({"geometry_smoothing": smoothing} if smoothing else {}),
**({"class": class_metadata} if class_metadata else {}),
},
points=None,
bbox=_polygons_bbox(polygons_to_save),
)
db.add(annotation)
created.append(annotation)
db.commit()
for annotation in created:
db.refresh(annotation)
return created, deleted_count
def _run_one_step(
db: Session,
*,
payload: dict[str, Any],
frames: list[Frame],
source_frame: Frame,
source_position: int,
step: dict[str, Any],
) -> dict[str, Any]:
direction = str(step.get("direction") or "forward").lower()
if direction not in {"forward", "backward"}:
raise ValueError("direction must be forward or backward")
max_frames = max(1, min(int(step.get("max_frames") or payload.get("max_frames") or 30), 500))
seed = step.get("seed") or {}
if not (seed.get("polygons") or seed.get("bbox") or seed.get("points")):
raise ValueError("Propagation requires seed polygons, bbox, or points")
model_id = sam_registry.normalize_model_id(payload.get("model"))
selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames)
include_source = bool(payload.get("include_source", False))
target_frame_ids = {
int(frame.id)
for frame in selected_frames
if include_source or frame.id != source_frame.id
}
seed_state = _prepare_seed_propagation(
db,
payload=payload,
model_id=model_id,
seed=seed,
direction=direction,
target_frame_ids=target_frame_ids,
)
if seed_state["skip"]:
return {
"model": model_id,
"direction": direction,
"processed_frame_count": 0,
"created_annotation_count": 0,
"deleted_annotation_count": 0,
"skipped_seed_count": 1,
"seed_label": seed.get("label"),
"seed_key": seed_state["seed_key"],
}
with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload['project_id']}_") as tmpdir:
frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir))
propagated = sam_registry.propagate_video(
model_id,
frame_paths,
source_relative_index,
seed,
direction,
len(selected_frames),
)
save_payload = {**payload, "current_direction": direction}
created, write_cleanup_count = _save_propagated_annotations(
db,
payload=save_payload,
selected_frames=selected_frames,
source_frame=source_frame,
propagated=propagated,
seed=seed,
)
return {
"model": model_id,
"direction": direction,
"processed_frame_count": len(selected_frames),
"created_annotation_count": len(created),
"deleted_annotation_count": int(seed_state["deleted_annotation_count"]) + write_cleanup_count,
"skipped_seed_count": 0,
"seed_label": seed.get("label"),
"seed_key": seed_state["seed_key"],
}
def run_propagate_project_task(db: Session, task_id: int) -> dict[str, Any]:
"""Run one queued SAM propagation task and update persisted progress."""
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
if not task:
raise ValueError(f"Task not found: {task_id}")
if task.status == TASK_STATUS_CANCELLED:
return {"task_id": task.id, "status": TASK_STATUS_CANCELLED, "message": task.message or "任务已取消"}
payload = task.payload or {}
project_id = int(payload.get("project_id") or task.project_id or 0)
source_frame_id = int(payload.get("frame_id") or 0)
try:
model_id = sam_registry.normalize_model_id(payload.get("model"))
except ValueError as exc:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
raise
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="项目不存在", error="Project not found", finished=True)
raise ValueError(f"Project not found: {project_id}")
source_frame = db.query(Frame).filter(Frame.id == source_frame_id, Frame.project_id == project_id).first()
if not source_frame:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不存在", error="Frame not found", finished=True)
raise ValueError(f"Frame not found: {source_frame_id}")
frames = db.query(Frame).filter(Frame.project_id == project_id).order_by(Frame.frame_index).all()
source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None)
if source_position is None:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不在项目帧序列中", error="Source frame is not in project frame sequence", finished=True)
raise ValueError("Source frame is not in project frame sequence")
steps = payload.get("steps") or []
if not steps:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="传播任务缺少步骤", error="Propagation task has no steps", finished=True)
raise ValueError("Propagation task has no steps")
_ensure_not_cancelled(db, task)
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="自动传播任务已启动", started=True)
step_results: list[dict[str, Any]] = []
created_count = 0
processed_count = 0
deleted_count = 0
skipped_count = 0
total_steps = len(steps)
try:
for index, step in enumerate(steps, start=1):
_ensure_not_cancelled(db, task)
seed_label = (step.get("seed") or {}).get("label") or "mask"
direction_label = "向前传播" if step.get("direction") == "backward" else "向后传播"
progress_before = 5 + int(((index - 1) / total_steps) * 90)
_set_task_state(
db,
task,
progress=progress_before,
message=f"{direction_label} {seed_label} ({index}/{total_steps})",
result={
"project_id": project_id,
"source_frame_id": source_frame_id,
"model": model_id,
"total_steps": total_steps,
"completed_steps": index - 1,
"processed_frame_count": processed_count,
"created_annotation_count": created_count,
"deleted_annotation_count": deleted_count,
"skipped_seed_count": skipped_count,
"steps": step_results,
},
)
result = _run_one_step(
db,
payload=payload,
frames=frames,
source_frame=source_frame,
source_position=source_position,
step=step,
)
step_results.append(result)
created_count += int(result["created_annotation_count"])
processed_count += int(result["processed_frame_count"])
deleted_count += int(result.get("deleted_annotation_count") or 0)
skipped_count += int(result.get("skipped_seed_count") or 0)
_set_task_state(
db,
task,
progress=5 + int((index / total_steps) * 90),
message=f"{direction_label} {seed_label} 完成 ({index}/{total_steps})",
result={
"project_id": project_id,
"source_frame_id": source_frame_id,
"model": model_id,
"total_steps": total_steps,
"completed_steps": index,
"processed_frame_count": processed_count,
"created_annotation_count": created_count,
"deleted_annotation_count": deleted_count,
"skipped_seed_count": skipped_count,
"steps": step_results,
},
)
result = {
"project_id": project_id,
"source_frame_id": source_frame_id,
"model": model_id,
"total_steps": total_steps,
"completed_steps": total_steps,
"processed_frame_count": processed_count,
"created_annotation_count": created_count,
"deleted_annotation_count": deleted_count,
"skipped_seed_count": skipped_count,
"steps": step_results,
}
_set_task_state(
db,
task,
status=TASK_STATUS_SUCCESS,
progress=100,
message="自动传播完成" if created_count > 0 else (
"自动传播完成,未改变的 mask 已跳过" if skipped_count > 0 else "自动传播完成,但没有生成新的 mask"
),
result=result,
finished=True,
)
return result
except PropagationTaskCancelled:
task.status = TASK_STATUS_CANCELLED
task.progress = 100
task.message = task.message or "任务已取消"
task.error = task.error or "Cancelled by user"
task.finished_at = task.finished_at or _now()
db.commit()
db.refresh(task)
publish_task_progress_event(task)
return {"task_id": task.id, "project_id": project_id, "status": TASK_STATUS_CANCELLED, "message": task.message}
except (ModelUnavailableError, NotImplementedError, ValueError) as exc:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
raise
except Exception as exc: # noqa: BLE001
logger.exception("Propagation task failed: task_id=%s", task.id)
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
raise