添加Docker自包含部署分支
- 新增 Seg_Server_Docker 自包含部署内容,包含前后端、FastAPI、Celery、PostgreSQL、Redis、MinIO、演示视频和 DICOM 数据。 - 保留 demo 数据以支持恢复演示出厂设置,排除 SAM 2.1 .pt 权重并在 README 中补充下载命令。 - 补充 GPU 部署、backend/worker 镜像复用、frpc/frps + NPM 公网域名反代部署说明。 - 在 .env/.env.example 中用 # XXXX 标注局域网和公网域名部署需要修改的配置项。 - 添加部署分支 .gitignore,忽略本地模型权重、构建产物、缓存和日志。
This commit is contained in:
21
backend/celery_app.py
Normal file
21
backend/celery_app.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Celery application for background processing."""
|
||||
|
||||
from celery import Celery
|
||||
|
||||
from config import settings
|
||||
|
||||
celery_app = Celery(
|
||||
"seg_server",
|
||||
broker=settings.redis_url,
|
||||
backend=settings.redis_url,
|
||||
include=["worker_tasks"],
|
||||
)
|
||||
|
||||
celery_app.conf.update(
|
||||
task_serializer="json",
|
||||
result_serializer="json",
|
||||
accept_content=["json"],
|
||||
timezone="Asia/Shanghai",
|
||||
enable_utc=True,
|
||||
task_track_started=True,
|
||||
)
|
||||
51
backend/config.py
Normal file
51
backend/config.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Application configuration using Pydantic Settings."""
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
# Database
|
||||
db_url: str = "postgresql://seguser:segpass123@localhost:5432/segserver"
|
||||
|
||||
# Redis
|
||||
redis_url: str = "redis://localhost:6379/0"
|
||||
|
||||
# MinIO
|
||||
minio_endpoint: str = "192.168.3.11:9000"
|
||||
minio_public_endpoint: str | None = None
|
||||
minio_access_key: str = "minioadmin"
|
||||
minio_secret_key: str = "minioadmin"
|
||||
minio_secure: bool = False
|
||||
|
||||
# SAM
|
||||
sam_default_model: str = "sam2.1_hiera_tiny"
|
||||
sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2.1_hiera_tiny.pt"
|
||||
sam_model_config: str = "configs/sam2.1/sam2.1_hiera_t.yaml"
|
||||
sam3_model_version: str = "sam3"
|
||||
sam3_checkpoint_path: str = "/home/wkmgc/Desktop/Seg_Server/sam3权重/sam3.pt"
|
||||
sam3_external_enabled: bool = False
|
||||
sam3_external_python: str = "/home/wkmgc/miniconda3/envs/sam3/bin/python"
|
||||
sam3_timeout_seconds: int = 300
|
||||
sam3_status_cache_seconds: int = 30
|
||||
sam3_confidence_threshold: float = 0.5
|
||||
|
||||
# App
|
||||
app_env: str = "development"
|
||||
cors_origins: list[str] = ["http://localhost:3000", "http://192.168.3.11:3000"]
|
||||
jwt_secret_key: str = "seg-server-dev-secret-change-me"
|
||||
jwt_algorithm: str = "HS256"
|
||||
access_token_expire_minutes: int = 60 * 24
|
||||
default_admin_username: str = "admin"
|
||||
default_admin_password: str = "123456"
|
||||
demo_video_path: str = "/home/wkmgc/Desktop/Seg_Server/demo/演视LC视频序列.mp4"
|
||||
demo_dicom_dir: str = "/home/wkmgc/Desktop/Seg_Server/demo/演视DICOM序列"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
29
backend/database.py
Normal file
29
backend/database.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Database configuration using synchronous SQLAlchemy."""
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker, Session
|
||||
from fastapi import Depends
|
||||
from typing import Generator
|
||||
|
||||
from config import settings
|
||||
|
||||
engine = create_engine(
|
||||
settings.db_url,
|
||||
pool_pre_ping=True,
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""FastAPI dependency that yields a database session."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
317
backend/main.py
Normal file
317
backend/main.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""FastAPI application entrypoint."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager, suppress
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sqlalchemy import inspect, text
|
||||
|
||||
from config import settings
|
||||
from database import Base, engine, SessionLocal
|
||||
from minio_client import ensure_bucket_exists
|
||||
from progress_events import PROGRESS_CHANNEL
|
||||
from redis_client import get_redis_client, ping as redis_ping
|
||||
|
||||
from routers import projects, templates, media, ai, export, auth, dashboard, tasks, admin
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ensure_runtime_schema_columns() -> None:
|
||||
"""Add nullable columns introduced after initial create_all deployments."""
|
||||
try:
|
||||
inspector = inspect(engine)
|
||||
frame_columns = {column["name"] for column in inspector.get_columns("frames")}
|
||||
project_columns = {column["name"] for column in inspector.get_columns("projects")}
|
||||
template_columns = {column["name"] for column in inspector.get_columns("templates")}
|
||||
with engine.begin() as connection:
|
||||
if "timestamp_ms" not in frame_columns:
|
||||
connection.execute(text("ALTER TABLE frames ADD COLUMN timestamp_ms FLOAT"))
|
||||
if "source_frame_number" not in frame_columns:
|
||||
connection.execute(text("ALTER TABLE frames ADD COLUMN source_frame_number INTEGER"))
|
||||
if "owner_user_id" not in project_columns:
|
||||
connection.execute(text("ALTER TABLE projects ADD COLUMN owner_user_id INTEGER"))
|
||||
if "owner_user_id" not in template_columns:
|
||||
connection.execute(text("ALTER TABLE templates ADD COLUMN owner_user_id INTEGER"))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Runtime schema column check failed: %s", exc)
|
||||
|
||||
|
||||
def _seed_default_admin_sync() -> None:
|
||||
"""Ensure the single default admin exists without rewriting project ownership metadata."""
|
||||
from routers.auth import ensure_default_admin
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
admin = ensure_default_admin(db)
|
||||
db.commit()
|
||||
logger.info("Default admin ready id=%s", admin.id)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to seed default admin: %s", exc)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _seed_default_project_sync() -> None:
|
||||
"""Synchronously seed the bundled demo video and DICOM projects on first startup."""
|
||||
from models import Project
|
||||
from routers.auth import ensure_default_admin
|
||||
from services.demo_media import (
|
||||
DEMO_DICOM_PROJECT_NAME,
|
||||
DEMO_VIDEO_PROJECT_NAME,
|
||||
LEGACY_DEMO_DICOM_PROJECT_NAMES,
|
||||
LEGACY_DEMO_VIDEO_PROJECT_NAMES,
|
||||
create_parsed_dicom_demo_project,
|
||||
create_parsed_video_demo_project,
|
||||
demo_dicom_files,
|
||||
)
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
admin = ensure_default_admin(db)
|
||||
legacy_video = (
|
||||
db.query(Project)
|
||||
.filter(Project.name.in_(LEGACY_DEMO_VIDEO_PROJECT_NAMES))
|
||||
.first()
|
||||
)
|
||||
if legacy_video is not None:
|
||||
legacy_video.name = DEMO_VIDEO_PROJECT_NAME
|
||||
db.commit()
|
||||
legacy_dicom = (
|
||||
db.query(Project)
|
||||
.filter(Project.name.in_(LEGACY_DEMO_DICOM_PROJECT_NAMES))
|
||||
.first()
|
||||
)
|
||||
if legacy_dicom is not None:
|
||||
legacy_dicom.name = DEMO_DICOM_PROJECT_NAME
|
||||
db.commit()
|
||||
existing_video = db.query(Project).filter(Project.name == DEMO_VIDEO_PROJECT_NAME).first()
|
||||
if existing_video is None and os.path.exists(settings.demo_video_path):
|
||||
video_project = create_parsed_video_demo_project(
|
||||
db,
|
||||
owner=admin,
|
||||
video_path=settings.demo_video_path,
|
||||
project_name=DEMO_VIDEO_PROJECT_NAME,
|
||||
)
|
||||
logger.info("Seeded default video project id=%s", video_project.id)
|
||||
|
||||
existing_dicom = db.query(Project).filter(Project.name == DEMO_DICOM_PROJECT_NAME).first()
|
||||
if existing_dicom is not None:
|
||||
return
|
||||
|
||||
if not demo_dicom_files(settings.demo_dicom_dir):
|
||||
logger.warning("Default DICOM series not found at %s", settings.demo_dicom_dir)
|
||||
return
|
||||
|
||||
project = create_parsed_dicom_demo_project(
|
||||
db,
|
||||
owner=admin,
|
||||
dicom_dir=settings.demo_dicom_dir,
|
||||
project_name=DEMO_DICOM_PROJECT_NAME,
|
||||
)
|
||||
logger.info("Seeded default DICOM project id=%s with %d frames", project.id, len(project.frames))
|
||||
except Exception as exc:
|
||||
logger.error("Failed to seed default project: %s", exc)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _seed_default_templates_sync() -> None:
|
||||
"""Seed default ontology templates on first startup."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ensure_default_templates(db)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to seed default templates: %s", exc)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def ensure_default_templates(db) -> None:
|
||||
"""Ensure all bundled system templates exist."""
|
||||
from services.default_templates import ensure_default_templates as _ensure_default_templates
|
||||
|
||||
_ensure_default_templates(db)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan: startup and shutdown hooks."""
|
||||
progress_listener: asyncio.Task | None = None
|
||||
# Startup
|
||||
logger.info("Starting up SegServer backend...")
|
||||
|
||||
# Initialize database tables
|
||||
try:
|
||||
Base.metadata.create_all(bind=engine)
|
||||
_ensure_runtime_schema_columns()
|
||||
_seed_default_admin_sync()
|
||||
logger.info("Database tables initialized.")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Database initialization failed: %s", exc)
|
||||
|
||||
# Check MinIO bucket
|
||||
try:
|
||||
ensure_bucket_exists()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("MinIO bucket check failed: %s", exc)
|
||||
|
||||
# Check Redis
|
||||
if redis_ping():
|
||||
logger.info("Redis connection OK.")
|
||||
else:
|
||||
logger.warning("Redis connection failed.")
|
||||
|
||||
try:
|
||||
progress_listener = asyncio.create_task(_progress_pubsub_loop())
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to start Redis progress subscription: %s", exc)
|
||||
|
||||
# Seed default templates
|
||||
try:
|
||||
asyncio.create_task(asyncio.to_thread(_seed_default_templates_sync))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to start default template seeding: %s", exc)
|
||||
|
||||
# Seed default project in background thread so it doesn't block startup
|
||||
try:
|
||||
asyncio.create_task(asyncio.to_thread(_seed_default_project_sync))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to start default project seeding: %s", exc)
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down SegServer backend...")
|
||||
if progress_listener is not None:
|
||||
progress_listener.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await progress_listener
|
||||
engine.dispose()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="SegServer API",
|
||||
description="Semantic Segmentation System Backend",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Routers
|
||||
app.include_router(auth.router)
|
||||
app.include_router(projects.router)
|
||||
app.include_router(templates.router)
|
||||
app.include_router(media.router)
|
||||
app.include_router(ai.router)
|
||||
app.include_router(export.router)
|
||||
app.include_router(dashboard.router)
|
||||
app.include_router(tasks.router)
|
||||
app.include_router(admin.router)
|
||||
|
||||
|
||||
@app.get("/health", tags=["Health"])
|
||||
def health_check() -> dict:
|
||||
"""Health check endpoint."""
|
||||
return {"status": "ok", "service": "SegServer"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WebSocket: 实时进度推送
|
||||
# ---------------------------------------------------------------------------
|
||||
class ConnectionManager:
|
||||
"""Manage WebSocket connections for progress broadcasting."""
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: list[WebSocket] = []
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
self.active_connections.append(websocket)
|
||||
logger.info("WebSocket client connected. Total: %d", len(self.active_connections))
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
if websocket in self.active_connections:
|
||||
self.active_connections.remove(websocket)
|
||||
logger.info("WebSocket client disconnected. Total: %d", len(self.active_connections))
|
||||
|
||||
async def broadcast(self, message: dict):
|
||||
"""Broadcast a message to all connected clients."""
|
||||
for connection in self.active_connections.copy():
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except Exception as exc:
|
||||
logger.warning("WebSocket send failed: %s", exc)
|
||||
self.disconnect(connection)
|
||||
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
|
||||
async def _progress_pubsub_loop() -> None:
|
||||
"""Forward Redis task-progress events to connected WebSocket clients."""
|
||||
while True:
|
||||
pubsub = None
|
||||
try:
|
||||
pubsub = get_redis_client().pubsub()
|
||||
await asyncio.to_thread(pubsub.subscribe, PROGRESS_CHANNEL)
|
||||
logger.info("Subscribed to Redis progress channel: %s", PROGRESS_CHANNEL)
|
||||
while True:
|
||||
message = await asyncio.to_thread(pubsub.get_message, True, 1.0)
|
||||
if message is None:
|
||||
await asyncio.sleep(0)
|
||||
continue
|
||||
raw_data = message.get("data")
|
||||
payload = json.loads(raw_data) if isinstance(raw_data, str) else raw_data
|
||||
if isinstance(payload, dict):
|
||||
await manager.broadcast(payload)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Redis progress subscription failed: %s", exc)
|
||||
await asyncio.sleep(5)
|
||||
finally:
|
||||
if pubsub is not None:
|
||||
with suppress(Exception):
|
||||
await asyncio.to_thread(pubsub.close)
|
||||
|
||||
|
||||
@app.websocket("/ws/progress")
|
||||
async def websocket_progress(websocket: WebSocket):
|
||||
"""WebSocket endpoint for real-time parsing/AI progress updates."""
|
||||
await manager.connect(websocket)
|
||||
try:
|
||||
while True:
|
||||
# Receive client messages (heartbeat / subscription requests)
|
||||
data = await websocket.receive_text()
|
||||
logger.debug("WebSocket received: %s", data)
|
||||
|
||||
# Echo heartbeat to keep connection alive
|
||||
await websocket.send_json({
|
||||
"type": "status",
|
||||
"status": "connected",
|
||||
"message": "Progress stream active",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket)
|
||||
except Exception as exc:
|
||||
logger.error("WebSocket error: %s", exc)
|
||||
manager.disconnect(websocket)
|
||||
142
backend/minio_client.py
Normal file
142
backend/minio_client.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""MinIO client wrapper for object storage operations."""
|
||||
|
||||
import io
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BUCKET_NAME = "seg-media"
|
||||
|
||||
_minio_client: Optional[Minio] = None
|
||||
_minio_public_client: Optional[Minio] = None
|
||||
|
||||
|
||||
def get_minio_client() -> Minio:
|
||||
"""Return a singleton MinIO client instance."""
|
||||
global _minio_client
|
||||
if _minio_client is None:
|
||||
_minio_client = Minio(
|
||||
settings.minio_endpoint,
|
||||
access_key=settings.minio_access_key,
|
||||
secret_key=settings.minio_secret_key,
|
||||
secure=settings.minio_secure,
|
||||
)
|
||||
return _minio_client
|
||||
|
||||
|
||||
def get_minio_public_client() -> Minio:
|
||||
"""Return a MinIO client configured for browser-facing presigned URLs."""
|
||||
global _minio_public_client
|
||||
if _minio_public_client is None:
|
||||
endpoint = settings.minio_public_endpoint or settings.minio_endpoint
|
||||
_minio_public_client = Minio(
|
||||
endpoint,
|
||||
access_key=settings.minio_access_key,
|
||||
secret_key=settings.minio_secret_key,
|
||||
secure=settings.minio_secure,
|
||||
)
|
||||
return _minio_public_client
|
||||
|
||||
|
||||
def ensure_bucket_exists() -> None:
|
||||
"""Create the bucket if it does not already exist."""
|
||||
client = get_minio_client()
|
||||
try:
|
||||
if not client.bucket_exists(BUCKET_NAME):
|
||||
client.make_bucket(BUCKET_NAME)
|
||||
logger.info("Created MinIO bucket: %s", BUCKET_NAME)
|
||||
else:
|
||||
logger.info("MinIO bucket %s already exists", BUCKET_NAME)
|
||||
except S3Error as exc:
|
||||
logger.error("MinIO bucket check/creation failed: %s", exc)
|
||||
raise
|
||||
|
||||
|
||||
def upload_file(
|
||||
object_name: str,
|
||||
data: bytes,
|
||||
content_type: str = "application/octet-stream",
|
||||
length: int = -1,
|
||||
) -> str:
|
||||
"""Upload bytes to MinIO and return the object name.
|
||||
|
||||
Args:
|
||||
object_name: Destination path inside the bucket.
|
||||
data: Raw bytes or a file-like object.
|
||||
content_type: MIME type of the object.
|
||||
length: Object size; -1 for unknown (uses chunked upload).
|
||||
|
||||
Returns:
|
||||
The object name (same as input).
|
||||
"""
|
||||
client = get_minio_client()
|
||||
if isinstance(data, bytes):
|
||||
data = io.BytesIO(data)
|
||||
length = len(data.getvalue())
|
||||
|
||||
try:
|
||||
client.put_object(
|
||||
BUCKET_NAME,
|
||||
object_name,
|
||||
data,
|
||||
length=length,
|
||||
content_type=content_type,
|
||||
)
|
||||
logger.info("Uploaded to MinIO: %s", object_name)
|
||||
return object_name
|
||||
except S3Error as exc:
|
||||
logger.error("MinIO upload failed: %s", exc)
|
||||
raise
|
||||
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
def get_presigned_url(
|
||||
object_name: str,
|
||||
expires: int = 3600,
|
||||
method: str = "GET",
|
||||
) -> str:
|
||||
"""Generate a presigned URL for an object.
|
||||
|
||||
Args:
|
||||
object_name: Path inside the bucket.
|
||||
expires: Expiration time in seconds (default 1 hour).
|
||||
method: HTTP method (GET or PUT).
|
||||
|
||||
Returns:
|
||||
Presigned URL string.
|
||||
"""
|
||||
client = get_minio_public_client()
|
||||
try:
|
||||
url = client.get_presigned_url(method, BUCKET_NAME, object_name, expires=timedelta(seconds=expires))
|
||||
return url
|
||||
except S3Error as exc:
|
||||
logger.error("MinIO presigned URL failed: %s", exc)
|
||||
raise
|
||||
|
||||
|
||||
def download_file(object_name: str) -> bytes:
|
||||
"""Download an object from MinIO and return its bytes.
|
||||
|
||||
Args:
|
||||
object_name: Path inside the bucket.
|
||||
|
||||
Returns:
|
||||
Raw bytes of the object.
|
||||
"""
|
||||
client = get_minio_client()
|
||||
try:
|
||||
response = client.get_object(BUCKET_NAME, object_name)
|
||||
data = response.read()
|
||||
response.close()
|
||||
response.release_conn()
|
||||
return data
|
||||
except S3Error as exc:
|
||||
logger.error("MinIO download failed: %s", exc)
|
||||
raise
|
||||
195
backend/models.py
Normal file
195
backend/models.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""SQLAlchemy ORM models."""
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
JSON,
|
||||
Float,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from database import Base
|
||||
from statuses import PROJECT_STATUS_PENDING
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""Application user used for authentication and data ownership."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
username = Column(String(150), unique=True, index=True, nullable=False)
|
||||
password_hash = Column(String(255), nullable=False)
|
||||
role = Column(String(50), default="annotator", nullable=False)
|
||||
is_active = Column(Integer, default=1, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
projects = relationship("Project", back_populates="owner")
|
||||
templates = relationship("Template", back_populates="owner")
|
||||
|
||||
|
||||
class Project(Base):
|
||||
"""Project model representing a segmentation project."""
|
||||
|
||||
__tablename__ = "projects"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
video_path = Column(String(512), nullable=True)
|
||||
thumbnail_url = Column(String(512), nullable=True)
|
||||
status = Column(String(50), default=PROJECT_STATUS_PENDING, nullable=False)
|
||||
source_type = Column(String(20), default="video", nullable=False) # video | dicom
|
||||
original_fps = Column(Float, nullable=True)
|
||||
parse_fps = Column(Float, default=30.0, nullable=False)
|
||||
owner_user_id = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
owner = relationship("User", back_populates="projects")
|
||||
frames = relationship("Frame", back_populates="project", cascade="all, delete-orphan")
|
||||
annotations = relationship(
|
||||
"Annotation", back_populates="project", cascade="all, delete-orphan"
|
||||
)
|
||||
tasks = relationship(
|
||||
"ProcessingTask", back_populates="project", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class Frame(Base):
|
||||
"""Frame model representing an extracted video frame."""
|
||||
|
||||
__tablename__ = "frames"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
project_id = Column(Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
||||
frame_index = Column(Integer, nullable=False)
|
||||
image_url = Column(String(512), nullable=False)
|
||||
width = Column(Integer, nullable=True)
|
||||
height = Column(Integer, nullable=True)
|
||||
timestamp_ms = Column(Float, nullable=True)
|
||||
source_frame_number = Column(Integer, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
project = relationship("Project", back_populates="frames")
|
||||
annotations = relationship(
|
||||
"Annotation", back_populates="frame", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class Template(Base):
|
||||
"""Template (Ontology) model for segmentation classes."""
|
||||
|
||||
__tablename__ = "templates"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
color = Column(String(50), nullable=False)
|
||||
z_index = Column(Integer, default=0, nullable=False)
|
||||
mapping_rules = Column(JSON, nullable=True)
|
||||
owner_user_id = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
owner = relationship("User", back_populates="templates")
|
||||
annotations = relationship(
|
||||
"Annotation", back_populates="template", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class Annotation(Base):
|
||||
"""Annotation model for segmentation masks and prompts."""
|
||||
|
||||
__tablename__ = "annotations"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
project_id = Column(
|
||||
Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
frame_id = Column(
|
||||
Integer, ForeignKey("frames.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
template_id = Column(
|
||||
Integer, ForeignKey("templates.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
mask_data = Column(JSON, nullable=True)
|
||||
points = Column(JSON, nullable=True)
|
||||
bbox = Column(JSON, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
project = relationship("Project", back_populates="annotations")
|
||||
frame = relationship("Frame", back_populates="annotations")
|
||||
template = relationship("Template", back_populates="annotations")
|
||||
masks = relationship("Mask", back_populates="annotation", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class Mask(Base):
|
||||
"""Mask model for exported/derived mask files."""
|
||||
|
||||
__tablename__ = "masks"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
annotation_id = Column(
|
||||
Integer, ForeignKey("annotations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
mask_url = Column(String(512), nullable=False)
|
||||
format = Column(String(50), default="png", nullable=False) # png / rle / json
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
annotation = relationship("Annotation", back_populates="masks")
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
"""Audit trail for security and administrative actions."""
|
||||
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
actor_user_id = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
action = Column(String(120), nullable=False)
|
||||
target_type = Column(String(80), nullable=True)
|
||||
target_id = Column(String(120), nullable=True)
|
||||
detail = Column(JSON, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
actor = relationship("User")
|
||||
|
||||
|
||||
class ProcessingTask(Base):
|
||||
"""Background task state persisted for dashboard and polling."""
|
||||
|
||||
__tablename__ = "processing_tasks"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
task_type = Column(String(80), nullable=False)
|
||||
status = Column(String(40), default="queued", nullable=False)
|
||||
progress = Column(Integer, default=0, nullable=False)
|
||||
message = Column(Text, nullable=True)
|
||||
project_id = Column(
|
||||
Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
celery_task_id = Column(String(255), nullable=True)
|
||||
payload = Column(JSON, nullable=True)
|
||||
result = Column(JSON, nullable=True)
|
||||
error = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
started_at = Column(DateTime(timezone=True), nullable=True)
|
||||
finished_at = Column(DateTime(timezone=True), nullable=True)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
project = relationship("Project", back_populates="tasks")
|
||||
66
backend/progress_events.py
Normal file
66
backend/progress_events.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Progress event payloads and Redis publication helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from redis_client import get_redis_client
|
||||
from statuses import TASK_STATUS_CANCELLED, TASK_STATUS_FAILED, TASK_STATUS_SUCCESS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROGRESS_CHANNEL = "seg:progress"
|
||||
|
||||
|
||||
def _iso_now() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _event_type(task_status: str) -> str:
|
||||
if task_status == TASK_STATUS_SUCCESS:
|
||||
return "complete"
|
||||
if task_status == TASK_STATUS_CANCELLED:
|
||||
return "cancelled"
|
||||
if task_status == TASK_STATUS_FAILED:
|
||||
return "error"
|
||||
return "progress"
|
||||
|
||||
|
||||
def task_progress_payload(task: Any) -> dict[str, Any]:
|
||||
"""Build the WebSocket payload from a persisted processing task."""
|
||||
project = getattr(task, "project", None)
|
||||
project_name = getattr(project, "name", None)
|
||||
status = getattr(task, "status", "")
|
||||
updated_at = getattr(task, "updated_at", None)
|
||||
timestamp = updated_at.isoformat() if updated_at is not None else _iso_now()
|
||||
message = getattr(task, "message", None)
|
||||
|
||||
return {
|
||||
"type": _event_type(status),
|
||||
"taskId": f"task-{task.id}",
|
||||
"task_id": task.id,
|
||||
"project_id": getattr(task, "project_id", None),
|
||||
"projectName": project_name,
|
||||
"filename": project_name,
|
||||
"progress": getattr(task, "progress", 0),
|
||||
"status": message or status,
|
||||
"message": message,
|
||||
"error": getattr(task, "error", None),
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
|
||||
|
||||
def publish_progress_event(payload: dict[str, Any]) -> None:
|
||||
"""Publish a JSON progress event without failing the worker on Redis errors."""
|
||||
try:
|
||||
get_redis_client().publish(PROGRESS_CHANNEL, json.dumps(payload, ensure_ascii=False))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Failed to publish progress event: %s", exc)
|
||||
|
||||
|
||||
def publish_task_progress_event(task: Any) -> None:
|
||||
"""Publish a progress event for a ProcessingTask ORM object."""
|
||||
publish_progress_event(task_progress_payload(task))
|
||||
61
backend/redis_client.py
Normal file
61
backend/redis_client.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Redis client wrapper for caching and task queuing."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Any
|
||||
|
||||
import redis
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_redis_client: Optional[redis.Redis] = None
|
||||
|
||||
|
||||
def get_redis_client() -> redis.Redis:
|
||||
"""Return a singleton Redis client instance."""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = redis.from_url(settings.redis_url, decode_responses=True)
|
||||
return _redis_client
|
||||
|
||||
|
||||
def ping() -> bool:
|
||||
"""Check Redis connectivity."""
|
||||
try:
|
||||
return get_redis_client().ping()
|
||||
except redis.ConnectionError as exc:
|
||||
logger.error("Redis ping failed: %s", exc)
|
||||
return False
|
||||
|
||||
|
||||
def set_json(key: str, value: Any, expire: Optional[int] = None) -> None:
|
||||
"""Store a JSON-serializable value in Redis."""
|
||||
client = get_redis_client()
|
||||
try:
|
||||
client.set(key, json.dumps(value), ex=expire)
|
||||
except redis.RedisError as exc:
|
||||
logger.error("Redis set_json failed: %s", exc)
|
||||
raise
|
||||
|
||||
|
||||
def get_json(key: str) -> Optional[Any]:
|
||||
"""Retrieve and deserialize a JSON value from Redis."""
|
||||
client = get_redis_client()
|
||||
try:
|
||||
data = client.get(key)
|
||||
return json.loads(data) if data is not None else None
|
||||
except redis.RedisError as exc:
|
||||
logger.error("Redis get_json failed: %s", exc)
|
||||
raise
|
||||
|
||||
|
||||
def delete_key(key: str) -> int:
|
||||
"""Delete a key from Redis. Returns number of deleted keys."""
|
||||
client = get_redis_client()
|
||||
try:
|
||||
return client.delete(key)
|
||||
except redis.RedisError as exc:
|
||||
logger.error("Redis delete_key failed: %s", exc)
|
||||
raise
|
||||
20
backend/requirements-docker.txt
Normal file
20
backend/requirements-docker.txt
Normal file
@@ -0,0 +1,20 @@
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
python-multipart
|
||||
sqlalchemy
|
||||
psycopg2-binary
|
||||
redis
|
||||
celery
|
||||
minio
|
||||
opencv-python-headless
|
||||
pillow
|
||||
scikit-image
|
||||
pydicom
|
||||
numpy
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
||||
sam2
|
||||
pydantic-settings
|
||||
python-jose[cryptography]
|
||||
passlib[bcrypt]
|
||||
0
backend/routers/__init__.py
Normal file
0
backend/routers/__init__.py
Normal file
299
backend/routers/admin.py
Normal file
299
backend/routers/admin.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""Administrator-only user and audit management endpoints."""
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config import settings
|
||||
from database import get_db
|
||||
from models import Annotation, AuditLog, Frame, Mask, ProcessingTask, Project, Template, User
|
||||
from routers.auth import SUPPORTED_ROLES, ensure_default_admin, hash_password, normalize_user_role, require_admin, write_audit_log
|
||||
from schemas import (
|
||||
AdminUserCreate,
|
||||
AdminUserUpdate,
|
||||
AuditLogOut,
|
||||
DemoFactoryResetOut,
|
||||
DemoFactoryResetRequest,
|
||||
UserOut,
|
||||
)
|
||||
from services.demo_media import (
|
||||
DEMO_DICOM_PROJECT_NAME,
|
||||
DEMO_VIDEO_PROJECT_NAME,
|
||||
create_parsed_dicom_demo_project,
|
||||
create_parsed_video_demo_project,
|
||||
demo_dicom_files,
|
||||
)
|
||||
from services.default_templates import restore_default_templates
|
||||
|
||||
router = APIRouter(prefix="/api/admin", tags=["Admin"])
|
||||
|
||||
DEMO_RESET_CONFIRMATION = "RESET_DEMO_FACTORY"
|
||||
DEMO_PROJECT_NAME = DEMO_DICOM_PROJECT_NAME
|
||||
|
||||
|
||||
def _normalize_role(role: str | None) -> str:
|
||||
normalized = (role or "annotator").strip().lower()
|
||||
if normalized not in SUPPORTED_ROLES:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported role: {role}")
|
||||
return normalized
|
||||
|
||||
|
||||
def _assert_non_admin_role(role: str) -> None:
|
||||
if role == "admin":
|
||||
raise HTTPException(status_code=400, detail="Only the default admin account can have admin role")
|
||||
|
||||
|
||||
@router.get("/users", response_model=List[UserOut], summary="List users")
|
||||
def list_users(
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> List[User]:
|
||||
"""Return all users for the administrator console."""
|
||||
_ = admin_user
|
||||
users = db.query(User).order_by(User.id).all()
|
||||
return [normalize_user_role(db, user) for user in users]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/users",
|
||||
response_model=UserOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create user",
|
||||
)
|
||||
def create_user(
|
||||
payload: AdminUserCreate,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> User:
|
||||
"""Create a user with an initial password and role."""
|
||||
username = payload.username.strip()
|
||||
if not username:
|
||||
raise HTTPException(status_code=400, detail="Username is required")
|
||||
if len(payload.password) < 6:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
|
||||
role = _normalize_role(payload.role)
|
||||
_assert_non_admin_role(role)
|
||||
user = User(
|
||||
username=username,
|
||||
password_hash=hash_password(payload.password),
|
||||
role=role,
|
||||
is_active=1 if payload.is_active else 0,
|
||||
)
|
||||
db.add(user)
|
||||
try:
|
||||
db.commit()
|
||||
except IntegrityError as exc:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=409, detail="Username already exists") from exc
|
||||
db.refresh(user)
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=admin_user,
|
||||
action="admin.user_created",
|
||||
target_type="user",
|
||||
target_id=user.id,
|
||||
detail={"username": user.username, "role": user.role, "is_active": bool(user.is_active)},
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@router.patch("/users/{user_id}", response_model=UserOut, summary="Update user")
|
||||
def update_user(
|
||||
user_id: int,
|
||||
payload: AdminUserUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> User:
|
||||
"""Update username, password, role or active state."""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
user = normalize_user_role(db, user)
|
||||
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
audit_detail: dict = {"before": {"username": user.username, "role": user.role, "is_active": bool(user.is_active)}}
|
||||
if "username" in updates:
|
||||
username = (updates["username"] or "").strip()
|
||||
if not username:
|
||||
raise HTTPException(status_code=400, detail="Username is required")
|
||||
if user.role == "admin" and username != settings.default_admin_username:
|
||||
raise HTTPException(status_code=400, detail="Default admin username cannot be changed")
|
||||
user.username = username
|
||||
if "password" in updates:
|
||||
password = updates["password"] or ""
|
||||
if len(password) < 6:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
|
||||
user.password_hash = hash_password(password)
|
||||
if "role" in updates:
|
||||
next_role = _normalize_role(updates["role"])
|
||||
if user.username == settings.default_admin_username:
|
||||
if next_role != "admin":
|
||||
raise HTTPException(status_code=400, detail="Cannot remove the default admin role")
|
||||
else:
|
||||
_assert_non_admin_role(next_role)
|
||||
user.role = next_role
|
||||
if "is_active" in updates:
|
||||
if user.id == admin_user.id and not updates["is_active"]:
|
||||
raise HTTPException(status_code=400, detail="Cannot deactivate yourself")
|
||||
user.is_active = 1 if updates["is_active"] else 0
|
||||
|
||||
try:
|
||||
db.commit()
|
||||
except IntegrityError as exc:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=409, detail="Username already exists") from exc
|
||||
db.refresh(user)
|
||||
audit_detail["after"] = {"username": user.username, "role": user.role, "is_active": bool(user.is_active)}
|
||||
audit_detail["password_changed"] = "password" in updates
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=admin_user,
|
||||
action="admin.user_updated",
|
||||
target_type="user",
|
||||
target_id=user.id,
|
||||
detail=audit_detail,
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete user")
|
||||
def delete_user(
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> None:
|
||||
"""Delete a user when it is safe to remove the account."""
|
||||
if user_id == admin_user.id:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete yourself")
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
user = normalize_user_role(db, user)
|
||||
if user.role == "admin":
|
||||
raise HTTPException(status_code=400, detail="Cannot delete the default admin account")
|
||||
username = user.username
|
||||
db.delete(user)
|
||||
db.commit()
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=admin_user,
|
||||
action="admin.user_deleted",
|
||||
target_type="user",
|
||||
target_id=user_id,
|
||||
detail={"username": username},
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/audit-logs", response_model=List[AuditLogOut], summary="List audit logs")
|
||||
def list_audit_logs(
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> List[AuditLog]:
|
||||
"""Return recent audit events for administrators."""
|
||||
_ = admin_user
|
||||
safe_limit = min(max(int(limit or 100), 1), 500)
|
||||
return db.query(AuditLog).order_by(AuditLog.created_at.desc(), AuditLog.id.desc()).limit(safe_limit).all()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/demo-factory-reset",
|
||||
response_model=DemoFactoryResetOut,
|
||||
summary="Reset demo data to factory defaults",
|
||||
)
|
||||
def reset_demo_factory(
|
||||
payload: DemoFactoryResetRequest,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> dict:
|
||||
"""Reset a demo deployment to one admin account, the demo video, and the demo DICOM project."""
|
||||
if payload.confirmation != DEMO_RESET_CONFIRMATION:
|
||||
raise HTTPException(status_code=400, detail="Invalid reset confirmation")
|
||||
|
||||
if not os.path.exists(settings.demo_video_path):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Demo video not found: {settings.demo_video_path}",
|
||||
)
|
||||
if not demo_dicom_files(settings.demo_dicom_dir):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Demo DICOM series not found: {settings.demo_dicom_dir}",
|
||||
)
|
||||
|
||||
requested_by = admin_user.username
|
||||
preserved_admin = ensure_default_admin(db)
|
||||
preserved_admin.username = settings.default_admin_username
|
||||
preserved_admin.password_hash = hash_password(settings.default_admin_password)
|
||||
preserved_admin.role = "admin"
|
||||
preserved_admin.is_active = 1
|
||||
db.flush()
|
||||
|
||||
deleted_counts = {
|
||||
"masks": db.query(Mask).delete(synchronize_session=False),
|
||||
"annotations": db.query(Annotation).delete(synchronize_session=False),
|
||||
"frames": db.query(Frame).delete(synchronize_session=False),
|
||||
"tasks": db.query(ProcessingTask).delete(synchronize_session=False),
|
||||
"projects": db.query(Project).delete(synchronize_session=False),
|
||||
"user_templates": db.query(Template).filter(Template.owner_user_id.is_not(None)).delete(synchronize_session=False),
|
||||
"audit_logs": db.query(AuditLog).delete(synchronize_session=False),
|
||||
"users": db.query(User).filter(User.id != preserved_admin.id).delete(synchronize_session=False),
|
||||
}
|
||||
db.flush()
|
||||
db.expunge_all()
|
||||
|
||||
preserved_admin = db.query(User).filter(User.username == settings.default_admin_username).first()
|
||||
if not preserved_admin:
|
||||
raise HTTPException(status_code=500, detail="Default admin was not preserved")
|
||||
|
||||
restored_templates = restore_default_templates(db)
|
||||
|
||||
video_project = create_parsed_video_demo_project(
|
||||
db,
|
||||
owner=preserved_admin,
|
||||
video_path=settings.demo_video_path,
|
||||
project_name=DEMO_VIDEO_PROJECT_NAME,
|
||||
)
|
||||
|
||||
dicom_project = create_parsed_dicom_demo_project(
|
||||
db,
|
||||
owner=preserved_admin,
|
||||
dicom_dir=settings.demo_dicom_dir,
|
||||
project_name=DEMO_PROJECT_NAME,
|
||||
)
|
||||
db.refresh(preserved_admin)
|
||||
db.refresh(video_project)
|
||||
db.refresh(dicom_project)
|
||||
video_project.frame_count = len(video_project.frames)
|
||||
dicom_project.frame_count = len(dicom_project.frames)
|
||||
projects = [video_project, dicom_project]
|
||||
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=preserved_admin,
|
||||
action="admin.demo_factory_reset",
|
||||
target_type="project",
|
||||
target_id=dicom_project.id,
|
||||
detail={
|
||||
"project_names": [project.name for project in projects],
|
||||
"video_path": video_project.video_path,
|
||||
"dicom_path": dicom_project.video_path,
|
||||
"source_types": [project.source_type for project in projects],
|
||||
"frame_counts": {project.name: len(project.frames) for project in projects},
|
||||
"deleted_counts": deleted_counts,
|
||||
"restored_templates": [template.name for template in restored_templates],
|
||||
"requested_by": requested_by,
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"admin_user": preserved_admin,
|
||||
"project": dicom_project,
|
||||
"projects": projects,
|
||||
"deleted_counts": deleted_counts,
|
||||
"message": "演示环境已恢复出厂设置",
|
||||
}
|
||||
1228
backend/routers/ai.py
Normal file
1228
backend/routers/ai.py
Normal file
File diff suppressed because it is too large
Load Diff
222
backend/routers/auth.py
Normal file
222
backend/routers/auth.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Authentication endpoints and dependencies."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config import settings
|
||||
from database import get_db
|
||||
from models import AuditLog, User
|
||||
from schemas import LoginResponse, UserOut
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["Auth"])
|
||||
password_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
SUPPORTED_ROLES = {"admin", "annotator"}
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a plain password for storage."""
|
||||
return password_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(password: str, password_hash: str) -> bool:
|
||||
"""Verify a plain password against a stored hash."""
|
||||
return password_context.verify(password, password_hash)
|
||||
|
||||
|
||||
def create_access_token(user: User, expires_delta: timedelta | None = None) -> str:
|
||||
"""Create a signed JWT access token for a user."""
|
||||
expire = datetime.now(timezone.utc) + (
|
||||
expires_delta or timedelta(minutes=settings.access_token_expire_minutes)
|
||||
)
|
||||
payload: dict[str, Any] = {
|
||||
"sub": str(user.id),
|
||||
"username": user.username,
|
||||
"role": user.role,
|
||||
"exp": expire,
|
||||
}
|
||||
return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
|
||||
|
||||
|
||||
def ensure_default_admin(db: Session) -> User:
|
||||
"""Create and enforce the single default administrator account."""
|
||||
existing = db.query(User).filter(User.username == settings.default_admin_username).first()
|
||||
if existing:
|
||||
changed = False
|
||||
if existing.role != "admin":
|
||||
existing.role = "admin"
|
||||
changed = True
|
||||
if not existing.is_active:
|
||||
existing.is_active = 1
|
||||
changed = True
|
||||
extra_admins = db.query(User).filter(
|
||||
User.role == "admin",
|
||||
User.id != existing.id,
|
||||
).all()
|
||||
for user in extra_admins:
|
||||
user.role = "annotator"
|
||||
changed = True
|
||||
if changed:
|
||||
db.commit()
|
||||
db.refresh(existing)
|
||||
return existing
|
||||
user = User(
|
||||
username=settings.default_admin_username,
|
||||
password_hash=hash_password(settings.default_admin_password),
|
||||
role="admin",
|
||||
is_active=1,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
extra_admins = db.query(User).filter(
|
||||
User.role == "admin",
|
||||
User.id != user.id,
|
||||
).all()
|
||||
if extra_admins:
|
||||
for extra_user in extra_admins:
|
||||
extra_user.role = "annotator"
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
def normalize_user_role(db: Session, user: User) -> User:
|
||||
"""Keep legacy accounts within the current two-role policy."""
|
||||
desired_role = "admin" if user.username == settings.default_admin_username else "annotator"
|
||||
changed = False
|
||||
if user.role != desired_role:
|
||||
user.role = desired_role
|
||||
changed = True
|
||||
if user.username == settings.default_admin_username and not user.is_active:
|
||||
user.is_active = 1
|
||||
changed = True
|
||||
if changed:
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
||||
db: Session = Depends(get_db),
|
||||
) -> User:
|
||||
"""Resolve and validate the current user from the Bearer token."""
|
||||
if credentials is None or credentials.scheme.lower() != "bearer":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
credentials.credentials,
|
||||
settings.jwt_secret_key,
|
||||
algorithms=[settings.jwt_algorithm],
|
||||
)
|
||||
user_id = int(payload.get("sub"))
|
||||
except (JWTError, TypeError, ValueError) as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from exc
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user:
|
||||
user = normalize_user_role(db, user)
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Inactive or missing user",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
def require_admin(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""Require the current user to have the administrator role."""
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin permission required")
|
||||
return current_user
|
||||
|
||||
|
||||
def require_editor(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""Require a user role that can modify segmentation data."""
|
||||
if current_user.role not in SUPPORTED_ROLES:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Edit permission required")
|
||||
return current_user
|
||||
|
||||
|
||||
def write_audit_log(
|
||||
db: Session,
|
||||
*,
|
||||
actor: User | None,
|
||||
action: str,
|
||||
target_type: str | None = None,
|
||||
target_id: str | int | None = None,
|
||||
detail: dict[str, Any] | None = None,
|
||||
) -> AuditLog:
|
||||
"""Persist a compact audit event."""
|
||||
log = AuditLog(
|
||||
actor_user_id=actor.id if actor else None,
|
||||
action=action,
|
||||
target_type=target_type,
|
||||
target_id=str(target_id) if target_id is not None else None,
|
||||
detail=detail or {},
|
||||
)
|
||||
db.add(log)
|
||||
db.commit()
|
||||
db.refresh(log)
|
||||
return log
|
||||
|
||||
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
def login(payload: LoginRequest, db: Session = Depends(get_db)) -> dict:
|
||||
"""Authenticate a user and return a signed JWT."""
|
||||
ensure_default_admin(db)
|
||||
user = db.query(User).filter(User.username == payload.username).first()
|
||||
if user:
|
||||
user = normalize_user_role(db, user)
|
||||
if not user or not user.is_active or not verify_password(payload.password, user.password_hash):
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=None,
|
||||
action="auth.login_failed",
|
||||
target_type="user",
|
||||
target_id=payload.username,
|
||||
detail={"username": payload.username},
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=user,
|
||||
action="auth.login_success",
|
||||
target_type="user",
|
||||
target_id=user.id,
|
||||
detail={"username": user.username},
|
||||
)
|
||||
return {
|
||||
"token": create_access_token(user),
|
||||
"token_type": "bearer",
|
||||
"username": user.username,
|
||||
"user": user,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserOut)
|
||||
def read_current_user(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""Return the authenticated user profile."""
|
||||
return current_user
|
||||
164
backend/routers/dashboard.py
Normal file
164
backend/routers/dashboard.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Dashboard overview endpoints."""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Annotation, Frame, ProcessingTask, Project, Template, User
|
||||
from routers.auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
|
||||
|
||||
ACTIVE_TASK_STATUSES = {"queued", "running"}
|
||||
MONITORED_TASK_STATUSES = {"queued", "running", "success", "failed", "cancelled"}
|
||||
|
||||
|
||||
def _system_load_percent() -> int:
|
||||
"""Return a real host load estimate without adding a psutil dependency."""
|
||||
try:
|
||||
load_1m = os.getloadavg()[0]
|
||||
cpu_count = os.cpu_count() or 1
|
||||
return min(100, max(0, round((load_1m / cpu_count) * 100)))
|
||||
except (AttributeError, OSError):
|
||||
return 0
|
||||
|
||||
|
||||
def _iso_or_none(value: datetime | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
return value.isoformat()
|
||||
|
||||
|
||||
def _task_payload(task: ProcessingTask) -> dict[str, Any]:
|
||||
result = task.result or {}
|
||||
return {
|
||||
"id": f"task-{task.id}",
|
||||
"task_id": task.id,
|
||||
"project_id": task.project_id or 0,
|
||||
"name": task.project.name if task.project else f"任务 {task.id}",
|
||||
"progress": task.progress,
|
||||
"status": task.message or task.status,
|
||||
"raw_status": task.status,
|
||||
"frame_count": result.get("frames_extracted", result.get("processed_frame_count", 0)),
|
||||
"error": task.error,
|
||||
"updated_at": _iso_or_none(task.updated_at),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/overview", summary="Get dashboard overview")
|
||||
def get_dashboard_overview(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict[str, Any]:
|
||||
"""Return live dashboard data derived from persisted backend records."""
|
||||
shared_project_ids_query = db.query(Project.id)
|
||||
project_count = db.query(func.count(Project.id)).scalar() or 0
|
||||
frame_count = db.query(func.count(Frame.id)).filter(Frame.project_id.in_(shared_project_ids_query)).scalar() or 0
|
||||
annotation_count = (
|
||||
db.query(func.count(Annotation.id))
|
||||
.filter(Annotation.project_id.in_(shared_project_ids_query))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
template_count = (
|
||||
db.query(func.count(Template.id))
|
||||
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
active_task_count = (
|
||||
db.query(func.count(ProcessingTask.id))
|
||||
.outerjoin(Project, Project.id == ProcessingTask.project_id)
|
||||
.filter(ProcessingTask.status.in_(ACTIVE_TASK_STATUSES))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
projects = (
|
||||
db.query(Project)
|
||||
.order_by(Project.updated_at.desc())
|
||||
.all()
|
||||
)
|
||||
recent_tasks = (
|
||||
db.query(ProcessingTask)
|
||||
.outerjoin(Project, Project.id == ProcessingTask.project_id)
|
||||
.order_by(ProcessingTask.created_at.desc())
|
||||
.limit(50)
|
||||
.all()
|
||||
)
|
||||
tasks = [_task_payload(task) for task in recent_tasks if task.status in MONITORED_TASK_STATUSES]
|
||||
|
||||
activities: list[dict[str, Any]] = []
|
||||
for task in recent_tasks[:10]:
|
||||
project_name = task.project.name if task.project else f"项目 {task.project_id}"
|
||||
activities.append({
|
||||
"id": f"task-{task.id}",
|
||||
"kind": "task",
|
||||
"time": _iso_or_none(task.updated_at),
|
||||
"message": task.message or f"任务状态: {task.status}",
|
||||
"project": project_name,
|
||||
})
|
||||
|
||||
for project in projects[:10]:
|
||||
activities.append({
|
||||
"id": f"project-{project.id}",
|
||||
"kind": "project",
|
||||
"time": _iso_or_none(project.updated_at),
|
||||
"message": f"项目状态: {project.status}",
|
||||
"project": project.name,
|
||||
})
|
||||
|
||||
recent_annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id.in_(shared_project_ids_query))
|
||||
.order_by(Annotation.updated_at.desc())
|
||||
.limit(10)
|
||||
.all()
|
||||
)
|
||||
for annotation in recent_annotations:
|
||||
project_name = annotation.project.name if annotation.project else f"项目 {annotation.project_id}"
|
||||
activities.append({
|
||||
"id": f"annotation-{annotation.id}",
|
||||
"kind": "annotation",
|
||||
"time": _iso_or_none(annotation.updated_at),
|
||||
"message": f"标注已更新 #{annotation.id}",
|
||||
"project": project_name,
|
||||
})
|
||||
|
||||
recent_templates = (
|
||||
db.query(Template)
|
||||
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
|
||||
.order_by(Template.created_at.desc())
|
||||
.limit(10)
|
||||
.all()
|
||||
)
|
||||
for template in recent_templates:
|
||||
activities.append({
|
||||
"id": f"template-{template.id}",
|
||||
"kind": "template",
|
||||
"time": _iso_or_none(template.created_at),
|
||||
"message": f"模板可用: {template.name}",
|
||||
"project": "系统",
|
||||
})
|
||||
|
||||
activities.sort(key=lambda item: item["time"] or "", reverse=True)
|
||||
|
||||
return {
|
||||
"summary": {
|
||||
"project_count": project_count,
|
||||
"parsing_task_count": active_task_count,
|
||||
"annotation_count": annotation_count,
|
||||
"frame_count": frame_count,
|
||||
"template_count": template_count,
|
||||
"system_load_percent": _system_load_percent(),
|
||||
},
|
||||
"tasks": tasks,
|
||||
"activity": activities[:10],
|
||||
}
|
||||
764
backend/routers/export.py
Normal file
764
backend/routers/export.py
Normal file
@@ -0,0 +1,764 @@
|
||||
"""Annotation export endpoints (COCO, PNG masks)."""
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from urllib.parse import quote
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import download_file
|
||||
from models import Project, Annotation, Frame, Template, User
|
||||
from routers.auth import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/export", tags=["Export"])
|
||||
|
||||
|
||||
def _mask_from_polygon(
|
||||
polygon: List[List[float]],
|
||||
width: int,
|
||||
height: int,
|
||||
) -> np.ndarray:
|
||||
"""Render a normalized polygon to a binary mask."""
|
||||
import cv2
|
||||
|
||||
pts = np.array(
|
||||
[[int(p[0] * width), int(p[1] * height)] for p in polygon],
|
||||
dtype=np.int32,
|
||||
)
|
||||
mask = np.zeros((height, width), dtype=np.uint8)
|
||||
cv2.fillPoly(mask, [pts], 255)
|
||||
return mask
|
||||
|
||||
|
||||
def _annotation_z_index(annotation: Annotation) -> int:
|
||||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||
if isinstance(class_meta, dict) and class_meta.get("zIndex") is not None:
|
||||
try:
|
||||
return int(class_meta["zIndex"])
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
if annotation.template and annotation.template.z_index is not None:
|
||||
return int(annotation.template.z_index)
|
||||
return 0
|
||||
|
||||
|
||||
def _annotation_mask_id(annotation: Annotation) -> int | None:
|
||||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||
if isinstance(class_meta, dict):
|
||||
for key in ("maskId", "maskid", "mask_id"):
|
||||
if class_meta.get(key) is None:
|
||||
continue
|
||||
try:
|
||||
value = int(class_meta[key])
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if value >= 0:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _annotation_category_name(annotation: Annotation) -> str:
|
||||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||
if isinstance(class_meta, dict) and class_meta.get("category"):
|
||||
return str(class_meta["category"])
|
||||
if annotation.template and annotation.template.name:
|
||||
return str(annotation.template.name)
|
||||
return ""
|
||||
|
||||
|
||||
def _annotation_class_key(annotation: Annotation) -> str:
|
||||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||
if isinstance(class_meta, dict):
|
||||
if class_meta.get("id"):
|
||||
return f"class:{class_meta['id']}"
|
||||
if class_meta.get("name"):
|
||||
return f"name:{class_meta['name']}"
|
||||
if annotation.template_id:
|
||||
return f"template:{annotation.template_id}"
|
||||
return f"annotation:{annotation.id}"
|
||||
|
||||
|
||||
def _annotation_label(annotation: Annotation) -> str:
|
||||
mask_data = annotation.mask_data or {}
|
||||
class_meta = mask_data.get("class") or {}
|
||||
if isinstance(class_meta, dict) and class_meta.get("name"):
|
||||
return str(class_meta["name"])
|
||||
if mask_data.get("label"):
|
||||
return str(mask_data["label"])
|
||||
if annotation.template and annotation.template.name:
|
||||
return str(annotation.template.name)
|
||||
return f"Annotation {annotation.id}"
|
||||
|
||||
|
||||
def _annotation_color(annotation: Annotation) -> str:
|
||||
mask_data = annotation.mask_data or {}
|
||||
class_meta = mask_data.get("class") or {}
|
||||
if isinstance(class_meta, dict) and class_meta.get("color"):
|
||||
return str(class_meta["color"])
|
||||
if mask_data.get("color"):
|
||||
return str(mask_data["color"])
|
||||
if annotation.template and annotation.template.color:
|
||||
return str(annotation.template.color)
|
||||
return "#ffffff"
|
||||
|
||||
|
||||
def _hex_to_rgb(color: str) -> list[int]:
|
||||
value = str(color or "").strip()
|
||||
if value.startswith("#"):
|
||||
value = value[1:]
|
||||
if len(value) == 3:
|
||||
value = "".join(part * 2 for part in value)
|
||||
if len(value) != 6:
|
||||
return [255, 255, 255]
|
||||
try:
|
||||
return [int(value[i:i + 2], 16) for i in (0, 2, 4)]
|
||||
except ValueError:
|
||||
return [255, 255, 255]
|
||||
|
||||
|
||||
def _safe_filename_part(value: Any, fallback: str = "unknown") -> str:
|
||||
text = str(value or "").strip()
|
||||
if not text:
|
||||
text = fallback
|
||||
text = re.sub(r"[\\/:*?\"<>|\s]+", "_", text)
|
||||
text = re.sub(r"_+", "_", text).strip("._")
|
||||
return text or fallback
|
||||
|
||||
|
||||
def _project_video_name(project: Project) -> str:
|
||||
if project.video_path:
|
||||
stem = Path(project.video_path).name
|
||||
if "." in stem:
|
||||
stem = ".".join(stem.split(".")[:-1])
|
||||
if stem:
|
||||
return _safe_filename_part(stem, f"project_{project.id}")
|
||||
return _safe_filename_part(project.name, f"project_{project.id}")
|
||||
|
||||
|
||||
def _project_export_name(project: Project) -> str:
|
||||
return _safe_filename_part(project.name, f"project_{project.id}")
|
||||
|
||||
|
||||
def _frame_timestamp_ms(frame: Frame, project: Project) -> float:
|
||||
if frame.timestamp_ms is not None:
|
||||
return float(frame.timestamp_ms)
|
||||
fps = project.parse_fps or project.original_fps or 30.0
|
||||
return float(frame.frame_index) * 1000.0 / max(float(fps), 1.0)
|
||||
|
||||
|
||||
def _project_frame_number(frame: Frame) -> int:
|
||||
return int(frame.frame_index) + 1
|
||||
|
||||
|
||||
def _format_timestamp_ms(value: float) -> str:
|
||||
total_ms = max(0, int(round(float(value))))
|
||||
hours = total_ms // 3_600_000
|
||||
minutes = (total_ms % 3_600_000) // 60_000
|
||||
seconds = (total_ms % 60_000) // 1_000
|
||||
milliseconds = total_ms % 1_000
|
||||
return f"{hours}h{minutes:02d}m{seconds:02d}s{milliseconds:03d}ms"
|
||||
|
||||
|
||||
def _frame_export_stem(project: Project, frame: Frame) -> str:
|
||||
return "_".join([
|
||||
_project_video_name(project),
|
||||
_format_timestamp_ms(_frame_timestamp_ms(frame, project)),
|
||||
f"frame{_project_frame_number(frame):06d}",
|
||||
])
|
||||
|
||||
|
||||
def _segmentation_results_filename(project: Project, frames: list[Frame]) -> str:
|
||||
if not frames:
|
||||
return f"{_project_export_name(project)}_seg_T_0h00m00s000ms-0h00m00s000ms_P_0-0.zip"
|
||||
first_frame = frames[0]
|
||||
last_frame = frames[-1]
|
||||
return (
|
||||
f"{_project_export_name(project)}"
|
||||
f"_seg_T_{_format_timestamp_ms(_frame_timestamp_ms(first_frame, project))}"
|
||||
f"-{_format_timestamp_ms(_frame_timestamp_ms(last_frame, project))}"
|
||||
f"_P_{_project_frame_number(first_frame)}-{_project_frame_number(last_frame)}.zip"
|
||||
)
|
||||
|
||||
|
||||
def _download_content_disposition(filename: str) -> str:
|
||||
ascii_fallback = filename.encode("ascii", "ignore").decode("ascii") or "segmentation_results.zip"
|
||||
ascii_fallback = _safe_filename_part(ascii_fallback, "segmentation_results.zip")
|
||||
if not ascii_fallback.endswith(".zip") and filename.endswith(".zip"):
|
||||
ascii_fallback = f"{ascii_fallback}.zip"
|
||||
return f"attachment; filename=\"{ascii_fallback}\"; filename*=UTF-8''{quote(filename)}"
|
||||
|
||||
|
||||
def _frame_image_extension(frame: Frame) -> str:
|
||||
suffix = Path(frame.image_url or "").suffix.lower()
|
||||
return suffix if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} else ".jpg"
|
||||
|
||||
|
||||
def _project_or_404(project_id: int, db: Session, current_user: User) -> Project:
|
||||
_ = current_user
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
return project
|
||||
|
||||
|
||||
def _project_frames(project_id: int, db: Session) -> list[Frame]:
|
||||
return (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id)
|
||||
.order_by(Frame.frame_index)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def _filter_frames(
|
||||
frames: list[Frame],
|
||||
*,
|
||||
scope: str = "all",
|
||||
start_frame: int | None = None,
|
||||
end_frame: int | None = None,
|
||||
frame_id: int | None = None,
|
||||
) -> list[Frame]:
|
||||
if scope == "current":
|
||||
if frame_id is None:
|
||||
raise HTTPException(status_code=400, detail="frame_id is required for current-frame export")
|
||||
selected = [frame for frame in frames if frame.id == frame_id]
|
||||
if not selected:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
return selected
|
||||
|
||||
if scope == "range":
|
||||
if start_frame is None or end_frame is None:
|
||||
raise HTTPException(status_code=400, detail="start_frame and end_frame are required for range export")
|
||||
start = max(1, min(int(start_frame), int(end_frame)))
|
||||
end = max(1, max(int(start_frame), int(end_frame)))
|
||||
return frames[start - 1:end]
|
||||
|
||||
return frames
|
||||
|
||||
|
||||
def _filtered_annotations(project_id: int, frame_ids: set[int], db: Session) -> list[Annotation]:
|
||||
if not frame_ids:
|
||||
return []
|
||||
return (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == project_id)
|
||||
.filter(Annotation.frame_id.in_(frame_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def _build_coco(project: Project, frames: list[Frame], annotations: list[Annotation], templates: list[Template]) -> dict[str, Any]:
|
||||
images = []
|
||||
for frame in frames:
|
||||
images.append({
|
||||
"id": frame.id,
|
||||
"file_name": frame.image_url,
|
||||
"width": frame.width or 1920,
|
||||
"height": frame.height or 1080,
|
||||
"frame_index": frame.frame_index,
|
||||
})
|
||||
|
||||
categories = []
|
||||
template_id_to_cat_id: Dict[int, int] = {}
|
||||
for cat_idx, tmpl in enumerate(templates, start=1):
|
||||
categories.append({
|
||||
"id": cat_idx,
|
||||
"name": tmpl.name,
|
||||
"color": tmpl.color,
|
||||
})
|
||||
template_id_to_cat_id[tmpl.id] = cat_idx
|
||||
|
||||
coco_annotations = []
|
||||
ann_id = 1
|
||||
selected_frame_ids = {frame.id for frame in frames}
|
||||
for ann in annotations:
|
||||
if ann.frame_id not in selected_frame_ids or not ann.mask_data:
|
||||
continue
|
||||
polygons = ann.mask_data.get("polygons", [])
|
||||
if not polygons:
|
||||
continue
|
||||
|
||||
first_poly = polygons[0]
|
||||
xs = [p[0] for p in first_poly]
|
||||
ys = [p[1] for p in first_poly]
|
||||
width = ann.frame.width if ann.frame else 1920
|
||||
height = ann.frame.height if ann.frame else 1080
|
||||
bbox = [
|
||||
min(xs) * width,
|
||||
min(ys) * height,
|
||||
(max(xs) - min(xs)) * width,
|
||||
(max(ys) - min(ys)) * height,
|
||||
]
|
||||
area = bbox[2] * bbox[3]
|
||||
|
||||
segmentation = []
|
||||
for poly in polygons:
|
||||
flat = []
|
||||
for p in poly:
|
||||
flat.append(p[0] * width)
|
||||
flat.append(p[1] * height)
|
||||
segmentation.append(flat)
|
||||
|
||||
coco_annotations.append({
|
||||
"id": ann_id,
|
||||
"image_id": ann.frame_id,
|
||||
"category_id": template_id_to_cat_id.get(ann.template_id, 0),
|
||||
"segmentation": segmentation,
|
||||
"area": area,
|
||||
"bbox": bbox,
|
||||
"iscrowd": 0,
|
||||
})
|
||||
ann_id += 1
|
||||
|
||||
return {
|
||||
"info": {
|
||||
"description": f"Annotations for {project.name}",
|
||||
"version": "1.0",
|
||||
"year": datetime.now().year,
|
||||
"date_created": datetime.now().isoformat(),
|
||||
},
|
||||
"images": images,
|
||||
"annotations": coco_annotations,
|
||||
"categories": categories,
|
||||
}
|
||||
|
||||
|
||||
def _class_mapping_entry(annotation: Annotation) -> dict[str, Any]:
|
||||
return {
|
||||
"key": _annotation_class_key(annotation),
|
||||
"className": _annotation_label(annotation),
|
||||
"chineseName": _annotation_label(annotation),
|
||||
"categoryName": _annotation_category_name(annotation),
|
||||
"color": _annotation_color(annotation),
|
||||
"internalPriority": _annotation_z_index(annotation),
|
||||
"maskidHint": _annotation_mask_id(annotation),
|
||||
"template_id": annotation.template_id,
|
||||
}
|
||||
|
||||
|
||||
def _build_gt_class_mapping(annotations: list[Annotation]) -> tuple[dict[str, int], list[dict[str, Any]]]:
|
||||
entries_by_key: dict[str, dict[str, Any]] = {}
|
||||
for annotation in annotations:
|
||||
if not annotation.mask_data or not annotation.mask_data.get("polygons"):
|
||||
continue
|
||||
entry = _class_mapping_entry(annotation)
|
||||
entries_by_key.setdefault(entry["key"], entry)
|
||||
|
||||
ordered = sorted(
|
||||
entries_by_key.values(),
|
||||
key=lambda item: (
|
||||
item["maskidHint"] if isinstance(item.get("maskidHint"), int) and item["maskidHint"] >= 0 else 10_000_000,
|
||||
str(item["className"]),
|
||||
str(item["key"]),
|
||||
),
|
||||
)
|
||||
key_to_value: dict[str, int] = {}
|
||||
classes: list[dict[str, Any]] = []
|
||||
used_maskids: set[int] = set()
|
||||
next_maskid = 1
|
||||
|
||||
def next_available_maskid() -> int:
|
||||
nonlocal next_maskid
|
||||
while next_maskid in used_maskids:
|
||||
next_maskid += 1
|
||||
if next_maskid > 255:
|
||||
raise HTTPException(status_code=400, detail="GT_label 仅支持 8-bit maskid,类别值必须在 1-255 之间")
|
||||
value = next_maskid
|
||||
used_maskids.add(value)
|
||||
next_maskid += 1
|
||||
return value
|
||||
|
||||
for entry in ordered:
|
||||
hinted_maskid = entry.get("maskidHint")
|
||||
if isinstance(hinted_maskid, int) and hinted_maskid > 255:
|
||||
raise HTTPException(status_code=400, detail="GT_label 仅支持 8-bit maskid,类别值必须在 1-255 之间")
|
||||
if isinstance(hinted_maskid, int) and hinted_maskid == 0:
|
||||
maskid = 0
|
||||
used_maskids.add(maskid)
|
||||
elif isinstance(hinted_maskid, int) and 0 < hinted_maskid <= 255 and hinted_maskid not in used_maskids:
|
||||
maskid = hinted_maskid
|
||||
used_maskids.add(maskid)
|
||||
else:
|
||||
maskid = next_available_maskid()
|
||||
key_to_value[entry["key"]] = maskid
|
||||
classes.append({
|
||||
"gt_pixel_value": maskid,
|
||||
"maskid": maskid,
|
||||
"chineseName": entry["chineseName"],
|
||||
"className": entry["className"],
|
||||
"categoryName": entry["categoryName"],
|
||||
"rgb": _hex_to_rgb(entry["color"]),
|
||||
"color": entry["color"],
|
||||
"key": entry["key"],
|
||||
"template_id": entry["template_id"],
|
||||
})
|
||||
return key_to_value, classes
|
||||
|
||||
|
||||
def _parse_result_outputs(mask_type: str, outputs: str | None) -> set[str]:
|
||||
allowed = {"separate", "gt_label", "pro_label", "mix_label"}
|
||||
if outputs:
|
||||
parsed = {item.strip() for item in outputs.split(",") if item.strip()}
|
||||
invalid = parsed - allowed
|
||||
if invalid:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid outputs: {', '.join(sorted(invalid))}")
|
||||
return parsed or allowed
|
||||
|
||||
if mask_type == "separate":
|
||||
return {"separate"}
|
||||
if mask_type == "gt_label":
|
||||
return {"gt_label"}
|
||||
if mask_type == "pro_label":
|
||||
return {"pro_label"}
|
||||
if mask_type == "mix_label":
|
||||
return {"mix_label"}
|
||||
return allowed
|
||||
|
||||
|
||||
def _write_original_frames(
|
||||
zf: zipfile.ZipFile,
|
||||
project: Project,
|
||||
frames: list[Frame],
|
||||
) -> dict[int, bytes]:
|
||||
image_bytes_by_frame: dict[int, bytes] = {}
|
||||
for frame in frames:
|
||||
image_bytes = download_file(frame.image_url)
|
||||
image_bytes_by_frame[frame.id] = image_bytes
|
||||
zf.writestr(
|
||||
f"原始图片/{_frame_export_stem(project, frame)}{_frame_image_extension(frame)}",
|
||||
image_bytes,
|
||||
)
|
||||
return image_bytes_by_frame
|
||||
|
||||
|
||||
def _decode_original_image(image_bytes: bytes | None, width: int, height: int) -> np.ndarray:
|
||||
import cv2
|
||||
|
||||
if image_bytes:
|
||||
decoded = cv2.imdecode(np.frombuffer(image_bytes, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
if decoded is not None:
|
||||
if decoded.shape[1] != width or decoded.shape[0] != height:
|
||||
decoded = cv2.resize(decoded, (width, height), interpolation=cv2.INTER_AREA)
|
||||
return decoded
|
||||
return np.zeros((height, width, 3), dtype=np.uint8)
|
||||
|
||||
|
||||
def _write_result_mask_outputs(
|
||||
zf: zipfile.ZipFile,
|
||||
project: Project,
|
||||
frames: list[Frame],
|
||||
annotations: list[Annotation],
|
||||
*,
|
||||
outputs: set[str],
|
||||
class_values: dict[str, int],
|
||||
class_mapping: list[dict[str, Any]],
|
||||
original_images: dict[int, bytes],
|
||||
mix_opacity: float,
|
||||
) -> None:
|
||||
import cv2
|
||||
|
||||
include_individual = "separate" in outputs
|
||||
include_semantic = "gt_label" in outputs
|
||||
include_pro_label = "pro_label" in outputs
|
||||
include_mix_label = "mix_label" in outputs
|
||||
class_rgb_by_key = {
|
||||
item["key"]: item.get("rgb") or _hex_to_rgb(item.get("color", "#ffffff"))
|
||||
for item in class_mapping
|
||||
}
|
||||
annotations_by_frame: dict[int, list[Annotation]] = {}
|
||||
selected_frame_ids = {frame.id for frame in frames}
|
||||
for annotation in annotations:
|
||||
if annotation.frame_id not in selected_frame_ids or not annotation.mask_data:
|
||||
continue
|
||||
if not annotation.mask_data.get("polygons"):
|
||||
continue
|
||||
annotations_by_frame.setdefault(annotation.frame_id, []).append(annotation)
|
||||
|
||||
for frame in frames:
|
||||
frame_annotations = annotations_by_frame.get(frame.id, [])
|
||||
if not frame_annotations:
|
||||
continue
|
||||
width = frame.width or 1920
|
||||
height = frame.height or 1080
|
||||
frame_stem = _frame_export_stem(project, frame)
|
||||
|
||||
if include_individual:
|
||||
class_masks: dict[str, np.ndarray] = {}
|
||||
class_meta: dict[str, dict[str, Any]] = {}
|
||||
for annotation in frame_annotations:
|
||||
key = _annotation_class_key(annotation)
|
||||
combined = class_masks.setdefault(key, np.zeros((height, width), dtype=np.uint8))
|
||||
for poly in (annotation.mask_data or {}).get("polygons", []):
|
||||
combined[:] = np.maximum(combined, _mask_from_polygon(poly, width, height))
|
||||
class_meta.setdefault(key, _class_mapping_entry(annotation))
|
||||
|
||||
folder = f"分开Mask分割结果/{frame_stem}_分别导出"
|
||||
for key, mask in sorted(class_masks.items(), key=lambda item: int(class_meta[item[0]]["internalPriority"])):
|
||||
meta = class_meta[key]
|
||||
maskid = class_values.get(key)
|
||||
if maskid is None:
|
||||
continue
|
||||
_, encoded = cv2.imencode(".png", mask)
|
||||
class_name = _safe_filename_part(meta["className"], "class")
|
||||
zf.writestr(
|
||||
f"{folder}/{frame_stem}_{class_name}_maskid{maskid}.png",
|
||||
encoded.tobytes(),
|
||||
)
|
||||
|
||||
needs_fused_output = include_semantic or include_pro_label or include_mix_label
|
||||
semantic = np.zeros((height, width), dtype=np.uint8) if needs_fused_output else None
|
||||
pro_label = np.zeros((height, width, 3), dtype=np.uint8) if (include_pro_label or include_mix_label) else None
|
||||
|
||||
if needs_fused_output:
|
||||
for annotation in sorted(frame_annotations, key=_annotation_z_index):
|
||||
key = _annotation_class_key(annotation)
|
||||
value = class_values.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
combined = np.zeros((height, width), dtype=np.uint8)
|
||||
for poly in (annotation.mask_data or {}).get("polygons", []):
|
||||
combined = np.maximum(combined, _mask_from_polygon(poly, width, height))
|
||||
if semantic is not None:
|
||||
semantic[combined > 0] = value
|
||||
if pro_label is not None:
|
||||
rgb = class_rgb_by_key.get(key, [255, 255, 255])
|
||||
bgr = np.array([rgb[2], rgb[1], rgb[0]], dtype=np.uint8)
|
||||
pro_label[combined > 0] = bgr
|
||||
|
||||
if include_semantic and semantic is not None:
|
||||
_, encoded = cv2.imencode(".png", semantic)
|
||||
zf.writestr(f"GT_label图/{frame_stem}.png", encoded.tobytes())
|
||||
|
||||
if include_pro_label and pro_label is not None:
|
||||
_, encoded = cv2.imencode(".png", pro_label)
|
||||
zf.writestr(f"Pro_label彩色分割结果/{frame_stem}.png", encoded.tobytes())
|
||||
|
||||
if include_mix_label and pro_label is not None:
|
||||
original = _decode_original_image(original_images.get(frame.id), width, height)
|
||||
mask_pixels = np.any(pro_label > 0, axis=2)
|
||||
mixed = original.copy()
|
||||
opacity = min(max(float(mix_opacity), 0.0), 1.0)
|
||||
mixed[mask_pixels] = (
|
||||
original[mask_pixels].astype(np.float32) * (1.0 - opacity)
|
||||
+ pro_label[mask_pixels].astype(np.float32) * opacity
|
||||
).clip(0, 255).astype(np.uint8)
|
||||
_, encoded = cv2.imencode(".png", mixed)
|
||||
zf.writestr(f"Mix_label重叠覆盖彩色分割结果/{frame_stem}.png", encoded.tobytes())
|
||||
|
||||
|
||||
def _write_mask_pngs(
|
||||
zf: zipfile.ZipFile,
|
||||
frames: list[Frame],
|
||||
annotations: list[Annotation],
|
||||
*,
|
||||
mask_type: str,
|
||||
individual_prefix: str = "",
|
||||
semantic_prefix: str = "",
|
||||
semantic_file_stem: str = "semantic_frame",
|
||||
semantic_dtype: Any = np.uint8,
|
||||
) -> list[dict[str, Any]]:
|
||||
import cv2
|
||||
|
||||
class_values: dict[str, int] = {}
|
||||
semantic_classes: list[dict[str, Any]] = []
|
||||
|
||||
def class_value(annotation: Annotation) -> int:
|
||||
key = _annotation_class_key(annotation)
|
||||
if key not in class_values:
|
||||
value = len(class_values) + 1
|
||||
class_values[key] = value
|
||||
semantic_classes.append({
|
||||
"value": value,
|
||||
"key": key,
|
||||
"label": _annotation_label(annotation),
|
||||
"color": _annotation_color(annotation),
|
||||
"zIndex": _annotation_z_index(annotation),
|
||||
"template_id": annotation.template_id,
|
||||
})
|
||||
return class_values[key]
|
||||
|
||||
include_individual = mask_type in {"separate", "both"}
|
||||
include_semantic = mask_type in {"gt_label", "both"}
|
||||
frame_masks: dict[int, list[tuple[Annotation, np.ndarray]]] = {}
|
||||
selected_frame_ids = {frame.id for frame in frames}
|
||||
|
||||
for ann in annotations:
|
||||
if ann.frame_id not in selected_frame_ids or not ann.mask_data:
|
||||
continue
|
||||
polygons = ann.mask_data.get("polygons", [])
|
||||
if not polygons:
|
||||
continue
|
||||
|
||||
width = ann.frame.width if ann.frame else 1920
|
||||
height = ann.frame.height if ann.frame else 1080
|
||||
combined = np.zeros((height, width), dtype=np.uint8)
|
||||
for poly in polygons:
|
||||
mask = _mask_from_polygon(poly, width, height)
|
||||
combined = np.maximum(combined, mask)
|
||||
|
||||
if include_individual:
|
||||
_, encoded = cv2.imencode(".png", combined)
|
||||
zf.writestr(f"{individual_prefix}mask_{ann.id:06d}.png", encoded.tobytes())
|
||||
if include_semantic and ann.frame_id is not None:
|
||||
frame_masks.setdefault(ann.frame_id, []).append((ann, combined))
|
||||
|
||||
if include_semantic:
|
||||
for frame in frames:
|
||||
entries = frame_masks.get(frame.id, [])
|
||||
if not entries:
|
||||
continue
|
||||
width = frame.width or 1920
|
||||
height = frame.height or 1080
|
||||
semantic = np.zeros((height, width), dtype=semantic_dtype)
|
||||
for ann, mask in sorted(entries, key=lambda item: _annotation_z_index(item[0])):
|
||||
semantic[mask > 0] = class_value(ann)
|
||||
_, encoded = cv2.imencode(".png", semantic)
|
||||
zf.writestr(f"{semantic_prefix}{semantic_file_stem}_{frame.frame_index:06d}.png", encoded.tobytes())
|
||||
|
||||
if include_semantic:
|
||||
zf.writestr(
|
||||
f"{semantic_prefix}semantic_classes.json",
|
||||
json.dumps({"classes": semantic_classes}, ensure_ascii=False, indent=2).encode("utf-8"),
|
||||
)
|
||||
return semantic_classes
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/coco",
|
||||
summary="Export annotations in COCO format",
|
||||
)
|
||||
def export_coco(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
"""Export all annotations for a project as a COCO-format JSON file."""
|
||||
project = _project_or_404(project_id, db, current_user)
|
||||
frames = _project_frames(project_id, db)
|
||||
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
|
||||
templates = db.query(Template).all()
|
||||
coco = _build_coco(project, frames, annotations, templates)
|
||||
|
||||
data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8")
|
||||
filename = f"project_{project_id}_coco.json"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(data),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/masks",
|
||||
summary="Export PNG masks as a ZIP archive",
|
||||
)
|
||||
def export_masks(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
"""Export individual masks plus z-index fused semantic masks inside a ZIP."""
|
||||
_project_or_404(project_id, db, current_user)
|
||||
frames = _project_frames(project_id, db)
|
||||
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
_write_mask_pngs(
|
||||
zf,
|
||||
frames,
|
||||
annotations,
|
||||
mask_type="both",
|
||||
semantic_prefix="",
|
||||
individual_prefix="",
|
||||
)
|
||||
|
||||
zip_buffer.seek(0)
|
||||
filename = f"project_{project_id}_masks.zip"
|
||||
|
||||
return StreamingResponse(
|
||||
zip_buffer,
|
||||
media_type="application/zip",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/results",
|
||||
summary="Export segmentation results as a ZIP archive",
|
||||
)
|
||||
def export_results(
|
||||
project_id: int,
|
||||
scope: str = Query("all", pattern="^(all|range|current)$"),
|
||||
mask_type: str = Query("both", pattern="^(separate|gt_label|pro_label|mix_label|both|all)$"),
|
||||
outputs: str | None = Query(None),
|
||||
mix_opacity: float = Query(0.3, ge=0.0, le=1.0),
|
||||
start_frame: int | None = Query(None, ge=1),
|
||||
end_frame: int | None = Query(None, ge=1),
|
||||
frame_id: int | None = Query(None, ge=1),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
"""Export JSON annotations plus selected PNG mask outputs inside one ZIP.
|
||||
|
||||
`scope=all` exports the whole video. `scope=range` uses 1-based frame
|
||||
numbers from the sorted project frame sequence. `scope=current` uses the
|
||||
concrete backend `frame_id`.
|
||||
"""
|
||||
project = _project_or_404(project_id, db, current_user)
|
||||
frames = _filter_frames(
|
||||
_project_frames(project_id, db),
|
||||
scope=scope,
|
||||
start_frame=start_frame,
|
||||
end_frame=end_frame,
|
||||
frame_id=frame_id,
|
||||
)
|
||||
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
|
||||
templates = db.query(Template).all()
|
||||
coco = _build_coco(project, frames, annotations, templates)
|
||||
class_values, class_mapping = _build_gt_class_mapping(annotations)
|
||||
selected_outputs = _parse_result_outputs(mask_type, outputs)
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
zf.writestr(
|
||||
"annotations_coco.json",
|
||||
json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8"),
|
||||
)
|
||||
zf.writestr(
|
||||
"maskid_GT像素值_类别映射.json",
|
||||
json.dumps({"classes": class_mapping}, ensure_ascii=False, indent=2).encode("utf-8"),
|
||||
)
|
||||
original_images = _write_original_frames(zf, project, frames)
|
||||
_write_result_mask_outputs(
|
||||
zf,
|
||||
project,
|
||||
frames,
|
||||
annotations,
|
||||
outputs=selected_outputs,
|
||||
class_values=class_values,
|
||||
class_mapping=class_mapping,
|
||||
original_images=original_images,
|
||||
mix_opacity=mix_opacity,
|
||||
)
|
||||
|
||||
zip_buffer.seek(0)
|
||||
filename = _segmentation_results_filename(project, frames)
|
||||
return StreamingResponse(
|
||||
zip_buffer,
|
||||
media_type="application/zip",
|
||||
headers={"Content-Disposition": _download_content_disposition(filename)},
|
||||
)
|
||||
234
backend/routers/media.py
Normal file
234
backend/routers/media.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""Media upload and parsing endpoints."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import upload_file, get_presigned_url
|
||||
from models import ProcessingTask, Project, User
|
||||
from progress_events import publish_task_progress_event
|
||||
from routers.auth import require_editor
|
||||
from schemas import ProcessingTaskOut
|
||||
from statuses import PROJECT_STATUS_PARSING, PROJECT_STATUS_PENDING, TASK_STATUS_QUEUED
|
||||
from worker_tasks import parse_project_media
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/media", tags=["Media"])
|
||||
|
||||
ALLOWED_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".png", ".jpg", ".jpeg", ".dcm"}
|
||||
|
||||
|
||||
def natural_filename_key(filename: str) -> tuple[object, ...]:
|
||||
return tuple(
|
||||
int(part) if part.isdigit() else part.casefold()
|
||||
for part in re.split(r"(\d+)", Path(filename).name)
|
||||
)
|
||||
|
||||
|
||||
def _get_ext(filename: str) -> str:
|
||||
return Path(filename).suffix.lower()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/upload",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Upload a media file",
|
||||
)
|
||||
async def upload_media(
|
||||
file: UploadFile = File(...),
|
||||
project_id: Optional[int] = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Accept a video, image, or DICOM file and store it in MinIO.
|
||||
|
||||
If project_id is provided, the video_path of the project is updated.
|
||||
Returns the presigned URL of the uploaded object.
|
||||
"""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="Missing filename")
|
||||
|
||||
ext = _get_ext(file.filename)
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported file type: {ext}",
|
||||
)
|
||||
|
||||
data = await file.read()
|
||||
object_name = f"uploads/{project_id or 'general'}/{file.filename}"
|
||||
|
||||
try:
|
||||
upload_file(object_name, data, content_type=file.content_type or "application/octet-stream", length=len(data))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Upload failed: %s", exc)
|
||||
raise HTTPException(status_code=500, detail="Upload to storage failed") from exc
|
||||
|
||||
file_url = get_presigned_url(object_name, expires=3600)
|
||||
|
||||
if project_id:
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
project.video_path = object_name
|
||||
db.commit()
|
||||
logger.info("Linked upload to project_id=%s", project_id)
|
||||
else:
|
||||
# Auto-create a project named after the file
|
||||
project = Project(
|
||||
name=file.filename,
|
||||
description="Auto-created from upload",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
video_path=object_name,
|
||||
source_type="video",
|
||||
owner_user_id=current_user.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
project_id = project.id
|
||||
object_name = f"uploads/{project_id}/{file.filename}"
|
||||
# Re-upload with corrected path
|
||||
upload_file(object_name, data, content_type=file.content_type or "application/octet-stream", length=len(data))
|
||||
project.video_path = object_name
|
||||
db.commit()
|
||||
logger.info("Auto-created project id=%s for upload %s", project_id, file.filename)
|
||||
|
||||
logger.info("Upload complete: %s (size=%d bytes). Async parsing queued.", object_name, len(data))
|
||||
|
||||
return {
|
||||
"object_name": object_name,
|
||||
"file_url": file_url,
|
||||
"size": len(data),
|
||||
"project_id": project_id,
|
||||
"message": "Upload successful. Parsing job queued.",
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/upload/dicom",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Upload multiple DICOM files",
|
||||
)
|
||||
async def upload_dicom_batch(
|
||||
files: List[UploadFile] = File(...),
|
||||
project_id: Optional[int] = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Upload multiple .dcm files for a DICOM series.
|
||||
|
||||
If project_id is provided, files are added to the existing project.
|
||||
Otherwise a new DICOM project is created.
|
||||
"""
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files uploaded")
|
||||
|
||||
sorted_files = sorted(
|
||||
[file for file in files if file.filename and file.filename.lower().endswith(".dcm")],
|
||||
key=lambda file: natural_filename_key(file.filename or ""),
|
||||
)
|
||||
if not sorted_files:
|
||||
raise HTTPException(status_code=400, detail="No valid DICOM files uploaded")
|
||||
uploaded = []
|
||||
|
||||
if project_id:
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
else:
|
||||
# Create new DICOM project
|
||||
first_name = sorted_files[0].filename or "DICOM_Series"
|
||||
project = Project(
|
||||
name=first_name,
|
||||
description=f"DICOM series with {len(sorted_files)} files",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="dicom",
|
||||
owner_user_id=current_user.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
project_id = project.id
|
||||
logger.info("Auto-created DICOM project id=%s", project_id)
|
||||
|
||||
for file in sorted_files:
|
||||
data = await file.read()
|
||||
object_name = f"uploads/{project_id}/dicom/{file.filename}"
|
||||
try:
|
||||
upload_file(object_name, data, content_type="application/dicom", length=len(data))
|
||||
uploaded.append(object_name)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to upload DICOM %s: %s", file.filename, exc)
|
||||
|
||||
project.video_path = f"uploads/{project_id}/dicom"
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"project_id": project_id,
|
||||
"uploaded_count": len(uploaded),
|
||||
"message": f"Uploaded {len(uploaded)} DICOM files. Parsing job queued.",
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/parse",
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
response_model=ProcessingTaskOut,
|
||||
summary="Trigger frame extraction",
|
||||
)
|
||||
def parse_media(
|
||||
project_id: int,
|
||||
source_type: Optional[str] = None,
|
||||
parse_fps: Optional[float] = Query(None, gt=0, le=120),
|
||||
max_frames: Optional[int] = Query(None, gt=0),
|
||||
target_width: int = Query(640, ge=64, le=4096),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> ProcessingTask:
|
||||
"""Create a background task for media frame extraction.
|
||||
|
||||
The Celery worker performs the heavy FFmpeg/OpenCV/pydicom work and
|
||||
updates the persisted task record as it progresses.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
if not project.video_path:
|
||||
raise HTTPException(status_code=400, detail="Project has no media uploaded")
|
||||
|
||||
effective_source = source_type or project.source_type or "video"
|
||||
effective_parse_fps = parse_fps or project.parse_fps or 30.0
|
||||
task = ProcessingTask(
|
||||
task_type=f"parse_{effective_source}",
|
||||
status=TASK_STATUS_QUEUED,
|
||||
progress=0,
|
||||
message="解析任务已入队",
|
||||
project_id=project_id,
|
||||
payload={
|
||||
"source_type": effective_source,
|
||||
"parse_fps": effective_parse_fps,
|
||||
"max_frames": max_frames,
|
||||
"target_width": target_width,
|
||||
},
|
||||
)
|
||||
project.parse_fps = effective_parse_fps
|
||||
project.status = PROJECT_STATUS_PARSING
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
|
||||
async_result = parse_project_media.delay(task.id)
|
||||
task.celery_task_id = async_result.id
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
|
||||
logger.info("Queued parse task id=%s project_id=%s celery_id=%s", task.id, project_id, async_result.id)
|
||||
return task
|
||||
310
backend/routers/projects.py
Normal file
310
backend/routers/projects.py
Normal file
@@ -0,0 +1,310 @@
|
||||
"""Project and Frame CRUD endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Annotation, Mask, Project, Frame, User
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import ProjectCopyRequest, ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut
|
||||
from minio_client import get_presigned_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/projects", tags=["Projects"])
|
||||
|
||||
|
||||
def _next_project_copy_name(db: Session, source_name: str) -> str:
|
||||
base_name = f"{source_name} 副本"
|
||||
existing_names = {
|
||||
row[0]
|
||||
for row in db.query(Project.name)
|
||||
.filter(Project.name.like(f"{base_name}%"))
|
||||
.all()
|
||||
}
|
||||
if base_name not in existing_names:
|
||||
return base_name
|
||||
suffix = 2
|
||||
while f"{base_name} {suffix}" in existing_names:
|
||||
suffix += 1
|
||||
return f"{base_name} {suffix}"
|
||||
|
||||
|
||||
def _prepare_project_response(project: Project) -> Project:
|
||||
project.frame_count = len(project.frames)
|
||||
if project.thumbnail_url:
|
||||
project.thumbnail_url = get_presigned_url(project.thumbnail_url, expires=3600)
|
||||
return project
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Projects
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post(
|
||||
"",
|
||||
response_model=ProjectOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a new project",
|
||||
)
|
||||
def create_project(
|
||||
payload: ProjectCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Project:
|
||||
"""Create a new segmentation project."""
|
||||
project = Project(**payload.model_dump(), owner_user_id=current_user.id)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
logger.info("Created project id=%s name=%s", project.id, project.name)
|
||||
return project
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=List[ProjectOut],
|
||||
summary="List all projects",
|
||||
)
|
||||
def list_projects(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[Project]:
|
||||
"""Retrieve a paginated list of projects."""
|
||||
projects = (
|
||||
db.query(Project)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
for p in projects:
|
||||
_prepare_project_response(p)
|
||||
return projects
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}",
|
||||
response_model=ProjectOut,
|
||||
summary="Get a single project",
|
||||
)
|
||||
def get_project(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Project:
|
||||
"""Retrieve a project by its ID."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
return _prepare_project_response(project)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{project_id}/copy",
|
||||
response_model=ProjectOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Copy a project",
|
||||
)
|
||||
def copy_project(
|
||||
project_id: int,
|
||||
payload: ProjectCopyRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Project:
|
||||
"""Copy a project. Reset copies media/frame sequence; full also copies annotations and mask metadata."""
|
||||
source = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
next_name = (payload.name or "").strip() if payload.name is not None else ""
|
||||
if not next_name:
|
||||
next_name = _next_project_copy_name(db, source.name)
|
||||
|
||||
copied = Project(
|
||||
name=next_name,
|
||||
description=source.description,
|
||||
video_path=source.video_path,
|
||||
thumbnail_url=source.thumbnail_url,
|
||||
status=source.status,
|
||||
source_type=source.source_type,
|
||||
original_fps=source.original_fps,
|
||||
parse_fps=source.parse_fps,
|
||||
owner_user_id=current_user.id,
|
||||
)
|
||||
db.add(copied)
|
||||
db.flush()
|
||||
|
||||
frame_id_map: dict[int, int] = {}
|
||||
for frame in sorted(source.frames, key=lambda item: item.frame_index):
|
||||
copied_frame = Frame(
|
||||
project_id=copied.id,
|
||||
frame_index=frame.frame_index,
|
||||
image_url=frame.image_url,
|
||||
width=frame.width,
|
||||
height=frame.height,
|
||||
timestamp_ms=frame.timestamp_ms,
|
||||
source_frame_number=frame.source_frame_number,
|
||||
)
|
||||
db.add(copied_frame)
|
||||
db.flush()
|
||||
frame_id_map[frame.id] = copied_frame.id
|
||||
|
||||
if payload.mode == "full":
|
||||
for annotation in sorted(source.annotations, key=lambda item: item.id):
|
||||
copied_annotation = Annotation(
|
||||
project_id=copied.id,
|
||||
frame_id=frame_id_map.get(annotation.frame_id) if annotation.frame_id is not None else None,
|
||||
template_id=annotation.template_id,
|
||||
mask_data=annotation.mask_data,
|
||||
points=annotation.points,
|
||||
bbox=annotation.bbox,
|
||||
)
|
||||
db.add(copied_annotation)
|
||||
db.flush()
|
||||
for mask in annotation.masks:
|
||||
db.add(Mask(
|
||||
annotation_id=copied_annotation.id,
|
||||
mask_url=mask.mask_url,
|
||||
format=mask.format,
|
||||
))
|
||||
|
||||
db.commit()
|
||||
db.refresh(copied)
|
||||
logger.info("Copied project id=%s to id=%s mode=%s", project_id, copied.id, payload.mode)
|
||||
return _prepare_project_response(copied)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{project_id}",
|
||||
response_model=ProjectOut,
|
||||
summary="Update a project",
|
||||
)
|
||||
def update_project(
|
||||
project_id: int,
|
||||
payload: ProjectUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Project:
|
||||
"""Update project fields partially."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
for key, value in payload.model_dump(exclude_unset=True).items():
|
||||
if key == "name":
|
||||
value = (value or "").strip()
|
||||
if not value:
|
||||
raise HTTPException(status_code=400, detail="Project name is required")
|
||||
setattr(project, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
logger.info("Updated project id=%s", project_id)
|
||||
return project
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{project_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a project",
|
||||
)
|
||||
def delete_project(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> None:
|
||||
"""Delete a project and all related frames and annotations."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
db.delete(project)
|
||||
db.commit()
|
||||
logger.info("Deleted project id=%s", project_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frames
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post(
|
||||
"/{project_id}/frames",
|
||||
response_model=FrameOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Add a frame to a project",
|
||||
)
|
||||
def create_frame(
|
||||
project_id: int,
|
||||
payload: FrameCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Frame:
|
||||
"""Register a new frame under a project."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
frame = Frame(project_id=project_id, **payload.model_dump(exclude={"project_id"}))
|
||||
db.add(frame)
|
||||
db.commit()
|
||||
db.refresh(frame)
|
||||
return frame
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/frames",
|
||||
response_model=List[FrameOut],
|
||||
summary="List frames for a project",
|
||||
)
|
||||
def list_frames(
|
||||
project_id: int,
|
||||
skip: int = 0,
|
||||
limit: Optional[int] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[Frame]:
|
||||
"""Retrieve all frames belonging to a project."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
query = (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id)
|
||||
.order_by(Frame.frame_index)
|
||||
.offset(skip)
|
||||
)
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
frames = query.all()
|
||||
for frame in frames:
|
||||
frame.image_url = get_presigned_url(frame.image_url, expires=3600)
|
||||
return frames
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/frames/{frame_id}",
|
||||
response_model=FrameOut,
|
||||
summary="Get a single frame",
|
||||
)
|
||||
def get_frame(
|
||||
project_id: int,
|
||||
frame_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Frame:
|
||||
"""Retrieve a specific frame by ID."""
|
||||
frame = (
|
||||
db.query(Frame)
|
||||
.join(Project, Project.id == Frame.project_id)
|
||||
.filter(
|
||||
Frame.project_id == project_id,
|
||||
Frame.id == frame_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
return frame
|
||||
161
backend/routers/tasks.py
Normal file
161
backend/routers/tasks.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Processing task query endpoints."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from celery_app import celery_app
|
||||
from database import get_db
|
||||
from models import ProcessingTask, Project, User
|
||||
from progress_events import publish_task_progress_event
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import ProcessingTaskOut
|
||||
from statuses import (
|
||||
PROJECT_STATUS_PARSING,
|
||||
PROJECT_STATUS_PENDING,
|
||||
PROJECT_STATUS_READY,
|
||||
TASK_ACTIVE_STATUSES,
|
||||
TASK_STATUS_CANCELLED,
|
||||
TASK_STATUS_FAILED,
|
||||
TASK_STATUS_QUEUED,
|
||||
)
|
||||
from worker_tasks import parse_project_media, propagate_project_masks
|
||||
|
||||
router = APIRouter(prefix="/api/tasks", tags=["Tasks"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _get_task_or_404(task_id: int, db: Session, current_user: User) -> ProcessingTask:
|
||||
_ = current_user
|
||||
task = (
|
||||
db.query(ProcessingTask)
|
||||
.outerjoin(Project, Project.id == ProcessingTask.project_id)
|
||||
.filter(ProcessingTask.id == task_id)
|
||||
.first()
|
||||
)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return task
|
||||
|
||||
|
||||
def _project_status_after_stop(project: Project) -> str:
|
||||
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
|
||||
|
||||
|
||||
@router.get("", response_model=List[ProcessingTaskOut], summary="List processing tasks")
|
||||
def list_tasks(
|
||||
project_id: int | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 50,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[ProcessingTask]:
|
||||
"""Return recent background processing tasks."""
|
||||
_ = current_user
|
||||
query = db.query(ProcessingTask).outerjoin(Project, Project.id == ProcessingTask.project_id)
|
||||
if project_id is not None:
|
||||
query = query.filter(ProcessingTask.project_id == project_id)
|
||||
if status is not None:
|
||||
query = query.filter(ProcessingTask.status == status)
|
||||
return query.order_by(ProcessingTask.created_at.desc()).limit(limit).all()
|
||||
|
||||
|
||||
@router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task")
|
||||
def get_task(
|
||||
task_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> ProcessingTask:
|
||||
"""Return one background task by id."""
|
||||
return _get_task_or_404(task_id, db, current_user)
|
||||
|
||||
|
||||
@router.post("/{task_id}/cancel", response_model=ProcessingTaskOut, summary="Cancel processing task")
|
||||
def cancel_task(
|
||||
task_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> ProcessingTask:
|
||||
"""Cancel a queued/running background task and revoke the Celery job when possible."""
|
||||
task = _get_task_or_404(task_id, db, current_user)
|
||||
if task.status not in TASK_ACTIVE_STATUSES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Task is not cancellable in status: {task.status}",
|
||||
)
|
||||
|
||||
if task.celery_task_id:
|
||||
try:
|
||||
celery_app.control.revoke(task.celery_task_id, terminate=True, signal="SIGTERM")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Failed to revoke celery task %s: %s", task.celery_task_id, exc)
|
||||
|
||||
task.status = TASK_STATUS_CANCELLED
|
||||
task.progress = 100
|
||||
task.message = "任务已取消"
|
||||
task.error = "Cancelled by user"
|
||||
task.finished_at = _now()
|
||||
if task.project:
|
||||
task.project.status = _project_status_after_stop(task.project)
|
||||
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
return task
|
||||
|
||||
|
||||
@router.post("/{task_id}/retry", response_model=ProcessingTaskOut, status_code=status.HTTP_202_ACCEPTED, summary="Retry processing task")
|
||||
def retry_task(
|
||||
task_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> ProcessingTask:
|
||||
"""Create a fresh queued task from a failed or cancelled task."""
|
||||
previous = _get_task_or_404(task_id, db, current_user)
|
||||
if previous.status not in {TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Task is not retryable in status: {previous.status}",
|
||||
)
|
||||
if previous.project_id is None:
|
||||
raise HTTPException(status_code=400, detail="Task has no project_id")
|
||||
|
||||
project = db.query(Project).filter(Project.id == previous.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
is_propagation_task = previous.task_type == "propagate_masks"
|
||||
if not is_propagation_task and not project.video_path:
|
||||
raise HTTPException(status_code=400, detail="Project has no media uploaded")
|
||||
|
||||
payload = dict(previous.payload or {})
|
||||
payload.setdefault("source_type", project.source_type or "video")
|
||||
payload["retry_of"] = previous.id
|
||||
|
||||
task = ProcessingTask(
|
||||
task_type=previous.task_type,
|
||||
status=TASK_STATUS_QUEUED,
|
||||
progress=0,
|
||||
message=f"重试任务已入队(源任务 #{previous.id})",
|
||||
project_id=project.id,
|
||||
payload=payload,
|
||||
)
|
||||
if not is_propagation_task:
|
||||
project.status = PROJECT_STATUS_PARSING
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
|
||||
async_result = propagate_project_masks.delay(task.id) if is_propagation_task else parse_project_media.delay(task.id)
|
||||
task.celery_task_id = async_result.id
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
return task
|
||||
183
backend/routers/templates.py
Normal file
183
backend/routers/templates.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Template (Ontology) CRUD endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Template, User
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import TemplateCreate, TemplateOut, TemplateUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/templates", tags=["Templates"])
|
||||
RESERVED_UNCLASSIFIED_CLASS = {
|
||||
"id": "reserved-unclassified",
|
||||
"name": "待分类",
|
||||
"color": "#000000",
|
||||
"zIndex": 0,
|
||||
"maskId": 0,
|
||||
"category": "系统保留",
|
||||
}
|
||||
|
||||
|
||||
def _is_reserved_class(item: dict) -> bool:
|
||||
return (
|
||||
item.get("id") == RESERVED_UNCLASSIFIED_CLASS["id"]
|
||||
or item.get("name") == RESERVED_UNCLASSIFIED_CLASS["name"]
|
||||
or item.get("maskId") == 0
|
||||
)
|
||||
|
||||
|
||||
def _normalize_template_classes(classes: list[dict] | None) -> list[dict]:
|
||||
normalized = [item for item in (classes or []) if not _is_reserved_class(item)]
|
||||
return [*normalized, dict(RESERVED_UNCLASSIFIED_CLASS)]
|
||||
|
||||
|
||||
def _pack_mapping_rules(data: dict) -> dict:
|
||||
"""Pack classes/rules into mapping_rules for DB storage."""
|
||||
mapping = data.get("mapping_rules") or {}
|
||||
if "classes" in data and data["classes"] is not None:
|
||||
mapping["classes"] = _normalize_template_classes(data.pop("classes"))
|
||||
if "rules" in data and data["rules"] is not None:
|
||||
mapping["rules"] = data.pop("rules")
|
||||
if "classes" in mapping:
|
||||
mapping["classes"] = _normalize_template_classes(mapping.get("classes"))
|
||||
data["mapping_rules"] = mapping
|
||||
return data
|
||||
|
||||
|
||||
def _unpack_template(template: Template) -> Template:
|
||||
"""Unpack mapping_rules into classes/rules for response."""
|
||||
mapping = template.mapping_rules or {}
|
||||
# Set as attributes so Pydantic from_attributes can pick them up
|
||||
template.classes = _normalize_template_classes(mapping.get("classes", []))
|
||||
template.rules = mapping.get("rules", [])
|
||||
return template
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=TemplateOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a new template",
|
||||
)
|
||||
def create_template(
|
||||
payload: TemplateCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Template:
|
||||
"""Create a new ontology template / segmentation class."""
|
||||
data = payload.model_dump()
|
||||
data = _pack_mapping_rules(data)
|
||||
template = Template(**data, owner_user_id=current_user.id)
|
||||
db.add(template)
|
||||
db.commit()
|
||||
db.refresh(template)
|
||||
_unpack_template(template)
|
||||
logger.info("Created template id=%s name=%s", template.id, template.name)
|
||||
return template
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=List[TemplateOut],
|
||||
summary="List all templates",
|
||||
)
|
||||
def list_templates(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[Template]:
|
||||
"""Retrieve all ontology templates."""
|
||||
templates = (
|
||||
db.query(Template)
|
||||
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
for t in templates:
|
||||
_unpack_template(t)
|
||||
return templates
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{template_id}",
|
||||
response_model=TemplateOut,
|
||||
summary="Get a single template",
|
||||
)
|
||||
def get_template(
|
||||
template_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Template:
|
||||
"""Retrieve a template by its ID."""
|
||||
template = db.query(Template).filter(
|
||||
Template.id == template_id,
|
||||
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
_unpack_template(template)
|
||||
return template
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{template_id}",
|
||||
response_model=TemplateOut,
|
||||
summary="Update a template",
|
||||
)
|
||||
def update_template(
|
||||
template_id: int,
|
||||
payload: TemplateUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Template:
|
||||
"""Update template fields partially."""
|
||||
template = db.query(Template).filter(
|
||||
Template.id == template_id,
|
||||
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
data = payload.model_dump(exclude_unset=True)
|
||||
if "classes" in data or "rules" in data:
|
||||
data = _pack_mapping_rules(data)
|
||||
|
||||
for key, value in data.items():
|
||||
setattr(template, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(template)
|
||||
_unpack_template(template)
|
||||
logger.info("Updated template id=%s", template_id)
|
||||
return template
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{template_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a template",
|
||||
)
|
||||
def delete_template(
|
||||
template_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> None:
|
||||
"""Delete a template. Associated annotations will have template_id set to NULL."""
|
||||
template = db.query(Template).filter(
|
||||
Template.id == template_id,
|
||||
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
db.delete(template)
|
||||
db.commit()
|
||||
logger.info("Deleted template id=%s", template_id)
|
||||
385
backend/schemas.py
Normal file
385
backend/schemas.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""Pydantic schemas for request/response validation."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional, Any
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth / user schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class UserOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
username: str
|
||||
role: str
|
||||
is_active: int
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
token: str
|
||||
token_type: str = "bearer"
|
||||
username: str
|
||||
user: UserOut
|
||||
|
||||
|
||||
class AdminUserCreate(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
role: str = "annotator"
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class AdminUserUpdate(BaseModel):
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class AuditLogOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
actor_user_id: Optional[int] = None
|
||||
action: str
|
||||
target_type: Optional[str] = None
|
||||
target_id: Optional[str] = None
|
||||
detail: Optional[dict[str, Any]] = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class DemoFactoryResetRequest(BaseModel):
|
||||
confirmation: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Project schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class ProjectBase(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
video_path: Optional[str] = None
|
||||
thumbnail_url: Optional[str] = None
|
||||
status: Optional[str] = "pending"
|
||||
source_type: Optional[str] = "video"
|
||||
original_fps: Optional[float] = None
|
||||
parse_fps: Optional[float] = 30.0
|
||||
|
||||
|
||||
class ProjectCreate(ProjectBase):
|
||||
pass
|
||||
|
||||
|
||||
class ProjectUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
video_path: Optional[str] = None
|
||||
thumbnail_url: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
source_type: Optional[str] = None
|
||||
original_fps: Optional[float] = None
|
||||
parse_fps: Optional[float] = None
|
||||
|
||||
|
||||
class ProjectCopyRequest(BaseModel):
|
||||
mode: Literal["reset", "full"] = "reset"
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class ProjectOut(ProjectBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
owner_user_id: Optional[int] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
frame_count: int = 0
|
||||
|
||||
|
||||
class DemoFactoryResetOut(BaseModel):
|
||||
admin_user: UserOut
|
||||
project: ProjectOut
|
||||
projects: list[ProjectOut]
|
||||
deleted_counts: dict[str, int]
|
||||
message: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frame schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class FrameBase(BaseModel):
|
||||
frame_index: int
|
||||
image_url: str
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
timestamp_ms: Optional[float] = None
|
||||
source_frame_number: Optional[int] = None
|
||||
|
||||
|
||||
class FrameCreate(FrameBase):
|
||||
project_id: int
|
||||
|
||||
|
||||
class FrameOut(FrameBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
project_id: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Template schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class TemplateBase(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
color: str
|
||||
z_index: int = 0
|
||||
mapping_rules: Optional[dict[str, Any]] = None
|
||||
classes: Optional[list[dict[str, Any]]] = None
|
||||
rules: Optional[list[dict[str, Any]]] = None
|
||||
|
||||
|
||||
class TemplateCreate(TemplateBase):
|
||||
pass
|
||||
|
||||
|
||||
class TemplateUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
z_index: Optional[int] = None
|
||||
mapping_rules: Optional[dict[str, Any]] = None
|
||||
classes: Optional[list[dict[str, Any]]] = None
|
||||
rules: Optional[list[dict[str, Any]]] = None
|
||||
|
||||
|
||||
class TemplateOut(TemplateBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
owner_user_id: Optional[int] = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Annotation schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class AnnotationBase(BaseModel):
|
||||
project_id: int
|
||||
frame_id: Optional[int] = None
|
||||
template_id: Optional[int] = None
|
||||
mask_data: Optional[dict[str, Any]] = None
|
||||
points: Optional[list[list[float]]] = None
|
||||
bbox: Optional[list[float]] = None
|
||||
|
||||
|
||||
class AnnotationCreate(AnnotationBase):
|
||||
pass
|
||||
|
||||
|
||||
class AnnotationUpdate(BaseModel):
|
||||
mask_data: Optional[dict[str, Any]] = None
|
||||
points: Optional[list[list[float]]] = None
|
||||
bbox: Optional[list[float]] = None
|
||||
template_id: Optional[int] = None
|
||||
|
||||
|
||||
class AnnotationOut(AnnotationBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mask schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class MaskBase(BaseModel):
|
||||
annotation_id: int
|
||||
mask_url: str
|
||||
format: str = "png"
|
||||
|
||||
|
||||
class MaskCreate(MaskBase):
|
||||
pass
|
||||
|
||||
|
||||
class MaskOut(MaskBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Processing task schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class ProcessingTaskOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
task_type: str
|
||||
status: str
|
||||
progress: int
|
||||
message: Optional[str] = None
|
||||
project_id: Optional[int] = None
|
||||
celery_task_id: Optional[str] = None
|
||||
payload: Optional[dict[str, Any]] = None
|
||||
result: Optional[dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
created_at: datetime
|
||||
started_at: Optional[datetime] = None
|
||||
finished_at: Optional[datetime] = None
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AI schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class PredictRequest(BaseModel):
|
||||
image_id: int
|
||||
prompt_type: str # point / box / semantic
|
||||
prompt_data: Any
|
||||
model: Optional[str] = None
|
||||
options: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class PredictResponse(BaseModel):
|
||||
polygons: list[list[list[float]]]
|
||||
scores: Optional[list[float]] = None
|
||||
|
||||
|
||||
class MaskAnalysisRequest(BaseModel):
|
||||
frame_id: Optional[int] = None
|
||||
mask_data: dict[str, Any]
|
||||
points: Optional[list[list[float]]] = None
|
||||
bbox: Optional[list[float]] = None
|
||||
extract_skeleton: bool = False
|
||||
|
||||
|
||||
class MaskAnalysisResponse(BaseModel):
|
||||
confidence: Optional[float] = None
|
||||
confidence_source: str
|
||||
topology_anchor_count: int
|
||||
topology_anchors: list[list[float]]
|
||||
area: float
|
||||
bbox: Optional[list[float]] = None
|
||||
source: Optional[str] = None
|
||||
message: str
|
||||
|
||||
|
||||
class SmoothMaskRequest(BaseModel):
|
||||
frame_id: Optional[int] = None
|
||||
mask_data: dict[str, Any]
|
||||
points: Optional[list[list[float]]] = None
|
||||
bbox: Optional[list[float]] = None
|
||||
strength: float = 0.0
|
||||
method: str = "chaikin"
|
||||
|
||||
|
||||
class SmoothMaskResponse(BaseModel):
|
||||
polygons: list[list[list[float]]]
|
||||
topology_anchor_count: int
|
||||
topology_anchors: list[list[float]]
|
||||
area: float
|
||||
bbox: Optional[list[float]] = None
|
||||
smoothing: dict[str, Any]
|
||||
message: str
|
||||
|
||||
|
||||
class PropagationSeed(BaseModel):
|
||||
polygons: Optional[list[list[list[float]]]] = None
|
||||
holes: Optional[list[list[list[list[float]]]]] = None
|
||||
bbox: Optional[list[float]] = None
|
||||
points: Optional[list[list[float]]] = None
|
||||
labels: Optional[list[int]] = None
|
||||
label: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
class_metadata: Optional[dict[str, Any]] = None
|
||||
template_id: Optional[int] = None
|
||||
source_mask_id: Optional[str] = None
|
||||
source_annotation_id: Optional[int] = None
|
||||
source_instance_id: Optional[str] = None
|
||||
propagation_seed_signature: Optional[str] = None
|
||||
smoothing: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class PropagateRequest(BaseModel):
|
||||
project_id: int
|
||||
frame_id: int
|
||||
model: Optional[str] = "sam2.1_hiera_tiny"
|
||||
seed: PropagationSeed
|
||||
direction: str = "forward"
|
||||
max_frames: int = 30
|
||||
include_source: bool = False
|
||||
save_annotations: bool = True
|
||||
|
||||
|
||||
class PropagateResponse(BaseModel):
|
||||
model: str
|
||||
direction: str
|
||||
source_frame_id: int
|
||||
processed_frame_count: int
|
||||
created_annotation_count: int
|
||||
annotations: list[AnnotationOut]
|
||||
|
||||
|
||||
class PropagateTaskStep(BaseModel):
|
||||
seed: PropagationSeed
|
||||
direction: str = "forward"
|
||||
max_frames: int = 30
|
||||
|
||||
|
||||
class PropagateTaskRequest(BaseModel):
|
||||
project_id: int
|
||||
frame_id: int
|
||||
model: Optional[str] = "sam2.1_hiera_tiny"
|
||||
steps: list[PropagateTaskStep]
|
||||
include_source: bool = False
|
||||
save_annotations: bool = True
|
||||
|
||||
|
||||
class AiModelStatus(BaseModel):
|
||||
id: str
|
||||
label: str
|
||||
available: bool
|
||||
loaded: bool = False
|
||||
device: str
|
||||
supports: list[str]
|
||||
message: str
|
||||
package_available: bool = False
|
||||
checkpoint_exists: bool = False
|
||||
checkpoint_path: Optional[str] = None
|
||||
python_ok: bool = True
|
||||
torch_ok: bool = True
|
||||
cuda_required: bool = False
|
||||
external_available: bool = False
|
||||
external_python: Optional[str] = None
|
||||
|
||||
|
||||
class GpuStatus(BaseModel):
|
||||
available: bool
|
||||
device: str
|
||||
name: Optional[str] = None
|
||||
torch_available: bool
|
||||
torch_version: Optional[str] = None
|
||||
cuda_version: Optional[str] = None
|
||||
|
||||
|
||||
class AiRuntimeStatus(BaseModel):
|
||||
selected_model: str
|
||||
gpu: GpuStatus
|
||||
models: list[AiModelStatus]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Export schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class ExportStatus(BaseModel):
|
||||
url: str
|
||||
format: str
|
||||
0
backend/services/__init__.py
Normal file
0
backend/services/__init__.py
Normal file
164
backend/services/default_templates.py
Normal file
164
backend/services/default_templates.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Bundled system ontology templates and restore helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import Template
|
||||
|
||||
RESERVED_UNCLASSIFIED_CLASS = {
|
||||
"id": "reserved-unclassified",
|
||||
"name": "待分类",
|
||||
"color": "#000000",
|
||||
"zIndex": 0,
|
||||
"maskId": 0,
|
||||
"category": "系统保留",
|
||||
}
|
||||
|
||||
|
||||
def _with_reserved_unclassified_class(classes: list[dict]) -> list[dict]:
|
||||
filtered = [
|
||||
item for item in classes
|
||||
if item.get("id") != RESERVED_UNCLASSIFIED_CLASS["id"]
|
||||
and item.get("name") != RESERVED_UNCLASSIFIED_CLASS["name"]
|
||||
and item.get("maskId") != 0
|
||||
]
|
||||
return [*filtered, dict(RESERVED_UNCLASSIFIED_CLASS)]
|
||||
|
||||
|
||||
def _template_classes(
|
||||
template_name: str,
|
||||
names: list[str],
|
||||
colors: list[tuple[int, int, int]],
|
||||
*,
|
||||
id_prefix: str,
|
||||
) -> list[dict]:
|
||||
classes = []
|
||||
for idx, (rgb, name) in enumerate(zip(colors, names)):
|
||||
color_hex = f"#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}"
|
||||
classes.append({
|
||||
"id": f"{id_prefix}-{idx}",
|
||||
"name": name,
|
||||
"color": color_hex,
|
||||
"zIndex": (len(names) - idx) * 10,
|
||||
"maskId": idx + 1,
|
||||
"category": template_name,
|
||||
})
|
||||
return classes
|
||||
|
||||
|
||||
def bundled_default_template_definitions() -> list[dict]:
|
||||
"""Return fresh definitions for all bundled system templates."""
|
||||
return [
|
||||
{
|
||||
"name": "腹腔镜胆囊切除术",
|
||||
"description": "腹腔镜胆囊切除术(LC)手术器械与解剖结构语义分割模板,共35个分类",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": _with_reserved_unclassified_class(_template_classes(
|
||||
"腹腔镜胆囊切除术",
|
||||
[
|
||||
"针", "线", "肿瘤", "血管阻断夹", "棉球", "双极电凝",
|
||||
"肝脏", "胆囊", "分离钳", "脂肪", "止血海绵", "肝总管",
|
||||
"吸引器", "剪刀", "超声刀", "止血纱布", "胆总管", "生物夹",
|
||||
"无损伤钳", "钳夹", "喷洒", "胆囊管", "动脉", "电凝",
|
||||
"静脉", "标本袋", "引流管", "纱布", "金属钛夹", "韧带",
|
||||
"肝蒂", "推结器", "乳胶管-血管阻断", "吻合器", "术中超声",
|
||||
],
|
||||
[
|
||||
(134, 124, 118), (0, 157, 142), (245, 161, 0), (255, 172, 159), (146, 175, 236), (155, 62, 0),
|
||||
(255, 91, 0), (255, 234, 0), (85, 111, 181), (155, 132, 0), (181, 227, 14), (72, 0, 255),
|
||||
(255, 0, 255), (29, 32, 136), (240, 16, 116), (160, 15, 95), (0, 155, 33), (0, 160, 233),
|
||||
(52, 184, 178), (66, 115, 82), (90, 120, 41), (255, 0, 0), (117, 0, 0), (167, 24, 233),
|
||||
(42, 8, 66), (112, 113, 150), (0, 255, 0), (255, 255, 255), (0, 255, 255), (181, 85, 105),
|
||||
(113, 102, 140), (202, 202, 200), (197, 83, 181), (136, 162, 196), (138, 251, 213),
|
||||
],
|
||||
id_prefix="cls-lap",
|
||||
)),
|
||||
},
|
||||
{
|
||||
"name": "头颈部CT分割",
|
||||
"description": "头颈部CT分割",
|
||||
"color": "#ef4444",
|
||||
"z_index": 10,
|
||||
"classes": _with_reserved_unclassified_class(_template_classes(
|
||||
"头颈部CT分割",
|
||||
[
|
||||
"肿瘤/结节",
|
||||
"下颌骨",
|
||||
"甲状腺",
|
||||
"气管",
|
||||
"颈椎",
|
||||
"颈动脉",
|
||||
"颈静脉",
|
||||
"腮腺",
|
||||
"下颌下腺",
|
||||
"舌骨",
|
||||
],
|
||||
[
|
||||
(255, 0, 0),
|
||||
(0, 255, 0),
|
||||
(0, 0, 255),
|
||||
(255, 255, 0),
|
||||
(255, 0, 255),
|
||||
(0, 255, 255),
|
||||
(255, 128, 0),
|
||||
(128, 0, 128),
|
||||
(0, 128, 128),
|
||||
(128, 128, 0),
|
||||
],
|
||||
id_prefix="cls-head-neck-ct",
|
||||
)),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _has_legacy_head_neck_english_labels(template: Template) -> bool:
|
||||
if template.name != "头颈部CT分割":
|
||||
return False
|
||||
classes = (template.mapping_rules or {}).get("classes") or []
|
||||
return any(
|
||||
isinstance(item, dict)
|
||||
and isinstance(item.get("name"), str)
|
||||
and "(" in item["name"]
|
||||
and ")" in item["name"]
|
||||
for item in classes
|
||||
)
|
||||
|
||||
|
||||
def ensure_default_templates(db: Session, *, restore_existing: bool = False) -> list[Template]:
|
||||
"""Create bundled system templates, optionally restoring existing ones exactly."""
|
||||
templates: list[Template] = []
|
||||
for definition in bundled_default_template_definitions():
|
||||
existing = db.query(Template).filter(
|
||||
Template.name == definition["name"],
|
||||
Template.owner_user_id.is_(None),
|
||||
).first()
|
||||
if existing is None:
|
||||
existing = Template(owner_user_id=None)
|
||||
db.add(existing)
|
||||
elif not restore_existing and not _has_legacy_head_neck_english_labels(existing):
|
||||
templates.append(existing)
|
||||
continue
|
||||
|
||||
existing.name = definition["name"]
|
||||
existing.description = definition["description"]
|
||||
existing.color = definition["color"]
|
||||
existing.z_index = definition["z_index"]
|
||||
existing.owner_user_id = None
|
||||
existing.mapping_rules = {
|
||||
"classes": deepcopy(definition["classes"]),
|
||||
"rules": [],
|
||||
}
|
||||
templates.append(existing)
|
||||
db.commit()
|
||||
for template in templates:
|
||||
db.refresh(template)
|
||||
return templates
|
||||
|
||||
|
||||
def restore_default_templates(db: Session) -> list[Template]:
|
||||
"""Restore bundled system templates after demo factory reset."""
|
||||
return ensure_default_templates(db, restore_existing=True)
|
||||
217
backend/services/demo_media.py
Normal file
217
backend/services/demo_media.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""Helpers for seeding the bundled demo media project."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from minio_client import upload_file
|
||||
from models import Frame, Project, User
|
||||
from services.frame_parser import (
|
||||
extract_thumbnail,
|
||||
natural_filename_key,
|
||||
parse_dicom,
|
||||
parse_video,
|
||||
upload_frames_to_minio,
|
||||
)
|
||||
from statuses import PROJECT_STATUS_PENDING, PROJECT_STATUS_READY
|
||||
|
||||
DEMO_DICOM_PROJECT_NAME = "演视DICOM序列"
|
||||
DEMO_DICOM_PARSE_FPS = 30.0
|
||||
DEMO_VIDEO_PROJECT_NAME = "演视LC视频序列"
|
||||
DEMO_VIDEO_PARSE_FPS = 30.0
|
||||
DEMO_VIDEO_TARGET_WIDTH = 640
|
||||
LEGACY_DEMO_VIDEO_PROJECT_NAMES = {"Data_MyVideo_1"}
|
||||
LEGACY_DEMO_DICOM_PROJECT_NAMES = {"演示DICOM序列"}
|
||||
|
||||
|
||||
def demo_dicom_files(dicom_dir: str) -> list[Path]:
|
||||
"""Return .dcm files in natural file-name order."""
|
||||
root = Path(dicom_dir)
|
||||
if not root.exists() or not root.is_dir():
|
||||
return []
|
||||
return sorted(
|
||||
[path for path in root.iterdir() if path.is_file() and path.name.lower().endswith(".dcm")],
|
||||
key=lambda path: natural_filename_key(path.name),
|
||||
)
|
||||
|
||||
|
||||
def create_unparsed_video_demo_project(
|
||||
db: Session,
|
||||
*,
|
||||
owner: User,
|
||||
video_path: str,
|
||||
project_name: str = DEMO_VIDEO_PROJECT_NAME,
|
||||
) -> Project:
|
||||
"""Create the bundled demo video project without extracting frames."""
|
||||
source = Path(video_path)
|
||||
if not source.exists() or not source.is_file():
|
||||
raise FileNotFoundError(f"Demo video not found: {video_path}")
|
||||
|
||||
project = Project(
|
||||
name=project_name,
|
||||
description="默认演示视频,尚未生成帧",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="video",
|
||||
parse_fps=30.0,
|
||||
original_fps=None,
|
||||
owner_user_id=owner.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.flush()
|
||||
|
||||
data = source.read_bytes()
|
||||
object_name = f"uploads/{project.id}/{source.name}"
|
||||
upload_file(object_name, data, content_type="video/mp4", length=len(data))
|
||||
project.video_path = object_name
|
||||
project.thumbnail_url = None
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
return project
|
||||
|
||||
|
||||
def create_parsed_video_demo_project(
|
||||
db: Session,
|
||||
*,
|
||||
owner: User,
|
||||
video_path: str,
|
||||
project_name: str = DEMO_VIDEO_PROJECT_NAME,
|
||||
) -> Project:
|
||||
"""Create the bundled demo video project and register its extracted frame sequence."""
|
||||
source = Path(video_path)
|
||||
if not source.exists() or not source.is_file():
|
||||
raise FileNotFoundError(f"Demo video not found: {video_path}")
|
||||
|
||||
project = Project(
|
||||
name=project_name,
|
||||
description="默认演示视频,已生成帧",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="video",
|
||||
parse_fps=DEMO_VIDEO_PARSE_FPS,
|
||||
original_fps=None,
|
||||
owner_user_id=owner.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.flush()
|
||||
|
||||
data = source.read_bytes()
|
||||
object_name = f"uploads/{project.id}/{source.name}"
|
||||
upload_file(object_name, data, content_type="video/mp4", length=len(data))
|
||||
project.video_path = object_name
|
||||
|
||||
tmp_dir = tempfile.mkdtemp(prefix=f"seg_demo_video_{project.id}_")
|
||||
try:
|
||||
output_dir = os.path.join(tmp_dir, "frames")
|
||||
frame_files, original_fps = parse_video(
|
||||
str(source),
|
||||
output_dir,
|
||||
fps=int(DEMO_VIDEO_PARSE_FPS),
|
||||
target_width=DEMO_VIDEO_TARGET_WIDTH,
|
||||
)
|
||||
project.original_fps = original_fps
|
||||
object_names = upload_frames_to_minio(frame_files, project.id)
|
||||
|
||||
for idx, obj_name in enumerate(object_names):
|
||||
image = cv2.imread(frame_files[idx])
|
||||
height, width = image.shape[:2] if image is not None else (None, None)
|
||||
db.add(Frame(
|
||||
project_id=project.id,
|
||||
frame_index=idx,
|
||||
image_url=obj_name,
|
||||
width=width,
|
||||
height=height,
|
||||
timestamp_ms=idx * 1000.0 / DEMO_VIDEO_PARSE_FPS,
|
||||
source_frame_number=idx,
|
||||
))
|
||||
|
||||
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
|
||||
try:
|
||||
extract_thumbnail(str(source), thumbnail_path)
|
||||
with open(thumbnail_path, "rb") as thumbnail_file:
|
||||
thumbnail_data = thumbnail_file.read()
|
||||
thumbnail_object = f"projects/{project.id}/thumbnail.jpg"
|
||||
upload_file(
|
||||
thumbnail_object,
|
||||
thumbnail_data,
|
||||
content_type="image/jpeg",
|
||||
length=len(thumbnail_data),
|
||||
)
|
||||
project.thumbnail_url = thumbnail_object
|
||||
except Exception: # noqa: BLE001
|
||||
if object_names:
|
||||
project.thumbnail_url = object_names[0]
|
||||
|
||||
project.status = PROJECT_STATUS_READY
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
return project
|
||||
finally:
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def create_parsed_dicom_demo_project(
|
||||
db: Session,
|
||||
*,
|
||||
owner: User,
|
||||
dicom_dir: str,
|
||||
project_name: str = DEMO_DICOM_PROJECT_NAME,
|
||||
) -> Project:
|
||||
"""Create the demo DICOM project, upload the series, and register parsed frames."""
|
||||
dcm_files = demo_dicom_files(dicom_dir)
|
||||
if not dcm_files:
|
||||
raise FileNotFoundError(f"Demo DICOM series not found: {dicom_dir}")
|
||||
|
||||
project = Project(
|
||||
name=project_name,
|
||||
description=f"默认演示 DICOM 序列,已按文件名自然顺序生成 {len(dcm_files)} 帧",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="dicom",
|
||||
parse_fps=DEMO_DICOM_PARSE_FPS,
|
||||
original_fps=None,
|
||||
owner_user_id=owner.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.flush()
|
||||
|
||||
dicom_prefix = f"uploads/{project.id}/dicom"
|
||||
for dcm_file in dcm_files:
|
||||
data = dcm_file.read_bytes()
|
||||
upload_file(
|
||||
f"{dicom_prefix}/{dcm_file.name}",
|
||||
data,
|
||||
content_type="application/dicom",
|
||||
length=len(data),
|
||||
)
|
||||
project.video_path = dicom_prefix
|
||||
|
||||
tmp_dir = tempfile.mkdtemp(prefix=f"seg_demo_dicom_{project.id}_")
|
||||
try:
|
||||
output_dir = os.path.join(tmp_dir, "frames")
|
||||
frame_files = parse_dicom(dicom_dir, output_dir)
|
||||
object_names = upload_frames_to_minio(frame_files, project.id)
|
||||
|
||||
for idx, obj_name in enumerate(object_names):
|
||||
image = cv2.imread(frame_files[idx])
|
||||
height, width = image.shape[:2] if image is not None else (None, None)
|
||||
db.add(Frame(
|
||||
project_id=project.id,
|
||||
frame_index=idx,
|
||||
image_url=obj_name,
|
||||
width=width,
|
||||
height=height,
|
||||
timestamp_ms=idx * 1000.0 / DEMO_DICOM_PARSE_FPS,
|
||||
source_frame_number=idx,
|
||||
))
|
||||
if object_names:
|
||||
project.thumbnail_url = object_names[0]
|
||||
project.status = PROJECT_STATUS_READY
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
return project
|
||||
finally:
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
237
backend/services/frame_parser.py
Normal file
237
backend/services/frame_parser.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Video/DICOM frame parsing and MinIO upload utilities."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pydicom import dcmread
|
||||
|
||||
from minio_client import upload_file, BUCKET_NAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def natural_filename_key(filename: str) -> Tuple[object, ...]:
|
||||
"""Sort file names by their visible numeric order instead of pure lexicographic order."""
|
||||
return tuple(
|
||||
int(part) if part.isdigit() else part.casefold()
|
||||
for part in re.split(r"(\d+)", Path(filename).name)
|
||||
)
|
||||
|
||||
|
||||
def get_video_fps(video_path: str) -> float:
|
||||
"""Read the original frame rate of a video file."""
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
return 30.0
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
cap.release()
|
||||
return fps if fps > 0 else 30.0
|
||||
|
||||
|
||||
def extract_thumbnail(video_path: str, output_path: str, width: int = 640) -> str:
|
||||
"""Extract the first frame of a video as a thumbnail JPEG."""
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
raise RuntimeError(f"Cannot open video for thumbnail: {video_path}")
|
||||
ret, frame = cap.read()
|
||||
cap.release()
|
||||
if not ret or frame is None:
|
||||
raise RuntimeError(f"Cannot read first frame from: {video_path}")
|
||||
|
||||
h, w = frame.shape[:2]
|
||||
if w > width:
|
||||
scale = width / w
|
||||
new_w = int(w * scale)
|
||||
new_h = int(h * scale)
|
||||
frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||||
|
||||
cv2.imwrite(output_path, frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
||||
return output_path
|
||||
|
||||
|
||||
def parse_video(
|
||||
video_path: str,
|
||||
output_dir: str,
|
||||
fps: int = 30,
|
||||
max_frames: Optional[int] = None,
|
||||
target_width: int = 640,
|
||||
) -> Tuple[List[str], float]:
|
||||
"""Extract frames from a video file using FFmpeg or OpenCV fallback.
|
||||
|
||||
Args:
|
||||
video_path: Path to the input video file.
|
||||
output_dir: Directory to save extracted frames.
|
||||
fps: Target frame extraction rate.
|
||||
max_frames: Optional maximum number of frames to extract.
|
||||
target_width: Output frame width for model-friendly frame sequences.
|
||||
|
||||
Returns:
|
||||
Tuple of (frame_paths, original_fps).
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
frame_paths: List[str] = []
|
||||
original_fps = get_video_fps(video_path)
|
||||
safe_fps = max(int(fps), 1)
|
||||
safe_width = max(int(target_width), 1)
|
||||
|
||||
# Try FFmpeg first
|
||||
if shutil.which("ffmpeg"):
|
||||
try:
|
||||
pattern = os.path.join(output_dir, "frame_%06d.jpg")
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-i", video_path,
|
||||
"-vf", f"fps={safe_fps},scale={safe_width}:-1",
|
||||
"-start_number", "0",
|
||||
"-q:v", "5",
|
||||
"-y",
|
||||
pattern,
|
||||
]
|
||||
logger.info("Running FFmpeg: %s", " ".join(cmd))
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
|
||||
if result.returncode == 0:
|
||||
frame_paths = sorted(
|
||||
[os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith(".jpg")]
|
||||
)
|
||||
if max_frames:
|
||||
frame_paths = frame_paths[:max_frames]
|
||||
logger.info("Extracted %d frames via FFmpeg", len(frame_paths))
|
||||
return frame_paths, original_fps
|
||||
else:
|
||||
logger.warning("FFmpeg failed: %s", result.stderr)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("FFmpeg exception: %s", exc)
|
||||
|
||||
# OpenCV fallback
|
||||
logger.info("Falling back to OpenCV frame extraction")
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
raise RuntimeError(f"Cannot open video: {video_path}")
|
||||
|
||||
video_fps = cap.get(cv2.CAP_PROP_FPS) or 30
|
||||
interval = max(1, int(round(video_fps / safe_fps)))
|
||||
count = 0
|
||||
saved = 0
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
if count % interval == 0:
|
||||
path = os.path.join(output_dir, f"frame_{saved:06d}.jpg")
|
||||
h, w = frame.shape[:2]
|
||||
if safe_width > 0 and w != safe_width:
|
||||
scale = safe_width / max(w, 1)
|
||||
frame = cv2.resize(frame, (safe_width, max(1, int(round(h * scale)))), interpolation=cv2.INTER_AREA)
|
||||
cv2.imwrite(path, frame, [cv2.IMWRITE_JPEG_QUALITY, 80])
|
||||
frame_paths.append(path)
|
||||
saved += 1
|
||||
if max_frames and saved >= max_frames:
|
||||
break
|
||||
count += 1
|
||||
|
||||
cap.release()
|
||||
logger.info("Extracted %d frames via OpenCV", len(frame_paths))
|
||||
return frame_paths, original_fps
|
||||
|
||||
|
||||
def parse_dicom(
|
||||
dicom_dir: str,
|
||||
output_dir: str,
|
||||
max_frames: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""Extract frames from DICOM files in a directory.
|
||||
|
||||
Args:
|
||||
dicom_dir: Directory containing .dcm files.
|
||||
output_dir: Directory to save extracted frames.
|
||||
max_frames: Optional maximum number of frames to extract.
|
||||
|
||||
Returns:
|
||||
List of paths to extracted frame images.
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
dcm_files = sorted(
|
||||
[f for f in os.listdir(dicom_dir) if f.lower().endswith(".dcm")],
|
||||
key=natural_filename_key,
|
||||
)
|
||||
|
||||
frame_paths: List[str] = []
|
||||
for idx, fname in enumerate(dcm_files):
|
||||
if max_frames and idx >= max_frames:
|
||||
break
|
||||
path = os.path.join(dicom_dir, fname)
|
||||
try:
|
||||
ds = dcmread(path)
|
||||
pixel_array = ds.pixel_array
|
||||
|
||||
# Normalize to 8-bit
|
||||
if pixel_array.dtype != np.uint8:
|
||||
pixel_array = pixel_array.astype(np.float32)
|
||||
pixel_array = (
|
||||
(pixel_array - pixel_array.min())
|
||||
/ (pixel_array.max() - pixel_array.min() + 1e-8)
|
||||
* 255
|
||||
)
|
||||
pixel_array = pixel_array.astype(np.uint8)
|
||||
|
||||
# Handle multi-frame DICOM
|
||||
if pixel_array.ndim == 3:
|
||||
for f in range(pixel_array.shape[0]):
|
||||
out_path = os.path.join(output_dir, f"frame_{idx:06d}_{f:03d}.jpg")
|
||||
cv2.imwrite(out_path, pixel_array[f], [cv2.IMWRITE_JPEG_QUALITY, 85])
|
||||
frame_paths.append(out_path)
|
||||
else:
|
||||
out_path = os.path.join(output_dir, f"frame_{idx:06d}.jpg")
|
||||
cv2.imwrite(out_path, pixel_array, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
||||
frame_paths.append(out_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to read DICOM %s: %s", path, exc)
|
||||
|
||||
logger.info("Extracted %d frames from DICOM", len(frame_paths))
|
||||
return frame_paths
|
||||
|
||||
|
||||
def upload_frames_to_minio(
|
||||
frames: List[str],
|
||||
project_id: int,
|
||||
object_prefix: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""Upload a list of local frame images to MinIO.
|
||||
|
||||
Args:
|
||||
frames: List of local file paths.
|
||||
project_id: Project ID used for bucket path organization.
|
||||
object_prefix: Optional prefix override.
|
||||
|
||||
Returns:
|
||||
List of object names (paths) in MinIO.
|
||||
"""
|
||||
prefix = object_prefix or f"projects/{project_id}/frames"
|
||||
object_names: List[str] = []
|
||||
|
||||
for frame_path in frames:
|
||||
fname = os.path.basename(frame_path)
|
||||
object_name = f"{prefix}/{fname}"
|
||||
try:
|
||||
with open(frame_path, "rb") as f:
|
||||
data = f.read()
|
||||
upload_file(
|
||||
object_name,
|
||||
data,
|
||||
content_type="image/jpeg",
|
||||
length=len(data),
|
||||
)
|
||||
object_names.append(object_name)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to upload %s: %s", frame_path, exc)
|
||||
|
||||
logger.info("Uploaded %d/%d frames to MinIO", len(object_names), len(frames))
|
||||
return object_names
|
||||
340
backend/services/media_task_runner.py
Normal file
340
backend/services/media_task_runner.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""Background media parsing runner used by Celery workers."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from minio_client import BUCKET_NAME, download_file, get_minio_client, upload_file
|
||||
from models import Annotation, Frame, Mask, ProcessingTask, Project
|
||||
from progress_events import publish_task_progress_event
|
||||
from services.frame_parser import (
|
||||
extract_thumbnail,
|
||||
natural_filename_key,
|
||||
parse_dicom,
|
||||
parse_video,
|
||||
upload_frames_to_minio,
|
||||
)
|
||||
from statuses import (
|
||||
PROJECT_STATUS_PENDING,
|
||||
PROJECT_STATUS_ERROR,
|
||||
PROJECT_STATUS_PARSING,
|
||||
PROJECT_STATUS_READY,
|
||||
TASK_STATUS_CANCELLED,
|
||||
TASK_STATUS_FAILED,
|
||||
TASK_STATUS_RUNNING,
|
||||
TASK_STATUS_SUCCESS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskCancelled(RuntimeError):
|
||||
"""Raised internally when a persisted task has been cancelled."""
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _set_task_state(
|
||||
db: Session,
|
||||
task: ProcessingTask,
|
||||
*,
|
||||
status: str | None = None,
|
||||
progress: int | None = None,
|
||||
message: str | None = None,
|
||||
result: dict[str, Any] | None = None,
|
||||
error: str | None = None,
|
||||
started: bool = False,
|
||||
finished: bool = False,
|
||||
) -> None:
|
||||
if status is not None:
|
||||
task.status = status
|
||||
if progress is not None:
|
||||
task.progress = max(0, min(100, progress))
|
||||
if message is not None:
|
||||
task.message = message
|
||||
if result is not None:
|
||||
task.result = result
|
||||
if error is not None:
|
||||
task.error = error
|
||||
if started:
|
||||
task.started_at = _now()
|
||||
if finished:
|
||||
task.finished_at = _now()
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
|
||||
|
||||
def _project_status_after_stop(project: Project) -> str:
|
||||
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
|
||||
|
||||
|
||||
def _positive_int(value: Any, default: int | None = None) -> int | None:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return parsed if parsed > 0 else default
|
||||
|
||||
|
||||
def _positive_float(value: Any, default: float) -> float:
|
||||
try:
|
||||
parsed = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return parsed if parsed > 0 else default
|
||||
|
||||
|
||||
def _frame_sequence_metadata(
|
||||
index: int,
|
||||
parse_fps: float,
|
||||
original_fps: float | None,
|
||||
) -> dict[str, float | int | None]:
|
||||
safe_parse_fps = max(float(parse_fps or 1.0), 1e-6)
|
||||
timestamp_ms = index * 1000.0 / safe_parse_fps
|
||||
source_frame_number = None
|
||||
if original_fps and original_fps > 0:
|
||||
source_frame_number = int(round(index * original_fps / safe_parse_fps))
|
||||
return {
|
||||
"timestamp_ms": timestamp_ms,
|
||||
"source_frame_number": source_frame_number,
|
||||
}
|
||||
|
||||
|
||||
def _clear_existing_project_outputs(db: Session, project: Project) -> None:
|
||||
"""Remove stale frame sequence and annotations before regenerating frames."""
|
||||
annotation_ids = db.query(Annotation.id).filter(Annotation.project_id == project.id)
|
||||
db.query(Mask).filter(Mask.annotation_id.in_(annotation_ids)).delete(synchronize_session=False)
|
||||
db.query(Annotation).filter(Annotation.project_id == project.id).delete(synchronize_session=False)
|
||||
db.query(Frame).filter(Frame.project_id == project.id).delete(synchronize_session=False)
|
||||
project.thumbnail_url = None
|
||||
db.commit()
|
||||
|
||||
|
||||
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
|
||||
db.refresh(task)
|
||||
if task.status == TASK_STATUS_CANCELLED:
|
||||
raise TaskCancelled("Task was cancelled")
|
||||
|
||||
|
||||
def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
"""Parse one project's media and update task progress in the database."""
|
||||
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
|
||||
if not task:
|
||||
raise ValueError(f"Task not found: {task_id}")
|
||||
|
||||
if task.status == TASK_STATUS_CANCELLED:
|
||||
return {
|
||||
"task_id": task.id,
|
||||
"status": TASK_STATUS_CANCELLED,
|
||||
"message": task.message or "任务已取消",
|
||||
}
|
||||
|
||||
if task.project_id is None:
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
status=TASK_STATUS_FAILED,
|
||||
progress=100,
|
||||
message="任务缺少 project_id",
|
||||
error="Task has no project_id",
|
||||
finished=True,
|
||||
)
|
||||
raise ValueError("Task has no project_id")
|
||||
|
||||
project = db.query(Project).filter(Project.id == task.project_id).first()
|
||||
if not project:
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
status=TASK_STATUS_FAILED,
|
||||
progress=100,
|
||||
message="项目不存在",
|
||||
error="Project not found",
|
||||
finished=True,
|
||||
)
|
||||
raise ValueError(f"Project not found: {task.project_id}")
|
||||
|
||||
if not project.video_path:
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
status=TASK_STATUS_FAILED,
|
||||
progress=100,
|
||||
message="项目没有可解析媒体",
|
||||
error="Project has no media uploaded",
|
||||
finished=True,
|
||||
)
|
||||
project.status = PROJECT_STATUS_ERROR
|
||||
db.commit()
|
||||
raise ValueError("Project has no media uploaded")
|
||||
|
||||
_ensure_not_cancelled(db, task)
|
||||
project.status = PROJECT_STATUS_PARSING
|
||||
_clear_existing_project_outputs(db, project)
|
||||
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="后台解析已启动", started=True)
|
||||
|
||||
payload = task.payload or {}
|
||||
effective_source = payload.get("source_type") or project.source_type or "video"
|
||||
parse_fps = _positive_float(payload.get("parse_fps"), project.parse_fps or 30.0)
|
||||
max_frames = _positive_int(payload.get("max_frames"))
|
||||
target_width = _positive_int(payload.get("target_width"), 640) or 640
|
||||
project.parse_fps = parse_fps
|
||||
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project.id}_")
|
||||
output_dir = os.path.join(tmp_dir, "frames")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
_ensure_not_cancelled(db, task)
|
||||
_set_task_state(db, task, progress=15, message="正在下载媒体文件")
|
||||
if effective_source == "dicom":
|
||||
dcm_dir = os.path.join(tmp_dir, "dcm")
|
||||
os.makedirs(dcm_dir, exist_ok=True)
|
||||
|
||||
client = get_minio_client()
|
||||
objects = sorted(
|
||||
list(client.list_objects(BUCKET_NAME, prefix=project.video_path, recursive=True)),
|
||||
key=lambda obj: natural_filename_key(obj.object_name),
|
||||
)
|
||||
for obj in objects:
|
||||
_ensure_not_cancelled(db, task)
|
||||
if obj.object_name.lower().endswith(".dcm"):
|
||||
data = download_file(obj.object_name)
|
||||
local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name))
|
||||
with open(local_dcm, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
_ensure_not_cancelled(db, task)
|
||||
_set_task_state(db, task, progress=35, message="正在解析 DICOM 序列")
|
||||
frame_files = parse_dicom(dcm_dir, output_dir, max_frames=max_frames)
|
||||
else:
|
||||
_ensure_not_cancelled(db, task)
|
||||
media_bytes = download_file(project.video_path)
|
||||
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
|
||||
with open(local_path, "wb") as f:
|
||||
f.write(media_bytes)
|
||||
|
||||
_ensure_not_cancelled(db, task)
|
||||
_set_task_state(db, task, progress=35, message="正在使用 FFmpeg/OpenCV 拆帧")
|
||||
frame_files, original_fps = parse_video(
|
||||
local_path,
|
||||
output_dir,
|
||||
fps=int(parse_fps),
|
||||
max_frames=max_frames,
|
||||
target_width=target_width,
|
||||
)
|
||||
project.original_fps = original_fps
|
||||
|
||||
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
|
||||
try:
|
||||
extract_thumbnail(local_path, thumbnail_path)
|
||||
with open(thumbnail_path, "rb") as f:
|
||||
thumb_data = f.read()
|
||||
thumb_object = f"projects/{project.id}/thumbnail.jpg"
|
||||
upload_file(thumb_object, thumb_data, content_type="image/jpeg", length=len(thumb_data))
|
||||
project.thumbnail_url = thumb_object
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Thumbnail extraction failed: %s", exc)
|
||||
|
||||
_ensure_not_cancelled(db, task)
|
||||
_set_task_state(db, task, progress=70, message="正在上传帧到对象存储")
|
||||
object_names = upload_frames_to_minio(frame_files, project.id)
|
||||
|
||||
_ensure_not_cancelled(db, task)
|
||||
_set_task_state(db, task, progress=85, message="正在写入帧索引")
|
||||
frames_out = []
|
||||
for idx, obj_name in enumerate(object_names):
|
||||
_ensure_not_cancelled(db, task)
|
||||
local_frame = frame_files[idx]
|
||||
try:
|
||||
import cv2
|
||||
|
||||
img = cv2.imread(local_frame)
|
||||
h, w = img.shape[:2] if img is not None else (None, None)
|
||||
except Exception: # noqa: BLE001
|
||||
h, w = None, None
|
||||
|
||||
sequence_meta = _frame_sequence_metadata(idx, parse_fps, project.original_fps)
|
||||
frame = Frame(
|
||||
project_id=project.id,
|
||||
frame_index=idx,
|
||||
image_url=obj_name,
|
||||
width=w,
|
||||
height=h,
|
||||
timestamp_ms=sequence_meta["timestamp_ms"],
|
||||
source_frame_number=sequence_meta["source_frame_number"],
|
||||
)
|
||||
db.add(frame)
|
||||
frames_out.append(frame)
|
||||
|
||||
project.status = PROJECT_STATUS_READY
|
||||
db.commit()
|
||||
|
||||
result = {
|
||||
"project_id": project.id,
|
||||
"frames_extracted": len(frames_out),
|
||||
"status": PROJECT_STATUS_READY,
|
||||
"message": "Frame extraction completed successfully.",
|
||||
"frame_sequence": {
|
||||
"original_fps": project.original_fps,
|
||||
"parse_fps": parse_fps,
|
||||
"frame_count": len(frames_out),
|
||||
"duration_ms": (len(frames_out) - 1) * 1000.0 / parse_fps if frames_out else 0,
|
||||
"target_width": target_width,
|
||||
"frame_width": frames_out[0].width if frames_out else None,
|
||||
"frame_height": frames_out[0].height if frames_out else None,
|
||||
"max_frames": max_frames,
|
||||
"object_prefix": f"projects/{project.id}/frames",
|
||||
},
|
||||
}
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
status=TASK_STATUS_SUCCESS,
|
||||
progress=100,
|
||||
message="解析完成",
|
||||
result=result,
|
||||
finished=True,
|
||||
)
|
||||
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project.id)
|
||||
return result
|
||||
except TaskCancelled:
|
||||
project.status = _project_status_after_stop(project)
|
||||
task.status = TASK_STATUS_CANCELLED
|
||||
task.progress = 100
|
||||
task.message = task.message or "任务已取消"
|
||||
task.error = task.error or "Cancelled by user"
|
||||
task.finished_at = task.finished_at or _now()
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
logger.info("Parse task cancelled: task_id=%s project_id=%s", task.id, project.id)
|
||||
return {
|
||||
"task_id": task.id,
|
||||
"project_id": project.id,
|
||||
"status": TASK_STATUS_CANCELLED,
|
||||
"message": task.message,
|
||||
}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
project.status = PROJECT_STATUS_ERROR
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
status=TASK_STATUS_FAILED,
|
||||
progress=100,
|
||||
message="解析失败",
|
||||
error=str(exc),
|
||||
finished=True,
|
||||
)
|
||||
logger.error("Frame extraction failed: %s", exc)
|
||||
raise
|
||||
finally:
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
842
backend/services/propagation_task_runner.py
Normal file
842
backend/services/propagation_task_runner.py
Normal file
@@ -0,0 +1,842 @@
|
||||
"""Background SAM video propagation runner used by Celery workers."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from minio_client import download_file
|
||||
from models import Annotation, Frame, ProcessingTask, Project
|
||||
from progress_events import publish_task_progress_event
|
||||
from services.sam_registry import ModelUnavailableError, sam_registry
|
||||
from statuses import (
|
||||
TASK_STATUS_CANCELLED,
|
||||
TASK_STATUS_FAILED,
|
||||
TASK_STATUS_RUNNING,
|
||||
TASK_STATUS_SUCCESS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PropagationTaskCancelled(RuntimeError):
|
||||
"""Raised internally when a persisted propagation task has been cancelled."""
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _set_task_state(
|
||||
db: Session,
|
||||
task: ProcessingTask,
|
||||
*,
|
||||
status: str | None = None,
|
||||
progress: int | None = None,
|
||||
message: str | None = None,
|
||||
result: dict[str, Any] | None = None,
|
||||
error: str | None = None,
|
||||
started: bool = False,
|
||||
finished: bool = False,
|
||||
) -> None:
|
||||
if status is not None:
|
||||
task.status = status
|
||||
if progress is not None:
|
||||
task.progress = max(0, min(100, progress))
|
||||
if message is not None:
|
||||
task.message = message
|
||||
if result is not None:
|
||||
task.result = result
|
||||
if error is not None:
|
||||
task.error = error
|
||||
if started:
|
||||
task.started_at = _now()
|
||||
if finished:
|
||||
task.finished_at = _now()
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
|
||||
|
||||
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
|
||||
db.refresh(task)
|
||||
if task.status == TASK_STATUS_CANCELLED:
|
||||
raise PropagationTaskCancelled("Task was cancelled")
|
||||
|
||||
|
||||
def _clamp01(value: float) -> float:
|
||||
return min(max(float(value), 0.0), 1.0)
|
||||
|
||||
|
||||
def _polygon_bbox(polygon: list[list[float]]) -> list[float]:
|
||||
xs = [_clamp01(point[0]) for point in polygon]
|
||||
ys = [_clamp01(point[1]) for point in polygon]
|
||||
left, right = min(xs), max(xs)
|
||||
top, bottom = min(ys), max(ys)
|
||||
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
|
||||
|
||||
|
||||
def _polygons_bbox(polygons: list[list[list[float]]]) -> list[float]:
|
||||
points = [point for polygon in polygons for point in polygon if len(point) >= 2]
|
||||
if not points:
|
||||
return [0.0, 0.0, 0.0, 0.0]
|
||||
xs = [_clamp01(point[0]) for point in points]
|
||||
ys = [_clamp01(point[1]) for point in points]
|
||||
left, right = min(xs), max(xs)
|
||||
top, bottom = min(ys), max(ys)
|
||||
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
|
||||
|
||||
|
||||
def _normalize_polygon(polygon: list[list[float]]) -> list[list[float]]:
|
||||
return [[_clamp01(point[0]), _clamp01(point[1])] for point in polygon if len(point) >= 2]
|
||||
|
||||
|
||||
def _normalize_smoothing_options(value: Any) -> dict[str, Any] | None:
|
||||
if not isinstance(value, dict):
|
||||
return None
|
||||
try:
|
||||
strength = max(0.0, min(float(value.get("strength") or 0.0), 100.0))
|
||||
except (TypeError, ValueError):
|
||||
strength = 0.0
|
||||
if strength <= 0:
|
||||
return None
|
||||
method = str(value.get("method") or "chaikin").lower()
|
||||
if method != "chaikin":
|
||||
method = "chaikin"
|
||||
return {"strength": round(strength, 2), "method": method}
|
||||
|
||||
|
||||
def _smoothing_ratio(strength: float, curve: float = 1.65) -> float:
|
||||
normalized = max(0.0, min(float(strength or 0.0), 100.0)) / 100.0
|
||||
return normalized ** curve
|
||||
|
||||
|
||||
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int, corner_cut: float = 0.25) -> list[list[float]]:
|
||||
points = _normalize_polygon(polygon)
|
||||
q = max(0.02, min(float(corner_cut), 0.25))
|
||||
for _ in range(max(0, iterations)):
|
||||
if len(points) < 3:
|
||||
break
|
||||
next_points: list[list[float]] = []
|
||||
for index, current in enumerate(points):
|
||||
following = points[(index + 1) % len(points)]
|
||||
next_points.append([
|
||||
_clamp01((1.0 - q) * current[0] + q * following[0]),
|
||||
_clamp01((1.0 - q) * current[1] + q * following[1]),
|
||||
])
|
||||
next_points.append([
|
||||
_clamp01(q * current[0] + (1.0 - q) * following[0]),
|
||||
_clamp01(q * current[1] + (1.0 - q) * following[1]),
|
||||
])
|
||||
points = next_points
|
||||
return points
|
||||
|
||||
|
||||
def _simplify_polygon(polygon: list[list[float]], strength: float) -> list[list[float]]:
|
||||
if len(polygon) < 3:
|
||||
return polygon
|
||||
contour = np.array([[[point[0], point[1]]] for point in polygon], dtype=np.float32)
|
||||
arc_length = cv2.arcLength(contour, True)
|
||||
epsilon = arc_length * (0.00015 + _smoothing_ratio(strength) * 0.00735)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True).reshape(-1, 2)
|
||||
if len(approx) < 3:
|
||||
return polygon
|
||||
return [[_clamp01(float(x)), _clamp01(float(y))] for x, y in approx]
|
||||
|
||||
|
||||
def _smooth_polygon(polygon: list[list[float]], smoothing: dict[str, Any] | None) -> list[list[float]]:
|
||||
if not smoothing:
|
||||
return _normalize_polygon(polygon)
|
||||
strength = float(smoothing.get("strength") or 0.0)
|
||||
if strength <= 0:
|
||||
return _normalize_polygon(polygon)
|
||||
effective_strength = _smoothing_ratio(strength, curve=1.45) * 100.0
|
||||
if effective_strength >= 85:
|
||||
iterations = 4
|
||||
elif effective_strength >= 55:
|
||||
iterations = 3
|
||||
elif effective_strength >= 25:
|
||||
iterations = 2
|
||||
else:
|
||||
iterations = 1
|
||||
corner_cut = 0.03 + _smoothing_ratio(strength, curve=1.35) * 0.22
|
||||
normalized = _normalize_polygon(polygon)
|
||||
pre_simplified = _simplify_polygon(normalized, effective_strength * 0.25)
|
||||
smoothed = _chaikin_smooth_polygon(pre_simplified, iterations, corner_cut)
|
||||
simplified = _simplify_polygon(smoothed, effective_strength)
|
||||
if len(simplified) > len(normalized):
|
||||
for fallback_strength in (25.0, 35.0, 50.0, 70.0, 90.0, 100.0):
|
||||
simplified = _simplify_polygon(simplified, max(effective_strength, fallback_strength))
|
||||
if len(simplified) <= len(normalized):
|
||||
break
|
||||
return simplified if len(simplified) >= 3 else _normalize_polygon(polygon)
|
||||
|
||||
|
||||
def _bbox_area(bbox: list[float]) -> float:
|
||||
return max(float(bbox[2]), 0.0) * max(float(bbox[3]), 0.0)
|
||||
|
||||
|
||||
def _bbox_overlap_ratio(a: list[float], b: list[float]) -> float:
|
||||
ax1, ay1, aw, ah = a
|
||||
bx1, by1, bw, bh = b
|
||||
ax2 = ax1 + aw
|
||||
ay2 = ay1 + ah
|
||||
bx2 = bx1 + bw
|
||||
by2 = by1 + bh
|
||||
overlap_width = max(0.0, min(ax2, bx2) - max(ax1, bx1))
|
||||
overlap_height = max(0.0, min(ay2, by2) - max(ay1, by1))
|
||||
overlap_area = overlap_width * overlap_height
|
||||
smallest_area = min(_bbox_area(a), _bbox_area(b))
|
||||
return overlap_area / smallest_area if smallest_area > 0 else 0.0
|
||||
|
||||
|
||||
def _stable_json(value: Any) -> str:
|
||||
return json.dumps(value, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||
|
||||
|
||||
def _canonicalize_signature_value(value: Any) -> Any:
|
||||
if isinstance(value, float):
|
||||
return round(value, 6)
|
||||
if isinstance(value, list):
|
||||
return [_canonicalize_signature_value(item) for item in value]
|
||||
if isinstance(value, dict):
|
||||
return {key: _canonicalize_signature_value(value[key]) for key in sorted(value)}
|
||||
return value
|
||||
|
||||
|
||||
def _seed_signature(seed: dict[str, Any]) -> str:
|
||||
"""Return a stable signature for seed geometry and semantic attrs."""
|
||||
inherited_signature = seed.get("propagation_seed_signature")
|
||||
if inherited_signature:
|
||||
return str(inherited_signature)
|
||||
signature_payload = {
|
||||
"polygons": seed.get("polygons") or [],
|
||||
"holes": seed.get("holes") or [],
|
||||
"bbox": seed.get("bbox") or [],
|
||||
"points": seed.get("points") or [],
|
||||
"labels": seed.get("labels") or [],
|
||||
"label": seed.get("label"),
|
||||
"color": seed.get("color"),
|
||||
"class_metadata": seed.get("class_metadata") or {},
|
||||
"template_id": seed.get("template_id"),
|
||||
"smoothing": _normalize_smoothing_options(seed.get("smoothing")),
|
||||
}
|
||||
return hashlib.sha256(_stable_json(_canonicalize_signature_value(signature_payload)).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _seed_key(seed: dict[str, Any]) -> str:
|
||||
"""Prefer stable persisted ids; fall back to semantic attrs for legacy callers."""
|
||||
source_instance_id = seed.get("source_instance_id")
|
||||
if source_instance_id:
|
||||
return f"instance:{source_instance_id}"
|
||||
source_annotation_id = seed.get("source_annotation_id")
|
||||
if source_annotation_id is not None:
|
||||
return f"annotation:{source_annotation_id}"
|
||||
source_mask_id = seed.get("source_mask_id")
|
||||
if source_mask_id:
|
||||
return f"mask:{source_mask_id}"
|
||||
class_metadata = seed.get("class_metadata") or {}
|
||||
class_id = class_metadata.get("id") or class_metadata.get("name")
|
||||
return _stable_json({
|
||||
"template_id": seed.get("template_id"),
|
||||
"class_id": class_id,
|
||||
"label": seed.get("label"),
|
||||
"color": seed.get("color"),
|
||||
})
|
||||
|
||||
|
||||
def _semantic_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool:
|
||||
"""Best-effort match when a manually edited replacement lacks old lineage ids."""
|
||||
class_metadata = seed.get("class_metadata") or {}
|
||||
previous_class = mask_data.get("class") or {}
|
||||
previous_class_id = previous_class.get("id") or previous_class.get("name")
|
||||
class_id = class_metadata.get("id") or class_metadata.get("name")
|
||||
if previous_class_id and class_id and str(previous_class_id) != str(class_id):
|
||||
return False
|
||||
return (
|
||||
mask_data.get("label") == seed.get("label")
|
||||
and mask_data.get("color") == seed.get("color")
|
||||
)
|
||||
|
||||
|
||||
def _legacy_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool:
|
||||
"""Best-effort match for propagation annotations created before seed keys."""
|
||||
class_metadata = seed.get("class_metadata") or {}
|
||||
previous_class = mask_data.get("class") or {}
|
||||
previous_class_id = previous_class.get("id") or previous_class.get("name")
|
||||
class_id = class_metadata.get("id") or class_metadata.get("name")
|
||||
return (
|
||||
mask_data.get("label") == seed.get("label")
|
||||
and mask_data.get("color") == seed.get("color")
|
||||
and previous_class_id == class_id
|
||||
)
|
||||
|
||||
|
||||
def _source_model_matches(mask_data: dict[str, Any], model_id: str) -> bool:
|
||||
return str(mask_data.get("source") or "") == f"{model_id}_propagation"
|
||||
|
||||
|
||||
def _seed_identity_matches(mask_data: dict[str, Any], seed_key: str, seed: dict[str, Any]) -> bool:
|
||||
previous_seed_key = mask_data.get("propagation_seed_key")
|
||||
if previous_seed_key == seed_key:
|
||||
return True
|
||||
source_instance_id = seed.get("source_instance_id")
|
||||
if source_instance_id and (
|
||||
mask_data.get("source_instance_id") == source_instance_id
|
||||
or mask_data.get("instance_id") == source_instance_id
|
||||
):
|
||||
return True
|
||||
source_annotation_id = seed.get("source_annotation_id")
|
||||
if source_annotation_id is not None and str(mask_data.get("source_annotation_id") or "") == str(source_annotation_id):
|
||||
return True
|
||||
source_mask_id = seed.get("source_mask_id")
|
||||
if source_mask_id and mask_data.get("source_mask_id") == source_mask_id:
|
||||
return True
|
||||
has_persisted_seed_identity = bool(source_instance_id) or source_annotation_id is not None or bool(source_mask_id)
|
||||
has_previous_identity = (
|
||||
bool(previous_seed_key)
|
||||
or mask_data.get("source_instance_id") is not None
|
||||
or mask_data.get("instance_id") is not None
|
||||
or mask_data.get("source_annotation_id") is not None
|
||||
or bool(mask_data.get("source_mask_id"))
|
||||
)
|
||||
if has_persisted_seed_identity or has_previous_identity:
|
||||
return False
|
||||
return _legacy_seed_matches(mask_data, seed)
|
||||
|
||||
|
||||
def _seed_identity_markers(seed: dict[str, Any]) -> set[str]:
|
||||
markers = {f"seed:{_seed_key(seed)}"}
|
||||
source_instance_id = seed.get("source_instance_id")
|
||||
if source_instance_id:
|
||||
markers.add(f"instance:{source_instance_id}")
|
||||
source_annotation_id = seed.get("source_annotation_id")
|
||||
if source_annotation_id is not None:
|
||||
markers.add(f"annotation:{source_annotation_id}")
|
||||
source_mask_id = seed.get("source_mask_id")
|
||||
if source_mask_id:
|
||||
markers.add(f"mask:{source_mask_id}")
|
||||
return markers
|
||||
|
||||
|
||||
def _mask_identity_markers(mask_data: dict[str, Any]) -> set[str]:
|
||||
markers: set[str] = set()
|
||||
previous_seed_key = mask_data.get("propagation_seed_key")
|
||||
if previous_seed_key:
|
||||
markers.add(f"seed:{previous_seed_key}")
|
||||
source_instance_id = mask_data.get("source_instance_id")
|
||||
if source_instance_id:
|
||||
markers.add(f"instance:{source_instance_id}")
|
||||
instance_id = mask_data.get("instance_id")
|
||||
if instance_id:
|
||||
markers.add(f"instance:{instance_id}")
|
||||
source_annotation_id = mask_data.get("source_annotation_id")
|
||||
if source_annotation_id is not None:
|
||||
markers.add(f"annotation:{source_annotation_id}")
|
||||
source_mask_id = mask_data.get("source_mask_id")
|
||||
if source_mask_id:
|
||||
markers.add(f"mask:{source_mask_id}")
|
||||
return markers
|
||||
|
||||
|
||||
def _payload_seed_identity_markers(payload: dict[str, Any]) -> set[str]:
|
||||
markers: set[str] = set()
|
||||
for step in payload.get("steps") or []:
|
||||
seed = step.get("seed") or {}
|
||||
markers.update(_seed_identity_markers(seed))
|
||||
return markers
|
||||
|
||||
|
||||
def _is_propagation_annotation(annotation: Annotation, seed_key: str, seed: dict[str, Any]) -> bool:
|
||||
mask_data = annotation.mask_data or {}
|
||||
source = str(mask_data.get("source") or "")
|
||||
if not source.endswith("_propagation"):
|
||||
return False
|
||||
return _seed_identity_matches(mask_data, seed_key, seed)
|
||||
|
||||
|
||||
def _direction_matches(mask_data: dict[str, Any], direction: str) -> bool:
|
||||
previous_direction = mask_data.get("propagation_direction")
|
||||
return previous_direction in {None, direction}
|
||||
|
||||
|
||||
def _annotation_spatially_matches(annotation: Annotation, polygon: list[list[float]]) -> bool:
|
||||
"""Use target-frame overlap as a final guard before replacing same-object propagation."""
|
||||
candidate_bbox = _polygon_bbox(polygon)
|
||||
for previous_polygon in (annotation.mask_data or {}).get("polygons") or []:
|
||||
if len(previous_polygon) < 3:
|
||||
continue
|
||||
if _bbox_overlap_ratio(_polygon_bbox(previous_polygon), candidate_bbox) >= 0.15:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _delete_replaced_frame_annotations(
|
||||
db: Session,
|
||||
*,
|
||||
payload: dict[str, Any],
|
||||
frame_id: int,
|
||||
seed_key: str,
|
||||
seed: dict[str, Any],
|
||||
polygon: list[list[float]],
|
||||
) -> int:
|
||||
"""Delete old propagated masks for the same object immediately before writing a new result."""
|
||||
previous_annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == int(payload["project_id"]))
|
||||
.filter(Annotation.frame_id == frame_id)
|
||||
.all()
|
||||
)
|
||||
deleted_count = 0
|
||||
current_seed_markers = _seed_identity_markers(seed)
|
||||
task_seed_markers = _payload_seed_identity_markers(payload)
|
||||
for annotation in previous_annotations:
|
||||
mask_data = annotation.mask_data or {}
|
||||
source = str(mask_data.get("source") or "")
|
||||
if not source.endswith("_propagation"):
|
||||
continue
|
||||
source_instance_id = seed.get("source_instance_id")
|
||||
mask_instance_ids = {
|
||||
str(value)
|
||||
for value in (mask_data.get("source_instance_id"), mask_data.get("instance_id"))
|
||||
if value
|
||||
}
|
||||
if source_instance_id and mask_instance_ids and str(source_instance_id) not in mask_instance_ids:
|
||||
continue
|
||||
mask_markers = _mask_identity_markers(mask_data)
|
||||
# Keep sibling seeds in the same propagation task from deleting each other.
|
||||
if mask_markers and mask_markers.isdisjoint(current_seed_markers) and not mask_markers.isdisjoint(task_seed_markers):
|
||||
continue
|
||||
same_lineage = _seed_identity_matches(mask_data, seed_key, seed)
|
||||
same_manual_replacement = (
|
||||
_semantic_seed_matches(mask_data, seed)
|
||||
and _annotation_spatially_matches(annotation, polygon)
|
||||
)
|
||||
if same_lineage or same_manual_replacement:
|
||||
db.delete(annotation)
|
||||
deleted_count += 1
|
||||
if deleted_count:
|
||||
db.commit()
|
||||
return deleted_count
|
||||
|
||||
|
||||
def _prepare_seed_propagation(
|
||||
db: Session,
|
||||
*,
|
||||
payload: dict[str, Any],
|
||||
model_id: str,
|
||||
seed: dict[str, Any],
|
||||
direction: str,
|
||||
target_frame_ids: set[int],
|
||||
) -> dict[str, Any]:
|
||||
seed_key = _seed_key(seed)
|
||||
seed_signature = _seed_signature(seed)
|
||||
if not target_frame_ids:
|
||||
return {
|
||||
"skip": True,
|
||||
"seed_key": seed_key,
|
||||
"seed_signature": seed_signature,
|
||||
"deleted_annotation_count": 0,
|
||||
}
|
||||
previous_annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == int(payload["project_id"]))
|
||||
.filter(Annotation.frame_id.in_(target_frame_ids))
|
||||
.all()
|
||||
)
|
||||
matching = [
|
||||
annotation for annotation in previous_annotations
|
||||
if _is_propagation_annotation(annotation, seed_key, seed)
|
||||
and _direction_matches(annotation.mask_data or {}, direction)
|
||||
]
|
||||
covered_frame_ids = {int(annotation.frame_id) for annotation in matching}
|
||||
if matching and all(
|
||||
(annotation.mask_data or {}).get("propagation_seed_signature") == seed_signature
|
||||
and _source_model_matches(annotation.mask_data or {}, model_id)
|
||||
for annotation in matching
|
||||
) and target_frame_ids.issubset(covered_frame_ids):
|
||||
return {
|
||||
"skip": True,
|
||||
"seed_key": seed_key,
|
||||
"seed_signature": seed_signature,
|
||||
"deleted_annotation_count": 0,
|
||||
}
|
||||
|
||||
deleted_count = 0
|
||||
if matching:
|
||||
for annotation in matching:
|
||||
db.delete(annotation)
|
||||
deleted_count += 1
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"skip": False,
|
||||
"seed_key": seed_key,
|
||||
"seed_signature": seed_signature,
|
||||
"deleted_annotation_count": deleted_count,
|
||||
}
|
||||
|
||||
|
||||
def _frame_window(
|
||||
frames: list[Frame],
|
||||
source_position: int,
|
||||
direction: str,
|
||||
max_frames: int,
|
||||
) -> tuple[list[Frame], int]:
|
||||
count = max(1, min(max_frames, len(frames)))
|
||||
if direction == "backward":
|
||||
start = max(0, source_position - count + 1)
|
||||
return frames[start:source_position + 1], source_position - start
|
||||
end = min(len(frames), source_position + count)
|
||||
return frames[source_position:end], 0
|
||||
|
||||
|
||||
def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]:
|
||||
paths = []
|
||||
for index, frame in enumerate(frames):
|
||||
data = download_file(frame.image_url)
|
||||
# SAM2VideoPredictor sorts frames by converting the filename stem to int.
|
||||
path = directory / f"{index:06d}.jpg"
|
||||
path.write_bytes(data)
|
||||
paths.append(str(path))
|
||||
return paths
|
||||
|
||||
|
||||
def _save_propagated_annotations(
|
||||
db: Session,
|
||||
*,
|
||||
payload: dict[str, Any],
|
||||
selected_frames: list[Frame],
|
||||
source_frame: Frame,
|
||||
propagated: list[dict[str, Any]],
|
||||
seed: dict[str, Any],
|
||||
) -> tuple[list[Annotation], int]:
|
||||
created: list[Annotation] = []
|
||||
if payload.get("save_annotations", True) is False:
|
||||
return created, 0
|
||||
|
||||
class_metadata = seed.get("class_metadata")
|
||||
template_id = seed.get("template_id")
|
||||
label = seed.get("label") or "Propagated Mask"
|
||||
color = seed.get("color") or "#06b6d4"
|
||||
model_id = sam_registry.normalize_model_id(payload.get("model"))
|
||||
include_source = bool(payload.get("include_source", False))
|
||||
seed_key = _seed_key(seed)
|
||||
seed_signature = _seed_signature(seed)
|
||||
source_annotation_id = seed.get("source_annotation_id")
|
||||
source_mask_id = seed.get("source_mask_id")
|
||||
source_instance_id = seed.get("source_instance_id") or seed_key
|
||||
smoothing = _normalize_smoothing_options(seed.get("smoothing"))
|
||||
direction = str(payload.get("current_direction") or "")
|
||||
deleted_count = 0
|
||||
cleaned_frame_ids: set[int] = set()
|
||||
|
||||
for frame_result in propagated:
|
||||
relative_index = int(frame_result.get("frame_index", -1))
|
||||
if relative_index < 0 or relative_index >= len(selected_frames):
|
||||
continue
|
||||
frame = selected_frames[relative_index]
|
||||
if not include_source and frame.id == source_frame.id:
|
||||
continue
|
||||
result_polygons = frame_result.get("polygons") or []
|
||||
result_holes = frame_result.get("holes") or []
|
||||
scores = frame_result.get("scores") or []
|
||||
prepared_polygons = [
|
||||
(polygon_index, _smooth_polygon(polygon, smoothing))
|
||||
for polygon_index, polygon in enumerate(result_polygons)
|
||||
if len(polygon) >= 3
|
||||
]
|
||||
cleanup_polygon = next((polygon for _polygon_index, polygon in prepared_polygons if len(polygon) >= 3), None)
|
||||
if cleanup_polygon is not None and frame.id not in cleaned_frame_ids:
|
||||
deleted_count += _delete_replaced_frame_annotations(
|
||||
db,
|
||||
payload=payload,
|
||||
frame_id=int(frame.id),
|
||||
seed_key=seed_key,
|
||||
seed=seed,
|
||||
polygon=cleanup_polygon,
|
||||
)
|
||||
cleaned_frame_ids.add(int(frame.id))
|
||||
polygons_to_save: list[list[list[float]]] = []
|
||||
holes_to_save: list[list[list[list[float]]]] = []
|
||||
score_values: list[float] = []
|
||||
for polygon_index, polygon in prepared_polygons:
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
polygons_to_save.append(polygon)
|
||||
hole_group = result_holes[polygon_index] if polygon_index < len(result_holes) and isinstance(result_holes[polygon_index], list) else []
|
||||
holes_to_save.append(hole_group if isinstance(hole_group, list) else [])
|
||||
if polygon_index < len(scores):
|
||||
try:
|
||||
score_values.append(float(scores[polygon_index]))
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
if not polygons_to_save:
|
||||
continue
|
||||
annotation = Annotation(
|
||||
project_id=int(payload["project_id"]),
|
||||
frame_id=frame.id,
|
||||
template_id=template_id,
|
||||
mask_data={
|
||||
"polygons": polygons_to_save,
|
||||
**({"holes": holes_to_save, "hasHoles": True} if any(holes_to_save) else {}),
|
||||
"label": label,
|
||||
"color": color,
|
||||
"source": f"{model_id}_propagation",
|
||||
"propagated_from_frame_id": source_frame.id,
|
||||
"propagated_from_frame_index": source_frame.frame_index,
|
||||
"propagation_seed_key": seed_key,
|
||||
"propagation_seed_signature": seed_signature,
|
||||
"propagation_direction": direction,
|
||||
"instance_id": source_instance_id,
|
||||
"source_instance_id": source_instance_id,
|
||||
"source_annotation_id": source_annotation_id,
|
||||
"source_mask_id": source_mask_id,
|
||||
"score": max(score_values) if score_values else None,
|
||||
**({"scores": score_values} if len(score_values) > 1 else {}),
|
||||
**({"geometry_smoothing": smoothing} if smoothing else {}),
|
||||
**({"class": class_metadata} if class_metadata else {}),
|
||||
},
|
||||
points=None,
|
||||
bbox=_polygons_bbox(polygons_to_save),
|
||||
)
|
||||
db.add(annotation)
|
||||
created.append(annotation)
|
||||
|
||||
db.commit()
|
||||
for annotation in created:
|
||||
db.refresh(annotation)
|
||||
return created, deleted_count
|
||||
|
||||
|
||||
def _run_one_step(
|
||||
db: Session,
|
||||
*,
|
||||
payload: dict[str, Any],
|
||||
frames: list[Frame],
|
||||
source_frame: Frame,
|
||||
source_position: int,
|
||||
step: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
direction = str(step.get("direction") or "forward").lower()
|
||||
if direction not in {"forward", "backward"}:
|
||||
raise ValueError("direction must be forward or backward")
|
||||
max_frames = max(1, min(int(step.get("max_frames") or payload.get("max_frames") or 30), 500))
|
||||
seed = step.get("seed") or {}
|
||||
if not (seed.get("polygons") or seed.get("bbox") or seed.get("points")):
|
||||
raise ValueError("Propagation requires seed polygons, bbox, or points")
|
||||
|
||||
model_id = sam_registry.normalize_model_id(payload.get("model"))
|
||||
selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames)
|
||||
include_source = bool(payload.get("include_source", False))
|
||||
target_frame_ids = {
|
||||
int(frame.id)
|
||||
for frame in selected_frames
|
||||
if include_source or frame.id != source_frame.id
|
||||
}
|
||||
seed_state = _prepare_seed_propagation(
|
||||
db,
|
||||
payload=payload,
|
||||
model_id=model_id,
|
||||
seed=seed,
|
||||
direction=direction,
|
||||
target_frame_ids=target_frame_ids,
|
||||
)
|
||||
if seed_state["skip"]:
|
||||
return {
|
||||
"model": model_id,
|
||||
"direction": direction,
|
||||
"processed_frame_count": 0,
|
||||
"created_annotation_count": 0,
|
||||
"deleted_annotation_count": 0,
|
||||
"skipped_seed_count": 1,
|
||||
"seed_label": seed.get("label"),
|
||||
"seed_key": seed_state["seed_key"],
|
||||
}
|
||||
|
||||
with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload['project_id']}_") as tmpdir:
|
||||
frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir))
|
||||
propagated = sam_registry.propagate_video(
|
||||
model_id,
|
||||
frame_paths,
|
||||
source_relative_index,
|
||||
seed,
|
||||
direction,
|
||||
len(selected_frames),
|
||||
)
|
||||
|
||||
save_payload = {**payload, "current_direction": direction}
|
||||
created, write_cleanup_count = _save_propagated_annotations(
|
||||
db,
|
||||
payload=save_payload,
|
||||
selected_frames=selected_frames,
|
||||
source_frame=source_frame,
|
||||
propagated=propagated,
|
||||
seed=seed,
|
||||
)
|
||||
return {
|
||||
"model": model_id,
|
||||
"direction": direction,
|
||||
"processed_frame_count": len(selected_frames),
|
||||
"created_annotation_count": len(created),
|
||||
"deleted_annotation_count": int(seed_state["deleted_annotation_count"]) + write_cleanup_count,
|
||||
"skipped_seed_count": 0,
|
||||
"seed_label": seed.get("label"),
|
||||
"seed_key": seed_state["seed_key"],
|
||||
}
|
||||
|
||||
|
||||
def run_propagate_project_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
"""Run one queued SAM propagation task and update persisted progress."""
|
||||
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
|
||||
if not task:
|
||||
raise ValueError(f"Task not found: {task_id}")
|
||||
|
||||
if task.status == TASK_STATUS_CANCELLED:
|
||||
return {"task_id": task.id, "status": TASK_STATUS_CANCELLED, "message": task.message or "任务已取消"}
|
||||
|
||||
payload = task.payload or {}
|
||||
project_id = int(payload.get("project_id") or task.project_id or 0)
|
||||
source_frame_id = int(payload.get("frame_id") or 0)
|
||||
try:
|
||||
model_id = sam_registry.normalize_model_id(payload.get("model"))
|
||||
except ValueError as exc:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
|
||||
raise
|
||||
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="项目不存在", error="Project not found", finished=True)
|
||||
raise ValueError(f"Project not found: {project_id}")
|
||||
|
||||
source_frame = db.query(Frame).filter(Frame.id == source_frame_id, Frame.project_id == project_id).first()
|
||||
if not source_frame:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不存在", error="Frame not found", finished=True)
|
||||
raise ValueError(f"Frame not found: {source_frame_id}")
|
||||
|
||||
frames = db.query(Frame).filter(Frame.project_id == project_id).order_by(Frame.frame_index).all()
|
||||
source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None)
|
||||
if source_position is None:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不在项目帧序列中", error="Source frame is not in project frame sequence", finished=True)
|
||||
raise ValueError("Source frame is not in project frame sequence")
|
||||
|
||||
steps = payload.get("steps") or []
|
||||
if not steps:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="传播任务缺少步骤", error="Propagation task has no steps", finished=True)
|
||||
raise ValueError("Propagation task has no steps")
|
||||
|
||||
_ensure_not_cancelled(db, task)
|
||||
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="自动传播任务已启动", started=True)
|
||||
|
||||
step_results: list[dict[str, Any]] = []
|
||||
created_count = 0
|
||||
processed_count = 0
|
||||
deleted_count = 0
|
||||
skipped_count = 0
|
||||
total_steps = len(steps)
|
||||
|
||||
try:
|
||||
for index, step in enumerate(steps, start=1):
|
||||
_ensure_not_cancelled(db, task)
|
||||
seed_label = (step.get("seed") or {}).get("label") or "mask"
|
||||
direction_label = "向前传播" if step.get("direction") == "backward" else "向后传播"
|
||||
progress_before = 5 + int(((index - 1) / total_steps) * 90)
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
progress=progress_before,
|
||||
message=f"{direction_label} {seed_label} ({index}/{total_steps})",
|
||||
result={
|
||||
"project_id": project_id,
|
||||
"source_frame_id": source_frame_id,
|
||||
"model": model_id,
|
||||
"total_steps": total_steps,
|
||||
"completed_steps": index - 1,
|
||||
"processed_frame_count": processed_count,
|
||||
"created_annotation_count": created_count,
|
||||
"deleted_annotation_count": deleted_count,
|
||||
"skipped_seed_count": skipped_count,
|
||||
"steps": step_results,
|
||||
},
|
||||
)
|
||||
|
||||
result = _run_one_step(
|
||||
db,
|
||||
payload=payload,
|
||||
frames=frames,
|
||||
source_frame=source_frame,
|
||||
source_position=source_position,
|
||||
step=step,
|
||||
)
|
||||
step_results.append(result)
|
||||
created_count += int(result["created_annotation_count"])
|
||||
processed_count += int(result["processed_frame_count"])
|
||||
deleted_count += int(result.get("deleted_annotation_count") or 0)
|
||||
skipped_count += int(result.get("skipped_seed_count") or 0)
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
progress=5 + int((index / total_steps) * 90),
|
||||
message=f"{direction_label} {seed_label} 完成 ({index}/{total_steps})",
|
||||
result={
|
||||
"project_id": project_id,
|
||||
"source_frame_id": source_frame_id,
|
||||
"model": model_id,
|
||||
"total_steps": total_steps,
|
||||
"completed_steps": index,
|
||||
"processed_frame_count": processed_count,
|
||||
"created_annotation_count": created_count,
|
||||
"deleted_annotation_count": deleted_count,
|
||||
"skipped_seed_count": skipped_count,
|
||||
"steps": step_results,
|
||||
},
|
||||
)
|
||||
|
||||
result = {
|
||||
"project_id": project_id,
|
||||
"source_frame_id": source_frame_id,
|
||||
"model": model_id,
|
||||
"total_steps": total_steps,
|
||||
"completed_steps": total_steps,
|
||||
"processed_frame_count": processed_count,
|
||||
"created_annotation_count": created_count,
|
||||
"deleted_annotation_count": deleted_count,
|
||||
"skipped_seed_count": skipped_count,
|
||||
"steps": step_results,
|
||||
}
|
||||
_set_task_state(
|
||||
db,
|
||||
task,
|
||||
status=TASK_STATUS_SUCCESS,
|
||||
progress=100,
|
||||
message="自动传播完成" if created_count > 0 else (
|
||||
"自动传播完成,未改变的 mask 已跳过" if skipped_count > 0 else "自动传播完成,但没有生成新的 mask"
|
||||
),
|
||||
result=result,
|
||||
finished=True,
|
||||
)
|
||||
return result
|
||||
except PropagationTaskCancelled:
|
||||
task.status = TASK_STATUS_CANCELLED
|
||||
task.progress = 100
|
||||
task.message = task.message or "任务已取消"
|
||||
task.error = task.error or "Cancelled by user"
|
||||
task.finished_at = task.finished_at or _now()
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
return {"task_id": task.id, "project_id": project_id, "status": TASK_STATUS_CANCELLED, "message": task.message}
|
||||
except (ModelUnavailableError, NotImplementedError, ValueError) as exc:
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
|
||||
raise
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("Propagation task failed: task_id=%s", task.id)
|
||||
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
|
||||
raise
|
||||
690
backend/services/sam2_engine.py
Normal file
690
backend/services/sam2_engine.py
Normal file
@@ -0,0 +1,690 @@
|
||||
"""SAM 2 engine wrapper with lazy loading and explicit runtime status."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SAM2_MODEL_ID = "sam2.1_hiera_tiny"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SAM2Variant:
|
||||
"""One selectable SAM 2.1 runtime variant."""
|
||||
|
||||
id: str
|
||||
label: str
|
||||
short_label: str
|
||||
config: str
|
||||
legacy_config: str
|
||||
checkpoint_filename: str
|
||||
legacy_checkpoint_filename: str
|
||||
|
||||
|
||||
SAM2_VARIANTS: dict[str, SAM2Variant] = {
|
||||
"sam2.1_hiera_tiny": SAM2Variant(
|
||||
id="sam2.1_hiera_tiny",
|
||||
label="SAM 2.1 Tiny",
|
||||
short_label="tiny",
|
||||
config="configs/sam2.1/sam2.1_hiera_t.yaml",
|
||||
legacy_config="configs/sam2/sam2_hiera_t.yaml",
|
||||
checkpoint_filename="sam2.1_hiera_tiny.pt",
|
||||
legacy_checkpoint_filename="sam2_hiera_tiny.pt",
|
||||
),
|
||||
"sam2.1_hiera_small": SAM2Variant(
|
||||
id="sam2.1_hiera_small",
|
||||
label="SAM 2.1 Small",
|
||||
short_label="small",
|
||||
config="configs/sam2.1/sam2.1_hiera_s.yaml",
|
||||
legacy_config="configs/sam2/sam2_hiera_s.yaml",
|
||||
checkpoint_filename="sam2.1_hiera_small.pt",
|
||||
legacy_checkpoint_filename="sam2_hiera_small.pt",
|
||||
),
|
||||
"sam2.1_hiera_base_plus": SAM2Variant(
|
||||
id="sam2.1_hiera_base_plus",
|
||||
label="SAM 2.1 Base+",
|
||||
short_label="base+",
|
||||
config="configs/sam2.1/sam2.1_hiera_b+.yaml",
|
||||
legacy_config="configs/sam2/sam2_hiera_b+.yaml",
|
||||
checkpoint_filename="sam2.1_hiera_base_plus.pt",
|
||||
legacy_checkpoint_filename="sam2_hiera_base_plus.pt",
|
||||
),
|
||||
"sam2.1_hiera_large": SAM2Variant(
|
||||
id="sam2.1_hiera_large",
|
||||
label="SAM 2.1 Large",
|
||||
short_label="large",
|
||||
config="configs/sam2.1/sam2.1_hiera_l.yaml",
|
||||
legacy_config="configs/sam2/sam2_hiera_l.yaml",
|
||||
checkpoint_filename="sam2.1_hiera_large.pt",
|
||||
legacy_checkpoint_filename="sam2_hiera_large.pt",
|
||||
),
|
||||
}
|
||||
|
||||
SAM2_MODEL_ALIASES = {
|
||||
"sam2": DEFAULT_SAM2_MODEL_ID,
|
||||
"sam2.1": DEFAULT_SAM2_MODEL_ID,
|
||||
"sam2_tiny": DEFAULT_SAM2_MODEL_ID,
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attempt to import PyTorch and SAM 2; fall back to stubs if unavailable.
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
import torch
|
||||
|
||||
TORCH_AVAILABLE = True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
TORCH_AVAILABLE = False
|
||||
torch = None # type: ignore[assignment]
|
||||
logger.warning("PyTorch import failed (%s). SAM2 will be unavailable.", exc)
|
||||
|
||||
try:
|
||||
from sam2.build_sam import build_sam2
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
|
||||
SAM2_AVAILABLE = True
|
||||
logger.info("SAM2 library imported successfully.")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
SAM2_AVAILABLE = False
|
||||
logger.warning("SAM2 import failed (%s). Using stub engine.", exc)
|
||||
|
||||
|
||||
class SAM2Engine:
|
||||
"""Lazy-loaded SAM 2 inference engine."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._predictors: dict[str, Optional[SAM2ImagePredictor]] = {}
|
||||
self._video_predictors: dict[str, object | None] = {}
|
||||
self._model_loaded: dict[str, bool] = {}
|
||||
self._video_model_loaded: dict[str, bool] = {}
|
||||
self._loaded_device: dict[str, str] = {}
|
||||
self._last_error: dict[str, str | None] = {}
|
||||
self._video_last_error: dict[str, str | None] = {}
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# -----------------------------------------------------------------------
|
||||
def variant_ids(self) -> list[str]:
|
||||
return list(SAM2_VARIANTS.keys())
|
||||
|
||||
def normalize_model_id(self, model_id: str | None) -> str:
|
||||
selected = (model_id or settings.sam_default_model or DEFAULT_SAM2_MODEL_ID).lower()
|
||||
selected = SAM2_MODEL_ALIASES.get(selected, selected)
|
||||
if selected not in SAM2_VARIANTS:
|
||||
raise ValueError(f"Unsupported SAM2 model: {model_id}")
|
||||
return selected
|
||||
|
||||
def is_sam2_model(self, model_id: str | None) -> bool:
|
||||
try:
|
||||
self.normalize_model_id(model_id)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _models_dir(self) -> Path:
|
||||
configured_path = Path(settings.sam_model_path)
|
||||
return configured_path.parent if configured_path.parent else Path("models")
|
||||
|
||||
def _variant(self, model_id: str | None) -> SAM2Variant:
|
||||
return SAM2_VARIANTS[self.normalize_model_id(model_id)]
|
||||
|
||||
def _checkpoint_config(self, model_id: str | None) -> tuple[str, str]:
|
||||
variant_id = self.normalize_model_id(model_id)
|
||||
variant = SAM2_VARIANTS[variant_id]
|
||||
models_dir = self._models_dir()
|
||||
candidates: list[tuple[str, str]] = []
|
||||
|
||||
configured_path = Path(settings.sam_model_path)
|
||||
if variant_id == DEFAULT_SAM2_MODEL_ID and configured_path.is_file():
|
||||
candidates.append((settings.sam_model_config, str(configured_path)))
|
||||
|
||||
candidates.extend([
|
||||
(variant.config, str(models_dir / variant.checkpoint_filename)),
|
||||
(variant.legacy_config, str(models_dir / variant.legacy_checkpoint_filename)),
|
||||
])
|
||||
|
||||
for config, checkpoint_path in candidates:
|
||||
if os.path.isfile(checkpoint_path):
|
||||
return config, checkpoint_path
|
||||
return candidates[0]
|
||||
|
||||
def _load_model(self, model_id: str | None = None) -> None:
|
||||
"""Load the SAM 2 model and predictor on first use."""
|
||||
variant_id = self.normalize_model_id(model_id)
|
||||
if self._model_loaded.get(variant_id):
|
||||
return
|
||||
|
||||
if not TORCH_AVAILABLE:
|
||||
self._last_error[variant_id] = "PyTorch is not installed."
|
||||
logger.warning("PyTorch not available; skipping SAM2 model load.")
|
||||
self._model_loaded[variant_id] = True
|
||||
return
|
||||
|
||||
if not SAM2_AVAILABLE:
|
||||
self._last_error[variant_id] = "sam2 package is not installed."
|
||||
logger.warning("SAM2 not available; skipping model load.")
|
||||
self._model_loaded[variant_id] = True
|
||||
return
|
||||
|
||||
config, checkpoint_path = self._checkpoint_config(variant_id)
|
||||
if not os.path.isfile(checkpoint_path):
|
||||
self._last_error[variant_id] = f"SAM2 checkpoint not found: {checkpoint_path}"
|
||||
logger.error("SAM checkpoint not found at %s", checkpoint_path)
|
||||
self._model_loaded[variant_id] = True
|
||||
return
|
||||
|
||||
try:
|
||||
device = self._best_device()
|
||||
model = build_sam2(
|
||||
config,
|
||||
checkpoint_path,
|
||||
device=device,
|
||||
)
|
||||
self._predictors[variant_id] = SAM2ImagePredictor(model)
|
||||
self._model_loaded[variant_id] = True
|
||||
self._loaded_device[variant_id] = device
|
||||
self._last_error[variant_id] = None
|
||||
logger.info("SAM 2 model %s loaded from %s on %s", variant_id, checkpoint_path, device)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self._last_error[variant_id] = str(exc)
|
||||
logger.error("Failed to load SAM 2 model %s: %s", variant_id, exc)
|
||||
self._model_loaded[variant_id] = True # Prevent repeated load attempts
|
||||
|
||||
def _load_video_model(self, model_id: str | None = None) -> None:
|
||||
"""Load the SAM 2 video predictor on first propagation use."""
|
||||
variant_id = self.normalize_model_id(model_id)
|
||||
if self._video_model_loaded.get(variant_id):
|
||||
return
|
||||
|
||||
if not TORCH_AVAILABLE:
|
||||
self._video_last_error[variant_id] = "PyTorch is not installed."
|
||||
self._video_model_loaded[variant_id] = True
|
||||
return
|
||||
if not SAM2_AVAILABLE:
|
||||
self._video_last_error[variant_id] = "sam2 package is not installed."
|
||||
self._video_model_loaded[variant_id] = True
|
||||
return
|
||||
|
||||
config, checkpoint_path = self._checkpoint_config(variant_id)
|
||||
if not os.path.isfile(checkpoint_path):
|
||||
self._video_last_error[variant_id] = f"SAM2 checkpoint not found: {checkpoint_path}"
|
||||
self._video_model_loaded[variant_id] = True
|
||||
return
|
||||
|
||||
try:
|
||||
device = self._best_device()
|
||||
self._video_predictors[variant_id] = build_sam2_video_predictor(
|
||||
config,
|
||||
checkpoint_path,
|
||||
device=device,
|
||||
)
|
||||
self._video_model_loaded[variant_id] = True
|
||||
self._loaded_device[variant_id] = device
|
||||
self._video_last_error[variant_id] = None
|
||||
logger.info("SAM 2 video predictor %s loaded from %s on %s", variant_id, checkpoint_path, device)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self._video_last_error[variant_id] = str(exc)
|
||||
self._video_model_loaded[variant_id] = True
|
||||
logger.error("Failed to load SAM 2 video predictor %s: %s", variant_id, exc)
|
||||
|
||||
def _best_device(self) -> str:
|
||||
if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available():
|
||||
return "cuda"
|
||||
return "cpu"
|
||||
|
||||
def _ensure_ready(self, model_id: str | None = None) -> bool:
|
||||
"""Ensure the model is loaded; return whether it is usable."""
|
||||
variant_id = self.normalize_model_id(model_id)
|
||||
self._load_model(variant_id)
|
||||
return SAM2_AVAILABLE and self._predictors.get(variant_id) is not None
|
||||
|
||||
def _ensure_video_ready(self, model_id: str | None = None) -> bool:
|
||||
"""Ensure the video predictor is loaded; return whether it is usable."""
|
||||
variant_id = self.normalize_model_id(model_id)
|
||||
self._load_video_model(variant_id)
|
||||
return SAM2_AVAILABLE and self._video_predictors.get(variant_id) is not None
|
||||
|
||||
def status(self, model_id: str | None = None) -> dict:
|
||||
"""Return lightweight, real runtime status without forcing model load."""
|
||||
variant_id = self.normalize_model_id(model_id)
|
||||
variant = SAM2_VARIANTS[variant_id]
|
||||
_, checkpoint_path = self._checkpoint_config(variant_id)
|
||||
checkpoint_exists = os.path.isfile(checkpoint_path)
|
||||
using_legacy_checkpoint = Path(checkpoint_path).name == variant.legacy_checkpoint_filename
|
||||
predictor = self._predictors.get(variant_id)
|
||||
device = self._loaded_device.get(variant_id) or self._best_device()
|
||||
available = bool(TORCH_AVAILABLE and SAM2_AVAILABLE and checkpoint_exists)
|
||||
if predictor is not None:
|
||||
message = f"{variant.label} model loaded and ready."
|
||||
elif available:
|
||||
message = f"{variant.label} dependencies and checkpoint are present; model will load on first inference."
|
||||
if using_legacy_checkpoint:
|
||||
message += " Using legacy SAM 2 checkpoint fallback."
|
||||
else:
|
||||
missing = []
|
||||
if not TORCH_AVAILABLE:
|
||||
missing.append("PyTorch")
|
||||
if not SAM2_AVAILABLE:
|
||||
missing.append("sam2 package")
|
||||
if not checkpoint_exists:
|
||||
missing.append("checkpoint")
|
||||
message = f"{variant.label} unavailable: missing {', '.join(missing)}."
|
||||
last_error = self._last_error.get(variant_id)
|
||||
if last_error and not predictor:
|
||||
message = last_error
|
||||
return {
|
||||
"id": variant.id,
|
||||
"label": variant.label,
|
||||
"available": available,
|
||||
"loaded": predictor is not None,
|
||||
"device": device,
|
||||
"supports": ["point", "box", "interactive", "auto", "propagate"],
|
||||
"message": message,
|
||||
"package_available": SAM2_AVAILABLE,
|
||||
"checkpoint_exists": checkpoint_exists,
|
||||
"checkpoint_path": checkpoint_path,
|
||||
"python_ok": True,
|
||||
"torch_ok": TORCH_AVAILABLE,
|
||||
"cuda_required": False,
|
||||
}
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Public API
|
||||
# -----------------------------------------------------------------------
|
||||
def predict_points(
|
||||
self,
|
||||
model_id: str | None,
|
||||
image: np.ndarray,
|
||||
points: list[list[float]],
|
||||
labels: list[int],
|
||||
) -> tuple[list[list[list[float]]], list[float]]:
|
||||
"""Run point-prompt segmentation.
|
||||
|
||||
Args:
|
||||
image: HWC numpy array (uint8).
|
||||
points: List of [x, y] normalized coordinates (0-1).
|
||||
labels: 1 for foreground, 0 for background.
|
||||
|
||||
Returns:
|
||||
Tuple of (polygons, scores).
|
||||
"""
|
||||
variant_id = self.normalize_model_id(model_id)
|
||||
if not self._ensure_ready(variant_id):
|
||||
logger.warning("SAM2 not ready; returning dummy masks.")
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
try:
|
||||
predictor = self._predictors[variant_id]
|
||||
h, w = image.shape[:2]
|
||||
pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32)
|
||||
lbls = np.array(labels, dtype=np.int32)
|
||||
|
||||
with torch.inference_mode(): # type: ignore[name-defined]
|
||||
predictor.set_image(image)
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=pts,
|
||||
point_labels=lbls,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
polygons = []
|
||||
for m in masks:
|
||||
poly = self._mask_to_polygon(m)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
return polygons, scores.tolist()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("SAM2 point prediction failed: %s", exc)
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
def predict_box(
|
||||
self,
|
||||
model_id: str | None,
|
||||
image: np.ndarray,
|
||||
box: list[float],
|
||||
) -> tuple[list[list[list[float]]], list[float]]:
|
||||
"""Run box-prompt segmentation.
|
||||
|
||||
Args:
|
||||
image: HWC numpy array (uint8).
|
||||
box: [x1, y1, x2, y2] normalized coordinates.
|
||||
|
||||
Returns:
|
||||
Tuple of (polygons, scores).
|
||||
"""
|
||||
variant_id = self.normalize_model_id(model_id)
|
||||
if not self._ensure_ready(variant_id):
|
||||
logger.warning("SAM2 not ready; returning dummy masks.")
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
try:
|
||||
predictor = self._predictors[variant_id]
|
||||
h, w = image.shape[:2]
|
||||
bbox = np.array(
|
||||
[box[0] * w, box[1] * h, box[2] * w, box[3] * h],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
with torch.inference_mode(): # type: ignore[name-defined]
|
||||
predictor.set_image(image)
|
||||
masks, scores, _ = predictor.predict(
|
||||
box=bbox[None, :],
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
polygons = []
|
||||
for m in masks:
|
||||
poly = self._mask_to_polygon(m)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
return polygons, scores.tolist()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("SAM2 box prediction failed: %s", exc)
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
def predict_interactive(
|
||||
self,
|
||||
model_id: str | None,
|
||||
image: np.ndarray,
|
||||
box: list[float] | None,
|
||||
points: list[list[float]],
|
||||
labels: list[int],
|
||||
) -> tuple[list[list[list[float]]], list[float]]:
|
||||
"""Run combined box and point prompt segmentation for refinement."""
|
||||
variant_id = self.normalize_model_id(model_id)
|
||||
if not self._ensure_ready(variant_id):
|
||||
logger.warning("SAM2 not ready; returning dummy masks.")
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
try:
|
||||
predictor = self._predictors[variant_id]
|
||||
h, w = image.shape[:2]
|
||||
bbox = None
|
||||
if box:
|
||||
bbox = np.array(
|
||||
[box[0] * w, box[1] * h, box[2] * w, box[3] * h],
|
||||
dtype=np.float32,
|
||||
)
|
||||
pts = None
|
||||
lbls = None
|
||||
if points:
|
||||
pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32)
|
||||
lbls = np.array(labels, dtype=np.int32)
|
||||
|
||||
with torch.inference_mode(): # type: ignore[name-defined]
|
||||
predictor.set_image(image)
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=pts,
|
||||
point_labels=lbls,
|
||||
box=bbox,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
polygons = []
|
||||
for m in masks:
|
||||
poly = self._mask_to_polygon(m)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
return polygons, scores.tolist()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("SAM2 interactive prediction failed: %s", exc)
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
def predict_auto(self, model_id: str | None, image: np.ndarray) -> tuple[list[list[list[float]]], list[float]]:
|
||||
"""Run automatic mask generation (grid of points).
|
||||
|
||||
Args:
|
||||
image: HWC numpy array (uint8).
|
||||
|
||||
Returns:
|
||||
Tuple of (polygons, scores).
|
||||
"""
|
||||
variant_id = self.normalize_model_id(model_id)
|
||||
if not self._ensure_ready(variant_id):
|
||||
logger.warning("SAM2 not ready; returning dummy masks.")
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
try:
|
||||
predictor = self._predictors[variant_id]
|
||||
with torch.inference_mode(): # type: ignore[name-defined]
|
||||
predictor.set_image(image)
|
||||
# Generate a uniform 16x16 grid of point prompts
|
||||
h, w = image.shape[:2]
|
||||
grid = np.mgrid[0:1:17j, 0:1:17j].reshape(2, -1).T
|
||||
pts = grid * np.array([w, h])
|
||||
lbls = np.ones(pts.shape[0], dtype=np.int32)
|
||||
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=pts,
|
||||
point_labels=lbls,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
polygons = []
|
||||
for m in masks[:1]:
|
||||
poly = self._mask_to_polygon(m)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
return polygons, scores[:1].tolist()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("SAM2 auto prediction failed: %s", exc)
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
def propagate_video(
|
||||
self,
|
||||
model_id: str | None,
|
||||
frame_paths: list[str],
|
||||
source_frame_index: int,
|
||||
seed: dict,
|
||||
direction: str = "forward",
|
||||
max_frames: int | None = None,
|
||||
) -> list[dict]:
|
||||
"""Propagate one seed mask across a prepared frame directory with SAM 2 video."""
|
||||
variant_id = self.normalize_model_id(model_id)
|
||||
if not self._ensure_video_ready(variant_id):
|
||||
raise RuntimeError(self._video_last_error.get(variant_id) or self.status(variant_id)["message"])
|
||||
video_predictor = self._video_predictors[variant_id]
|
||||
if not frame_paths:
|
||||
return []
|
||||
if source_frame_index < 0 or source_frame_index >= len(frame_paths):
|
||||
raise ValueError("source_frame_index is outside the frame sequence.")
|
||||
|
||||
import cv2
|
||||
|
||||
source_image = cv2.imread(frame_paths[source_frame_index])
|
||||
if source_image is None:
|
||||
raise RuntimeError("Failed to decode source frame for SAM 2 propagation.")
|
||||
height, width = source_image.shape[:2]
|
||||
seed_mask = self._polygons_to_mask(seed.get("polygons") or [], width, height, seed.get("holes") or [])
|
||||
if not seed_mask.any():
|
||||
bbox = seed.get("bbox")
|
||||
if isinstance(bbox, list) and len(bbox) == 4:
|
||||
seed_mask = self._bbox_to_mask(bbox, width, height)
|
||||
if not seed_mask.any():
|
||||
raise ValueError("SAM 2 propagation requires a non-empty seed polygon or bbox.")
|
||||
|
||||
inference_state = video_predictor.init_state(
|
||||
video_path=os.path.dirname(frame_paths[0]),
|
||||
offload_video_to_cpu=True,
|
||||
offload_state_to_cpu=True,
|
||||
)
|
||||
video_predictor.add_new_mask(
|
||||
inference_state,
|
||||
frame_idx=source_frame_index,
|
||||
obj_id=1,
|
||||
mask=seed_mask,
|
||||
)
|
||||
|
||||
results: dict[int, dict] = {}
|
||||
|
||||
def collect(reverse: bool) -> None:
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(
|
||||
inference_state,
|
||||
start_frame_idx=source_frame_index,
|
||||
max_frame_num_to_track=max_frames,
|
||||
reverse=reverse,
|
||||
):
|
||||
masks = out_mask_logits
|
||||
if hasattr(masks, "detach"):
|
||||
masks = masks.detach().cpu().numpy()
|
||||
masks = np.asarray(masks)
|
||||
if masks.ndim == 4:
|
||||
masks = masks[:, 0]
|
||||
polygons = []
|
||||
holes = []
|
||||
scores = []
|
||||
for mask in masks:
|
||||
mask_polygons, mask_holes = self._mask_to_polygon_data(mask > 0)
|
||||
for polygon_index, polygon in enumerate(mask_polygons):
|
||||
polygons.append(polygon)
|
||||
holes.append(mask_holes[polygon_index] if polygon_index < len(mask_holes) else [])
|
||||
scores.append(1.0)
|
||||
results[int(out_frame_idx)] = {
|
||||
"frame_index": int(out_frame_idx),
|
||||
"polygons": polygons,
|
||||
"holes": holes,
|
||||
"scores": scores,
|
||||
"object_ids": [int(obj_id) for obj_id in list(out_obj_ids)],
|
||||
}
|
||||
|
||||
normalized_direction = direction.lower()
|
||||
if normalized_direction in {"forward", "both"}:
|
||||
collect(reverse=False)
|
||||
if normalized_direction in {"backward", "both"}:
|
||||
collect(reverse=True)
|
||||
|
||||
try:
|
||||
video_predictor.reset_state(inference_state)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
return [results[index] for index in sorted(results)]
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -----------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
|
||||
"""Convert a binary mask to a normalized polygon."""
|
||||
polygons, _holes = SAM2Engine._mask_to_polygon_data(mask)
|
||||
return polygons[0] if polygons else []
|
||||
|
||||
@staticmethod
|
||||
def _mask_to_polygon_data(mask: np.ndarray) -> tuple[list[list[list[float]]], list[list[list[list[float]]]]]:
|
||||
"""Convert a binary mask to normalized outer polygons and aligned hole rings."""
|
||||
import cv2
|
||||
|
||||
if mask.dtype != np.uint8:
|
||||
mask = (mask > 0).astype(np.uint8)
|
||||
contours, hierarchy = cv2.findContours(mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
|
||||
h, w = mask.shape[:2]
|
||||
if hierarchy is None:
|
||||
return [], []
|
||||
|
||||
def contour_to_polygon(contour: np.ndarray) -> list[list[float]]:
|
||||
if len(contour) < 3:
|
||||
return []
|
||||
return [[float(pt[0][0]) / w, float(pt[0][1]) / h] for pt in contour]
|
||||
|
||||
hierarchy_rows = hierarchy[0]
|
||||
outer_indices = [
|
||||
index for index, row in enumerate(hierarchy_rows)
|
||||
if int(row[3]) < 0 and len(contours[index]) >= 3
|
||||
]
|
||||
outer_indices.sort(key=lambda index: cv2.contourArea(contours[index]), reverse=True)
|
||||
|
||||
polygons: list[list[list[float]]] = []
|
||||
holes: list[list[list[list[float]]]] = []
|
||||
for outer_index in outer_indices:
|
||||
outer = contour_to_polygon(contours[outer_index])
|
||||
if not outer:
|
||||
continue
|
||||
child_index = int(hierarchy_rows[outer_index][2])
|
||||
hole_group: list[list[list[float]]] = []
|
||||
while child_index >= 0:
|
||||
hole = contour_to_polygon(contours[child_index])
|
||||
if hole:
|
||||
hole_group.append(hole)
|
||||
child_index = int(hierarchy_rows[child_index][0])
|
||||
polygons.append(outer)
|
||||
holes.append(hole_group)
|
||||
return polygons, holes
|
||||
|
||||
@staticmethod
|
||||
def _dummy_polygons(w: int, h: int) -> list[list[list[float]]]:
|
||||
"""Return a dummy rectangle polygon for fallback mode."""
|
||||
return [
|
||||
[
|
||||
[0.25, 0.25],
|
||||
[0.75, 0.25],
|
||||
[0.75, 0.75],
|
||||
[0.25, 0.75],
|
||||
]
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _polygons_to_mask(
|
||||
polygons: list[list[list[float]]],
|
||||
width: int,
|
||||
height: int,
|
||||
holes_by_polygon: list[list[list[list[float]]]] | None = None,
|
||||
) -> np.ndarray:
|
||||
import cv2
|
||||
|
||||
mask = np.zeros((height, width), dtype=np.uint8)
|
||||
for polygon_index, polygon in enumerate(polygons):
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
pts = np.array(
|
||||
[
|
||||
[
|
||||
int(round(min(max(float(x), 0.0), 1.0) * max(width - 1, 1))),
|
||||
int(round(min(max(float(y), 0.0), 1.0) * max(height - 1, 1))),
|
||||
]
|
||||
for x, y in polygon
|
||||
],
|
||||
dtype=np.int32,
|
||||
)
|
||||
cv2.fillPoly(mask, [pts], 1)
|
||||
holes = holes_by_polygon[polygon_index] if holes_by_polygon and polygon_index < len(holes_by_polygon) else []
|
||||
for hole in holes:
|
||||
if len(hole) < 3:
|
||||
continue
|
||||
hole_pts = np.array(
|
||||
[
|
||||
[
|
||||
int(round(min(max(float(x), 0.0), 1.0) * max(width - 1, 1))),
|
||||
int(round(min(max(float(y), 0.0), 1.0) * max(height - 1, 1))),
|
||||
]
|
||||
for x, y in hole
|
||||
],
|
||||
dtype=np.int32,
|
||||
)
|
||||
cv2.fillPoly(mask, [hole_pts], 0)
|
||||
return mask.astype(bool)
|
||||
|
||||
@staticmethod
|
||||
def _bbox_to_mask(bbox: list[float], width: int, height: int) -> np.ndarray:
|
||||
x, y, w, h = [min(max(float(value), 0.0), 1.0) for value in bbox]
|
||||
left = int(round(x * max(width - 1, 1)))
|
||||
top = int(round(y * max(height - 1, 1)))
|
||||
right = int(round(min(x + w, 1.0) * max(width - 1, 1)))
|
||||
bottom = int(round(min(y + h, 1.0) * max(height - 1, 1)))
|
||||
mask = np.zeros((height, width), dtype=bool)
|
||||
mask[top:max(bottom + 1, top + 1), left:max(right + 1, left + 1)] = True
|
||||
return mask
|
||||
|
||||
|
||||
# Singleton instance
|
||||
sam_engine = SAM2Engine()
|
||||
447
backend/services/sam3_engine.py
Normal file
447
backend/services/sam3_engine.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""SAM 3 engine adapter and runtime status.
|
||||
|
||||
The official facebookresearch/sam3 package currently targets Python 3.12+
|
||||
and CUDA-capable PyTorch. This adapter reports those requirements honestly and
|
||||
only performs inference when the local runtime can actually import and execute
|
||||
the package.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from config import settings
|
||||
from services.sam2_engine import SAM2Engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
TORCH_AVAILABLE = True
|
||||
except Exception as exc: # noqa: BLE001
|
||||
TORCH_AVAILABLE = False
|
||||
torch = None # type: ignore[assignment]
|
||||
logger.warning("PyTorch import failed (%s). SAM3 will be unavailable.", exc)
|
||||
|
||||
SAM3_PACKAGE_AVAILABLE = importlib.util.find_spec("sam3") is not None
|
||||
|
||||
|
||||
class SAM3Engine:
|
||||
"""Lazy SAM 3 image inference adapter."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._model: Any | None = None
|
||||
self._processor: Any | None = None
|
||||
self._model_loaded = False
|
||||
self._last_error: str | None = None
|
||||
self._external_status_cache: dict[str, Any] | None = None
|
||||
self._external_status_checked_at = 0.0
|
||||
|
||||
def _python_ok(self) -> bool:
|
||||
return sys.version_info >= (3, 12)
|
||||
|
||||
def _gpu_ok(self) -> bool:
|
||||
return bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
|
||||
|
||||
def _checkpoint_path(self) -> str | None:
|
||||
path = settings.sam3_checkpoint_path.strip()
|
||||
return path if path else None
|
||||
|
||||
def _checkpoint_exists(self) -> bool:
|
||||
path = self._checkpoint_path()
|
||||
return bool(path and os.path.isfile(path))
|
||||
|
||||
def _can_load(self) -> bool:
|
||||
return bool(
|
||||
SAM3_PACKAGE_AVAILABLE
|
||||
and TORCH_AVAILABLE
|
||||
and self._python_ok()
|
||||
and self._gpu_ok()
|
||||
and self._checkpoint_exists()
|
||||
)
|
||||
|
||||
def _worker_path(self) -> Path:
|
||||
return Path(__file__).with_name("sam3_external_worker.py")
|
||||
|
||||
def _external_python_exists(self) -> bool:
|
||||
return bool(settings.sam3_external_enabled and os.path.isfile(settings.sam3_external_python))
|
||||
|
||||
def _external_status(self, force: bool = False) -> dict[str, Any]:
|
||||
now = time.monotonic()
|
||||
if (
|
||||
not force
|
||||
and self._external_status_cache is not None
|
||||
and now - self._external_status_checked_at < settings.sam3_status_cache_seconds
|
||||
):
|
||||
return self._external_status_cache
|
||||
|
||||
if not settings.sam3_external_enabled:
|
||||
status = {
|
||||
"available": False,
|
||||
"package_available": False,
|
||||
"python_ok": False,
|
||||
"torch_ok": False,
|
||||
"cuda_available": False,
|
||||
"device": "unavailable",
|
||||
"message": "SAM 3 external runtime is disabled.",
|
||||
}
|
||||
elif not self._external_python_exists():
|
||||
status = {
|
||||
"available": False,
|
||||
"package_available": False,
|
||||
"python_ok": False,
|
||||
"torch_ok": False,
|
||||
"cuda_available": False,
|
||||
"device": "unavailable",
|
||||
"message": f"SAM 3 external Python not found: {settings.sam3_external_python}",
|
||||
}
|
||||
else:
|
||||
try:
|
||||
env = os.environ.copy()
|
||||
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
|
||||
if self._checkpoint_path():
|
||||
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
|
||||
completed = subprocess.run(
|
||||
[settings.sam3_external_python, str(self._worker_path()), "--status"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=min(settings.sam3_timeout_seconds, 30),
|
||||
check=False,
|
||||
env=env,
|
||||
)
|
||||
if completed.returncode != 0:
|
||||
detail = completed.stderr.strip() or completed.stdout.strip()
|
||||
status = {
|
||||
"available": False,
|
||||
"package_available": False,
|
||||
"python_ok": False,
|
||||
"torch_ok": False,
|
||||
"cuda_available": False,
|
||||
"device": "unavailable",
|
||||
"message": f"SAM 3 external status failed: {detail}",
|
||||
}
|
||||
else:
|
||||
status = json.loads(completed.stdout)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
status = {
|
||||
"available": False,
|
||||
"package_available": False,
|
||||
"python_ok": False,
|
||||
"torch_ok": False,
|
||||
"cuda_available": False,
|
||||
"device": "unavailable",
|
||||
"message": f"SAM 3 external status failed: {exc}",
|
||||
}
|
||||
|
||||
self._external_status_cache = status
|
||||
self._external_status_checked_at = now
|
||||
return status
|
||||
|
||||
def _load_model(self) -> None:
|
||||
if self._model_loaded:
|
||||
return
|
||||
if not self._can_load():
|
||||
self._last_error = self._status_message()
|
||||
self._model_loaded = True
|
||||
return
|
||||
|
||||
try:
|
||||
from sam3.model.sam3_image_processor import Sam3Processor
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
|
||||
self._model = build_sam3_image_model(
|
||||
checkpoint_path=self._checkpoint_path(),
|
||||
load_from_HF=False,
|
||||
)
|
||||
self._processor = Sam3Processor(self._model)
|
||||
self._model_loaded = True
|
||||
self._last_error = None
|
||||
logger.info("SAM 3 image model loaded with version setting %s", settings.sam3_model_version)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self._last_error = str(exc)
|
||||
self._model_loaded = True
|
||||
logger.error("Failed to load SAM 3 model: %s", exc)
|
||||
|
||||
def _ensure_ready(self) -> bool:
|
||||
self._load_model()
|
||||
return self._processor is not None
|
||||
|
||||
def _status_message(self) -> str:
|
||||
missing = []
|
||||
if not SAM3_PACKAGE_AVAILABLE:
|
||||
missing.append("sam3 package")
|
||||
if not self._python_ok():
|
||||
missing.append("Python 3.12+ runtime")
|
||||
if not TORCH_AVAILABLE:
|
||||
missing.append("PyTorch")
|
||||
if not self._gpu_ok():
|
||||
missing.append("CUDA GPU")
|
||||
if not self._checkpoint_exists():
|
||||
missing.append(f"local checkpoint ({settings.sam3_checkpoint_path})")
|
||||
if missing:
|
||||
return f"SAM 3 unavailable: missing {', '.join(missing)}."
|
||||
return "SAM 3 dependencies are present; model will load on first inference."
|
||||
|
||||
def status(self) -> dict:
|
||||
external_status = self._external_status()
|
||||
available = bool(self._can_load() or external_status.get("available"))
|
||||
external_ready = bool(external_status.get("available"))
|
||||
message = self._last_error or self._status_message()
|
||||
if self._processor is not None:
|
||||
message = "SAM 3 model loaded and ready."
|
||||
elif external_ready:
|
||||
message = "SAM 3 external runtime is ready; local checkpoint will load in the helper process on inference."
|
||||
elif external_status.get("message") and not self._can_load():
|
||||
message = str(external_status["message"])
|
||||
return {
|
||||
"id": "sam3",
|
||||
"label": "SAM 3",
|
||||
"available": available,
|
||||
"loaded": self._processor is not None,
|
||||
"device": "cuda" if self._gpu_ok() else str(external_status.get("device", "unavailable")),
|
||||
"supports": ["semantic", "box", "video_track"],
|
||||
"message": message,
|
||||
"package_available": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("package_available")),
|
||||
"checkpoint_exists": bool(self._checkpoint_exists() or external_status.get("checkpoint_access")),
|
||||
"checkpoint_path": self._checkpoint_path() or f"official/HuggingFace ({settings.sam3_model_version})",
|
||||
"python_ok": bool(self._python_ok() or external_status.get("python_ok")),
|
||||
"torch_ok": bool(TORCH_AVAILABLE or external_status.get("torch_ok")),
|
||||
"cuda_required": True,
|
||||
"external_available": external_ready,
|
||||
"external_python": settings.sam3_external_python if settings.sam3_external_enabled else None,
|
||||
}
|
||||
|
||||
def _xyxy_to_cxcywh(self, box: list[float]) -> list[float]:
|
||||
if len(box) != 4:
|
||||
raise ValueError("SAM 3 box prompt requires [x1, y1, x2, y2].")
|
||||
x1, y1, x2, y2 = [min(max(float(value), 0.0), 1.0) for value in box]
|
||||
left, right = sorted([x1, x2])
|
||||
top, bottom = sorted([y1, y2])
|
||||
width = max(right - left, 1e-6)
|
||||
height = max(bottom - top, 1e-6)
|
||||
return [left + width / 2, top + height / 2, width, height]
|
||||
|
||||
def _prediction_to_polygons(self, output: Any) -> tuple[list[list[list[float]]], list[float]]:
|
||||
masks = output.get("masks", [])
|
||||
scores = output.get("scores", [])
|
||||
polygons = []
|
||||
for mask in masks:
|
||||
if hasattr(mask, "detach"):
|
||||
mask = mask.detach().cpu().numpy()
|
||||
if mask.ndim == 3:
|
||||
mask = mask[0]
|
||||
poly = SAM2Engine._mask_to_polygon(mask)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
if hasattr(scores, "detach"):
|
||||
scores = scores.detach().cpu().tolist()
|
||||
elif hasattr(scores, "tolist"):
|
||||
scores = scores.tolist()
|
||||
return polygons, list(scores)
|
||||
|
||||
def _predict_external(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
prompt_type: str,
|
||||
*,
|
||||
text: str = "",
|
||||
box: list[float] | None = None,
|
||||
confidence_threshold: float | None = None,
|
||||
) -> tuple[list[list[list[float]]], list[float]]:
|
||||
status = self._external_status(force=True)
|
||||
if not status.get("available"):
|
||||
raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.")
|
||||
|
||||
with tempfile.TemporaryDirectory(prefix="sam3_") as tmpdir:
|
||||
tmp_path = Path(tmpdir)
|
||||
image_path = tmp_path / "image.png"
|
||||
request_path = tmp_path / "request.json"
|
||||
Image.fromarray(image).save(image_path)
|
||||
request_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"image_path": str(image_path),
|
||||
"prompt_type": prompt_type,
|
||||
"text": text.strip(),
|
||||
"box": box,
|
||||
"model_version": settings.sam3_model_version,
|
||||
"checkpoint_path": self._checkpoint_path(),
|
||||
"confidence_threshold": (
|
||||
confidence_threshold
|
||||
if confidence_threshold is not None
|
||||
else settings.sam3_confidence_threshold
|
||||
),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
env = os.environ.copy()
|
||||
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
|
||||
if self._checkpoint_path():
|
||||
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
|
||||
completed = subprocess.run(
|
||||
[settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=settings.sam3_timeout_seconds,
|
||||
check=False,
|
||||
env=env,
|
||||
)
|
||||
|
||||
if completed.returncode != 0:
|
||||
detail = completed.stderr.strip() or completed.stdout.strip()
|
||||
try:
|
||||
parsed = json.loads(detail)
|
||||
detail = parsed.get("error", detail)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
raise RuntimeError(f"SAM 3 external inference failed: {detail}")
|
||||
|
||||
payload = json.loads(completed.stdout)
|
||||
if payload.get("error"):
|
||||
raise RuntimeError(str(payload["error"]))
|
||||
return payload.get("polygons", []), payload.get("scores", [])
|
||||
|
||||
def _predict_semantic_external(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
text: str,
|
||||
confidence_threshold: float | None = None,
|
||||
) -> tuple[list[list[list[float]]], list[float]]:
|
||||
return self._predict_external(
|
||||
image,
|
||||
"semantic",
|
||||
text=text,
|
||||
confidence_threshold=confidence_threshold,
|
||||
)
|
||||
|
||||
def _predict_box_external(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]:
|
||||
return self._predict_external(image, "box", box=box)
|
||||
|
||||
def _propagate_video_external(
|
||||
self,
|
||||
frame_paths: list[str],
|
||||
source_frame_index: int,
|
||||
seed: dict[str, Any],
|
||||
direction: str,
|
||||
max_frames: int | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
status = self._external_status(force=True)
|
||||
if not status.get("available"):
|
||||
raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.")
|
||||
if not frame_paths:
|
||||
return []
|
||||
|
||||
with tempfile.TemporaryDirectory(prefix="sam3_video_") as tmpdir:
|
||||
request_path = Path(tmpdir) / "request.json"
|
||||
request_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"prompt_type": "video_track",
|
||||
"frame_dir": str(Path(frame_paths[0]).parent),
|
||||
"source_frame_index": source_frame_index,
|
||||
"seed": seed,
|
||||
"direction": direction,
|
||||
"max_frames": max_frames,
|
||||
"model_version": settings.sam3_model_version,
|
||||
"checkpoint_path": self._checkpoint_path(),
|
||||
"confidence_threshold": settings.sam3_confidence_threshold,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
env = os.environ.copy()
|
||||
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
|
||||
if self._checkpoint_path():
|
||||
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
|
||||
completed = subprocess.run(
|
||||
[settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=settings.sam3_timeout_seconds,
|
||||
check=False,
|
||||
env=env,
|
||||
)
|
||||
|
||||
if completed.returncode != 0:
|
||||
detail = completed.stderr.strip() or completed.stdout.strip()
|
||||
try:
|
||||
parsed = json.loads(detail)
|
||||
detail = parsed.get("error", detail)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
raise RuntimeError(f"SAM 3 external video tracking failed: {detail}")
|
||||
|
||||
payload = json.loads(completed.stdout)
|
||||
if payload.get("error"):
|
||||
raise RuntimeError(str(payload["error"]))
|
||||
return payload.get("frames", [])
|
||||
|
||||
def predict_semantic(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
text: str,
|
||||
confidence_threshold: float | None = None,
|
||||
) -> tuple[list[list[list[float]]], list[float]]:
|
||||
if not text.strip():
|
||||
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
|
||||
if not self._can_load() and self._external_status().get("available"):
|
||||
return self._predict_semantic_external(image, text, confidence_threshold=confidence_threshold)
|
||||
if not self._ensure_ready():
|
||||
raise RuntimeError(self.status()["message"])
|
||||
|
||||
pil_image = Image.fromarray(image)
|
||||
with torch.inference_mode(): # type: ignore[union-attr]
|
||||
state = self._processor.set_image(pil_image)
|
||||
output = self._processor.set_text_prompt(state=state, prompt=text.strip())
|
||||
|
||||
return self._prediction_to_polygons(output)
|
||||
|
||||
def predict_points(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
|
||||
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for point prompts.")
|
||||
|
||||
def predict_box(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]:
|
||||
if not self._can_load() and self._external_status().get("available"):
|
||||
return self._predict_box_external(image, box)
|
||||
if not self._ensure_ready():
|
||||
raise RuntimeError(self.status()["message"])
|
||||
|
||||
pil_image = Image.fromarray(image)
|
||||
with torch.inference_mode(): # type: ignore[union-attr]
|
||||
state = self._processor.set_image(pil_image)
|
||||
output = self._processor.add_geometric_prompt(
|
||||
state=state,
|
||||
box=self._xyxy_to_cxcywh(box),
|
||||
label=True,
|
||||
)
|
||||
|
||||
return self._prediction_to_polygons(output)
|
||||
|
||||
def propagate_video(
|
||||
self,
|
||||
frame_paths: list[str],
|
||||
source_frame_index: int,
|
||||
seed: dict[str, Any],
|
||||
direction: str = "forward",
|
||||
max_frames: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
return self._propagate_video_external(frame_paths, source_frame_index, seed, direction, max_frames)
|
||||
|
||||
|
||||
sam3_engine = SAM3Engine()
|
||||
343
backend/services/sam3_external_worker.py
Normal file
343
backend/services/sam3_external_worker.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""Standalone SAM 3 helper for the dedicated Python 3.12 runtime.
|
||||
|
||||
The main FastAPI backend can keep running in the existing Python 3.11/SAM 2
|
||||
environment while this helper is executed with a separate conda env that meets
|
||||
SAM 3's stricter runtime requirements.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def _torch_status() -> tuple[bool, str | None, str | None, str | None]:
|
||||
try:
|
||||
import torch
|
||||
|
||||
cuda_available = bool(torch.cuda.is_available())
|
||||
return (
|
||||
cuda_available,
|
||||
getattr(torch, "__version__", None),
|
||||
getattr(torch.version, "cuda", None),
|
||||
torch.cuda.get_device_name(0) if cuda_available else None,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
return False, None, None, None
|
||||
|
||||
|
||||
def _compact_error(exc: Exception) -> str:
|
||||
lines = [line.strip() for line in str(exc).splitlines() if line.strip()]
|
||||
for line in lines:
|
||||
if "Access to model" in line or "Cannot access gated repo" in line:
|
||||
return line
|
||||
return lines[0] if lines else exc.__class__.__name__
|
||||
|
||||
|
||||
def _checkpoint_access(model_version: str) -> tuple[bool, str | None]:
|
||||
checkpoint_path = os.environ.get("SAM3_CHECKPOINT_PATH", "").strip()
|
||||
if checkpoint_path:
|
||||
path = Path(checkpoint_path)
|
||||
if path.is_file():
|
||||
return True, None
|
||||
return False, f"local checkpoint not found: {checkpoint_path}"
|
||||
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
repo_id = "facebook/sam3.1" if model_version == "sam3.1" else "facebook/sam3"
|
||||
hf_hub_download(repo_id=repo_id, filename="config.json")
|
||||
return True, None
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return False, _compact_error(exc)
|
||||
|
||||
|
||||
def runtime_status() -> dict[str, Any]:
|
||||
model_version = os.environ.get("SAM3_MODEL_VERSION", "sam3")
|
||||
checkpoint_path = os.environ.get("SAM3_CHECKPOINT_PATH", "").strip() or None
|
||||
package_error = None
|
||||
package_available = importlib.util.find_spec("sam3") is not None
|
||||
if package_available:
|
||||
try:
|
||||
import sam3 # noqa: F401
|
||||
except Exception as exc: # noqa: BLE001
|
||||
package_available = False
|
||||
package_error = str(exc)
|
||||
cuda_available, torch_version, cuda_version, device_name = _torch_status()
|
||||
python_ok = sys.version_info >= (3, 12)
|
||||
checkpoint_access = False
|
||||
checkpoint_error = None
|
||||
if package_available:
|
||||
checkpoint_access, checkpoint_error = _checkpoint_access(model_version)
|
||||
available = bool(package_available and python_ok and cuda_available and checkpoint_access)
|
||||
missing = []
|
||||
if not python_ok:
|
||||
missing.append("Python 3.12+ runtime")
|
||||
if not package_available:
|
||||
missing.append(f"sam3 package ({package_error})" if package_error else "sam3 package")
|
||||
if torch_version is None:
|
||||
missing.append("PyTorch")
|
||||
if not cuda_available:
|
||||
missing.append("CUDA GPU")
|
||||
if package_available and not checkpoint_access:
|
||||
missing.append(f"Hugging Face checkpoint access ({checkpoint_error})")
|
||||
return {
|
||||
"available": available,
|
||||
"package_available": package_available,
|
||||
"checkpoint_access": checkpoint_access,
|
||||
"checkpoint_path": checkpoint_path or f"official/HuggingFace ({model_version})",
|
||||
"python_ok": python_ok,
|
||||
"torch_ok": torch_version is not None,
|
||||
"torch_version": torch_version,
|
||||
"cuda_version": cuda_version,
|
||||
"cuda_available": cuda_available,
|
||||
"device": "cuda" if cuda_available else "unavailable",
|
||||
"device_name": device_name,
|
||||
"message": (
|
||||
"SAM 3 external runtime is ready."
|
||||
if available
|
||||
else f"SAM 3 external runtime unavailable: missing {', '.join(missing)}."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
|
||||
import cv2
|
||||
|
||||
if mask.dtype != np.uint8:
|
||||
mask = (mask > 0).astype(np.uint8)
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
height, width = mask.shape[:2]
|
||||
largest = []
|
||||
for contour in contours:
|
||||
if len(contour) > len(largest):
|
||||
largest = contour
|
||||
if len(largest) < 3:
|
||||
return []
|
||||
return [[float(point[0][0]) / width, float(point[0][1]) / height] for point in largest]
|
||||
|
||||
|
||||
def _to_numpy(value: Any) -> np.ndarray:
|
||||
if hasattr(value, "detach"):
|
||||
value = value.detach()
|
||||
if hasattr(value, "is_floating_point") and value.is_floating_point():
|
||||
value = value.float()
|
||||
value = value.cpu().numpy()
|
||||
elif hasattr(value, "cpu"):
|
||||
value = value.cpu()
|
||||
if hasattr(value, "is_floating_point") and value.is_floating_point():
|
||||
value = value.float()
|
||||
value = value.numpy()
|
||||
return np.asarray(value)
|
||||
|
||||
|
||||
def _xyxy_to_cxcywh(box: list[float]) -> list[float]:
|
||||
if len(box) != 4:
|
||||
raise ValueError("SAM 3 box prompt requires [x1, y1, x2, y2].")
|
||||
x1, y1, x2, y2 = [min(max(float(value), 0.0), 1.0) for value in box]
|
||||
left, right = sorted([x1, x2])
|
||||
top, bottom = sorted([y1, y2])
|
||||
width = max(right - left, 1e-6)
|
||||
height = max(bottom - top, 1e-6)
|
||||
return [left + width / 2, top + height / 2, width, height]
|
||||
|
||||
|
||||
def _bbox_from_seed(seed: dict[str, Any]) -> list[float]:
|
||||
bbox = seed.get("bbox")
|
||||
if isinstance(bbox, list) and len(bbox) == 4:
|
||||
return [min(max(float(value), 0.0), 1.0) for value in bbox]
|
||||
|
||||
polygons = seed.get("polygons") or []
|
||||
points = [point for polygon in polygons for point in polygon if len(point) >= 2]
|
||||
if not points:
|
||||
raise ValueError("SAM 3 video tracking requires seed bbox or polygons.")
|
||||
xs = [min(max(float(point[0]), 0.0), 1.0) for point in points]
|
||||
ys = [min(max(float(point[1]), 0.0), 1.0) for point in points]
|
||||
left, right = min(xs), max(xs)
|
||||
top, bottom = min(ys), max(ys)
|
||||
return [left, top, max(right - left, 1e-6), max(bottom - top, 1e-6)]
|
||||
|
||||
|
||||
def _video_outputs_to_response(outputs: dict[str, Any]) -> dict[str, Any]:
|
||||
masks = _to_numpy(outputs.get("out_binary_masks", []))
|
||||
scores = _to_numpy(outputs.get("out_probs", []))
|
||||
obj_ids = _to_numpy(outputs.get("out_obj_ids", []))
|
||||
if masks.ndim == 4:
|
||||
masks = masks[:, 0]
|
||||
elif masks.ndim == 2:
|
||||
masks = masks[None, ...]
|
||||
|
||||
polygons = []
|
||||
out_scores = []
|
||||
out_ids = []
|
||||
for index, mask in enumerate(masks):
|
||||
polygon = _mask_to_polygon(mask)
|
||||
if polygon:
|
||||
polygons.append(polygon)
|
||||
out_scores.append(float(scores[index]) if scores.size > index else 1.0)
|
||||
out_ids.append(int(obj_ids[index]) if obj_ids.size > index else index + 1)
|
||||
return {"polygons": polygons, "scores": out_scores, "object_ids": out_ids}
|
||||
|
||||
|
||||
def _prediction_to_response(output: dict[str, Any]) -> dict[str, Any]:
|
||||
masks = _to_numpy(output.get("masks", []))
|
||||
scores = _to_numpy(output.get("scores", []))
|
||||
if masks.ndim == 2:
|
||||
masks = masks[None, :, :]
|
||||
elif masks.ndim == 4:
|
||||
masks = masks[:, 0]
|
||||
elif masks.ndim == 3 and masks.shape[0] == 1:
|
||||
masks = masks[None, 0]
|
||||
|
||||
polygons = []
|
||||
for mask in masks:
|
||||
polygon = _mask_to_polygon(mask)
|
||||
if polygon:
|
||||
polygons.append(polygon)
|
||||
|
||||
return {
|
||||
"polygons": polygons,
|
||||
"scores": scores.astype(float).tolist() if scores.size else [],
|
||||
}
|
||||
|
||||
|
||||
def predict_video(request_path: Path) -> dict[str, Any]:
|
||||
import torch
|
||||
from sam3.model_builder import build_sam3_video_predictor
|
||||
|
||||
payload = json.loads(request_path.read_text(encoding="utf-8"))
|
||||
frame_dir = Path(payload["frame_dir"])
|
||||
source_frame_index = int(payload.get("source_frame_index", 0))
|
||||
seed = payload.get("seed") or {}
|
||||
direction = str(payload.get("direction") or "forward").lower()
|
||||
max_frames = payload.get("max_frames")
|
||||
max_frames = int(max_frames) if max_frames else None
|
||||
checkpoint_path = str(payload.get("checkpoint_path") or os.environ.get("SAM3_CHECKPOINT_PATH", "")).strip()
|
||||
threshold = float(payload.get("confidence_threshold", 0.5))
|
||||
if direction not in {"forward", "backward", "both"}:
|
||||
raise ValueError(f"Unsupported propagation direction: {direction}")
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
predictor = build_sam3_video_predictor(
|
||||
checkpoint_path=checkpoint_path or None,
|
||||
async_loading_frames=False,
|
||||
)
|
||||
session_id = None
|
||||
try:
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
session = predictor.handle_request(
|
||||
{
|
||||
"type": "start_session",
|
||||
"resource_path": str(frame_dir),
|
||||
"offload_video_to_cpu": True,
|
||||
"offload_state_to_cpu": True,
|
||||
}
|
||||
)
|
||||
session_id = session["session_id"]
|
||||
predictor.handle_request(
|
||||
{
|
||||
"type": "add_prompt",
|
||||
"session_id": session_id,
|
||||
"frame_index": source_frame_index,
|
||||
"bounding_boxes": [_bbox_from_seed(seed)],
|
||||
"bounding_box_labels": [1],
|
||||
"output_prob_thresh": threshold,
|
||||
"rel_coordinates": True,
|
||||
}
|
||||
)
|
||||
frames = []
|
||||
for item in predictor.handle_stream_request(
|
||||
{
|
||||
"type": "propagate_in_video",
|
||||
"session_id": session_id,
|
||||
"propagation_direction": direction,
|
||||
"start_frame_index": source_frame_index,
|
||||
"max_frame_num_to_track": max_frames,
|
||||
"output_prob_thresh": threshold,
|
||||
}
|
||||
):
|
||||
frame_response = _video_outputs_to_response(item.get("outputs") or {})
|
||||
frame_response["frame_index"] = int(item["frame_index"])
|
||||
frames.append(frame_response)
|
||||
finally:
|
||||
if session_id:
|
||||
predictor.handle_request({"type": "close_session", "session_id": session_id})
|
||||
|
||||
return {"frames": frames}
|
||||
|
||||
|
||||
def predict(request_path: Path) -> dict[str, Any]:
|
||||
import torch
|
||||
from sam3.model.sam3_image_processor import Sam3Processor
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
|
||||
payload = json.loads(request_path.read_text(encoding="utf-8"))
|
||||
if str(payload.get("prompt_type") or "").strip().lower() == "video_track":
|
||||
return predict_video(request_path)
|
||||
|
||||
image_path = Path(payload["image_path"])
|
||||
prompt_type = str(payload.get("prompt_type") or "semantic").strip().lower()
|
||||
text = str(payload.get("text") or "").strip()
|
||||
threshold = float(payload.get("confidence_threshold", 0.5))
|
||||
checkpoint_path = str(payload.get("checkpoint_path") or os.environ.get("SAM3_CHECKPOINT_PATH", "")).strip()
|
||||
if prompt_type == "semantic" and not text:
|
||||
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
|
||||
if prompt_type not in {"semantic", "box"}:
|
||||
raise ValueError(f"Unsupported SAM 3 prompt type: {prompt_type}")
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
model = build_sam3_image_model(
|
||||
checkpoint_path=checkpoint_path or None,
|
||||
load_from_HF=not bool(checkpoint_path),
|
||||
)
|
||||
processor = Sam3Processor(model, confidence_threshold=threshold)
|
||||
state = processor.set_image(image)
|
||||
if prompt_type == "box":
|
||||
output = processor.add_geometric_prompt(
|
||||
state=state,
|
||||
box=_xyxy_to_cxcywh(payload.get("box") or []),
|
||||
label=True,
|
||||
)
|
||||
else:
|
||||
output = processor.set_text_prompt(state=state, prompt=text)
|
||||
|
||||
return _prediction_to_response(output)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="SAM 3 external runtime helper")
|
||||
parser.add_argument("--status", action="store_true")
|
||||
parser.add_argument("--request", type=Path)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
if args.status:
|
||||
print(json.dumps(runtime_status(), ensure_ascii=False))
|
||||
return 0
|
||||
if args.request:
|
||||
print(json.dumps(predict(args.request), ensure_ascii=False))
|
||||
return 0
|
||||
parser.error("Use --status or --request")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(json.dumps({"error": str(exc)}, ensure_ascii=False), file=sys.stderr)
|
||||
return 1
|
||||
|
||||
return 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
130
backend/services/sam_registry.py
Normal file
130
backend/services/sam_registry.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Model registry for SAM runtimes and GPU status."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from config import settings
|
||||
from services.sam2_engine import DEFAULT_SAM2_MODEL_ID, TORCH_AVAILABLE, sam_engine as sam2_engine
|
||||
|
||||
# SAM 3 integration is intentionally disabled for the current product flow.
|
||||
# The source files are kept in the repository so the integration can be
|
||||
# restored later, but the active registry only exposes SAM 2.
|
||||
# from services.sam3_engine import sam3_engine
|
||||
|
||||
try:
|
||||
import torch
|
||||
except Exception: # noqa: BLE001
|
||||
torch = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class ModelUnavailableError(RuntimeError):
|
||||
"""Raised when a selected model cannot run in this environment."""
|
||||
|
||||
|
||||
class SAMRegistry:
|
||||
"""Dispatch predictions to the selected SAM backend."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._engines = {
|
||||
"sam2": sam2_engine,
|
||||
# "sam3": sam3_engine,
|
||||
}
|
||||
|
||||
def normalize_model_id(self, model_id: str | None) -> str:
|
||||
selected = (model_id or settings.sam_default_model or DEFAULT_SAM2_MODEL_ID).lower()
|
||||
if self._engines["sam2"].is_sam2_model(selected):
|
||||
return self._engines["sam2"].normalize_model_id(selected)
|
||||
if selected not in self._engines:
|
||||
raise ValueError(f"Unsupported model: {model_id}")
|
||||
return selected
|
||||
|
||||
def runtime_status(self, selected_model: str | None = None) -> dict[str, Any]:
|
||||
selected = self.normalize_model_id(selected_model)
|
||||
return {
|
||||
"selected_model": selected,
|
||||
"gpu": self.gpu_status(),
|
||||
"models": [sam2_engine.status(model_id) for model_id in sam2_engine.variant_ids()],
|
||||
}
|
||||
|
||||
def gpu_status(self) -> dict[str, Any]:
|
||||
cuda_available = bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
|
||||
return {
|
||||
"available": cuda_available,
|
||||
"device": "cuda" if cuda_available else "cpu",
|
||||
"name": torch.cuda.get_device_name(0) if cuda_available else None,
|
||||
"torch_available": bool(TORCH_AVAILABLE),
|
||||
"torch_version": getattr(torch, "__version__", None) if torch is not None else None,
|
||||
"cuda_version": getattr(torch.version, "cuda", None) if torch is not None else None,
|
||||
}
|
||||
|
||||
def _engine(self, model_id: str | None) -> Any:
|
||||
normalized = self.normalize_model_id(model_id)
|
||||
if self._engines["sam2"].is_sam2_model(normalized):
|
||||
return self._engines["sam2"]
|
||||
return self._engines[normalized]
|
||||
|
||||
def _ensure_available(self, model_id: str | None) -> Any:
|
||||
normalized = self.normalize_model_id(model_id)
|
||||
engine = self._engine(model_id)
|
||||
status = engine.status(normalized) if engine is sam2_engine else engine.status()
|
||||
if not status["available"]:
|
||||
raise ModelUnavailableError(status["message"])
|
||||
return engine
|
||||
|
||||
def predict_points(self, model_id: str | None, image: Any, points: list[list[float]], labels: list[int]):
|
||||
model = self.normalize_model_id(model_id)
|
||||
return self._ensure_available(model).predict_points(model, image, points, labels)
|
||||
|
||||
def predict_box(self, model_id: str | None, image: Any, box: list[float]):
|
||||
model = self.normalize_model_id(model_id)
|
||||
return self._ensure_available(model).predict_box(model, image, box)
|
||||
|
||||
def predict_interactive(
|
||||
self,
|
||||
model_id: str | None,
|
||||
image: Any,
|
||||
box: list[float] | None,
|
||||
points: list[list[float]],
|
||||
labels: list[int],
|
||||
):
|
||||
model = self.normalize_model_id(model_id)
|
||||
if not sam2_engine.is_sam2_model(model):
|
||||
raise NotImplementedError("Interactive box + point refinement is currently supported by SAM 2.")
|
||||
return self._ensure_available(model).predict_interactive(model, image, box, points, labels)
|
||||
|
||||
def predict_auto(self, model_id: str | None, image: Any):
|
||||
model = self.normalize_model_id(model_id)
|
||||
return self._ensure_available(model).predict_auto(model, image)
|
||||
|
||||
def predict_semantic(
|
||||
self,
|
||||
model_id: str | None,
|
||||
image: Any,
|
||||
text: str,
|
||||
confidence_threshold: float | None = None,
|
||||
):
|
||||
self.normalize_model_id(model_id)
|
||||
raise NotImplementedError("Semantic text prompting is disabled; use SAM 2 point or box prompts.")
|
||||
|
||||
def propagate_video(
|
||||
self,
|
||||
model_id: str | None,
|
||||
frame_paths: list[str],
|
||||
source_frame_index: int,
|
||||
seed: dict[str, Any],
|
||||
direction: str,
|
||||
max_frames: int | None,
|
||||
):
|
||||
model = self.normalize_model_id(model_id)
|
||||
return self._ensure_available(model).propagate_video(
|
||||
model,
|
||||
frame_paths,
|
||||
source_frame_index,
|
||||
seed,
|
||||
direction=direction,
|
||||
max_frames=max_frames,
|
||||
)
|
||||
|
||||
|
||||
sam_registry = SAMRegistry()
|
||||
15
backend/statuses.py
Normal file
15
backend/statuses.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Shared status constants used across backend project/task flows."""
|
||||
|
||||
PROJECT_STATUS_PENDING = "pending"
|
||||
PROJECT_STATUS_PARSING = "parsing"
|
||||
PROJECT_STATUS_READY = "ready"
|
||||
PROJECT_STATUS_ERROR = "error"
|
||||
|
||||
TASK_STATUS_QUEUED = "queued"
|
||||
TASK_STATUS_RUNNING = "running"
|
||||
TASK_STATUS_SUCCESS = "success"
|
||||
TASK_STATUS_FAILED = "failed"
|
||||
TASK_STATUS_CANCELLED = "cancelled"
|
||||
|
||||
TASK_ACTIVE_STATUSES = {TASK_STATUS_QUEUED, TASK_STATUS_RUNNING}
|
||||
TASK_TERMINAL_STATUSES = {TASK_STATUS_SUCCESS, TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}
|
||||
36
backend/worker_tasks.py
Normal file
36
backend/worker_tasks.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Celery task definitions."""
|
||||
|
||||
import logging
|
||||
|
||||
from celery_app import celery_app
|
||||
from database import SessionLocal
|
||||
from services.media_task_runner import run_parse_media_task
|
||||
from services.propagation_task_runner import run_propagate_project_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(name="media.parse_project")
|
||||
def parse_project_media(task_id: int) -> dict:
|
||||
"""Run media parsing for one queued task."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
return run_parse_media_task(db, task_id)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("Parse media task failed: task_id=%s", task_id)
|
||||
raise exc
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(name="ai.propagate_project")
|
||||
def propagate_project_masks(task_id: int) -> dict:
|
||||
"""Run SAM video propagation for one queued task."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
return run_propagate_project_task(db, task_id)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("Propagation task failed: task_id=%s", task_id)
|
||||
raise exc
|
||||
finally:
|
||||
db.close()
|
||||
Reference in New Issue
Block a user