"""AI inference endpoints using selectable SAM runtimes.""" import logging import tempfile from pathlib import Path from typing import Any, List import cv2 import numpy as np from fastapi import APIRouter, Depends, File, Form, HTTPException, Response, UploadFile, status from sqlalchemy.orm import Session from database import get_db from minio_client import download_file from models import Project, Frame, Template, Annotation from schemas import ( AiRuntimeStatus, MaskAnalysisRequest, MaskAnalysisResponse, PredictRequest, PredictResponse, PropagateRequest, PropagateResponse, AnnotationOut, AnnotationCreate, AnnotationUpdate, ) from services.sam_registry import ModelUnavailableError, sam_registry logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/ai", tags=["AI"]) def _load_frame_image(frame: Frame) -> np.ndarray: """Download a frame from MinIO and decode it to an RGB numpy array.""" try: data = download_file(frame.image_url) arr = np.frombuffer(data, dtype=np.uint8) img = cv2.imdecode(arr, cv2.IMREAD_COLOR) if img is None: raise ValueError("OpenCV could not decode image") return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) except Exception as exc: # noqa: BLE001 logger.error("Failed to load frame image: %s", exc) raise HTTPException(status_code=500, detail="Failed to load frame image") from exc def _normalized_contour(contour: np.ndarray, width: int, height: int) -> list[list[float]]: """Approximate a contour and convert it to normalized polygon coordinates.""" arc_length = cv2.arcLength(contour, True) epsilon = max(1.0, arc_length * 0.01) approx = cv2.approxPolyDP(contour, epsilon, True) points = approx.reshape(-1, 2) if len(points) < 3: points = contour.reshape(-1, 2) return [ [ min(max(float(x) / max(width, 1), 0.0), 1.0), min(max(float(y) / max(height, 1), 0.0), 1.0), ] for x, y in points ] def _contour_bbox(contour: np.ndarray, width: int, height: int) -> list[float]: x, y, w, h = cv2.boundingRect(contour) return [ min(max(float(x) / max(width, 1), 0.0), 1.0), min(max(float(y) / max(height, 1), 0.0), 1.0), min(max(float(w) / max(width, 1), 0.0), 1.0), min(max(float(h) / max(height, 1), 0.0), 1.0), ] def _polygon_bbox(polygon: list[list[float]]) -> list[float]: xs = [_clamp01(point[0]) for point in polygon] ys = [_clamp01(point[1]) for point in polygon] left, right = min(xs), max(xs) top, bottom = min(ys), max(ys) return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)] def _polygon_area(polygon: list[list[float]]) -> float: if len(polygon) < 3: return 0.0 total = 0.0 for index, point in enumerate(polygon): next_point = polygon[(index + 1) % len(polygon)] total += _clamp01(point[0]) * _clamp01(next_point[1]) total -= _clamp01(next_point[0]) * _clamp01(point[1]) return abs(total) / 2.0 def _analysis_anchors(polygons: list[list[list[float]]], points: list[list[float]] | None) -> list[list[float]]: if points: return [[_clamp01(point[0]), _clamp01(point[1])] for point in points if len(point) >= 2] anchors: list[list[float]] = [] for polygon in polygons: if not polygon: continue step = max(1, len(polygon) // 12) anchors.extend([[_clamp01(point[0]), _clamp01(point[1])] for point in polygon[::step]]) return anchors[:32] def _frame_window( frames: list[Frame], source_position: int, direction: str, max_frames: int, ) -> tuple[list[Frame], int]: count = max(1, min(max_frames, len(frames))) if direction == "backward": start = max(0, source_position - count + 1) return frames[start:source_position + 1], source_position - start if direction == "both": before = (count - 1) // 2 after = count - 1 - before start = max(0, source_position - before) end = min(len(frames), source_position + after + 1) while end - start < count and start > 0: start -= 1 while end - start < count and end < len(frames): end += 1 return frames[start:end], source_position - start end = min(len(frames), source_position + count) return frames[source_position:end], 0 def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]: paths = [] for index, frame in enumerate(frames): data = download_file(frame.image_url) path = directory / f"frame_{index:06d}.jpg" path.write_bytes(data) paths.append(str(path)) return paths def _component_seed_point(component_mask: np.ndarray, width: int, height: int) -> list[float]: """Reduce a binary component to one positive prompt point using distance transform.""" dist = cv2.distanceTransform(component_mask.astype(np.uint8), cv2.DIST_L2, 5) _, _, _, max_loc = cv2.minMaxLoc(dist) x, y = max_loc return [ min(max(float(x) / max(width, 1), 0.0), 1.0), min(max(float(y) / max(height, 1), 0.0), 1.0), ] def _clamp01(value: float) -> float: return min(max(float(value), 0.0), 1.0) def _point_in_polygon(point: list[float], polygon: list[list[float]]) -> bool: """Return whether a normalized point is inside a normalized polygon.""" if len(polygon) < 3: return False x, y = point inside = False j = len(polygon) - 1 for i, current in enumerate(polygon): xi, yi = current xj, yj = polygon[j] intersects = ((yi > y) != (yj > y)) and ( x < (xj - xi) * (y - yi) / ((yj - yi) or 1e-9) + xi ) if intersects: inside = not inside j = i return inside def _crop_bounds_from_points(points: list[list[float]], margin: float) -> tuple[float, float, float, float]: xs = [_clamp01(point[0]) for point in points] ys = [_clamp01(point[1]) for point in points] x1 = max(0.0, min(xs) - margin) y1 = max(0.0, min(ys) - margin) x2 = min(1.0, max(xs) + margin) y2 = min(1.0, max(ys) + margin) if x2 - x1 < 0.05: center = (x1 + x2) / 2 x1 = max(0.0, center - 0.025) x2 = min(1.0, center + 0.025) if y2 - y1 < 0.05: center = (y1 + y2) / 2 y1 = max(0.0, center - 0.025) y2 = min(1.0, center + 0.025) return x1, y1, x2, y2 def _crop_image(image: np.ndarray, bounds: tuple[float, float, float, float]) -> np.ndarray: height, width = image.shape[:2] x1, y1, x2, y2 = bounds left = int(round(x1 * width)) top = int(round(y1 * height)) right = max(left + 1, int(round(x2 * width))) bottom = max(top + 1, int(round(y2 * height))) return image[top:bottom, left:right] def _to_crop_point(point: list[float], bounds: tuple[float, float, float, float]) -> list[float]: x1, y1, x2, y2 = bounds return [ _clamp01((float(point[0]) - x1) / max(x2 - x1, 1e-9)), _clamp01((float(point[1]) - y1) / max(y2 - y1, 1e-9)), ] def _from_crop_polygon( polygon: list[list[float]], bounds: tuple[float, float, float, float], ) -> list[list[float]]: x1, y1, x2, y2 = bounds return [ [ _clamp01(x1 + float(point[0]) * (x2 - x1)), _clamp01(y1 + float(point[1]) * (y2 - y1)), ] for point in polygon ] def _filter_predictions( polygons: list[list[list[float]]], scores: list[float], options: dict[str, Any], negative_points: list[list[float]] | None = None, ) -> tuple[list[list[list[float]]], list[float]]: if not options.get("auto_filter_background"): return polygons, scores min_score = float(options.get("min_score", 0.0) or 0.0) next_polygons: list[list[list[float]]] = [] next_scores: list[float] = [] for index, polygon in enumerate(polygons): score = scores[index] if index < len(scores) else 0.0 if score < min_score: continue if negative_points and any(_point_in_polygon(point, polygon) for point in negative_points): continue next_polygons.append(polygon) next_scores.append(score) return next_polygons, next_scores @router.post( "/predict", response_model=PredictResponse, summary="Run SAM inference with a prompt", ) def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict: """Execute selected SAM segmentation given an image and a prompt. - **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized coordinates or `{ "points": [[x, y], ...], "labels": [1, 0, ...] }`. - **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates. - **interactive**: `prompt_data` is `{ "box": [...], "points": [[x, y]], "labels": [1, 0] }`. - **semantic**: disabled in the current SAM 2.1 point/box product flow. """ frame = db.query(Frame).filter(Frame.id == payload.image_id).first() if not frame: raise HTTPException(status_code=404, detail="Frame not found") image = _load_frame_image(frame) prompt_type = payload.prompt_type.lower() options = payload.options or {} polygons: List[List[List[float]]] = [] scores: List[float] = [] negative_points: list[list[float]] = [] try: if prompt_type == "point": point_payload = payload.prompt_data if isinstance(point_payload, dict): points = point_payload.get("points") labels = point_payload.get("labels") else: points = point_payload labels = None if not isinstance(points, list) or len(points) == 0: raise HTTPException(status_code=400, detail="Invalid point prompt data") if not isinstance(labels, list) or len(labels) != len(points): labels = [1] * len(points) negative_points = [ point for point, label in zip(points, labels) if label == 0 ] inference_image = image inference_points = points crop_bounds = None if options.get("crop_to_prompt"): margin = float(options.get("crop_margin", 0.25) or 0.25) crop_bounds = _crop_bounds_from_points(points, margin) inference_image = _crop_image(image, crop_bounds) inference_points = [_to_crop_point(point, crop_bounds) for point in points] polygons, scores = sam_registry.predict_points(payload.model, inference_image, inference_points, labels) if crop_bounds: polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons] elif prompt_type == "box": box = payload.prompt_data if not isinstance(box, list) or len(box) != 4: raise HTTPException(status_code=400, detail="Invalid box prompt data") inference_image = image inference_box = box crop_bounds = None if options.get("crop_to_prompt"): margin = float(options.get("crop_margin", 0.05) or 0.05) crop_bounds = _crop_bounds_from_points([[box[0], box[1]], [box[2], box[3]]], margin) inference_image = _crop_image(image, crop_bounds) inference_box = [ *_to_crop_point([box[0], box[1]], crop_bounds), *_to_crop_point([box[2], box[3]], crop_bounds), ] polygons, scores = sam_registry.predict_box(payload.model, inference_image, inference_box) if crop_bounds: polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons] elif prompt_type == "interactive": prompt = payload.prompt_data if not isinstance(prompt, dict): raise HTTPException(status_code=400, detail="Invalid interactive prompt data") box = prompt.get("box") points = prompt.get("points") or [] labels = prompt.get("labels") if box is not None and (not isinstance(box, list) or len(box) != 4): raise HTTPException(status_code=400, detail="Invalid interactive box prompt data") if not isinstance(points, list): raise HTTPException(status_code=400, detail="Invalid interactive point prompt data") if not box and len(points) == 0: raise HTTPException(status_code=400, detail="Interactive prompt requires a box or points") if not isinstance(labels, list) or len(labels) != len(points): labels = [1] * len(points) negative_points = [ point for point, label in zip(points, labels) if label == 0 ] inference_image = image inference_box = box inference_points = points crop_bounds = None if options.get("crop_to_prompt"): margin = float(options.get("crop_margin", 0.05) or 0.05) crop_points = list(points) if box: crop_points.extend([[box[0], box[1]], [box[2], box[3]]]) crop_bounds = _crop_bounds_from_points(crop_points, margin) inference_image = _crop_image(image, crop_bounds) inference_points = [_to_crop_point(point, crop_bounds) for point in points] if box: inference_box = [ *_to_crop_point([box[0], box[1]], crop_bounds), *_to_crop_point([box[2], box[3]], crop_bounds), ] polygons, scores = sam_registry.predict_interactive( payload.model, inference_image, inference_box, inference_points, labels, ) if crop_bounds: polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons] elif prompt_type == "semantic": text = payload.prompt_data if isinstance(payload.prompt_data, str) else "" min_score = options.get("min_score") confidence_threshold = None if min_score is not None: try: parsed_min_score = float(min_score) if parsed_min_score > 0: confidence_threshold = parsed_min_score except (TypeError, ValueError): confidence_threshold = None polygons, scores = sam_registry.predict_semantic( payload.model, image, text, confidence_threshold=confidence_threshold, ) else: raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}") except ModelUnavailableError as exc: raise HTTPException(status_code=503, detail=str(exc)) from exc except NotImplementedError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc polygons, scores = _filter_predictions(polygons, scores, options, negative_points) logger.info( "AI predict completed model=%s prompt_type=%s frame_id=%s polygons=%d", payload.model or "default", prompt_type, payload.image_id, len(polygons), ) return {"polygons": polygons, "scores": scores} @router.get( "/models/status", response_model=AiRuntimeStatus, summary="Get SAM model and GPU runtime status", ) def model_status(selected_model: str | None = None) -> dict: """Return real runtime availability for GPU and the currently enabled SAM model.""" try: return sam_registry.runtime_status(selected_model) except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc @router.post( "/analyze-mask", response_model=MaskAnalysisResponse, summary="Analyze mask geometry and prompt anchors", ) def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) -> dict: """Return backend-computed mask properties for the frontend inspector.""" if payload.frame_id is not None: frame = db.query(Frame).filter(Frame.id == payload.frame_id).first() if not frame: raise HTTPException(status_code=404, detail="Frame not found") mask_data = payload.mask_data or {} polygons = mask_data.get("polygons") or [] if not polygons: raise HTTPException(status_code=400, detail="Mask analysis requires polygons") valid_polygons = [ [[_clamp01(point[0]), _clamp01(point[1])] for point in polygon if len(point) >= 2] for polygon in polygons ] valid_polygons = [polygon for polygon in valid_polygons if len(polygon) >= 3] if not valid_polygons: raise HTTPException(status_code=400, detail="Mask analysis requires at least one valid polygon") area = sum(_polygon_area(polygon) for polygon in valid_polygons) bbox = payload.bbox or _polygon_bbox(valid_polygons[0]) source = mask_data.get("source") raw_score = mask_data.get("score") confidence: float | None = None confidence_source = "unavailable" if isinstance(raw_score, (int, float)): confidence = max(0.0, min(float(raw_score), 1.0)) confidence_source = "model_score" elif source: confidence_source = "source_without_score" else: confidence_source = "manual_or_imported" anchors = _analysis_anchors(valid_polygons, payload.points) message = "已从后端重新提取几何拓扑锚点" if payload.extract_skeleton else "已读取后端几何属性" return { "confidence": confidence, "confidence_source": confidence_source, "topology_anchor_count": len(anchors), "topology_anchors": anchors, "area": area, "bbox": bbox, "source": source, "message": message, } @router.post( "/propagate", response_model=PropagateResponse, summary="Propagate one current-frame region across a video frame segment", ) def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict: """Track one selected region from the current frame across nearby frames. SAM 2 uses the official video predictor with the selected mask as the seed. SAM 3 video tracking is currently disabled in this product flow. """ direction = payload.direction.lower() if direction not in {"forward", "backward", "both"}: raise HTTPException(status_code=400, detail="direction must be forward, backward, or both") max_frames = max(1, min(int(payload.max_frames or 30), 500)) project = db.query(Project).filter(Project.id == payload.project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") source_frame = db.query(Frame).filter( Frame.id == payload.frame_id, Frame.project_id == payload.project_id, ).first() if not source_frame: raise HTTPException(status_code=404, detail="Frame not found") seed = payload.seed.model_dump(exclude_none=True) polygons = seed.get("polygons") or [] bbox = seed.get("bbox") points = seed.get("points") or [] if not polygons and not bbox and not points: raise HTTPException(status_code=400, detail="Propagation requires seed polygons, bbox, or points") frames = db.query(Frame).filter(Frame.project_id == payload.project_id).order_by(Frame.frame_index).all() source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None) if source_position is None: raise HTTPException(status_code=404, detail="Source frame is not in project frame sequence") selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames) if len(selected_frames) == 0: raise HTTPException(status_code=400, detail="No frames available for propagation") try: with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload.project_id}_") as tmpdir: frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir)) propagated = sam_registry.propagate_video( payload.model, frame_paths, source_relative_index, seed, direction, len(selected_frames), ) except ModelUnavailableError as exc: raise HTTPException(status_code=503, detail=str(exc)) from exc except NotImplementedError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: # noqa: BLE001 logger.error("Video propagation failed: %s", exc) raise HTTPException(status_code=500, detail=f"Video propagation failed: {exc}") from exc created: list[Annotation] = [] if payload.save_annotations: class_metadata = seed.get("class_metadata") template_id = seed.get("template_id") label = seed.get("label") or "Propagated Mask" color = seed.get("color") or "#06b6d4" model_id = sam_registry.normalize_model_id(payload.model) for frame_result in propagated: relative_index = int(frame_result.get("frame_index", -1)) if relative_index < 0 or relative_index >= len(selected_frames): continue frame = selected_frames[relative_index] if not payload.include_source and frame.id == source_frame.id: continue result_polygons = frame_result.get("polygons") or [] scores = frame_result.get("scores") or [] for polygon_index, polygon in enumerate(result_polygons): if len(polygon) < 3: continue annotation = Annotation( project_id=payload.project_id, frame_id=frame.id, template_id=template_id, mask_data={ "polygons": [polygon], "label": label, "color": color, "source": f"{model_id}_propagation", "propagated_from_frame_id": source_frame.id, "propagated_from_frame_index": source_frame.frame_index, "score": scores[polygon_index] if polygon_index < len(scores) else None, **({"class": class_metadata} if class_metadata else {}), }, points=None, bbox=_polygon_bbox(polygon), ) db.add(annotation) created.append(annotation) db.commit() for annotation in created: db.refresh(annotation) return { "model": sam_registry.normalize_model_id(payload.model), "direction": direction, "source_frame_id": source_frame.id, "processed_frame_count": len(selected_frames), "created_annotation_count": len(created), "annotations": created, } @router.post( "/auto", response_model=PredictResponse, summary="Run automatic segmentation", ) def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict: """Run automatic mask generation on a frame using a grid of point prompts.""" frame = db.query(Frame).filter(Frame.id == image_id).first() if not frame: raise HTTPException(status_code=404, detail="Frame not found") image = _load_frame_image(frame) try: polygons, scores = sam_registry.predict_auto(None, image) except ModelUnavailableError as exc: raise HTTPException(status_code=503, detail=str(exc)) from exc return {"polygons": polygons, "scores": scores} @router.post( "/annotate", response_model=AnnotationOut, status_code=status.HTTP_201_CREATED, summary="Save an AI-generated annotation", ) def save_annotation( payload: AnnotationCreate, db: Session = Depends(get_db), ) -> Annotation: """Persist an annotation (mask, points, bbox) into the database.""" project = db.query(Project).filter(Project.id == payload.project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") if payload.frame_id: frame = db.query(Frame).filter(Frame.id == payload.frame_id).first() if not frame: raise HTTPException(status_code=404, detail="Frame not found") annotation = Annotation(**payload.model_dump()) db.add(annotation) db.commit() db.refresh(annotation) logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id) return annotation @router.post( "/import-gt-mask", response_model=List[AnnotationOut], status_code=status.HTTP_201_CREATED, summary="Import a GT mask and reduce components to editable point regions", ) async def import_gt_mask( project_id: int = Form(...), frame_id: int = Form(...), template_id: int | None = Form(None), label: str = Form("GT Mask"), color: str = Form("#22c55e"), file: UploadFile = File(...), db: Session = Depends(get_db), ) -> List[Annotation]: """Convert a binary/label mask image into persisted polygon annotations. Each connected component becomes one annotation. The `points` field stores a positive seed point at the component's distance-transform center, which gives the frontend an editable point-region representation instead of a static bitmap layer. """ project = db.query(Project).filter(Project.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") frame = db.query(Frame).filter(Frame.id == frame_id, Frame.project_id == project_id).first() if not frame: raise HTTPException(status_code=404, detail="Frame not found") if template_id is not None: template = db.query(Template).filter(Template.id == template_id).first() if not template: raise HTTPException(status_code=404, detail="Template not found") data = await file.read() image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_GRAYSCALE) if image is None: raise HTTPException(status_code=400, detail="Invalid mask image") width = int(frame.width or image.shape[1]) height = int(frame.height or image.shape[0]) label_values = [int(value) for value in np.unique(image) if int(value) > 0] if not label_values: raise HTTPException(status_code=400, detail="No foreground mask regions found") has_multiple_labels = len(label_values) > 1 annotations: list[Annotation] = [] for label_value in label_values: binary = np.where(image == label_value, 255, 0).astype(np.uint8) contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) annotation_label = f"{label} {label_value}" if has_multiple_labels else label for contour in contours: if cv2.contourArea(contour) < 1: continue polygon = _normalized_contour(contour, image.shape[1], image.shape[0]) if len(polygon) < 3: continue component = np.zeros_like(binary, dtype=np.uint8) cv2.drawContours(component, [contour], -1, 1, thickness=-1) seed_point = _component_seed_point(component, image.shape[1], image.shape[0]) bbox = _contour_bbox(contour, image.shape[1], image.shape[0]) annotation = Annotation( project_id=project_id, frame_id=frame_id, template_id=template_id, mask_data={ "polygons": [polygon], "label": annotation_label, "color": color, "source": "gt_mask", "gt_label_value": label_value, "image_size": {"width": width, "height": height}, }, points=[seed_point], bbox=bbox, ) db.add(annotation) annotations.append(annotation) if not annotations: raise HTTPException(status_code=400, detail="No foreground mask regions found") db.commit() for annotation in annotations: db.refresh(annotation) logger.info("Imported %s GT mask annotations for project_id=%s frame_id=%s", len(annotations), project_id, frame_id) return annotations @router.get( "/annotations", response_model=List[AnnotationOut], summary="List saved annotations for a project", ) def list_annotations( project_id: int, frame_id: int | None = None, db: Session = Depends(get_db), ) -> List[Annotation]: """Return persisted annotations for a project, optionally scoped to one frame.""" project = db.query(Project).filter(Project.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") query = db.query(Annotation).filter(Annotation.project_id == project_id) if frame_id is not None: query = query.filter(Annotation.frame_id == frame_id) return query.order_by(Annotation.id).all() @router.patch( "/annotations/{annotation_id}", response_model=AnnotationOut, summary="Update a saved annotation", ) def update_annotation( annotation_id: int, payload: AnnotationUpdate, db: Session = Depends(get_db), ) -> Annotation: """Update mutable annotation fields persisted in the database.""" annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first() if not annotation: raise HTTPException(status_code=404, detail="Annotation not found") updates = payload.model_dump(exclude_unset=True) if "template_id" in updates and updates["template_id"] is not None: template = db.query(Template).filter(Template.id == updates["template_id"]).first() if not template: raise HTTPException(status_code=404, detail="Template not found") for field, value in updates.items(): setattr(annotation, field, value) db.commit() db.refresh(annotation) logger.info("Updated annotation id=%s", annotation.id) return annotation @router.delete( "/annotations/{annotation_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete a saved annotation", ) def delete_annotation( annotation_id: int, db: Session = Depends(get_db), ) -> Response: """Delete an annotation and its derived mask rows through ORM cascade.""" annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first() if not annotation: raise HTTPException(status_code=404, detail="Annotation not found") db.delete(annotation) db.commit() logger.info("Deleted annotation id=%s", annotation_id) return Response(status_code=status.HTTP_204_NO_CONTENT)