import numpy as np import cv2 from pathlib import Path from models import Annotation, ProcessingTask from services.propagation_task_runner import _seed_signature, run_propagate_project_task def _create_project_and_frame(client): project = client.post("/api/projects", json={"name": "AI Project"}).json() frame = client.post(f"/api/projects/{project['id']}/frames", json={ "project_id": project["id"], "frame_index": 0, "image_url": "frames/0.jpg", "width": 640, "height": 360, }).json() template = client.post("/api/templates", json={ "name": "Template", "color": "#06b6d4", "z_index": 0, "classes": [], "rules": [], }).json() return project, frame, template def test_predict_accepts_point_object_with_labels(client, monkeypatch): _, frame, _ = _create_project_and_frame(client) calls = {} monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8)) def fake_predict_points(image, points, labels): calls["args"] = (points, labels) return ( [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]], [0.95], ) monkeypatch.setattr("routers.ai.sam_registry.predict_points", lambda model, image, points, labels: fake_predict_points(image, points, labels)) response = client.post("/api/ai/predict", json={ "image_id": frame["id"], "prompt_type": "point", "prompt_data": {"points": [[0.5, 0.5], [0.1, 0.1]], "labels": [1, 0]}, }) assert response.status_code == 200 assert response.json()["scores"] == [0.95] assert calls["args"] == ([[0.5, 0.5], [0.1, 0.1]], [1, 0]) def test_predict_applies_crop_and_background_filter_options(client, monkeypatch): _, frame, _ = _create_project_and_frame(client) calls = {} monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((100, 200, 3), dtype=np.uint8)) def fake_predict_points(model, image, points, labels): calls["shape"] = image.shape calls["points"] = points calls["labels"] = labels return ( [ [[0.0, 0.0], [0.2, 0.0], [0.2, 0.2]], [[0.45, 0.45], [0.55, 0.45], [0.55, 0.55]], ], [0.9, 0.01], ) monkeypatch.setattr("routers.ai.sam_registry.predict_points", fake_predict_points) response = client.post("/api/ai/predict", json={ "image_id": frame["id"], "prompt_type": "point", "prompt_data": {"points": [[0.5, 0.5], [0.52, 0.52]], "labels": [1, 0]}, "options": { "crop_to_prompt": True, "crop_margin": 0.1, "auto_filter_background": True, "min_score": 0.05, }, }) assert response.status_code == 200 assert calls["shape"][0] < 100 assert calls["shape"][1] < 200 assert calls["labels"] == [1, 0] assert response.json()["scores"] == [0.9] polygon = response.json()["polygons"][0] assert all(0.0 <= coord <= 1.0 for point in polygon for coord in point) def test_predict_box_and_rejects_semantic_prompt(client, monkeypatch): _, frame, _ = _create_project_and_frame(client) monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8)) monkeypatch.setattr("routers.ai.sam_registry.predict_box", lambda model, image, box: ( [[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]], [0.8], )) box_response = client.post("/api/ai/predict", json={ "image_id": frame["id"], "prompt_type": "box", "prompt_data": [0.2, 0.2, 0.8, 0.8], }) semantic_response = client.post("/api/ai/predict", json={ "image_id": frame["id"], "prompt_type": "semantic", "prompt_data": "胆囊", "model": "sam3", "options": {"min_score": 0.05}, }) assert box_response.status_code == 200 assert box_response.json()["scores"] == [0.8] assert semantic_response.status_code == 400 assert "Unsupported model: sam3" in semantic_response.json()["detail"] def test_predict_interactive_combines_box_and_points(client, monkeypatch): _, frame, _ = _create_project_and_frame(client) calls = {} monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8)) def fake_predict_interactive(model, image, box, points, labels): calls["model"] = model calls["box"] = box calls["points"] = points calls["labels"] = labels return ( [[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]], [0.88], ) monkeypatch.setattr("routers.ai.sam_registry.predict_interactive", fake_predict_interactive) response = client.post("/api/ai/predict", json={ "image_id": frame["id"], "prompt_type": "interactive", "prompt_data": { "box": [0.1, 0.1, 0.9, 0.9], "points": [[0.5, 0.5], [0.2, 0.2]], "labels": [1, 0], }, "model": "sam2.1_hiera_small", }) assert response.status_code == 200 assert response.json()["scores"] == [0.88] assert calls == { "model": "sam2.1_hiera_small", "box": [0.1, 0.1, 0.9, 0.9], "points": [[0.5, 0.5], [0.2, 0.2]], "labels": [1, 0], } def test_model_status_reports_runtime(client, monkeypatch): monkeypatch.setattr("routers.ai.sam_registry.runtime_status", lambda selected_model=None: { "selected_model": "sam2.1_hiera_tiny", "gpu": { "available": False, "device": "cpu", "name": None, "torch_available": True, "torch_version": "2.x", "cuda_version": None, }, "models": [ { "id": "sam2.1_hiera_tiny", "label": "SAM 2.1 Tiny", "available": True, "loaded": False, "device": "cpu", "supports": ["point", "box", "auto"], "message": "ready", "package_available": True, "checkpoint_exists": True, "checkpoint_path": "model.pt", "python_ok": True, "torch_ok": True, "cuda_required": False, }, ], }) response = client.get("/api/ai/models/status") assert response.status_code == 200 body = response.json() assert body["selected_model"] == "sam2.1_hiera_tiny" assert len(body["models"]) == 1 assert body["models"][0]["id"] == "sam2.1_hiera_tiny" def test_model_status_rejects_disabled_sam3(client): response = client.get("/api/ai/models/status?selected_model=sam3") assert response.status_code == 400 assert "Unsupported model" in response.json()["detail"] def test_analyze_mask_returns_backend_geometry_properties(client): _, frame, _ = _create_project_and_frame(client) response = client.post("/api/ai/analyze-mask", json={ "frame_id": frame["id"], "mask_data": { "polygons": [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3], [0.1, 0.3]]], "source": "sam2.1_hiera_tiny", "score": 0.87, }, "extract_skeleton": True, }) assert response.status_code == 200 body = response.json() assert body["confidence"] == 0.87 assert body["confidence_source"] == "model_score" assert body["topology_anchor_count"] == 4 assert body["area"] > 0 assert body["message"] == "已从后端重新提取几何拓扑锚点" def test_analyze_mask_reports_actual_polygon_anchor_count(client): _, frame, _ = _create_project_and_frame(client) polygon = [[0.1 + index * 0.005, 0.1 + (0.01 if index % 2 else 0)] for index in range(80)] response = client.post("/api/ai/analyze-mask", json={ "frame_id": frame["id"], "mask_data": { "polygons": [polygon], "label": "AI Mask", "color": "#06b6d4", }, "points": [[0.2, 0.2]], }) assert response.status_code == 200 body = response.json() assert body["topology_anchor_count"] == len(polygon) assert len(body["topology_anchors"]) <= 64 def test_smooth_mask_simplifies_noisy_ai_polygon(client): _, frame, _ = _create_project_and_frame(client) polygon = [] for index in range(20): polygon.append([0.1 + index * 0.02, 0.1 + (0.01 if index % 2 else 0)]) for index in range(20): polygon.append([0.5 + (0.01 if index % 2 else 0), 0.1 + index * 0.02]) for index in range(20): polygon.append([0.5 - index * 0.02, 0.5 + (0.01 if index % 2 else 0)]) for index in range(20): polygon.append([0.1 + (0.01 if index % 2 else 0), 0.5 - index * 0.02]) response = client.post("/api/ai/smooth-mask", json={ "frame_id": frame["id"], "mask_data": { "polygons": [polygon], "label": "AI Mask", "color": "#06b6d4", }, "strength": 80, }) assert response.status_code == 200 body = response.json() assert body["topology_anchor_count"] == len(body["polygons"][0]) assert len(body["polygons"][0]) < len(polygon) def test_smooth_mask_uses_eased_strength_curve(client): _, frame, _ = _create_project_and_frame(client) polygon = [] for index in range(20): polygon.append([0.1 + index * 0.02, 0.1 + (0.01 if index % 2 else 0)]) for index in range(20): polygon.append([0.5 + (0.01 if index % 2 else 0), 0.1 + index * 0.02]) for index in range(20): polygon.append([0.5 - index * 0.02, 0.5 + (0.01 if index % 2 else 0)]) for index in range(20): polygon.append([0.1 + (0.01 if index % 2 else 0), 0.5 - index * 0.02]) def smoothed_count(strength: int) -> int: response = client.post("/api/ai/smooth-mask", json={ "frame_id": frame["id"], "mask_data": { "polygons": [polygon], "label": "AI Mask", "color": "#06b6d4", }, "strength": strength, }) assert response.status_code == 200 return len(response.json()["polygons"][0]) low_count = smoothed_count(20) mid_count = smoothed_count(70) high_count = smoothed_count(95) assert low_count <= len(polygon) assert mid_count < low_count assert high_count < mid_count def test_smooth_mask_returns_backend_smoothed_geometry(client): _, frame, _ = _create_project_and_frame(client) response = client.post("/api/ai/smooth-mask", json={ "frame_id": frame["id"], "mask_data": { "polygons": [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3], [0.1, 0.3]]], "label": "胆囊", "color": "#ff0000", }, "strength": 45, }) assert response.status_code == 200 body = response.json() assert body["smoothing"] == {"strength": 45.0, "method": "chaikin"} assert len(body["polygons"]) == 1 assert len(body["polygons"][0]) > 4 assert body["topology_anchor_count"] > 0 assert body["message"] == "已应用边缘平滑强度 45" def test_seed_signature_includes_smoothing_parameters(): seed = { "polygons": [[[0.1, 0.1], [0.3, 0.1], [0.3, 0.3]]], "label": "胆囊", "color": "#ff0000", } assert _seed_signature({**seed, "smoothing": {"strength": 20, "method": "chaikin"}}) != _seed_signature({ **seed, "smoothing": {"strength": 40, "method": "chaikin"}, }) def test_propagate_saves_tracked_annotations(client, monkeypatch): project = client.post("/api/projects", json={"name": "Video Project"}).json() frames = [ client.post(f"/api/projects/{project['id']}/frames", json={ "project_id": project["id"], "frame_index": idx, "image_url": f"frames/{idx}.jpg", "width": 640, "height": 360, }).json() for idx in range(3) ] calls = {} monkeypatch.setattr("routers.ai.download_file", lambda object_name: b"jpeg") def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames): calls["model"] = model calls["source_frame_index"] = source_frame_index calls["seed"] = seed calls["direction"] = direction calls["max_frames"] = max_frames calls["frame_count"] = len(frame_paths) return [ { "frame_index": 0, "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], "scores": [0.9], "object_ids": [1], }, { "frame_index": 1, "polygons": [[[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]], "scores": [0.8], "object_ids": [1], }, ] monkeypatch.setattr("routers.ai.sam_registry.propagate_video", fake_propagate_video) response = client.post("/api/ai/propagate", json={ "project_id": project["id"], "frame_id": frames[0]["id"], "model": "sam2.1_hiera_tiny", "direction": "forward", "max_frames": 2, "include_source": False, "seed": { "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], "bbox": [0.1, 0.1, 0.1, 0.1], "label": "胆囊", "color": "#ff0000", "class_metadata": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20}, "template_id": None, "smoothing": {"strength": 45, "method": "chaikin"}, }, }) assert response.status_code == 200 body = response.json() assert body["created_annotation_count"] == 1 assert body["processed_frame_count"] == 2 assert calls["model"] == "sam2.1_hiera_tiny" assert calls["source_frame_index"] == 0 assert calls["direction"] == "forward" assert calls["frame_count"] == 2 saved = body["annotations"][0] assert saved["frame_id"] == frames[1]["id"] assert saved["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation" assert saved["mask_data"]["class"]["name"] == "胆囊" assert saved["mask_data"]["score"] == 0.8 assert saved["mask_data"]["geometry_smoothing"] == {"strength": 45.0, "method": "chaikin"} assert saved["mask_data"]["polygons"][0] != [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]] assert len(saved["mask_data"]["polygons"][0]) > 3 listing = client.get(f"/api/ai/annotations?project_id={project['id']}") assert len(listing.json()) == 1 def test_queue_propagation_task_creates_processing_task(client, monkeypatch): project = client.post("/api/projects", json={"name": "Queued Propagation"}).json() frame = client.post(f"/api/projects/{project['id']}/frames", json={ "project_id": project["id"], "frame_index": 0, "image_url": "frames/0.jpg", "width": 640, "height": 360, }).json() class FakeAsyncResult: id = "celery-propagate-1" queued = [] monkeypatch.setattr("routers.ai.propagate_project_masks.delay", lambda task_id: queued.append(task_id) or FakeAsyncResult()) monkeypatch.setattr("routers.ai.publish_task_progress_event", lambda task: None) response = client.post("/api/ai/propagate/task", json={ "project_id": project["id"], "frame_id": frame["id"], "model": "sam2.1_hiera_tiny", "steps": [{ "direction": "forward", "max_frames": 2, "seed": { "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], "label": "胆囊", "smoothing": {"strength": 35, "method": "chaikin"}, }, }], }) assert response.status_code == 202 body = response.json() assert body["task_type"] == "propagate_masks" assert body["status"] == "queued" assert body["celery_task_id"] == "celery-propagate-1" assert body["payload"]["model"] == "sam2.1_hiera_tiny" assert body["payload"]["steps"][0]["seed"]["label"] == "胆囊" assert body["payload"]["steps"][0]["seed"]["smoothing"] == {"strength": 35, "method": "chaikin"} assert queued == [body["id"]] def test_queue_propagation_task_normalizes_model_and_rejects_unsupported(client, monkeypatch): project = client.post("/api/projects", json={"name": "Propagation Model"}).json() frame = client.post(f"/api/projects/{project['id']}/frames", json={ "project_id": project["id"], "frame_index": 0, "image_url": "frames/0.jpg", "width": 640, "height": 360, }).json() class FakeAsyncResult: id = "celery-propagate-model" monkeypatch.setattr("routers.ai.propagate_project_masks.delay", lambda task_id: FakeAsyncResult()) monkeypatch.setattr("routers.ai.publish_task_progress_event", lambda task: None) response = client.post("/api/ai/propagate/task", json={ "project_id": project["id"], "frame_id": frame["id"], "model": "sam2", "steps": [{ "direction": "forward", "max_frames": 2, "seed": { "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], }, }], }) assert response.status_code == 202 assert response.json()["payload"]["model"] == "sam2.1_hiera_tiny" unsupported = client.post("/api/ai/propagate/task", json={ "project_id": project["id"], "frame_id": frame["id"], "model": "sam3", "steps": [{ "direction": "forward", "max_frames": 2, "seed": { "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], }, }], }) assert unsupported.status_code == 400 assert "Unsupported model" in unsupported.json()["detail"] def test_propagation_task_runner_saves_annotations_and_progress(client, db_session, monkeypatch): project = client.post("/api/projects", json={"name": "Propagation Worker"}).json() frames = [ client.post(f"/api/projects/{project['id']}/frames", json={ "project_id": project["id"], "frame_index": idx, "image_url": f"frames/{idx}.jpg", "width": 640, "height": 360, }).json() for idx in range(2) ] task = ProcessingTask( task_type="propagate_masks", status="queued", progress=0, project_id=project["id"], payload={ "project_id": project["id"], "frame_id": frames[0]["id"], "model": "sam2.1_hiera_tiny", "include_source": False, "save_annotations": True, "steps": [{ "direction": "forward", "max_frames": 2, "seed": { "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], "label": "胆囊", "color": "#ff0000", "class_metadata": {"id": "c1", "name": "胆囊"}, "smoothing": {"strength": 40, "method": "chaikin"}, }, }], }, ) db_session.add(task) db_session.commit() db_session.refresh(task) published = [] monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg") monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: published.append((event_task.status, event_task.progress))) def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames): assert [Path(path).name for path in frame_paths] == ["000000.jpg", "000001.jpg"] return [ {"frame_index": 0, "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], "scores": [0.9]}, {"frame_index": 1, "polygons": [[[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]], "scores": [0.8]}, ] monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", fake_propagate_video) result = run_propagate_project_task(db_session, task.id) db_session.refresh(task) assert task.status == "success" assert task.progress == 100 assert task.result["model"] == "sam2.1_hiera_tiny" assert task.result["steps"][0]["model"] == "sam2.1_hiera_tiny" assert result["created_annotation_count"] == 1 assert result["processed_frame_count"] == 2 assert published[0][0] == "running" assert published[-1] == ("success", 100) listing = client.get(f"/api/ai/annotations?project_id={project['id']}") assert listing.json()[0]["frame_id"] == frames[1]["id"] assert listing.json()[0]["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation" stored_polygon = listing.json()[0]["mask_data"]["polygons"][0] assert listing.json()[0]["mask_data"]["geometry_smoothing"] == {"strength": 40.0, "method": "chaikin"} assert stored_polygon != [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]] assert len(stored_polygon) > 3 def test_propagation_task_runner_keeps_disconnected_result_polygons_in_one_annotation(client, db_session, monkeypatch): project = client.post("/api/projects", json={"name": "Propagation Disconnected Mask"}).json() frames = [ client.post(f"/api/projects/{project['id']}/frames", json={ "project_id": project["id"], "frame_index": idx, "image_url": f"frames/{idx}.jpg", "width": 640, "height": 360, }).json() for idx in range(2) ] first_piece = [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25], [0.15, 0.25]] second_piece = [[0.70, 0.70], [0.90, 0.70], [0.90, 0.90], [0.70, 0.90]] second_hole = [[[0.76, 0.76], [0.82, 0.76], [0.82, 0.82], [0.76, 0.82]]] task = ProcessingTask( task_type="propagate_masks", status="queued", progress=0, project_id=project["id"], payload={ "project_id": project["id"], "frame_id": frames[0]["id"], "model": "sam2.1_hiera_tiny", "include_source": False, "save_annotations": True, "steps": [{ "direction": "forward", "max_frames": 2, "seed": { "polygons": [ [[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]], [[0.6, 0.6], [0.8, 0.6], [0.8, 0.8]], ], "label": "多区域", "color": "#ff0000", "source_annotation_id": 7, "source_mask_id": "annotation-7", }, }], }, ) db_session.add(task) db_session.commit() db_session.refresh(task) monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg") monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None) monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", lambda *args, **kwargs: [ {"frame_index": 0, "polygons": [], "scores": []}, { "frame_index": 1, "polygons": [first_piece, second_piece], "holes": [[], second_hole], "scores": [0.72, 0.93], }, ]) result = run_propagate_project_task(db_session, task.id) assert result["created_annotation_count"] == 1 annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).all() assert len(annotations) == 1 annotation = annotations[0] assert annotation.frame_id == frames[1]["id"] assert annotation.bbox == [0.15, 0.15, 0.75, 0.75] assert annotation.mask_data["polygons"] == [first_piece, second_piece] assert annotation.mask_data["holes"] == [[], second_hole] assert annotation.mask_data["hasHoles"] is True assert annotation.mask_data["score"] == 0.93 assert annotation.mask_data["scores"] == [0.72, 0.93] def test_propagation_task_runner_skips_unchanged_seed_and_replaces_changed_seed(client, db_session, monkeypatch): project = client.post("/api/projects", json={"name": "Propagation Dedupe"}).json() frames = [ client.post(f"/api/projects/{project['id']}/frames", json={ "project_id": project["id"], "frame_index": idx, "image_url": f"frames/{idx}.jpg", "width": 640, "height": 360, }).json() for idx in range(2) ] def make_task(seed_polygon): task = ProcessingTask( task_type="propagate_masks", status="queued", progress=0, project_id=project["id"], payload={ "project_id": project["id"], "frame_id": frames[0]["id"], "model": "sam2.1_hiera_tiny", "include_source": False, "save_annotations": True, "steps": [{ "direction": "forward", "max_frames": 2, "seed": { "polygons": [seed_polygon], "label": "胆囊", "color": "#ff0000", "source_annotation_id": 7, "source_mask_id": "annotation-7", }, }], }, ) db_session.add(task) db_session.commit() db_session.refresh(task) return task seed_polygon = [[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]] first_output_polygon = [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]] changed_seed_polygon = [[0.2, 0.2], [0.3, 0.2], [0.3, 0.3]] replacement_output_polygon = [[0.22, 0.22], [0.32, 0.22], [0.32, 0.32]] monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg") monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None) propagate_calls = [] def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames): propagate_calls.append(seed["polygons"][0]) output_polygon = replacement_output_polygon if seed["polygons"][0] == changed_seed_polygon else first_output_polygon return [ {"frame_index": 0, "polygons": [seed["polygons"][0]], "scores": [0.9]}, {"frame_index": 1, "polygons": [output_polygon], "scores": [0.8]}, ] monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", fake_propagate_video) first_result = run_propagate_project_task(db_session, make_task(seed_polygon).id) assert first_result["created_annotation_count"] == 1 assert len(propagate_calls) == 1 unchanged_result = run_propagate_project_task(db_session, make_task(seed_polygon).id) assert unchanged_result["created_annotation_count"] == 0 assert unchanged_result["skipped_seed_count"] == 1 assert len(propagate_calls) == 1 assert db_session.query(Annotation).filter(Annotation.project_id == project["id"]).count() == 1 changed_result = run_propagate_project_task(db_session, make_task(changed_seed_polygon).id) assert changed_result["created_annotation_count"] == 1 assert changed_result["deleted_annotation_count"] == 1 assert len(propagate_calls) == 2 annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).all() assert len(annotations) == 1 assert annotations[0].mask_data["polygons"] == [replacement_output_polygon] assert annotations[0].mask_data["source_annotation_id"] == 7 def test_propagation_task_runner_replaces_legacy_or_different_weight_results(client, db_session, monkeypatch): project = client.post("/api/projects", json={"name": "Propagation Legacy Cleanup"}).json() frames = [ client.post(f"/api/projects/{project['id']}/frames", json={ "project_id": project["id"], "frame_index": idx, "image_url": f"frames/{idx}.jpg", "width": 640, "height": 360, }).json() for idx in range(2) ] seed_polygon = [[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]] output_polygon = [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]] db_session.add(Annotation( project_id=project["id"], frame_id=frames[1]["id"], mask_data={ "polygons": [[[0.12, 0.12], [0.22, 0.12], [0.22, 0.22]]], "label": "胆囊", "color": "#ff0000", "source": "sam2.1_hiera_tiny_propagation", "propagated_from_frame_id": frames[0]["id"], "propagation_seed_key": "mask:temporary-front-end-id", "propagation_direction": "forward", }, bbox=[0.12, 0.12, 0.1, 0.1], )) db_session.commit() task = ProcessingTask( task_type="propagate_masks", status="queued", progress=0, project_id=project["id"], payload={ "project_id": project["id"], "frame_id": frames[0]["id"], "model": "sam2.1_hiera_small", "include_source": False, "save_annotations": True, "steps": [{ "direction": "forward", "max_frames": 2, "seed": { "polygons": [seed_polygon], "label": "胆囊", "color": "#ff0000", "source_annotation_id": 7, "source_mask_id": "annotation-7", }, }], }, ) db_session.add(task) db_session.commit() db_session.refresh(task) monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg") monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None) monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", lambda model, frame_paths, source_frame_index, seed, direction, max_frames: [ {"frame_index": 0, "polygons": [seed["polygons"][0]], "scores": [0.9]}, {"frame_index": 1, "polygons": [output_polygon], "scores": [0.8]}, ]) result = run_propagate_project_task(db_session, task.id) assert result["created_annotation_count"] == 1 assert result["deleted_annotation_count"] == 1 annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).all() assert len(annotations) == 1 assert annotations[0].mask_data["source"] == "sam2.1_hiera_small_propagation" assert annotations[0].mask_data["source_annotation_id"] == 7 assert annotations[0].mask_data["polygons"] == [output_polygon] def test_propagation_task_runner_keeps_same_class_seeds_separate(client, db_session, monkeypatch): project = client.post("/api/projects", json={"name": "Propagation Multi Instance"}).json() frames = [ client.post(f"/api/projects/{project['id']}/frames", json={ "project_id": project["id"], "frame_index": idx, "image_url": f"frames/{idx}.jpg", "width": 640, "height": 360, }).json() for idx in range(2) ] output_by_source = { 7: [[0.10, 0.10], [0.30, 0.10], [0.30, 0.30], [0.10, 0.30]], 8: [[0.12, 0.12], [0.32, 0.12], [0.32, 0.32], [0.12, 0.32]], } task = ProcessingTask( task_type="propagate_masks", status="queued", progress=0, project_id=project["id"], payload={ "project_id": project["id"], "frame_id": frames[0]["id"], "model": "sam2.1_hiera_tiny", "include_source": False, "save_annotations": True, "steps": [ { "direction": "forward", "max_frames": 2, "seed": { "polygons": [[[0.05, 0.05], [0.15, 0.05], [0.15, 0.15]]], "label": "胆囊", "color": "#ff0000", "source_annotation_id": 7, "source_mask_id": "annotation-7", "class_metadata": {"id": "gallbladder", "name": "胆囊"}, }, }, { "direction": "forward", "max_frames": 2, "seed": { "polygons": [[[0.65, 0.65], [0.75, 0.65], [0.75, 0.75]]], "label": "胆囊", "color": "#ff0000", "source_annotation_id": 8, "source_mask_id": "annotation-8", "class_metadata": {"id": "gallbladder", "name": "胆囊"}, }, }, ], }, ) db_session.add(task) db_session.commit() db_session.refresh(task) monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg") monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None) def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames): output_polygon = output_by_source[seed["source_annotation_id"]] return [ {"frame_index": 0, "polygons": [seed["polygons"][0]], "scores": [0.9]}, {"frame_index": 1, "polygons": [output_polygon], "scores": [0.8]}, ] monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", fake_propagate_video) result = run_propagate_project_task(db_session, task.id) assert result["created_annotation_count"] == 2 assert result["deleted_annotation_count"] == 0 annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).order_by(Annotation.id).all() assert [annotation.mask_data["source_annotation_id"] for annotation in annotations] == [7, 8] assert [annotation.mask_data["polygons"][0] for annotation in annotations] == [output_by_source[7], output_by_source[8]] def test_propagation_task_runner_replaces_downstream_result_from_middle_frame_manual_seed(client, db_session, monkeypatch): project = client.post("/api/projects", json={"name": "Propagation Middle Frame Replacement"}).json() frames = [ client.post(f"/api/projects/{project['id']}/frames", json={ "project_id": project["id"], "frame_index": idx, "image_url": f"frames/{idx}.jpg", "width": 640, "height": 360, }).json() for idx in range(3) ] old_downstream_polygon = [[0.18, 0.18], [0.28, 0.18], [0.28, 0.28]] replacement_seed_polygon = [[0.16, 0.16], [0.26, 0.16], [0.26, 0.26]] replacement_downstream_polygon = [[0.19, 0.19], [0.29, 0.19], [0.29, 0.29]] db_session.add(Annotation( project_id=project["id"], frame_id=frames[2]["id"], template_id=3, mask_data={ "polygons": [old_downstream_polygon], "label": "胆囊", "color": "#ff0000", "class": {"id": "c1", "name": "胆囊", "color": "#ff0000"}, "source": "sam2.1_hiera_tiny_propagation", "propagated_from_frame_id": frames[0]["id"], "propagation_seed_key": "annotation:7", "propagation_seed_signature": "old-signature", "propagation_direction": "forward", "source_annotation_id": 7, "source_mask_id": "annotation-7", }, bbox=[0.18, 0.18, 0.1, 0.1], )) db_session.commit() task = ProcessingTask( task_type="propagate_masks", status="queued", progress=0, project_id=project["id"], payload={ "project_id": project["id"], "frame_id": frames[1]["id"], "model": "sam2.1_hiera_tiny", "include_source": False, "save_annotations": True, "steps": [{ "direction": "forward", "max_frames": 2, "seed": { "polygons": [replacement_seed_polygon], "label": "胆囊", "color": "#ff0000", "source_annotation_id": 20, "source_mask_id": "annotation-20", }, }], }, ) db_session.add(task) db_session.commit() db_session.refresh(task) monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg") monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None) monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", lambda model, frame_paths, source_frame_index, seed, direction, max_frames: [ {"frame_index": 0, "polygons": [seed["polygons"][0]], "scores": [0.9]}, {"frame_index": 1, "polygons": [replacement_downstream_polygon], "scores": [0.8]}, ]) result = run_propagate_project_task(db_session, task.id) assert result["created_annotation_count"] == 1 assert result["deleted_annotation_count"] == 1 annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).all() assert len(annotations) == 1 assert annotations[0].frame_id == frames[2]["id"] assert annotations[0].mask_data["polygons"] == [replacement_downstream_polygon] assert annotations[0].mask_data["source_annotation_id"] == 20 assert annotations[0].mask_data["source_mask_id"] == "annotation-20" def test_propagation_task_runner_replaces_forward_result_when_middle_frame_propagates_backward(client, db_session, monkeypatch): project = client.post("/api/projects", json={"name": "Propagation Backward Middle Replacement"}).json() frames = [ client.post(f"/api/projects/{project['id']}/frames", json={ "project_id": project["id"], "frame_index": idx, "image_url": f"frames/{idx}.jpg", "width": 640, "height": 360, }).json() for idx in range(3) ] old_upstream_polygon = [[0.12, 0.12], [0.22, 0.12], [0.22, 0.22]] replacement_seed_polygon = [[0.16, 0.16], [0.26, 0.16], [0.26, 0.26]] replacement_upstream_polygon = [[0.13, 0.13], [0.23, 0.13], [0.23, 0.23]] db_session.add(Annotation( project_id=project["id"], frame_id=frames[0]["id"], mask_data={ "polygons": [old_upstream_polygon], "label": "胆囊", "color": "#ff0000", "source": "sam2.1_hiera_tiny_propagation", "propagated_from_frame_id": frames[0]["id"], "propagation_seed_key": "annotation:7", "propagation_seed_signature": "old-signature", "propagation_direction": "forward", "source_annotation_id": 7, "source_mask_id": "annotation-7", }, bbox=[0.12, 0.12, 0.1, 0.1], )) db_session.commit() task = ProcessingTask( task_type="propagate_masks", status="queued", progress=0, project_id=project["id"], payload={ "project_id": project["id"], "frame_id": frames[1]["id"], "model": "sam2.1_hiera_tiny", "include_source": False, "save_annotations": True, "steps": [{ "direction": "backward", "max_frames": 2, "seed": { "polygons": [replacement_seed_polygon], "label": "胆囊", "color": "#ff0000", "source_annotation_id": 20, "source_mask_id": "annotation-20", }, }], }, ) db_session.add(task) db_session.commit() db_session.refresh(task) monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg") monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None) monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", lambda model, frame_paths, source_frame_index, seed, direction, max_frames: [ {"frame_index": 0, "polygons": [replacement_upstream_polygon], "scores": [0.8]}, {"frame_index": 1, "polygons": [seed["polygons"][0]], "scores": [0.9]}, ]) result = run_propagate_project_task(db_session, task.id) assert result["created_annotation_count"] == 1 assert result["deleted_annotation_count"] == 1 annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).all() assert len(annotations) == 1 assert annotations[0].frame_id == frames[0]["id"] assert annotations[0].mask_data["polygons"] == [replacement_upstream_polygon] assert annotations[0].mask_data["propagation_direction"] == "backward" assert annotations[0].mask_data["source_annotation_id"] == 20 def test_propagation_task_runner_skips_unmodified_propagated_seed_on_overlapping_frames(client, db_session, monkeypatch): project = client.post("/api/projects", json={"name": "Propagation Overlap Skip"}).json() frames = [ client.post(f"/api/projects/{project['id']}/frames", json={ "project_id": project["id"], "frame_index": idx, "image_url": f"frames/{idx}.jpg", "width": 640, "height": 360, }).json() for idx in range(3) ] original_seed_polygon = [[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]] propagated_seed_polygon = [[0.14, 0.14], [0.24, 0.14], [0.24, 0.24]] downstream_polygon = [[0.18, 0.18], [0.28, 0.18], [0.28, 0.28]] inherited_signature = _seed_signature({ "polygons": [original_seed_polygon], "label": "胆囊", "color": "#ff0000", "source_annotation_id": 7, "source_mask_id": "annotation-7", }) db_session.add(Annotation( project_id=project["id"], frame_id=frames[2]["id"], mask_data={ "polygons": [downstream_polygon], "label": "胆囊", "color": "#ff0000", "source": "sam2.1_hiera_tiny_propagation", "propagated_from_frame_id": frames[0]["id"], "propagation_seed_key": "annotation:7", "propagation_seed_signature": inherited_signature, "propagation_direction": "forward", "source_annotation_id": 7, "source_mask_id": "annotation-7", }, bbox=[0.18, 0.18, 0.1, 0.1], )) db_session.commit() task = ProcessingTask( task_type="propagate_masks", status="queued", progress=0, project_id=project["id"], payload={ "project_id": project["id"], "frame_id": frames[1]["id"], "model": "sam2.1_hiera_tiny", "include_source": False, "save_annotations": True, "steps": [{ "direction": "forward", "max_frames": 2, "seed": { "polygons": [propagated_seed_polygon], "label": "胆囊", "color": "#ff0000", "source_annotation_id": 7, "source_mask_id": "annotation-7", "propagation_seed_signature": inherited_signature, }, }], }, ) db_session.add(task) db_session.commit() db_session.refresh(task) propagate_calls = [] monkeypatch.setattr("services.propagation_task_runner.download_file", lambda object_name: b"jpeg") monkeypatch.setattr("services.propagation_task_runner.publish_task_progress_event", lambda event_task: None) monkeypatch.setattr("services.propagation_task_runner.sam_registry.propagate_video", lambda *args, **kwargs: propagate_calls.append(args) or []) result = run_propagate_project_task(db_session, task.id) assert result["created_annotation_count"] == 0 assert result["deleted_annotation_count"] == 0 assert result["skipped_seed_count"] == 1 assert propagate_calls == [] annotations = db_session.query(Annotation).filter(Annotation.project_id == project["id"]).all() assert len(annotations) == 1 assert annotations[0].mask_data["polygons"] == [downstream_polygon] def test_predict_validation_errors(client, monkeypatch): project, _, _ = _create_project_and_frame(client) assert client.post("/api/ai/predict", json={ "image_id": 999, "prompt_type": "point", "prompt_data": [[0.5, 0.5]], }).status_code == 404 frame = client.post(f"/api/projects/{project['id']}/frames", json={ "project_id": project["id"], "frame_index": 1, "image_url": "frames/1.jpg", }).json() monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8)) assert client.post("/api/ai/predict", json={ "image_id": frame["id"], "prompt_type": "box", "prompt_data": [0.1, 0.2], }).status_code == 400 def test_save_annotation_validates_project_and_frame(client): project, frame, template = _create_project_and_frame(client) saved = client.post("/api/ai/annotate", json={ "project_id": project["id"], "frame_id": frame["id"], "template_id": template["id"], "mask_data": {"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]]}, "points": [[0.5, 0.5]], "bbox": [0.1, 0.1, 0.8, 0.8], }) assert saved.status_code == 201 assert saved.json()["project_id"] == project["id"] listing = client.get(f"/api/ai/annotations?project_id={project['id']}") assert listing.status_code == 200 assert listing.json()[0]["id"] == saved.json()["id"] frame_listing = client.get(f"/api/ai/annotations?project_id={project['id']}&frame_id={frame['id']}") assert frame_listing.status_code == 200 assert len(frame_listing.json()) == 1 missing_project = client.post("/api/ai/annotate", json={"project_id": 999}) assert missing_project.status_code == 404 missing_frame = client.post("/api/ai/annotate", json={ "project_id": project["id"], "frame_id": 999, }) assert missing_frame.status_code == 404 missing_project_list = client.get("/api/ai/annotations?project_id=999") assert missing_project_list.status_code == 404 def test_update_and_delete_annotation(client): project, frame, template = _create_project_and_frame(client) saved = client.post("/api/ai/annotate", json={ "project_id": project["id"], "frame_id": frame["id"], "template_id": template["id"], "mask_data": { "polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]], "label": "AI Mask", "color": "#06b6d4", }, "points": [[0.5, 0.5]], "bbox": [0.1, 0.1, 0.8, 0.8], }).json() updated = client.patch(f"/api/ai/annotations/{saved['id']}", json={ "template_id": template["id"], "mask_data": { "polygons": [[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]], "label": "胆囊", "color": "#ff0000", "class": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20}, }, "points": [[0.4, 0.4]], "bbox": [0.2, 0.2, 0.6, 0.6], }) assert updated.status_code == 200 body = updated.json() assert body["mask_data"]["label"] == "胆囊" assert body["mask_data"]["class"]["id"] == "c1" assert body["points"] == [[0.4, 0.4]] assert body["bbox"] == [0.2, 0.2, 0.6, 0.6] listing = client.get(f"/api/ai/annotations?project_id={project['id']}") assert listing.status_code == 200 assert listing.json()[0]["mask_data"]["class"]["name"] == "胆囊" deleted = client.delete(f"/api/ai/annotations/{saved['id']}") assert deleted.status_code == 204 empty_listing = client.get(f"/api/ai/annotations?project_id={project['id']}") assert empty_listing.status_code == 200 assert empty_listing.json() == [] def test_update_and_delete_annotation_validation(client): project, frame, template = _create_project_and_frame(client) saved = client.post("/api/ai/annotate", json={ "project_id": project["id"], "frame_id": frame["id"], "template_id": template["id"], }).json() assert client.patch("/api/ai/annotations/999", json={"bbox": [0, 0, 1, 1]}).status_code == 404 assert client.delete("/api/ai/annotations/999").status_code == 404 assert client.patch( f"/api/ai/annotations/{saved['id']}", json={"template_id": 999}, ).status_code == 404 def test_import_gt_mask_creates_annotations_with_seed_points(client): project, frame, template = _create_project_and_frame(client) mask = np.zeros((360, 640), dtype=np.uint8) cv2.rectangle(mask, (100, 80), (260, 220), 255, thickness=-1) ok, encoded = cv2.imencode(".png", mask) assert ok response = client.post( "/api/ai/import-gt-mask", data={ "project_id": str(project["id"]), "frame_id": str(frame["id"]), "template_id": str(template["id"]), "label": "Imported GT", "color": "#22c55e", }, files={"file": ("mask.png", encoded.tobytes(), "image/png")}, ) assert response.status_code == 201 body = response.json() assert len(body) == 1 assert body[0]["project_id"] == project["id"] assert body[0]["frame_id"] == frame["id"] assert body[0]["template_id"] == template["id"] assert body[0]["mask_data"]["label"] == "Imported GT" assert body[0]["mask_data"]["source"] == "gt_mask" assert body[0]["mask_data"]["gt_label_value"] == 255 assert len(body[0]["mask_data"]["polygons"][0]) >= 3 assert len(body[0]["points"]) == 1 assert 0.0 <= body[0]["points"][0][0] <= 1.0 assert 0.0 <= body[0]["points"][0][1] <= 1.0 def test_import_gt_mask_polygons_work_with_analysis_and_smoothing(client): project, frame, _ = _create_project_and_frame(client) mask = np.zeros((360, 640), dtype=np.uint8) cv2.ellipse(mask, (260, 160), (130, 70), 20, 0, 360, 1, thickness=-1) ok, encoded = cv2.imencode(".png", mask) assert ok response = client.post( "/api/ai/import-gt-mask", data={ "project_id": str(project["id"]), "frame_id": str(frame["id"]), "label": "Imported GT", "color": "#22c55e", }, files={"file": ("mask.png", encoded.tobytes(), "image/png")}, ) assert response.status_code == 201 annotation = response.json()[0] assert annotation["mask_data"]["source"] == "gt_mask" analysis = client.post("/api/ai/analyze-mask", json={ "frame_id": frame["id"], "mask_data": annotation["mask_data"], "points": annotation["points"], "bbox": annotation["bbox"], }) assert analysis.status_code == 200 assert analysis.json()["topology_anchor_count"] == len(annotation["mask_data"]["polygons"][0]) smoothing = client.post("/api/ai/smooth-mask", json={ "frame_id": frame["id"], "mask_data": annotation["mask_data"], "points": annotation["points"], "bbox": annotation["bbox"], "strength": 35, }) assert smoothing.status_code == 200 assert smoothing.json()["topology_anchor_count"] == len(smoothing.json()["polygons"][0]) def test_import_gt_mask_preserves_detailed_contours(client): project, frame, _ = _create_project_and_frame(client) mask = np.zeros((360, 640), dtype=np.uint8) center = np.array([320, 180]) vertices = [] for index in range(96): angle = 2 * np.pi * index / 96 radius = 120 if index % 2 == 0 else 88 vertices.append([ int(center[0] + np.cos(angle) * radius), int(center[1] + np.sin(angle) * radius), ]) cv2.fillPoly(mask, [np.array(vertices, dtype=np.int32)], 1) ok, encoded = cv2.imencode(".png", mask) assert ok response = client.post( "/api/ai/import-gt-mask", data={ "project_id": str(project["id"]), "frame_id": str(frame["id"]), "label": "Detailed GT", "color": "#22c55e", }, files={"file": ("mask.png", encoded.tobytes(), "image/png")}, ) assert response.status_code == 201 polygon = response.json()[0]["mask_data"]["polygons"][0] assert len(polygon) > 80 assert len(polygon) <= 2048 def test_import_gt_mask_splits_label_values(client): project, frame, _ = _create_project_and_frame(client) mask = np.zeros((360, 640), dtype=np.uint8) cv2.rectangle(mask, (20, 20), (120, 120), 1, thickness=-1) cv2.rectangle(mask, (220, 80), (320, 180), 2, thickness=-1) ok, encoded = cv2.imencode(".png", mask) assert ok response = client.post( "/api/ai/import-gt-mask", data={ "project_id": str(project["id"]), "frame_id": str(frame["id"]), "label": "GT Class", }, files={"file": ("labels.png", encoded.tobytes(), "image/png")}, ) assert response.status_code == 201 body = sorted(response.json(), key=lambda item: item["mask_data"]["gt_label_value"]) assert [item["mask_data"]["gt_label_value"] for item in body] == [1, 2] assert [item["mask_data"]["label"] for item in body] == ["GT Class 1", "GT Class 2"] assert all(len(item["points"]) == 1 for item in body) def test_import_gt_mask_rejects_background_only_label_image(client): project, frame, _ = _create_project_and_frame(client) mask = np.zeros((360, 640), dtype=np.uint8) ok, encoded = cv2.imencode(".png", mask) assert ok response = client.post( "/api/ai/import-gt-mask", data={ "project_id": str(project["id"]), "frame_id": str(frame["id"]), "label": "GT Class", }, files={"file": ("empty-gt-label.png", encoded.tobytes(), "image/png")}, ) assert response.status_code == 400 assert response.json()["detail"] == "GT Mask 图片中没有非背景 maskid 区域。" def test_import_gt_mask_accepts_uint8_low_value_gtlabel_png(client): project, frame, _ = _create_project_and_frame(client) template = client.post("/api/templates", json={ "name": "GTLabel Template", "color": "#06b6d4", "z_index": 0, "classes": [ {"id": "tumor", "name": "肿瘤", "color": "#ff0000", "zIndex": 10, "maskId": 1}, ], "rules": [], }).json() mask = np.zeros((360, 640), dtype=np.uint8) cv2.rectangle(mask, (40, 40), (140, 140), 1, thickness=-1) ok, encoded = cv2.imencode(".png", mask) assert ok response = client.post( "/api/ai/import-gt-mask", data={ "project_id": str(project["id"]), "frame_id": str(frame["id"]), "template_id": str(template["id"]), "unknown_color_policy": "discard", }, files={"file": ("GT_label.png", encoded.tobytes(), "image/png")}, ) assert response.status_code == 201 body = response.json() assert len(body) == 1 assert body[0]["mask_data"]["gt_label_value"] == 1 assert body[0]["mask_data"]["class"]["name"] == "肿瘤" assert body[0]["mask_data"]["class"]["maskId"] == 1 def test_import_gt_mask_rejects_rgb_color_masks(client): project, frame, _ = _create_project_and_frame(client) template = client.post("/api/templates", json={ "name": "Color Template", "color": "#06b6d4", "z_index": 0, "classes": [ {"id": "known", "name": "已知类别", "color": "#ff0000", "zIndex": 10, "maskId": 1}, ], "rules": [], }).json() mask = np.zeros((80, 120, 3), dtype=np.uint8) mask[10:40, 10:40] = [0, 0, 255] # BGR red -> #ff0000 mask[40:70, 70:110] = [0, 255, 0] # BGR green -> unknown #00ff00 ok, encoded = cv2.imencode(".png", mask) assert ok response = client.post( "/api/ai/import-gt-mask", data={ "project_id": str(project["id"]), "frame_id": str(frame["id"]), "template_id": str(template["id"]), "unknown_color_policy": "discard", }, files={"file": ("color-mask.png", encoded.tobytes(), "image/png")}, ) assert response.status_code == 400 assert "RGB 三通道完全相同" in response.json()["detail"] def test_import_gt_mask_rejects_uint16_gt_label(client): project, frame, _ = _create_project_and_frame(client) template = client.post("/api/templates", json={ "name": "Label Template", "color": "#06b6d4", "z_index": 0, "classes": [{"id": "tumor", "name": "肿瘤", "color": "#ff0000", "zIndex": 10, "maskId": 1}], "rules": [], }).json() mask = np.zeros((360, 640), dtype=np.uint16) cv2.rectangle(mask, (20, 20), (120, 120), 1, thickness=-1) ok, encoded = cv2.imencode(".png", mask) assert ok response = client.post( "/api/ai/import-gt-mask", data={ "project_id": str(project["id"]), "frame_id": str(frame["id"]), "template_id": str(template["id"]), "unknown_color_policy": "discard", }, files={"file": ("gt_label.png", encoded.tobytes(), "image/png")}, ) assert response.status_code == 400 assert "仅支持 8-bit" in response.json()["detail"] def test_import_gt_mask_handles_unknown_maskid_policy_and_resizes_to_frame(client): project, frame, _ = _create_project_and_frame(client) template = client.post("/api/templates", json={ "name": "Color Template", "color": "#06b6d4", "z_index": 0, "classes": [{"id": "known", "name": "已定义", "color": "#ff0000", "zIndex": 10, "maskId": 1}], "rules": [], }).json() mask = np.zeros((90, 160, 3), dtype=np.uint8) cv2.rectangle(mask, (5, 5), (40, 40), (1, 1, 1), thickness=-1) cv2.rectangle(mask, (80, 5), (120, 40), (2, 2, 2), thickness=-1) ok, encoded = cv2.imencode(".png", mask) assert ok discard_response = client.post( "/api/ai/import-gt-mask", data={ "project_id": str(project["id"]), "frame_id": str(frame["id"]), "template_id": str(template["id"]), "unknown_color_policy": "discard", }, files={"file": ("colors.png", encoded.tobytes(), "image/png")}, ) assert discard_response.status_code == 201 assert [item["mask_data"]["label"] for item in discard_response.json()] == ["已定义"] assert discard_response.json()[0]["mask_data"]["gt_original_size"] == {"width": 160, "height": 90} assert discard_response.json()[0]["mask_data"]["gt_resized_to_frame"] is True assert discard_response.json()[0]["mask_data"]["image_size"] == {"width": 640, "height": 360} undefined_response = client.post( "/api/ai/import-gt-mask", data={ "project_id": str(project["id"]), "frame_id": str(frame["id"]), "template_id": str(template["id"]), "unknown_color_policy": "undefined", }, files={"file": ("colors.png", encoded.tobytes(), "image/png")}, ) assert undefined_response.status_code == 201 labels = {item["mask_data"]["label"] for item in undefined_response.json()} assert labels == {"已定义", "未定义类别 2"} unknown = next(item for item in undefined_response.json() if item["mask_data"]["label"].startswith("未定义")) assert unknown["mask_data"]["gt_unknown_class"] is True assert unknown["mask_data"]["gt_label_value"] == 2 assert unknown["mask_data"]["gt_resized_to_frame"] is True