2026-04-29-21-51-19 - 全栈系统改造:FastAPI后端+SAM2+PostgreSQL+Redis+MinIO+前端Zustand重构
This commit is contained in:
123
backend/routers/ai.py
Normal file
123
backend/routers/ai.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""AI inference endpoints using SAM 2."""
|
||||
|
||||
import logging
|
||||
from typing import Any, List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import download_file
|
||||
from models import Frame, Annotation
|
||||
from schemas import PredictRequest, PredictResponse, AnnotationOut, AnnotationCreate
|
||||
from services.sam2_engine import sam_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/ai", tags=["AI"])
|
||||
|
||||
|
||||
def _load_frame_image(frame: Frame) -> np.ndarray:
|
||||
"""Download a frame from MinIO and decode it to an RGB numpy array."""
|
||||
try:
|
||||
data = download_file(frame.image_url)
|
||||
arr = np.frombuffer(data, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise ValueError("OpenCV could not decode image")
|
||||
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to load frame image: %s", exc)
|
||||
raise HTTPException(status_code=500, detail="Failed to load frame image") from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/predict",
|
||||
response_model=PredictResponse,
|
||||
summary="Run SAM 2 inference with a prompt",
|
||||
)
|
||||
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
"""Execute SAM 2 segmentation given an image and a prompt.
|
||||
|
||||
- **point**: `prompt_data` is a list of `[[x, y], ...]` normalized coordinates.
|
||||
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
|
||||
- **semantic**: Not yet implemented; falls back to auto segmentation.
|
||||
"""
|
||||
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
image = _load_frame_image(frame)
|
||||
prompt_type = payload.prompt_type.lower()
|
||||
|
||||
polygons: List[List[List[float]]] = []
|
||||
scores: List[float] = []
|
||||
|
||||
if prompt_type == "point":
|
||||
points = payload.prompt_data
|
||||
if not isinstance(points, list) or len(points) == 0:
|
||||
raise HTTPException(status_code=400, detail="Invalid point prompt data")
|
||||
labels = [1] * len(points)
|
||||
polygons, scores = sam_engine.predict_points(image, points, labels)
|
||||
|
||||
elif prompt_type == "box":
|
||||
box = payload.prompt_data
|
||||
if not isinstance(box, list) or len(box) != 4:
|
||||
raise HTTPException(status_code=400, detail="Invalid box prompt data")
|
||||
polygons, scores = sam_engine.predict_box(image, box)
|
||||
|
||||
elif prompt_type == "semantic":
|
||||
# Placeholder: use auto segmentation for now
|
||||
logger.info("Semantic prompt not implemented; using auto segmentation")
|
||||
polygons, scores = sam_engine.predict_auto(image)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
|
||||
|
||||
return {"polygons": polygons, "scores": scores}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/auto",
|
||||
response_model=PredictResponse,
|
||||
summary="Run automatic segmentation",
|
||||
)
|
||||
def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
|
||||
"""Run automatic mask generation on a frame using a grid of point prompts."""
|
||||
frame = db.query(Frame).filter(Frame.id == image_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
image = _load_frame_image(frame)
|
||||
polygons, scores = sam_engine.predict_auto(image)
|
||||
|
||||
return {"polygons": polygons, "scores": scores}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/annotate",
|
||||
response_model=AnnotationOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Save an AI-generated annotation",
|
||||
)
|
||||
def save_annotation(
|
||||
payload: AnnotationCreate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Annotation:
|
||||
"""Persist an annotation (mask, points, bbox) into the database."""
|
||||
project = db.query(Frame).filter(Frame.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
if payload.frame_id:
|
||||
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
annotation = Annotation(**payload.model_dump())
|
||||
db.add(annotation)
|
||||
db.commit()
|
||||
db.refresh(annotation)
|
||||
logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id)
|
||||
return annotation
|
||||
Reference in New Issue
Block a user