功能增加:点击 Canvas mask 后,右侧语义分类树会按 classId/className/label 自动匹配分类,并滚动聚焦到对应分类按钮。
功能增加:工作区新增按起止帧批量清空片段遮罩,复用传播范围输入,范围内已保存标注走 DELETE /api/ai/annotations/{id},本地 draft mask 同步移除。
功能增加:右侧语义分类树上方新增工作区 mask 透明度滑杆,写入 Zustand maskPreviewOpacity,Canvas mask 预览按该值渲染并保留选中加亮反馈。
功能增加:视频处理进度条记录最近自动传播区间,使用不同色系深浅渐变提示最近处理片段。
功能增加:工作区自动传播前会先保存 draft/dirty seed mask,使用稳定后端 source_annotation_id 入队,减少二次传播重复结果。
Bugfix:后端传播任务对旧临时 seed id、不同 SAM 2.1 权重结果做兼容清理;相同 seed 和相同权重才跳过,否则先删旧自动传播标注再重传。
Bugfix:修复 polygon 顶点拖拽结束后触发 Stage 平移导致画布中心偏移的问题,并补充测试环境对 drag target 的模拟。
Bugfix:工具提示会在数秒后自动隐藏,避免创建多边形/矩形等提示长期遮挡画布。
UI 调整:移除右侧面板顶部‘本体论与属性分类管理树’说明栏,减少无效占位。
UI 调整:左侧工具栏和右侧语义面板使用低对比 seg-scrollbar;左侧工具栏外扩滚动条槽位,避免滚动条挤占图标列。
UI 调整:工作区模型状态徽标改为紧凑显示,减少与传播权重选择重复;传播权重下拉改成深色背景和青色文字,避免灰底白字不可读。
UI 调整:缩略图状态框固定优先级,当前帧、人工/AI 标注帧、自动传播帧可用外框/内框组合同时表达。
测试:补充 VideoWorkspace、CanvasArea、FrameTimeline、OntologyInspector、ToolsPalette、useStore 和后端 test_ai 覆盖新增交互、传播去重、批量清空、透明度、滚动条和 UI 状态。
文档:同步更新 README、AGENTS 和 doc/03、doc/04、doc/07、doc/08、doc/09,记录当前功能、接口契约、需求设计冻结和测试覆盖。
520 lines
19 KiB
Python
520 lines
19 KiB
Python
"""Background SAM video propagation runner used by Celery workers."""
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import tempfile
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from minio_client import download_file
|
|
from models import Annotation, Frame, ProcessingTask, Project
|
|
from progress_events import publish_task_progress_event
|
|
from services.sam_registry import ModelUnavailableError, sam_registry
|
|
from statuses import (
|
|
TASK_STATUS_CANCELLED,
|
|
TASK_STATUS_FAILED,
|
|
TASK_STATUS_RUNNING,
|
|
TASK_STATUS_SUCCESS,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PropagationTaskCancelled(RuntimeError):
|
|
"""Raised internally when a persisted propagation task has been cancelled."""
|
|
|
|
|
|
def _now() -> datetime:
|
|
return datetime.now(timezone.utc)
|
|
|
|
|
|
def _set_task_state(
|
|
db: Session,
|
|
task: ProcessingTask,
|
|
*,
|
|
status: str | None = None,
|
|
progress: int | None = None,
|
|
message: str | None = None,
|
|
result: dict[str, Any] | None = None,
|
|
error: str | None = None,
|
|
started: bool = False,
|
|
finished: bool = False,
|
|
) -> None:
|
|
if status is not None:
|
|
task.status = status
|
|
if progress is not None:
|
|
task.progress = max(0, min(100, progress))
|
|
if message is not None:
|
|
task.message = message
|
|
if result is not None:
|
|
task.result = result
|
|
if error is not None:
|
|
task.error = error
|
|
if started:
|
|
task.started_at = _now()
|
|
if finished:
|
|
task.finished_at = _now()
|
|
db.commit()
|
|
db.refresh(task)
|
|
publish_task_progress_event(task)
|
|
|
|
|
|
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
|
|
db.refresh(task)
|
|
if task.status == TASK_STATUS_CANCELLED:
|
|
raise PropagationTaskCancelled("Task was cancelled")
|
|
|
|
|
|
def _clamp01(value: float) -> float:
|
|
return min(max(float(value), 0.0), 1.0)
|
|
|
|
|
|
def _polygon_bbox(polygon: list[list[float]]) -> list[float]:
|
|
xs = [_clamp01(point[0]) for point in polygon]
|
|
ys = [_clamp01(point[1]) for point in polygon]
|
|
left, right = min(xs), max(xs)
|
|
top, bottom = min(ys), max(ys)
|
|
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
|
|
|
|
|
|
def _stable_json(value: Any) -> str:
|
|
return json.dumps(value, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
|
|
|
|
|
def _seed_signature(seed: dict[str, Any]) -> str:
|
|
"""Return a stable signature for seed geometry and semantic attrs."""
|
|
signature_payload = {
|
|
"polygons": seed.get("polygons") or [],
|
|
"bbox": seed.get("bbox") or [],
|
|
"points": seed.get("points") or [],
|
|
"labels": seed.get("labels") or [],
|
|
"label": seed.get("label"),
|
|
"color": seed.get("color"),
|
|
"class_metadata": seed.get("class_metadata") or {},
|
|
"template_id": seed.get("template_id"),
|
|
}
|
|
return hashlib.sha256(_stable_json(signature_payload).encode("utf-8")).hexdigest()
|
|
|
|
|
|
def _seed_key(seed: dict[str, Any]) -> str:
|
|
"""Prefer stable persisted ids; fall back to semantic attrs for legacy callers."""
|
|
source_annotation_id = seed.get("source_annotation_id")
|
|
if source_annotation_id is not None:
|
|
return f"annotation:{source_annotation_id}"
|
|
source_mask_id = seed.get("source_mask_id")
|
|
if source_mask_id:
|
|
return f"mask:{source_mask_id}"
|
|
class_metadata = seed.get("class_metadata") or {}
|
|
class_id = class_metadata.get("id") or class_metadata.get("name")
|
|
return _stable_json({
|
|
"template_id": seed.get("template_id"),
|
|
"class_id": class_id,
|
|
"label": seed.get("label"),
|
|
"color": seed.get("color"),
|
|
})
|
|
|
|
|
|
def _legacy_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool:
|
|
"""Best-effort match for propagation annotations created before seed keys."""
|
|
class_metadata = seed.get("class_metadata") or {}
|
|
previous_class = mask_data.get("class") or {}
|
|
previous_class_id = previous_class.get("id") or previous_class.get("name")
|
|
class_id = class_metadata.get("id") or class_metadata.get("name")
|
|
return (
|
|
mask_data.get("label") == seed.get("label")
|
|
and mask_data.get("color") == seed.get("color")
|
|
and previous_class_id == class_id
|
|
)
|
|
|
|
|
|
def _source_model_matches(mask_data: dict[str, Any], model_id: str) -> bool:
|
|
return str(mask_data.get("source") or "") == f"{model_id}_propagation"
|
|
|
|
|
|
def _is_propagation_annotation(
|
|
annotation: Annotation,
|
|
source_frame: Frame,
|
|
seed_key: str,
|
|
seed: dict[str, Any],
|
|
) -> bool:
|
|
mask_data = annotation.mask_data or {}
|
|
source = str(mask_data.get("source") or "")
|
|
if not source.endswith("_propagation"):
|
|
return False
|
|
if int(mask_data.get("propagated_from_frame_id") or 0) != int(source_frame.id):
|
|
return False
|
|
previous_seed_key = mask_data.get("propagation_seed_key")
|
|
if previous_seed_key is not None:
|
|
return previous_seed_key == seed_key or _legacy_seed_matches(mask_data, seed)
|
|
return _legacy_seed_matches(mask_data, seed)
|
|
|
|
|
|
def _direction_matches(mask_data: dict[str, Any], direction: str) -> bool:
|
|
previous_direction = mask_data.get("propagation_direction")
|
|
return previous_direction in {None, direction}
|
|
|
|
|
|
def _prepare_seed_propagation(
|
|
db: Session,
|
|
*,
|
|
payload: dict[str, Any],
|
|
model_id: str,
|
|
source_frame: Frame,
|
|
seed: dict[str, Any],
|
|
direction: str,
|
|
) -> dict[str, Any]:
|
|
seed_key = _seed_key(seed)
|
|
seed_signature = _seed_signature(seed)
|
|
previous_annotations = (
|
|
db.query(Annotation)
|
|
.filter(Annotation.project_id == int(payload["project_id"]))
|
|
.all()
|
|
)
|
|
matching = [
|
|
annotation for annotation in previous_annotations
|
|
if _is_propagation_annotation(annotation, source_frame, seed_key, seed)
|
|
and _direction_matches(annotation.mask_data or {}, direction)
|
|
]
|
|
if matching and all(
|
|
(annotation.mask_data or {}).get("propagation_seed_signature") == seed_signature
|
|
and _source_model_matches(annotation.mask_data or {}, model_id)
|
|
for annotation in matching
|
|
):
|
|
return {
|
|
"skip": True,
|
|
"seed_key": seed_key,
|
|
"seed_signature": seed_signature,
|
|
"deleted_annotation_count": 0,
|
|
}
|
|
|
|
deleted_count = 0
|
|
if matching:
|
|
for annotation in matching:
|
|
db.delete(annotation)
|
|
deleted_count += 1
|
|
db.commit()
|
|
|
|
return {
|
|
"skip": False,
|
|
"seed_key": seed_key,
|
|
"seed_signature": seed_signature,
|
|
"deleted_annotation_count": deleted_count,
|
|
}
|
|
|
|
|
|
def _frame_window(
|
|
frames: list[Frame],
|
|
source_position: int,
|
|
direction: str,
|
|
max_frames: int,
|
|
) -> tuple[list[Frame], int]:
|
|
count = max(1, min(max_frames, len(frames)))
|
|
if direction == "backward":
|
|
start = max(0, source_position - count + 1)
|
|
return frames[start:source_position + 1], source_position - start
|
|
end = min(len(frames), source_position + count)
|
|
return frames[source_position:end], 0
|
|
|
|
|
|
def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]:
|
|
paths = []
|
|
for index, frame in enumerate(frames):
|
|
data = download_file(frame.image_url)
|
|
# SAM2VideoPredictor sorts frames by converting the filename stem to int.
|
|
path = directory / f"{index:06d}.jpg"
|
|
path.write_bytes(data)
|
|
paths.append(str(path))
|
|
return paths
|
|
|
|
|
|
def _save_propagated_annotations(
|
|
db: Session,
|
|
*,
|
|
payload: dict[str, Any],
|
|
selected_frames: list[Frame],
|
|
source_frame: Frame,
|
|
propagated: list[dict[str, Any]],
|
|
seed: dict[str, Any],
|
|
) -> list[Annotation]:
|
|
created: list[Annotation] = []
|
|
if payload.get("save_annotations", True) is False:
|
|
return created
|
|
|
|
class_metadata = seed.get("class_metadata")
|
|
template_id = seed.get("template_id")
|
|
label = seed.get("label") or "Propagated Mask"
|
|
color = seed.get("color") or "#06b6d4"
|
|
model_id = sam_registry.normalize_model_id(payload.get("model"))
|
|
include_source = bool(payload.get("include_source", False))
|
|
seed_key = _seed_key(seed)
|
|
seed_signature = _seed_signature(seed)
|
|
source_annotation_id = seed.get("source_annotation_id")
|
|
source_mask_id = seed.get("source_mask_id")
|
|
direction = str(payload.get("current_direction") or "")
|
|
|
|
for frame_result in propagated:
|
|
relative_index = int(frame_result.get("frame_index", -1))
|
|
if relative_index < 0 or relative_index >= len(selected_frames):
|
|
continue
|
|
frame = selected_frames[relative_index]
|
|
if not include_source and frame.id == source_frame.id:
|
|
continue
|
|
result_polygons = frame_result.get("polygons") or []
|
|
scores = frame_result.get("scores") or []
|
|
for polygon_index, polygon in enumerate(result_polygons):
|
|
if len(polygon) < 3:
|
|
continue
|
|
annotation = Annotation(
|
|
project_id=int(payload["project_id"]),
|
|
frame_id=frame.id,
|
|
template_id=template_id,
|
|
mask_data={
|
|
"polygons": [polygon],
|
|
"label": label,
|
|
"color": color,
|
|
"source": f"{model_id}_propagation",
|
|
"propagated_from_frame_id": source_frame.id,
|
|
"propagated_from_frame_index": source_frame.frame_index,
|
|
"propagation_seed_key": seed_key,
|
|
"propagation_seed_signature": seed_signature,
|
|
"propagation_direction": direction,
|
|
"source_annotation_id": source_annotation_id,
|
|
"source_mask_id": source_mask_id,
|
|
"score": scores[polygon_index] if polygon_index < len(scores) else None,
|
|
**({"class": class_metadata} if class_metadata else {}),
|
|
},
|
|
points=None,
|
|
bbox=_polygon_bbox(polygon),
|
|
)
|
|
db.add(annotation)
|
|
created.append(annotation)
|
|
|
|
db.commit()
|
|
for annotation in created:
|
|
db.refresh(annotation)
|
|
return created
|
|
|
|
|
|
def _run_one_step(
|
|
db: Session,
|
|
*,
|
|
payload: dict[str, Any],
|
|
frames: list[Frame],
|
|
source_frame: Frame,
|
|
source_position: int,
|
|
step: dict[str, Any],
|
|
) -> dict[str, Any]:
|
|
direction = str(step.get("direction") or "forward").lower()
|
|
if direction not in {"forward", "backward"}:
|
|
raise ValueError("direction must be forward or backward")
|
|
max_frames = max(1, min(int(step.get("max_frames") or payload.get("max_frames") or 30), 500))
|
|
seed = step.get("seed") or {}
|
|
if not (seed.get("polygons") or seed.get("bbox") or seed.get("points")):
|
|
raise ValueError("Propagation requires seed polygons, bbox, or points")
|
|
|
|
model_id = sam_registry.normalize_model_id(payload.get("model"))
|
|
seed_state = _prepare_seed_propagation(
|
|
db,
|
|
payload=payload,
|
|
model_id=model_id,
|
|
source_frame=source_frame,
|
|
seed=seed,
|
|
direction=direction,
|
|
)
|
|
if seed_state["skip"]:
|
|
return {
|
|
"model": model_id,
|
|
"direction": direction,
|
|
"processed_frame_count": 0,
|
|
"created_annotation_count": 0,
|
|
"deleted_annotation_count": 0,
|
|
"skipped_seed_count": 1,
|
|
"seed_label": seed.get("label"),
|
|
"seed_key": seed_state["seed_key"],
|
|
}
|
|
|
|
selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames)
|
|
with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload['project_id']}_") as tmpdir:
|
|
frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir))
|
|
propagated = sam_registry.propagate_video(
|
|
model_id,
|
|
frame_paths,
|
|
source_relative_index,
|
|
seed,
|
|
direction,
|
|
len(selected_frames),
|
|
)
|
|
|
|
save_payload = {**payload, "current_direction": direction}
|
|
created = _save_propagated_annotations(
|
|
db,
|
|
payload=save_payload,
|
|
selected_frames=selected_frames,
|
|
source_frame=source_frame,
|
|
propagated=propagated,
|
|
seed=seed,
|
|
)
|
|
return {
|
|
"model": model_id,
|
|
"direction": direction,
|
|
"processed_frame_count": len(selected_frames),
|
|
"created_annotation_count": len(created),
|
|
"deleted_annotation_count": int(seed_state["deleted_annotation_count"]),
|
|
"skipped_seed_count": 0,
|
|
"seed_label": seed.get("label"),
|
|
"seed_key": seed_state["seed_key"],
|
|
}
|
|
|
|
|
|
def run_propagate_project_task(db: Session, task_id: int) -> dict[str, Any]:
|
|
"""Run one queued SAM propagation task and update persisted progress."""
|
|
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
|
|
if not task:
|
|
raise ValueError(f"Task not found: {task_id}")
|
|
|
|
if task.status == TASK_STATUS_CANCELLED:
|
|
return {"task_id": task.id, "status": TASK_STATUS_CANCELLED, "message": task.message or "任务已取消"}
|
|
|
|
payload = task.payload or {}
|
|
project_id = int(payload.get("project_id") or task.project_id or 0)
|
|
source_frame_id = int(payload.get("frame_id") or 0)
|
|
try:
|
|
model_id = sam_registry.normalize_model_id(payload.get("model"))
|
|
except ValueError as exc:
|
|
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
|
|
raise
|
|
|
|
project = db.query(Project).filter(Project.id == project_id).first()
|
|
if not project:
|
|
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="项目不存在", error="Project not found", finished=True)
|
|
raise ValueError(f"Project not found: {project_id}")
|
|
|
|
source_frame = db.query(Frame).filter(Frame.id == source_frame_id, Frame.project_id == project_id).first()
|
|
if not source_frame:
|
|
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不存在", error="Frame not found", finished=True)
|
|
raise ValueError(f"Frame not found: {source_frame_id}")
|
|
|
|
frames = db.query(Frame).filter(Frame.project_id == project_id).order_by(Frame.frame_index).all()
|
|
source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None)
|
|
if source_position is None:
|
|
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不在项目帧序列中", error="Source frame is not in project frame sequence", finished=True)
|
|
raise ValueError("Source frame is not in project frame sequence")
|
|
|
|
steps = payload.get("steps") or []
|
|
if not steps:
|
|
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="传播任务缺少步骤", error="Propagation task has no steps", finished=True)
|
|
raise ValueError("Propagation task has no steps")
|
|
|
|
_ensure_not_cancelled(db, task)
|
|
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="自动传播任务已启动", started=True)
|
|
|
|
step_results: list[dict[str, Any]] = []
|
|
created_count = 0
|
|
processed_count = 0
|
|
deleted_count = 0
|
|
skipped_count = 0
|
|
total_steps = len(steps)
|
|
|
|
try:
|
|
for index, step in enumerate(steps, start=1):
|
|
_ensure_not_cancelled(db, task)
|
|
seed_label = (step.get("seed") or {}).get("label") or "mask"
|
|
direction_label = "向前传播" if step.get("direction") == "backward" else "向后传播"
|
|
progress_before = 5 + int(((index - 1) / total_steps) * 90)
|
|
_set_task_state(
|
|
db,
|
|
task,
|
|
progress=progress_before,
|
|
message=f"{direction_label} {seed_label} ({index}/{total_steps})",
|
|
result={
|
|
"project_id": project_id,
|
|
"source_frame_id": source_frame_id,
|
|
"model": model_id,
|
|
"total_steps": total_steps,
|
|
"completed_steps": index - 1,
|
|
"processed_frame_count": processed_count,
|
|
"created_annotation_count": created_count,
|
|
"deleted_annotation_count": deleted_count,
|
|
"skipped_seed_count": skipped_count,
|
|
"steps": step_results,
|
|
},
|
|
)
|
|
|
|
result = _run_one_step(
|
|
db,
|
|
payload=payload,
|
|
frames=frames,
|
|
source_frame=source_frame,
|
|
source_position=source_position,
|
|
step=step,
|
|
)
|
|
step_results.append(result)
|
|
created_count += int(result["created_annotation_count"])
|
|
processed_count += int(result["processed_frame_count"])
|
|
deleted_count += int(result.get("deleted_annotation_count") or 0)
|
|
skipped_count += int(result.get("skipped_seed_count") or 0)
|
|
_set_task_state(
|
|
db,
|
|
task,
|
|
progress=5 + int((index / total_steps) * 90),
|
|
message=f"{direction_label} {seed_label} 完成 ({index}/{total_steps})",
|
|
result={
|
|
"project_id": project_id,
|
|
"source_frame_id": source_frame_id,
|
|
"model": model_id,
|
|
"total_steps": total_steps,
|
|
"completed_steps": index,
|
|
"processed_frame_count": processed_count,
|
|
"created_annotation_count": created_count,
|
|
"deleted_annotation_count": deleted_count,
|
|
"skipped_seed_count": skipped_count,
|
|
"steps": step_results,
|
|
},
|
|
)
|
|
|
|
result = {
|
|
"project_id": project_id,
|
|
"source_frame_id": source_frame_id,
|
|
"model": model_id,
|
|
"total_steps": total_steps,
|
|
"completed_steps": total_steps,
|
|
"processed_frame_count": processed_count,
|
|
"created_annotation_count": created_count,
|
|
"deleted_annotation_count": deleted_count,
|
|
"skipped_seed_count": skipped_count,
|
|
"steps": step_results,
|
|
}
|
|
_set_task_state(
|
|
db,
|
|
task,
|
|
status=TASK_STATUS_SUCCESS,
|
|
progress=100,
|
|
message="自动传播完成" if created_count > 0 else (
|
|
"自动传播完成,未改变的 mask 已跳过" if skipped_count > 0 else "自动传播完成,但没有生成新的 mask"
|
|
),
|
|
result=result,
|
|
finished=True,
|
|
)
|
|
return result
|
|
except PropagationTaskCancelled:
|
|
task.status = TASK_STATUS_CANCELLED
|
|
task.progress = 100
|
|
task.message = task.message or "任务已取消"
|
|
task.error = task.error or "Cancelled by user"
|
|
task.finished_at = task.finished_at or _now()
|
|
db.commit()
|
|
db.refresh(task)
|
|
publish_task_progress_event(task)
|
|
return {"task_id": task.id, "project_id": project_id, "status": TASK_STATUS_CANCELLED, "message": task.message}
|
|
except (ModelUnavailableError, NotImplementedError, ValueError) as exc:
|
|
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
|
|
raise
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.exception("Propagation task failed: task_id=%s", task.id)
|
|
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
|
|
raise
|