修复同类多实例传播结果被清理

- 在传播结果写入前的空间清理中识别同一任务里的其它 seed,避免后续同类 seed 删除前面刚写入的传播结果。

- 保留旧结果替换逻辑:不属于本次其它 seed 的旧传播结果仍可按语义和空间重叠被新 seed 替换。

- 将同类多实例传播回归测试改为重叠输出场景,覆盖此前两个同类方块只保留一个的问题。
This commit is contained in:
2026-05-04 00:44:11 +08:00
parent 093ef6c63a
commit 2fe4623cae
2 changed files with 41 additions and 2 deletions

View File

@@ -286,6 +286,39 @@ def _seed_identity_matches(mask_data: dict[str, Any], seed_key: str, seed: dict[
return _legacy_seed_matches(mask_data, seed) return _legacy_seed_matches(mask_data, seed)
def _seed_identity_markers(seed: dict[str, Any]) -> set[str]:
markers = {f"seed:{_seed_key(seed)}"}
source_annotation_id = seed.get("source_annotation_id")
if source_annotation_id is not None:
markers.add(f"annotation:{source_annotation_id}")
source_mask_id = seed.get("source_mask_id")
if source_mask_id:
markers.add(f"mask:{source_mask_id}")
return markers
def _mask_identity_markers(mask_data: dict[str, Any]) -> set[str]:
markers: set[str] = set()
previous_seed_key = mask_data.get("propagation_seed_key")
if previous_seed_key:
markers.add(f"seed:{previous_seed_key}")
source_annotation_id = mask_data.get("source_annotation_id")
if source_annotation_id is not None:
markers.add(f"annotation:{source_annotation_id}")
source_mask_id = mask_data.get("source_mask_id")
if source_mask_id:
markers.add(f"mask:{source_mask_id}")
return markers
def _payload_seed_identity_markers(payload: dict[str, Any]) -> set[str]:
markers: set[str] = set()
for step in payload.get("steps") or []:
seed = step.get("seed") or {}
markers.update(_seed_identity_markers(seed))
return markers
def _is_propagation_annotation(annotation: Annotation, seed_key: str, seed: dict[str, Any]) -> bool: def _is_propagation_annotation(annotation: Annotation, seed_key: str, seed: dict[str, Any]) -> bool:
mask_data = annotation.mask_data or {} mask_data = annotation.mask_data or {}
source = str(mask_data.get("source") or "") source = str(mask_data.get("source") or "")
@@ -327,11 +360,17 @@ def _delete_replaced_frame_annotations(
.all() .all()
) )
deleted_count = 0 deleted_count = 0
current_seed_markers = _seed_identity_markers(seed)
task_seed_markers = _payload_seed_identity_markers(payload)
for annotation in previous_annotations: for annotation in previous_annotations:
mask_data = annotation.mask_data or {} mask_data = annotation.mask_data or {}
source = str(mask_data.get("source") or "") source = str(mask_data.get("source") or "")
if not source.endswith("_propagation"): if not source.endswith("_propagation"):
continue continue
mask_markers = _mask_identity_markers(mask_data)
# Keep sibling seeds in the same propagation task from deleting each other.
if mask_markers and mask_markers.isdisjoint(current_seed_markers) and not mask_markers.isdisjoint(task_seed_markers):
continue
same_lineage = _seed_identity_matches(mask_data, seed_key, seed) same_lineage = _seed_identity_matches(mask_data, seed_key, seed)
same_manual_replacement = ( same_manual_replacement = (
_semantic_seed_matches(mask_data, seed) _semantic_seed_matches(mask_data, seed)

View File

@@ -756,8 +756,8 @@ def test_propagation_task_runner_keeps_same_class_seeds_separate(client, db_sess
] ]
output_by_source = { output_by_source = {
7: [[0.10, 0.10], [0.20, 0.10], [0.20, 0.20]], 7: [[0.10, 0.10], [0.30, 0.10], [0.30, 0.30], [0.10, 0.30]],
8: [[0.70, 0.70], [0.80, 0.70], [0.80, 0.80]], 8: [[0.12, 0.12], [0.32, 0.12], [0.32, 0.32], [0.12, 0.32]],
} }
task = ProcessingTask( task = ProcessingTask(
task_type="propagate_masks", task_type="propagate_masks",