"""AI inference endpoints using SAM 2.""" import logging from typing import Any, List import cv2 import numpy as np from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from database import get_db from minio_client import download_file from models import Frame, Annotation from schemas import PredictRequest, PredictResponse, AnnotationOut, AnnotationCreate from services.sam2_engine import sam_engine logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/ai", tags=["AI"]) def _load_frame_image(frame: Frame) -> np.ndarray: """Download a frame from MinIO and decode it to an RGB numpy array.""" try: data = download_file(frame.image_url) arr = np.frombuffer(data, dtype=np.uint8) img = cv2.imdecode(arr, cv2.IMREAD_COLOR) if img is None: raise ValueError("OpenCV could not decode image") return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) except Exception as exc: # noqa: BLE001 logger.error("Failed to load frame image: %s", exc) raise HTTPException(status_code=500, detail="Failed to load frame image") from exc @router.post( "/predict", response_model=PredictResponse, summary="Run SAM 2 inference with a prompt", ) def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict: """Execute SAM 2 segmentation given an image and a prompt. - **point**: `prompt_data` is a list of `[[x, y], ...]` normalized coordinates. - **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates. - **semantic**: Not yet implemented; falls back to auto segmentation. """ frame = db.query(Frame).filter(Frame.id == payload.image_id).first() if not frame: raise HTTPException(status_code=404, detail="Frame not found") image = _load_frame_image(frame) prompt_type = payload.prompt_type.lower() polygons: List[List[List[float]]] = [] scores: List[float] = [] if prompt_type == "point": points = payload.prompt_data if not isinstance(points, list) or len(points) == 0: raise HTTPException(status_code=400, detail="Invalid point prompt data") labels = [1] * len(points) polygons, scores = sam_engine.predict_points(image, points, labels) elif prompt_type == "box": box = payload.prompt_data if not isinstance(box, list) or len(box) != 4: raise HTTPException(status_code=400, detail="Invalid box prompt data") polygons, scores = sam_engine.predict_box(image, box) elif prompt_type == "semantic": # Placeholder: use auto segmentation for now logger.info("Semantic prompt not implemented; using auto segmentation") polygons, scores = sam_engine.predict_auto(image) else: raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}") return {"polygons": polygons, "scores": scores} @router.post( "/auto", response_model=PredictResponse, summary="Run automatic segmentation", ) def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict: """Run automatic mask generation on a frame using a grid of point prompts.""" frame = db.query(Frame).filter(Frame.id == image_id).first() if not frame: raise HTTPException(status_code=404, detail="Frame not found") image = _load_frame_image(frame) polygons, scores = sam_engine.predict_auto(image) return {"polygons": polygons, "scores": scores} @router.post( "/annotate", response_model=AnnotationOut, status_code=status.HTTP_201_CREATED, summary="Save an AI-generated annotation", ) def save_annotation( payload: AnnotationCreate, db: Session = Depends(get_db), ) -> Annotation: """Persist an annotation (mask, points, bbox) into the database.""" project = db.query(Frame).filter(Frame.id == payload.project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") if payload.frame_id: frame = db.query(Frame).filter(Frame.id == payload.frame_id).first() if not frame: raise HTTPException(status_code=404, detail="Frame not found") annotation = Annotation(**payload.model_dump()) db.add(annotation) db.commit() db.refresh(annotation) logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id) return annotation