2026-04-29-21-51-19 - 全栈系统改造:FastAPI后端+SAM2+PostgreSQL+Redis+MinIO+前端Zustand重构

This commit is contained in:
2026-04-29 22:17:25 +08:00
parent c8f8686097
commit fd4b5e5b3d
39 changed files with 3816 additions and 211 deletions

123
backend/routers/ai.py Normal file
View File

@@ -0,0 +1,123 @@
"""AI inference endpoints using SAM 2."""
import logging
from typing import Any, List
import cv2
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from database import get_db
from minio_client import download_file
from models import Frame, Annotation
from schemas import PredictRequest, PredictResponse, AnnotationOut, AnnotationCreate
from services.sam2_engine import sam_engine
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/ai", tags=["AI"])
def _load_frame_image(frame: Frame) -> np.ndarray:
"""Download a frame from MinIO and decode it to an RGB numpy array."""
try:
data = download_file(frame.image_url)
arr = np.frombuffer(data, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise ValueError("OpenCV could not decode image")
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
except Exception as exc: # noqa: BLE001
logger.error("Failed to load frame image: %s", exc)
raise HTTPException(status_code=500, detail="Failed to load frame image") from exc
@router.post(
"/predict",
response_model=PredictResponse,
summary="Run SAM 2 inference with a prompt",
)
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
"""Execute SAM 2 segmentation given an image and a prompt.
- **point**: `prompt_data` is a list of `[[x, y], ...]` normalized coordinates.
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
- **semantic**: Not yet implemented; falls back to auto segmentation.
"""
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
image = _load_frame_image(frame)
prompt_type = payload.prompt_type.lower()
polygons: List[List[List[float]]] = []
scores: List[float] = []
if prompt_type == "point":
points = payload.prompt_data
if not isinstance(points, list) or len(points) == 0:
raise HTTPException(status_code=400, detail="Invalid point prompt data")
labels = [1] * len(points)
polygons, scores = sam_engine.predict_points(image, points, labels)
elif prompt_type == "box":
box = payload.prompt_data
if not isinstance(box, list) or len(box) != 4:
raise HTTPException(status_code=400, detail="Invalid box prompt data")
polygons, scores = sam_engine.predict_box(image, box)
elif prompt_type == "semantic":
# Placeholder: use auto segmentation for now
logger.info("Semantic prompt not implemented; using auto segmentation")
polygons, scores = sam_engine.predict_auto(image)
else:
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
return {"polygons": polygons, "scores": scores}
@router.post(
"/auto",
response_model=PredictResponse,
summary="Run automatic segmentation",
)
def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
"""Run automatic mask generation on a frame using a grid of point prompts."""
frame = db.query(Frame).filter(Frame.id == image_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
image = _load_frame_image(frame)
polygons, scores = sam_engine.predict_auto(image)
return {"polygons": polygons, "scores": scores}
@router.post(
"/annotate",
response_model=AnnotationOut,
status_code=status.HTTP_201_CREATED,
summary="Save an AI-generated annotation",
)
def save_annotation(
payload: AnnotationCreate,
db: Session = Depends(get_db),
) -> Annotation:
"""Persist an annotation (mask, points, bbox) into the database."""
project = db.query(Frame).filter(Frame.id == payload.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if payload.frame_id:
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
annotation = Annotation(**payload.model_dump())
db.add(annotation)
db.commit()
db.refresh(annotation)
logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id)
return annotation