2026-04-29-21-51-19 - 全栈系统改造:FastAPI后端+SAM2+PostgreSQL+Redis+MinIO+前端Zustand重构
This commit is contained in:
35
backend/config.py
Normal file
35
backend/config.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Application configuration using Pydantic Settings."""
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
# Database
|
||||
db_url: str = "postgresql://seguser:segpass123@localhost:5432/segserver"
|
||||
|
||||
# Redis
|
||||
redis_url: str = "redis://localhost:6379/0"
|
||||
|
||||
# MinIO
|
||||
minio_endpoint: str = "localhost:9000"
|
||||
minio_access_key: str = "minioadmin"
|
||||
minio_secret_key: str = "minioadmin"
|
||||
minio_secure: bool = False
|
||||
|
||||
# SAM2
|
||||
sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
|
||||
sam_model_config: str = "sam2_hiera_t.yaml"
|
||||
|
||||
# App
|
||||
app_env: str = "development"
|
||||
cors_origins: list[str] = ["http://localhost:3000"]
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
29
backend/database.py
Normal file
29
backend/database.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Database configuration using synchronous SQLAlchemy."""
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker, Session
|
||||
from fastapi import Depends
|
||||
from typing import Generator
|
||||
|
||||
from config import settings
|
||||
|
||||
engine = create_engine(
|
||||
settings.db_url,
|
||||
pool_pre_ping=True,
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""FastAPI dependency that yields a database session."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
49
backend/download_sam2.py
Normal file
49
backend/download_sam2.py
Normal file
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
SAM 2 模型权重下载脚本
|
||||
运行: python download_sam2.py
|
||||
"""
|
||||
import os
|
||||
import urllib.request
|
||||
import sys
|
||||
|
||||
MODEL_DIR = "/home/wkmgc/Desktop/Seg_Server/models"
|
||||
os.makedirs(MODEL_DIR, exist_ok=True)
|
||||
|
||||
# SAM 2 模型权重 (Meta AI 官方)
|
||||
MODELS = {
|
||||
"sam2_hiera_tiny.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt",
|
||||
"sam2_hiera_small.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
|
||||
"sam2_hiera_base_plus.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
|
||||
"sam2_hiera_large.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
|
||||
}
|
||||
|
||||
def download_file(url: str, dest: str):
|
||||
if os.path.exists(dest):
|
||||
print(f"[跳过] {os.path.basename(dest)} 已存在")
|
||||
return
|
||||
print(f"[下载] {url}")
|
||||
print(f" → {dest}")
|
||||
try:
|
||||
urllib.request.urlretrieve(url, dest)
|
||||
size_mb = os.path.getsize(dest) / 1024 / 1024
|
||||
print(f" ✓ 完成 ({size_mb:.1f} MB)")
|
||||
except Exception as e:
|
||||
print(f" ✗ 失败: {e}", file=sys.stderr)
|
||||
if os.path.exists(dest):
|
||||
os.remove(dest)
|
||||
|
||||
def main():
|
||||
print("=" * 50)
|
||||
print("SAM 2 模型权重下载")
|
||||
print("=" * 50)
|
||||
for name, url in MODELS.items():
|
||||
dest = os.path.join(MODEL_DIR, name)
|
||||
download_file(url, dest)
|
||||
print("=" * 50)
|
||||
print("全部完成!")
|
||||
print(f"模型目录: {MODEL_DIR}")
|
||||
print("=" * 50)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
83
backend/main.py
Normal file
83
backend/main.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""FastAPI application entrypoint."""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from config import settings
|
||||
from database import Base, engine
|
||||
from minio_client import ensure_bucket_exists
|
||||
from redis_client import ping as redis_ping
|
||||
|
||||
from routers import projects, templates, media, ai, export, auth
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan: startup and shutdown hooks."""
|
||||
# Startup
|
||||
logger.info("Starting up SegServer backend...")
|
||||
|
||||
# Initialize database tables
|
||||
try:
|
||||
Base.metadata.create_all(bind=engine)
|
||||
logger.info("Database tables initialized.")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Database initialization failed: %s", exc)
|
||||
|
||||
# Check MinIO bucket
|
||||
try:
|
||||
ensure_bucket_exists()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("MinIO bucket check failed: %s", exc)
|
||||
|
||||
# Check Redis
|
||||
if redis_ping():
|
||||
logger.info("Redis connection OK.")
|
||||
else:
|
||||
logger.warning("Redis connection failed.")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down SegServer backend...")
|
||||
engine.dispose()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="SegServer API",
|
||||
description="Semantic Segmentation System Backend",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Routers
|
||||
app.include_router(auth.router)
|
||||
app.include_router(projects.router)
|
||||
app.include_router(templates.router)
|
||||
app.include_router(media.router)
|
||||
app.include_router(ai.router)
|
||||
app.include_router(export.router)
|
||||
|
||||
|
||||
@app.get("/health", tags=["Health"])
|
||||
def health_check() -> dict:
|
||||
"""Health check endpoint."""
|
||||
return {"status": "ok", "service": "SegServer"}
|
||||
127
backend/minio_client.py
Normal file
127
backend/minio_client.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""MinIO client wrapper for object storage operations."""
|
||||
|
||||
import io
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BUCKET_NAME = "seg-media"
|
||||
|
||||
_minio_client: Optional[Minio] = None
|
||||
|
||||
|
||||
def get_minio_client() -> Minio:
|
||||
"""Return a singleton MinIO client instance."""
|
||||
global _minio_client
|
||||
if _minio_client is None:
|
||||
_minio_client = Minio(
|
||||
settings.minio_endpoint,
|
||||
access_key=settings.minio_access_key,
|
||||
secret_key=settings.minio_secret_key,
|
||||
secure=settings.minio_secure,
|
||||
)
|
||||
return _minio_client
|
||||
|
||||
|
||||
def ensure_bucket_exists() -> None:
|
||||
"""Create the bucket if it does not already exist."""
|
||||
client = get_minio_client()
|
||||
try:
|
||||
if not client.bucket_exists(BUCKET_NAME):
|
||||
client.make_bucket(BUCKET_NAME)
|
||||
logger.info("Created MinIO bucket: %s", BUCKET_NAME)
|
||||
else:
|
||||
logger.info("MinIO bucket %s already exists", BUCKET_NAME)
|
||||
except S3Error as exc:
|
||||
logger.error("MinIO bucket check/creation failed: %s", exc)
|
||||
raise
|
||||
|
||||
|
||||
def upload_file(
|
||||
object_name: str,
|
||||
data: bytes,
|
||||
content_type: str = "application/octet-stream",
|
||||
length: int = -1,
|
||||
) -> str:
|
||||
"""Upload bytes to MinIO and return the object name.
|
||||
|
||||
Args:
|
||||
object_name: Destination path inside the bucket.
|
||||
data: Raw bytes or a file-like object.
|
||||
content_type: MIME type of the object.
|
||||
length: Object size; -1 for unknown (uses chunked upload).
|
||||
|
||||
Returns:
|
||||
The object name (same as input).
|
||||
"""
|
||||
client = get_minio_client()
|
||||
if isinstance(data, bytes):
|
||||
data = io.BytesIO(data)
|
||||
length = len(data.getvalue())
|
||||
|
||||
try:
|
||||
client.put_object(
|
||||
BUCKET_NAME,
|
||||
object_name,
|
||||
data,
|
||||
length=length,
|
||||
content_type=content_type,
|
||||
)
|
||||
logger.info("Uploaded to MinIO: %s", object_name)
|
||||
return object_name
|
||||
except S3Error as exc:
|
||||
logger.error("MinIO upload failed: %s", exc)
|
||||
raise
|
||||
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
def get_presigned_url(
|
||||
object_name: str,
|
||||
expires: int = 3600,
|
||||
method: str = "GET",
|
||||
) -> str:
|
||||
"""Generate a presigned URL for an object.
|
||||
|
||||
Args:
|
||||
object_name: Path inside the bucket.
|
||||
expires: Expiration time in seconds (default 1 hour).
|
||||
method: HTTP method (GET or PUT).
|
||||
|
||||
Returns:
|
||||
Presigned URL string.
|
||||
"""
|
||||
client = get_minio_client()
|
||||
try:
|
||||
url = client.get_presigned_url(method, BUCKET_NAME, object_name, expires=timedelta(seconds=expires))
|
||||
return url
|
||||
except S3Error as exc:
|
||||
logger.error("MinIO presigned URL failed: %s", exc)
|
||||
raise
|
||||
|
||||
|
||||
def download_file(object_name: str) -> bytes:
|
||||
"""Download an object from MinIO and return its bytes.
|
||||
|
||||
Args:
|
||||
object_name: Path inside the bucket.
|
||||
|
||||
Returns:
|
||||
Raw bytes of the object.
|
||||
"""
|
||||
client = get_minio_client()
|
||||
try:
|
||||
response = client.get_object(BUCKET_NAME, object_name)
|
||||
data = response.read()
|
||||
response.close()
|
||||
response.release_conn()
|
||||
return data
|
||||
except S3Error as exc:
|
||||
logger.error("MinIO download failed: %s", exc)
|
||||
raise
|
||||
118
backend/models.py
Normal file
118
backend/models.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""SQLAlchemy ORM models."""
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
JSON,
|
||||
Float,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from database import Base
|
||||
|
||||
|
||||
class Project(Base):
|
||||
"""Project model representing a segmentation project."""
|
||||
|
||||
__tablename__ = "projects"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
video_path = Column(String(512), nullable=True)
|
||||
status = Column(String(50), default="pending", nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
frames = relationship("Frame", back_populates="project", cascade="all, delete-orphan")
|
||||
annotations = relationship(
|
||||
"Annotation", back_populates="project", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class Frame(Base):
|
||||
"""Frame model representing an extracted video frame."""
|
||||
|
||||
__tablename__ = "frames"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
project_id = Column(Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
||||
frame_index = Column(Integer, nullable=False)
|
||||
image_url = Column(String(512), nullable=False)
|
||||
width = Column(Integer, nullable=True)
|
||||
height = Column(Integer, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
project = relationship("Project", back_populates="frames")
|
||||
annotations = relationship(
|
||||
"Annotation", back_populates="frame", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class Template(Base):
|
||||
"""Template (Ontology) model for segmentation classes."""
|
||||
|
||||
__tablename__ = "templates"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(255), nullable=False)
|
||||
color = Column(String(50), nullable=False)
|
||||
z_index = Column(Integer, default=0, nullable=False)
|
||||
mapping_rules = Column(JSON, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
annotations = relationship(
|
||||
"Annotation", back_populates="template", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class Annotation(Base):
|
||||
"""Annotation model for segmentation masks and prompts."""
|
||||
|
||||
__tablename__ = "annotations"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
project_id = Column(
|
||||
Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
frame_id = Column(
|
||||
Integer, ForeignKey("frames.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
template_id = Column(
|
||||
Integer, ForeignKey("templates.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
mask_data = Column(JSON, nullable=True)
|
||||
points = Column(JSON, nullable=True)
|
||||
bbox = Column(JSON, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
project = relationship("Project", back_populates="annotations")
|
||||
frame = relationship("Frame", back_populates="annotations")
|
||||
template = relationship("Template", back_populates="annotations")
|
||||
masks = relationship("Mask", back_populates="annotation", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class Mask(Base):
|
||||
"""Mask model for exported/derived mask files."""
|
||||
|
||||
__tablename__ = "masks"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
annotation_id = Column(
|
||||
Integer, ForeignKey("annotations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
mask_url = Column(String(512), nullable=False)
|
||||
format = Column(String(50), default="png", nullable=False) # png / rle / json
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
annotation = relationship("Annotation", back_populates="masks")
|
||||
61
backend/redis_client.py
Normal file
61
backend/redis_client.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Redis client wrapper for caching and task queuing."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Any
|
||||
|
||||
import redis
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_redis_client: Optional[redis.Redis] = None
|
||||
|
||||
|
||||
def get_redis_client() -> redis.Redis:
|
||||
"""Return a singleton Redis client instance."""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = redis.from_url(settings.redis_url, decode_responses=True)
|
||||
return _redis_client
|
||||
|
||||
|
||||
def ping() -> bool:
|
||||
"""Check Redis connectivity."""
|
||||
try:
|
||||
return get_redis_client().ping()
|
||||
except redis.ConnectionError as exc:
|
||||
logger.error("Redis ping failed: %s", exc)
|
||||
return False
|
||||
|
||||
|
||||
def set_json(key: str, value: Any, expire: Optional[int] = None) -> None:
|
||||
"""Store a JSON-serializable value in Redis."""
|
||||
client = get_redis_client()
|
||||
try:
|
||||
client.set(key, json.dumps(value), ex=expire)
|
||||
except redis.RedisError as exc:
|
||||
logger.error("Redis set_json failed: %s", exc)
|
||||
raise
|
||||
|
||||
|
||||
def get_json(key: str) -> Optional[Any]:
|
||||
"""Retrieve and deserialize a JSON value from Redis."""
|
||||
client = get_redis_client()
|
||||
try:
|
||||
data = client.get(key)
|
||||
return json.loads(data) if data is not None else None
|
||||
except redis.RedisError as exc:
|
||||
logger.error("Redis get_json failed: %s", exc)
|
||||
raise
|
||||
|
||||
|
||||
def delete_key(key: str) -> int:
|
||||
"""Delete a key from Redis. Returns number of deleted keys."""
|
||||
client = get_redis_client()
|
||||
try:
|
||||
return client.delete(key)
|
||||
except redis.RedisError as exc:
|
||||
logger.error("Redis delete_key failed: %s", exc)
|
||||
raise
|
||||
38
backend/requirements.txt
Normal file
38
backend/requirements.txt
Normal file
@@ -0,0 +1,38 @@
|
||||
# Web framework
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
python-multipart
|
||||
|
||||
# Database
|
||||
sqlalchemy
|
||||
psycopg2-binary
|
||||
alembic
|
||||
|
||||
# Cache / Task queue
|
||||
redis
|
||||
celery
|
||||
|
||||
# Object storage
|
||||
minio
|
||||
|
||||
# Image / Video / DICOM processing
|
||||
opencv-python
|
||||
pillow
|
||||
scikit-image
|
||||
pydicom
|
||||
numpy
|
||||
|
||||
# SAM 2 (may require manual installation depending on CUDA version)
|
||||
sam2
|
||||
|
||||
# PyTorch (CUDA 12.4 wheel index; adjust for your CUDA version if needed)
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
||||
|
||||
# Configuration
|
||||
pydantic-settings
|
||||
|
||||
# Utilities
|
||||
python-jose[cryptography]
|
||||
passlib[bcrypt]
|
||||
0
backend/routers/__init__.py
Normal file
0
backend/routers/__init__.py
Normal file
123
backend/routers/ai.py
Normal file
123
backend/routers/ai.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""AI inference endpoints using SAM 2."""
|
||||
|
||||
import logging
|
||||
from typing import Any, List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import download_file
|
||||
from models import Frame, Annotation
|
||||
from schemas import PredictRequest, PredictResponse, AnnotationOut, AnnotationCreate
|
||||
from services.sam2_engine import sam_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/ai", tags=["AI"])
|
||||
|
||||
|
||||
def _load_frame_image(frame: Frame) -> np.ndarray:
|
||||
"""Download a frame from MinIO and decode it to an RGB numpy array."""
|
||||
try:
|
||||
data = download_file(frame.image_url)
|
||||
arr = np.frombuffer(data, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise ValueError("OpenCV could not decode image")
|
||||
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to load frame image: %s", exc)
|
||||
raise HTTPException(status_code=500, detail="Failed to load frame image") from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/predict",
|
||||
response_model=PredictResponse,
|
||||
summary="Run SAM 2 inference with a prompt",
|
||||
)
|
||||
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
"""Execute SAM 2 segmentation given an image and a prompt.
|
||||
|
||||
- **point**: `prompt_data` is a list of `[[x, y], ...]` normalized coordinates.
|
||||
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
|
||||
- **semantic**: Not yet implemented; falls back to auto segmentation.
|
||||
"""
|
||||
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
image = _load_frame_image(frame)
|
||||
prompt_type = payload.prompt_type.lower()
|
||||
|
||||
polygons: List[List[List[float]]] = []
|
||||
scores: List[float] = []
|
||||
|
||||
if prompt_type == "point":
|
||||
points = payload.prompt_data
|
||||
if not isinstance(points, list) or len(points) == 0:
|
||||
raise HTTPException(status_code=400, detail="Invalid point prompt data")
|
||||
labels = [1] * len(points)
|
||||
polygons, scores = sam_engine.predict_points(image, points, labels)
|
||||
|
||||
elif prompt_type == "box":
|
||||
box = payload.prompt_data
|
||||
if not isinstance(box, list) or len(box) != 4:
|
||||
raise HTTPException(status_code=400, detail="Invalid box prompt data")
|
||||
polygons, scores = sam_engine.predict_box(image, box)
|
||||
|
||||
elif prompt_type == "semantic":
|
||||
# Placeholder: use auto segmentation for now
|
||||
logger.info("Semantic prompt not implemented; using auto segmentation")
|
||||
polygons, scores = sam_engine.predict_auto(image)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
|
||||
|
||||
return {"polygons": polygons, "scores": scores}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/auto",
|
||||
response_model=PredictResponse,
|
||||
summary="Run automatic segmentation",
|
||||
)
|
||||
def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
|
||||
"""Run automatic mask generation on a frame using a grid of point prompts."""
|
||||
frame = db.query(Frame).filter(Frame.id == image_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
image = _load_frame_image(frame)
|
||||
polygons, scores = sam_engine.predict_auto(image)
|
||||
|
||||
return {"polygons": polygons, "scores": scores}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/annotate",
|
||||
response_model=AnnotationOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Save an AI-generated annotation",
|
||||
)
|
||||
def save_annotation(
|
||||
payload: AnnotationCreate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Annotation:
|
||||
"""Persist an annotation (mask, points, bbox) into the database."""
|
||||
project = db.query(Frame).filter(Frame.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
if payload.frame_id:
|
||||
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
annotation = Annotation(**payload.model_dump())
|
||||
db.add(annotation)
|
||||
db.commit()
|
||||
db.refresh(annotation)
|
||||
logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id)
|
||||
return annotation
|
||||
24
backend/routers/auth.py
Normal file
24
backend/routers/auth.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Authentication endpoints."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["Auth"])
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
token: str
|
||||
username: str
|
||||
|
||||
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
def login(payload: LoginRequest) -> dict:
|
||||
"""Simple login for development."""
|
||||
if payload.username == "admin" and payload.password == "123456":
|
||||
return {"token": "fake-jwt-token-for-admin", "username": payload.username}
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
194
backend/routers/export.py
Normal file
194
backend/routers/export.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Annotation export endpoints (COCO, PNG masks)."""
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Project, Annotation, Frame, Template
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/export", tags=["Export"])
|
||||
|
||||
|
||||
def _mask_from_polygon(
|
||||
polygon: List[List[float]],
|
||||
width: int,
|
||||
height: int,
|
||||
) -> np.ndarray:
|
||||
"""Render a normalized polygon to a binary mask."""
|
||||
import cv2
|
||||
|
||||
pts = np.array(
|
||||
[[int(p[0] * width), int(p[1] * height)] for p in polygon],
|
||||
dtype=np.int32,
|
||||
)
|
||||
mask = np.zeros((height, width), dtype=np.uint8)
|
||||
cv2.fillPoly(mask, [pts], 255)
|
||||
return mask
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/coco",
|
||||
summary="Export annotations in COCO format",
|
||||
)
|
||||
def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
|
||||
"""Export all annotations for a project as a COCO-format JSON file."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == project_id)
|
||||
.all()
|
||||
)
|
||||
frames = (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id)
|
||||
.order_by(Frame.frame_index)
|
||||
.all()
|
||||
)
|
||||
templates = db.query(Template).all()
|
||||
|
||||
# Build COCO structure
|
||||
images = []
|
||||
for idx, frame in enumerate(frames):
|
||||
images.append({
|
||||
"id": frame.id,
|
||||
"file_name": frame.image_url,
|
||||
"width": frame.width or 1920,
|
||||
"height": frame.height or 1080,
|
||||
"frame_index": idx,
|
||||
})
|
||||
|
||||
categories = []
|
||||
template_id_to_cat_id: Dict[int, int] = {}
|
||||
for cat_idx, tmpl in enumerate(templates, start=1):
|
||||
categories.append({
|
||||
"id": cat_idx,
|
||||
"name": tmpl.name,
|
||||
"color": tmpl.color,
|
||||
})
|
||||
template_id_to_cat_id[tmpl.id] = cat_idx
|
||||
|
||||
coco_annotations = []
|
||||
ann_id = 1
|
||||
for ann in annotations:
|
||||
if not ann.mask_data:
|
||||
continue
|
||||
polygons = ann.mask_data.get("polygons", [])
|
||||
if not polygons:
|
||||
continue
|
||||
|
||||
# Use first polygon for bbox / area approximation
|
||||
first_poly = polygons[0]
|
||||
xs = [p[0] for p in first_poly]
|
||||
ys = [p[1] for p in first_poly]
|
||||
width = ann.frame.width if ann.frame else 1920
|
||||
height = ann.frame.height if ann.frame else 1080
|
||||
bbox = [
|
||||
min(xs) * width,
|
||||
min(ys) * height,
|
||||
(max(xs) - min(xs)) * width,
|
||||
(max(ys) - min(ys)) * height,
|
||||
]
|
||||
area = bbox[2] * bbox[3]
|
||||
|
||||
segmentation = []
|
||||
for poly in polygons:
|
||||
flat = []
|
||||
for p in poly:
|
||||
flat.append(p[0] * width)
|
||||
flat.append(p[1] * height)
|
||||
segmentation.append(flat)
|
||||
|
||||
coco_annotations.append({
|
||||
"id": ann_id,
|
||||
"image_id": ann.frame_id,
|
||||
"category_id": template_id_to_cat_id.get(ann.template_id, 0),
|
||||
"segmentation": segmentation,
|
||||
"area": area,
|
||||
"bbox": bbox,
|
||||
"iscrowd": 0,
|
||||
})
|
||||
ann_id += 1
|
||||
|
||||
coco = {
|
||||
"info": {
|
||||
"description": f"Annotations for {project.name}",
|
||||
"version": "1.0",
|
||||
"year": datetime.now().year,
|
||||
"date_created": datetime.now().isoformat(),
|
||||
},
|
||||
"images": images,
|
||||
"annotations": coco_annotations,
|
||||
"categories": categories,
|
||||
}
|
||||
|
||||
data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8")
|
||||
filename = f"project_{project_id}_coco.json"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(data),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/masks",
|
||||
summary="Export PNG masks as a ZIP archive",
|
||||
)
|
||||
def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
|
||||
"""Export all annotation masks as individual PNG files inside a ZIP archive."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == project_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for ann in annotations:
|
||||
if not ann.mask_data:
|
||||
continue
|
||||
polygons = ann.mask_data.get("polygons", [])
|
||||
if not polygons:
|
||||
continue
|
||||
|
||||
width = ann.frame.width if ann.frame else 1920
|
||||
height = ann.frame.height if ann.frame else 1080
|
||||
combined = np.zeros((height, width), dtype=np.uint8)
|
||||
|
||||
for poly in polygons:
|
||||
mask = _mask_from_polygon(poly, width, height)
|
||||
combined = np.maximum(combined, mask)
|
||||
|
||||
# Encode PNG
|
||||
import cv2
|
||||
_, encoded = cv2.imencode(".png", combined)
|
||||
fname = f"mask_{ann.id:06d}.png"
|
||||
zf.writestr(fname, encoded.tobytes())
|
||||
|
||||
zip_buffer.seek(0)
|
||||
filename = f"project_{project_id}_masks.zip"
|
||||
|
||||
return StreamingResponse(
|
||||
zip_buffer,
|
||||
media_type="application/zip",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
192
backend/routers/media.py
Normal file
192
backend/routers/media.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Media upload and parsing endpoints."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import upload_file, get_presigned_url
|
||||
from models import Project, Frame
|
||||
from schemas import FrameOut
|
||||
from services.frame_parser import parse_video, parse_dicom, upload_frames_to_minio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/media", tags=["Media"])
|
||||
|
||||
ALLOWED_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".png", ".jpg", ".jpeg", ".dcm"}
|
||||
|
||||
|
||||
def _get_ext(filename: str) -> str:
|
||||
return Path(filename).suffix.lower()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/upload",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Upload a media file",
|
||||
)
|
||||
async def upload_media(
|
||||
file: UploadFile = File(...),
|
||||
project_id: Optional[int] = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Accept a video, image, or DICOM file and store it in MinIO.
|
||||
|
||||
If project_id is provided, the video_path of the project is updated.
|
||||
Returns the presigned URL of the uploaded object.
|
||||
"""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="Missing filename")
|
||||
|
||||
ext = _get_ext(file.filename)
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported file type: {ext}",
|
||||
)
|
||||
|
||||
data = await file.read()
|
||||
object_name = f"uploads/{project_id or 'general'}/{file.filename}"
|
||||
|
||||
try:
|
||||
upload_file(object_name, data, content_type=file.content_type or "application/octet-stream", length=len(data))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Upload failed: %s", exc)
|
||||
raise HTTPException(status_code=500, detail="Upload to storage failed") from exc
|
||||
|
||||
file_url = get_presigned_url(object_name, expires=3600)
|
||||
|
||||
if project_id:
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if project:
|
||||
project.video_path = object_name
|
||||
db.commit()
|
||||
logger.info("Linked upload to project_id=%s", project_id)
|
||||
else:
|
||||
logger.warning("Project id=%s not found for upload linkage", project_id)
|
||||
|
||||
# TODO: enqueue async parsing job (Celery / background task)
|
||||
logger.info("Upload complete: %s (size=%d bytes). Async parsing queued.", object_name, len(data))
|
||||
|
||||
return {
|
||||
"object_name": object_name,
|
||||
"file_url": file_url,
|
||||
"size": len(data),
|
||||
"message": "Upload successful. Parsing job queued.",
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/parse",
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
summary="Trigger frame extraction",
|
||||
)
|
||||
def parse_media(
|
||||
project_id: int,
|
||||
source_type: str = "video", # video | dicom
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Trigger frame extraction for a project's uploaded media.
|
||||
|
||||
* video: uses FFmpeg or OpenCV fallback.
|
||||
* dicom: uses pydicom to read DCM frames.
|
||||
|
||||
Extracted frames are uploaded to MinIO and registered in the database.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
if not project.video_path:
|
||||
raise HTTPException(status_code=400, detail="Project has no media uploaded")
|
||||
|
||||
# Download from MinIO to a temp directory
|
||||
from minio_client import download_file
|
||||
|
||||
try:
|
||||
media_bytes = download_file(project.video_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to download media for parsing: %s", exc)
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve media from storage") from exc
|
||||
|
||||
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project_id}_")
|
||||
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
|
||||
|
||||
with open(local_path, "wb") as f:
|
||||
f.write(media_bytes)
|
||||
|
||||
output_dir = os.path.join(tmp_dir, "frames")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
if source_type == "dicom":
|
||||
# For DICOM, treat local_path as a directory if it contains multiple .dcm
|
||||
# If a single .dcm file was uploaded, put it in its own sub-dir
|
||||
dcm_dir = os.path.join(tmp_dir, "dcm")
|
||||
os.makedirs(dcm_dir, exist_ok=True)
|
||||
if local_path.lower().endswith(".dcm"):
|
||||
shutil.move(local_path, os.path.join(dcm_dir, os.path.basename(local_path)))
|
||||
else:
|
||||
shutil.unpack_archive(local_path, dcm_dir) if shutil.which("unzip") else shutil.move(local_path, dcm_dir)
|
||||
frame_files = parse_dicom(dcm_dir, output_dir)
|
||||
else:
|
||||
frame_files = parse_video(local_path, output_dir, fps=30)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Frame extraction failed: %s", exc)
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
raise HTTPException(status_code=500, detail="Frame extraction failed") from exc
|
||||
|
||||
# Upload frames to MinIO
|
||||
try:
|
||||
object_names = upload_frames_to_minio(frame_files, project_id)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Frame upload failed: %s", exc)
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
raise HTTPException(status_code=500, detail="Frame upload to storage failed") from exc
|
||||
|
||||
# Register frames in DB
|
||||
frames_out = []
|
||||
for idx, obj_name in enumerate(object_names):
|
||||
# Get image dimensions
|
||||
local_frame = frame_files[idx]
|
||||
try:
|
||||
import cv2
|
||||
img = cv2.imread(local_frame)
|
||||
h, w = img.shape[:2] if img is not None else (None, None)
|
||||
except Exception: # noqa: BLE001
|
||||
h, w = None, None
|
||||
|
||||
frame = Frame(
|
||||
project_id=project_id,
|
||||
frame_index=idx,
|
||||
image_url=obj_name,
|
||||
width=w,
|
||||
height=h,
|
||||
)
|
||||
db.add(frame)
|
||||
frames_out.append(frame)
|
||||
|
||||
db.commit()
|
||||
for f in frames_out:
|
||||
db.refresh(f)
|
||||
|
||||
# Cleanup temp files
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
|
||||
project.status = "ready"
|
||||
db.commit()
|
||||
|
||||
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project_id)
|
||||
return {
|
||||
"project_id": project_id,
|
||||
"frames_extracted": len(frames_out),
|
||||
"status": "ready",
|
||||
"message": "Frame extraction completed successfully.",
|
||||
}
|
||||
165
backend/routers/projects.py
Normal file
165
backend/routers/projects.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Project and Frame CRUD endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Project, Frame
|
||||
from schemas import ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/projects", tags=["Projects"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Projects
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post(
|
||||
"",
|
||||
response_model=ProjectOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a new project",
|
||||
)
|
||||
def create_project(payload: ProjectCreate, db: Session = Depends(get_db)) -> Project:
|
||||
"""Create a new segmentation project."""
|
||||
project = Project(**payload.model_dump())
|
||||
db.add(project)
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
logger.info("Created project id=%s name=%s", project.id, project.name)
|
||||
return project
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=List[ProjectOut],
|
||||
summary="List all projects",
|
||||
)
|
||||
def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)) -> List[Project]:
|
||||
"""Retrieve a paginated list of projects."""
|
||||
return db.query(Project).offset(skip).limit(limit).all()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}",
|
||||
response_model=ProjectOut,
|
||||
summary="Get a single project",
|
||||
)
|
||||
def get_project(project_id: int, db: Session = Depends(get_db)) -> Project:
|
||||
"""Retrieve a project by its ID."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
return project
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{project_id}",
|
||||
response_model=ProjectOut,
|
||||
summary="Update a project",
|
||||
)
|
||||
def update_project(
|
||||
project_id: int,
|
||||
payload: ProjectUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Project:
|
||||
"""Update project fields partially."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
for key, value in payload.model_dump(exclude_unset=True).items():
|
||||
setattr(project, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
logger.info("Updated project id=%s", project_id)
|
||||
return project
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{project_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a project",
|
||||
)
|
||||
def delete_project(project_id: int, db: Session = Depends(get_db)) -> None:
|
||||
"""Delete a project and all related frames and annotations."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
db.delete(project)
|
||||
db.commit()
|
||||
logger.info("Deleted project id=%s", project_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frames
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post(
|
||||
"/{project_id}/frames",
|
||||
response_model=FrameOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Add a frame to a project",
|
||||
)
|
||||
def create_frame(
|
||||
project_id: int,
|
||||
payload: FrameCreate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Frame:
|
||||
"""Register a new frame under a project."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
frame = Frame(project_id=project_id, **payload.model_dump(exclude={"project_id"}))
|
||||
db.add(frame)
|
||||
db.commit()
|
||||
db.refresh(frame)
|
||||
return frame
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/frames",
|
||||
response_model=List[FrameOut],
|
||||
summary="List frames for a project",
|
||||
)
|
||||
def list_frames(
|
||||
project_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 1000,
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[Frame]:
|
||||
"""Retrieve all frames belonging to a project."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
return (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id)
|
||||
.order_by(Frame.frame_index)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/frames/{frame_id}",
|
||||
response_model=FrameOut,
|
||||
summary="Get a single frame",
|
||||
)
|
||||
def get_frame(project_id: int, frame_id: int, db: Session = Depends(get_db)) -> Frame:
|
||||
"""Retrieve a specific frame by ID."""
|
||||
frame = (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id, Frame.id == frame_id)
|
||||
.first()
|
||||
)
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
return frame
|
||||
97
backend/routers/templates.py
Normal file
97
backend/routers/templates.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Template (Ontology) CRUD endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Template
|
||||
from schemas import TemplateCreate, TemplateOut, TemplateUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/templates", tags=["Templates"])
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=TemplateOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a new template",
|
||||
)
|
||||
def create_template(payload: TemplateCreate, db: Session = Depends(get_db)) -> Template:
|
||||
"""Create a new ontology template / segmentation class."""
|
||||
template = Template(**payload.model_dump())
|
||||
db.add(template)
|
||||
db.commit()
|
||||
db.refresh(template)
|
||||
logger.info("Created template id=%s name=%s", template.id, template.name)
|
||||
return template
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=List[TemplateOut],
|
||||
summary="List all templates",
|
||||
)
|
||||
def list_templates(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[Template]:
|
||||
"""Retrieve all ontology templates."""
|
||||
return db.query(Template).offset(skip).limit(limit).all()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{template_id}",
|
||||
response_model=TemplateOut,
|
||||
summary="Get a single template",
|
||||
)
|
||||
def get_template(template_id: int, db: Session = Depends(get_db)) -> Template:
|
||||
"""Retrieve a template by its ID."""
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
return template
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{template_id}",
|
||||
response_model=TemplateOut,
|
||||
summary="Update a template",
|
||||
)
|
||||
def update_template(
|
||||
template_id: int,
|
||||
payload: TemplateUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Template:
|
||||
"""Update template fields partially."""
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
for key, value in payload.model_dump(exclude_unset=True).items():
|
||||
setattr(template, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(template)
|
||||
logger.info("Updated template id=%s", template_id)
|
||||
return template
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{template_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a template",
|
||||
)
|
||||
def delete_template(template_id: int, db: Session = Depends(get_db)) -> None:
|
||||
"""Delete a template. Associated annotations will have template_id set to NULL."""
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
db.delete(template)
|
||||
db.commit()
|
||||
logger.info("Deleted template id=%s", template_id)
|
||||
157
backend/schemas.py
Normal file
157
backend/schemas.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""Pydantic schemas for request/response validation."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional, Any
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Project schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class ProjectBase(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
video_path: Optional[str] = None
|
||||
status: Optional[str] = "pending"
|
||||
|
||||
|
||||
class ProjectCreate(ProjectBase):
|
||||
pass
|
||||
|
||||
|
||||
class ProjectUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
video_path: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
|
||||
|
||||
class ProjectOut(ProjectBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frame schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class FrameBase(BaseModel):
|
||||
frame_index: int
|
||||
image_url: str
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
|
||||
|
||||
class FrameCreate(FrameBase):
|
||||
project_id: int
|
||||
|
||||
|
||||
class FrameOut(FrameBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
project_id: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Template schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class TemplateBase(BaseModel):
|
||||
name: str
|
||||
color: str
|
||||
z_index: int = 0
|
||||
mapping_rules: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class TemplateCreate(TemplateBase):
|
||||
pass
|
||||
|
||||
|
||||
class TemplateUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
z_index: Optional[int] = None
|
||||
mapping_rules: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class TemplateOut(TemplateBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Annotation schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class AnnotationBase(BaseModel):
|
||||
project_id: int
|
||||
frame_id: Optional[int] = None
|
||||
template_id: Optional[int] = None
|
||||
mask_data: Optional[dict[str, Any]] = None
|
||||
points: Optional[list[list[float]]] = None
|
||||
bbox: Optional[list[float]] = None
|
||||
|
||||
|
||||
class AnnotationCreate(AnnotationBase):
|
||||
pass
|
||||
|
||||
|
||||
class AnnotationUpdate(BaseModel):
|
||||
mask_data: Optional[dict[str, Any]] = None
|
||||
points: Optional[list[list[float]]] = None
|
||||
bbox: Optional[list[float]] = None
|
||||
template_id: Optional[int] = None
|
||||
|
||||
|
||||
class AnnotationOut(AnnotationBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mask schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class MaskBase(BaseModel):
|
||||
annotation_id: int
|
||||
mask_url: str
|
||||
format: str = "png"
|
||||
|
||||
|
||||
class MaskCreate(MaskBase):
|
||||
pass
|
||||
|
||||
|
||||
class MaskOut(MaskBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AI schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class PredictRequest(BaseModel):
|
||||
image_id: int
|
||||
prompt_type: str # point / box / semantic
|
||||
prompt_data: Any
|
||||
|
||||
|
||||
class PredictResponse(BaseModel):
|
||||
polygons: list[list[list[float]]]
|
||||
scores: Optional[list[float]] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Export schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class ExportStatus(BaseModel):
|
||||
url: str
|
||||
format: str
|
||||
0
backend/services/__init__.py
Normal file
0
backend/services/__init__.py
Normal file
186
backend/services/frame_parser.py
Normal file
186
backend/services/frame_parser.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""Video/DICOM frame parsing and MinIO upload utilities."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pydicom import dcmread
|
||||
|
||||
from minio_client import upload_file, BUCKET_NAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_video(
|
||||
video_path: str,
|
||||
output_dir: str,
|
||||
fps: int = 30,
|
||||
max_frames: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""Extract frames from a video file using FFmpeg or OpenCV fallback.
|
||||
|
||||
Args:
|
||||
video_path: Path to the input video file.
|
||||
output_dir: Directory to save extracted frames.
|
||||
fps: Target frame extraction rate.
|
||||
max_frames: Optional maximum number of frames to extract.
|
||||
|
||||
Returns:
|
||||
List of paths to extracted frame images.
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
frame_paths: List[str] = []
|
||||
|
||||
# Try FFmpeg first
|
||||
if shutil.which("ffmpeg"):
|
||||
try:
|
||||
pattern = os.path.join(output_dir, "frame_%06d.png")
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-i", video_path,
|
||||
"-vf", f"fps={fps},scale='min(1920,iw)':-1",
|
||||
"-pix_fmt", "rgb24",
|
||||
"-y",
|
||||
pattern,
|
||||
]
|
||||
logger.info("Running FFmpeg: %s", " ".join(cmd))
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
|
||||
if result.returncode == 0:
|
||||
frame_paths = sorted(
|
||||
[os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith(".png")]
|
||||
)
|
||||
if max_frames:
|
||||
frame_paths = frame_paths[:max_frames]
|
||||
logger.info("Extracted %d frames via FFmpeg", len(frame_paths))
|
||||
return frame_paths
|
||||
else:
|
||||
logger.warning("FFmpeg failed: %s", result.stderr)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("FFmpeg exception: %s", exc)
|
||||
|
||||
# OpenCV fallback
|
||||
logger.info("Falling back to OpenCV frame extraction")
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
raise RuntimeError(f"Cannot open video: {video_path}")
|
||||
|
||||
video_fps = cap.get(cv2.CAP_PROP_FPS) or 30
|
||||
interval = max(1, int(round(video_fps / fps)))
|
||||
count = 0
|
||||
saved = 0
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
if count % interval == 0:
|
||||
path = os.path.join(output_dir, f"frame_{saved:06d}.png")
|
||||
cv2.imwrite(path, frame)
|
||||
frame_paths.append(path)
|
||||
saved += 1
|
||||
if max_frames and saved >= max_frames:
|
||||
break
|
||||
count += 1
|
||||
|
||||
cap.release()
|
||||
logger.info("Extracted %d frames via OpenCV", len(frame_paths))
|
||||
return frame_paths
|
||||
|
||||
|
||||
def parse_dicom(
|
||||
dicom_dir: str,
|
||||
output_dir: str,
|
||||
max_frames: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""Extract frames from DICOM files in a directory.
|
||||
|
||||
Args:
|
||||
dicom_dir: Directory containing .dcm files.
|
||||
output_dir: Directory to save extracted frames.
|
||||
max_frames: Optional maximum number of frames to extract.
|
||||
|
||||
Returns:
|
||||
List of paths to extracted frame images.
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
dcm_files = sorted(
|
||||
[f for f in os.listdir(dicom_dir) if f.lower().endswith(".dcm")]
|
||||
)
|
||||
|
||||
frame_paths: List[str] = []
|
||||
for idx, fname in enumerate(dcm_files):
|
||||
if max_frames and idx >= max_frames:
|
||||
break
|
||||
path = os.path.join(dicom_dir, fname)
|
||||
try:
|
||||
ds = dcmread(path)
|
||||
pixel_array = ds.pixel_array
|
||||
|
||||
# Normalize to 8-bit
|
||||
if pixel_array.dtype != np.uint8:
|
||||
pixel_array = pixel_array.astype(np.float32)
|
||||
pixel_array = (
|
||||
(pixel_array - pixel_array.min())
|
||||
/ (pixel_array.max() - pixel_array.min() + 1e-8)
|
||||
* 255
|
||||
)
|
||||
pixel_array = pixel_array.astype(np.uint8)
|
||||
|
||||
# Handle multi-frame DICOM
|
||||
if pixel_array.ndim == 3:
|
||||
for f in range(pixel_array.shape[0]):
|
||||
out_path = os.path.join(output_dir, f"frame_{idx:06d}_{f:03d}.png")
|
||||
cv2.imwrite(out_path, pixel_array[f])
|
||||
frame_paths.append(out_path)
|
||||
else:
|
||||
out_path = os.path.join(output_dir, f"frame_{idx:06d}.png")
|
||||
cv2.imwrite(out_path, pixel_array)
|
||||
frame_paths.append(out_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to read DICOM %s: %s", path, exc)
|
||||
|
||||
logger.info("Extracted %d frames from DICOM", len(frame_paths))
|
||||
return frame_paths
|
||||
|
||||
|
||||
def upload_frames_to_minio(
|
||||
frames: List[str],
|
||||
project_id: int,
|
||||
object_prefix: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""Upload a list of local frame images to MinIO.
|
||||
|
||||
Args:
|
||||
frames: List of local file paths.
|
||||
project_id: Project ID used for bucket path organization.
|
||||
object_prefix: Optional prefix override.
|
||||
|
||||
Returns:
|
||||
List of object names (paths) in MinIO.
|
||||
"""
|
||||
prefix = object_prefix or f"projects/{project_id}/frames"
|
||||
object_names: List[str] = []
|
||||
|
||||
for frame_path in frames:
|
||||
fname = os.path.basename(frame_path)
|
||||
object_name = f"{prefix}/{fname}"
|
||||
try:
|
||||
with open(frame_path, "rb") as f:
|
||||
data = f.read()
|
||||
upload_file(
|
||||
object_name,
|
||||
data,
|
||||
content_type="image/png",
|
||||
length=len(data),
|
||||
)
|
||||
object_names.append(object_name)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to upload %s: %s", frame_path, exc)
|
||||
|
||||
logger.info("Uploaded %d/%d frames to MinIO", len(object_names), len(frames))
|
||||
return object_names
|
||||
234
backend/services/sam2_engine.py
Normal file
234
backend/services/sam2_engine.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""SAM 2 engine wrapper with lazy loading and fallback stubs."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attempt to import SAM 2; fall back to stubs if unavailable.
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
import torch
|
||||
from sam2.build_sam import build_sam2
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
|
||||
SAM2_AVAILABLE = True
|
||||
logger.info("SAM2 library imported successfully.")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
SAM2_AVAILABLE = False
|
||||
logger.warning("SAM2 import failed (%s). Using stub engine.", exc)
|
||||
|
||||
|
||||
class SAM2Engine:
|
||||
"""Lazy-loaded SAM 2 inference engine."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._predictor: Optional[SAM2ImagePredictor] = None
|
||||
self._model_loaded = False
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# -----------------------------------------------------------------------
|
||||
def _load_model(self) -> None:
|
||||
"""Load the SAM 2 model and predictor on first use."""
|
||||
if self._model_loaded:
|
||||
return
|
||||
|
||||
if not SAM2_AVAILABLE:
|
||||
logger.warning("SAM2 not available; skipping model load.")
|
||||
self._model_loaded = True
|
||||
return
|
||||
|
||||
if not os.path.isfile(settings.sam_model_path):
|
||||
logger.error("SAM checkpoint not found at %s", settings.sam_model_path)
|
||||
self._model_loaded = True
|
||||
return
|
||||
|
||||
try:
|
||||
model = build_sam2(
|
||||
settings.sam_model_config,
|
||||
settings.sam_model_path,
|
||||
device="cuda",
|
||||
)
|
||||
self._predictor = SAM2ImagePredictor(model)
|
||||
self._model_loaded = True
|
||||
logger.info("SAM 2 model loaded from %s", settings.sam_model_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to load SAM 2 model: %s", exc)
|
||||
self._model_loaded = True # Prevent repeated load attempts
|
||||
|
||||
def _ensure_ready(self) -> bool:
|
||||
"""Ensure the model is loaded; return whether it is usable."""
|
||||
self._load_model()
|
||||
return SAM2_AVAILABLE and self._predictor is not None
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Public API
|
||||
# -----------------------------------------------------------------------
|
||||
def predict_points(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
points: list[list[float]],
|
||||
labels: list[int],
|
||||
) -> tuple[list[list[list[float]]], list[float]]:
|
||||
"""Run point-prompt segmentation.
|
||||
|
||||
Args:
|
||||
image: HWC numpy array (uint8).
|
||||
points: List of [x, y] normalized coordinates (0-1).
|
||||
labels: 1 for foreground, 0 for background.
|
||||
|
||||
Returns:
|
||||
Tuple of (polygons, scores).
|
||||
"""
|
||||
if not self._ensure_ready():
|
||||
logger.warning("SAM2 not ready; returning dummy masks.")
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
try:
|
||||
h, w = image.shape[:2]
|
||||
pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32)
|
||||
lbls = np.array(labels, dtype=np.int32)
|
||||
|
||||
with torch.inference_mode(): # type: ignore[name-defined]
|
||||
self._predictor.set_image(image)
|
||||
masks, scores, _ = self._predictor.predict(
|
||||
point_coords=pts,
|
||||
point_labels=lbls,
|
||||
multimask_output=True,
|
||||
)
|
||||
|
||||
polygons = []
|
||||
for m in masks:
|
||||
poly = self._mask_to_polygon(m)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
return polygons, scores.tolist()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("SAM2 point prediction failed: %s", exc)
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
def predict_box(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
box: list[float],
|
||||
) -> tuple[list[list[list[float]]], list[float]]:
|
||||
"""Run box-prompt segmentation.
|
||||
|
||||
Args:
|
||||
image: HWC numpy array (uint8).
|
||||
box: [x1, y1, x2, y2] normalized coordinates.
|
||||
|
||||
Returns:
|
||||
Tuple of (polygons, scores).
|
||||
"""
|
||||
if not self._ensure_ready():
|
||||
logger.warning("SAM2 not ready; returning dummy masks.")
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
try:
|
||||
h, w = image.shape[:2]
|
||||
bbox = np.array(
|
||||
[box[0] * w, box[1] * h, box[2] * w, box[3] * h],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
with torch.inference_mode(): # type: ignore[name-defined]
|
||||
self._predictor.set_image(image)
|
||||
masks, scores, _ = self._predictor.predict(
|
||||
box=bbox[None, :],
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
polygons = []
|
||||
for m in masks:
|
||||
poly = self._mask_to_polygon(m)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
return polygons, scores.tolist()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("SAM2 box prediction failed: %s", exc)
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
def predict_auto(self, image: np.ndarray) -> tuple[list[list[list[float]]], list[float]]:
|
||||
"""Run automatic mask generation (grid of points).
|
||||
|
||||
Args:
|
||||
image: HWC numpy array (uint8).
|
||||
|
||||
Returns:
|
||||
Tuple of (polygons, scores).
|
||||
"""
|
||||
if not self._ensure_ready():
|
||||
logger.warning("SAM2 not ready; returning dummy masks.")
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
try:
|
||||
with torch.inference_mode(): # type: ignore[name-defined]
|
||||
self._predictor.set_image(image)
|
||||
# Generate a uniform 16x16 grid of point prompts
|
||||
h, w = image.shape[:2]
|
||||
grid = np.mgrid[0:1:17j, 0:1:17j].reshape(2, -1).T
|
||||
pts = grid * np.array([w, h])
|
||||
lbls = np.ones(pts.shape[0], dtype=np.int32)
|
||||
|
||||
masks, scores, _ = self._predictor.predict(
|
||||
point_coords=pts,
|
||||
point_labels=lbls,
|
||||
multimask_output=True,
|
||||
)
|
||||
|
||||
polygons = []
|
||||
for m in masks[:3]: # Limit to top 3 masks
|
||||
poly = self._mask_to_polygon(m)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
return polygons, scores[:3].tolist()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("SAM2 auto prediction failed: %s", exc)
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -----------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
|
||||
"""Convert a binary mask to a normalized polygon."""
|
||||
import cv2
|
||||
|
||||
if mask.dtype != np.uint8:
|
||||
mask = (mask > 0).astype(np.uint8)
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
h, w = mask.shape[:2]
|
||||
largest = []
|
||||
for cnt in contours:
|
||||
if len(cnt) > len(largest):
|
||||
largest = cnt
|
||||
if len(largest) < 3:
|
||||
return []
|
||||
return [[float(pt[0][0]) / w, float(pt[0][1]) / h] for pt in largest]
|
||||
|
||||
@staticmethod
|
||||
def _dummy_polygons(w: int, h: int) -> list[list[list[float]]]:
|
||||
"""Return a dummy rectangle polygon for fallback mode."""
|
||||
return [
|
||||
[
|
||||
[0.25, 0.25],
|
||||
[0.75, 0.25],
|
||||
[0.75, 0.75],
|
||||
[0.25, 0.75],
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
sam_engine = SAM2Engine()
|
||||
Reference in New Issue
Block a user