"""AI inference endpoints using selectable SAM runtimes.""" import logging from typing import Any, List import cv2 import numpy as np from fastapi import APIRouter, Depends, HTTPException, Response, status from sqlalchemy.orm import Session from database import get_db from minio_client import download_file from models import Project, Frame, Template, Annotation from schemas import ( AiRuntimeStatus, PredictRequest, PredictResponse, AnnotationOut, AnnotationCreate, AnnotationUpdate, ) from services.sam_registry import ModelUnavailableError, sam_registry logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/ai", tags=["AI"]) def _load_frame_image(frame: Frame) -> np.ndarray: """Download a frame from MinIO and decode it to an RGB numpy array.""" try: data = download_file(frame.image_url) arr = np.frombuffer(data, dtype=np.uint8) img = cv2.imdecode(arr, cv2.IMREAD_COLOR) if img is None: raise ValueError("OpenCV could not decode image") return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) except Exception as exc: # noqa: BLE001 logger.error("Failed to load frame image: %s", exc) raise HTTPException(status_code=500, detail="Failed to load frame image") from exc @router.post( "/predict", response_model=PredictResponse, summary="Run SAM inference with a prompt", ) def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict: """Execute selected SAM segmentation given an image and a prompt. - **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized coordinates or `{ "points": [[x, y], ...], "labels": [1, 0, ...] }`. - **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates. - **semantic**: SAM 3 text prompt when model=`sam3`; SAM 2 falls back to auto. """ frame = db.query(Frame).filter(Frame.id == payload.image_id).first() if not frame: raise HTTPException(status_code=404, detail="Frame not found") image = _load_frame_image(frame) prompt_type = payload.prompt_type.lower() polygons: List[List[List[float]]] = [] scores: List[float] = [] try: if prompt_type == "point": point_payload = payload.prompt_data if isinstance(point_payload, dict): points = point_payload.get("points") labels = point_payload.get("labels") else: points = point_payload labels = None if not isinstance(points, list) or len(points) == 0: raise HTTPException(status_code=400, detail="Invalid point prompt data") if not isinstance(labels, list) or len(labels) != len(points): labels = [1] * len(points) polygons, scores = sam_registry.predict_points(payload.model, image, points, labels) elif prompt_type == "box": box = payload.prompt_data if not isinstance(box, list) or len(box) != 4: raise HTTPException(status_code=400, detail="Invalid box prompt data") polygons, scores = sam_registry.predict_box(payload.model, image, box) elif prompt_type == "semantic": text = payload.prompt_data if isinstance(payload.prompt_data, str) else "" polygons, scores = sam_registry.predict_semantic(payload.model, image, text) else: raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}") except ModelUnavailableError as exc: raise HTTPException(status_code=503, detail=str(exc)) from exc except NotImplementedError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc return {"polygons": polygons, "scores": scores} @router.get( "/models/status", response_model=AiRuntimeStatus, summary="Get SAM model and GPU runtime status", ) def model_status(selected_model: str | None = None) -> dict: """Return real runtime availability for GPU, SAM 2, and SAM 3.""" try: return sam_registry.runtime_status(selected_model) except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc @router.post( "/auto", response_model=PredictResponse, summary="Run automatic segmentation", ) def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict: """Run automatic mask generation on a frame using a grid of point prompts.""" frame = db.query(Frame).filter(Frame.id == image_id).first() if not frame: raise HTTPException(status_code=404, detail="Frame not found") image = _load_frame_image(frame) try: polygons, scores = sam_registry.predict_auto(None, image) except ModelUnavailableError as exc: raise HTTPException(status_code=503, detail=str(exc)) from exc return {"polygons": polygons, "scores": scores} @router.post( "/annotate", response_model=AnnotationOut, status_code=status.HTTP_201_CREATED, summary="Save an AI-generated annotation", ) def save_annotation( payload: AnnotationCreate, db: Session = Depends(get_db), ) -> Annotation: """Persist an annotation (mask, points, bbox) into the database.""" project = db.query(Project).filter(Project.id == payload.project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") if payload.frame_id: frame = db.query(Frame).filter(Frame.id == payload.frame_id).first() if not frame: raise HTTPException(status_code=404, detail="Frame not found") annotation = Annotation(**payload.model_dump()) db.add(annotation) db.commit() db.refresh(annotation) logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id) return annotation @router.get( "/annotations", response_model=List[AnnotationOut], summary="List saved annotations for a project", ) def list_annotations( project_id: int, frame_id: int | None = None, db: Session = Depends(get_db), ) -> List[Annotation]: """Return persisted annotations for a project, optionally scoped to one frame.""" project = db.query(Project).filter(Project.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") query = db.query(Annotation).filter(Annotation.project_id == project_id) if frame_id is not None: query = query.filter(Annotation.frame_id == frame_id) return query.order_by(Annotation.id).all() @router.patch( "/annotations/{annotation_id}", response_model=AnnotationOut, summary="Update a saved annotation", ) def update_annotation( annotation_id: int, payload: AnnotationUpdate, db: Session = Depends(get_db), ) -> Annotation: """Update mutable annotation fields persisted in the database.""" annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first() if not annotation: raise HTTPException(status_code=404, detail="Annotation not found") updates = payload.model_dump(exclude_unset=True) if "template_id" in updates and updates["template_id"] is not None: template = db.query(Template).filter(Template.id == updates["template_id"]).first() if not template: raise HTTPException(status_code=404, detail="Template not found") for field, value in updates.items(): setattr(annotation, field, value) db.commit() db.refresh(annotation) logger.info("Updated annotation id=%s", annotation.id) return annotation @router.delete( "/annotations/{annotation_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete a saved annotation", ) def delete_annotation( annotation_id: int, db: Session = Depends(get_db), ) -> Response: """Delete an annotation and its derived mask rows through ORM cascade.""" annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first() if not annotation: raise HTTPException(status_code=404, detail="Annotation not found") db.delete(annotation) db.commit() logger.info("Deleted annotation id=%s", annotation_id) return Response(status_code=status.HTTP_204_NO_CONTENT)