fix: 避免自动传播重复叠加同源 mask
Bugfix:自动传播 worker 改为在本次目标帧段内按 seed 来源、方向、权重和签名查找旧传播结果;未修改且目标帧已覆盖时直接跳过,不再重复跑 SAM 造成 mask 堆叠。 Bugfix:同一 seed 被编辑、目标帧段只部分覆盖或切换 SAM 2.1 权重时,worker 会先删除本次目标帧段内同源旧自动传播标注,再重新传播。 Bugfix:未编辑的自动传播结果再次作为参考 seed 时会继承原始 propagation_seed_signature;编辑后的传播结果只保留 source_annotation_id/source_mask_id lineage,不继承旧签名,从而触发重传路径。 Bugfix:后端传播签名增加 canonical rounding,减少浮点精度细微变化导致未编辑 mask 被误判为已修改。 功能调整:清空片段遮罩改成与自动传播一致的时间轴范围选择流程,首次点击进入选区,拖拽选择起止帧后点击确认清空才执行。 接口契约:PropagationSeed 增加 propagation_seed_signature 字段,用于前端把未编辑传播结果绑定回原始 seed 传播链。 测试:补充前端 VideoWorkspace 范围清空、传播 lineage 传递测试;补充后端未编辑传播 seed 跳过重复传播、旧结果清理与换权重重传测试。 文档:同步更新 doc/03、doc/04、doc/07、doc/08、doc/09,明确 A/B 传播去重规则、清空片段范围选择和新增 seed signature 契约。
This commit is contained in:
@@ -220,6 +220,7 @@ class PropagationSeed(BaseModel):
|
||||
template_id: Optional[int] = None
|
||||
source_mask_id: Optional[str] = None
|
||||
source_annotation_id: Optional[int] = None
|
||||
propagation_seed_signature: Optional[str] = None
|
||||
|
||||
|
||||
class PropagateRequest(BaseModel):
|
||||
|
||||
@@ -85,8 +85,21 @@ def _stable_json(value: Any) -> str:
|
||||
return json.dumps(value, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||
|
||||
|
||||
def _canonicalize_signature_value(value: Any) -> Any:
|
||||
if isinstance(value, float):
|
||||
return round(value, 6)
|
||||
if isinstance(value, list):
|
||||
return [_canonicalize_signature_value(item) for item in value]
|
||||
if isinstance(value, dict):
|
||||
return {key: _canonicalize_signature_value(value[key]) for key in sorted(value)}
|
||||
return value
|
||||
|
||||
|
||||
def _seed_signature(seed: dict[str, Any]) -> str:
|
||||
"""Return a stable signature for seed geometry and semantic attrs."""
|
||||
inherited_signature = seed.get("propagation_seed_signature")
|
||||
if inherited_signature:
|
||||
return str(inherited_signature)
|
||||
signature_payload = {
|
||||
"polygons": seed.get("polygons") or [],
|
||||
"bbox": seed.get("bbox") or [],
|
||||
@@ -97,7 +110,7 @@ def _seed_signature(seed: dict[str, Any]) -> str:
|
||||
"class_metadata": seed.get("class_metadata") or {},
|
||||
"template_id": seed.get("template_id"),
|
||||
}
|
||||
return hashlib.sha256(_stable_json(signature_payload).encode("utf-8")).hexdigest()
|
||||
return hashlib.sha256(_stable_json(_canonicalize_signature_value(signature_payload)).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _seed_key(seed: dict[str, Any]) -> str:
|
||||
@@ -135,22 +148,25 @@ def _source_model_matches(mask_data: dict[str, Any], model_id: str) -> bool:
|
||||
return str(mask_data.get("source") or "") == f"{model_id}_propagation"
|
||||
|
||||
|
||||
def _is_propagation_annotation(
|
||||
annotation: Annotation,
|
||||
source_frame: Frame,
|
||||
seed_key: str,
|
||||
seed: dict[str, Any],
|
||||
) -> bool:
|
||||
def _seed_identity_matches(mask_data: dict[str, Any], seed_key: str, seed: dict[str, Any]) -> bool:
|
||||
previous_seed_key = mask_data.get("propagation_seed_key")
|
||||
if previous_seed_key == seed_key:
|
||||
return True
|
||||
source_annotation_id = seed.get("source_annotation_id")
|
||||
if source_annotation_id is not None and str(mask_data.get("source_annotation_id") or "") == str(source_annotation_id):
|
||||
return True
|
||||
source_mask_id = seed.get("source_mask_id")
|
||||
if source_mask_id and mask_data.get("source_mask_id") == source_mask_id:
|
||||
return True
|
||||
return _legacy_seed_matches(mask_data, seed)
|
||||
|
||||
|
||||
def _is_propagation_annotation(annotation: Annotation, seed_key: str, seed: dict[str, Any]) -> bool:
|
||||
mask_data = annotation.mask_data or {}
|
||||
source = str(mask_data.get("source") or "")
|
||||
if not source.endswith("_propagation"):
|
||||
return False
|
||||
if int(mask_data.get("propagated_from_frame_id") or 0) != int(source_frame.id):
|
||||
return False
|
||||
previous_seed_key = mask_data.get("propagation_seed_key")
|
||||
if previous_seed_key is not None:
|
||||
return previous_seed_key == seed_key or _legacy_seed_matches(mask_data, seed)
|
||||
return _legacy_seed_matches(mask_data, seed)
|
||||
return _seed_identity_matches(mask_data, seed_key, seed)
|
||||
|
||||
|
||||
def _direction_matches(mask_data: dict[str, Any], direction: str) -> bool:
|
||||
@@ -163,27 +179,36 @@ def _prepare_seed_propagation(
|
||||
*,
|
||||
payload: dict[str, Any],
|
||||
model_id: str,
|
||||
source_frame: Frame,
|
||||
seed: dict[str, Any],
|
||||
direction: str,
|
||||
target_frame_ids: set[int],
|
||||
) -> dict[str, Any]:
|
||||
seed_key = _seed_key(seed)
|
||||
seed_signature = _seed_signature(seed)
|
||||
if not target_frame_ids:
|
||||
return {
|
||||
"skip": True,
|
||||
"seed_key": seed_key,
|
||||
"seed_signature": seed_signature,
|
||||
"deleted_annotation_count": 0,
|
||||
}
|
||||
previous_annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == int(payload["project_id"]))
|
||||
.filter(Annotation.frame_id.in_(target_frame_ids))
|
||||
.all()
|
||||
)
|
||||
matching = [
|
||||
annotation for annotation in previous_annotations
|
||||
if _is_propagation_annotation(annotation, source_frame, seed_key, seed)
|
||||
if _is_propagation_annotation(annotation, seed_key, seed)
|
||||
and _direction_matches(annotation.mask_data or {}, direction)
|
||||
]
|
||||
covered_frame_ids = {int(annotation.frame_id) for annotation in matching}
|
||||
if matching and all(
|
||||
(annotation.mask_data or {}).get("propagation_seed_signature") == seed_signature
|
||||
and _source_model_matches(annotation.mask_data or {}, model_id)
|
||||
for annotation in matching
|
||||
):
|
||||
) and target_frame_ids.issubset(covered_frame_ids):
|
||||
return {
|
||||
"skip": True,
|
||||
"seed_key": seed_key,
|
||||
@@ -317,13 +342,20 @@ def _run_one_step(
|
||||
raise ValueError("Propagation requires seed polygons, bbox, or points")
|
||||
|
||||
model_id = sam_registry.normalize_model_id(payload.get("model"))
|
||||
selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames)
|
||||
include_source = bool(payload.get("include_source", False))
|
||||
target_frame_ids = {
|
||||
int(frame.id)
|
||||
for frame in selected_frames
|
||||
if include_source or frame.id != source_frame.id
|
||||
}
|
||||
seed_state = _prepare_seed_propagation(
|
||||
db,
|
||||
payload=payload,
|
||||
model_id=model_id,
|
||||
source_frame=source_frame,
|
||||
seed=seed,
|
||||
direction=direction,
|
||||
target_frame_ids=target_frame_ids,
|
||||
)
|
||||
if seed_state["skip"]:
|
||||
return {
|
||||
@@ -337,7 +369,6 @@ def _run_one_step(
|
||||
"seed_key": seed_state["seed_key"],
|
||||
}
|
||||
|
||||
selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames)
|
||||
with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload['project_id']}_") as tmpdir:
|
||||
frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir))
|
||||
propagated = sam_registry.propagate_video(
|
||||
|
||||
@@ -2,7 +2,7 @@ import numpy as np
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
from models import Annotation, ProcessingTask
|
||||
from services.propagation_task_runner import run_propagate_project_task
|
||||
from services.propagation_task_runner import _seed_signature, run_propagate_project_task
|
||||
|
||||
|
||||
def _create_project_and_frame(client):
|
||||
@@ -614,6 +614,94 @@ def test_propagation_task_runner_replaces_legacy_or_different_weight_results(cli
|
||||
assert annotations[0].mask_data["polygons"] == [output_polygon]
|
||||
|
||||
|
||||
def test_propagation_task_runner_skips_unmodified_propagated_seed_on_overlapping_frames(client, db_session, monkeypatch):
|
||||
project = client.post("/api/projects", json={"name": "Propagation Overlap Skip"}).json()
|
||||
frames = [
|
||||
client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": idx,
|
||||
"image_url": f"frames/{idx}.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
}).json()
|
||||
for idx in range(3)
|
||||
]
|
||||
|
||||
original_seed_polygon = [[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]
|
||||
propagated_seed_polygon = [[0.14, 0.14], [0.24, 0.14], [0.24, 0.24]]
|
||||
downstream_polygon = [[0.18, 0.18], [0.28, 0.18], [0.28, 0.28]]
|
||||
inherited_signature = _seed_signature({
|
||||
"polygons": [original_seed_polygon],
|
||||
"label": "胆囊",
|
||||
"color": "#ff0000",
|
||||
"source_annotation_id": 7,
|
||||
"source_mask_id": "annotation-7",
|
||||
})
|
||||
|
||||
db_session.add(Annotation(
|
||||
project_id=project["id"],
|
||||
frame_id=frames[2]["id"],
|
||||
mask_data={
|
||||
"polygons": [downstream_polygon],
|
||||
"label": "胆囊",
|
||||
"color": "#ff0000",
|
||||
"source": "sam2.1_hiera_tiny_propagation",
|
||||
"propagated_from_frame_id": frames[0]["id"],
|
||||
"propagation_seed_key": "annotation:7",
|
||||
"propagation_seed_signature": inherited_signature,
|
||||
"propagation_direction": "forward",
|
||||
"source_annotation_id": 7,
|
||||
"source_mask_id": "annotation-7",
|
||||
},
|
||||
bbox=[0.18, 0.18, 0.1, 0.1],
|
||||
))
|
||||
db_session.commit()
|
||||
|
||||
task = ProcessingTask(
|
||||
task_type="propagate_masks",
|
||||
status="queued",
|
||||
progress=0,
|
||||
project_id=project["id"],
|
||||
payload={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frames[1]["id"],
|
||||
"model": "sam2.1_hiera_tiny",
|
||||
"include_source": False,
|
||||
"save_annotations": True,
|
||||
"steps": [{
|
||||
"direction": "forward",
|
||||
"max_frames": 2,
|
||||
"seed": {
|
||||
"polygons": [propagated_seed_polygon],
|
||||
"label": "胆囊",
|
||||
"color": "#ff0000",
|
||||
"source_annotation_id": 7,
|
||||
"source_mask_id": "annotation-7",
|
||||
"propagation_seed_signature": inherited_signature,
|
||||
},
|
||||
}],
|
||||
},
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
|
||||
propagate_calls = []
|
||||
monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg")
|
||||
monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None)
|
||||
monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", lambda *args, **kwargs: propagate_calls.append(args) or [])
|
||||
|
||||
result = run_propagate_project_task(db_session, task.id)
|
||||
|
||||
assert result["created_annotation_count"] == 0
|
||||
assert result["deleted_annotation_count"] == 0
|
||||
assert result["skipped_seed_count"] == 1
|
||||
assert propagate_calls == []
|
||||
annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).all()
|
||||
assert len(annotations) == 1
|
||||
assert annotations[0].mask_data["polygons"] == [downstream_polygon]
|
||||
|
||||
|
||||
def test_predict_validation_errors(client, monkeypatch):
|
||||
project, _, _ = _create_project_and_frame(client)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user