添加Docker自包含部署分支
- 新增 Seg_Server_Docker 自包含部署内容,包含前后端、FastAPI、Celery、PostgreSQL、Redis、MinIO、演示视频和 DICOM 数据。 - 保留 demo 数据以支持恢复演示出厂设置,排除 SAM 2.1 .pt 权重并在 README 中补充下载命令。 - 补充 GPU 部署、backend/worker 镜像复用、frpc/frps + NPM 公网域名反代部署说明。 - 在 .env/.env.example 中用 # XXXX 标注局域网和公网域名部署需要修改的配置项。 - 添加部署分支 .gitignore,忽略本地模型权重、构建产物、缓存和日志。
This commit is contained in:
842
backend/services/propagation_task_runner.py
Normal file
842
backend/services/propagation_task_runner.py
Normal file
@@ -0,0 +1,842 @@
|
||||
"""Background SAM video propagation runner used by Celery workers."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from minio_client import download_file
|
||||
from models import Annotation, Frame, ProcessingTask, Project
|
||||
from progress_events import publish_task_progress_event
|
||||
from services.sam_registry import ModelUnavailableError, sam_registry
|
||||
from statuses import (
|
||||
TASK_STATUS_CANCELLED,
|
||||
TASK_STATUS_FAILED,
|
||||
TASK_STATUS_RUNNING,
|
||||
TASK_STATUS_SUCCESS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PropagationTaskCancelled(RuntimeError):
|
||||
"""Raised internally when a persisted propagation task has been cancelled."""
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _set_task_state(
|
||||
db: Session,
|
||||
task: ProcessingTask,
|
||||
*,
|
||||
status: str | None = None,
|
||||
progress: int | None = None,
|
||||
message: str | None = None,
|
||||
result: dict[str, Any] | None = None,
|
||||
error: str | None = None,
|
||||
started: bool = False,
|
||||
finished: bool = False,
|
||||
) -> None:
|
||||
if status is not None:
|
||||
task.status = status
|
||||
if progress is not None:
|
||||
task.progress = max(0, min(100, progress))
|
||||
if message is not None:
|
||||
task.message = message
|
||||
if result is not None:
|
||||
task.result = result
|
||||
if error is not None:
|
||||
task.error = error
|
||||
if started:
|
||||
task.started_at = _now()
|
||||
if finished:
|
||||
task.finished_at = _now()
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
|
||||
|
||||
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
|
||||
db.refresh(task)
|
||||
if task.status == TASK_STATUS_CANCELLED:
|
||||
raise PropagationTaskCancelled("Task was cancelled")
|
||||
|
||||
|
||||
def _clamp01(value: float) -> float:
|
||||
return min(max(float(value), 0.0), 1.0)
|
||||
|
||||
|
||||
def _polygon_bbox(polygon: list[list[float]]) -> list[float]:
|
||||
xs = [_clamp01(point[0]) for point in polygon]
|
||||
ys = [_clamp01(point[1]) for point in polygon]
|
||||
left, right = min(xs), max(xs)
|
||||
top, bottom = min(ys), max(ys)
|
||||
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
|
||||
|
||||
|
||||
def _polygons_bbox(polygons: list[list[list[float]]]) -> list[float]:
|
||||
points = [point for polygon in polygons for point in polygon if len(point) >= 2]
|
||||
if not points:
|
||||
return [0.0, 0.0, 0.0, 0.0]
|
||||
xs = [_clamp01(point[0]) for point in points]
|
||||
ys = [_clamp01(point[1]) for point in points]
|
||||
left, right = min(xs), max(xs)
|
||||
top, bottom = min(ys), max(ys)
|
||||
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
|
||||
|
||||
|
||||
def _normalize_polygon(polygon: list[list[float]]) -> list[list[float]]:
|
||||
return [[_clamp01(point[0]), _clamp01(point[1])] for point in polygon if len(point) >= 2]
|
||||
|
||||
|
||||
def _normalize_smoothing_options(value: Any) -> dict[str, Any] | None:
|
||||
if not isinstance(value, dict):
|
||||
return None
|
||||
try:
|
||||
strength = max(0.0, min(float(value.get("strength") or 0.0), 100.0))
|
||||
except (TypeError, ValueError):
|
||||
strength = 0.0
|
||||
if strength <= 0:
|
||||
return None
|
||||
method = str(value.get("method") or "chaikin").lower()
|
||||
if method != "chaikin":
|
||||
method = "chaikin"
|
||||
return {"strength": round(strength, 2), "method": method}
|
||||
|
||||
|
||||
def _smoothing_ratio(strength: float, curve: float = 1.65) -> float:
|
||||
normalized = max(0.0, min(float(strength or 0.0), 100.0)) / 100.0
|
||||
return normalized ** curve
|
||||
|
||||
|
||||
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int, corner_cut: float = 0.25) -> list[list[float]]:
|
||||
points = _normalize_polygon(polygon)
|
||||
q = max(0.02, min(float(corner_cut), 0.25))
|
||||
for _ in range(max(0, iterations)):
|
||||
if len(points) < 3:
|
||||
break
|
||||
next_points: list[list[float]] = []
|
||||
for index, current in enumerate(points):
|
||||
following = points[(index + 1) % len(points)]
|
||||
next_points.append([
|
||||
_clamp01((1.0 - q) * current[0] + q * following[0]),
|
||||
_clamp01((1.0 - q) * current[1] + q * following[1]),
|
||||
])
|
||||
next_points.append([
|
||||
_clamp01(q * current[0] + (1.0 - q) * following[0]),
|
||||
_clamp01(q * current[1] + (1.0 - q) * following[1]),
|
||||
])
|
||||
points = next_points
|
||||
return points
|
||||
|
||||
|
||||
def _simplify_polygon(polygon: list[list[float]], strength: float) -> list[list[float]]:
|
||||
if len(polygon) < 3:
|
||||
return polygon
|
||||
contour = np.array([[[point[0], point[1]]] for point in polygon], dtype=np.float32)
|
||||
arc_length = cv2.arcLength(contour, True)
|
||||
epsilon = arc_length * (0.00015 + _smoothing_ratio(strength) * 0.00735)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True).reshape(-1, 2)
|
||||
if len(approx) < 3:
|
||||
return polygon
|
||||
return [[_clamp01(float(x)), _clamp01(float(y))] for x, y in approx]
|
||||
|
||||
|
||||
def _smooth_polygon(polygon: list[list[float]], smoothing: dict[str, Any] | None) -> list[list[float]]:
|
||||
if not smoothing:
|
||||
return _normalize_polygon(polygon)
|
||||
strength = float(smoothing.get("strength") or 0.0)
|
||||
if strength <= 0:
|
||||
return _normalize_polygon(polygon)
|
||||
effective_strength = _smoothing_ratio(strength, curve=1.45) * 100.0
|
||||
if effective_strength >= 85:
|
||||
iterations = 4
|
||||
elif effective_strength >= 55:
|
||||
iterations = 3
|
||||
elif effective_strength >= 25:
|
||||
iterations = 2
|
||||
else:
|
||||
iterations = 1
|
||||
corner_cut = 0.03 + _smoothing_ratio(strength, curve=1.35) * 0.22
|
||||
normalized = _normalize_polygon(polygon)
|
||||
pre_simplified = _simplify_polygon(normalized, effective_strength * 0.25)
|
||||
smoothed = _chaikin_smooth_polygon(pre_simplified, iterations, corner_cut)
|
||||
simplified = _simplify_polygon(smoothed, effective_strength)
|
||||
if len(simplified) > len(normalized):
|
||||
for fallback_strength in (25.0, 35.0, 50.0, 70.0, 90.0, 100.0):
|
||||
simplified = _simplify_polygon(simplified, max(effective_strength, fallback_strength))
|
||||
if len(simplified) <= len(normalized):
|
||||
break
|
||||
return simplified if len(simplified) >= 3 else _normalize_polygon(polygon)
|
||||
|
||||
|
||||
def _bbox_area(bbox: list[float]) -> float:
|
||||
return max(float(bbox[2]), 0.0) * max(float(bbox[3]), 0.0)
|
||||
|
||||
|
||||
def _bbox_overlap_ratio(a: list[float], b: list[float]) -> float:
|
||||
ax1, ay1, aw, ah = a
|
||||
bx1, by1, bw, bh = b
|
||||
ax2 = ax1 + aw
|
||||
ay2 = ay1 + ah
|
||||
bx2 = bx1 + bw
|
||||
by2 = by1 + bh
|
||||
overlap_width = max(0.0, min(ax2, bx2) - max(ax1, bx1))
|
||||
overlap_height = max(0.0, min(ay2, by2) - max(ay1, by1))
|
||||
overlap_area = overlap_width * overlap_height
|
||||
smallest_area = min(_bbox_area(a), _bbox_area(b))
|
||||
return overlap_area / smallest_area if smallest_area > 0 else 0.0
|
||||
|
||||
|
||||
def _stable_json(value: Any) -> str:
|
||||
return json.dumps(value, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||
|
||||
|
||||
def _canonicalize_signature_value(value: Any) -> Any:
|
||||
if isinstance(value, float):
|
||||
return round(value, 6)
|
||||
if isinstance(value, list):
|
||||
return [_canonicalize_signature_value(item) for item in value]
|
||||
if isinstance(value, dict):
|
||||
return {key: _canonicalize_signature_value(value[key]) for key in sorted(value)}
|
||||
return value
|
||||
|
||||
|
||||
def _seed_signature(seed: dict[str, Any]) -> str:
|
||||
"""Return a stable signature for seed geometry and semantic attrs."""
|
||||
inherited_signature = seed.get("propagation_seed_signature")
|
||||
if inherited_signature:
|
||||
return str(inherited_signature)
|
||||
signature_payload = {
|
||||
"polygons": seed.get("polygons") or [],
|
||||
"holes": seed.get("holes") or [],
|
||||
"bbox": seed.get("bbox") or [],
|
||||
"points": seed.get("points") or [],
|
||||
"labels": seed.get("labels") or [],
|
||||
"label": seed.get("label"),
|
||||
"color": seed.get("color"),
|
||||
"class_metadata": seed.get("class_metadata") or {},
|
||||
"template_id": seed.get("template_id"),
|
||||
"smoothing": _normalize_smoothing_options(seed.get("smoothing")),
|
||||
}
|
||||
return hashlib.sha256(_stable_json(_canonicalize_signature_value(signature_payload)).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _seed_key(seed: dict[str, Any]) -> str:
|
||||
"""Prefer stable persisted ids; fall back to semantic attrs for legacy callers."""
|
||||
source_instance_id = seed.get("source_instance_id")
|
||||
if source_instance_id:
|
||||
return f"instance:{source_instance_id}"
|
||||
source_annotation_id = seed.get("source_annotation_id")
|
||||
if source_annotation_id is not None:
|
||||
return f"annotation:{source_annotation_id}"
|
||||
source_mask_id = seed.get("source_mask_id")
|
||||
if source_mask_id:
|
||||
return f"mask:{source_mask_id}"
|
||||
class_metadata = seed.get("class_metadata") or {}
|
||||
class_id = class_metadata.get("id") or class_metadata.get("name")
|
||||
return _stable_json({
|
||||
"template_id": seed.get("template_id"),
|
||||
"class_id": class_id,
|
||||
"label": seed.get("label"),
|
||||
"color": seed.get("color"),
|
||||
})
|
||||
|
||||
|
||||
def _semantic_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool:
|
||||
"""Best-effort match when a manually edited replacement lacks old lineage ids."""
|
||||
class_metadata = seed.get("class_metadata") or {}
|
||||
previous_class = mask_data.get("class") or {}
|
||||
previous_class_id = previous_class.get("id") or previous_class.get("name")
|
||||
class_id = class_metadata.get("id") or class_metadata.get("name")
|
||||
if previous_class_id and class_id and str(previous_class_id) != str(class_id):
|
||||
return False
|
||||
return (
|
||||
mask_data.get("label") == seed.get("label")
|
||||
and mask_data.get("color") == seed.get("color")
|
||||
)
|
||||
|
||||
|
||||
def _legacy_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool:
|
||||
"""Best-effort match for propagation annotations created before seed keys."""
|
||||
class_metadata = seed.get("class_metadata") or {}
|
||||
previous_class = mask_data.get("class") or {}
|
||||
previous_class_id = previous_class.get("id") or previous_class.get("name")
|
||||
class_id = class_metadata.get("id") or class_metadata.get("name")
|
||||
return (
|
||||
mask_data.get("label") == seed.get("label")
|
||||
and mask_data.get("color") == seed.get("color")
|
||||
and previous_class_id == class_id
|
||||
)
|
||||
|
||||
|
||||
def _source_model_matches(mask_data: dict[str, Any], model_id: str) -> bool:
|
||||
return str(mask_data.get("source") or "") == f"{model_id}_propagation"
|
||||
|
||||
|
||||
def _seed_identity_matches(mask_data: dict[str, Any], seed_key: str, seed: dict[str, Any]) -> bool:
|
||||
previous_seed_key = mask_data.get("propagation_seed_key")
|
||||
if previous_seed_key == seed_key:
|
||||
return True
|
||||
source_instance_id = seed.get("source_instance_id")
|
||||
if source_instance_id and (
|
||||
mask_data.get("source_instance_id") == source_instance_id
|
||||
or mask_data.get("instance_id") == source_instance_id
|
||||
):
|
||||
return True
|
||||
source_annotation_id = seed.get("source_annotation_id")
|
||||
if source_annotation_id is not None and str(mask_data.get("source_annotation_id") or "") == str(source_annotation_id):
|
||||
return True
|
||||
source_mask_id = seed.get("source_mask_id")
|
||||
if source_mask_id and mask_data.get("source_mask_id") == source_mask_id:
|
||||
return True
|
||||
has_persisted_seed_identity = bool(source_instance_id) or source_annotation_id is not None or bool(source_mask_id)
|
||||
has_previous_identity = (
|
||||
bool(previous_seed_key)
|
||||
or mask_data.get("source_instance_id") is not None
|
||||
or mask_data.get("instance_id") is not None
|
||||
or mask_data.get("source_annotation_id") is not None
|
||||
or bool(mask_data.get("source_mask_id"))
|
||||
)
|
||||
if has_persisted_seed_identity or has_previous_identity:
|
||||
return False
|
||||
return _legacy_seed_matches(mask_data, seed)
|
||||
|
||||
|
||||
def _seed_identity_markers(seed: dict[str, Any]) -> set[str]:
|
||||
markers = {f"seed:{_seed_key(seed)}"}
|
||||
source_instance_id = seed.get("source_instance_id")
|
||||
if source_instance_id:
|
||||
markers.add(f"instance:{source_instance_id}")
|
||||
source_annotation_id = seed.get("source_annotation_id")
|
||||
if source_annotation_id is not None:
|
||||
markers.add(f"annotation:{source_annotation_id}")
|
||||
source_mask_id = seed.get("source_mask_id")
|
||||
if source_mask_id:
|
||||
markers.add(f"mask:{source_mask_id}")
|
||||
return markers
|
||||
|
||||
|
||||
def _mask_identity_markers(mask_data: dict[str, Any]) -> set[str]:
|
||||
markers: set[str] = set()
|
||||
previous_seed_key = mask_data.get("propagation_seed_key")
|
||||
if previous_seed_key:
|
||||
markers.add(f"seed:{previous_seed_key}")
|
||||
source_instance_id = mask_data.get("source_instance_id")
|
||||
if source_instance_id:
|
||||
markers.add(f"instance:{source_instance_id}")
|
||||
instance_id = mask_data.get("instance_id")
|
||||
if instance_id:
|
||||
markers.add(f"instance:{instance_id}")
|
||||
source_annotation_id = mask_data.get("source_annotation_id")
|
||||
if source_annotation_id is not None:
|
||||
markers.add(f"annotation:{source_annotation_id}")
|
||||
source_mask_id = mask_data.get("source_mask_id")
|
||||
if source_mask_id:
|
||||
markers.add(f"mask:{source_mask_id}")
|
||||
return markers
|
||||
|
||||
|
||||
def _payload_seed_identity_markers(payload: dict[str, Any]) -> set[str]:
|
||||
markers: set[str] = set()
|
||||
for step in payload.get("steps") or []:
|
||||
seed = step.get("seed") or {}
|
||||
markers.update(_seed_identity_markers(seed))
|
||||
return markers
|
||||
|
||||
|
||||
def _is_propagation_annotation(annotation: Annotation, seed_key: str, seed: dict[str, Any]) -> bool:
|
||||
mask_data = annotation.mask_data or {}
|
||||
source = str(mask_data.get("source") or "")
|
||||
if not source.endswith("_propagation"):
|
||||
return False
|
||||
return _seed_identity_matches(mask_data, seed_key, seed)
|
||||
|
||||
|
||||
def _direction_matches(mask_data: dict[str, Any], direction: str) -> bool:
|
||||
previous_direction = mask_data.get("propagation_direction")
|
||||
return previous_direction in {None, direction}
|
||||
|
||||
|
||||
def _annotation_spatially_matches(annotation: Annotation, polygon: list[list[float]]) -> bool:
|
||||
"""Use target-frame overlap as a final guard before replacing same-object propagation."""
|
||||
candidate_bbox = _polygon_bbox(polygon)
|
||||
for previous_polygon in (annotation.mask_data or {}).get("polygons") or []:
|
||||
if len(previous_polygon) < 3:
|
||||
continue
|
||||
if _bbox_overlap_ratio(_polygon_bbox(previous_polygon), candidate_bbox) >= 0.15:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _delete_replaced_frame_annotations(
|
||||
db: Session,
|
||||
*,
|
||||
payload: dict[str, Any],
|
||||
frame_id: int,
|
||||
seed_key: str,
|
||||
seed: dict[str, Any],
|
||||
polygon: list[list[float]],
|
||||
) -> int:
|
||||
"""Delete old propagated masks for the same object immediately before writing a new result."""
|
||||
previous_annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == int(payload["project_id"]))
|
||||
.filter(Annotation.frame_id == frame_id)
|
||||
.all()
|
||||
)
|
||||
deleted_count = 0
|
||||
current_seed_markers = _seed_identity_markers(seed)
|
||||
task_seed_markers = _payload_seed_identity_markers(payload)
|
||||
for annotation in previous_annotations:
|
||||
mask_data = annotation.mask_data or {}
|
||||
source = str(mask_data.get("source") or "")
|
||||
if not source.endswith("_propagation"):
|
||||
continue
|
||||
source_instance_id = seed.get("source_instance_id")
|
||||
mask_instance_ids = {
|
||||
str(value)
|
||||
for value in (mask_data.get("source_instance_id"), mask_data.get("instance_id"))
|
||||
if value
|
||||
}
|
||||
if source_instance_id and mask_instance_ids and str(source_instance_id) not in mask_instance_ids:
|
||||
continue
|
||||
mask_markers = _mask_identity_markers(mask_data)
|
||||
# Keep sibling seeds in the same propagation task from deleting each other.
|
||||
if mask_markers and mask_markers.isdisjoint(current_seed_markers) and not mask_markers.isdisjoint(task_seed_markers):
|
||||
continue
|
||||
same_lineage = _seed_identity_matches(mask_data, seed_key, seed)
|
||||
same_manual_replacement = (
|
||||
_semantic_seed_matches(mask_data, seed)
|
||||
and _annotation_spatially_matches(annotation, polygon)
|
||||
)
|
||||
if same_lineage or same_manual_replacement:
|
||||
db.delete(annotation)
|
||||
deleted_count += 1
|
||||
if deleted_count:
|
||||
db.commit()
|
||||
return deleted_count
|
||||
|
||||
|
||||
def _prepare_seed_propagation(
|
||||
db: Session,
|
||||
*,
|
||||
payload: dict[str, Any],
|
||||
model_id: str,
|
||||
seed: dict[str, Any],
|
||||
direction: str,
|
||||
target_frame_ids: set[int],
|
||||
) -> dict[str, Any]:
|
||||
seed_key = _seed_key(seed)
|
||||
seed_signature = _seed_signature(seed)
|
||||
if not target_frame_ids:
|
||||
return {
|
||||
"skip": True,
|
||||
"seed_key": seed_key,
|
||||
"seed_signature": seed_signature,
|
||||
"deleted_annotation_count": 0,
|
||||
}
|
||||
previous_annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == int(payload["project_id"]))
|
||||
.filter(Annotation.frame_id.in_(target_frame_ids))
|
||||
.all()
|
||||
)
|
||||
matching = [
|
||||
annotation for annotation in previous_annotations
|
||||
if _is_propagation_annotation(annotation, seed_key, seed)
|
||||
and _direction_matches(annotation.mask_data or {}, direction)
|
||||
]
|
||||
covered_frame_ids = {int(annotation.frame_id) for annotation in matching}
|
||||
if matching and all(
|
||||
(annotation.mask_data or {}).get("propagation_seed_signature") == seed_signature
|
||||
and _source_model_matches(annotation.mask_data or {}, model_id)
|
||||
for annotation in matching
|
||||
) and target_frame_ids.issubset(covered_frame_ids):
|
||||
return {
|
||||
"skip": True,
|
||||
"seed_key": seed_key,
|
||||
"seed_signature": seed_signature,
|
||||
"deleted_annotation_count": 0,
|
||||
}
|
||||
|
||||
deleted_count = 0
|
||||
if matching:
|
||||
for annotation in matching:
|
||||
db.delete(annotation)
|
||||
deleted_count += 1
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"skip": False,
|
||||
"seed_key": seed_key,
|
||||
"seed_signature": seed_signature,
|
||||
"deleted_annotation_count": deleted_count,
|
||||
}
|
||||
|
||||
|
||||
def _frame_window(
|
||||
frames: list[Frame],
|
||||
source_position: int,
|
||||
direction: str,
|
||||
max_frames: int,
|
||||
) -> tuple[list[Frame], int]:
|
||||
count = max(1, min(max_frames, len(frames)))
|
||||
if direction == "backward":
|
||||
start = max(0, source_position - count + 1)
|
||||
return frames[start:source_position + 1], source_position - start
|
||||
end = min(len(frames), source_position + count)
|
||||
return frames[source_position:end], 0
|
||||
|
||||
|
||||
def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]:
|
||||
paths = []
|
||||
for index, frame in enumerate(frames):
|
||||
data = download_file(frame.image_url)
|
||||
# SAM2VideoPredictor sorts frames by converting the filename stem to int.
|
||||
path = directory / f"{index:06d}.jpg"
|
||||
path.write_bytes(data)
|
||||
paths.append(str(path))
|
||||
return paths
|
||||
|
||||
|
||||
def _save_propagated_annotations(
|
||||
db: Session,
|
||||
*,
|
||||
payload: dict[str, Any],
|
||||
selected_frames: list[Frame],
|
||||
source_frame: Frame,
|
||||
propagated: list[dict[str, Any]],
|
||||
seed: dict[str, Any],
|
||||
) -> tuple[list[Annotation], int]:
|
||||
created: list[Annotation] = []
|
||||
if payload.get("save_annotations", True) is False:
|
||||
return created, 0
|
||||
|
||||
class_metadata = seed.get("class_metadata")
|
||||
template_id = seed.get("template_id")
|
||||
label = seed.get("label") or "Propagated Mask"
|
||||
color = seed.get("color") or "#06b6d4"
|
||||
model_id = sam_registry.normalize_model_id(payload.get("model"))
|
||||
include_source = bool(payload.get("include_source", False))
|
||||
seed_key = _seed_key(seed)
|
||||
seed_signature = _seed_signature(seed)
|
||||
source_annotation_id = seed.get("source_annotation_id")
|
||||
source_mask_id = seed.get("source_mask_id")
|
||||
source_instance_id = seed.get("source_instance_id") or seed_key
|
||||
smoothing = _normalize_smoothing_options(seed.get("smoothing"))
|
||||
direction = str(payload.get("current_direction") or "")
|
||||
deleted_count = 0
|
||||
cleaned_frame_ids: set[int] = set()
|
||||
|
||||
for frame_result in propagated:
|
||||
relative_index = int(frame_result.get("frame_index", -1))
|
||||
if relative_index < 0 or relative_index >= len(selected_frames):
|
||||
continue
|
||||
frame = selected_frames[relative_index]
|
||||
if not include_source and frame.id == source_frame.id:
|
||||
continue
|
||||
result_polygons = frame_result.get("polygons") or []
|
||||
result_holes = frame_result.get("holes") or []
|
||||
scores = frame_result.get("scores") or []
|
||||
prepared_polygons = [
|
||||
(polygon_index, _smooth_polygon(polygon, smoothing))
|
||||
for polygon_index, polygon in enumerate(result_polygons)
|
||||
if len(polygon) >= 3
|
||||
]
|
||||
cleanup_polygon = next((polygon for _polygon_index, polygon in prepared_polygons if len(polygon) >= 3), None)
|
||||
if cleanup_polygon is not None and frame.id not in cleaned_frame_ids:
|
||||
deleted_count += _delete_replaced_frame_annotations(
|
||||
db,
|
||||
payload=payload,
|
||||
frame_id=int(frame.id),
|
||||
seed_key=seed_key,
|
||||
seed=seed,
|
||||
polygon=cleanup_polygon,
|
||||
)
|
||||
cleaned_frame_ids.add(int(frame.id))
|
||||
polygons_to_save: list[list[list[float]]] = []
|
||||
holes_to_save: list[list[list[list[float]]]] = []
|
||||
score_values: list[float] = []
|
||||
for polygon_index, polygon in prepared_polygons:
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
polygons_to_save.append(polygon)
|
||||
hole_group = result_holes[polygon_index] if polygon_index < len(result_holes) and isinstance(result_holes[polygon_index], list) else []
|
||||
holes_to_save.append(hole_group if isinstance(hole_group, list) else [])
|
||||
if polygon_index < len(scores):
|
||||
try:
|
||||
score_values.append(float(scores[polygon_index]))
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
if not polygons_to_save:
|
||||
continue
|
||||
annotation = Annotation(
|
||||
project_id=int(payload["project_id"]),
|
||||
frame_id=frame.id,
|
||||
template_id=template_id,
|
||||
mask_data={
|
||||
"polygons": polygons_to_save,
|
||||
**({"holes": holes_to_save, "hasHoles": True} if any(holes_to_save) else {}),
|
||||
"label": label,
|
||||
"color": color,
|
||||
"source": f"{model_id}_propagation",
|
||||
"propagated_from_frame_id": source_frame.id,
|
||||
"propagated_from_frame_index": source_frame.frame_index,
|
||||
"propagation_seed_key": seed_key,
|
||||
"propagation_seed_signature": seed_signature,
|
||||
"propagation_direction": direction,
|
||||
"instance_id": source_instance_id,
|
||||
"source_instance_id": source_instance_id,
|
||||
"source_annotation_id": source_annotation_id,
|
||||
"source_mask_id": source_mask_id,
|
||||
"score": max(score_values) if score_values else None,
|
||||
**({"scores": score_values} if len(score_values) > 1 else {}),
|
||||
**({"geometry_smoothing": smoothing} if smoothing else {}),
|
||||
**({"class": class_metadata} if class_metadata else {}),
|
||||
},
|
||||
points=None,
|
||||
bbox=_polygons_bbox(polygons_to_save),
|
||||
)
|
||||
db.add(annotation)
|
||||
created.append(annotation)
|
||||
|
||||
db.commit()
|
||||
for annotation in created:
|
||||
db.refresh(annotation)
|
||||
return created, deleted_count
|
||||
|
||||
|
||||
def _run_one_step(
|
||||
db: Session,
|
||||
*,
|
||||
payload: dict[str, Any],
|
||||
frames: list[Frame],
|
||||
source_frame: Frame,
|
||||
source_position: int,
|
||||
step: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
direction = str(step.get("direction") or "forward").lower()
|
||||
if direction not in {"forward", "backward"}:
|
||||
raise ValueError("direction must be forward or backward")
|
||||
max_frames = max(1, min(int(step.get("max_frames") or payload.get("max_frames") or 30), 500))
|
||||
seed = step.get("seed") or {}
|
||||
if not (seed.get("polygons") or seed.get("bbox") or seed.get("points")):
|
||||
raise ValueError("Propagation requires seed polygons, bbox, or points")
|
||||
|
||||
model_id = sam_registry.normalize_model_id(payload.get("model"))
|
||||
selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames)
|
||||
include_source = bool(payload.get("include_source", False))
|
||||
target_frame_ids = {
|
||||
int(frame.id)
|
||||
for frame in selected_frames
|
||||
if include_source or frame.id != source_frame.id
|
||||
}
|
||||
seed_state = _prepare_seed_propagation(
|
||||
db,
|
||||
payload=payload,
|
||||
model_id=model_id,
|
||||
seed=seed,
|
||||
direction=direction,
|
||||
target_frame_ids=target_frame_ids,
|
||||
)
|
||||
if seed_state["skip"]:
|
||||
return {
|
||||
"model": model_id,
|
||||
"direction": direction,
|
||||
"processed_frame_count": 0,
|
||||
"created_annotation_count": 0,
|
||||
"deleted_annotation_count": 0,
|
||||
"skipped_seed_count": 1,
|
||||
"seed_label": seed.get("label"),
|
||||
"seed_key": seed_state["seed_key"],
|
||||
}
|
||||
|
||||
with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload['project_id']}_") as tmpdir:
|
||||
frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir))
|
||||
propagated = sam_registry.propagate_video(
|
||||
model_id,
|
||||
frame_paths,
|
||||
source_relative_index,
|
||||
seed,
|
||||
direction,
|
||||
len(selected_frames),
|
||||
)
|
||||
|
||||
save_payload = {**payload, "current_direction": direction}
|
||||
created, write_cleanup_count = _save_propagated_annotations(
|
||||
db,
|
||||
payload=save_payload,
|
||||
selected_frames=selected_frames,
|
||||
source_frame=source_frame,
|
||||
propagated=propagated,
|
||||
seed=seed,
|
||||
)
|
||||
return {
|
||||
"model": model_id,
|
||||
"direction": direction,
|
||||
"processed_frame_count": len(selected_frames),
|
||||
"created_annotation_count": len(created),
|
||||
"deleted_annotation_count": int(seed_state["deleted_annotation_count"]) + write_cleanup_count,
|
||||
"skipped_seed_count": 0,
|
||||
"seed_label": seed.get("label"),
|
||||
"seed_key": seed_state["seed_key"],
|
||||
}
|
||||
|
||||
|
||||
def run_propagate_project_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
"""Run one queued SAM propagation task and update persisted progress."""
|
||||
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
|
||||
if not task:
|
||||
raise ValueError(f"Task not found: {task_id}")
|
||||
|
||||
if task.status == TASK_STATUS_CANCELLED:
|
||||
return {"task_id": task.id, "status": TASK_STATUS_CANCELLED, "message": task.message or "任务已取消"}
|
||||
|
||||
payload = task.payload or {}
|
||||
project_id = int(payload.get("project_id") or task.project_id or 0)
|
||||
source_frame_id = int(payload.get("frame_id") or 0)
|
||||
try:
|
||||
model_id = sam_registry.normalize_model_id(payload.get("model"))
|
||||
except ValueError as exc:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
|
||||
raise
|
||||
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="项目不存在", error="Project not found", finished=True)
|
||||
raise ValueError(f"Project not found: {project_id}")
|
||||
|
||||
source_frame = db.query(Frame).filter(Frame.id == source_frame_id, Frame.project_id == project_id).first()
|
||||
if not source_frame:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不存在", error="Frame not found", finished=True)
|
||||
raise ValueError(f"Frame not found: {source_frame_id}")
|
||||
|
||||
frames = db.query(Frame).filter(Frame.project_id == project_id).order_by(Frame.frame_index).all()
|
||||
source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None)
|
||||
if source_position is None:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不在项目帧序列中", error="Source frame is not in project frame sequence", finished=True)
|
||||
raise ValueError("Source frame is not in project frame sequence")
|
||||
|
||||
steps = payload.get("steps") or []
|
||||
if not steps:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="传播任务缺少步骤", error="Propagation task has no steps", finished=True)
|
||||
raise ValueError("Propagation task has no steps")
|
||||
|
||||
_ensure_not_cancelled(db, task)
|
||||
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="自动传播任务已启动", started=True)
|
||||
|
||||
step_results: list[dict[str, Any]] = []
|
||||
created_count = 0
|
||||
processed_count = 0
|
||||
deleted_count = 0
|
||||
skipped_count = 0
|
||||
total_steps = len(steps)
|
||||
|
||||
try:
|
||||
for index, step in enumerate(steps, start=1):
|
||||
_ensure_not_cancelled(db, task)
|
||||
seed_label = (step.get("seed") or {}).get("label") or "mask"
|
||||
direction_label = "向前传播" if step.get("direction") == "backward" else "向后传播"
|
||||
progress_before = 5 + int(((index - 1) / total_steps) * 90)
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
progress=progress_before,
|
||||
message=f"{direction_label} {seed_label} ({index}/{total_steps})",
|
||||
result={
|
||||
"project_id": project_id,
|
||||
"source_frame_id": source_frame_id,
|
||||
"model": model_id,
|
||||
"total_steps": total_steps,
|
||||
"completed_steps": index - 1,
|
||||
"processed_frame_count": processed_count,
|
||||
"created_annotation_count": created_count,
|
||||
"deleted_annotation_count": deleted_count,
|
||||
"skipped_seed_count": skipped_count,
|
||||
"steps": step_results,
|
||||
},
|
||||
)
|
||||
|
||||
result = _run_one_step(
|
||||
db,
|
||||
payload=payload,
|
||||
frames=frames,
|
||||
source_frame=source_frame,
|
||||
source_position=source_position,
|
||||
step=step,
|
||||
)
|
||||
step_results.append(result)
|
||||
created_count += int(result["created_annotation_count"])
|
||||
processed_count += int(result["processed_frame_count"])
|
||||
deleted_count += int(result.get("deleted_annotation_count") or 0)
|
||||
skipped_count += int(result.get("skipped_seed_count") or 0)
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
progress=5 + int((index / total_steps) * 90),
|
||||
message=f"{direction_label} {seed_label} 完成 ({index}/{total_steps})",
|
||||
result={
|
||||
"project_id": project_id,
|
||||
"source_frame_id": source_frame_id,
|
||||
"model": model_id,
|
||||
"total_steps": total_steps,
|
||||
"completed_steps": index,
|
||||
"processed_frame_count": processed_count,
|
||||
"created_annotation_count": created_count,
|
||||
"deleted_annotation_count": deleted_count,
|
||||
"skipped_seed_count": skipped_count,
|
||||
"steps": step_results,
|
||||
},
|
||||
)
|
||||
|
||||
result = {
|
||||
"project_id": project_id,
|
||||
"source_frame_id": source_frame_id,
|
||||
"model": model_id,
|
||||
"total_steps": total_steps,
|
||||
"completed_steps": total_steps,
|
||||
"processed_frame_count": processed_count,
|
||||
"created_annotation_count": created_count,
|
||||
"deleted_annotation_count": deleted_count,
|
||||
"skipped_seed_count": skipped_count,
|
||||
"steps": step_results,
|
||||
}
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
status=TASK_STATUS_SUCCESS,
|
||||
progress=100,
|
||||
message="自动传播完成" if created_count > 0 else (
|
||||
"自动传播完成,未改变的 mask 已跳过" if skipped_count > 0 else "自动传播完成,但没有生成新的 mask"
|
||||
),
|
||||
result=result,
|
||||
finished=True,
|
||||
)
|
||||
return result
|
||||
except PropagationTaskCancelled:
|
||||
task.status = TASK_STATUS_CANCELLED
|
||||
task.progress = 100
|
||||
task.message = task.message or "任务已取消"
|
||||
task.error = task.error or "Cancelled by user"
|
||||
task.finished_at = task.finished_at or _now()
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
return {"task_id": task.id, "project_id": project_id, "status": TASK_STATUS_CANCELLED, "message": task.message}
|
||||
except (ModelUnavailableError, NotImplementedError, ValueError) as exc:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
|
||||
raise
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("Propagation task failed: task_id=%s", task.id)
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
|
||||
raise
|
||||
Reference in New Issue
Block a user