124 lines
4.3 KiB
Python
124 lines
4.3 KiB
Python
"""AI inference endpoints using SAM 2."""
|
|
|
|
import logging
|
|
from typing import Any, List
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from sqlalchemy.orm import Session
|
|
|
|
from database import get_db
|
|
from minio_client import download_file
|
|
from models import Frame, Annotation
|
|
from schemas import PredictRequest, PredictResponse, AnnotationOut, AnnotationCreate
|
|
from services.sam2_engine import sam_engine
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/api/ai", tags=["AI"])
|
|
|
|
|
|
def _load_frame_image(frame: Frame) -> np.ndarray:
|
|
"""Download a frame from MinIO and decode it to an RGB numpy array."""
|
|
try:
|
|
data = download_file(frame.image_url)
|
|
arr = np.frombuffer(data, dtype=np.uint8)
|
|
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
|
if img is None:
|
|
raise ValueError("OpenCV could not decode image")
|
|
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.error("Failed to load frame image: %s", exc)
|
|
raise HTTPException(status_code=500, detail="Failed to load frame image") from exc
|
|
|
|
|
|
@router.post(
|
|
"/predict",
|
|
response_model=PredictResponse,
|
|
summary="Run SAM 2 inference with a prompt",
|
|
)
|
|
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
|
"""Execute SAM 2 segmentation given an image and a prompt.
|
|
|
|
- **point**: `prompt_data` is a list of `[[x, y], ...]` normalized coordinates.
|
|
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
|
|
- **semantic**: Not yet implemented; falls back to auto segmentation.
|
|
"""
|
|
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
|
|
if not frame:
|
|
raise HTTPException(status_code=404, detail="Frame not found")
|
|
|
|
image = _load_frame_image(frame)
|
|
prompt_type = payload.prompt_type.lower()
|
|
|
|
polygons: List[List[List[float]]] = []
|
|
scores: List[float] = []
|
|
|
|
if prompt_type == "point":
|
|
points = payload.prompt_data
|
|
if not isinstance(points, list) or len(points) == 0:
|
|
raise HTTPException(status_code=400, detail="Invalid point prompt data")
|
|
labels = [1] * len(points)
|
|
polygons, scores = sam_engine.predict_points(image, points, labels)
|
|
|
|
elif prompt_type == "box":
|
|
box = payload.prompt_data
|
|
if not isinstance(box, list) or len(box) != 4:
|
|
raise HTTPException(status_code=400, detail="Invalid box prompt data")
|
|
polygons, scores = sam_engine.predict_box(image, box)
|
|
|
|
elif prompt_type == "semantic":
|
|
# Placeholder: use auto segmentation for now
|
|
logger.info("Semantic prompt not implemented; using auto segmentation")
|
|
polygons, scores = sam_engine.predict_auto(image)
|
|
|
|
else:
|
|
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
|
|
|
|
return {"polygons": polygons, "scores": scores}
|
|
|
|
|
|
@router.post(
|
|
"/auto",
|
|
response_model=PredictResponse,
|
|
summary="Run automatic segmentation",
|
|
)
|
|
def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
|
|
"""Run automatic mask generation on a frame using a grid of point prompts."""
|
|
frame = db.query(Frame).filter(Frame.id == image_id).first()
|
|
if not frame:
|
|
raise HTTPException(status_code=404, detail="Frame not found")
|
|
|
|
image = _load_frame_image(frame)
|
|
polygons, scores = sam_engine.predict_auto(image)
|
|
|
|
return {"polygons": polygons, "scores": scores}
|
|
|
|
|
|
@router.post(
|
|
"/annotate",
|
|
response_model=AnnotationOut,
|
|
status_code=status.HTTP_201_CREATED,
|
|
summary="Save an AI-generated annotation",
|
|
)
|
|
def save_annotation(
|
|
payload: AnnotationCreate,
|
|
db: Session = Depends(get_db),
|
|
) -> Annotation:
|
|
"""Persist an annotation (mask, points, bbox) into the database."""
|
|
project = db.query(Frame).filter(Frame.id == payload.project_id).first()
|
|
if not project:
|
|
raise HTTPException(status_code=404, detail="Project not found")
|
|
|
|
if payload.frame_id:
|
|
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
|
|
if not frame:
|
|
raise HTTPException(status_code=404, detail="Frame not found")
|
|
|
|
annotation = Annotation(**payload.model_dump())
|
|
db.add(annotation)
|
|
db.commit()
|
|
db.refresh(annotation)
|
|
logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id)
|
|
return annotation
|