feat: 完善 mask 编辑、传播平滑与开发重启闭环

功能增加:

- 新增后端 /api/ai/smooth-mask 接口,对当前 mask polygon 执行 Chaikin 边缘平滑,并返回 polygon、bbox、area 与拓扑锚点。

- 在右侧实例属性面板加入边缘平滑强度和应用边缘平滑操作,应用后将 mask 标记为 draft/dirty,并通过正常保存链路落库。

- 保存标注与传播 seed 时保留 geometry_smoothing 元数据,自动传播 forward/backward 结果保存前应用同一平滑参数。

- 自动传播 seed signature 纳入平滑参数,修改平滑强度后会触发旧同源传播结果清理并重新传播。

- 支持跨帧跟随同一传播链 mask,AI 推送回工作区时保留当前帧视角。

Bugfix:

- 修复中间帧向前传播时旧 forward/backward 同物体结果未被清理导致双重 mask 的问题。

- 修复 propagation worker 写入目标帧前只按旧方向清理导致 backward 重传残留的问题。

- 修复多边形顶点拖拽和编辑后画布视口异常移动的问题,并补充拖拽状态回写。

- 修复实例属性标题跟随全局 active class 而不是当前 mask label 的问题,并移除后端模型置信度展示。

开发与部署:

- 新增 restart_dev_services.sh,使用 setsid 独立后台重启 FastAPI、Celery 和前端,写入 pid/log 文件并做 3000/8000 健康检查。

- 明确后端或 Celery 相关改动完成后需要运行重启脚本,保证运行态加载最新代码。

测试与文档:

- 补充后端 smooth-mask、传播平滑 metadata、seed signature、传播去重方向覆盖等测试。

- 补充前端 OntologyInspector、VideoWorkspace、CanvasArea 和 api 契约测试,覆盖边缘平滑、传播参数、跨帧选区跟随和画布编辑行为。

- 更新 README、AGENTS、安装文档、前端元素审计、需求冻结、设计冻结和测试计划,记录当前真实行为与重启要求。
This commit is contained in:
2026-05-02 17:04:02 +08:00
parent f365539ff2
commit 4c1d3dba73
20 changed files with 1358 additions and 71 deletions

View File

@@ -17,6 +17,8 @@ from schemas import (
AiRuntimeStatus,
MaskAnalysisRequest,
MaskAnalysisResponse,
SmoothMaskRequest,
SmoothMaskResponse,
PredictRequest,
PredictResponse,
PropagateRequest,
@@ -96,6 +98,14 @@ def _polygon_area(polygon: list[list[float]]) -> float:
return abs(total) / 2.0
def _normalize_polygon(polygon: list[list[float]]) -> list[list[float]]:
return [[_clamp01(point[0]), _clamp01(point[1])] for point in polygon if len(point) >= 2]
def _normalize_polygons(polygons: list[list[list[float]]]) -> list[list[list[float]]]:
return [polygon for polygon in (_normalize_polygon(polygon) for polygon in polygons) if len(polygon) >= 3]
def _analysis_anchors(polygons: list[list[list[float]]], points: list[list[float]] | None) -> list[list[float]]:
if points:
return [[_clamp01(point[0]), _clamp01(point[1])] for point in points if len(point) >= 2]
@@ -108,6 +118,63 @@ def _analysis_anchors(polygons: list[list[list[float]]], points: list[list[float
return anchors[:32]
def _normalize_smoothing_options(strength: float | int | None, method: str | None = None) -> dict[str, Any]:
clamped_strength = max(0.0, min(float(strength or 0.0), 100.0))
normalized_method = (method or "chaikin").lower()
if normalized_method != "chaikin":
normalized_method = "chaikin"
return {
"strength": round(clamped_strength, 2),
"method": normalized_method,
}
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int) -> list[list[float]]:
points = polygon
for _ in range(max(0, iterations)):
if len(points) < 3:
break
next_points: list[list[float]] = []
for index, current in enumerate(points):
following = points[(index + 1) % len(points)]
next_points.append([
_clamp01(0.75 * current[0] + 0.25 * following[0]),
_clamp01(0.75 * current[1] + 0.25 * following[1]),
])
next_points.append([
_clamp01(0.25 * current[0] + 0.75 * following[0]),
_clamp01(0.25 * current[1] + 0.75 * following[1]),
])
points = next_points
return points
def _simplify_polygon(polygon: list[list[float]], strength: float) -> list[list[float]]:
if len(polygon) < 3 or strength <= 0:
return polygon
contour = np.array([[[point[0], point[1]]] for point in polygon], dtype=np.float32)
arc_length = cv2.arcLength(contour, True)
epsilon = arc_length * (0.001 + (strength / 100.0) * 0.006)
approx = cv2.approxPolyDP(contour, epsilon, True).reshape(-1, 2)
if len(approx) < 3:
return polygon
return [[_clamp01(float(x)), _clamp01(float(y))] for x, y in approx]
def _smooth_polygon(polygon: list[list[float]], smoothing: dict[str, Any]) -> list[list[float]]:
strength = float(smoothing.get("strength") or 0.0)
if strength <= 0:
return _normalize_polygon(polygon)
iterations = max(1, min(3, int(strength // 35) + 1))
smoothed = _chaikin_smooth_polygon(_normalize_polygon(polygon), iterations)
simplified = _simplify_polygon(smoothed, strength)
return simplified if len(simplified) >= 3 else _normalize_polygon(polygon)
def _smooth_polygons(polygons: list[list[list[float]]], smoothing: dict[str, Any]) -> list[list[list[float]]]:
return [polygon for polygon in (_smooth_polygon(polygon, smoothing) for polygon in polygons) if len(polygon) >= 3]
def _frame_window(
frames: list[Frame],
source_position: int,
@@ -436,11 +503,7 @@ def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) ->
if not polygons:
raise HTTPException(status_code=400, detail="Mask analysis requires polygons")
valid_polygons = [
[[_clamp01(point[0]), _clamp01(point[1])] for point in polygon if len(point) >= 2]
for polygon in polygons
]
valid_polygons = [polygon for polygon in valid_polygons if len(polygon) >= 3]
valid_polygons = _normalize_polygons(polygons)
if not valid_polygons:
raise HTTPException(status_code=400, detail="Mask analysis requires at least one valid polygon")
@@ -473,6 +536,46 @@ def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) ->
}
@router.post(
"/smooth-mask",
response_model=SmoothMaskResponse,
summary="Smooth editable mask polygons with backend geometry rules",
)
def smooth_mask(payload: SmoothMaskRequest, db: Session = Depends(get_db)) -> dict:
"""Return a smoothed polygon mask without persisting it.
The frontend keeps this as an explicit edit operation: users preview/apply it
to the current mask, then save through the normal annotation endpoint.
"""
if payload.frame_id is not None:
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
polygons = payload.mask_data.get("polygons") or []
valid_polygons = _normalize_polygons(polygons)
if not valid_polygons:
raise HTTPException(status_code=400, detail="Mask smoothing requires at least one valid polygon")
smoothing = _normalize_smoothing_options(payload.strength, payload.method)
smoothed_polygons = _smooth_polygons(valid_polygons, smoothing)
if not smoothed_polygons:
raise HTTPException(status_code=400, detail="Mask smoothing produced no valid polygons")
area = sum(_polygon_area(polygon) for polygon in smoothed_polygons)
bbox = _polygon_bbox(smoothed_polygons[0])
anchors = _analysis_anchors(smoothed_polygons, payload.points)
return {
"polygons": smoothed_polygons,
"topology_anchor_count": len(anchors),
"topology_anchors": anchors,
"area": area,
"bbox": bbox,
"smoothing": smoothing,
"message": f"已应用边缘平滑强度 {smoothing['strength']:.0f}",
}
@router.post(
"/propagate",
response_model=PropagateResponse,
@@ -544,6 +647,13 @@ def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
label = seed.get("label") or "Propagated Mask"
color = seed.get("color") or "#06b6d4"
model_id = sam_registry.normalize_model_id(payload.model)
seed_smoothing = seed.get("smoothing")
smoothing = _normalize_smoothing_options(
seed_smoothing.get("strength"),
seed_smoothing.get("method"),
) if isinstance(seed_smoothing, dict) else None
if smoothing and smoothing["strength"] <= 0:
smoothing = None
for frame_result in propagated:
relative_index = int(frame_result.get("frame_index", -1))
@@ -557,22 +667,24 @@ def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
for polygon_index, polygon in enumerate(result_polygons):
if len(polygon) < 3:
continue
polygon_to_save = _smooth_polygon(polygon, smoothing) if smoothing else polygon
annotation = Annotation(
project_id=payload.project_id,
frame_id=frame.id,
template_id=template_id,
mask_data={
"polygons": [polygon],
"polygons": [polygon_to_save],
"label": label,
"color": color,
"source": f"{model_id}_propagation",
"propagated_from_frame_id": source_frame.id,
"propagated_from_frame_index": source_frame.frame_index,
"score": scores[polygon_index] if polygon_index < len(scores) else None,
**({"geometry_smoothing": smoothing} if smoothing else {}),
**({"class": class_metadata} if class_metadata else {}),
},
points=None,
bbox=_polygon_bbox(polygon),
bbox=_polygon_bbox(polygon_to_save),
)
db.add(annotation)
created.append(annotation)

View File

@@ -209,6 +209,25 @@ class MaskAnalysisResponse(BaseModel):
message: str
class SmoothMaskRequest(BaseModel):
frame_id: Optional[int] = None
mask_data: dict[str, Any]
points: Optional[list[list[float]]] = None
bbox: Optional[list[float]] = None
strength: float = 0.0
method: str = "chaikin"
class SmoothMaskResponse(BaseModel):
polygons: list[list[list[float]]]
topology_anchor_count: int
topology_anchors: list[list[float]]
area: float
bbox: Optional[list[float]] = None
smoothing: dict[str, Any]
message: str
class PropagationSeed(BaseModel):
polygons: Optional[list[list[list[float]]]] = None
bbox: Optional[list[float]] = None
@@ -221,6 +240,7 @@ class PropagationSeed(BaseModel):
source_mask_id: Optional[str] = None
source_annotation_id: Optional[int] = None
propagation_seed_signature: Optional[str] = None
smoothing: Optional[dict[str, Any]] = None
class PropagateRequest(BaseModel):

View File

@@ -8,6 +8,8 @@ from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import cv2
import numpy as np
from sqlalchemy.orm import Session
from minio_client import download_file
@@ -81,6 +83,87 @@ def _polygon_bbox(polygon: list[list[float]]) -> list[float]:
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
def _normalize_polygon(polygon: list[list[float]]) -> list[list[float]]:
return [[_clamp01(point[0]), _clamp01(point[1])] for point in polygon if len(point) >= 2]
def _normalize_smoothing_options(value: Any) -> dict[str, Any] | None:
if not isinstance(value, dict):
return None
try:
strength = max(0.0, min(float(value.get("strength") or 0.0), 100.0))
except (TypeError, ValueError):
strength = 0.0
if strength <= 0:
return None
method = str(value.get("method") or "chaikin").lower()
if method != "chaikin":
method = "chaikin"
return {"strength": round(strength, 2), "method": method}
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int) -> list[list[float]]:
points = _normalize_polygon(polygon)
for _ in range(max(0, iterations)):
if len(points) < 3:
break
next_points: list[list[float]] = []
for index, current in enumerate(points):
following = points[(index + 1) % len(points)]
next_points.append([
_clamp01(0.75 * current[0] + 0.25 * following[0]),
_clamp01(0.75 * current[1] + 0.25 * following[1]),
])
next_points.append([
_clamp01(0.25 * current[0] + 0.75 * following[0]),
_clamp01(0.25 * current[1] + 0.75 * following[1]),
])
points = next_points
return points
def _simplify_polygon(polygon: list[list[float]], strength: float) -> list[list[float]]:
if len(polygon) < 3:
return polygon
contour = np.array([[[point[0], point[1]]] for point in polygon], dtype=np.float32)
arc_length = cv2.arcLength(contour, True)
epsilon = arc_length * (0.001 + (strength / 100.0) * 0.006)
approx = cv2.approxPolyDP(contour, epsilon, True).reshape(-1, 2)
if len(approx) < 3:
return polygon
return [[_clamp01(float(x)), _clamp01(float(y))] for x, y in approx]
def _smooth_polygon(polygon: list[list[float]], smoothing: dict[str, Any] | None) -> list[list[float]]:
if not smoothing:
return _normalize_polygon(polygon)
strength = float(smoothing.get("strength") or 0.0)
if strength <= 0:
return _normalize_polygon(polygon)
iterations = max(1, min(3, int(strength // 35) + 1))
smoothed = _chaikin_smooth_polygon(polygon, iterations)
simplified = _simplify_polygon(smoothed, strength)
return simplified if len(simplified) >= 3 else _normalize_polygon(polygon)
def _bbox_area(bbox: list[float]) -> float:
return max(float(bbox[2]), 0.0) * max(float(bbox[3]), 0.0)
def _bbox_overlap_ratio(a: list[float], b: list[float]) -> float:
ax1, ay1, aw, ah = a
bx1, by1, bw, bh = b
ax2 = ax1 + aw
ay2 = ay1 + ah
bx2 = bx1 + bw
by2 = by1 + bh
overlap_width = max(0.0, min(ax2, bx2) - max(ax1, bx1))
overlap_height = max(0.0, min(ay2, by2) - max(ay1, by1))
overlap_area = overlap_width * overlap_height
smallest_area = min(_bbox_area(a), _bbox_area(b))
return overlap_area / smallest_area if smallest_area > 0 else 0.0
def _stable_json(value: Any) -> str:
return json.dumps(value, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
@@ -109,6 +192,7 @@ def _seed_signature(seed: dict[str, Any]) -> str:
"color": seed.get("color"),
"class_metadata": seed.get("class_metadata") or {},
"template_id": seed.get("template_id"),
"smoothing": _normalize_smoothing_options(seed.get("smoothing")),
}
return hashlib.sha256(_stable_json(_canonicalize_signature_value(signature_payload)).encode("utf-8")).hexdigest()
@@ -131,6 +215,20 @@ def _seed_key(seed: dict[str, Any]) -> str:
})
def _semantic_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool:
"""Best-effort match when a manually edited replacement lacks old lineage ids."""
class_metadata = seed.get("class_metadata") or {}
previous_class = mask_data.get("class") or {}
previous_class_id = previous_class.get("id") or previous_class.get("name")
class_id = class_metadata.get("id") or class_metadata.get("name")
if previous_class_id and class_id and str(previous_class_id) != str(class_id):
return False
return (
mask_data.get("label") == seed.get("label")
and mask_data.get("color") == seed.get("color")
)
def _legacy_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool:
"""Best-effort match for propagation annotations created before seed keys."""
class_metadata = seed.get("class_metadata") or {}
@@ -174,6 +272,52 @@ def _direction_matches(mask_data: dict[str, Any], direction: str) -> bool:
return previous_direction in {None, direction}
def _annotation_spatially_matches(annotation: Annotation, polygon: list[list[float]]) -> bool:
"""Use target-frame overlap as a final guard before replacing same-object propagation."""
candidate_bbox = _polygon_bbox(polygon)
for previous_polygon in (annotation.mask_data or {}).get("polygons") or []:
if len(previous_polygon) < 3:
continue
if _bbox_overlap_ratio(_polygon_bbox(previous_polygon), candidate_bbox) >= 0.15:
return True
return False
def _delete_replaced_frame_annotations(
db: Session,
*,
payload: dict[str, Any],
frame_id: int,
seed_key: str,
seed: dict[str, Any],
polygon: list[list[float]],
) -> int:
"""Delete old propagated masks for the same object immediately before writing a new result."""
previous_annotations = (
db.query(Annotation)
.filter(Annotation.project_id == int(payload["project_id"]))
.filter(Annotation.frame_id == frame_id)
.all()
)
deleted_count = 0
for annotation in previous_annotations:
mask_data = annotation.mask_data or {}
source = str(mask_data.get("source") or "")
if not source.endswith("_propagation"):
continue
same_lineage = _seed_identity_matches(mask_data, seed_key, seed)
same_manual_replacement = (
_semantic_seed_matches(mask_data, seed)
and _annotation_spatially_matches(annotation, polygon)
)
if same_lineage or same_manual_replacement:
db.delete(annotation)
deleted_count += 1
if deleted_count:
db.commit()
return deleted_count
def _prepare_seed_propagation(
db: Session,
*,
@@ -264,10 +408,10 @@ def _save_propagated_annotations(
source_frame: Frame,
propagated: list[dict[str, Any]],
seed: dict[str, Any],
) -> list[Annotation]:
) -> tuple[list[Annotation], int]:
created: list[Annotation] = []
if payload.get("save_annotations", True) is False:
return created
return created, 0
class_metadata = seed.get("class_metadata")
template_id = seed.get("template_id")
@@ -279,7 +423,10 @@ def _save_propagated_annotations(
seed_signature = _seed_signature(seed)
source_annotation_id = seed.get("source_annotation_id")
source_mask_id = seed.get("source_mask_id")
smoothing = _normalize_smoothing_options(seed.get("smoothing"))
direction = str(payload.get("current_direction") or "")
deleted_count = 0
cleaned_frame_ids: set[int] = set()
for frame_result in propagated:
relative_index = int(frame_result.get("frame_index", -1))
@@ -290,7 +437,23 @@ def _save_propagated_annotations(
continue
result_polygons = frame_result.get("polygons") or []
scores = frame_result.get("scores") or []
for polygon_index, polygon in enumerate(result_polygons):
smoothed_polygons = [
_smooth_polygon(polygon, smoothing)
for polygon in result_polygons
if len(polygon) >= 3
]
cleanup_polygon = next((polygon for polygon in smoothed_polygons if len(polygon) >= 3), None)
if cleanup_polygon is not None and frame.id not in cleaned_frame_ids:
deleted_count += _delete_replaced_frame_annotations(
db,
payload=payload,
frame_id=int(frame.id),
seed_key=seed_key,
seed=seed,
polygon=cleanup_polygon,
)
cleaned_frame_ids.add(int(frame.id))
for polygon_index, polygon in enumerate(smoothed_polygons):
if len(polygon) < 3:
continue
annotation = Annotation(
@@ -310,6 +473,7 @@ def _save_propagated_annotations(
"source_annotation_id": source_annotation_id,
"source_mask_id": source_mask_id,
"score": scores[polygon_index] if polygon_index < len(scores) else None,
**({"geometry_smoothing": smoothing} if smoothing else {}),
**({"class": class_metadata} if class_metadata else {}),
},
points=None,
@@ -321,7 +485,7 @@ def _save_propagated_annotations(
db.commit()
for annotation in created:
db.refresh(annotation)
return created
return created, deleted_count
def _run_one_step(
@@ -381,7 +545,7 @@ def _run_one_step(
)
save_payload = {**payload, "current_direction": direction}
created = _save_propagated_annotations(
created, write_cleanup_count = _save_propagated_annotations(
db,
payload=save_payload,
selected_frames=selected_frames,
@@ -394,7 +558,7 @@ def _run_one_step(
"direction": direction,
"processed_frame_count": len(selected_frames),
"created_annotation_count": len(created),
"deleted_annotation_count": int(seed_state["deleted_annotation_count"]),
"deleted_annotation_count": int(seed_state["deleted_annotation_count"]) + write_cleanup_count,
"skipped_seed_count": 0,
"seed_label": seed.get("label"),
"seed_key": seed_state["seed_key"],

View File

@@ -223,6 +223,41 @@ def test_analyze_mask_returns_backend_geometry_properties(client):
assert body["message"] == "已从后端重新提取几何拓扑锚点"
def test_smooth_mask_returns_backend_smoothed_geometry(client):
_, frame, _ = _create_project_and_frame(client)
response = client.post("/api/ai/smooth-mask", json={
"frame_id": frame["id"],
"mask_data": {
"polygons": [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3], [0.1, 0.3]]],
"label": "胆囊",
"color": "#ff0000",
},
"strength": 45,
})
assert response.status_code == 200
body = response.json()
assert body["smoothing"] == {"strength": 45.0, "method": "chaikin"}
assert len(body["polygons"]) == 1
assert len(body["polygons"][0]) > 4
assert body["topology_anchor_count"] > 0
assert body["message"] == "已应用边缘平滑强度 45"
def test_seed_signature_includes_smoothing_parameters():
seed = {
"polygons": [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]],
"label": "胆囊",
"color": "#ff0000",
}
assert _seed_signature({**seed, "smoothing": {"strength": 20, "method": "chaikin"}}) != _seed_signature({
**seed,
"smoothing": {"strength": 40, "method": "chaikin"},
})
def test_propagate_saves_tracked_annotations(client, monkeypatch):
project = client.post("/api/projects", json={"name": "Video Project"}).json()
frames = [
@@ -324,6 +359,7 @@ def test_queue_propagation_task_creates_processing_task(client, monkeypatch):
"seed": {
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
"label": "胆囊",
"smoothing": {"strength": 35, "method": "chaikin"},
},
}],
})
@@ -335,6 +371,7 @@ def test_queue_propagation_task_creates_processing_task(client, monkeypatch):
assert body["celery_task_id"] == "celery-propagate-1"
assert body["payload"]["model"] == "sam2.1_hiera_tiny"
assert body["payload"]["steps"][0]["seed"]["label"] == "胆囊"
assert body["payload"]["steps"][0]["seed"]["smoothing"] == {"strength": 35, "method": "chaikin"}
assert queued == [body["id"]]
@@ -418,6 +455,7 @@ def test_propagation_task_runner_saves_annotations_and_progress(client, db_sessi
"label": "胆囊",
"color": "#ff0000",
"class_metadata": {"id": "c1", "name": "胆囊"},
"smoothing": {"strength": 40, "method": "chaikin"},
},
}],
},
@@ -452,6 +490,8 @@ def test_propagation_task_runner_saves_annotations_and_progress(client, db_sessi
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
assert listing.json()[0]["frame_id"] == frames[1]["id"]
assert listing.json()[0]["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation"
assert listing.json()[0]["mask_data"]["geometry_smoothing"] == {"strength": 40.0, "method": "chaikin"}
assert len(listing.json()[0]["mask_data"]["polygons"][0]) > 3
def test_propagation_task_runner_skips_unchanged_seed_and_replaces_changed_seed(client, db_session, monkeypatch):
@@ -614,6 +654,172 @@ def test_propagation_task_runner_replaces_legacy_or_different_weight_results(cli
assert annotations[0].mask_data["polygons"] == [output_polygon]
def test_propagation_task_runner_replaces_downstream_result_from_middle_frame_manual_seed(client, db_session, monkeypatch):
project = client.post("/api/projects", json={"name": "Propagation Middle Frame Replacement"}).json()
frames = [
client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": idx,
"image_url": f"frames/{idx}.jpg",
"width": 640,
"height": 360,
}).json()
for idx in range(3)
]
old_downstream_polygon = [[0.18, 0.18], [0.28, 0.18], [0.28, 0.28]]
replacement_seed_polygon = [[0.16, 0.16], [0.26, 0.16], [0.26, 0.26]]
replacement_downstream_polygon = [[0.19, 0.19], [0.29, 0.19], [0.29, 0.29]]
db_session.add(Annotation(
project_id=project["id"],
frame_id=frames[2]["id"],
template_id=3,
mask_data={
"polygons": [old_downstream_polygon],
"label": "胆囊",
"color": "#ff0000",
"class": {"id": "c1", "name": "胆囊", "color": "#ff0000"},
"source": "sam2.1_hiera_tiny_propagation",
"propagated_from_frame_id": frames[0]["id"],
"propagation_seed_key": "annotation:7",
"propagation_seed_signature": "old-signature",
"propagation_direction": "forward",
"source_annotation_id": 7,
"source_mask_id": "annotation-7",
},
bbox=[0.18, 0.18, 0.1, 0.1],
))
db_session.commit()
task = ProcessingTask(
task_type="propagate_masks",
status="queued",
progress=0,
project_id=project["id"],
payload={
"project_id": project["id"],
"frame_id": frames[1]["id"],
"model": "sam2.1_hiera_tiny",
"include_source": False,
"save_annotations": True,
"steps": [{
"direction": "forward",
"max_frames": 2,
"seed": {
"polygons": [replacement_seed_polygon],
"label": "胆囊",
"color": "#ff0000",
"source_annotation_id": 20,
"source_mask_id": "annotation-20",
},
}],
},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg")
monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None)
monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", lambda model, frame_paths, source_frame_index, seed, direction, max_frames: [
{"frame_index": 0, "polygons": [seed["polygons"][0]], "scores": [0.9]},
{"frame_index": 1, "polygons": [replacement_downstream_polygon], "scores": [0.8]},
])
result = run_propagate_project_task(db_session, task.id)
assert result["created_annotation_count"] == 1
assert result["deleted_annotation_count"] == 1
annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).all()
assert len(annotations) == 1
assert annotations[0].frame_id == frames[2]["id"]
assert annotations[0].mask_data["polygons"] == [replacement_downstream_polygon]
assert annotations[0].mask_data["source_annotation_id"] == 20
assert annotations[0].mask_data["source_mask_id"] == "annotation-20"
def test_propagation_task_runner_replaces_forward_result_when_middle_frame_propagates_backward(client, db_session, monkeypatch):
project = client.post("/api/projects", json={"name": "Propagation Backward Middle Replacement"}).json()
frames = [
client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": idx,
"image_url": f"frames/{idx}.jpg",
"width": 640,
"height": 360,
}).json()
for idx in range(3)
]
old_upstream_polygon = [[0.12, 0.12], [0.22, 0.12], [0.22, 0.22]]
replacement_seed_polygon = [[0.16, 0.16], [0.26, 0.16], [0.26, 0.26]]
replacement_upstream_polygon = [[0.13, 0.13], [0.23, 0.13], [0.23, 0.23]]
db_session.add(Annotation(
project_id=project["id"],
frame_id=frames[0]["id"],
mask_data={
"polygons": [old_upstream_polygon],
"label": "胆囊",
"color": "#ff0000",
"source": "sam2.1_hiera_tiny_propagation",
"propagated_from_frame_id": frames[0]["id"],
"propagation_seed_key": "annotation:7",
"propagation_seed_signature": "old-signature",
"propagation_direction": "forward",
"source_annotation_id": 7,
"source_mask_id": "annotation-7",
},
bbox=[0.12, 0.12, 0.1, 0.1],
))
db_session.commit()
task = ProcessingTask(
task_type="propagate_masks",
status="queued",
progress=0,
project_id=project["id"],
payload={
"project_id": project["id"],
"frame_id": frames[1]["id"],
"model": "sam2.1_hiera_tiny",
"include_source": False,
"save_annotations": True,
"steps": [{
"direction": "backward",
"max_frames": 2,
"seed": {
"polygons": [replacement_seed_polygon],
"label": "胆囊",
"color": "#ff0000",
"source_annotation_id": 20,
"source_mask_id": "annotation-20",
},
}],
},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg")
monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None)
monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", lambda model, frame_paths, source_frame_index, seed, direction, max_frames: [
{"frame_index": 0, "polygons": [replacement_upstream_polygon], "scores": [0.8]},
{"frame_index": 1, "polygons": [seed["polygons"][0]], "scores": [0.9]},
])
result = run_propagate_project_task(db_session, task.id)
assert result["created_annotation_count"] == 1
assert result["deleted_annotation_count"] == 1
annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).all()
assert len(annotations) == 1
assert annotations[0].frame_id == frames[0]["id"]
assert annotations[0].mask_data["polygons"] == [replacement_upstream_polygon]
assert annotations[0].mask_data["propagation_direction"] == "backward"
assert annotations[0].mask_data["source_annotation_id"] == 20
def test_propagation_task_runner_skips_unmodified_propagated_seed_on_overlapping_frames(client, db_session, monkeypatch):
project = client.post("/api/projects", json={"name": "Propagation Overlap Skip"}).json()
frames = [