import numpy as np import cv2 def _create_project_and_frame(client): project = client.post("/api/projects", json={"name": "AI Project"}).json() frame = client.post(f"/api/projects/{project['id']}/frames", json={ "project_id": project["id"], "frame_index": 0, "image_url": "frames/0.jpg", "width": 640, "height": 360, }).json() template = client.post("/api/templates", json={ "name": "Template", "color": "#06b6d4", "z_index": 0, "classes": [], "rules": [], }).json() return project, frame, template def test_predict_accepts_point_object_with_labels(client, monkeypatch): _, frame, _ = _create_project_and_frame(client) calls = {} monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8)) def fake_predict_points(image, points, labels): calls["args"] = (points, labels) return ( [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]], [0.95], ) monkeypatch.setattr("routers.ai.sam_registry.predict_points", lambda model, image, points, labels: fake_predict_points(image, points, labels)) response = client.post("/api/ai/predict", json={ "image_id": frame["id"], "prompt_type": "point", "prompt_data": {"points": [[0.5, 0.5], [0.1, 0.1]], "labels": [1, 0]}, }) assert response.status_code == 200 assert response.json()["scores"] == [0.95] assert calls["args"] == ([[0.5, 0.5], [0.1, 0.1]], [1, 0]) def test_predict_applies_crop_and_background_filter_options(client, monkeypatch): _, frame, _ = _create_project_and_frame(client) calls = {} monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((100, 200, 3), dtype=np.uint8)) def fake_predict_points(model, image, points, labels): calls["shape"] = image.shape calls["points"] = points calls["labels"] = labels return ( [ [[0.0, 0.0], [0.2, 0.0], [0.2, 0.2]], [[0.45, 0.45], [0.55, 0.45], [0.55, 0.55]], ], [0.9, 0.01], ) monkeypatch.setattr("routers.ai.sam_registry.predict_points", fake_predict_points) response = client.post("/api/ai/predict", json={ "image_id": frame["id"], "prompt_type": "point", "prompt_data": {"points": [[0.5, 0.5], [0.52, 0.52]], "labels": [1, 0]}, "options": { "crop_to_prompt": True, "crop_margin": 0.1, "auto_filter_background": True, "min_score": 0.05, }, }) assert response.status_code == 200 assert calls["shape"][0] < 100 assert calls["shape"][1] < 200 assert calls["labels"] == [1, 0] assert response.json()["scores"] == [0.9] polygon = response.json()["polygons"][0] assert all(0.0 <= coord <= 1.0 for point in polygon for coord in point) def test_predict_box_and_semantic_fallback(client, monkeypatch): _, frame, _ = _create_project_and_frame(client) calls = {} monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8)) monkeypatch.setattr("routers.ai.sam_registry.predict_box", lambda model, image, box: ( [[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]], [0.8], )) def fake_predict_semantic(model, image, text, confidence_threshold=None): calls["semantic"] = { "model": model, "text": text, "confidence_threshold": confidence_threshold, } return ( [[[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]]], [0.5], ) monkeypatch.setattr("routers.ai.sam_registry.predict_semantic", fake_predict_semantic) box_response = client.post("/api/ai/predict", json={ "image_id": frame["id"], "prompt_type": "box", "prompt_data": [0.2, 0.2, 0.8, 0.8], }) semantic_response = client.post("/api/ai/predict", json={ "image_id": frame["id"], "prompt_type": "semantic", "prompt_data": "胆囊", "model": "sam3", "options": {"min_score": 0.05}, }) assert box_response.status_code == 200 assert box_response.json()["scores"] == [0.8] assert semantic_response.status_code == 200 assert semantic_response.json()["scores"] == [0.5] assert calls["semantic"] == { "model": "sam3", "text": "胆囊", "confidence_threshold": 0.05, } def test_predict_interactive_combines_box_and_points(client, monkeypatch): _, frame, _ = _create_project_and_frame(client) calls = {} monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8)) def fake_predict_interactive(model, image, box, points, labels): calls["model"] = model calls["box"] = box calls["points"] = points calls["labels"] = labels return ( [[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]], [0.88], ) monkeypatch.setattr("routers.ai.sam_registry.predict_interactive", fake_predict_interactive) response = client.post("/api/ai/predict", json={ "image_id": frame["id"], "prompt_type": "interactive", "prompt_data": { "box": [0.1, 0.1, 0.9, 0.9], "points": [[0.5, 0.5], [0.2, 0.2]], "labels": [1, 0], }, "model": "sam2", }) assert response.status_code == 200 assert response.json()["scores"] == [0.88] assert calls == { "model": "sam2", "box": [0.1, 0.1, 0.9, 0.9], "points": [[0.5, 0.5], [0.2, 0.2]], "labels": [1, 0], } def test_model_status_reports_runtime(client, monkeypatch): monkeypatch.setattr("routers.ai.sam_registry.runtime_status", lambda selected_model=None: { "selected_model": selected_model or "sam2", "gpu": { "available": False, "device": "cpu", "name": None, "torch_available": True, "torch_version": "2.x", "cuda_version": None, }, "models": [ { "id": "sam2", "label": "SAM 2", "available": True, "loaded": False, "device": "cpu", "supports": ["point", "box", "auto"], "message": "ready", "package_available": True, "checkpoint_exists": True, "checkpoint_path": "model.pt", "python_ok": True, "torch_ok": True, "cuda_required": False, }, { "id": "sam3", "label": "SAM 3", "available": False, "loaded": False, "device": "unavailable", "supports": ["semantic"], "message": "missing Python 3.12+ runtime", "package_available": False, "checkpoint_exists": False, "checkpoint_path": None, "python_ok": False, "torch_ok": True, "cuda_required": True, }, ], }) response = client.get("/api/ai/models/status?selected_model=sam3") assert response.status_code == 200 body = response.json() assert body["selected_model"] == "sam3" assert body["models"][1]["id"] == "sam3" assert body["models"][1]["available"] is False def test_propagate_saves_tracked_annotations(client, monkeypatch): project = client.post("/api/projects", json={"name": "Video Project"}).json() frames = [ client.post(f"/api/projects/{project['id']}/frames", json={ "project_id": project["id"], "frame_index": idx, "image_url": f"frames/{idx}.jpg", "width": 640, "height": 360, }).json() for idx in range(3) ] calls = {} monkeypatch.setattr("routers.ai.download_file", lambda object_name: b"jpeg") def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames): calls["model"] = model calls["source_frame_index"] = source_frame_index calls["seed"] = seed calls["direction"] = direction calls["max_frames"] = max_frames calls["frame_count"] = len(frame_paths) return [ { "frame_index": 0, "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], "scores": [0.9], "object_ids": [1], }, { "frame_index": 1, "polygons": [[[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]], "scores": [0.8], "object_ids": [1], }, ] monkeypatch.setattr("routers.ai.sam_registry.propagate_video", fake_propagate_video) response = client.post("/api/ai/propagate", json={ "project_id": project["id"], "frame_id": frames[0]["id"], "model": "sam2", "direction": "forward", "max_frames": 2, "include_source": False, "seed": { "polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]], "bbox": [0.1, 0.1, 0.1, 0.1], "label": "胆囊", "color": "#ff0000", "class_metadata": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20}, "template_id": None, }, }) assert response.status_code == 200 body = response.json() assert body["created_annotation_count"] == 1 assert body["processed_frame_count"] == 2 assert calls["model"] == "sam2" assert calls["source_frame_index"] == 0 assert calls["direction"] == "forward" assert calls["frame_count"] == 2 saved = body["annotations"][0] assert saved["frame_id"] == frames[1]["id"] assert saved["mask_data"]["source"] == "sam2_propagation" assert saved["mask_data"]["class"]["name"] == "胆囊" assert saved["mask_data"]["score"] == 0.8 listing = client.get(f"/api/ai/annotations?project_id={project['id']}") assert len(listing.json()) == 1 def test_predict_validation_errors(client, monkeypatch): project, _, _ = _create_project_and_frame(client) assert client.post("/api/ai/predict", json={ "image_id": 999, "prompt_type": "point", "prompt_data": [[0.5, 0.5]], }).status_code == 404 frame = client.post(f"/api/projects/{project['id']}/frames", json={ "project_id": project["id"], "frame_index": 1, "image_url": "frames/1.jpg", }).json() monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8)) assert client.post("/api/ai/predict", json={ "image_id": frame["id"], "prompt_type": "box", "prompt_data": [0.1, 0.2], }).status_code == 400 def test_save_annotation_validates_project_and_frame(client): project, frame, template = _create_project_and_frame(client) saved = client.post("/api/ai/annotate", json={ "project_id": project["id"], "frame_id": frame["id"], "template_id": template["id"], "mask_data": {"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]]}, "points": [[0.5, 0.5]], "bbox": [0.1, 0.1, 0.8, 0.8], }) assert saved.status_code == 201 assert saved.json()["project_id"] == project["id"] listing = client.get(f"/api/ai/annotations?project_id={project['id']}") assert listing.status_code == 200 assert listing.json()[0]["id"] == saved.json()["id"] frame_listing = client.get(f"/api/ai/annotations?project_id={project['id']}&frame_id={frame['id']}") assert frame_listing.status_code == 200 assert len(frame_listing.json()) == 1 missing_project = client.post("/api/ai/annotate", json={"project_id": 999}) assert missing_project.status_code == 404 missing_frame = client.post("/api/ai/annotate", json={ "project_id": project["id"], "frame_id": 999, }) assert missing_frame.status_code == 404 missing_project_list = client.get("/api/ai/annotations?project_id=999") assert missing_project_list.status_code == 404 def test_update_and_delete_annotation(client): project, frame, template = _create_project_and_frame(client) saved = client.post("/api/ai/annotate", json={ "project_id": project["id"], "frame_id": frame["id"], "template_id": template["id"], "mask_data": { "polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]], "label": "AI Mask", "color": "#06b6d4", }, "points": [[0.5, 0.5]], "bbox": [0.1, 0.1, 0.8, 0.8], }).json() updated = client.patch(f"/api/ai/annotations/{saved['id']}", json={ "template_id": template["id"], "mask_data": { "polygons": [[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]], "label": "胆囊", "color": "#ff0000", "class": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20}, }, "points": [[0.4, 0.4]], "bbox": [0.2, 0.2, 0.6, 0.6], }) assert updated.status_code == 200 body = updated.json() assert body["mask_data"]["label"] == "胆囊" assert body["mask_data"]["class"]["id"] == "c1" assert body["points"] == [[0.4, 0.4]] assert body["bbox"] == [0.2, 0.2, 0.6, 0.6] listing = client.get(f"/api/ai/annotations?project_id={project['id']}") assert listing.status_code == 200 assert listing.json()[0]["mask_data"]["class"]["name"] == "胆囊" deleted = client.delete(f"/api/ai/annotations/{saved['id']}") assert deleted.status_code == 204 empty_listing = client.get(f"/api/ai/annotations?project_id={project['id']}") assert empty_listing.status_code == 200 assert empty_listing.json() == [] def test_update_and_delete_annotation_validation(client): project, frame, template = _create_project_and_frame(client) saved = client.post("/api/ai/annotate", json={ "project_id": project["id"], "frame_id": frame["id"], "template_id": template["id"], }).json() assert client.patch("/api/ai/annotations/999", json={"bbox": [0, 0, 1, 1]}).status_code == 404 assert client.delete("/api/ai/annotations/999").status_code == 404 assert client.patch( f"/api/ai/annotations/{saved['id']}", json={"template_id": 999}, ).status_code == 404 def test_import_gt_mask_creates_annotations_with_seed_points(client): project, frame, template = _create_project_and_frame(client) mask = np.zeros((360, 640), dtype=np.uint8) cv2.rectangle(mask, (100, 80), (260, 220), 255, thickness=-1) ok, encoded = cv2.imencode(".png", mask) assert ok response = client.post( "/api/ai/import-gt-mask", data={ "project_id": str(project["id"]), "frame_id": str(frame["id"]), "template_id": str(template["id"]), "label": "Imported GT", "color": "#22c55e", }, files={"file": ("mask.png", encoded.tobytes(), "image/png")}, ) assert response.status_code == 201 body = response.json() assert len(body) == 1 assert body[0]["project_id"] == project["id"] assert body[0]["frame_id"] == frame["id"] assert body[0]["template_id"] == template["id"] assert body[0]["mask_data"]["label"] == "Imported GT" assert body[0]["mask_data"]["source"] == "gt_mask" assert body[0]["mask_data"]["gt_label_value"] == 255 assert len(body[0]["mask_data"]["polygons"][0]) >= 3 assert len(body[0]["points"]) == 1 assert 0.0 <= body[0]["points"][0][0] <= 1.0 assert 0.0 <= body[0]["points"][0][1] <= 1.0 def test_import_gt_mask_splits_label_values(client): project, frame, _ = _create_project_and_frame(client) mask = np.zeros((360, 640), dtype=np.uint8) cv2.rectangle(mask, (20, 20), (120, 120), 1, thickness=-1) cv2.rectangle(mask, (220, 80), (320, 180), 2, thickness=-1) ok, encoded = cv2.imencode(".png", mask) assert ok response = client.post( "/api/ai/import-gt-mask", data={ "project_id": str(project["id"]), "frame_id": str(frame["id"]), "label": "GT Class", }, files={"file": ("labels.png", encoded.tobytes(), "image/png")}, ) assert response.status_code == 201 body = sorted(response.json(), key=lambda item: item["mask_data"]["gt_label_value"]) assert [item["mask_data"]["gt_label_value"] for item in body] == [1, 2] assert [item["mask_data"]["label"] for item in body] == ["GT Class 1", "GT Class 2"] assert all(len(item["points"]) == 1 for item in body)