Files
Pre_Seg_Server/backend/services/propagation_task_runner.py
admin 093ef6c63a 优化工作区传播和清空交互
- 手工多边形、矩形和圆在未选语义分类时默认归入 maskid:0 的待分类类别。

- 后端自动传播按来源 annotation/mask/seed key 区分同类多实例,避免多个同类型 mask 传播时互相清理。

- 左侧工具栏在橡皮擦下方新增彩色 AI 自动传播入口,传播权重和范围控件只在进入传播后显示。

- 移除顶栏重复的清空片段遮罩入口,并取消当前清空/DEL 弹窗中的按帧范围清空路径。

- Canvas 右下角显示当前帧:XX/XXX,并调整布尔操作浮层位置避免重叠。

- 更新前端和后端回归测试,覆盖待分类默认、工具栏自动传播和同类多实例传播。

- 同步 AGENTS 与 doc 文档,说明新的工具栏、清空和传播行为。
2026-05-04 00:26:11 +08:00

745 lines
29 KiB
Python

"""Background SAM video propagation runner used by Celery workers."""
import hashlib
import json
import logging
import tempfile
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import cv2
import numpy as np
from sqlalchemy.orm import Session
from minio_client import download_file
from models import Annotation, Frame, ProcessingTask, Project
from progress_events import publish_task_progress_event
from services.sam_registry import ModelUnavailableError, sam_registry
from statuses import (
TASK_STATUS_CANCELLED,
TASK_STATUS_FAILED,
TASK_STATUS_RUNNING,
TASK_STATUS_SUCCESS,
)
logger = logging.getLogger(__name__)
class PropagationTaskCancelled(RuntimeError):
"""Raised internally when a persisted propagation task has been cancelled."""
def _now() -> datetime:
return datetime.now(timezone.utc)
def _set_task_state(
db: Session,
task: ProcessingTask,
*,
status: str | None = None,
progress: int | None = None,
message: str | None = None,
result: dict[str, Any] | None = None,
error: str | None = None,
started: bool = False,
finished: bool = False,
) -> None:
if status is not None:
task.status = status
if progress is not None:
task.progress = max(0, min(100, progress))
if message is not None:
task.message = message
if result is not None:
task.result = result
if error is not None:
task.error = error
if started:
task.started_at = _now()
if finished:
task.finished_at = _now()
db.commit()
db.refresh(task)
publish_task_progress_event(task)
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
db.refresh(task)
if task.status == TASK_STATUS_CANCELLED:
raise PropagationTaskCancelled("Task was cancelled")
def _clamp01(value: float) -> float:
return min(max(float(value), 0.0), 1.0)
def _polygon_bbox(polygon: list[list[float]]) -> list[float]:
xs = [_clamp01(point[0]) for point in polygon]
ys = [_clamp01(point[1]) for point in polygon]
left, right = min(xs), max(xs)
top, bottom = min(ys), max(ys)
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
def _normalize_polygon(polygon: list[list[float]]) -> list[list[float]]:
return [[_clamp01(point[0]), _clamp01(point[1])] for point in polygon if len(point) >= 2]
def _normalize_smoothing_options(value: Any) -> dict[str, Any] | None:
if not isinstance(value, dict):
return None
try:
strength = max(0.0, min(float(value.get("strength") or 0.0), 100.0))
except (TypeError, ValueError):
strength = 0.0
if strength <= 0:
return None
method = str(value.get("method") or "chaikin").lower()
if method != "chaikin":
method = "chaikin"
return {"strength": round(strength, 2), "method": method}
def _smoothing_ratio(strength: float, curve: float = 1.65) -> float:
normalized = max(0.0, min(float(strength or 0.0), 100.0)) / 100.0
return normalized ** curve
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int, corner_cut: float = 0.25) -> list[list[float]]:
points = _normalize_polygon(polygon)
q = max(0.02, min(float(corner_cut), 0.25))
for _ in range(max(0, iterations)):
if len(points) < 3:
break
next_points: list[list[float]] = []
for index, current in enumerate(points):
following = points[(index + 1) % len(points)]
next_points.append([
_clamp01((1.0 - q) * current[0] + q * following[0]),
_clamp01((1.0 - q) * current[1] + q * following[1]),
])
next_points.append([
_clamp01(q * current[0] + (1.0 - q) * following[0]),
_clamp01(q * current[1] + (1.0 - q) * following[1]),
])
points = next_points
return points
def _simplify_polygon(polygon: list[list[float]], strength: float) -> list[list[float]]:
if len(polygon) < 3:
return polygon
contour = np.array([[[point[0], point[1]]] for point in polygon], dtype=np.float32)
arc_length = cv2.arcLength(contour, True)
epsilon = arc_length * (0.00015 + _smoothing_ratio(strength) * 0.00735)
approx = cv2.approxPolyDP(contour, epsilon, True).reshape(-1, 2)
if len(approx) < 3:
return polygon
return [[_clamp01(float(x)), _clamp01(float(y))] for x, y in approx]
def _smooth_polygon(polygon: list[list[float]], smoothing: dict[str, Any] | None) -> list[list[float]]:
if not smoothing:
return _normalize_polygon(polygon)
strength = float(smoothing.get("strength") or 0.0)
if strength <= 0:
return _normalize_polygon(polygon)
effective_strength = _smoothing_ratio(strength, curve=1.45) * 100.0
if effective_strength >= 85:
iterations = 4
elif effective_strength >= 55:
iterations = 3
elif effective_strength >= 25:
iterations = 2
else:
iterations = 1
corner_cut = 0.03 + _smoothing_ratio(strength, curve=1.35) * 0.22
normalized = _normalize_polygon(polygon)
pre_simplified = _simplify_polygon(normalized, effective_strength * 0.25)
smoothed = _chaikin_smooth_polygon(pre_simplified, iterations, corner_cut)
simplified = _simplify_polygon(smoothed, effective_strength)
if len(simplified) > len(normalized):
for fallback_strength in (25.0, 35.0, 50.0, 70.0, 90.0, 100.0):
simplified = _simplify_polygon(simplified, max(effective_strength, fallback_strength))
if len(simplified) <= len(normalized):
break
return simplified if len(simplified) >= 3 else _normalize_polygon(polygon)
def _bbox_area(bbox: list[float]) -> float:
return max(float(bbox[2]), 0.0) * max(float(bbox[3]), 0.0)
def _bbox_overlap_ratio(a: list[float], b: list[float]) -> float:
ax1, ay1, aw, ah = a
bx1, by1, bw, bh = b
ax2 = ax1 + aw
ay2 = ay1 + ah
bx2 = bx1 + bw
by2 = by1 + bh
overlap_width = max(0.0, min(ax2, bx2) - max(ax1, bx1))
overlap_height = max(0.0, min(ay2, by2) - max(ay1, by1))
overlap_area = overlap_width * overlap_height
smallest_area = min(_bbox_area(a), _bbox_area(b))
return overlap_area / smallest_area if smallest_area > 0 else 0.0
def _stable_json(value: Any) -> str:
return json.dumps(value, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
def _canonicalize_signature_value(value: Any) -> Any:
if isinstance(value, float):
return round(value, 6)
if isinstance(value, list):
return [_canonicalize_signature_value(item) for item in value]
if isinstance(value, dict):
return {key: _canonicalize_signature_value(value[key]) for key in sorted(value)}
return value
def _seed_signature(seed: dict[str, Any]) -> str:
"""Return a stable signature for seed geometry and semantic attrs."""
inherited_signature = seed.get("propagation_seed_signature")
if inherited_signature:
return str(inherited_signature)
signature_payload = {
"polygons": seed.get("polygons") or [],
"holes": seed.get("holes") or [],
"bbox": seed.get("bbox") or [],
"points": seed.get("points") or [],
"labels": seed.get("labels") or [],
"label": seed.get("label"),
"color": seed.get("color"),
"class_metadata": seed.get("class_metadata") or {},
"template_id": seed.get("template_id"),
"smoothing": _normalize_smoothing_options(seed.get("smoothing")),
}
return hashlib.sha256(_stable_json(_canonicalize_signature_value(signature_payload)).encode("utf-8")).hexdigest()
def _seed_key(seed: dict[str, Any]) -> str:
"""Prefer stable persisted ids; fall back to semantic attrs for legacy callers."""
source_annotation_id = seed.get("source_annotation_id")
if source_annotation_id is not None:
return f"annotation:{source_annotation_id}"
source_mask_id = seed.get("source_mask_id")
if source_mask_id:
return f"mask:{source_mask_id}"
class_metadata = seed.get("class_metadata") or {}
class_id = class_metadata.get("id") or class_metadata.get("name")
return _stable_json({
"template_id": seed.get("template_id"),
"class_id": class_id,
"label": seed.get("label"),
"color": seed.get("color"),
})
def _semantic_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool:
"""Best-effort match when a manually edited replacement lacks old lineage ids."""
class_metadata = seed.get("class_metadata") or {}
previous_class = mask_data.get("class") or {}
previous_class_id = previous_class.get("id") or previous_class.get("name")
class_id = class_metadata.get("id") or class_metadata.get("name")
if previous_class_id and class_id and str(previous_class_id) != str(class_id):
return False
return (
mask_data.get("label") == seed.get("label")
and mask_data.get("color") == seed.get("color")
)
def _legacy_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool:
"""Best-effort match for propagation annotations created before seed keys."""
class_metadata = seed.get("class_metadata") or {}
previous_class = mask_data.get("class") or {}
previous_class_id = previous_class.get("id") or previous_class.get("name")
class_id = class_metadata.get("id") or class_metadata.get("name")
return (
mask_data.get("label") == seed.get("label")
and mask_data.get("color") == seed.get("color")
and previous_class_id == class_id
)
def _source_model_matches(mask_data: dict[str, Any], model_id: str) -> bool:
return str(mask_data.get("source") or "") == f"{model_id}_propagation"
def _seed_identity_matches(mask_data: dict[str, Any], seed_key: str, seed: dict[str, Any]) -> bool:
previous_seed_key = mask_data.get("propagation_seed_key")
if previous_seed_key == seed_key:
return True
source_annotation_id = seed.get("source_annotation_id")
if source_annotation_id is not None and str(mask_data.get("source_annotation_id") or "") == str(source_annotation_id):
return True
source_mask_id = seed.get("source_mask_id")
if source_mask_id and mask_data.get("source_mask_id") == source_mask_id:
return True
has_persisted_seed_identity = source_annotation_id is not None or bool(source_mask_id)
has_previous_identity = bool(previous_seed_key) or mask_data.get("source_annotation_id") is not None or bool(mask_data.get("source_mask_id"))
if has_persisted_seed_identity or has_previous_identity:
return False
return _legacy_seed_matches(mask_data, seed)
def _is_propagation_annotation(annotation: Annotation, seed_key: str, seed: dict[str, Any]) -> bool:
mask_data = annotation.mask_data or {}
source = str(mask_data.get("source") or "")
if not source.endswith("_propagation"):
return False
return _seed_identity_matches(mask_data, seed_key, seed)
def _direction_matches(mask_data: dict[str, Any], direction: str) -> bool:
previous_direction = mask_data.get("propagation_direction")
return previous_direction in {None, direction}
def _annotation_spatially_matches(annotation: Annotation, polygon: list[list[float]]) -> bool:
"""Use target-frame overlap as a final guard before replacing same-object propagation."""
candidate_bbox = _polygon_bbox(polygon)
for previous_polygon in (annotation.mask_data or {}).get("polygons") or []:
if len(previous_polygon) < 3:
continue
if _bbox_overlap_ratio(_polygon_bbox(previous_polygon), candidate_bbox) >= 0.15:
return True
return False
def _delete_replaced_frame_annotations(
db: Session,
*,
payload: dict[str, Any],
frame_id: int,
seed_key: str,
seed: dict[str, Any],
polygon: list[list[float]],
) -> int:
"""Delete old propagated masks for the same object immediately before writing a new result."""
previous_annotations = (
db.query(Annotation)
.filter(Annotation.project_id == int(payload["project_id"]))
.filter(Annotation.frame_id == frame_id)
.all()
)
deleted_count = 0
for annotation in previous_annotations:
mask_data = annotation.mask_data or {}
source = str(mask_data.get("source") or "")
if not source.endswith("_propagation"):
continue
same_lineage = _seed_identity_matches(mask_data, seed_key, seed)
same_manual_replacement = (
_semantic_seed_matches(mask_data, seed)
and _annotation_spatially_matches(annotation, polygon)
)
if same_lineage or same_manual_replacement:
db.delete(annotation)
deleted_count += 1
if deleted_count:
db.commit()
return deleted_count
def _prepare_seed_propagation(
db: Session,
*,
payload: dict[str, Any],
model_id: str,
seed: dict[str, Any],
direction: str,
target_frame_ids: set[int],
) -> dict[str, Any]:
seed_key = _seed_key(seed)
seed_signature = _seed_signature(seed)
if not target_frame_ids:
return {
"skip": True,
"seed_key": seed_key,
"seed_signature": seed_signature,
"deleted_annotation_count": 0,
}
previous_annotations = (
db.query(Annotation)
.filter(Annotation.project_id == int(payload["project_id"]))
.filter(Annotation.frame_id.in_(target_frame_ids))
.all()
)
matching = [
annotation for annotation in previous_annotations
if _is_propagation_annotation(annotation, seed_key, seed)
and _direction_matches(annotation.mask_data or {}, direction)
]
covered_frame_ids = {int(annotation.frame_id) for annotation in matching}
if matching and all(
(annotation.mask_data or {}).get("propagation_seed_signature") == seed_signature
and _source_model_matches(annotation.mask_data or {}, model_id)
for annotation in matching
) and target_frame_ids.issubset(covered_frame_ids):
return {
"skip": True,
"seed_key": seed_key,
"seed_signature": seed_signature,
"deleted_annotation_count": 0,
}
deleted_count = 0
if matching:
for annotation in matching:
db.delete(annotation)
deleted_count += 1
db.commit()
return {
"skip": False,
"seed_key": seed_key,
"seed_signature": seed_signature,
"deleted_annotation_count": deleted_count,
}
def _frame_window(
frames: list[Frame],
source_position: int,
direction: str,
max_frames: int,
) -> tuple[list[Frame], int]:
count = max(1, min(max_frames, len(frames)))
if direction == "backward":
start = max(0, source_position - count + 1)
return frames[start:source_position + 1], source_position - start
end = min(len(frames), source_position + count)
return frames[source_position:end], 0
def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]:
paths = []
for index, frame in enumerate(frames):
data = download_file(frame.image_url)
# SAM2VideoPredictor sorts frames by converting the filename stem to int.
path = directory / f"{index:06d}.jpg"
path.write_bytes(data)
paths.append(str(path))
return paths
def _save_propagated_annotations(
db: Session,
*,
payload: dict[str, Any],
selected_frames: list[Frame],
source_frame: Frame,
propagated: list[dict[str, Any]],
seed: dict[str, Any],
) -> tuple[list[Annotation], int]:
created: list[Annotation] = []
if payload.get("save_annotations", True) is False:
return created, 0
class_metadata = seed.get("class_metadata")
template_id = seed.get("template_id")
label = seed.get("label") or "Propagated Mask"
color = seed.get("color") or "#06b6d4"
model_id = sam_registry.normalize_model_id(payload.get("model"))
include_source = bool(payload.get("include_source", False))
seed_key = _seed_key(seed)
seed_signature = _seed_signature(seed)
source_annotation_id = seed.get("source_annotation_id")
source_mask_id = seed.get("source_mask_id")
smoothing = _normalize_smoothing_options(seed.get("smoothing"))
direction = str(payload.get("current_direction") or "")
deleted_count = 0
cleaned_frame_ids: set[int] = set()
for frame_result in propagated:
relative_index = int(frame_result.get("frame_index", -1))
if relative_index < 0 or relative_index >= len(selected_frames):
continue
frame = selected_frames[relative_index]
if not include_source and frame.id == source_frame.id:
continue
result_polygons = frame_result.get("polygons") or []
result_holes = frame_result.get("holes") or []
scores = frame_result.get("scores") or []
prepared_polygons = [
(polygon_index, _smooth_polygon(polygon, smoothing))
for polygon_index, polygon in enumerate(result_polygons)
if len(polygon) >= 3
]
cleanup_polygon = next((polygon for _polygon_index, polygon in prepared_polygons if len(polygon) >= 3), None)
if cleanup_polygon is not None and frame.id not in cleaned_frame_ids:
deleted_count += _delete_replaced_frame_annotations(
db,
payload=payload,
frame_id=int(frame.id),
seed_key=seed_key,
seed=seed,
polygon=cleanup_polygon,
)
cleaned_frame_ids.add(int(frame.id))
for polygon_index, polygon in prepared_polygons:
if len(polygon) < 3:
continue
hole_group = result_holes[polygon_index] if polygon_index < len(result_holes) and isinstance(result_holes[polygon_index], list) else []
annotation = Annotation(
project_id=int(payload["project_id"]),
frame_id=frame.id,
template_id=template_id,
mask_data={
"polygons": [polygon],
**({"holes": [hole_group], "hasHoles": True} if hole_group else {}),
"label": label,
"color": color,
"source": f"{model_id}_propagation",
"propagated_from_frame_id": source_frame.id,
"propagated_from_frame_index": source_frame.frame_index,
"propagation_seed_key": seed_key,
"propagation_seed_signature": seed_signature,
"propagation_direction": direction,
"source_annotation_id": source_annotation_id,
"source_mask_id": source_mask_id,
"score": scores[polygon_index] if polygon_index < len(scores) else None,
**({"geometry_smoothing": smoothing} if smoothing else {}),
**({"class": class_metadata} if class_metadata else {}),
},
points=None,
bbox=_polygon_bbox(polygon),
)
db.add(annotation)
created.append(annotation)
db.commit()
for annotation in created:
db.refresh(annotation)
return created, deleted_count
def _run_one_step(
db: Session,
*,
payload: dict[str, Any],
frames: list[Frame],
source_frame: Frame,
source_position: int,
step: dict[str, Any],
) -> dict[str, Any]:
direction = str(step.get("direction") or "forward").lower()
if direction not in {"forward", "backward"}:
raise ValueError("direction must be forward or backward")
max_frames = max(1, min(int(step.get("max_frames") or payload.get("max_frames") or 30), 500))
seed = step.get("seed") or {}
if not (seed.get("polygons") or seed.get("bbox") or seed.get("points")):
raise ValueError("Propagation requires seed polygons, bbox, or points")
model_id = sam_registry.normalize_model_id(payload.get("model"))
selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames)
include_source = bool(payload.get("include_source", False))
target_frame_ids = {
int(frame.id)
for frame in selected_frames
if include_source or frame.id != source_frame.id
}
seed_state = _prepare_seed_propagation(
db,
payload=payload,
model_id=model_id,
seed=seed,
direction=direction,
target_frame_ids=target_frame_ids,
)
if seed_state["skip"]:
return {
"model": model_id,
"direction": direction,
"processed_frame_count": 0,
"created_annotation_count": 0,
"deleted_annotation_count": 0,
"skipped_seed_count": 1,
"seed_label": seed.get("label"),
"seed_key": seed_state["seed_key"],
}
with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload['project_id']}_") as tmpdir:
frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir))
propagated = sam_registry.propagate_video(
model_id,
frame_paths,
source_relative_index,
seed,
direction,
len(selected_frames),
)
save_payload = {**payload, "current_direction": direction}
created, write_cleanup_count = _save_propagated_annotations(
db,
payload=save_payload,
selected_frames=selected_frames,
source_frame=source_frame,
propagated=propagated,
seed=seed,
)
return {
"model": model_id,
"direction": direction,
"processed_frame_count": len(selected_frames),
"created_annotation_count": len(created),
"deleted_annotation_count": int(seed_state["deleted_annotation_count"]) + write_cleanup_count,
"skipped_seed_count": 0,
"seed_label": seed.get("label"),
"seed_key": seed_state["seed_key"],
}
def run_propagate_project_task(db: Session, task_id: int) -> dict[str, Any]:
"""Run one queued SAM propagation task and update persisted progress."""
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
if not task:
raise ValueError(f"Task not found: {task_id}")
if task.status == TASK_STATUS_CANCELLED:
return {"task_id": task.id, "status": TASK_STATUS_CANCELLED, "message": task.message or "任务已取消"}
payload = task.payload or {}
project_id = int(payload.get("project_id") or task.project_id or 0)
source_frame_id = int(payload.get("frame_id") or 0)
try:
model_id = sam_registry.normalize_model_id(payload.get("model"))
except ValueError as exc:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
raise
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="项目不存在", error="Project not found", finished=True)
raise ValueError(f"Project not found: {project_id}")
source_frame = db.query(Frame).filter(Frame.id == source_frame_id, Frame.project_id == project_id).first()
if not source_frame:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不存在", error="Frame not found", finished=True)
raise ValueError(f"Frame not found: {source_frame_id}")
frames = db.query(Frame).filter(Frame.project_id == project_id).order_by(Frame.frame_index).all()
source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None)
if source_position is None:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不在项目帧序列中", error="Source frame is not in project frame sequence", finished=True)
raise ValueError("Source frame is not in project frame sequence")
steps = payload.get("steps") or []
if not steps:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="传播任务缺少步骤", error="Propagation task has no steps", finished=True)
raise ValueError("Propagation task has no steps")
_ensure_not_cancelled(db, task)
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="自动传播任务已启动", started=True)
step_results: list[dict[str, Any]] = []
created_count = 0
processed_count = 0
deleted_count = 0
skipped_count = 0
total_steps = len(steps)
try:
for index, step in enumerate(steps, start=1):
_ensure_not_cancelled(db, task)
seed_label = (step.get("seed") or {}).get("label") or "mask"
direction_label = "向前传播" if step.get("direction") == "backward" else "向后传播"
progress_before = 5 + int(((index - 1) / total_steps) * 90)
_set_task_state(
db,
task,
progress=progress_before,
message=f"{direction_label} {seed_label} ({index}/{total_steps})",
result={
"project_id": project_id,
"source_frame_id": source_frame_id,
"model": model_id,
"total_steps": total_steps,
"completed_steps": index - 1,
"processed_frame_count": processed_count,
"created_annotation_count": created_count,
"deleted_annotation_count": deleted_count,
"skipped_seed_count": skipped_count,
"steps": step_results,
},
)
result = _run_one_step(
db,
payload=payload,
frames=frames,
source_frame=source_frame,
source_position=source_position,
step=step,
)
step_results.append(result)
created_count += int(result["created_annotation_count"])
processed_count += int(result["processed_frame_count"])
deleted_count += int(result.get("deleted_annotation_count") or 0)
skipped_count += int(result.get("skipped_seed_count") or 0)
_set_task_state(
db,
task,
progress=5 + int((index / total_steps) * 90),
message=f"{direction_label} {seed_label} 完成 ({index}/{total_steps})",
result={
"project_id": project_id,
"source_frame_id": source_frame_id,
"model": model_id,
"total_steps": total_steps,
"completed_steps": index,
"processed_frame_count": processed_count,
"created_annotation_count": created_count,
"deleted_annotation_count": deleted_count,
"skipped_seed_count": skipped_count,
"steps": step_results,
},
)
result = {
"project_id": project_id,
"source_frame_id": source_frame_id,
"model": model_id,
"total_steps": total_steps,
"completed_steps": total_steps,
"processed_frame_count": processed_count,
"created_annotation_count": created_count,
"deleted_annotation_count": deleted_count,
"skipped_seed_count": skipped_count,
"steps": step_results,
}
_set_task_state(
db,
task,
status=TASK_STATUS_SUCCESS,
progress=100,
message="自动传播完成" if created_count > 0 else (
"自动传播完成,未改变的 mask 已跳过" if skipped_count > 0 else "自动传播完成,但没有生成新的 mask"
),
result=result,
finished=True,
)
return result
except PropagationTaskCancelled:
task.status = TASK_STATUS_CANCELLED
task.progress = 100
task.message = task.message or "任务已取消"
task.error = task.error or "Cancelled by user"
task.finished_at = task.finished_at or _now()
db.commit()
db.refresh(task)
publish_task_progress_event(task)
return {"task_id": task.id, "project_id": project_id, "status": TASK_STATUS_CANCELLED, "message": task.message}
except (ModelUnavailableError, NotImplementedError, ValueError) as exc:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
raise
except Exception as exc: # noqa: BLE001
logger.exception("Propagation task failed: task_id=%s", task.id)
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
raise