Files
Pre_Seg_Server/backend/services/propagation_task_runner.py
admin f88f9bdbb9 支持中空mask编辑和传播保洞
- 前端按 polygonRingCounts 维护外圈/内洞 ring 分组,中空 mask 在调整多边形时显示内洞顶点和插点手柄。

- 保存与回显标注时将中空结构拆分为 mask_data.polygons 和 mask_data.holes,导入/普通 mask 共享同一编辑体验。

- 自动传播 seed 携带 holes,SAM 2 seed 栅格化时扣除内洞,避免中空 mask 以实心形式传播。

- 传播结果轮廓提取改为保留层级内洞,并在同步传播和 Celery 传播落库时写回 holes 与 hasHoles。

- 传播 seed 签名纳入 holes,并加固保存结果时 holes 与原始 polygon 索引对齐。

- 补充前端保存/回显、Canvas 内洞编辑和后端 SAM 2 hole 处理测试。

- 更新 AGENTS、接口契约、需求冻结、设计冻结和测试计划文档,移除中空结构未实现的旧描述。
2026-05-03 18:28:46 +08:00

741 lines
28 KiB
Python

"""Background SAM video propagation runner used by Celery workers."""
import hashlib
import json
import logging
import tempfile
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import cv2
import numpy as np
from sqlalchemy.orm import Session
from minio_client import download_file
from models import Annotation, Frame, ProcessingTask, Project
from progress_events import publish_task_progress_event
from services.sam_registry import ModelUnavailableError, sam_registry
from statuses import (
TASK_STATUS_CANCELLED,
TASK_STATUS_FAILED,
TASK_STATUS_RUNNING,
TASK_STATUS_SUCCESS,
)
logger = logging.getLogger(__name__)
class PropagationTaskCancelled(RuntimeError):
"""Raised internally when a persisted propagation task has been cancelled."""
def _now() -> datetime:
return datetime.now(timezone.utc)
def _set_task_state(
db: Session,
task: ProcessingTask,
*,
status: str | None = None,
progress: int | None = None,
message: str | None = None,
result: dict[str, Any] | None = None,
error: str | None = None,
started: bool = False,
finished: bool = False,
) -> None:
if status is not None:
task.status = status
if progress is not None:
task.progress = max(0, min(100, progress))
if message is not None:
task.message = message
if result is not None:
task.result = result
if error is not None:
task.error = error
if started:
task.started_at = _now()
if finished:
task.finished_at = _now()
db.commit()
db.refresh(task)
publish_task_progress_event(task)
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
db.refresh(task)
if task.status == TASK_STATUS_CANCELLED:
raise PropagationTaskCancelled("Task was cancelled")
def _clamp01(value: float) -> float:
return min(max(float(value), 0.0), 1.0)
def _polygon_bbox(polygon: list[list[float]]) -> list[float]:
xs = [_clamp01(point[0]) for point in polygon]
ys = [_clamp01(point[1]) for point in polygon]
left, right = min(xs), max(xs)
top, bottom = min(ys), max(ys)
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
def _normalize_polygon(polygon: list[list[float]]) -> list[list[float]]:
return [[_clamp01(point[0]), _clamp01(point[1])] for point in polygon if len(point) >= 2]
def _normalize_smoothing_options(value: Any) -> dict[str, Any] | None:
if not isinstance(value, dict):
return None
try:
strength = max(0.0, min(float(value.get("strength") or 0.0), 100.0))
except (TypeError, ValueError):
strength = 0.0
if strength <= 0:
return None
method = str(value.get("method") or "chaikin").lower()
if method != "chaikin":
method = "chaikin"
return {"strength": round(strength, 2), "method": method}
def _smoothing_ratio(strength: float, curve: float = 1.65) -> float:
normalized = max(0.0, min(float(strength or 0.0), 100.0)) / 100.0
return normalized ** curve
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int, corner_cut: float = 0.25) -> list[list[float]]:
points = _normalize_polygon(polygon)
q = max(0.02, min(float(corner_cut), 0.25))
for _ in range(max(0, iterations)):
if len(points) < 3:
break
next_points: list[list[float]] = []
for index, current in enumerate(points):
following = points[(index + 1) % len(points)]
next_points.append([
_clamp01((1.0 - q) * current[0] + q * following[0]),
_clamp01((1.0 - q) * current[1] + q * following[1]),
])
next_points.append([
_clamp01(q * current[0] + (1.0 - q) * following[0]),
_clamp01(q * current[1] + (1.0 - q) * following[1]),
])
points = next_points
return points
def _simplify_polygon(polygon: list[list[float]], strength: float) -> list[list[float]]:
if len(polygon) < 3:
return polygon
contour = np.array([[[point[0], point[1]]] for point in polygon], dtype=np.float32)
arc_length = cv2.arcLength(contour, True)
epsilon = arc_length * (0.00015 + _smoothing_ratio(strength) * 0.00735)
approx = cv2.approxPolyDP(contour, epsilon, True).reshape(-1, 2)
if len(approx) < 3:
return polygon
return [[_clamp01(float(x)), _clamp01(float(y))] for x, y in approx]
def _smooth_polygon(polygon: list[list[float]], smoothing: dict[str, Any] | None) -> list[list[float]]:
if not smoothing:
return _normalize_polygon(polygon)
strength = float(smoothing.get("strength") or 0.0)
if strength <= 0:
return _normalize_polygon(polygon)
effective_strength = _smoothing_ratio(strength, curve=1.45) * 100.0
if effective_strength >= 85:
iterations = 4
elif effective_strength >= 55:
iterations = 3
elif effective_strength >= 25:
iterations = 2
else:
iterations = 1
corner_cut = 0.03 + _smoothing_ratio(strength, curve=1.35) * 0.22
normalized = _normalize_polygon(polygon)
pre_simplified = _simplify_polygon(normalized, effective_strength * 0.25)
smoothed = _chaikin_smooth_polygon(pre_simplified, iterations, corner_cut)
simplified = _simplify_polygon(smoothed, effective_strength)
if len(simplified) > len(normalized):
for fallback_strength in (25.0, 35.0, 50.0, 70.0, 90.0, 100.0):
simplified = _simplify_polygon(simplified, max(effective_strength, fallback_strength))
if len(simplified) <= len(normalized):
break
return simplified if len(simplified) >= 3 else _normalize_polygon(polygon)
def _bbox_area(bbox: list[float]) -> float:
return max(float(bbox[2]), 0.0) * max(float(bbox[3]), 0.0)
def _bbox_overlap_ratio(a: list[float], b: list[float]) -> float:
ax1, ay1, aw, ah = a
bx1, by1, bw, bh = b
ax2 = ax1 + aw
ay2 = ay1 + ah
bx2 = bx1 + bw
by2 = by1 + bh
overlap_width = max(0.0, min(ax2, bx2) - max(ax1, bx1))
overlap_height = max(0.0, min(ay2, by2) - max(ay1, by1))
overlap_area = overlap_width * overlap_height
smallest_area = min(_bbox_area(a), _bbox_area(b))
return overlap_area / smallest_area if smallest_area > 0 else 0.0
def _stable_json(value: Any) -> str:
return json.dumps(value, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
def _canonicalize_signature_value(value: Any) -> Any:
if isinstance(value, float):
return round(value, 6)
if isinstance(value, list):
return [_canonicalize_signature_value(item) for item in value]
if isinstance(value, dict):
return {key: _canonicalize_signature_value(value[key]) for key in sorted(value)}
return value
def _seed_signature(seed: dict[str, Any]) -> str:
"""Return a stable signature for seed geometry and semantic attrs."""
inherited_signature = seed.get("propagation_seed_signature")
if inherited_signature:
return str(inherited_signature)
signature_payload = {
"polygons": seed.get("polygons") or [],
"holes": seed.get("holes") or [],
"bbox": seed.get("bbox") or [],
"points": seed.get("points") or [],
"labels": seed.get("labels") or [],
"label": seed.get("label"),
"color": seed.get("color"),
"class_metadata": seed.get("class_metadata") or {},
"template_id": seed.get("template_id"),
"smoothing": _normalize_smoothing_options(seed.get("smoothing")),
}
return hashlib.sha256(_stable_json(_canonicalize_signature_value(signature_payload)).encode("utf-8")).hexdigest()
def _seed_key(seed: dict[str, Any]) -> str:
"""Prefer stable persisted ids; fall back to semantic attrs for legacy callers."""
source_annotation_id = seed.get("source_annotation_id")
if source_annotation_id is not None:
return f"annotation:{source_annotation_id}"
source_mask_id = seed.get("source_mask_id")
if source_mask_id:
return f"mask:{source_mask_id}"
class_metadata = seed.get("class_metadata") or {}
class_id = class_metadata.get("id") or class_metadata.get("name")
return _stable_json({
"template_id": seed.get("template_id"),
"class_id": class_id,
"label": seed.get("label"),
"color": seed.get("color"),
})
def _semantic_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool:
"""Best-effort match when a manually edited replacement lacks old lineage ids."""
class_metadata = seed.get("class_metadata") or {}
previous_class = mask_data.get("class") or {}
previous_class_id = previous_class.get("id") or previous_class.get("name")
class_id = class_metadata.get("id") or class_metadata.get("name")
if previous_class_id and class_id and str(previous_class_id) != str(class_id):
return False
return (
mask_data.get("label") == seed.get("label")
and mask_data.get("color") == seed.get("color")
)
def _legacy_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool:
"""Best-effort match for propagation annotations created before seed keys."""
class_metadata = seed.get("class_metadata") or {}
previous_class = mask_data.get("class") or {}
previous_class_id = previous_class.get("id") or previous_class.get("name")
class_id = class_metadata.get("id") or class_metadata.get("name")
return (
mask_data.get("label") == seed.get("label")
and mask_data.get("color") == seed.get("color")
and previous_class_id == class_id
)
def _source_model_matches(mask_data: dict[str, Any], model_id: str) -> bool:
return str(mask_data.get("source") or "") == f"{model_id}_propagation"
def _seed_identity_matches(mask_data: dict[str, Any], seed_key: str, seed: dict[str, Any]) -> bool:
previous_seed_key = mask_data.get("propagation_seed_key")
if previous_seed_key == seed_key:
return True
source_annotation_id = seed.get("source_annotation_id")
if source_annotation_id is not None and str(mask_data.get("source_annotation_id") or "") == str(source_annotation_id):
return True
source_mask_id = seed.get("source_mask_id")
if source_mask_id and mask_data.get("source_mask_id") == source_mask_id:
return True
return _legacy_seed_matches(mask_data, seed)
def _is_propagation_annotation(annotation: Annotation, seed_key: str, seed: dict[str, Any]) -> bool:
mask_data = annotation.mask_data or {}
source = str(mask_data.get("source") or "")
if not source.endswith("_propagation"):
return False
return _seed_identity_matches(mask_data, seed_key, seed)
def _direction_matches(mask_data: dict[str, Any], direction: str) -> bool:
previous_direction = mask_data.get("propagation_direction")
return previous_direction in {None, direction}
def _annotation_spatially_matches(annotation: Annotation, polygon: list[list[float]]) -> bool:
"""Use target-frame overlap as a final guard before replacing same-object propagation."""
candidate_bbox = _polygon_bbox(polygon)
for previous_polygon in (annotation.mask_data or {}).get("polygons") or []:
if len(previous_polygon) < 3:
continue
if _bbox_overlap_ratio(_polygon_bbox(previous_polygon), candidate_bbox) >= 0.15:
return True
return False
def _delete_replaced_frame_annotations(
db: Session,
*,
payload: dict[str, Any],
frame_id: int,
seed_key: str,
seed: dict[str, Any],
polygon: list[list[float]],
) -> int:
"""Delete old propagated masks for the same object immediately before writing a new result."""
previous_annotations = (
db.query(Annotation)
.filter(Annotation.project_id == int(payload["project_id"]))
.filter(Annotation.frame_id == frame_id)
.all()
)
deleted_count = 0
for annotation in previous_annotations:
mask_data = annotation.mask_data or {}
source = str(mask_data.get("source") or "")
if not source.endswith("_propagation"):
continue
same_lineage = _seed_identity_matches(mask_data, seed_key, seed)
same_manual_replacement = (
_semantic_seed_matches(mask_data, seed)
and _annotation_spatially_matches(annotation, polygon)
)
if same_lineage or same_manual_replacement:
db.delete(annotation)
deleted_count += 1
if deleted_count:
db.commit()
return deleted_count
def _prepare_seed_propagation(
db: Session,
*,
payload: dict[str, Any],
model_id: str,
seed: dict[str, Any],
direction: str,
target_frame_ids: set[int],
) -> dict[str, Any]:
seed_key = _seed_key(seed)
seed_signature = _seed_signature(seed)
if not target_frame_ids:
return {
"skip": True,
"seed_key": seed_key,
"seed_signature": seed_signature,
"deleted_annotation_count": 0,
}
previous_annotations = (
db.query(Annotation)
.filter(Annotation.project_id == int(payload["project_id"]))
.filter(Annotation.frame_id.in_(target_frame_ids))
.all()
)
matching = [
annotation for annotation in previous_annotations
if _is_propagation_annotation(annotation, seed_key, seed)
and _direction_matches(annotation.mask_data or {}, direction)
]
covered_frame_ids = {int(annotation.frame_id) for annotation in matching}
if matching and all(
(annotation.mask_data or {}).get("propagation_seed_signature") == seed_signature
and _source_model_matches(annotation.mask_data or {}, model_id)
for annotation in matching
) and target_frame_ids.issubset(covered_frame_ids):
return {
"skip": True,
"seed_key": seed_key,
"seed_signature": seed_signature,
"deleted_annotation_count": 0,
}
deleted_count = 0
if matching:
for annotation in matching:
db.delete(annotation)
deleted_count += 1
db.commit()
return {
"skip": False,
"seed_key": seed_key,
"seed_signature": seed_signature,
"deleted_annotation_count": deleted_count,
}
def _frame_window(
frames: list[Frame],
source_position: int,
direction: str,
max_frames: int,
) -> tuple[list[Frame], int]:
count = max(1, min(max_frames, len(frames)))
if direction == "backward":
start = max(0, source_position - count + 1)
return frames[start:source_position + 1], source_position - start
end = min(len(frames), source_position + count)
return frames[source_position:end], 0
def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]:
paths = []
for index, frame in enumerate(frames):
data = download_file(frame.image_url)
# SAM2VideoPredictor sorts frames by converting the filename stem to int.
path = directory / f"{index:06d}.jpg"
path.write_bytes(data)
paths.append(str(path))
return paths
def _save_propagated_annotations(
db: Session,
*,
payload: dict[str, Any],
selected_frames: list[Frame],
source_frame: Frame,
propagated: list[dict[str, Any]],
seed: dict[str, Any],
) -> tuple[list[Annotation], int]:
created: list[Annotation] = []
if payload.get("save_annotations", True) is False:
return created, 0
class_metadata = seed.get("class_metadata")
template_id = seed.get("template_id")
label = seed.get("label") or "Propagated Mask"
color = seed.get("color") or "#06b6d4"
model_id = sam_registry.normalize_model_id(payload.get("model"))
include_source = bool(payload.get("include_source", False))
seed_key = _seed_key(seed)
seed_signature = _seed_signature(seed)
source_annotation_id = seed.get("source_annotation_id")
source_mask_id = seed.get("source_mask_id")
smoothing = _normalize_smoothing_options(seed.get("smoothing"))
direction = str(payload.get("current_direction") or "")
deleted_count = 0
cleaned_frame_ids: set[int] = set()
for frame_result in propagated:
relative_index = int(frame_result.get("frame_index", -1))
if relative_index < 0 or relative_index >= len(selected_frames):
continue
frame = selected_frames[relative_index]
if not include_source and frame.id == source_frame.id:
continue
result_polygons = frame_result.get("polygons") or []
result_holes = frame_result.get("holes") or []
scores = frame_result.get("scores") or []
prepared_polygons = [
(polygon_index, _smooth_polygon(polygon, smoothing))
for polygon_index, polygon in enumerate(result_polygons)
if len(polygon) >= 3
]
cleanup_polygon = next((polygon for _polygon_index, polygon in prepared_polygons if len(polygon) >= 3), None)
if cleanup_polygon is not None and frame.id not in cleaned_frame_ids:
deleted_count += _delete_replaced_frame_annotations(
db,
payload=payload,
frame_id=int(frame.id),
seed_key=seed_key,
seed=seed,
polygon=cleanup_polygon,
)
cleaned_frame_ids.add(int(frame.id))
for polygon_index, polygon in prepared_polygons:
if len(polygon) < 3:
continue
hole_group = result_holes[polygon_index] if polygon_index < len(result_holes) and isinstance(result_holes[polygon_index], list) else []
annotation = Annotation(
project_id=int(payload["project_id"]),
frame_id=frame.id,
template_id=template_id,
mask_data={
"polygons": [polygon],
**({"holes": [hole_group], "hasHoles": True} if hole_group else {}),
"label": label,
"color": color,
"source": f"{model_id}_propagation",
"propagated_from_frame_id": source_frame.id,
"propagated_from_frame_index": source_frame.frame_index,
"propagation_seed_key": seed_key,
"propagation_seed_signature": seed_signature,
"propagation_direction": direction,
"source_annotation_id": source_annotation_id,
"source_mask_id": source_mask_id,
"score": scores[polygon_index] if polygon_index < len(scores) else None,
**({"geometry_smoothing": smoothing} if smoothing else {}),
**({"class": class_metadata} if class_metadata else {}),
},
points=None,
bbox=_polygon_bbox(polygon),
)
db.add(annotation)
created.append(annotation)
db.commit()
for annotation in created:
db.refresh(annotation)
return created, deleted_count
def _run_one_step(
db: Session,
*,
payload: dict[str, Any],
frames: list[Frame],
source_frame: Frame,
source_position: int,
step: dict[str, Any],
) -> dict[str, Any]:
direction = str(step.get("direction") or "forward").lower()
if direction not in {"forward", "backward"}:
raise ValueError("direction must be forward or backward")
max_frames = max(1, min(int(step.get("max_frames") or payload.get("max_frames") or 30), 500))
seed = step.get("seed") or {}
if not (seed.get("polygons") or seed.get("bbox") or seed.get("points")):
raise ValueError("Propagation requires seed polygons, bbox, or points")
model_id = sam_registry.normalize_model_id(payload.get("model"))
selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames)
include_source = bool(payload.get("include_source", False))
target_frame_ids = {
int(frame.id)
for frame in selected_frames
if include_source or frame.id != source_frame.id
}
seed_state = _prepare_seed_propagation(
db,
payload=payload,
model_id=model_id,
seed=seed,
direction=direction,
target_frame_ids=target_frame_ids,
)
if seed_state["skip"]:
return {
"model": model_id,
"direction": direction,
"processed_frame_count": 0,
"created_annotation_count": 0,
"deleted_annotation_count": 0,
"skipped_seed_count": 1,
"seed_label": seed.get("label"),
"seed_key": seed_state["seed_key"],
}
with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload['project_id']}_") as tmpdir:
frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir))
propagated = sam_registry.propagate_video(
model_id,
frame_paths,
source_relative_index,
seed,
direction,
len(selected_frames),
)
save_payload = {**payload, "current_direction": direction}
created, write_cleanup_count = _save_propagated_annotations(
db,
payload=save_payload,
selected_frames=selected_frames,
source_frame=source_frame,
propagated=propagated,
seed=seed,
)
return {
"model": model_id,
"direction": direction,
"processed_frame_count": len(selected_frames),
"created_annotation_count": len(created),
"deleted_annotation_count": int(seed_state["deleted_annotation_count"]) + write_cleanup_count,
"skipped_seed_count": 0,
"seed_label": seed.get("label"),
"seed_key": seed_state["seed_key"],
}
def run_propagate_project_task(db: Session, task_id: int) -> dict[str, Any]:
"""Run one queued SAM propagation task and update persisted progress."""
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
if not task:
raise ValueError(f"Task not found: {task_id}")
if task.status == TASK_STATUS_CANCELLED:
return {"task_id": task.id, "status": TASK_STATUS_CANCELLED, "message": task.message or "任务已取消"}
payload = task.payload or {}
project_id = int(payload.get("project_id") or task.project_id or 0)
source_frame_id = int(payload.get("frame_id") or 0)
try:
model_id = sam_registry.normalize_model_id(payload.get("model"))
except ValueError as exc:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
raise
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="项目不存在", error="Project not found", finished=True)
raise ValueError(f"Project not found: {project_id}")
source_frame = db.query(Frame).filter(Frame.id == source_frame_id, Frame.project_id == project_id).first()
if not source_frame:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不存在", error="Frame not found", finished=True)
raise ValueError(f"Frame not found: {source_frame_id}")
frames = db.query(Frame).filter(Frame.project_id == project_id).order_by(Frame.frame_index).all()
source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None)
if source_position is None:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不在项目帧序列中", error="Source frame is not in project frame sequence", finished=True)
raise ValueError("Source frame is not in project frame sequence")
steps = payload.get("steps") or []
if not steps:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="传播任务缺少步骤", error="Propagation task has no steps", finished=True)
raise ValueError("Propagation task has no steps")
_ensure_not_cancelled(db, task)
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="自动传播任务已启动", started=True)
step_results: list[dict[str, Any]] = []
created_count = 0
processed_count = 0
deleted_count = 0
skipped_count = 0
total_steps = len(steps)
try:
for index, step in enumerate(steps, start=1):
_ensure_not_cancelled(db, task)
seed_label = (step.get("seed") or {}).get("label") or "mask"
direction_label = "向前传播" if step.get("direction") == "backward" else "向后传播"
progress_before = 5 + int(((index - 1) / total_steps) * 90)
_set_task_state(
db,
task,
progress=progress_before,
message=f"{direction_label} {seed_label} ({index}/{total_steps})",
result={
"project_id": project_id,
"source_frame_id": source_frame_id,
"model": model_id,
"total_steps": total_steps,
"completed_steps": index - 1,
"processed_frame_count": processed_count,
"created_annotation_count": created_count,
"deleted_annotation_count": deleted_count,
"skipped_seed_count": skipped_count,
"steps": step_results,
},
)
result = _run_one_step(
db,
payload=payload,
frames=frames,
source_frame=source_frame,
source_position=source_position,
step=step,
)
step_results.append(result)
created_count += int(result["created_annotation_count"])
processed_count += int(result["processed_frame_count"])
deleted_count += int(result.get("deleted_annotation_count") or 0)
skipped_count += int(result.get("skipped_seed_count") or 0)
_set_task_state(
db,
task,
progress=5 + int((index / total_steps) * 90),
message=f"{direction_label} {seed_label} 完成 ({index}/{total_steps})",
result={
"project_id": project_id,
"source_frame_id": source_frame_id,
"model": model_id,
"total_steps": total_steps,
"completed_steps": index,
"processed_frame_count": processed_count,
"created_annotation_count": created_count,
"deleted_annotation_count": deleted_count,
"skipped_seed_count": skipped_count,
"steps": step_results,
},
)
result = {
"project_id": project_id,
"source_frame_id": source_frame_id,
"model": model_id,
"total_steps": total_steps,
"completed_steps": total_steps,
"processed_frame_count": processed_count,
"created_annotation_count": created_count,
"deleted_annotation_count": deleted_count,
"skipped_seed_count": skipped_count,
"steps": step_results,
}
_set_task_state(
db,
task,
status=TASK_STATUS_SUCCESS,
progress=100,
message="自动传播完成" if created_count > 0 else (
"自动传播完成,未改变的 mask 已跳过" if skipped_count > 0 else "自动传播完成,但没有生成新的 mask"
),
result=result,
finished=True,
)
return result
except PropagationTaskCancelled:
task.status = TASK_STATUS_CANCELLED
task.progress = 100
task.message = task.message or "任务已取消"
task.error = task.error or "Cancelled by user"
task.finished_at = task.finished_at or _now()
db.commit()
db.refresh(task)
publish_task_progress_event(task)
return {"task_id": task.id, "project_id": project_id, "status": TASK_STATUS_CANCELLED, "message": task.message}
except (ModelUnavailableError, NotImplementedError, ValueError) as exc:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
raise
except Exception as exc: # noqa: BLE001
logger.exception("Propagation task failed: task_id=%s", task.id)
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
raise