支持中空mask编辑和传播保洞
- 前端按 polygonRingCounts 维护外圈/内洞 ring 分组,中空 mask 在调整多边形时显示内洞顶点和插点手柄。 - 保存与回显标注时将中空结构拆分为 mask_data.polygons 和 mask_data.holes,导入/普通 mask 共享同一编辑体验。 - 自动传播 seed 携带 holes,SAM 2 seed 栅格化时扣除内洞,避免中空 mask 以实心形式传播。 - 传播结果轮廓提取改为保留层级内洞,并在同步传播和 Celery 传播落库时写回 holes 与 hasHoles。 - 传播 seed 签名纳入 holes,并加固保存结果时 holes 与原始 polygon 索引对齐。 - 补充前端保存/回显、Canvas 内洞编辑和后端 SAM 2 hole 处理测试。 - 更新 AGENTS、接口契约、需求冻结、设计冻结和测试计划文档,移除中空结构未实现的旧描述。
This commit is contained in:
@@ -803,17 +803,20 @@ def propagate(
|
||||
if not payload.include_source and frame.id == source_frame.id:
|
||||
continue
|
||||
result_polygons = frame_result.get("polygons") or []
|
||||
result_holes = frame_result.get("holes") or []
|
||||
scores = frame_result.get("scores") or []
|
||||
for polygon_index, polygon in enumerate(result_polygons):
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
polygon_to_save = _smooth_polygon(polygon, smoothing) if smoothing else polygon
|
||||
hole_group = result_holes[polygon_index] if polygon_index < len(result_holes) and isinstance(result_holes[polygon_index], list) else []
|
||||
annotation = Annotation(
|
||||
project_id=payload.project_id,
|
||||
frame_id=frame.id,
|
||||
template_id=template_id,
|
||||
mask_data={
|
||||
"polygons": [polygon_to_save],
|
||||
**({"holes": [hole_group], "hasHoles": True} if hole_group else {}),
|
||||
"label": label,
|
||||
"color": color,
|
||||
"source": f"{model_id}_propagation",
|
||||
|
||||
@@ -294,6 +294,7 @@ class SmoothMaskResponse(BaseModel):
|
||||
|
||||
class PropagationSeed(BaseModel):
|
||||
polygons: Optional[list[list[list[float]]]] = None
|
||||
holes: Optional[list[list[list[list[float]]]]] = None
|
||||
bbox: Optional[list[float]] = None
|
||||
points: Optional[list[list[float]]] = None
|
||||
labels: Optional[list[int]] = None
|
||||
|
||||
@@ -207,6 +207,7 @@ def _seed_signature(seed: dict[str, Any]) -> str:
|
||||
return str(inherited_signature)
|
||||
signature_payload = {
|
||||
"polygons": seed.get("polygons") or [],
|
||||
"holes": seed.get("holes") or [],
|
||||
"bbox": seed.get("bbox") or [],
|
||||
"points": seed.get("points") or [],
|
||||
"labels": seed.get("labels") or [],
|
||||
@@ -458,13 +459,14 @@ def _save_propagated_annotations(
|
||||
if not include_source and frame.id == source_frame.id:
|
||||
continue
|
||||
result_polygons = frame_result.get("polygons") or []
|
||||
result_holes = frame_result.get("holes") or []
|
||||
scores = frame_result.get("scores") or []
|
||||
smoothed_polygons = [
|
||||
_smooth_polygon(polygon, smoothing)
|
||||
for polygon in result_polygons
|
||||
prepared_polygons = [
|
||||
(polygon_index, _smooth_polygon(polygon, smoothing))
|
||||
for polygon_index, polygon in enumerate(result_polygons)
|
||||
if len(polygon) >= 3
|
||||
]
|
||||
cleanup_polygon = next((polygon for polygon in smoothed_polygons if len(polygon) >= 3), None)
|
||||
cleanup_polygon = next((polygon for _polygon_index, polygon in prepared_polygons if len(polygon) >= 3), None)
|
||||
if cleanup_polygon is not None and frame.id not in cleaned_frame_ids:
|
||||
deleted_count += _delete_replaced_frame_annotations(
|
||||
db,
|
||||
@@ -475,15 +477,17 @@ def _save_propagated_annotations(
|
||||
polygon=cleanup_polygon,
|
||||
)
|
||||
cleaned_frame_ids.add(int(frame.id))
|
||||
for polygon_index, polygon in enumerate(smoothed_polygons):
|
||||
for polygon_index, polygon in prepared_polygons:
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
hole_group = result_holes[polygon_index] if polygon_index < len(result_holes) and isinstance(result_holes[polygon_index], list) else []
|
||||
annotation = Annotation(
|
||||
project_id=int(payload["project_id"]),
|
||||
frame_id=frame.id,
|
||||
template_id=template_id,
|
||||
mask_data={
|
||||
"polygons": [polygon],
|
||||
**({"holes": [hole_group], "hasHoles": True} if hole_group else {}),
|
||||
"label": label,
|
||||
"color": color,
|
||||
"source": f"{model_id}_propagation",
|
||||
|
||||
@@ -507,7 +507,7 @@ class SAM2Engine:
|
||||
if source_image is None:
|
||||
raise RuntimeError("Failed to decode source frame for SAM 2 propagation.")
|
||||
height, width = source_image.shape[:2]
|
||||
seed_mask = self._polygons_to_mask(seed.get("polygons") or [], width, height)
|
||||
seed_mask = self._polygons_to_mask(seed.get("polygons") or [], width, height, seed.get("holes") or [])
|
||||
if not seed_mask.any():
|
||||
bbox = seed.get("bbox")
|
||||
if isinstance(bbox, list) and len(bbox) == 4:
|
||||
@@ -543,15 +543,18 @@ class SAM2Engine:
|
||||
if masks.ndim == 4:
|
||||
masks = masks[:, 0]
|
||||
polygons = []
|
||||
holes = []
|
||||
scores = []
|
||||
for mask in masks:
|
||||
polygon = self._mask_to_polygon(mask > 0)
|
||||
if polygon:
|
||||
mask_polygons, mask_holes = self._mask_to_polygon_data(mask > 0)
|
||||
for polygon_index, polygon in enumerate(mask_polygons):
|
||||
polygons.append(polygon)
|
||||
holes.append(mask_holes[polygon_index] if polygon_index < len(mask_holes) else [])
|
||||
scores.append(1.0)
|
||||
results[int(out_frame_idx)] = {
|
||||
"frame_index": int(out_frame_idx),
|
||||
"polygons": polygons,
|
||||
"holes": holes,
|
||||
"scores": scores,
|
||||
"object_ids": [int(obj_id) for obj_id in list(out_obj_ids)],
|
||||
}
|
||||
@@ -574,19 +577,49 @@ class SAM2Engine:
|
||||
@staticmethod
|
||||
def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
|
||||
"""Convert a binary mask to a normalized polygon."""
|
||||
polygons, _holes = SAM2Engine._mask_to_polygon_data(mask)
|
||||
return polygons[0] if polygons else []
|
||||
|
||||
@staticmethod
|
||||
def _mask_to_polygon_data(mask: np.ndarray) -> tuple[list[list[list[float]]], list[list[list[list[float]]]]]:
|
||||
"""Convert a binary mask to normalized outer polygons and aligned hole rings."""
|
||||
import cv2
|
||||
|
||||
if mask.dtype != np.uint8:
|
||||
mask = (mask > 0).astype(np.uint8)
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
contours, hierarchy = cv2.findContours(mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
|
||||
h, w = mask.shape[:2]
|
||||
largest = []
|
||||
for cnt in contours:
|
||||
if len(cnt) > len(largest):
|
||||
largest = cnt
|
||||
if len(largest) < 3:
|
||||
return []
|
||||
return [[float(pt[0][0]) / w, float(pt[0][1]) / h] for pt in largest]
|
||||
if hierarchy is None:
|
||||
return [], []
|
||||
|
||||
def contour_to_polygon(contour: np.ndarray) -> list[list[float]]:
|
||||
if len(contour) < 3:
|
||||
return []
|
||||
return [[float(pt[0][0]) / w, float(pt[0][1]) / h] for pt in contour]
|
||||
|
||||
hierarchy_rows = hierarchy[0]
|
||||
outer_indices = [
|
||||
index for index, row in enumerate(hierarchy_rows)
|
||||
if int(row[3]) < 0 and len(contours[index]) >= 3
|
||||
]
|
||||
outer_indices.sort(key=lambda index: cv2.contourArea(contours[index]), reverse=True)
|
||||
|
||||
polygons: list[list[list[float]]] = []
|
||||
holes: list[list[list[list[float]]]] = []
|
||||
for outer_index in outer_indices:
|
||||
outer = contour_to_polygon(contours[outer_index])
|
||||
if not outer:
|
||||
continue
|
||||
child_index = int(hierarchy_rows[outer_index][2])
|
||||
hole_group: list[list[list[float]]] = []
|
||||
while child_index >= 0:
|
||||
hole = contour_to_polygon(contours[child_index])
|
||||
if hole:
|
||||
hole_group.append(hole)
|
||||
child_index = int(hierarchy_rows[child_index][0])
|
||||
polygons.append(outer)
|
||||
holes.append(hole_group)
|
||||
return polygons, holes
|
||||
|
||||
@staticmethod
|
||||
def _dummy_polygons(w: int, h: int) -> list[list[list[float]]]:
|
||||
@@ -601,11 +634,16 @@ class SAM2Engine:
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _polygons_to_mask(polygons: list[list[list[float]]], width: int, height: int) -> np.ndarray:
|
||||
def _polygons_to_mask(
|
||||
polygons: list[list[list[float]]],
|
||||
width: int,
|
||||
height: int,
|
||||
holes_by_polygon: list[list[list[list[float]]]] | None = None,
|
||||
) -> np.ndarray:
|
||||
import cv2
|
||||
|
||||
mask = np.zeros((height, width), dtype=np.uint8)
|
||||
for polygon in polygons:
|
||||
for polygon_index, polygon in enumerate(polygons):
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
pts = np.array(
|
||||
@@ -619,6 +657,21 @@ class SAM2Engine:
|
||||
dtype=np.int32,
|
||||
)
|
||||
cv2.fillPoly(mask, [pts], 1)
|
||||
holes = holes_by_polygon[polygon_index] if holes_by_polygon and polygon_index < len(holes_by_polygon) else []
|
||||
for hole in holes:
|
||||
if len(hole) < 3:
|
||||
continue
|
||||
hole_pts = np.array(
|
||||
[
|
||||
[
|
||||
int(round(min(max(float(x), 0.0), 1.0) * max(width - 1, 1))),
|
||||
int(round(min(max(float(y), 0.0), 1.0) * max(height - 1, 1))),
|
||||
]
|
||||
for x, y in hole
|
||||
],
|
||||
dtype=np.int32,
|
||||
)
|
||||
cv2.fillPoly(mask, [hole_pts], 0)
|
||||
return mask.astype(bool)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -78,3 +78,27 @@ def test_sam2_status_exposes_selectable_variants(monkeypatch, tmp_path):
|
||||
assert status["label"] == "SAM 2.1 Small"
|
||||
assert status["checkpoint_exists"] is True
|
||||
assert status["checkpoint_path"].endswith("sam2.1_hiera_small.pt")
|
||||
|
||||
|
||||
def test_sam2_seed_mask_subtracts_holes():
|
||||
mask = SAM2Engine._polygons_to_mask(
|
||||
polygons=[[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9], [0.1, 0.9]]],
|
||||
width=100,
|
||||
height=100,
|
||||
holes_by_polygon=[[[[0.4, 0.4], [0.6, 0.4], [0.6, 0.6], [0.4, 0.6]]]],
|
||||
)
|
||||
|
||||
assert bool(mask[20, 20]) is True
|
||||
assert bool(mask[50, 50]) is False
|
||||
|
||||
|
||||
def test_sam2_mask_to_polygon_data_preserves_holes():
|
||||
mask = np.zeros((100, 100), dtype=np.uint8)
|
||||
mask[10:90, 10:90] = 1
|
||||
mask[40:60, 40:60] = 0
|
||||
|
||||
polygons, holes = SAM2Engine._mask_to_polygon_data(mask)
|
||||
|
||||
assert len(polygons) == 1
|
||||
assert len(holes) == 1
|
||||
assert len(holes[0]) == 1
|
||||
|
||||
Reference in New Issue
Block a user