2026-04-29-21-51-19 - 全栈系统改造:FastAPI后端+SAM2+PostgreSQL+Redis+MinIO+前端Zustand重构

This commit is contained in:
2026-04-29 22:17:25 +08:00
parent c8f8686097
commit fd4b5e5b3d
39 changed files with 3816 additions and 211 deletions

35
backend/config.py Normal file
View File

@@ -0,0 +1,35 @@
"""Application configuration using Pydantic Settings."""
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
# Database
db_url: str = "postgresql://seguser:segpass123@localhost:5432/segserver"
# Redis
redis_url: str = "redis://localhost:6379/0"
# MinIO
minio_endpoint: str = "localhost:9000"
minio_access_key: str = "minioadmin"
minio_secret_key: str = "minioadmin"
minio_secure: bool = False
# SAM2
sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
sam_model_config: str = "sam2_hiera_t.yaml"
# App
app_env: str = "development"
cors_origins: list[str] = ["http://localhost:3000"]
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
extra = "ignore"
settings = Settings()

29
backend/database.py Normal file
View File

@@ -0,0 +1,29 @@
"""Database configuration using synchronous SQLAlchemy."""
from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker, Session
from fastapi import Depends
from typing import Generator
from config import settings
engine = create_engine(
settings.db_url,
pool_pre_ping=True,
pool_size=10,
max_overflow=20,
echo=False,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def get_db() -> Generator[Session, None, None]:
"""FastAPI dependency that yields a database session."""
db = SessionLocal()
try:
yield db
finally:
db.close()

49
backend/download_sam2.py Normal file
View File

@@ -0,0 +1,49 @@
#!/usr/bin/env python3
"""
SAM 2 模型权重下载脚本
运行: python download_sam2.py
"""
import os
import urllib.request
import sys
MODEL_DIR = "/home/wkmgc/Desktop/Seg_Server/models"
os.makedirs(MODEL_DIR, exist_ok=True)
# SAM 2 模型权重 (Meta AI 官方)
MODELS = {
"sam2_hiera_tiny.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt",
"sam2_hiera_small.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
"sam2_hiera_base_plus.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
"sam2_hiera_large.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
}
def download_file(url: str, dest: str):
if os.path.exists(dest):
print(f"[跳过] {os.path.basename(dest)} 已存在")
return
print(f"[下载] {url}")
print(f"{dest}")
try:
urllib.request.urlretrieve(url, dest)
size_mb = os.path.getsize(dest) / 1024 / 1024
print(f" ✓ 完成 ({size_mb:.1f} MB)")
except Exception as e:
print(f" ✗ 失败: {e}", file=sys.stderr)
if os.path.exists(dest):
os.remove(dest)
def main():
print("=" * 50)
print("SAM 2 模型权重下载")
print("=" * 50)
for name, url in MODELS.items():
dest = os.path.join(MODEL_DIR, name)
download_file(url, dest)
print("=" * 50)
print("全部完成!")
print(f"模型目录: {MODEL_DIR}")
print("=" * 50)
if __name__ == "__main__":
main()

83
backend/main.py Normal file
View File

@@ -0,0 +1,83 @@
"""FastAPI application entrypoint."""
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from config import settings
from database import Base, engine
from minio_client import ensure_bucket_exists
from redis_client import ping as redis_ping
from routers import projects, templates, media, ai, export, auth
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan: startup and shutdown hooks."""
# Startup
logger.info("Starting up SegServer backend...")
# Initialize database tables
try:
Base.metadata.create_all(bind=engine)
logger.info("Database tables initialized.")
except Exception as exc: # noqa: BLE001
logger.error("Database initialization failed: %s", exc)
# Check MinIO bucket
try:
ensure_bucket_exists()
except Exception as exc: # noqa: BLE001
logger.error("MinIO bucket check failed: %s", exc)
# Check Redis
if redis_ping():
logger.info("Redis connection OK.")
else:
logger.warning("Redis connection failed.")
yield
# Shutdown
logger.info("Shutting down SegServer backend...")
engine.dispose()
app = FastAPI(
title="SegServer API",
description="Semantic Segmentation System Backend",
version="1.0.0",
lifespan=lifespan,
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Routers
app.include_router(auth.router)
app.include_router(projects.router)
app.include_router(templates.router)
app.include_router(media.router)
app.include_router(ai.router)
app.include_router(export.router)
@app.get("/health", tags=["Health"])
def health_check() -> dict:
"""Health check endpoint."""
return {"status": "ok", "service": "SegServer"}

127
backend/minio_client.py Normal file
View File

@@ -0,0 +1,127 @@
"""MinIO client wrapper for object storage operations."""
import io
import logging
from typing import Optional
from minio import Minio
from minio.error import S3Error
from config import settings
logger = logging.getLogger(__name__)
BUCKET_NAME = "seg-media"
_minio_client: Optional[Minio] = None
def get_minio_client() -> Minio:
"""Return a singleton MinIO client instance."""
global _minio_client
if _minio_client is None:
_minio_client = Minio(
settings.minio_endpoint,
access_key=settings.minio_access_key,
secret_key=settings.minio_secret_key,
secure=settings.minio_secure,
)
return _minio_client
def ensure_bucket_exists() -> None:
"""Create the bucket if it does not already exist."""
client = get_minio_client()
try:
if not client.bucket_exists(BUCKET_NAME):
client.make_bucket(BUCKET_NAME)
logger.info("Created MinIO bucket: %s", BUCKET_NAME)
else:
logger.info("MinIO bucket %s already exists", BUCKET_NAME)
except S3Error as exc:
logger.error("MinIO bucket check/creation failed: %s", exc)
raise
def upload_file(
object_name: str,
data: bytes,
content_type: str = "application/octet-stream",
length: int = -1,
) -> str:
"""Upload bytes to MinIO and return the object name.
Args:
object_name: Destination path inside the bucket.
data: Raw bytes or a file-like object.
content_type: MIME type of the object.
length: Object size; -1 for unknown (uses chunked upload).
Returns:
The object name (same as input).
"""
client = get_minio_client()
if isinstance(data, bytes):
data = io.BytesIO(data)
length = len(data.getvalue())
try:
client.put_object(
BUCKET_NAME,
object_name,
data,
length=length,
content_type=content_type,
)
logger.info("Uploaded to MinIO: %s", object_name)
return object_name
except S3Error as exc:
logger.error("MinIO upload failed: %s", exc)
raise
from datetime import timedelta
def get_presigned_url(
object_name: str,
expires: int = 3600,
method: str = "GET",
) -> str:
"""Generate a presigned URL for an object.
Args:
object_name: Path inside the bucket.
expires: Expiration time in seconds (default 1 hour).
method: HTTP method (GET or PUT).
Returns:
Presigned URL string.
"""
client = get_minio_client()
try:
url = client.get_presigned_url(method, BUCKET_NAME, object_name, expires=timedelta(seconds=expires))
return url
except S3Error as exc:
logger.error("MinIO presigned URL failed: %s", exc)
raise
def download_file(object_name: str) -> bytes:
"""Download an object from MinIO and return its bytes.
Args:
object_name: Path inside the bucket.
Returns:
Raw bytes of the object.
"""
client = get_minio_client()
try:
response = client.get_object(BUCKET_NAME, object_name)
data = response.read()
response.close()
response.release_conn()
return data
except S3Error as exc:
logger.error("MinIO download failed: %s", exc)
raise

118
backend/models.py Normal file
View File

@@ -0,0 +1,118 @@
"""SQLAlchemy ORM models."""
from sqlalchemy import (
Column,
Integer,
String,
Text,
DateTime,
ForeignKey,
JSON,
Float,
)
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from database import Base
class Project(Base):
"""Project model representing a segmentation project."""
__tablename__ = "projects"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(255), nullable=False)
description = Column(Text, nullable=True)
video_path = Column(String(512), nullable=True)
status = Column(String(50), default="pending", nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
frames = relationship("Frame", back_populates="project", cascade="all, delete-orphan")
annotations = relationship(
"Annotation", back_populates="project", cascade="all, delete-orphan"
)
class Frame(Base):
"""Frame model representing an extracted video frame."""
__tablename__ = "frames"
id = Column(Integer, primary_key=True, index=True)
project_id = Column(Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
frame_index = Column(Integer, nullable=False)
image_url = Column(String(512), nullable=False)
width = Column(Integer, nullable=True)
height = Column(Integer, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
project = relationship("Project", back_populates="frames")
annotations = relationship(
"Annotation", back_populates="frame", cascade="all, delete-orphan"
)
class Template(Base):
"""Template (Ontology) model for segmentation classes."""
__tablename__ = "templates"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(255), nullable=False)
color = Column(String(50), nullable=False)
z_index = Column(Integer, default=0, nullable=False)
mapping_rules = Column(JSON, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
annotations = relationship(
"Annotation", back_populates="template", cascade="all, delete-orphan"
)
class Annotation(Base):
"""Annotation model for segmentation masks and prompts."""
__tablename__ = "annotations"
id = Column(Integer, primary_key=True, index=True)
project_id = Column(
Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=False
)
frame_id = Column(
Integer, ForeignKey("frames.id", ondelete="CASCADE"), nullable=True
)
template_id = Column(
Integer, ForeignKey("templates.id", ondelete="SET NULL"), nullable=True
)
mask_data = Column(JSON, nullable=True)
points = Column(JSON, nullable=True)
bbox = Column(JSON, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
project = relationship("Project", back_populates="annotations")
frame = relationship("Frame", back_populates="annotations")
template = relationship("Template", back_populates="annotations")
masks = relationship("Mask", back_populates="annotation", cascade="all, delete-orphan")
class Mask(Base):
"""Mask model for exported/derived mask files."""
__tablename__ = "masks"
id = Column(Integer, primary_key=True, index=True)
annotation_id = Column(
Integer, ForeignKey("annotations.id", ondelete="CASCADE"), nullable=False
)
mask_url = Column(String(512), nullable=False)
format = Column(String(50), default="png", nullable=False) # png / rle / json
created_at = Column(DateTime(timezone=True), server_default=func.now())
annotation = relationship("Annotation", back_populates="masks")

61
backend/redis_client.py Normal file
View File

@@ -0,0 +1,61 @@
"""Redis client wrapper for caching and task queuing."""
import json
import logging
from typing import Optional, Any
import redis
from config import settings
logger = logging.getLogger(__name__)
_redis_client: Optional[redis.Redis] = None
def get_redis_client() -> redis.Redis:
"""Return a singleton Redis client instance."""
global _redis_client
if _redis_client is None:
_redis_client = redis.from_url(settings.redis_url, decode_responses=True)
return _redis_client
def ping() -> bool:
"""Check Redis connectivity."""
try:
return get_redis_client().ping()
except redis.ConnectionError as exc:
logger.error("Redis ping failed: %s", exc)
return False
def set_json(key: str, value: Any, expire: Optional[int] = None) -> None:
"""Store a JSON-serializable value in Redis."""
client = get_redis_client()
try:
client.set(key, json.dumps(value), ex=expire)
except redis.RedisError as exc:
logger.error("Redis set_json failed: %s", exc)
raise
def get_json(key: str) -> Optional[Any]:
"""Retrieve and deserialize a JSON value from Redis."""
client = get_redis_client()
try:
data = client.get(key)
return json.loads(data) if data is not None else None
except redis.RedisError as exc:
logger.error("Redis get_json failed: %s", exc)
raise
def delete_key(key: str) -> int:
"""Delete a key from Redis. Returns number of deleted keys."""
client = get_redis_client()
try:
return client.delete(key)
except redis.RedisError as exc:
logger.error("Redis delete_key failed: %s", exc)
raise

38
backend/requirements.txt Normal file
View File

@@ -0,0 +1,38 @@
# Web framework
fastapi
uvicorn[standard]
python-multipart
# Database
sqlalchemy
psycopg2-binary
alembic
# Cache / Task queue
redis
celery
# Object storage
minio
# Image / Video / DICOM processing
opencv-python
pillow
scikit-image
pydicom
numpy
# SAM 2 (may require manual installation depending on CUDA version)
sam2
# PyTorch (CUDA 12.4 wheel index; adjust for your CUDA version if needed)
torch
torchvision
torchaudio
# Configuration
pydantic-settings
# Utilities
python-jose[cryptography]
passlib[bcrypt]

View File

123
backend/routers/ai.py Normal file
View File

@@ -0,0 +1,123 @@
"""AI inference endpoints using SAM 2."""
import logging
from typing import Any, List
import cv2
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from database import get_db
from minio_client import download_file
from models import Frame, Annotation
from schemas import PredictRequest, PredictResponse, AnnotationOut, AnnotationCreate
from services.sam2_engine import sam_engine
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/ai", tags=["AI"])
def _load_frame_image(frame: Frame) -> np.ndarray:
"""Download a frame from MinIO and decode it to an RGB numpy array."""
try:
data = download_file(frame.image_url)
arr = np.frombuffer(data, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise ValueError("OpenCV could not decode image")
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
except Exception as exc: # noqa: BLE001
logger.error("Failed to load frame image: %s", exc)
raise HTTPException(status_code=500, detail="Failed to load frame image") from exc
@router.post(
"/predict",
response_model=PredictResponse,
summary="Run SAM 2 inference with a prompt",
)
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
"""Execute SAM 2 segmentation given an image and a prompt.
- **point**: `prompt_data` is a list of `[[x, y], ...]` normalized coordinates.
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
- **semantic**: Not yet implemented; falls back to auto segmentation.
"""
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
image = _load_frame_image(frame)
prompt_type = payload.prompt_type.lower()
polygons: List[List[List[float]]] = []
scores: List[float] = []
if prompt_type == "point":
points = payload.prompt_data
if not isinstance(points, list) or len(points) == 0:
raise HTTPException(status_code=400, detail="Invalid point prompt data")
labels = [1] * len(points)
polygons, scores = sam_engine.predict_points(image, points, labels)
elif prompt_type == "box":
box = payload.prompt_data
if not isinstance(box, list) or len(box) != 4:
raise HTTPException(status_code=400, detail="Invalid box prompt data")
polygons, scores = sam_engine.predict_box(image, box)
elif prompt_type == "semantic":
# Placeholder: use auto segmentation for now
logger.info("Semantic prompt not implemented; using auto segmentation")
polygons, scores = sam_engine.predict_auto(image)
else:
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
return {"polygons": polygons, "scores": scores}
@router.post(
"/auto",
response_model=PredictResponse,
summary="Run automatic segmentation",
)
def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
"""Run automatic mask generation on a frame using a grid of point prompts."""
frame = db.query(Frame).filter(Frame.id == image_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
image = _load_frame_image(frame)
polygons, scores = sam_engine.predict_auto(image)
return {"polygons": polygons, "scores": scores}
@router.post(
"/annotate",
response_model=AnnotationOut,
status_code=status.HTTP_201_CREATED,
summary="Save an AI-generated annotation",
)
def save_annotation(
payload: AnnotationCreate,
db: Session = Depends(get_db),
) -> Annotation:
"""Persist an annotation (mask, points, bbox) into the database."""
project = db.query(Frame).filter(Frame.id == payload.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if payload.frame_id:
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
annotation = Annotation(**payload.model_dump())
db.add(annotation)
db.commit()
db.refresh(annotation)
logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id)
return annotation

24
backend/routers/auth.py Normal file
View File

@@ -0,0 +1,24 @@
"""Authentication endpoints."""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
router = APIRouter(prefix="/api/auth", tags=["Auth"])
class LoginRequest(BaseModel):
username: str
password: str
class LoginResponse(BaseModel):
token: str
username: str
@router.post("/login", response_model=LoginResponse)
def login(payload: LoginRequest) -> dict:
"""Simple login for development."""
if payload.username == "admin" and payload.password == "123456":
return {"token": "fake-jwt-token-for-admin", "username": payload.username}
raise HTTPException(status_code=401, detail="Invalid credentials")

194
backend/routers/export.py Normal file
View File

@@ -0,0 +1,194 @@
"""Annotation export endpoints (COCO, PNG masks)."""
import io
import json
import logging
import os
import zipfile
from datetime import datetime
from typing import Any, Dict, List
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from database import get_db
from models import Project, Annotation, Frame, Template
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/export", tags=["Export"])
def _mask_from_polygon(
polygon: List[List[float]],
width: int,
height: int,
) -> np.ndarray:
"""Render a normalized polygon to a binary mask."""
import cv2
pts = np.array(
[[int(p[0] * width), int(p[1] * height)] for p in polygon],
dtype=np.int32,
)
mask = np.zeros((height, width), dtype=np.uint8)
cv2.fillPoly(mask, [pts], 255)
return mask
@router.get(
"/{project_id}/coco",
summary="Export annotations in COCO format",
)
def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
"""Export all annotations for a project as a COCO-format JSON file."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
annotations = (
db.query(Annotation)
.filter(Annotation.project_id == project_id)
.all()
)
frames = (
db.query(Frame)
.filter(Frame.project_id == project_id)
.order_by(Frame.frame_index)
.all()
)
templates = db.query(Template).all()
# Build COCO structure
images = []
for idx, frame in enumerate(frames):
images.append({
"id": frame.id,
"file_name": frame.image_url,
"width": frame.width or 1920,
"height": frame.height or 1080,
"frame_index": idx,
})
categories = []
template_id_to_cat_id: Dict[int, int] = {}
for cat_idx, tmpl in enumerate(templates, start=1):
categories.append({
"id": cat_idx,
"name": tmpl.name,
"color": tmpl.color,
})
template_id_to_cat_id[tmpl.id] = cat_idx
coco_annotations = []
ann_id = 1
for ann in annotations:
if not ann.mask_data:
continue
polygons = ann.mask_data.get("polygons", [])
if not polygons:
continue
# Use first polygon for bbox / area approximation
first_poly = polygons[0]
xs = [p[0] for p in first_poly]
ys = [p[1] for p in first_poly]
width = ann.frame.width if ann.frame else 1920
height = ann.frame.height if ann.frame else 1080
bbox = [
min(xs) * width,
min(ys) * height,
(max(xs) - min(xs)) * width,
(max(ys) - min(ys)) * height,
]
area = bbox[2] * bbox[3]
segmentation = []
for poly in polygons:
flat = []
for p in poly:
flat.append(p[0] * width)
flat.append(p[1] * height)
segmentation.append(flat)
coco_annotations.append({
"id": ann_id,
"image_id": ann.frame_id,
"category_id": template_id_to_cat_id.get(ann.template_id, 0),
"segmentation": segmentation,
"area": area,
"bbox": bbox,
"iscrowd": 0,
})
ann_id += 1
coco = {
"info": {
"description": f"Annotations for {project.name}",
"version": "1.0",
"year": datetime.now().year,
"date_created": datetime.now().isoformat(),
},
"images": images,
"annotations": coco_annotations,
"categories": categories,
}
data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8")
filename = f"project_{project_id}_coco.json"
return StreamingResponse(
io.BytesIO(data),
media_type="application/json",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
@router.get(
"/{project_id}/masks",
summary="Export PNG masks as a ZIP archive",
)
def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
"""Export all annotation masks as individual PNG files inside a ZIP archive."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
annotations = (
db.query(Annotation)
.filter(Annotation.project_id == project_id)
.all()
)
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
for ann in annotations:
if not ann.mask_data:
continue
polygons = ann.mask_data.get("polygons", [])
if not polygons:
continue
width = ann.frame.width if ann.frame else 1920
height = ann.frame.height if ann.frame else 1080
combined = np.zeros((height, width), dtype=np.uint8)
for poly in polygons:
mask = _mask_from_polygon(poly, width, height)
combined = np.maximum(combined, mask)
# Encode PNG
import cv2
_, encoded = cv2.imencode(".png", combined)
fname = f"mask_{ann.id:06d}.png"
zf.writestr(fname, encoded.tobytes())
zip_buffer.seek(0)
filename = f"project_{project_id}_masks.zip"
return StreamingResponse(
zip_buffer,
media_type="application/zip",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)

192
backend/routers/media.py Normal file
View File

@@ -0,0 +1,192 @@
"""Media upload and parsing endpoints."""
import logging
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
from sqlalchemy.orm import Session
from database import get_db
from minio_client import upload_file, get_presigned_url
from models import Project, Frame
from schemas import FrameOut
from services.frame_parser import parse_video, parse_dicom, upload_frames_to_minio
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/media", tags=["Media"])
ALLOWED_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".png", ".jpg", ".jpeg", ".dcm"}
def _get_ext(filename: str) -> str:
return Path(filename).suffix.lower()
@router.post(
"/upload",
status_code=status.HTTP_201_CREATED,
summary="Upload a media file",
)
async def upload_media(
file: UploadFile = File(...),
project_id: Optional[int] = Form(None),
db: Session = Depends(get_db),
) -> dict:
"""Accept a video, image, or DICOM file and store it in MinIO.
If project_id is provided, the video_path of the project is updated.
Returns the presigned URL of the uploaded object.
"""
if not file.filename:
raise HTTPException(status_code=400, detail="Missing filename")
ext = _get_ext(file.filename)
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {ext}",
)
data = await file.read()
object_name = f"uploads/{project_id or 'general'}/{file.filename}"
try:
upload_file(object_name, data, content_type=file.content_type or "application/octet-stream", length=len(data))
except Exception as exc: # noqa: BLE001
logger.error("Upload failed: %s", exc)
raise HTTPException(status_code=500, detail="Upload to storage failed") from exc
file_url = get_presigned_url(object_name, expires=3600)
if project_id:
project = db.query(Project).filter(Project.id == project_id).first()
if project:
project.video_path = object_name
db.commit()
logger.info("Linked upload to project_id=%s", project_id)
else:
logger.warning("Project id=%s not found for upload linkage", project_id)
# TODO: enqueue async parsing job (Celery / background task)
logger.info("Upload complete: %s (size=%d bytes). Async parsing queued.", object_name, len(data))
return {
"object_name": object_name,
"file_url": file_url,
"size": len(data),
"message": "Upload successful. Parsing job queued.",
}
@router.post(
"/parse",
status_code=status.HTTP_202_ACCEPTED,
summary="Trigger frame extraction",
)
def parse_media(
project_id: int,
source_type: str = "video", # video | dicom
db: Session = Depends(get_db),
) -> dict:
"""Trigger frame extraction for a project's uploaded media.
* video: uses FFmpeg or OpenCV fallback.
* dicom: uses pydicom to read DCM frames.
Extracted frames are uploaded to MinIO and registered in the database.
"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if not project.video_path:
raise HTTPException(status_code=400, detail="Project has no media uploaded")
# Download from MinIO to a temp directory
from minio_client import download_file
try:
media_bytes = download_file(project.video_path)
except Exception as exc: # noqa: BLE001
logger.error("Failed to download media for parsing: %s", exc)
raise HTTPException(status_code=500, detail="Failed to retrieve media from storage") from exc
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project_id}_")
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
with open(local_path, "wb") as f:
f.write(media_bytes)
output_dir = os.path.join(tmp_dir, "frames")
os.makedirs(output_dir, exist_ok=True)
try:
if source_type == "dicom":
# For DICOM, treat local_path as a directory if it contains multiple .dcm
# If a single .dcm file was uploaded, put it in its own sub-dir
dcm_dir = os.path.join(tmp_dir, "dcm")
os.makedirs(dcm_dir, exist_ok=True)
if local_path.lower().endswith(".dcm"):
shutil.move(local_path, os.path.join(dcm_dir, os.path.basename(local_path)))
else:
shutil.unpack_archive(local_path, dcm_dir) if shutil.which("unzip") else shutil.move(local_path, dcm_dir)
frame_files = parse_dicom(dcm_dir, output_dir)
else:
frame_files = parse_video(local_path, output_dir, fps=30)
except Exception as exc: # noqa: BLE001
logger.error("Frame extraction failed: %s", exc)
shutil.rmtree(tmp_dir, ignore_errors=True)
raise HTTPException(status_code=500, detail="Frame extraction failed") from exc
# Upload frames to MinIO
try:
object_names = upload_frames_to_minio(frame_files, project_id)
except Exception as exc: # noqa: BLE001
logger.error("Frame upload failed: %s", exc)
shutil.rmtree(tmp_dir, ignore_errors=True)
raise HTTPException(status_code=500, detail="Frame upload to storage failed") from exc
# Register frames in DB
frames_out = []
for idx, obj_name in enumerate(object_names):
# Get image dimensions
local_frame = frame_files[idx]
try:
import cv2
img = cv2.imread(local_frame)
h, w = img.shape[:2] if img is not None else (None, None)
except Exception: # noqa: BLE001
h, w = None, None
frame = Frame(
project_id=project_id,
frame_index=idx,
image_url=obj_name,
width=w,
height=h,
)
db.add(frame)
frames_out.append(frame)
db.commit()
for f in frames_out:
db.refresh(f)
# Cleanup temp files
shutil.rmtree(tmp_dir, ignore_errors=True)
project.status = "ready"
db.commit()
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project_id)
return {
"project_id": project_id,
"frames_extracted": len(frames_out),
"status": "ready",
"message": "Frame extraction completed successfully.",
}

165
backend/routers/projects.py Normal file
View File

@@ -0,0 +1,165 @@
"""Project and Frame CRUD endpoints."""
import logging
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from database import get_db
from models import Project, Frame
from schemas import ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/projects", tags=["Projects"])
# ---------------------------------------------------------------------------
# Projects
# ---------------------------------------------------------------------------
@router.post(
"",
response_model=ProjectOut,
status_code=status.HTTP_201_CREATED,
summary="Create a new project",
)
def create_project(payload: ProjectCreate, db: Session = Depends(get_db)) -> Project:
"""Create a new segmentation project."""
project = Project(**payload.model_dump())
db.add(project)
db.commit()
db.refresh(project)
logger.info("Created project id=%s name=%s", project.id, project.name)
return project
@router.get(
"",
response_model=List[ProjectOut],
summary="List all projects",
)
def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)) -> List[Project]:
"""Retrieve a paginated list of projects."""
return db.query(Project).offset(skip).limit(limit).all()
@router.get(
"/{project_id}",
response_model=ProjectOut,
summary="Get a single project",
)
def get_project(project_id: int, db: Session = Depends(get_db)) -> Project:
"""Retrieve a project by its ID."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return project
@router.patch(
"/{project_id}",
response_model=ProjectOut,
summary="Update a project",
)
def update_project(
project_id: int,
payload: ProjectUpdate,
db: Session = Depends(get_db),
) -> Project:
"""Update project fields partially."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
for key, value in payload.model_dump(exclude_unset=True).items():
setattr(project, key, value)
db.commit()
db.refresh(project)
logger.info("Updated project id=%s", project_id)
return project
@router.delete(
"/{project_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a project",
)
def delete_project(project_id: int, db: Session = Depends(get_db)) -> None:
"""Delete a project and all related frames and annotations."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
db.delete(project)
db.commit()
logger.info("Deleted project id=%s", project_id)
# ---------------------------------------------------------------------------
# Frames
# ---------------------------------------------------------------------------
@router.post(
"/{project_id}/frames",
response_model=FrameOut,
status_code=status.HTTP_201_CREATED,
summary="Add a frame to a project",
)
def create_frame(
project_id: int,
payload: FrameCreate,
db: Session = Depends(get_db),
) -> Frame:
"""Register a new frame under a project."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
frame = Frame(project_id=project_id, **payload.model_dump(exclude={"project_id"}))
db.add(frame)
db.commit()
db.refresh(frame)
return frame
@router.get(
"/{project_id}/frames",
response_model=List[FrameOut],
summary="List frames for a project",
)
def list_frames(
project_id: int,
skip: int = 0,
limit: int = 1000,
db: Session = Depends(get_db),
) -> List[Frame]:
"""Retrieve all frames belonging to a project."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return (
db.query(Frame)
.filter(Frame.project_id == project_id)
.order_by(Frame.frame_index)
.offset(skip)
.limit(limit)
.all()
)
@router.get(
"/{project_id}/frames/{frame_id}",
response_model=FrameOut,
summary="Get a single frame",
)
def get_frame(project_id: int, frame_id: int, db: Session = Depends(get_db)) -> Frame:
"""Retrieve a specific frame by ID."""
frame = (
db.query(Frame)
.filter(Frame.project_id == project_id, Frame.id == frame_id)
.first()
)
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
return frame

View File

@@ -0,0 +1,97 @@
"""Template (Ontology) CRUD endpoints."""
import logging
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from database import get_db
from models import Template
from schemas import TemplateCreate, TemplateOut, TemplateUpdate
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/templates", tags=["Templates"])
@router.post(
"",
response_model=TemplateOut,
status_code=status.HTTP_201_CREATED,
summary="Create a new template",
)
def create_template(payload: TemplateCreate, db: Session = Depends(get_db)) -> Template:
"""Create a new ontology template / segmentation class."""
template = Template(**payload.model_dump())
db.add(template)
db.commit()
db.refresh(template)
logger.info("Created template id=%s name=%s", template.id, template.name)
return template
@router.get(
"",
response_model=List[TemplateOut],
summary="List all templates",
)
def list_templates(
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
) -> List[Template]:
"""Retrieve all ontology templates."""
return db.query(Template).offset(skip).limit(limit).all()
@router.get(
"/{template_id}",
response_model=TemplateOut,
summary="Get a single template",
)
def get_template(template_id: int, db: Session = Depends(get_db)) -> Template:
"""Retrieve a template by its ID."""
template = db.query(Template).filter(Template.id == template_id).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
return template
@router.patch(
"/{template_id}",
response_model=TemplateOut,
summary="Update a template",
)
def update_template(
template_id: int,
payload: TemplateUpdate,
db: Session = Depends(get_db),
) -> Template:
"""Update template fields partially."""
template = db.query(Template).filter(Template.id == template_id).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
for key, value in payload.model_dump(exclude_unset=True).items():
setattr(template, key, value)
db.commit()
db.refresh(template)
logger.info("Updated template id=%s", template_id)
return template
@router.delete(
"/{template_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a template",
)
def delete_template(template_id: int, db: Session = Depends(get_db)) -> None:
"""Delete a template. Associated annotations will have template_id set to NULL."""
template = db.query(Template).filter(Template.id == template_id).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
db.delete(template)
db.commit()
logger.info("Deleted template id=%s", template_id)

157
backend/schemas.py Normal file
View File

@@ -0,0 +1,157 @@
"""Pydantic schemas for request/response validation."""
from datetime import datetime
from typing import Optional, Any
from pydantic import BaseModel, ConfigDict
# ---------------------------------------------------------------------------
# Project schemas
# ---------------------------------------------------------------------------
class ProjectBase(BaseModel):
name: str
description: Optional[str] = None
video_path: Optional[str] = None
status: Optional[str] = "pending"
class ProjectCreate(ProjectBase):
pass
class ProjectUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
video_path: Optional[str] = None
status: Optional[str] = None
class ProjectOut(ProjectBase):
model_config = ConfigDict(from_attributes=True)
id: int
created_at: datetime
updated_at: datetime
# ---------------------------------------------------------------------------
# Frame schemas
# ---------------------------------------------------------------------------
class FrameBase(BaseModel):
frame_index: int
image_url: str
width: Optional[int] = None
height: Optional[int] = None
class FrameCreate(FrameBase):
project_id: int
class FrameOut(FrameBase):
model_config = ConfigDict(from_attributes=True)
id: int
project_id: int
created_at: datetime
# ---------------------------------------------------------------------------
# Template schemas
# ---------------------------------------------------------------------------
class TemplateBase(BaseModel):
name: str
color: str
z_index: int = 0
mapping_rules: Optional[dict[str, Any]] = None
class TemplateCreate(TemplateBase):
pass
class TemplateUpdate(BaseModel):
name: Optional[str] = None
color: Optional[str] = None
z_index: Optional[int] = None
mapping_rules: Optional[dict[str, Any]] = None
class TemplateOut(TemplateBase):
model_config = ConfigDict(from_attributes=True)
id: int
created_at: datetime
# ---------------------------------------------------------------------------
# Annotation schemas
# ---------------------------------------------------------------------------
class AnnotationBase(BaseModel):
project_id: int
frame_id: Optional[int] = None
template_id: Optional[int] = None
mask_data: Optional[dict[str, Any]] = None
points: Optional[list[list[float]]] = None
bbox: Optional[list[float]] = None
class AnnotationCreate(AnnotationBase):
pass
class AnnotationUpdate(BaseModel):
mask_data: Optional[dict[str, Any]] = None
points: Optional[list[list[float]]] = None
bbox: Optional[list[float]] = None
template_id: Optional[int] = None
class AnnotationOut(AnnotationBase):
model_config = ConfigDict(from_attributes=True)
id: int
created_at: datetime
updated_at: datetime
# ---------------------------------------------------------------------------
# Mask schemas
# ---------------------------------------------------------------------------
class MaskBase(BaseModel):
annotation_id: int
mask_url: str
format: str = "png"
class MaskCreate(MaskBase):
pass
class MaskOut(MaskBase):
model_config = ConfigDict(from_attributes=True)
id: int
created_at: datetime
# ---------------------------------------------------------------------------
# AI schemas
# ---------------------------------------------------------------------------
class PredictRequest(BaseModel):
image_id: int
prompt_type: str # point / box / semantic
prompt_data: Any
class PredictResponse(BaseModel):
polygons: list[list[list[float]]]
scores: Optional[list[float]] = None
# ---------------------------------------------------------------------------
# Export schemas
# ---------------------------------------------------------------------------
class ExportStatus(BaseModel):
url: str
format: str

View File

View File

@@ -0,0 +1,186 @@
"""Video/DICOM frame parsing and MinIO upload utilities."""
import logging
import os
import shutil
import subprocess
from pathlib import Path
from typing import List, Optional
import cv2
import numpy as np
from pydicom import dcmread
from minio_client import upload_file, BUCKET_NAME
logger = logging.getLogger(__name__)
def parse_video(
video_path: str,
output_dir: str,
fps: int = 30,
max_frames: Optional[int] = None,
) -> List[str]:
"""Extract frames from a video file using FFmpeg or OpenCV fallback.
Args:
video_path: Path to the input video file.
output_dir: Directory to save extracted frames.
fps: Target frame extraction rate.
max_frames: Optional maximum number of frames to extract.
Returns:
List of paths to extracted frame images.
"""
os.makedirs(output_dir, exist_ok=True)
frame_paths: List[str] = []
# Try FFmpeg first
if shutil.which("ffmpeg"):
try:
pattern = os.path.join(output_dir, "frame_%06d.png")
cmd = [
"ffmpeg",
"-i", video_path,
"-vf", f"fps={fps},scale='min(1920,iw)':-1",
"-pix_fmt", "rgb24",
"-y",
pattern,
]
logger.info("Running FFmpeg: %s", " ".join(cmd))
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
if result.returncode == 0:
frame_paths = sorted(
[os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith(".png")]
)
if max_frames:
frame_paths = frame_paths[:max_frames]
logger.info("Extracted %d frames via FFmpeg", len(frame_paths))
return frame_paths
else:
logger.warning("FFmpeg failed: %s", result.stderr)
except Exception as exc: # noqa: BLE001
logger.warning("FFmpeg exception: %s", exc)
# OpenCV fallback
logger.info("Falling back to OpenCV frame extraction")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"Cannot open video: {video_path}")
video_fps = cap.get(cv2.CAP_PROP_FPS) or 30
interval = max(1, int(round(video_fps / fps)))
count = 0
saved = 0
while True:
ret, frame = cap.read()
if not ret:
break
if count % interval == 0:
path = os.path.join(output_dir, f"frame_{saved:06d}.png")
cv2.imwrite(path, frame)
frame_paths.append(path)
saved += 1
if max_frames and saved >= max_frames:
break
count += 1
cap.release()
logger.info("Extracted %d frames via OpenCV", len(frame_paths))
return frame_paths
def parse_dicom(
dicom_dir: str,
output_dir: str,
max_frames: Optional[int] = None,
) -> List[str]:
"""Extract frames from DICOM files in a directory.
Args:
dicom_dir: Directory containing .dcm files.
output_dir: Directory to save extracted frames.
max_frames: Optional maximum number of frames to extract.
Returns:
List of paths to extracted frame images.
"""
os.makedirs(output_dir, exist_ok=True)
dcm_files = sorted(
[f for f in os.listdir(dicom_dir) if f.lower().endswith(".dcm")]
)
frame_paths: List[str] = []
for idx, fname in enumerate(dcm_files):
if max_frames and idx >= max_frames:
break
path = os.path.join(dicom_dir, fname)
try:
ds = dcmread(path)
pixel_array = ds.pixel_array
# Normalize to 8-bit
if pixel_array.dtype != np.uint8:
pixel_array = pixel_array.astype(np.float32)
pixel_array = (
(pixel_array - pixel_array.min())
/ (pixel_array.max() - pixel_array.min() + 1e-8)
* 255
)
pixel_array = pixel_array.astype(np.uint8)
# Handle multi-frame DICOM
if pixel_array.ndim == 3:
for f in range(pixel_array.shape[0]):
out_path = os.path.join(output_dir, f"frame_{idx:06d}_{f:03d}.png")
cv2.imwrite(out_path, pixel_array[f])
frame_paths.append(out_path)
else:
out_path = os.path.join(output_dir, f"frame_{idx:06d}.png")
cv2.imwrite(out_path, pixel_array)
frame_paths.append(out_path)
except Exception as exc: # noqa: BLE001
logger.error("Failed to read DICOM %s: %s", path, exc)
logger.info("Extracted %d frames from DICOM", len(frame_paths))
return frame_paths
def upload_frames_to_minio(
frames: List[str],
project_id: int,
object_prefix: Optional[str] = None,
) -> List[str]:
"""Upload a list of local frame images to MinIO.
Args:
frames: List of local file paths.
project_id: Project ID used for bucket path organization.
object_prefix: Optional prefix override.
Returns:
List of object names (paths) in MinIO.
"""
prefix = object_prefix or f"projects/{project_id}/frames"
object_names: List[str] = []
for frame_path in frames:
fname = os.path.basename(frame_path)
object_name = f"{prefix}/{fname}"
try:
with open(frame_path, "rb") as f:
data = f.read()
upload_file(
object_name,
data,
content_type="image/png",
length=len(data),
)
object_names.append(object_name)
except Exception as exc: # noqa: BLE001
logger.error("Failed to upload %s: %s", frame_path, exc)
logger.info("Uploaded %d/%d frames to MinIO", len(object_names), len(frames))
return object_names

View File

@@ -0,0 +1,234 @@
"""SAM 2 engine wrapper with lazy loading and fallback stubs."""
import logging
import os
from typing import Optional
import numpy as np
from config import settings
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Attempt to import SAM 2; fall back to stubs if unavailable.
# ---------------------------------------------------------------------------
try:
import torch
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
SAM2_AVAILABLE = True
logger.info("SAM2 library imported successfully.")
except Exception as exc: # noqa: BLE001
SAM2_AVAILABLE = False
logger.warning("SAM2 import failed (%s). Using stub engine.", exc)
class SAM2Engine:
"""Lazy-loaded SAM 2 inference engine."""
def __init__(self) -> None:
self._predictor: Optional[SAM2ImagePredictor] = None
self._model_loaded = False
# -----------------------------------------------------------------------
# Internal helpers
# -----------------------------------------------------------------------
def _load_model(self) -> None:
"""Load the SAM 2 model and predictor on first use."""
if self._model_loaded:
return
if not SAM2_AVAILABLE:
logger.warning("SAM2 not available; skipping model load.")
self._model_loaded = True
return
if not os.path.isfile(settings.sam_model_path):
logger.error("SAM checkpoint not found at %s", settings.sam_model_path)
self._model_loaded = True
return
try:
model = build_sam2(
settings.sam_model_config,
settings.sam_model_path,
device="cuda",
)
self._predictor = SAM2ImagePredictor(model)
self._model_loaded = True
logger.info("SAM 2 model loaded from %s", settings.sam_model_path)
except Exception as exc: # noqa: BLE001
logger.error("Failed to load SAM 2 model: %s", exc)
self._model_loaded = True # Prevent repeated load attempts
def _ensure_ready(self) -> bool:
"""Ensure the model is loaded; return whether it is usable."""
self._load_model()
return SAM2_AVAILABLE and self._predictor is not None
# -----------------------------------------------------------------------
# Public API
# -----------------------------------------------------------------------
def predict_points(
self,
image: np.ndarray,
points: list[list[float]],
labels: list[int],
) -> tuple[list[list[list[float]]], list[float]]:
"""Run point-prompt segmentation.
Args:
image: HWC numpy array (uint8).
points: List of [x, y] normalized coordinates (0-1).
labels: 1 for foreground, 0 for background.
Returns:
Tuple of (polygons, scores).
"""
if not self._ensure_ready():
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
h, w = image.shape[:2]
pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32)
lbls = np.array(labels, dtype=np.int32)
with torch.inference_mode(): # type: ignore[name-defined]
self._predictor.set_image(image)
masks, scores, _ = self._predictor.predict(
point_coords=pts,
point_labels=lbls,
multimask_output=True,
)
polygons = []
for m in masks:
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores.tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 point prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def predict_box(
self,
image: np.ndarray,
box: list[float],
) -> tuple[list[list[list[float]]], list[float]]:
"""Run box-prompt segmentation.
Args:
image: HWC numpy array (uint8).
box: [x1, y1, x2, y2] normalized coordinates.
Returns:
Tuple of (polygons, scores).
"""
if not self._ensure_ready():
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
h, w = image.shape[:2]
bbox = np.array(
[box[0] * w, box[1] * h, box[2] * w, box[3] * h],
dtype=np.float32,
)
with torch.inference_mode(): # type: ignore[name-defined]
self._predictor.set_image(image)
masks, scores, _ = self._predictor.predict(
box=bbox[None, :],
multimask_output=False,
)
polygons = []
for m in masks:
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores.tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 box prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def predict_auto(self, image: np.ndarray) -> tuple[list[list[list[float]]], list[float]]:
"""Run automatic mask generation (grid of points).
Args:
image: HWC numpy array (uint8).
Returns:
Tuple of (polygons, scores).
"""
if not self._ensure_ready():
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
with torch.inference_mode(): # type: ignore[name-defined]
self._predictor.set_image(image)
# Generate a uniform 16x16 grid of point prompts
h, w = image.shape[:2]
grid = np.mgrid[0:1:17j, 0:1:17j].reshape(2, -1).T
pts = grid * np.array([w, h])
lbls = np.ones(pts.shape[0], dtype=np.int32)
masks, scores, _ = self._predictor.predict(
point_coords=pts,
point_labels=lbls,
multimask_output=True,
)
polygons = []
for m in masks[:3]: # Limit to top 3 masks
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores[:3].tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 auto prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
# -----------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------
@staticmethod
def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
"""Convert a binary mask to a normalized polygon."""
import cv2
if mask.dtype != np.uint8:
mask = (mask > 0).astype(np.uint8)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
h, w = mask.shape[:2]
largest = []
for cnt in contours:
if len(cnt) > len(largest):
largest = cnt
if len(largest) < 3:
return []
return [[float(pt[0][0]) / w, float(pt[0][1]) / h] for pt in largest]
@staticmethod
def _dummy_polygons(w: int, h: int) -> list[list[list[float]]]:
"""Return a dummy rectangle polygon for fallback mode."""
return [
[
[0.25, 0.25],
[0.75, 0.25],
[0.75, 0.75],
[0.25, 0.75],
]
]
# Singleton instance
sam_engine = SAM2Engine()