添加Docker自包含部署分支

- 新增 Seg_Server_Docker 自包含部署内容,包含前后端、FastAPI、Celery、PostgreSQL、Redis、MinIO、演示视频和 DICOM 数据。

- 保留 demo 数据以支持恢复演示出厂设置,排除 SAM 2.1 .pt 权重并在 README 中补充下载命令。

- 补充 GPU 部署、backend/worker 镜像复用、frpc/frps + NPM 公网域名反代部署说明。

- 在 .env/.env.example 中用 # XXXX 标注局域网和公网域名部署需要修改的配置项。

- 添加部署分支 .gitignore,忽略本地模型权重、构建产物、缓存和日志。
This commit is contained in:
2026-05-07 19:06:07 +08:00
commit b5413066a0
396 changed files with 32742 additions and 0 deletions

21
backend/celery_app.py Normal file
View File

@@ -0,0 +1,21 @@
"""Celery application for background processing."""
from celery import Celery
from config import settings
celery_app = Celery(
"seg_server",
broker=settings.redis_url,
backend=settings.redis_url,
include=["worker_tasks"],
)
celery_app.conf.update(
task_serializer="json",
result_serializer="json",
accept_content=["json"],
timezone="Asia/Shanghai",
enable_utc=True,
task_track_started=True,
)

51
backend/config.py Normal file
View File

@@ -0,0 +1,51 @@
"""Application configuration using Pydantic Settings."""
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
# Database
db_url: str = "postgresql://seguser:segpass123@localhost:5432/segserver"
# Redis
redis_url: str = "redis://localhost:6379/0"
# MinIO
minio_endpoint: str = "192.168.3.11:9000"
minio_public_endpoint: str | None = None
minio_access_key: str = "minioadmin"
minio_secret_key: str = "minioadmin"
minio_secure: bool = False
# SAM
sam_default_model: str = "sam2.1_hiera_tiny"
sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2.1_hiera_tiny.pt"
sam_model_config: str = "configs/sam2.1/sam2.1_hiera_t.yaml"
sam3_model_version: str = "sam3"
sam3_checkpoint_path: str = "/home/wkmgc/Desktop/Seg_Server/sam3权重/sam3.pt"
sam3_external_enabled: bool = False
sam3_external_python: str = "/home/wkmgc/miniconda3/envs/sam3/bin/python"
sam3_timeout_seconds: int = 300
sam3_status_cache_seconds: int = 30
sam3_confidence_threshold: float = 0.5
# App
app_env: str = "development"
cors_origins: list[str] = ["http://localhost:3000", "http://192.168.3.11:3000"]
jwt_secret_key: str = "seg-server-dev-secret-change-me"
jwt_algorithm: str = "HS256"
access_token_expire_minutes: int = 60 * 24
default_admin_username: str = "admin"
default_admin_password: str = "123456"
demo_video_path: str = "/home/wkmgc/Desktop/Seg_Server/demo/演视LC视频序列.mp4"
demo_dicom_dir: str = "/home/wkmgc/Desktop/Seg_Server/demo/演视DICOM序列"
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
extra = "ignore"
settings = Settings()

29
backend/database.py Normal file
View File

@@ -0,0 +1,29 @@
"""Database configuration using synchronous SQLAlchemy."""
from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker, Session
from fastapi import Depends
from typing import Generator
from config import settings
engine = create_engine(
settings.db_url,
pool_pre_ping=True,
pool_size=10,
max_overflow=20,
echo=False,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def get_db() -> Generator[Session, None, None]:
"""FastAPI dependency that yields a database session."""
db = SessionLocal()
try:
yield db
finally:
db.close()

317
backend/main.py Normal file
View File

@@ -0,0 +1,317 @@
"""FastAPI application entrypoint."""
import asyncio
import json
import logging
import os
from contextlib import asynccontextmanager, suppress
from datetime import datetime, timezone
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import inspect, text
from config import settings
from database import Base, engine, SessionLocal
from minio_client import ensure_bucket_exists
from progress_events import PROGRESS_CHANNEL
from redis_client import get_redis_client, ping as redis_ping
from routers import projects, templates, media, ai, export, auth, dashboard, tasks, admin
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
logger = logging.getLogger(__name__)
def _ensure_runtime_schema_columns() -> None:
"""Add nullable columns introduced after initial create_all deployments."""
try:
inspector = inspect(engine)
frame_columns = {column["name"] for column in inspector.get_columns("frames")}
project_columns = {column["name"] for column in inspector.get_columns("projects")}
template_columns = {column["name"] for column in inspector.get_columns("templates")}
with engine.begin() as connection:
if "timestamp_ms" not in frame_columns:
connection.execute(text("ALTER TABLE frames ADD COLUMN timestamp_ms FLOAT"))
if "source_frame_number" not in frame_columns:
connection.execute(text("ALTER TABLE frames ADD COLUMN source_frame_number INTEGER"))
if "owner_user_id" not in project_columns:
connection.execute(text("ALTER TABLE projects ADD COLUMN owner_user_id INTEGER"))
if "owner_user_id" not in template_columns:
connection.execute(text("ALTER TABLE templates ADD COLUMN owner_user_id INTEGER"))
except Exception as exc: # noqa: BLE001
logger.warning("Runtime schema column check failed: %s", exc)
def _seed_default_admin_sync() -> None:
"""Ensure the single default admin exists without rewriting project ownership metadata."""
from routers.auth import ensure_default_admin
db = SessionLocal()
try:
admin = ensure_default_admin(db)
db.commit()
logger.info("Default admin ready id=%s", admin.id)
except Exception as exc: # noqa: BLE001
logger.error("Failed to seed default admin: %s", exc)
finally:
db.close()
def _seed_default_project_sync() -> None:
"""Synchronously seed the bundled demo video and DICOM projects on first startup."""
from models import Project
from routers.auth import ensure_default_admin
from services.demo_media import (
DEMO_DICOM_PROJECT_NAME,
DEMO_VIDEO_PROJECT_NAME,
LEGACY_DEMO_DICOM_PROJECT_NAMES,
LEGACY_DEMO_VIDEO_PROJECT_NAMES,
create_parsed_dicom_demo_project,
create_parsed_video_demo_project,
demo_dicom_files,
)
db = SessionLocal()
try:
admin = ensure_default_admin(db)
legacy_video = (
db.query(Project)
.filter(Project.name.in_(LEGACY_DEMO_VIDEO_PROJECT_NAMES))
.first()
)
if legacy_video is not None:
legacy_video.name = DEMO_VIDEO_PROJECT_NAME
db.commit()
legacy_dicom = (
db.query(Project)
.filter(Project.name.in_(LEGACY_DEMO_DICOM_PROJECT_NAMES))
.first()
)
if legacy_dicom is not None:
legacy_dicom.name = DEMO_DICOM_PROJECT_NAME
db.commit()
existing_video = db.query(Project).filter(Project.name == DEMO_VIDEO_PROJECT_NAME).first()
if existing_video is None and os.path.exists(settings.demo_video_path):
video_project = create_parsed_video_demo_project(
db,
owner=admin,
video_path=settings.demo_video_path,
project_name=DEMO_VIDEO_PROJECT_NAME,
)
logger.info("Seeded default video project id=%s", video_project.id)
existing_dicom = db.query(Project).filter(Project.name == DEMO_DICOM_PROJECT_NAME).first()
if existing_dicom is not None:
return
if not demo_dicom_files(settings.demo_dicom_dir):
logger.warning("Default DICOM series not found at %s", settings.demo_dicom_dir)
return
project = create_parsed_dicom_demo_project(
db,
owner=admin,
dicom_dir=settings.demo_dicom_dir,
project_name=DEMO_DICOM_PROJECT_NAME,
)
logger.info("Seeded default DICOM project id=%s with %d frames", project.id, len(project.frames))
except Exception as exc:
logger.error("Failed to seed default project: %s", exc)
finally:
db.close()
def _seed_default_templates_sync() -> None:
"""Seed default ontology templates on first startup."""
db = SessionLocal()
try:
ensure_default_templates(db)
except Exception as exc:
logger.error("Failed to seed default templates: %s", exc)
finally:
db.close()
def ensure_default_templates(db) -> None:
"""Ensure all bundled system templates exist."""
from services.default_templates import ensure_default_templates as _ensure_default_templates
_ensure_default_templates(db)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan: startup and shutdown hooks."""
progress_listener: asyncio.Task | None = None
# Startup
logger.info("Starting up SegServer backend...")
# Initialize database tables
try:
Base.metadata.create_all(bind=engine)
_ensure_runtime_schema_columns()
_seed_default_admin_sync()
logger.info("Database tables initialized.")
except Exception as exc: # noqa: BLE001
logger.error("Database initialization failed: %s", exc)
# Check MinIO bucket
try:
ensure_bucket_exists()
except Exception as exc: # noqa: BLE001
logger.error("MinIO bucket check failed: %s", exc)
# Check Redis
if redis_ping():
logger.info("Redis connection OK.")
else:
logger.warning("Redis connection failed.")
try:
progress_listener = asyncio.create_task(_progress_pubsub_loop())
except Exception as exc: # noqa: BLE001
logger.error("Failed to start Redis progress subscription: %s", exc)
# Seed default templates
try:
asyncio.create_task(asyncio.to_thread(_seed_default_templates_sync))
except Exception as exc: # noqa: BLE001
logger.error("Failed to start default template seeding: %s", exc)
# Seed default project in background thread so it doesn't block startup
try:
asyncio.create_task(asyncio.to_thread(_seed_default_project_sync))
except Exception as exc: # noqa: BLE001
logger.error("Failed to start default project seeding: %s", exc)
yield
# Shutdown
logger.info("Shutting down SegServer backend...")
if progress_listener is not None:
progress_listener.cancel()
with suppress(asyncio.CancelledError):
await progress_listener
engine.dispose()
app = FastAPI(
title="SegServer API",
description="Semantic Segmentation System Backend",
version="1.0.0",
lifespan=lifespan,
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Routers
app.include_router(auth.router)
app.include_router(projects.router)
app.include_router(templates.router)
app.include_router(media.router)
app.include_router(ai.router)
app.include_router(export.router)
app.include_router(dashboard.router)
app.include_router(tasks.router)
app.include_router(admin.router)
@app.get("/health", tags=["Health"])
def health_check() -> dict:
"""Health check endpoint."""
return {"status": "ok", "service": "SegServer"}
# ---------------------------------------------------------------------------
# WebSocket: 实时进度推送
# ---------------------------------------------------------------------------
class ConnectionManager:
"""Manage WebSocket connections for progress broadcasting."""
def __init__(self):
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
logger.info("WebSocket client connected. Total: %d", len(self.active_connections))
def disconnect(self, websocket: WebSocket):
if websocket in self.active_connections:
self.active_connections.remove(websocket)
logger.info("WebSocket client disconnected. Total: %d", len(self.active_connections))
async def broadcast(self, message: dict):
"""Broadcast a message to all connected clients."""
for connection in self.active_connections.copy():
try:
await connection.send_json(message)
except Exception as exc:
logger.warning("WebSocket send failed: %s", exc)
self.disconnect(connection)
manager = ConnectionManager()
async def _progress_pubsub_loop() -> None:
"""Forward Redis task-progress events to connected WebSocket clients."""
while True:
pubsub = None
try:
pubsub = get_redis_client().pubsub()
await asyncio.to_thread(pubsub.subscribe, PROGRESS_CHANNEL)
logger.info("Subscribed to Redis progress channel: %s", PROGRESS_CHANNEL)
while True:
message = await asyncio.to_thread(pubsub.get_message, True, 1.0)
if message is None:
await asyncio.sleep(0)
continue
raw_data = message.get("data")
payload = json.loads(raw_data) if isinstance(raw_data, str) else raw_data
if isinstance(payload, dict):
await manager.broadcast(payload)
except asyncio.CancelledError:
raise
except Exception as exc: # noqa: BLE001
logger.error("Redis progress subscription failed: %s", exc)
await asyncio.sleep(5)
finally:
if pubsub is not None:
with suppress(Exception):
await asyncio.to_thread(pubsub.close)
@app.websocket("/ws/progress")
async def websocket_progress(websocket: WebSocket):
"""WebSocket endpoint for real-time parsing/AI progress updates."""
await manager.connect(websocket)
try:
while True:
# Receive client messages (heartbeat / subscription requests)
data = await websocket.receive_text()
logger.debug("WebSocket received: %s", data)
# Echo heartbeat to keep connection alive
await websocket.send_json({
"type": "status",
"status": "connected",
"message": "Progress stream active",
"timestamp": datetime.now(timezone.utc).isoformat(),
})
except WebSocketDisconnect:
manager.disconnect(websocket)
except Exception as exc:
logger.error("WebSocket error: %s", exc)
manager.disconnect(websocket)

142
backend/minio_client.py Normal file
View File

@@ -0,0 +1,142 @@
"""MinIO client wrapper for object storage operations."""
import io
import logging
from typing import Optional
from minio import Minio
from minio.error import S3Error
from config import settings
logger = logging.getLogger(__name__)
BUCKET_NAME = "seg-media"
_minio_client: Optional[Minio] = None
_minio_public_client: Optional[Minio] = None
def get_minio_client() -> Minio:
"""Return a singleton MinIO client instance."""
global _minio_client
if _minio_client is None:
_minio_client = Minio(
settings.minio_endpoint,
access_key=settings.minio_access_key,
secret_key=settings.minio_secret_key,
secure=settings.minio_secure,
)
return _minio_client
def get_minio_public_client() -> Minio:
"""Return a MinIO client configured for browser-facing presigned URLs."""
global _minio_public_client
if _minio_public_client is None:
endpoint = settings.minio_public_endpoint or settings.minio_endpoint
_minio_public_client = Minio(
endpoint,
access_key=settings.minio_access_key,
secret_key=settings.minio_secret_key,
secure=settings.minio_secure,
)
return _minio_public_client
def ensure_bucket_exists() -> None:
"""Create the bucket if it does not already exist."""
client = get_minio_client()
try:
if not client.bucket_exists(BUCKET_NAME):
client.make_bucket(BUCKET_NAME)
logger.info("Created MinIO bucket: %s", BUCKET_NAME)
else:
logger.info("MinIO bucket %s already exists", BUCKET_NAME)
except S3Error as exc:
logger.error("MinIO bucket check/creation failed: %s", exc)
raise
def upload_file(
object_name: str,
data: bytes,
content_type: str = "application/octet-stream",
length: int = -1,
) -> str:
"""Upload bytes to MinIO and return the object name.
Args:
object_name: Destination path inside the bucket.
data: Raw bytes or a file-like object.
content_type: MIME type of the object.
length: Object size; -1 for unknown (uses chunked upload).
Returns:
The object name (same as input).
"""
client = get_minio_client()
if isinstance(data, bytes):
data = io.BytesIO(data)
length = len(data.getvalue())
try:
client.put_object(
BUCKET_NAME,
object_name,
data,
length=length,
content_type=content_type,
)
logger.info("Uploaded to MinIO: %s", object_name)
return object_name
except S3Error as exc:
logger.error("MinIO upload failed: %s", exc)
raise
from datetime import timedelta
def get_presigned_url(
object_name: str,
expires: int = 3600,
method: str = "GET",
) -> str:
"""Generate a presigned URL for an object.
Args:
object_name: Path inside the bucket.
expires: Expiration time in seconds (default 1 hour).
method: HTTP method (GET or PUT).
Returns:
Presigned URL string.
"""
client = get_minio_public_client()
try:
url = client.get_presigned_url(method, BUCKET_NAME, object_name, expires=timedelta(seconds=expires))
return url
except S3Error as exc:
logger.error("MinIO presigned URL failed: %s", exc)
raise
def download_file(object_name: str) -> bytes:
"""Download an object from MinIO and return its bytes.
Args:
object_name: Path inside the bucket.
Returns:
Raw bytes of the object.
"""
client = get_minio_client()
try:
response = client.get_object(BUCKET_NAME, object_name)
data = response.read()
response.close()
response.release_conn()
return data
except S3Error as exc:
logger.error("MinIO download failed: %s", exc)
raise

195
backend/models.py Normal file
View File

@@ -0,0 +1,195 @@
"""SQLAlchemy ORM models."""
from sqlalchemy import (
Column,
Integer,
String,
Text,
DateTime,
ForeignKey,
JSON,
Float,
)
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from database import Base
from statuses import PROJECT_STATUS_PENDING
class User(Base):
"""Application user used for authentication and data ownership."""
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
username = Column(String(150), unique=True, index=True, nullable=False)
password_hash = Column(String(255), nullable=False)
role = Column(String(50), default="annotator", nullable=False)
is_active = Column(Integer, default=1, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
projects = relationship("Project", back_populates="owner")
templates = relationship("Template", back_populates="owner")
class Project(Base):
"""Project model representing a segmentation project."""
__tablename__ = "projects"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(255), nullable=False)
description = Column(Text, nullable=True)
video_path = Column(String(512), nullable=True)
thumbnail_url = Column(String(512), nullable=True)
status = Column(String(50), default=PROJECT_STATUS_PENDING, nullable=False)
source_type = Column(String(20), default="video", nullable=False) # video | dicom
original_fps = Column(Float, nullable=True)
parse_fps = Column(Float, default=30.0, nullable=False)
owner_user_id = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
owner = relationship("User", back_populates="projects")
frames = relationship("Frame", back_populates="project", cascade="all, delete-orphan")
annotations = relationship(
"Annotation", back_populates="project", cascade="all, delete-orphan"
)
tasks = relationship(
"ProcessingTask", back_populates="project", cascade="all, delete-orphan"
)
class Frame(Base):
"""Frame model representing an extracted video frame."""
__tablename__ = "frames"
id = Column(Integer, primary_key=True, index=True)
project_id = Column(Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
frame_index = Column(Integer, nullable=False)
image_url = Column(String(512), nullable=False)
width = Column(Integer, nullable=True)
height = Column(Integer, nullable=True)
timestamp_ms = Column(Float, nullable=True)
source_frame_number = Column(Integer, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
project = relationship("Project", back_populates="frames")
annotations = relationship(
"Annotation", back_populates="frame", cascade="all, delete-orphan"
)
class Template(Base):
"""Template (Ontology) model for segmentation classes."""
__tablename__ = "templates"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(255), nullable=False)
description = Column(Text, nullable=True)
color = Column(String(50), nullable=False)
z_index = Column(Integer, default=0, nullable=False)
mapping_rules = Column(JSON, nullable=True)
owner_user_id = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
owner = relationship("User", back_populates="templates")
annotations = relationship(
"Annotation", back_populates="template", cascade="all, delete-orphan"
)
class Annotation(Base):
"""Annotation model for segmentation masks and prompts."""
__tablename__ = "annotations"
id = Column(Integer, primary_key=True, index=True)
project_id = Column(
Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=False
)
frame_id = Column(
Integer, ForeignKey("frames.id", ondelete="CASCADE"), nullable=True
)
template_id = Column(
Integer, ForeignKey("templates.id", ondelete="SET NULL"), nullable=True
)
mask_data = Column(JSON, nullable=True)
points = Column(JSON, nullable=True)
bbox = Column(JSON, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
project = relationship("Project", back_populates="annotations")
frame = relationship("Frame", back_populates="annotations")
template = relationship("Template", back_populates="annotations")
masks = relationship("Mask", back_populates="annotation", cascade="all, delete-orphan")
class Mask(Base):
"""Mask model for exported/derived mask files."""
__tablename__ = "masks"
id = Column(Integer, primary_key=True, index=True)
annotation_id = Column(
Integer, ForeignKey("annotations.id", ondelete="CASCADE"), nullable=False
)
mask_url = Column(String(512), nullable=False)
format = Column(String(50), default="png", nullable=False) # png / rle / json
created_at = Column(DateTime(timezone=True), server_default=func.now())
annotation = relationship("Annotation", back_populates="masks")
class AuditLog(Base):
"""Audit trail for security and administrative actions."""
__tablename__ = "audit_logs"
id = Column(Integer, primary_key=True, index=True)
actor_user_id = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
action = Column(String(120), nullable=False)
target_type = Column(String(80), nullable=True)
target_id = Column(String(120), nullable=True)
detail = Column(JSON, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
actor = relationship("User")
class ProcessingTask(Base):
"""Background task state persisted for dashboard and polling."""
__tablename__ = "processing_tasks"
id = Column(Integer, primary_key=True, index=True)
task_type = Column(String(80), nullable=False)
status = Column(String(40), default="queued", nullable=False)
progress = Column(Integer, default=0, nullable=False)
message = Column(Text, nullable=True)
project_id = Column(
Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=True
)
celery_task_id = Column(String(255), nullable=True)
payload = Column(JSON, nullable=True)
result = Column(JSON, nullable=True)
error = Column(Text, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
started_at = Column(DateTime(timezone=True), nullable=True)
finished_at = Column(DateTime(timezone=True), nullable=True)
updated_at = Column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
project = relationship("Project", back_populates="tasks")

View File

@@ -0,0 +1,66 @@
"""Progress event payloads and Redis publication helpers."""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from typing import Any
from redis_client import get_redis_client
from statuses import TASK_STATUS_CANCELLED, TASK_STATUS_FAILED, TASK_STATUS_SUCCESS
logger = logging.getLogger(__name__)
PROGRESS_CHANNEL = "seg:progress"
def _iso_now() -> str:
return datetime.now(timezone.utc).isoformat()
def _event_type(task_status: str) -> str:
if task_status == TASK_STATUS_SUCCESS:
return "complete"
if task_status == TASK_STATUS_CANCELLED:
return "cancelled"
if task_status == TASK_STATUS_FAILED:
return "error"
return "progress"
def task_progress_payload(task: Any) -> dict[str, Any]:
"""Build the WebSocket payload from a persisted processing task."""
project = getattr(task, "project", None)
project_name = getattr(project, "name", None)
status = getattr(task, "status", "")
updated_at = getattr(task, "updated_at", None)
timestamp = updated_at.isoformat() if updated_at is not None else _iso_now()
message = getattr(task, "message", None)
return {
"type": _event_type(status),
"taskId": f"task-{task.id}",
"task_id": task.id,
"project_id": getattr(task, "project_id", None),
"projectName": project_name,
"filename": project_name,
"progress": getattr(task, "progress", 0),
"status": message or status,
"message": message,
"error": getattr(task, "error", None),
"timestamp": timestamp,
}
def publish_progress_event(payload: dict[str, Any]) -> None:
"""Publish a JSON progress event without failing the worker on Redis errors."""
try:
get_redis_client().publish(PROGRESS_CHANNEL, json.dumps(payload, ensure_ascii=False))
except Exception as exc: # noqa: BLE001
logger.warning("Failed to publish progress event: %s", exc)
def publish_task_progress_event(task: Any) -> None:
"""Publish a progress event for a ProcessingTask ORM object."""
publish_progress_event(task_progress_payload(task))

61
backend/redis_client.py Normal file
View File

@@ -0,0 +1,61 @@
"""Redis client wrapper for caching and task queuing."""
import json
import logging
from typing import Optional, Any
import redis
from config import settings
logger = logging.getLogger(__name__)
_redis_client: Optional[redis.Redis] = None
def get_redis_client() -> redis.Redis:
"""Return a singleton Redis client instance."""
global _redis_client
if _redis_client is None:
_redis_client = redis.from_url(settings.redis_url, decode_responses=True)
return _redis_client
def ping() -> bool:
"""Check Redis connectivity."""
try:
return get_redis_client().ping()
except redis.ConnectionError as exc:
logger.error("Redis ping failed: %s", exc)
return False
def set_json(key: str, value: Any, expire: Optional[int] = None) -> None:
"""Store a JSON-serializable value in Redis."""
client = get_redis_client()
try:
client.set(key, json.dumps(value), ex=expire)
except redis.RedisError as exc:
logger.error("Redis set_json failed: %s", exc)
raise
def get_json(key: str) -> Optional[Any]:
"""Retrieve and deserialize a JSON value from Redis."""
client = get_redis_client()
try:
data = client.get(key)
return json.loads(data) if data is not None else None
except redis.RedisError as exc:
logger.error("Redis get_json failed: %s", exc)
raise
def delete_key(key: str) -> int:
"""Delete a key from Redis. Returns number of deleted keys."""
client = get_redis_client()
try:
return client.delete(key)
except redis.RedisError as exc:
logger.error("Redis delete_key failed: %s", exc)
raise

View File

@@ -0,0 +1,20 @@
fastapi
uvicorn[standard]
python-multipart
sqlalchemy
psycopg2-binary
redis
celery
minio
opencv-python-headless
pillow
scikit-image
pydicom
numpy
torch
torchvision
torchaudio
sam2
pydantic-settings
python-jose[cryptography]
passlib[bcrypt]

View File

299
backend/routers/admin.py Normal file
View File

@@ -0,0 +1,299 @@
"""Administrator-only user and audit management endpoints."""
import os
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from config import settings
from database import get_db
from models import Annotation, AuditLog, Frame, Mask, ProcessingTask, Project, Template, User
from routers.auth import SUPPORTED_ROLES, ensure_default_admin, hash_password, normalize_user_role, require_admin, write_audit_log
from schemas import (
AdminUserCreate,
AdminUserUpdate,
AuditLogOut,
DemoFactoryResetOut,
DemoFactoryResetRequest,
UserOut,
)
from services.demo_media import (
DEMO_DICOM_PROJECT_NAME,
DEMO_VIDEO_PROJECT_NAME,
create_parsed_dicom_demo_project,
create_parsed_video_demo_project,
demo_dicom_files,
)
from services.default_templates import restore_default_templates
router = APIRouter(prefix="/api/admin", tags=["Admin"])
DEMO_RESET_CONFIRMATION = "RESET_DEMO_FACTORY"
DEMO_PROJECT_NAME = DEMO_DICOM_PROJECT_NAME
def _normalize_role(role: str | None) -> str:
normalized = (role or "annotator").strip().lower()
if normalized not in SUPPORTED_ROLES:
raise HTTPException(status_code=400, detail=f"Unsupported role: {role}")
return normalized
def _assert_non_admin_role(role: str) -> None:
if role == "admin":
raise HTTPException(status_code=400, detail="Only the default admin account can have admin role")
@router.get("/users", response_model=List[UserOut], summary="List users")
def list_users(
db: Session = Depends(get_db),
admin_user: User = Depends(require_admin),
) -> List[User]:
"""Return all users for the administrator console."""
_ = admin_user
users = db.query(User).order_by(User.id).all()
return [normalize_user_role(db, user) for user in users]
@router.post(
"/users",
response_model=UserOut,
status_code=status.HTTP_201_CREATED,
summary="Create user",
)
def create_user(
payload: AdminUserCreate,
db: Session = Depends(get_db),
admin_user: User = Depends(require_admin),
) -> User:
"""Create a user with an initial password and role."""
username = payload.username.strip()
if not username:
raise HTTPException(status_code=400, detail="Username is required")
if len(payload.password) < 6:
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
role = _normalize_role(payload.role)
_assert_non_admin_role(role)
user = User(
username=username,
password_hash=hash_password(payload.password),
role=role,
is_active=1 if payload.is_active else 0,
)
db.add(user)
try:
db.commit()
except IntegrityError as exc:
db.rollback()
raise HTTPException(status_code=409, detail="Username already exists") from exc
db.refresh(user)
write_audit_log(
db,
actor=admin_user,
action="admin.user_created",
target_type="user",
target_id=user.id,
detail={"username": user.username, "role": user.role, "is_active": bool(user.is_active)},
)
return user
@router.patch("/users/{user_id}", response_model=UserOut, summary="Update user")
def update_user(
user_id: int,
payload: AdminUserUpdate,
db: Session = Depends(get_db),
admin_user: User = Depends(require_admin),
) -> User:
"""Update username, password, role or active state."""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
user = normalize_user_role(db, user)
updates = payload.model_dump(exclude_unset=True)
audit_detail: dict = {"before": {"username": user.username, "role": user.role, "is_active": bool(user.is_active)}}
if "username" in updates:
username = (updates["username"] or "").strip()
if not username:
raise HTTPException(status_code=400, detail="Username is required")
if user.role == "admin" and username != settings.default_admin_username:
raise HTTPException(status_code=400, detail="Default admin username cannot be changed")
user.username = username
if "password" in updates:
password = updates["password"] or ""
if len(password) < 6:
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
user.password_hash = hash_password(password)
if "role" in updates:
next_role = _normalize_role(updates["role"])
if user.username == settings.default_admin_username:
if next_role != "admin":
raise HTTPException(status_code=400, detail="Cannot remove the default admin role")
else:
_assert_non_admin_role(next_role)
user.role = next_role
if "is_active" in updates:
if user.id == admin_user.id and not updates["is_active"]:
raise HTTPException(status_code=400, detail="Cannot deactivate yourself")
user.is_active = 1 if updates["is_active"] else 0
try:
db.commit()
except IntegrityError as exc:
db.rollback()
raise HTTPException(status_code=409, detail="Username already exists") from exc
db.refresh(user)
audit_detail["after"] = {"username": user.username, "role": user.role, "is_active": bool(user.is_active)}
audit_detail["password_changed"] = "password" in updates
write_audit_log(
db,
actor=admin_user,
action="admin.user_updated",
target_type="user",
target_id=user.id,
detail=audit_detail,
)
return user
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete user")
def delete_user(
user_id: int,
db: Session = Depends(get_db),
admin_user: User = Depends(require_admin),
) -> None:
"""Delete a user when it is safe to remove the account."""
if user_id == admin_user.id:
raise HTTPException(status_code=400, detail="Cannot delete yourself")
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
user = normalize_user_role(db, user)
if user.role == "admin":
raise HTTPException(status_code=400, detail="Cannot delete the default admin account")
username = user.username
db.delete(user)
db.commit()
write_audit_log(
db,
actor=admin_user,
action="admin.user_deleted",
target_type="user",
target_id=user_id,
detail={"username": username},
)
return None
@router.get("/audit-logs", response_model=List[AuditLogOut], summary="List audit logs")
def list_audit_logs(
limit: int = 100,
db: Session = Depends(get_db),
admin_user: User = Depends(require_admin),
) -> List[AuditLog]:
"""Return recent audit events for administrators."""
_ = admin_user
safe_limit = min(max(int(limit or 100), 1), 500)
return db.query(AuditLog).order_by(AuditLog.created_at.desc(), AuditLog.id.desc()).limit(safe_limit).all()
@router.post(
"/demo-factory-reset",
response_model=DemoFactoryResetOut,
summary="Reset demo data to factory defaults",
)
def reset_demo_factory(
payload: DemoFactoryResetRequest,
db: Session = Depends(get_db),
admin_user: User = Depends(require_admin),
) -> dict:
"""Reset a demo deployment to one admin account, the demo video, and the demo DICOM project."""
if payload.confirmation != DEMO_RESET_CONFIRMATION:
raise HTTPException(status_code=400, detail="Invalid reset confirmation")
if not os.path.exists(settings.demo_video_path):
raise HTTPException(
status_code=409,
detail=f"Demo video not found: {settings.demo_video_path}",
)
if not demo_dicom_files(settings.demo_dicom_dir):
raise HTTPException(
status_code=409,
detail=f"Demo DICOM series not found: {settings.demo_dicom_dir}",
)
requested_by = admin_user.username
preserved_admin = ensure_default_admin(db)
preserved_admin.username = settings.default_admin_username
preserved_admin.password_hash = hash_password(settings.default_admin_password)
preserved_admin.role = "admin"
preserved_admin.is_active = 1
db.flush()
deleted_counts = {
"masks": db.query(Mask).delete(synchronize_session=False),
"annotations": db.query(Annotation).delete(synchronize_session=False),
"frames": db.query(Frame).delete(synchronize_session=False),
"tasks": db.query(ProcessingTask).delete(synchronize_session=False),
"projects": db.query(Project).delete(synchronize_session=False),
"user_templates": db.query(Template).filter(Template.owner_user_id.is_not(None)).delete(synchronize_session=False),
"audit_logs": db.query(AuditLog).delete(synchronize_session=False),
"users": db.query(User).filter(User.id != preserved_admin.id).delete(synchronize_session=False),
}
db.flush()
db.expunge_all()
preserved_admin = db.query(User).filter(User.username == settings.default_admin_username).first()
if not preserved_admin:
raise HTTPException(status_code=500, detail="Default admin was not preserved")
restored_templates = restore_default_templates(db)
video_project = create_parsed_video_demo_project(
db,
owner=preserved_admin,
video_path=settings.demo_video_path,
project_name=DEMO_VIDEO_PROJECT_NAME,
)
dicom_project = create_parsed_dicom_demo_project(
db,
owner=preserved_admin,
dicom_dir=settings.demo_dicom_dir,
project_name=DEMO_PROJECT_NAME,
)
db.refresh(preserved_admin)
db.refresh(video_project)
db.refresh(dicom_project)
video_project.frame_count = len(video_project.frames)
dicom_project.frame_count = len(dicom_project.frames)
projects = [video_project, dicom_project]
write_audit_log(
db,
actor=preserved_admin,
action="admin.demo_factory_reset",
target_type="project",
target_id=dicom_project.id,
detail={
"project_names": [project.name for project in projects],
"video_path": video_project.video_path,
"dicom_path": dicom_project.video_path,
"source_types": [project.source_type for project in projects],
"frame_counts": {project.name: len(project.frames) for project in projects},
"deleted_counts": deleted_counts,
"restored_templates": [template.name for template in restored_templates],
"requested_by": requested_by,
},
)
return {
"admin_user": preserved_admin,
"project": dicom_project,
"projects": projects,
"deleted_counts": deleted_counts,
"message": "演示环境已恢复出厂设置",
}

1228
backend/routers/ai.py Normal file

File diff suppressed because it is too large Load Diff

222
backend/routers/auth.py Normal file
View File

@@ -0,0 +1,222 @@
"""Authentication endpoints and dependencies."""
from datetime import datetime, timedelta, timezone
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import JWTError, jwt
from passlib.context import CryptContext
from pydantic import BaseModel
from sqlalchemy.orm import Session
from config import settings
from database import get_db
from models import AuditLog, User
from schemas import LoginResponse, UserOut
router = APIRouter(prefix="/api/auth", tags=["Auth"])
password_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
bearer_scheme = HTTPBearer(auto_error=False)
SUPPORTED_ROLES = {"admin", "annotator"}
class LoginRequest(BaseModel):
username: str
password: str
def hash_password(password: str) -> str:
"""Hash a plain password for storage."""
return password_context.hash(password)
def verify_password(password: str, password_hash: str) -> bool:
"""Verify a plain password against a stored hash."""
return password_context.verify(password, password_hash)
def create_access_token(user: User, expires_delta: timedelta | None = None) -> str:
"""Create a signed JWT access token for a user."""
expire = datetime.now(timezone.utc) + (
expires_delta or timedelta(minutes=settings.access_token_expire_minutes)
)
payload: dict[str, Any] = {
"sub": str(user.id),
"username": user.username,
"role": user.role,
"exp": expire,
}
return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
def ensure_default_admin(db: Session) -> User:
"""Create and enforce the single default administrator account."""
existing = db.query(User).filter(User.username == settings.default_admin_username).first()
if existing:
changed = False
if existing.role != "admin":
existing.role = "admin"
changed = True
if not existing.is_active:
existing.is_active = 1
changed = True
extra_admins = db.query(User).filter(
User.role == "admin",
User.id != existing.id,
).all()
for user in extra_admins:
user.role = "annotator"
changed = True
if changed:
db.commit()
db.refresh(existing)
return existing
user = User(
username=settings.default_admin_username,
password_hash=hash_password(settings.default_admin_password),
role="admin",
is_active=1,
)
db.add(user)
db.commit()
db.refresh(user)
extra_admins = db.query(User).filter(
User.role == "admin",
User.id != user.id,
).all()
if extra_admins:
for extra_user in extra_admins:
extra_user.role = "annotator"
db.commit()
db.refresh(user)
return user
def normalize_user_role(db: Session, user: User) -> User:
"""Keep legacy accounts within the current two-role policy."""
desired_role = "admin" if user.username == settings.default_admin_username else "annotator"
changed = False
if user.role != desired_role:
user.role = desired_role
changed = True
if user.username == settings.default_admin_username and not user.is_active:
user.is_active = 1
changed = True
if changed:
db.commit()
db.refresh(user)
return user
def get_current_user(
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
db: Session = Depends(get_db),
) -> User:
"""Resolve and validate the current user from the Bearer token."""
if credentials is None or credentials.scheme.lower() != "bearer":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(
credentials.credentials,
settings.jwt_secret_key,
algorithms=[settings.jwt_algorithm],
)
user_id = int(payload.get("sub"))
except (JWTError, TypeError, ValueError) as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
) from exc
user = db.query(User).filter(User.id == user_id).first()
if user:
user = normalize_user_role(db, user)
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Inactive or missing user",
headers={"WWW-Authenticate": "Bearer"},
)
return user
def require_admin(current_user: User = Depends(get_current_user)) -> User:
"""Require the current user to have the administrator role."""
if current_user.role != "admin":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin permission required")
return current_user
def require_editor(current_user: User = Depends(get_current_user)) -> User:
"""Require a user role that can modify segmentation data."""
if current_user.role not in SUPPORTED_ROLES:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Edit permission required")
return current_user
def write_audit_log(
db: Session,
*,
actor: User | None,
action: str,
target_type: str | None = None,
target_id: str | int | None = None,
detail: dict[str, Any] | None = None,
) -> AuditLog:
"""Persist a compact audit event."""
log = AuditLog(
actor_user_id=actor.id if actor else None,
action=action,
target_type=target_type,
target_id=str(target_id) if target_id is not None else None,
detail=detail or {},
)
db.add(log)
db.commit()
db.refresh(log)
return log
@router.post("/login", response_model=LoginResponse)
def login(payload: LoginRequest, db: Session = Depends(get_db)) -> dict:
"""Authenticate a user and return a signed JWT."""
ensure_default_admin(db)
user = db.query(User).filter(User.username == payload.username).first()
if user:
user = normalize_user_role(db, user)
if not user or not user.is_active or not verify_password(payload.password, user.password_hash):
write_audit_log(
db,
actor=None,
action="auth.login_failed",
target_type="user",
target_id=payload.username,
detail={"username": payload.username},
)
raise HTTPException(status_code=401, detail="Invalid credentials")
write_audit_log(
db,
actor=user,
action="auth.login_success",
target_type="user",
target_id=user.id,
detail={"username": user.username},
)
return {
"token": create_access_token(user),
"token_type": "bearer",
"username": user.username,
"user": user,
}
@router.get("/me", response_model=UserOut)
def read_current_user(current_user: User = Depends(get_current_user)) -> User:
"""Return the authenticated user profile."""
return current_user

View File

@@ -0,0 +1,164 @@
"""Dashboard overview endpoints."""
import os
from datetime import datetime, timezone
from typing import Any
from fastapi import APIRouter, Depends
from sqlalchemy import func, or_
from sqlalchemy.orm import Session
from database import get_db
from models import Annotation, Frame, ProcessingTask, Project, Template, User
from routers.auth import get_current_user
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
ACTIVE_TASK_STATUSES = {"queued", "running"}
MONITORED_TASK_STATUSES = {"queued", "running", "success", "failed", "cancelled"}
def _system_load_percent() -> int:
"""Return a real host load estimate without adding a psutil dependency."""
try:
load_1m = os.getloadavg()[0]
cpu_count = os.cpu_count() or 1
return min(100, max(0, round((load_1m / cpu_count) * 100)))
except (AttributeError, OSError):
return 0
def _iso_or_none(value: datetime | None) -> str | None:
if value is None:
return None
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
return value.isoformat()
def _task_payload(task: ProcessingTask) -> dict[str, Any]:
result = task.result or {}
return {
"id": f"task-{task.id}",
"task_id": task.id,
"project_id": task.project_id or 0,
"name": task.project.name if task.project else f"任务 {task.id}",
"progress": task.progress,
"status": task.message or task.status,
"raw_status": task.status,
"frame_count": result.get("frames_extracted", result.get("processed_frame_count", 0)),
"error": task.error,
"updated_at": _iso_or_none(task.updated_at),
}
@router.get("/overview", summary="Get dashboard overview")
def get_dashboard_overview(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> dict[str, Any]:
"""Return live dashboard data derived from persisted backend records."""
shared_project_ids_query = db.query(Project.id)
project_count = db.query(func.count(Project.id)).scalar() or 0
frame_count = db.query(func.count(Frame.id)).filter(Frame.project_id.in_(shared_project_ids_query)).scalar() or 0
annotation_count = (
db.query(func.count(Annotation.id))
.filter(Annotation.project_id.in_(shared_project_ids_query))
.scalar()
or 0
)
template_count = (
db.query(func.count(Template.id))
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
.scalar()
or 0
)
active_task_count = (
db.query(func.count(ProcessingTask.id))
.outerjoin(Project, Project.id == ProcessingTask.project_id)
.filter(ProcessingTask.status.in_(ACTIVE_TASK_STATUSES))
.scalar()
or 0
)
projects = (
db.query(Project)
.order_by(Project.updated_at.desc())
.all()
)
recent_tasks = (
db.query(ProcessingTask)
.outerjoin(Project, Project.id == ProcessingTask.project_id)
.order_by(ProcessingTask.created_at.desc())
.limit(50)
.all()
)
tasks = [_task_payload(task) for task in recent_tasks if task.status in MONITORED_TASK_STATUSES]
activities: list[dict[str, Any]] = []
for task in recent_tasks[:10]:
project_name = task.project.name if task.project else f"项目 {task.project_id}"
activities.append({
"id": f"task-{task.id}",
"kind": "task",
"time": _iso_or_none(task.updated_at),
"message": task.message or f"任务状态: {task.status}",
"project": project_name,
})
for project in projects[:10]:
activities.append({
"id": f"project-{project.id}",
"kind": "project",
"time": _iso_or_none(project.updated_at),
"message": f"项目状态: {project.status}",
"project": project.name,
})
recent_annotations = (
db.query(Annotation)
.filter(Annotation.project_id.in_(shared_project_ids_query))
.order_by(Annotation.updated_at.desc())
.limit(10)
.all()
)
for annotation in recent_annotations:
project_name = annotation.project.name if annotation.project else f"项目 {annotation.project_id}"
activities.append({
"id": f"annotation-{annotation.id}",
"kind": "annotation",
"time": _iso_or_none(annotation.updated_at),
"message": f"标注已更新 #{annotation.id}",
"project": project_name,
})
recent_templates = (
db.query(Template)
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
.order_by(Template.created_at.desc())
.limit(10)
.all()
)
for template in recent_templates:
activities.append({
"id": f"template-{template.id}",
"kind": "template",
"time": _iso_or_none(template.created_at),
"message": f"模板可用: {template.name}",
"project": "系统",
})
activities.sort(key=lambda item: item["time"] or "", reverse=True)
return {
"summary": {
"project_count": project_count,
"parsing_task_count": active_task_count,
"annotation_count": annotation_count,
"frame_count": frame_count,
"template_count": template_count,
"system_load_percent": _system_load_percent(),
},
"tasks": tasks,
"activity": activities[:10],
}

764
backend/routers/export.py Normal file
View File

@@ -0,0 +1,764 @@
"""Annotation export endpoints (COCO, PNG masks)."""
import io
import json
import logging
import os
import re
import zipfile
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List
from urllib.parse import quote
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from database import get_db
from minio_client import download_file
from models import Project, Annotation, Frame, Template, User
from routers.auth import get_current_user
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/export", tags=["Export"])
def _mask_from_polygon(
polygon: List[List[float]],
width: int,
height: int,
) -> np.ndarray:
"""Render a normalized polygon to a binary mask."""
import cv2
pts = np.array(
[[int(p[0] * width), int(p[1] * height)] for p in polygon],
dtype=np.int32,
)
mask = np.zeros((height, width), dtype=np.uint8)
cv2.fillPoly(mask, [pts], 255)
return mask
def _annotation_z_index(annotation: Annotation) -> int:
class_meta = (annotation.mask_data or {}).get("class") or {}
if isinstance(class_meta, dict) and class_meta.get("zIndex") is not None:
try:
return int(class_meta["zIndex"])
except (TypeError, ValueError):
pass
if annotation.template and annotation.template.z_index is not None:
return int(annotation.template.z_index)
return 0
def _annotation_mask_id(annotation: Annotation) -> int | None:
class_meta = (annotation.mask_data or {}).get("class") or {}
if isinstance(class_meta, dict):
for key in ("maskId", "maskid", "mask_id"):
if class_meta.get(key) is None:
continue
try:
value = int(class_meta[key])
except (TypeError, ValueError):
continue
if value >= 0:
return value
return None
def _annotation_category_name(annotation: Annotation) -> str:
class_meta = (annotation.mask_data or {}).get("class") or {}
if isinstance(class_meta, dict) and class_meta.get("category"):
return str(class_meta["category"])
if annotation.template and annotation.template.name:
return str(annotation.template.name)
return ""
def _annotation_class_key(annotation: Annotation) -> str:
class_meta = (annotation.mask_data or {}).get("class") or {}
if isinstance(class_meta, dict):
if class_meta.get("id"):
return f"class:{class_meta['id']}"
if class_meta.get("name"):
return f"name:{class_meta['name']}"
if annotation.template_id:
return f"template:{annotation.template_id}"
return f"annotation:{annotation.id}"
def _annotation_label(annotation: Annotation) -> str:
mask_data = annotation.mask_data or {}
class_meta = mask_data.get("class") or {}
if isinstance(class_meta, dict) and class_meta.get("name"):
return str(class_meta["name"])
if mask_data.get("label"):
return str(mask_data["label"])
if annotation.template and annotation.template.name:
return str(annotation.template.name)
return f"Annotation {annotation.id}"
def _annotation_color(annotation: Annotation) -> str:
mask_data = annotation.mask_data or {}
class_meta = mask_data.get("class") or {}
if isinstance(class_meta, dict) and class_meta.get("color"):
return str(class_meta["color"])
if mask_data.get("color"):
return str(mask_data["color"])
if annotation.template and annotation.template.color:
return str(annotation.template.color)
return "#ffffff"
def _hex_to_rgb(color: str) -> list[int]:
value = str(color or "").strip()
if value.startswith("#"):
value = value[1:]
if len(value) == 3:
value = "".join(part * 2 for part in value)
if len(value) != 6:
return [255, 255, 255]
try:
return [int(value[i:i + 2], 16) for i in (0, 2, 4)]
except ValueError:
return [255, 255, 255]
def _safe_filename_part(value: Any, fallback: str = "unknown") -> str:
text = str(value or "").strip()
if not text:
text = fallback
text = re.sub(r"[\\/:*?\"<>|\s]+", "_", text)
text = re.sub(r"_+", "_", text).strip("._")
return text or fallback
def _project_video_name(project: Project) -> str:
if project.video_path:
stem = Path(project.video_path).name
if "." in stem:
stem = ".".join(stem.split(".")[:-1])
if stem:
return _safe_filename_part(stem, f"project_{project.id}")
return _safe_filename_part(project.name, f"project_{project.id}")
def _project_export_name(project: Project) -> str:
return _safe_filename_part(project.name, f"project_{project.id}")
def _frame_timestamp_ms(frame: Frame, project: Project) -> float:
if frame.timestamp_ms is not None:
return float(frame.timestamp_ms)
fps = project.parse_fps or project.original_fps or 30.0
return float(frame.frame_index) * 1000.0 / max(float(fps), 1.0)
def _project_frame_number(frame: Frame) -> int:
return int(frame.frame_index) + 1
def _format_timestamp_ms(value: float) -> str:
total_ms = max(0, int(round(float(value))))
hours = total_ms // 3_600_000
minutes = (total_ms % 3_600_000) // 60_000
seconds = (total_ms % 60_000) // 1_000
milliseconds = total_ms % 1_000
return f"{hours}h{minutes:02d}m{seconds:02d}s{milliseconds:03d}ms"
def _frame_export_stem(project: Project, frame: Frame) -> str:
return "_".join([
_project_video_name(project),
_format_timestamp_ms(_frame_timestamp_ms(frame, project)),
f"frame{_project_frame_number(frame):06d}",
])
def _segmentation_results_filename(project: Project, frames: list[Frame]) -> str:
if not frames:
return f"{_project_export_name(project)}_seg_T_0h00m00s000ms-0h00m00s000ms_P_0-0.zip"
first_frame = frames[0]
last_frame = frames[-1]
return (
f"{_project_export_name(project)}"
f"_seg_T_{_format_timestamp_ms(_frame_timestamp_ms(first_frame, project))}"
f"-{_format_timestamp_ms(_frame_timestamp_ms(last_frame, project))}"
f"_P_{_project_frame_number(first_frame)}-{_project_frame_number(last_frame)}.zip"
)
def _download_content_disposition(filename: str) -> str:
ascii_fallback = filename.encode("ascii", "ignore").decode("ascii") or "segmentation_results.zip"
ascii_fallback = _safe_filename_part(ascii_fallback, "segmentation_results.zip")
if not ascii_fallback.endswith(".zip") and filename.endswith(".zip"):
ascii_fallback = f"{ascii_fallback}.zip"
return f"attachment; filename=\"{ascii_fallback}\"; filename*=UTF-8''{quote(filename)}"
def _frame_image_extension(frame: Frame) -> str:
suffix = Path(frame.image_url or "").suffix.lower()
return suffix if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} else ".jpg"
def _project_or_404(project_id: int, db: Session, current_user: User) -> Project:
_ = current_user
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return project
def _project_frames(project_id: int, db: Session) -> list[Frame]:
return (
db.query(Frame)
.filter(Frame.project_id == project_id)
.order_by(Frame.frame_index)
.all()
)
def _filter_frames(
frames: list[Frame],
*,
scope: str = "all",
start_frame: int | None = None,
end_frame: int | None = None,
frame_id: int | None = None,
) -> list[Frame]:
if scope == "current":
if frame_id is None:
raise HTTPException(status_code=400, detail="frame_id is required for current-frame export")
selected = [frame for frame in frames if frame.id == frame_id]
if not selected:
raise HTTPException(status_code=404, detail="Frame not found")
return selected
if scope == "range":
if start_frame is None or end_frame is None:
raise HTTPException(status_code=400, detail="start_frame and end_frame are required for range export")
start = max(1, min(int(start_frame), int(end_frame)))
end = max(1, max(int(start_frame), int(end_frame)))
return frames[start - 1:end]
return frames
def _filtered_annotations(project_id: int, frame_ids: set[int], db: Session) -> list[Annotation]:
if not frame_ids:
return []
return (
db.query(Annotation)
.filter(Annotation.project_id == project_id)
.filter(Annotation.frame_id.in_(frame_ids))
.all()
)
def _build_coco(project: Project, frames: list[Frame], annotations: list[Annotation], templates: list[Template]) -> dict[str, Any]:
images = []
for frame in frames:
images.append({
"id": frame.id,
"file_name": frame.image_url,
"width": frame.width or 1920,
"height": frame.height or 1080,
"frame_index": frame.frame_index,
})
categories = []
template_id_to_cat_id: Dict[int, int] = {}
for cat_idx, tmpl in enumerate(templates, start=1):
categories.append({
"id": cat_idx,
"name": tmpl.name,
"color": tmpl.color,
})
template_id_to_cat_id[tmpl.id] = cat_idx
coco_annotations = []
ann_id = 1
selected_frame_ids = {frame.id for frame in frames}
for ann in annotations:
if ann.frame_id not in selected_frame_ids or not ann.mask_data:
continue
polygons = ann.mask_data.get("polygons", [])
if not polygons:
continue
first_poly = polygons[0]
xs = [p[0] for p in first_poly]
ys = [p[1] for p in first_poly]
width = ann.frame.width if ann.frame else 1920
height = ann.frame.height if ann.frame else 1080
bbox = [
min(xs) * width,
min(ys) * height,
(max(xs) - min(xs)) * width,
(max(ys) - min(ys)) * height,
]
area = bbox[2] * bbox[3]
segmentation = []
for poly in polygons:
flat = []
for p in poly:
flat.append(p[0] * width)
flat.append(p[1] * height)
segmentation.append(flat)
coco_annotations.append({
"id": ann_id,
"image_id": ann.frame_id,
"category_id": template_id_to_cat_id.get(ann.template_id, 0),
"segmentation": segmentation,
"area": area,
"bbox": bbox,
"iscrowd": 0,
})
ann_id += 1
return {
"info": {
"description": f"Annotations for {project.name}",
"version": "1.0",
"year": datetime.now().year,
"date_created": datetime.now().isoformat(),
},
"images": images,
"annotations": coco_annotations,
"categories": categories,
}
def _class_mapping_entry(annotation: Annotation) -> dict[str, Any]:
return {
"key": _annotation_class_key(annotation),
"className": _annotation_label(annotation),
"chineseName": _annotation_label(annotation),
"categoryName": _annotation_category_name(annotation),
"color": _annotation_color(annotation),
"internalPriority": _annotation_z_index(annotation),
"maskidHint": _annotation_mask_id(annotation),
"template_id": annotation.template_id,
}
def _build_gt_class_mapping(annotations: list[Annotation]) -> tuple[dict[str, int], list[dict[str, Any]]]:
entries_by_key: dict[str, dict[str, Any]] = {}
for annotation in annotations:
if not annotation.mask_data or not annotation.mask_data.get("polygons"):
continue
entry = _class_mapping_entry(annotation)
entries_by_key.setdefault(entry["key"], entry)
ordered = sorted(
entries_by_key.values(),
key=lambda item: (
item["maskidHint"] if isinstance(item.get("maskidHint"), int) and item["maskidHint"] >= 0 else 10_000_000,
str(item["className"]),
str(item["key"]),
),
)
key_to_value: dict[str, int] = {}
classes: list[dict[str, Any]] = []
used_maskids: set[int] = set()
next_maskid = 1
def next_available_maskid() -> int:
nonlocal next_maskid
while next_maskid in used_maskids:
next_maskid += 1
if next_maskid > 255:
raise HTTPException(status_code=400, detail="GT_label 仅支持 8-bit maskid类别值必须在 1-255 之间")
value = next_maskid
used_maskids.add(value)
next_maskid += 1
return value
for entry in ordered:
hinted_maskid = entry.get("maskidHint")
if isinstance(hinted_maskid, int) and hinted_maskid > 255:
raise HTTPException(status_code=400, detail="GT_label 仅支持 8-bit maskid类别值必须在 1-255 之间")
if isinstance(hinted_maskid, int) and hinted_maskid == 0:
maskid = 0
used_maskids.add(maskid)
elif isinstance(hinted_maskid, int) and 0 < hinted_maskid <= 255 and hinted_maskid not in used_maskids:
maskid = hinted_maskid
used_maskids.add(maskid)
else:
maskid = next_available_maskid()
key_to_value[entry["key"]] = maskid
classes.append({
"gt_pixel_value": maskid,
"maskid": maskid,
"chineseName": entry["chineseName"],
"className": entry["className"],
"categoryName": entry["categoryName"],
"rgb": _hex_to_rgb(entry["color"]),
"color": entry["color"],
"key": entry["key"],
"template_id": entry["template_id"],
})
return key_to_value, classes
def _parse_result_outputs(mask_type: str, outputs: str | None) -> set[str]:
allowed = {"separate", "gt_label", "pro_label", "mix_label"}
if outputs:
parsed = {item.strip() for item in outputs.split(",") if item.strip()}
invalid = parsed - allowed
if invalid:
raise HTTPException(status_code=400, detail=f"Invalid outputs: {', '.join(sorted(invalid))}")
return parsed or allowed
if mask_type == "separate":
return {"separate"}
if mask_type == "gt_label":
return {"gt_label"}
if mask_type == "pro_label":
return {"pro_label"}
if mask_type == "mix_label":
return {"mix_label"}
return allowed
def _write_original_frames(
zf: zipfile.ZipFile,
project: Project,
frames: list[Frame],
) -> dict[int, bytes]:
image_bytes_by_frame: dict[int, bytes] = {}
for frame in frames:
image_bytes = download_file(frame.image_url)
image_bytes_by_frame[frame.id] = image_bytes
zf.writestr(
f"原始图片/{_frame_export_stem(project, frame)}{_frame_image_extension(frame)}",
image_bytes,
)
return image_bytes_by_frame
def _decode_original_image(image_bytes: bytes | None, width: int, height: int) -> np.ndarray:
import cv2
if image_bytes:
decoded = cv2.imdecode(np.frombuffer(image_bytes, dtype=np.uint8), cv2.IMREAD_COLOR)
if decoded is not None:
if decoded.shape[1] != width or decoded.shape[0] != height:
decoded = cv2.resize(decoded, (width, height), interpolation=cv2.INTER_AREA)
return decoded
return np.zeros((height, width, 3), dtype=np.uint8)
def _write_result_mask_outputs(
zf: zipfile.ZipFile,
project: Project,
frames: list[Frame],
annotations: list[Annotation],
*,
outputs: set[str],
class_values: dict[str, int],
class_mapping: list[dict[str, Any]],
original_images: dict[int, bytes],
mix_opacity: float,
) -> None:
import cv2
include_individual = "separate" in outputs
include_semantic = "gt_label" in outputs
include_pro_label = "pro_label" in outputs
include_mix_label = "mix_label" in outputs
class_rgb_by_key = {
item["key"]: item.get("rgb") or _hex_to_rgb(item.get("color", "#ffffff"))
for item in class_mapping
}
annotations_by_frame: dict[int, list[Annotation]] = {}
selected_frame_ids = {frame.id for frame in frames}
for annotation in annotations:
if annotation.frame_id not in selected_frame_ids or not annotation.mask_data:
continue
if not annotation.mask_data.get("polygons"):
continue
annotations_by_frame.setdefault(annotation.frame_id, []).append(annotation)
for frame in frames:
frame_annotations = annotations_by_frame.get(frame.id, [])
if not frame_annotations:
continue
width = frame.width or 1920
height = frame.height or 1080
frame_stem = _frame_export_stem(project, frame)
if include_individual:
class_masks: dict[str, np.ndarray] = {}
class_meta: dict[str, dict[str, Any]] = {}
for annotation in frame_annotations:
key = _annotation_class_key(annotation)
combined = class_masks.setdefault(key, np.zeros((height, width), dtype=np.uint8))
for poly in (annotation.mask_data or {}).get("polygons", []):
combined[:] = np.maximum(combined, _mask_from_polygon(poly, width, height))
class_meta.setdefault(key, _class_mapping_entry(annotation))
folder = f"分开Mask分割结果/{frame_stem}_分别导出"
for key, mask in sorted(class_masks.items(), key=lambda item: int(class_meta[item[0]]["internalPriority"])):
meta = class_meta[key]
maskid = class_values.get(key)
if maskid is None:
continue
_, encoded = cv2.imencode(".png", mask)
class_name = _safe_filename_part(meta["className"], "class")
zf.writestr(
f"{folder}/{frame_stem}_{class_name}_maskid{maskid}.png",
encoded.tobytes(),
)
needs_fused_output = include_semantic or include_pro_label or include_mix_label
semantic = np.zeros((height, width), dtype=np.uint8) if needs_fused_output else None
pro_label = np.zeros((height, width, 3), dtype=np.uint8) if (include_pro_label or include_mix_label) else None
if needs_fused_output:
for annotation in sorted(frame_annotations, key=_annotation_z_index):
key = _annotation_class_key(annotation)
value = class_values.get(key)
if value is None:
continue
combined = np.zeros((height, width), dtype=np.uint8)
for poly in (annotation.mask_data or {}).get("polygons", []):
combined = np.maximum(combined, _mask_from_polygon(poly, width, height))
if semantic is not None:
semantic[combined > 0] = value
if pro_label is not None:
rgb = class_rgb_by_key.get(key, [255, 255, 255])
bgr = np.array([rgb[2], rgb[1], rgb[0]], dtype=np.uint8)
pro_label[combined > 0] = bgr
if include_semantic and semantic is not None:
_, encoded = cv2.imencode(".png", semantic)
zf.writestr(f"GT_label图/{frame_stem}.png", encoded.tobytes())
if include_pro_label and pro_label is not None:
_, encoded = cv2.imencode(".png", pro_label)
zf.writestr(f"Pro_label彩色分割结果/{frame_stem}.png", encoded.tobytes())
if include_mix_label and pro_label is not None:
original = _decode_original_image(original_images.get(frame.id), width, height)
mask_pixels = np.any(pro_label > 0, axis=2)
mixed = original.copy()
opacity = min(max(float(mix_opacity), 0.0), 1.0)
mixed[mask_pixels] = (
original[mask_pixels].astype(np.float32) * (1.0 - opacity)
+ pro_label[mask_pixels].astype(np.float32) * opacity
).clip(0, 255).astype(np.uint8)
_, encoded = cv2.imencode(".png", mixed)
zf.writestr(f"Mix_label重叠覆盖彩色分割结果/{frame_stem}.png", encoded.tobytes())
def _write_mask_pngs(
zf: zipfile.ZipFile,
frames: list[Frame],
annotations: list[Annotation],
*,
mask_type: str,
individual_prefix: str = "",
semantic_prefix: str = "",
semantic_file_stem: str = "semantic_frame",
semantic_dtype: Any = np.uint8,
) -> list[dict[str, Any]]:
import cv2
class_values: dict[str, int] = {}
semantic_classes: list[dict[str, Any]] = []
def class_value(annotation: Annotation) -> int:
key = _annotation_class_key(annotation)
if key not in class_values:
value = len(class_values) + 1
class_values[key] = value
semantic_classes.append({
"value": value,
"key": key,
"label": _annotation_label(annotation),
"color": _annotation_color(annotation),
"zIndex": _annotation_z_index(annotation),
"template_id": annotation.template_id,
})
return class_values[key]
include_individual = mask_type in {"separate", "both"}
include_semantic = mask_type in {"gt_label", "both"}
frame_masks: dict[int, list[tuple[Annotation, np.ndarray]]] = {}
selected_frame_ids = {frame.id for frame in frames}
for ann in annotations:
if ann.frame_id not in selected_frame_ids or not ann.mask_data:
continue
polygons = ann.mask_data.get("polygons", [])
if not polygons:
continue
width = ann.frame.width if ann.frame else 1920
height = ann.frame.height if ann.frame else 1080
combined = np.zeros((height, width), dtype=np.uint8)
for poly in polygons:
mask = _mask_from_polygon(poly, width, height)
combined = np.maximum(combined, mask)
if include_individual:
_, encoded = cv2.imencode(".png", combined)
zf.writestr(f"{individual_prefix}mask_{ann.id:06d}.png", encoded.tobytes())
if include_semantic and ann.frame_id is not None:
frame_masks.setdefault(ann.frame_id, []).append((ann, combined))
if include_semantic:
for frame in frames:
entries = frame_masks.get(frame.id, [])
if not entries:
continue
width = frame.width or 1920
height = frame.height or 1080
semantic = np.zeros((height, width), dtype=semantic_dtype)
for ann, mask in sorted(entries, key=lambda item: _annotation_z_index(item[0])):
semantic[mask > 0] = class_value(ann)
_, encoded = cv2.imencode(".png", semantic)
zf.writestr(f"{semantic_prefix}{semantic_file_stem}_{frame.frame_index:06d}.png", encoded.tobytes())
if include_semantic:
zf.writestr(
f"{semantic_prefix}semantic_classes.json",
json.dumps({"classes": semantic_classes}, ensure_ascii=False, indent=2).encode("utf-8"),
)
return semantic_classes
@router.get(
"/{project_id}/coco",
summary="Export annotations in COCO format",
)
def export_coco(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> StreamingResponse:
"""Export all annotations for a project as a COCO-format JSON file."""
project = _project_or_404(project_id, db, current_user)
frames = _project_frames(project_id, db)
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
templates = db.query(Template).all()
coco = _build_coco(project, frames, annotations, templates)
data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8")
filename = f"project_{project_id}_coco.json"
return StreamingResponse(
io.BytesIO(data),
media_type="application/json",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
@router.get(
"/{project_id}/masks",
summary="Export PNG masks as a ZIP archive",
)
def export_masks(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> StreamingResponse:
"""Export individual masks plus z-index fused semantic masks inside a ZIP."""
_project_or_404(project_id, db, current_user)
frames = _project_frames(project_id, db)
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
_write_mask_pngs(
zf,
frames,
annotations,
mask_type="both",
semantic_prefix="",
individual_prefix="",
)
zip_buffer.seek(0)
filename = f"project_{project_id}_masks.zip"
return StreamingResponse(
zip_buffer,
media_type="application/zip",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
@router.get(
"/{project_id}/results",
summary="Export segmentation results as a ZIP archive",
)
def export_results(
project_id: int,
scope: str = Query("all", pattern="^(all|range|current)$"),
mask_type: str = Query("both", pattern="^(separate|gt_label|pro_label|mix_label|both|all)$"),
outputs: str | None = Query(None),
mix_opacity: float = Query(0.3, ge=0.0, le=1.0),
start_frame: int | None = Query(None, ge=1),
end_frame: int | None = Query(None, ge=1),
frame_id: int | None = Query(None, ge=1),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> StreamingResponse:
"""Export JSON annotations plus selected PNG mask outputs inside one ZIP.
`scope=all` exports the whole video. `scope=range` uses 1-based frame
numbers from the sorted project frame sequence. `scope=current` uses the
concrete backend `frame_id`.
"""
project = _project_or_404(project_id, db, current_user)
frames = _filter_frames(
_project_frames(project_id, db),
scope=scope,
start_frame=start_frame,
end_frame=end_frame,
frame_id=frame_id,
)
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
templates = db.query(Template).all()
coco = _build_coco(project, frames, annotations, templates)
class_values, class_mapping = _build_gt_class_mapping(annotations)
selected_outputs = _parse_result_outputs(mask_type, outputs)
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
zf.writestr(
"annotations_coco.json",
json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8"),
)
zf.writestr(
"maskid_GT像素值_类别映射.json",
json.dumps({"classes": class_mapping}, ensure_ascii=False, indent=2).encode("utf-8"),
)
original_images = _write_original_frames(zf, project, frames)
_write_result_mask_outputs(
zf,
project,
frames,
annotations,
outputs=selected_outputs,
class_values=class_values,
class_mapping=class_mapping,
original_images=original_images,
mix_opacity=mix_opacity,
)
zip_buffer.seek(0)
filename = _segmentation_results_filename(project, frames)
return StreamingResponse(
zip_buffer,
media_type="application/zip",
headers={"Content-Disposition": _download_content_disposition(filename)},
)

234
backend/routers/media.py Normal file
View File

@@ -0,0 +1,234 @@
"""Media upload and parsing endpoints."""
import logging
import re
from pathlib import Path
from typing import List, Optional
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile, status
from sqlalchemy.orm import Session
from database import get_db
from minio_client import upload_file, get_presigned_url
from models import ProcessingTask, Project, User
from progress_events import publish_task_progress_event
from routers.auth import require_editor
from schemas import ProcessingTaskOut
from statuses import PROJECT_STATUS_PARSING, PROJECT_STATUS_PENDING, TASK_STATUS_QUEUED
from worker_tasks import parse_project_media
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/media", tags=["Media"])
ALLOWED_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".png", ".jpg", ".jpeg", ".dcm"}
def natural_filename_key(filename: str) -> tuple[object, ...]:
return tuple(
int(part) if part.isdigit() else part.casefold()
for part in re.split(r"(\d+)", Path(filename).name)
)
def _get_ext(filename: str) -> str:
return Path(filename).suffix.lower()
@router.post(
"/upload",
status_code=status.HTTP_201_CREATED,
summary="Upload a media file",
)
async def upload_media(
file: UploadFile = File(...),
project_id: Optional[int] = Form(None),
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> dict:
"""Accept a video, image, or DICOM file and store it in MinIO.
If project_id is provided, the video_path of the project is updated.
Returns the presigned URL of the uploaded object.
"""
if not file.filename:
raise HTTPException(status_code=400, detail="Missing filename")
ext = _get_ext(file.filename)
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {ext}",
)
data = await file.read()
object_name = f"uploads/{project_id or 'general'}/{file.filename}"
try:
upload_file(object_name, data, content_type=file.content_type or "application/octet-stream", length=len(data))
except Exception as exc: # noqa: BLE001
logger.error("Upload failed: %s", exc)
raise HTTPException(status_code=500, detail="Upload to storage failed") from exc
file_url = get_presigned_url(object_name, expires=3600)
if project_id:
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
project.video_path = object_name
db.commit()
logger.info("Linked upload to project_id=%s", project_id)
else:
# Auto-create a project named after the file
project = Project(
name=file.filename,
description="Auto-created from upload",
status=PROJECT_STATUS_PENDING,
video_path=object_name,
source_type="video",
owner_user_id=current_user.id,
)
db.add(project)
db.commit()
db.refresh(project)
project_id = project.id
object_name = f"uploads/{project_id}/{file.filename}"
# Re-upload with corrected path
upload_file(object_name, data, content_type=file.content_type or "application/octet-stream", length=len(data))
project.video_path = object_name
db.commit()
logger.info("Auto-created project id=%s for upload %s", project_id, file.filename)
logger.info("Upload complete: %s (size=%d bytes). Async parsing queued.", object_name, len(data))
return {
"object_name": object_name,
"file_url": file_url,
"size": len(data),
"project_id": project_id,
"message": "Upload successful. Parsing job queued.",
}
@router.post(
"/upload/dicom",
status_code=status.HTTP_201_CREATED,
summary="Upload multiple DICOM files",
)
async def upload_dicom_batch(
files: List[UploadFile] = File(...),
project_id: Optional[int] = Form(None),
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> dict:
"""Upload multiple .dcm files for a DICOM series.
If project_id is provided, files are added to the existing project.
Otherwise a new DICOM project is created.
"""
if not files:
raise HTTPException(status_code=400, detail="No files uploaded")
sorted_files = sorted(
[file for file in files if file.filename and file.filename.lower().endswith(".dcm")],
key=lambda file: natural_filename_key(file.filename or ""),
)
if not sorted_files:
raise HTTPException(status_code=400, detail="No valid DICOM files uploaded")
uploaded = []
if project_id:
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
else:
# Create new DICOM project
first_name = sorted_files[0].filename or "DICOM_Series"
project = Project(
name=first_name,
description=f"DICOM series with {len(sorted_files)} files",
status=PROJECT_STATUS_PENDING,
source_type="dicom",
owner_user_id=current_user.id,
)
db.add(project)
db.commit()
db.refresh(project)
project_id = project.id
logger.info("Auto-created DICOM project id=%s", project_id)
for file in sorted_files:
data = await file.read()
object_name = f"uploads/{project_id}/dicom/{file.filename}"
try:
upload_file(object_name, data, content_type="application/dicom", length=len(data))
uploaded.append(object_name)
except Exception as exc: # noqa: BLE001
logger.error("Failed to upload DICOM %s: %s", file.filename, exc)
project.video_path = f"uploads/{project_id}/dicom"
db.commit()
return {
"project_id": project_id,
"uploaded_count": len(uploaded),
"message": f"Uploaded {len(uploaded)} DICOM files. Parsing job queued.",
}
@router.post(
"/parse",
status_code=status.HTTP_202_ACCEPTED,
response_model=ProcessingTaskOut,
summary="Trigger frame extraction",
)
def parse_media(
project_id: int,
source_type: Optional[str] = None,
parse_fps: Optional[float] = Query(None, gt=0, le=120),
max_frames: Optional[int] = Query(None, gt=0),
target_width: int = Query(640, ge=64, le=4096),
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> ProcessingTask:
"""Create a background task for media frame extraction.
The Celery worker performs the heavy FFmpeg/OpenCV/pydicom work and
updates the persisted task record as it progresses.
"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if not project.video_path:
raise HTTPException(status_code=400, detail="Project has no media uploaded")
effective_source = source_type or project.source_type or "video"
effective_parse_fps = parse_fps or project.parse_fps or 30.0
task = ProcessingTask(
task_type=f"parse_{effective_source}",
status=TASK_STATUS_QUEUED,
progress=0,
message="解析任务已入队",
project_id=project_id,
payload={
"source_type": effective_source,
"parse_fps": effective_parse_fps,
"max_frames": max_frames,
"target_width": target_width,
},
)
project.parse_fps = effective_parse_fps
project.status = PROJECT_STATUS_PARSING
db.add(task)
db.commit()
db.refresh(task)
publish_task_progress_event(task)
async_result = parse_project_media.delay(task.id)
task.celery_task_id = async_result.id
db.commit()
db.refresh(task)
logger.info("Queued parse task id=%s project_id=%s celery_id=%s", task.id, project_id, async_result.id)
return task

310
backend/routers/projects.py Normal file
View File

@@ -0,0 +1,310 @@
"""Project and Frame CRUD endpoints."""
import logging
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from database import get_db
from models import Annotation, Mask, Project, Frame, User
from routers.auth import get_current_user, require_editor
from schemas import ProjectCopyRequest, ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut
from minio_client import get_presigned_url
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/projects", tags=["Projects"])
def _next_project_copy_name(db: Session, source_name: str) -> str:
base_name = f"{source_name} 副本"
existing_names = {
row[0]
for row in db.query(Project.name)
.filter(Project.name.like(f"{base_name}%"))
.all()
}
if base_name not in existing_names:
return base_name
suffix = 2
while f"{base_name} {suffix}" in existing_names:
suffix += 1
return f"{base_name} {suffix}"
def _prepare_project_response(project: Project) -> Project:
project.frame_count = len(project.frames)
if project.thumbnail_url:
project.thumbnail_url = get_presigned_url(project.thumbnail_url, expires=3600)
return project
# ---------------------------------------------------------------------------
# Projects
# ---------------------------------------------------------------------------
@router.post(
"",
response_model=ProjectOut,
status_code=status.HTTP_201_CREATED,
summary="Create a new project",
)
def create_project(
payload: ProjectCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Project:
"""Create a new segmentation project."""
project = Project(**payload.model_dump(), owner_user_id=current_user.id)
db.add(project)
db.commit()
db.refresh(project)
logger.info("Created project id=%s name=%s", project.id, project.name)
return project
@router.get(
"",
response_model=List[ProjectOut],
summary="List all projects",
)
def list_projects(
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> List[Project]:
"""Retrieve a paginated list of projects."""
projects = (
db.query(Project)
.offset(skip)
.limit(limit)
.all()
)
for p in projects:
_prepare_project_response(p)
return projects
@router.get(
"/{project_id}",
response_model=ProjectOut,
summary="Get a single project",
)
def get_project(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Project:
"""Retrieve a project by its ID."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return _prepare_project_response(project)
@router.post(
"/{project_id}/copy",
response_model=ProjectOut,
status_code=status.HTTP_201_CREATED,
summary="Copy a project",
)
def copy_project(
project_id: int,
payload: ProjectCopyRequest,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Project:
"""Copy a project. Reset copies media/frame sequence; full also copies annotations and mask metadata."""
source = db.query(Project).filter(Project.id == project_id).first()
if not source:
raise HTTPException(status_code=404, detail="Project not found")
next_name = (payload.name or "").strip() if payload.name is not None else ""
if not next_name:
next_name = _next_project_copy_name(db, source.name)
copied = Project(
name=next_name,
description=source.description,
video_path=source.video_path,
thumbnail_url=source.thumbnail_url,
status=source.status,
source_type=source.source_type,
original_fps=source.original_fps,
parse_fps=source.parse_fps,
owner_user_id=current_user.id,
)
db.add(copied)
db.flush()
frame_id_map: dict[int, int] = {}
for frame in sorted(source.frames, key=lambda item: item.frame_index):
copied_frame = Frame(
project_id=copied.id,
frame_index=frame.frame_index,
image_url=frame.image_url,
width=frame.width,
height=frame.height,
timestamp_ms=frame.timestamp_ms,
source_frame_number=frame.source_frame_number,
)
db.add(copied_frame)
db.flush()
frame_id_map[frame.id] = copied_frame.id
if payload.mode == "full":
for annotation in sorted(source.annotations, key=lambda item: item.id):
copied_annotation = Annotation(
project_id=copied.id,
frame_id=frame_id_map.get(annotation.frame_id) if annotation.frame_id is not None else None,
template_id=annotation.template_id,
mask_data=annotation.mask_data,
points=annotation.points,
bbox=annotation.bbox,
)
db.add(copied_annotation)
db.flush()
for mask in annotation.masks:
db.add(Mask(
annotation_id=copied_annotation.id,
mask_url=mask.mask_url,
format=mask.format,
))
db.commit()
db.refresh(copied)
logger.info("Copied project id=%s to id=%s mode=%s", project_id, copied.id, payload.mode)
return _prepare_project_response(copied)
@router.patch(
"/{project_id}",
response_model=ProjectOut,
summary="Update a project",
)
def update_project(
project_id: int,
payload: ProjectUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Project:
"""Update project fields partially."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
for key, value in payload.model_dump(exclude_unset=True).items():
if key == "name":
value = (value or "").strip()
if not value:
raise HTTPException(status_code=400, detail="Project name is required")
setattr(project, key, value)
db.commit()
db.refresh(project)
logger.info("Updated project id=%s", project_id)
return project
@router.delete(
"/{project_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a project",
)
def delete_project(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> None:
"""Delete a project and all related frames and annotations."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
db.delete(project)
db.commit()
logger.info("Deleted project id=%s", project_id)
# ---------------------------------------------------------------------------
# Frames
# ---------------------------------------------------------------------------
@router.post(
"/{project_id}/frames",
response_model=FrameOut,
status_code=status.HTTP_201_CREATED,
summary="Add a frame to a project",
)
def create_frame(
project_id: int,
payload: FrameCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Frame:
"""Register a new frame under a project."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
frame = Frame(project_id=project_id, **payload.model_dump(exclude={"project_id"}))
db.add(frame)
db.commit()
db.refresh(frame)
return frame
@router.get(
"/{project_id}/frames",
response_model=List[FrameOut],
summary="List frames for a project",
)
def list_frames(
project_id: int,
skip: int = 0,
limit: Optional[int] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> List[Frame]:
"""Retrieve all frames belonging to a project."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
query = (
db.query(Frame)
.filter(Frame.project_id == project_id)
.order_by(Frame.frame_index)
.offset(skip)
)
if limit is not None:
query = query.limit(limit)
frames = query.all()
for frame in frames:
frame.image_url = get_presigned_url(frame.image_url, expires=3600)
return frames
@router.get(
"/{project_id}/frames/{frame_id}",
response_model=FrameOut,
summary="Get a single frame",
)
def get_frame(
project_id: int,
frame_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Frame:
"""Retrieve a specific frame by ID."""
frame = (
db.query(Frame)
.join(Project, Project.id == Frame.project_id)
.filter(
Frame.project_id == project_id,
Frame.id == frame_id,
)
.first()
)
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
return frame

161
backend/routers/tasks.py Normal file
View File

@@ -0,0 +1,161 @@
"""Processing task query endpoints."""
import logging
from datetime import datetime, timezone
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from celery_app import celery_app
from database import get_db
from models import ProcessingTask, Project, User
from progress_events import publish_task_progress_event
from routers.auth import get_current_user, require_editor
from schemas import ProcessingTaskOut
from statuses import (
PROJECT_STATUS_PARSING,
PROJECT_STATUS_PENDING,
PROJECT_STATUS_READY,
TASK_ACTIVE_STATUSES,
TASK_STATUS_CANCELLED,
TASK_STATUS_FAILED,
TASK_STATUS_QUEUED,
)
from worker_tasks import parse_project_media, propagate_project_masks
router = APIRouter(prefix="/api/tasks", tags=["Tasks"])
logger = logging.getLogger(__name__)
def _now() -> datetime:
return datetime.now(timezone.utc)
def _get_task_or_404(task_id: int, db: Session, current_user: User) -> ProcessingTask:
_ = current_user
task = (
db.query(ProcessingTask)
.outerjoin(Project, Project.id == ProcessingTask.project_id)
.filter(ProcessingTask.id == task_id)
.first()
)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task
def _project_status_after_stop(project: Project) -> str:
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
@router.get("", response_model=List[ProcessingTaskOut], summary="List processing tasks")
def list_tasks(
project_id: int | None = None,
status: str | None = None,
limit: int = 50,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> List[ProcessingTask]:
"""Return recent background processing tasks."""
_ = current_user
query = db.query(ProcessingTask).outerjoin(Project, Project.id == ProcessingTask.project_id)
if project_id is not None:
query = query.filter(ProcessingTask.project_id == project_id)
if status is not None:
query = query.filter(ProcessingTask.status == status)
return query.order_by(ProcessingTask.created_at.desc()).limit(limit).all()
@router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task")
def get_task(
task_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> ProcessingTask:
"""Return one background task by id."""
return _get_task_or_404(task_id, db, current_user)
@router.post("/{task_id}/cancel", response_model=ProcessingTaskOut, summary="Cancel processing task")
def cancel_task(
task_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> ProcessingTask:
"""Cancel a queued/running background task and revoke the Celery job when possible."""
task = _get_task_or_404(task_id, db, current_user)
if task.status not in TASK_ACTIVE_STATUSES:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Task is not cancellable in status: {task.status}",
)
if task.celery_task_id:
try:
celery_app.control.revoke(task.celery_task_id, terminate=True, signal="SIGTERM")
except Exception as exc: # noqa: BLE001
logger.warning("Failed to revoke celery task %s: %s", task.celery_task_id, exc)
task.status = TASK_STATUS_CANCELLED
task.progress = 100
task.message = "任务已取消"
task.error = "Cancelled by user"
task.finished_at = _now()
if task.project:
task.project.status = _project_status_after_stop(task.project)
db.commit()
db.refresh(task)
publish_task_progress_event(task)
return task
@router.post("/{task_id}/retry", response_model=ProcessingTaskOut, status_code=status.HTTP_202_ACCEPTED, summary="Retry processing task")
def retry_task(
task_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> ProcessingTask:
"""Create a fresh queued task from a failed or cancelled task."""
previous = _get_task_or_404(task_id, db, current_user)
if previous.status not in {TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Task is not retryable in status: {previous.status}",
)
if previous.project_id is None:
raise HTTPException(status_code=400, detail="Task has no project_id")
project = db.query(Project).filter(Project.id == previous.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
is_propagation_task = previous.task_type == "propagate_masks"
if not is_propagation_task and not project.video_path:
raise HTTPException(status_code=400, detail="Project has no media uploaded")
payload = dict(previous.payload or {})
payload.setdefault("source_type", project.source_type or "video")
payload["retry_of"] = previous.id
task = ProcessingTask(
task_type=previous.task_type,
status=TASK_STATUS_QUEUED,
progress=0,
message=f"重试任务已入队(源任务 #{previous.id}",
project_id=project.id,
payload=payload,
)
if not is_propagation_task:
project.status = PROJECT_STATUS_PARSING
db.add(task)
db.commit()
db.refresh(task)
publish_task_progress_event(task)
async_result = propagate_project_masks.delay(task.id) if is_propagation_task else parse_project_media.delay(task.id)
task.celery_task_id = async_result.id
db.commit()
db.refresh(task)
publish_task_progress_event(task)
return task

View File

@@ -0,0 +1,183 @@
"""Template (Ontology) CRUD endpoints."""
import logging
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import or_
from sqlalchemy.orm import Session
from database import get_db
from models import Template, User
from routers.auth import get_current_user, require_editor
from schemas import TemplateCreate, TemplateOut, TemplateUpdate
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/templates", tags=["Templates"])
RESERVED_UNCLASSIFIED_CLASS = {
"id": "reserved-unclassified",
"name": "待分类",
"color": "#000000",
"zIndex": 0,
"maskId": 0,
"category": "系统保留",
}
def _is_reserved_class(item: dict) -> bool:
return (
item.get("id") == RESERVED_UNCLASSIFIED_CLASS["id"]
or item.get("name") == RESERVED_UNCLASSIFIED_CLASS["name"]
or item.get("maskId") == 0
)
def _normalize_template_classes(classes: list[dict] | None) -> list[dict]:
normalized = [item for item in (classes or []) if not _is_reserved_class(item)]
return [*normalized, dict(RESERVED_UNCLASSIFIED_CLASS)]
def _pack_mapping_rules(data: dict) -> dict:
"""Pack classes/rules into mapping_rules for DB storage."""
mapping = data.get("mapping_rules") or {}
if "classes" in data and data["classes"] is not None:
mapping["classes"] = _normalize_template_classes(data.pop("classes"))
if "rules" in data and data["rules"] is not None:
mapping["rules"] = data.pop("rules")
if "classes" in mapping:
mapping["classes"] = _normalize_template_classes(mapping.get("classes"))
data["mapping_rules"] = mapping
return data
def _unpack_template(template: Template) -> Template:
"""Unpack mapping_rules into classes/rules for response."""
mapping = template.mapping_rules or {}
# Set as attributes so Pydantic from_attributes can pick them up
template.classes = _normalize_template_classes(mapping.get("classes", []))
template.rules = mapping.get("rules", [])
return template
@router.post(
"",
response_model=TemplateOut,
status_code=status.HTTP_201_CREATED,
summary="Create a new template",
)
def create_template(
payload: TemplateCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Template:
"""Create a new ontology template / segmentation class."""
data = payload.model_dump()
data = _pack_mapping_rules(data)
template = Template(**data, owner_user_id=current_user.id)
db.add(template)
db.commit()
db.refresh(template)
_unpack_template(template)
logger.info("Created template id=%s name=%s", template.id, template.name)
return template
@router.get(
"",
response_model=List[TemplateOut],
summary="List all templates",
)
def list_templates(
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> List[Template]:
"""Retrieve all ontology templates."""
templates = (
db.query(Template)
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
.offset(skip)
.limit(limit)
.all()
)
for t in templates:
_unpack_template(t)
return templates
@router.get(
"/{template_id}",
response_model=TemplateOut,
summary="Get a single template",
)
def get_template(
template_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Template:
"""Retrieve a template by its ID."""
template = db.query(Template).filter(
Template.id == template_id,
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
_unpack_template(template)
return template
@router.patch(
"/{template_id}",
response_model=TemplateOut,
summary="Update a template",
)
def update_template(
template_id: int,
payload: TemplateUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Template:
"""Update template fields partially."""
template = db.query(Template).filter(
Template.id == template_id,
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
data = payload.model_dump(exclude_unset=True)
if "classes" in data or "rules" in data:
data = _pack_mapping_rules(data)
for key, value in data.items():
setattr(template, key, value)
db.commit()
db.refresh(template)
_unpack_template(template)
logger.info("Updated template id=%s", template_id)
return template
@router.delete(
"/{template_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a template",
)
def delete_template(
template_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> None:
"""Delete a template. Associated annotations will have template_id set to NULL."""
template = db.query(Template).filter(
Template.id == template_id,
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
db.delete(template)
db.commit()
logger.info("Deleted template id=%s", template_id)

385
backend/schemas.py Normal file
View File

@@ -0,0 +1,385 @@
"""Pydantic schemas for request/response validation."""
from datetime import datetime
from typing import Literal, Optional, Any
from pydantic import BaseModel, ConfigDict
# ---------------------------------------------------------------------------
# Auth / user schemas
# ---------------------------------------------------------------------------
class UserOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
username: str
role: str
is_active: int
class LoginResponse(BaseModel):
token: str
token_type: str = "bearer"
username: str
user: UserOut
class AdminUserCreate(BaseModel):
username: str
password: str
role: str = "annotator"
is_active: bool = True
class AdminUserUpdate(BaseModel):
username: Optional[str] = None
password: Optional[str] = None
role: Optional[str] = None
is_active: Optional[bool] = None
class AuditLogOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
actor_user_id: Optional[int] = None
action: str
target_type: Optional[str] = None
target_id: Optional[str] = None
detail: Optional[dict[str, Any]] = None
created_at: datetime
class DemoFactoryResetRequest(BaseModel):
confirmation: str
# ---------------------------------------------------------------------------
# Project schemas
# ---------------------------------------------------------------------------
class ProjectBase(BaseModel):
name: str
description: Optional[str] = None
video_path: Optional[str] = None
thumbnail_url: Optional[str] = None
status: Optional[str] = "pending"
source_type: Optional[str] = "video"
original_fps: Optional[float] = None
parse_fps: Optional[float] = 30.0
class ProjectCreate(ProjectBase):
pass
class ProjectUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
video_path: Optional[str] = None
thumbnail_url: Optional[str] = None
status: Optional[str] = None
source_type: Optional[str] = None
original_fps: Optional[float] = None
parse_fps: Optional[float] = None
class ProjectCopyRequest(BaseModel):
mode: Literal["reset", "full"] = "reset"
name: Optional[str] = None
class ProjectOut(ProjectBase):
model_config = ConfigDict(from_attributes=True)
id: int
owner_user_id: Optional[int] = None
created_at: datetime
updated_at: datetime
frame_count: int = 0
class DemoFactoryResetOut(BaseModel):
admin_user: UserOut
project: ProjectOut
projects: list[ProjectOut]
deleted_counts: dict[str, int]
message: str
# ---------------------------------------------------------------------------
# Frame schemas
# ---------------------------------------------------------------------------
class FrameBase(BaseModel):
frame_index: int
image_url: str
width: Optional[int] = None
height: Optional[int] = None
timestamp_ms: Optional[float] = None
source_frame_number: Optional[int] = None
class FrameCreate(FrameBase):
project_id: int
class FrameOut(FrameBase):
model_config = ConfigDict(from_attributes=True)
id: int
project_id: int
created_at: datetime
# ---------------------------------------------------------------------------
# Template schemas
# ---------------------------------------------------------------------------
class TemplateBase(BaseModel):
name: str
description: Optional[str] = None
color: str
z_index: int = 0
mapping_rules: Optional[dict[str, Any]] = None
classes: Optional[list[dict[str, Any]]] = None
rules: Optional[list[dict[str, Any]]] = None
class TemplateCreate(TemplateBase):
pass
class TemplateUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
color: Optional[str] = None
z_index: Optional[int] = None
mapping_rules: Optional[dict[str, Any]] = None
classes: Optional[list[dict[str, Any]]] = None
rules: Optional[list[dict[str, Any]]] = None
class TemplateOut(TemplateBase):
model_config = ConfigDict(from_attributes=True)
id: int
owner_user_id: Optional[int] = None
created_at: datetime
# ---------------------------------------------------------------------------
# Annotation schemas
# ---------------------------------------------------------------------------
class AnnotationBase(BaseModel):
project_id: int
frame_id: Optional[int] = None
template_id: Optional[int] = None
mask_data: Optional[dict[str, Any]] = None
points: Optional[list[list[float]]] = None
bbox: Optional[list[float]] = None
class AnnotationCreate(AnnotationBase):
pass
class AnnotationUpdate(BaseModel):
mask_data: Optional[dict[str, Any]] = None
points: Optional[list[list[float]]] = None
bbox: Optional[list[float]] = None
template_id: Optional[int] = None
class AnnotationOut(AnnotationBase):
model_config = ConfigDict(from_attributes=True)
id: int
created_at: datetime
updated_at: datetime
# ---------------------------------------------------------------------------
# Mask schemas
# ---------------------------------------------------------------------------
class MaskBase(BaseModel):
annotation_id: int
mask_url: str
format: str = "png"
class MaskCreate(MaskBase):
pass
class MaskOut(MaskBase):
model_config = ConfigDict(from_attributes=True)
id: int
created_at: datetime
# ---------------------------------------------------------------------------
# Processing task schemas
# ---------------------------------------------------------------------------
class ProcessingTaskOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
task_type: str
status: str
progress: int
message: Optional[str] = None
project_id: Optional[int] = None
celery_task_id: Optional[str] = None
payload: Optional[dict[str, Any]] = None
result: Optional[dict[str, Any]] = None
error: Optional[str] = None
created_at: datetime
started_at: Optional[datetime] = None
finished_at: Optional[datetime] = None
updated_at: datetime
# ---------------------------------------------------------------------------
# AI schemas
# ---------------------------------------------------------------------------
class PredictRequest(BaseModel):
image_id: int
prompt_type: str # point / box / semantic
prompt_data: Any
model: Optional[str] = None
options: Optional[dict[str, Any]] = None
class PredictResponse(BaseModel):
polygons: list[list[list[float]]]
scores: Optional[list[float]] = None
class MaskAnalysisRequest(BaseModel):
frame_id: Optional[int] = None
mask_data: dict[str, Any]
points: Optional[list[list[float]]] = None
bbox: Optional[list[float]] = None
extract_skeleton: bool = False
class MaskAnalysisResponse(BaseModel):
confidence: Optional[float] = None
confidence_source: str
topology_anchor_count: int
topology_anchors: list[list[float]]
area: float
bbox: Optional[list[float]] = None
source: Optional[str] = None
message: str
class SmoothMaskRequest(BaseModel):
frame_id: Optional[int] = None
mask_data: dict[str, Any]
points: Optional[list[list[float]]] = None
bbox: Optional[list[float]] = None
strength: float = 0.0
method: str = "chaikin"
class SmoothMaskResponse(BaseModel):
polygons: list[list[list[float]]]
topology_anchor_count: int
topology_anchors: list[list[float]]
area: float
bbox: Optional[list[float]] = None
smoothing: dict[str, Any]
message: str
class PropagationSeed(BaseModel):
polygons: Optional[list[list[list[float]]]] = None
holes: Optional[list[list[list[list[float]]]]] = None
bbox: Optional[list[float]] = None
points: Optional[list[list[float]]] = None
labels: Optional[list[int]] = None
label: Optional[str] = None
color: Optional[str] = None
class_metadata: Optional[dict[str, Any]] = None
template_id: Optional[int] = None
source_mask_id: Optional[str] = None
source_annotation_id: Optional[int] = None
source_instance_id: Optional[str] = None
propagation_seed_signature: Optional[str] = None
smoothing: Optional[dict[str, Any]] = None
class PropagateRequest(BaseModel):
project_id: int
frame_id: int
model: Optional[str] = "sam2.1_hiera_tiny"
seed: PropagationSeed
direction: str = "forward"
max_frames: int = 30
include_source: bool = False
save_annotations: bool = True
class PropagateResponse(BaseModel):
model: str
direction: str
source_frame_id: int
processed_frame_count: int
created_annotation_count: int
annotations: list[AnnotationOut]
class PropagateTaskStep(BaseModel):
seed: PropagationSeed
direction: str = "forward"
max_frames: int = 30
class PropagateTaskRequest(BaseModel):
project_id: int
frame_id: int
model: Optional[str] = "sam2.1_hiera_tiny"
steps: list[PropagateTaskStep]
include_source: bool = False
save_annotations: bool = True
class AiModelStatus(BaseModel):
id: str
label: str
available: bool
loaded: bool = False
device: str
supports: list[str]
message: str
package_available: bool = False
checkpoint_exists: bool = False
checkpoint_path: Optional[str] = None
python_ok: bool = True
torch_ok: bool = True
cuda_required: bool = False
external_available: bool = False
external_python: Optional[str] = None
class GpuStatus(BaseModel):
available: bool
device: str
name: Optional[str] = None
torch_available: bool
torch_version: Optional[str] = None
cuda_version: Optional[str] = None
class AiRuntimeStatus(BaseModel):
selected_model: str
gpu: GpuStatus
models: list[AiModelStatus]
# ---------------------------------------------------------------------------
# Export schemas
# ---------------------------------------------------------------------------
class ExportStatus(BaseModel):
url: str
format: str

View File

View File

@@ -0,0 +1,164 @@
"""Bundled system ontology templates and restore helpers."""
from __future__ import annotations
from copy import deepcopy
from sqlalchemy.orm import Session
from models import Template
RESERVED_UNCLASSIFIED_CLASS = {
"id": "reserved-unclassified",
"name": "待分类",
"color": "#000000",
"zIndex": 0,
"maskId": 0,
"category": "系统保留",
}
def _with_reserved_unclassified_class(classes: list[dict]) -> list[dict]:
filtered = [
item for item in classes
if item.get("id") != RESERVED_UNCLASSIFIED_CLASS["id"]
and item.get("name") != RESERVED_UNCLASSIFIED_CLASS["name"]
and item.get("maskId") != 0
]
return [*filtered, dict(RESERVED_UNCLASSIFIED_CLASS)]
def _template_classes(
template_name: str,
names: list[str],
colors: list[tuple[int, int, int]],
*,
id_prefix: str,
) -> list[dict]:
classes = []
for idx, (rgb, name) in enumerate(zip(colors, names)):
color_hex = f"#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}"
classes.append({
"id": f"{id_prefix}-{idx}",
"name": name,
"color": color_hex,
"zIndex": (len(names) - idx) * 10,
"maskId": idx + 1,
"category": template_name,
})
return classes
def bundled_default_template_definitions() -> list[dict]:
"""Return fresh definitions for all bundled system templates."""
return [
{
"name": "腹腔镜胆囊切除术",
"description": "腹腔镜胆囊切除术LC手术器械与解剖结构语义分割模板共35个分类",
"color": "#06b6d4",
"z_index": 0,
"classes": _with_reserved_unclassified_class(_template_classes(
"腹腔镜胆囊切除术",
[
"", "线", "肿瘤", "血管阻断夹", "棉球", "双极电凝",
"肝脏", "胆囊", "分离钳", "脂肪", "止血海绵", "肝总管",
"吸引器", "剪刀", "超声刀", "止血纱布", "胆总管", "生物夹",
"无损伤钳", "钳夹", "喷洒", "胆囊管", "动脉", "电凝",
"静脉", "标本袋", "引流管", "纱布", "金属钛夹", "韧带",
"肝蒂", "推结器", "乳胶管-血管阻断", "吻合器", "术中超声",
],
[
(134, 124, 118), (0, 157, 142), (245, 161, 0), (255, 172, 159), (146, 175, 236), (155, 62, 0),
(255, 91, 0), (255, 234, 0), (85, 111, 181), (155, 132, 0), (181, 227, 14), (72, 0, 255),
(255, 0, 255), (29, 32, 136), (240, 16, 116), (160, 15, 95), (0, 155, 33), (0, 160, 233),
(52, 184, 178), (66, 115, 82), (90, 120, 41), (255, 0, 0), (117, 0, 0), (167, 24, 233),
(42, 8, 66), (112, 113, 150), (0, 255, 0), (255, 255, 255), (0, 255, 255), (181, 85, 105),
(113, 102, 140), (202, 202, 200), (197, 83, 181), (136, 162, 196), (138, 251, 213),
],
id_prefix="cls-lap",
)),
},
{
"name": "头颈部CT分割",
"description": "头颈部CT分割",
"color": "#ef4444",
"z_index": 10,
"classes": _with_reserved_unclassified_class(_template_classes(
"头颈部CT分割",
[
"肿瘤/结节",
"下颌骨",
"甲状腺",
"气管",
"颈椎",
"颈动脉",
"颈静脉",
"腮腺",
"下颌下腺",
"舌骨",
],
[
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 255),
(255, 128, 0),
(128, 0, 128),
(0, 128, 128),
(128, 128, 0),
],
id_prefix="cls-head-neck-ct",
)),
},
]
def _has_legacy_head_neck_english_labels(template: Template) -> bool:
if template.name != "头颈部CT分割":
return False
classes = (template.mapping_rules or {}).get("classes") or []
return any(
isinstance(item, dict)
and isinstance(item.get("name"), str)
and "(" in item["name"]
and ")" in item["name"]
for item in classes
)
def ensure_default_templates(db: Session, *, restore_existing: bool = False) -> list[Template]:
"""Create bundled system templates, optionally restoring existing ones exactly."""
templates: list[Template] = []
for definition in bundled_default_template_definitions():
existing = db.query(Template).filter(
Template.name == definition["name"],
Template.owner_user_id.is_(None),
).first()
if existing is None:
existing = Template(owner_user_id=None)
db.add(existing)
elif not restore_existing and not _has_legacy_head_neck_english_labels(existing):
templates.append(existing)
continue
existing.name = definition["name"]
existing.description = definition["description"]
existing.color = definition["color"]
existing.z_index = definition["z_index"]
existing.owner_user_id = None
existing.mapping_rules = {
"classes": deepcopy(definition["classes"]),
"rules": [],
}
templates.append(existing)
db.commit()
for template in templates:
db.refresh(template)
return templates
def restore_default_templates(db: Session) -> list[Template]:
"""Restore bundled system templates after demo factory reset."""
return ensure_default_templates(db, restore_existing=True)

View File

@@ -0,0 +1,217 @@
"""Helpers for seeding the bundled demo media project."""
from __future__ import annotations
import os
import shutil
import tempfile
from pathlib import Path
import cv2
from sqlalchemy.orm import Session
from minio_client import upload_file
from models import Frame, Project, User
from services.frame_parser import (
extract_thumbnail,
natural_filename_key,
parse_dicom,
parse_video,
upload_frames_to_minio,
)
from statuses import PROJECT_STATUS_PENDING, PROJECT_STATUS_READY
DEMO_DICOM_PROJECT_NAME = "演视DICOM序列"
DEMO_DICOM_PARSE_FPS = 30.0
DEMO_VIDEO_PROJECT_NAME = "演视LC视频序列"
DEMO_VIDEO_PARSE_FPS = 30.0
DEMO_VIDEO_TARGET_WIDTH = 640
LEGACY_DEMO_VIDEO_PROJECT_NAMES = {"Data_MyVideo_1"}
LEGACY_DEMO_DICOM_PROJECT_NAMES = {"演示DICOM序列"}
def demo_dicom_files(dicom_dir: str) -> list[Path]:
"""Return .dcm files in natural file-name order."""
root = Path(dicom_dir)
if not root.exists() or not root.is_dir():
return []
return sorted(
[path for path in root.iterdir() if path.is_file() and path.name.lower().endswith(".dcm")],
key=lambda path: natural_filename_key(path.name),
)
def create_unparsed_video_demo_project(
db: Session,
*,
owner: User,
video_path: str,
project_name: str = DEMO_VIDEO_PROJECT_NAME,
) -> Project:
"""Create the bundled demo video project without extracting frames."""
source = Path(video_path)
if not source.exists() or not source.is_file():
raise FileNotFoundError(f"Demo video not found: {video_path}")
project = Project(
name=project_name,
description="默认演示视频,尚未生成帧",
status=PROJECT_STATUS_PENDING,
source_type="video",
parse_fps=30.0,
original_fps=None,
owner_user_id=owner.id,
)
db.add(project)
db.flush()
data = source.read_bytes()
object_name = f"uploads/{project.id}/{source.name}"
upload_file(object_name, data, content_type="video/mp4", length=len(data))
project.video_path = object_name
project.thumbnail_url = None
db.commit()
db.refresh(project)
return project
def create_parsed_video_demo_project(
db: Session,
*,
owner: User,
video_path: str,
project_name: str = DEMO_VIDEO_PROJECT_NAME,
) -> Project:
"""Create the bundled demo video project and register its extracted frame sequence."""
source = Path(video_path)
if not source.exists() or not source.is_file():
raise FileNotFoundError(f"Demo video not found: {video_path}")
project = Project(
name=project_name,
description="默认演示视频,已生成帧",
status=PROJECT_STATUS_PENDING,
source_type="video",
parse_fps=DEMO_VIDEO_PARSE_FPS,
original_fps=None,
owner_user_id=owner.id,
)
db.add(project)
db.flush()
data = source.read_bytes()
object_name = f"uploads/{project.id}/{source.name}"
upload_file(object_name, data, content_type="video/mp4", length=len(data))
project.video_path = object_name
tmp_dir = tempfile.mkdtemp(prefix=f"seg_demo_video_{project.id}_")
try:
output_dir = os.path.join(tmp_dir, "frames")
frame_files, original_fps = parse_video(
str(source),
output_dir,
fps=int(DEMO_VIDEO_PARSE_FPS),
target_width=DEMO_VIDEO_TARGET_WIDTH,
)
project.original_fps = original_fps
object_names = upload_frames_to_minio(frame_files, project.id)
for idx, obj_name in enumerate(object_names):
image = cv2.imread(frame_files[idx])
height, width = image.shape[:2] if image is not None else (None, None)
db.add(Frame(
project_id=project.id,
frame_index=idx,
image_url=obj_name,
width=width,
height=height,
timestamp_ms=idx * 1000.0 / DEMO_VIDEO_PARSE_FPS,
source_frame_number=idx,
))
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
try:
extract_thumbnail(str(source), thumbnail_path)
with open(thumbnail_path, "rb") as thumbnail_file:
thumbnail_data = thumbnail_file.read()
thumbnail_object = f"projects/{project.id}/thumbnail.jpg"
upload_file(
thumbnail_object,
thumbnail_data,
content_type="image/jpeg",
length=len(thumbnail_data),
)
project.thumbnail_url = thumbnail_object
except Exception: # noqa: BLE001
if object_names:
project.thumbnail_url = object_names[0]
project.status = PROJECT_STATUS_READY
db.commit()
db.refresh(project)
return project
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
def create_parsed_dicom_demo_project(
db: Session,
*,
owner: User,
dicom_dir: str,
project_name: str = DEMO_DICOM_PROJECT_NAME,
) -> Project:
"""Create the demo DICOM project, upload the series, and register parsed frames."""
dcm_files = demo_dicom_files(dicom_dir)
if not dcm_files:
raise FileNotFoundError(f"Demo DICOM series not found: {dicom_dir}")
project = Project(
name=project_name,
description=f"默认演示 DICOM 序列,已按文件名自然顺序生成 {len(dcm_files)}",
status=PROJECT_STATUS_PENDING,
source_type="dicom",
parse_fps=DEMO_DICOM_PARSE_FPS,
original_fps=None,
owner_user_id=owner.id,
)
db.add(project)
db.flush()
dicom_prefix = f"uploads/{project.id}/dicom"
for dcm_file in dcm_files:
data = dcm_file.read_bytes()
upload_file(
f"{dicom_prefix}/{dcm_file.name}",
data,
content_type="application/dicom",
length=len(data),
)
project.video_path = dicom_prefix
tmp_dir = tempfile.mkdtemp(prefix=f"seg_demo_dicom_{project.id}_")
try:
output_dir = os.path.join(tmp_dir, "frames")
frame_files = parse_dicom(dicom_dir, output_dir)
object_names = upload_frames_to_minio(frame_files, project.id)
for idx, obj_name in enumerate(object_names):
image = cv2.imread(frame_files[idx])
height, width = image.shape[:2] if image is not None else (None, None)
db.add(Frame(
project_id=project.id,
frame_index=idx,
image_url=obj_name,
width=width,
height=height,
timestamp_ms=idx * 1000.0 / DEMO_DICOM_PARSE_FPS,
source_frame_number=idx,
))
if object_names:
project.thumbnail_url = object_names[0]
project.status = PROJECT_STATUS_READY
db.commit()
db.refresh(project)
return project
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)

View File

@@ -0,0 +1,237 @@
"""Video/DICOM frame parsing and MinIO upload utilities."""
import logging
import os
import re
import shutil
import subprocess
from pathlib import Path
from typing import List, Optional, Tuple
import cv2
import numpy as np
from pydicom import dcmread
from minio_client import upload_file, BUCKET_NAME
logger = logging.getLogger(__name__)
def natural_filename_key(filename: str) -> Tuple[object, ...]:
"""Sort file names by their visible numeric order instead of pure lexicographic order."""
return tuple(
int(part) if part.isdigit() else part.casefold()
for part in re.split(r"(\d+)", Path(filename).name)
)
def get_video_fps(video_path: str) -> float:
"""Read the original frame rate of a video file."""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return 30.0
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()
return fps if fps > 0 else 30.0
def extract_thumbnail(video_path: str, output_path: str, width: int = 640) -> str:
"""Extract the first frame of a video as a thumbnail JPEG."""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"Cannot open video for thumbnail: {video_path}")
ret, frame = cap.read()
cap.release()
if not ret or frame is None:
raise RuntimeError(f"Cannot read first frame from: {video_path}")
h, w = frame.shape[:2]
if w > width:
scale = width / w
new_w = int(w * scale)
new_h = int(h * scale)
frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
cv2.imwrite(output_path, frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
return output_path
def parse_video(
video_path: str,
output_dir: str,
fps: int = 30,
max_frames: Optional[int] = None,
target_width: int = 640,
) -> Tuple[List[str], float]:
"""Extract frames from a video file using FFmpeg or OpenCV fallback.
Args:
video_path: Path to the input video file.
output_dir: Directory to save extracted frames.
fps: Target frame extraction rate.
max_frames: Optional maximum number of frames to extract.
target_width: Output frame width for model-friendly frame sequences.
Returns:
Tuple of (frame_paths, original_fps).
"""
os.makedirs(output_dir, exist_ok=True)
frame_paths: List[str] = []
original_fps = get_video_fps(video_path)
safe_fps = max(int(fps), 1)
safe_width = max(int(target_width), 1)
# Try FFmpeg first
if shutil.which("ffmpeg"):
try:
pattern = os.path.join(output_dir, "frame_%06d.jpg")
cmd = [
"ffmpeg",
"-i", video_path,
"-vf", f"fps={safe_fps},scale={safe_width}:-1",
"-start_number", "0",
"-q:v", "5",
"-y",
pattern,
]
logger.info("Running FFmpeg: %s", " ".join(cmd))
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
if result.returncode == 0:
frame_paths = sorted(
[os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith(".jpg")]
)
if max_frames:
frame_paths = frame_paths[:max_frames]
logger.info("Extracted %d frames via FFmpeg", len(frame_paths))
return frame_paths, original_fps
else:
logger.warning("FFmpeg failed: %s", result.stderr)
except Exception as exc: # noqa: BLE001
logger.warning("FFmpeg exception: %s", exc)
# OpenCV fallback
logger.info("Falling back to OpenCV frame extraction")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"Cannot open video: {video_path}")
video_fps = cap.get(cv2.CAP_PROP_FPS) or 30
interval = max(1, int(round(video_fps / safe_fps)))
count = 0
saved = 0
while True:
ret, frame = cap.read()
if not ret:
break
if count % interval == 0:
path = os.path.join(output_dir, f"frame_{saved:06d}.jpg")
h, w = frame.shape[:2]
if safe_width > 0 and w != safe_width:
scale = safe_width / max(w, 1)
frame = cv2.resize(frame, (safe_width, max(1, int(round(h * scale)))), interpolation=cv2.INTER_AREA)
cv2.imwrite(path, frame, [cv2.IMWRITE_JPEG_QUALITY, 80])
frame_paths.append(path)
saved += 1
if max_frames and saved >= max_frames:
break
count += 1
cap.release()
logger.info("Extracted %d frames via OpenCV", len(frame_paths))
return frame_paths, original_fps
def parse_dicom(
dicom_dir: str,
output_dir: str,
max_frames: Optional[int] = None,
) -> List[str]:
"""Extract frames from DICOM files in a directory.
Args:
dicom_dir: Directory containing .dcm files.
output_dir: Directory to save extracted frames.
max_frames: Optional maximum number of frames to extract.
Returns:
List of paths to extracted frame images.
"""
os.makedirs(output_dir, exist_ok=True)
dcm_files = sorted(
[f for f in os.listdir(dicom_dir) if f.lower().endswith(".dcm")],
key=natural_filename_key,
)
frame_paths: List[str] = []
for idx, fname in enumerate(dcm_files):
if max_frames and idx >= max_frames:
break
path = os.path.join(dicom_dir, fname)
try:
ds = dcmread(path)
pixel_array = ds.pixel_array
# Normalize to 8-bit
if pixel_array.dtype != np.uint8:
pixel_array = pixel_array.astype(np.float32)
pixel_array = (
(pixel_array - pixel_array.min())
/ (pixel_array.max() - pixel_array.min() + 1e-8)
* 255
)
pixel_array = pixel_array.astype(np.uint8)
# Handle multi-frame DICOM
if pixel_array.ndim == 3:
for f in range(pixel_array.shape[0]):
out_path = os.path.join(output_dir, f"frame_{idx:06d}_{f:03d}.jpg")
cv2.imwrite(out_path, pixel_array[f], [cv2.IMWRITE_JPEG_QUALITY, 85])
frame_paths.append(out_path)
else:
out_path = os.path.join(output_dir, f"frame_{idx:06d}.jpg")
cv2.imwrite(out_path, pixel_array, [cv2.IMWRITE_JPEG_QUALITY, 85])
frame_paths.append(out_path)
except Exception as exc: # noqa: BLE001
logger.error("Failed to read DICOM %s: %s", path, exc)
logger.info("Extracted %d frames from DICOM", len(frame_paths))
return frame_paths
def upload_frames_to_minio(
frames: List[str],
project_id: int,
object_prefix: Optional[str] = None,
) -> List[str]:
"""Upload a list of local frame images to MinIO.
Args:
frames: List of local file paths.
project_id: Project ID used for bucket path organization.
object_prefix: Optional prefix override.
Returns:
List of object names (paths) in MinIO.
"""
prefix = object_prefix or f"projects/{project_id}/frames"
object_names: List[str] = []
for frame_path in frames:
fname = os.path.basename(frame_path)
object_name = f"{prefix}/{fname}"
try:
with open(frame_path, "rb") as f:
data = f.read()
upload_file(
object_name,
data,
content_type="image/jpeg",
length=len(data),
)
object_names.append(object_name)
except Exception as exc: # noqa: BLE001
logger.error("Failed to upload %s: %s", frame_path, exc)
logger.info("Uploaded %d/%d frames to MinIO", len(object_names), len(frames))
return object_names

View File

@@ -0,0 +1,340 @@
"""Background media parsing runner used by Celery workers."""
import logging
import os
import shutil
import tempfile
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from sqlalchemy.orm import Session
from minio_client import BUCKET_NAME, download_file, get_minio_client, upload_file
from models import Annotation, Frame, Mask, ProcessingTask, Project
from progress_events import publish_task_progress_event
from services.frame_parser import (
extract_thumbnail,
natural_filename_key,
parse_dicom,
parse_video,
upload_frames_to_minio,
)
from statuses import (
PROJECT_STATUS_PENDING,
PROJECT_STATUS_ERROR,
PROJECT_STATUS_PARSING,
PROJECT_STATUS_READY,
TASK_STATUS_CANCELLED,
TASK_STATUS_FAILED,
TASK_STATUS_RUNNING,
TASK_STATUS_SUCCESS,
)
logger = logging.getLogger(__name__)
class TaskCancelled(RuntimeError):
"""Raised internally when a persisted task has been cancelled."""
def _now() -> datetime:
return datetime.now(timezone.utc)
def _set_task_state(
db: Session,
task: ProcessingTask,
*,
status: str | None = None,
progress: int | None = None,
message: str | None = None,
result: dict[str, Any] | None = None,
error: str | None = None,
started: bool = False,
finished: bool = False,
) -> None:
if status is not None:
task.status = status
if progress is not None:
task.progress = max(0, min(100, progress))
if message is not None:
task.message = message
if result is not None:
task.result = result
if error is not None:
task.error = error
if started:
task.started_at = _now()
if finished:
task.finished_at = _now()
db.commit()
db.refresh(task)
publish_task_progress_event(task)
def _project_status_after_stop(project: Project) -> str:
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
def _positive_int(value: Any, default: int | None = None) -> int | None:
try:
parsed = int(value)
except (TypeError, ValueError):
return default
return parsed if parsed > 0 else default
def _positive_float(value: Any, default: float) -> float:
try:
parsed = float(value)
except (TypeError, ValueError):
return default
return parsed if parsed > 0 else default
def _frame_sequence_metadata(
index: int,
parse_fps: float,
original_fps: float | None,
) -> dict[str, float | int | None]:
safe_parse_fps = max(float(parse_fps or 1.0), 1e-6)
timestamp_ms = index * 1000.0 / safe_parse_fps
source_frame_number = None
if original_fps and original_fps > 0:
source_frame_number = int(round(index * original_fps / safe_parse_fps))
return {
"timestamp_ms": timestamp_ms,
"source_frame_number": source_frame_number,
}
def _clear_existing_project_outputs(db: Session, project: Project) -> None:
"""Remove stale frame sequence and annotations before regenerating frames."""
annotation_ids = db.query(Annotation.id).filter(Annotation.project_id == project.id)
db.query(Mask).filter(Mask.annotation_id.in_(annotation_ids)).delete(synchronize_session=False)
db.query(Annotation).filter(Annotation.project_id == project.id).delete(synchronize_session=False)
db.query(Frame).filter(Frame.project_id == project.id).delete(synchronize_session=False)
project.thumbnail_url = None
db.commit()
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
db.refresh(task)
if task.status == TASK_STATUS_CANCELLED:
raise TaskCancelled("Task was cancelled")
def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
"""Parse one project's media and update task progress in the database."""
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
if not task:
raise ValueError(f"Task not found: {task_id}")
if task.status == TASK_STATUS_CANCELLED:
return {
"task_id": task.id,
"status": TASK_STATUS_CANCELLED,
"message": task.message or "任务已取消",
}
if task.project_id is None:
_set_task_state(
db,
task,
status=TASK_STATUS_FAILED,
progress=100,
message="任务缺少 project_id",
error="Task has no project_id",
finished=True,
)
raise ValueError("Task has no project_id")
project = db.query(Project).filter(Project.id == task.project_id).first()
if not project:
_set_task_state(
db,
task,
status=TASK_STATUS_FAILED,
progress=100,
message="项目不存在",
error="Project not found",
finished=True,
)
raise ValueError(f"Project not found: {task.project_id}")
if not project.video_path:
_set_task_state(
db,
task,
status=TASK_STATUS_FAILED,
progress=100,
message="项目没有可解析媒体",
error="Project has no media uploaded",
finished=True,
)
project.status = PROJECT_STATUS_ERROR
db.commit()
raise ValueError("Project has no media uploaded")
_ensure_not_cancelled(db, task)
project.status = PROJECT_STATUS_PARSING
_clear_existing_project_outputs(db, project)
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="后台解析已启动", started=True)
payload = task.payload or {}
effective_source = payload.get("source_type") or project.source_type or "video"
parse_fps = _positive_float(payload.get("parse_fps"), project.parse_fps or 30.0)
max_frames = _positive_int(payload.get("max_frames"))
target_width = _positive_int(payload.get("target_width"), 640) or 640
project.parse_fps = parse_fps
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project.id}_")
output_dir = os.path.join(tmp_dir, "frames")
os.makedirs(output_dir, exist_ok=True)
try:
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=15, message="正在下载媒体文件")
if effective_source == "dicom":
dcm_dir = os.path.join(tmp_dir, "dcm")
os.makedirs(dcm_dir, exist_ok=True)
client = get_minio_client()
objects = sorted(
list(client.list_objects(BUCKET_NAME, prefix=project.video_path, recursive=True)),
key=lambda obj: natural_filename_key(obj.object_name),
)
for obj in objects:
_ensure_not_cancelled(db, task)
if obj.object_name.lower().endswith(".dcm"):
data = download_file(obj.object_name)
local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name))
with open(local_dcm, "wb") as f:
f.write(data)
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=35, message="正在解析 DICOM 序列")
frame_files = parse_dicom(dcm_dir, output_dir, max_frames=max_frames)
else:
_ensure_not_cancelled(db, task)
media_bytes = download_file(project.video_path)
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
with open(local_path, "wb") as f:
f.write(media_bytes)
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=35, message="正在使用 FFmpeg/OpenCV 拆帧")
frame_files, original_fps = parse_video(
local_path,
output_dir,
fps=int(parse_fps),
max_frames=max_frames,
target_width=target_width,
)
project.original_fps = original_fps
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
try:
extract_thumbnail(local_path, thumbnail_path)
with open(thumbnail_path, "rb") as f:
thumb_data = f.read()
thumb_object = f"projects/{project.id}/thumbnail.jpg"
upload_file(thumb_object, thumb_data, content_type="image/jpeg", length=len(thumb_data))
project.thumbnail_url = thumb_object
except Exception as exc: # noqa: BLE001
logger.warning("Thumbnail extraction failed: %s", exc)
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=70, message="正在上传帧到对象存储")
object_names = upload_frames_to_minio(frame_files, project.id)
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=85, message="正在写入帧索引")
frames_out = []
for idx, obj_name in enumerate(object_names):
_ensure_not_cancelled(db, task)
local_frame = frame_files[idx]
try:
import cv2
img = cv2.imread(local_frame)
h, w = img.shape[:2] if img is not None else (None, None)
except Exception: # noqa: BLE001
h, w = None, None
sequence_meta = _frame_sequence_metadata(idx, parse_fps, project.original_fps)
frame = Frame(
project_id=project.id,
frame_index=idx,
image_url=obj_name,
width=w,
height=h,
timestamp_ms=sequence_meta["timestamp_ms"],
source_frame_number=sequence_meta["source_frame_number"],
)
db.add(frame)
frames_out.append(frame)
project.status = PROJECT_STATUS_READY
db.commit()
result = {
"project_id": project.id,
"frames_extracted": len(frames_out),
"status": PROJECT_STATUS_READY,
"message": "Frame extraction completed successfully.",
"frame_sequence": {
"original_fps": project.original_fps,
"parse_fps": parse_fps,
"frame_count": len(frames_out),
"duration_ms": (len(frames_out) - 1) * 1000.0 / parse_fps if frames_out else 0,
"target_width": target_width,
"frame_width": frames_out[0].width if frames_out else None,
"frame_height": frames_out[0].height if frames_out else None,
"max_frames": max_frames,
"object_prefix": f"projects/{project.id}/frames",
},
}
_set_task_state(
db,
task,
status=TASK_STATUS_SUCCESS,
progress=100,
message="解析完成",
result=result,
finished=True,
)
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project.id)
return result
except TaskCancelled:
project.status = _project_status_after_stop(project)
task.status = TASK_STATUS_CANCELLED
task.progress = 100
task.message = task.message or "任务已取消"
task.error = task.error or "Cancelled by user"
task.finished_at = task.finished_at or _now()
db.commit()
db.refresh(task)
publish_task_progress_event(task)
logger.info("Parse task cancelled: task_id=%s project_id=%s", task.id, project.id)
return {
"task_id": task.id,
"project_id": project.id,
"status": TASK_STATUS_CANCELLED,
"message": task.message,
}
except Exception as exc: # noqa: BLE001
project.status = PROJECT_STATUS_ERROR
_set_task_state(
db,
task,
status=TASK_STATUS_FAILED,
progress=100,
message="解析失败",
error=str(exc),
finished=True,
)
logger.error("Frame extraction failed: %s", exc)
raise
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)

View File

@@ -0,0 +1,842 @@
"""Background SAM video propagation runner used by Celery workers."""
import hashlib
import json
import logging
import tempfile
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import cv2
import numpy as np
from sqlalchemy.orm import Session
from minio_client import download_file
from models import Annotation, Frame, ProcessingTask, Project
from progress_events import publish_task_progress_event
from services.sam_registry import ModelUnavailableError, sam_registry
from statuses import (
TASK_STATUS_CANCELLED,
TASK_STATUS_FAILED,
TASK_STATUS_RUNNING,
TASK_STATUS_SUCCESS,
)
logger = logging.getLogger(__name__)
class PropagationTaskCancelled(RuntimeError):
"""Raised internally when a persisted propagation task has been cancelled."""
def _now() -> datetime:
return datetime.now(timezone.utc)
def _set_task_state(
db: Session,
task: ProcessingTask,
*,
status: str | None = None,
progress: int | None = None,
message: str | None = None,
result: dict[str, Any] | None = None,
error: str | None = None,
started: bool = False,
finished: bool = False,
) -> None:
if status is not None:
task.status = status
if progress is not None:
task.progress = max(0, min(100, progress))
if message is not None:
task.message = message
if result is not None:
task.result = result
if error is not None:
task.error = error
if started:
task.started_at = _now()
if finished:
task.finished_at = _now()
db.commit()
db.refresh(task)
publish_task_progress_event(task)
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
db.refresh(task)
if task.status == TASK_STATUS_CANCELLED:
raise PropagationTaskCancelled("Task was cancelled")
def _clamp01(value: float) -> float:
return min(max(float(value), 0.0), 1.0)
def _polygon_bbox(polygon: list[list[float]]) -> list[float]:
xs = [_clamp01(point[0]) for point in polygon]
ys = [_clamp01(point[1]) for point in polygon]
left, right = min(xs), max(xs)
top, bottom = min(ys), max(ys)
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
def _polygons_bbox(polygons: list[list[list[float]]]) -> list[float]:
points = [point for polygon in polygons for point in polygon if len(point) >= 2]
if not points:
return [0.0, 0.0, 0.0, 0.0]
xs = [_clamp01(point[0]) for point in points]
ys = [_clamp01(point[1]) for point in points]
left, right = min(xs), max(xs)
top, bottom = min(ys), max(ys)
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
def _normalize_polygon(polygon: list[list[float]]) -> list[list[float]]:
return [[_clamp01(point[0]), _clamp01(point[1])] for point in polygon if len(point) >= 2]
def _normalize_smoothing_options(value: Any) -> dict[str, Any] | None:
if not isinstance(value, dict):
return None
try:
strength = max(0.0, min(float(value.get("strength") or 0.0), 100.0))
except (TypeError, ValueError):
strength = 0.0
if strength <= 0:
return None
method = str(value.get("method") or "chaikin").lower()
if method != "chaikin":
method = "chaikin"
return {"strength": round(strength, 2), "method": method}
def _smoothing_ratio(strength: float, curve: float = 1.65) -> float:
normalized = max(0.0, min(float(strength or 0.0), 100.0)) / 100.0
return normalized ** curve
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int, corner_cut: float = 0.25) -> list[list[float]]:
points = _normalize_polygon(polygon)
q = max(0.02, min(float(corner_cut), 0.25))
for _ in range(max(0, iterations)):
if len(points) < 3:
break
next_points: list[list[float]] = []
for index, current in enumerate(points):
following = points[(index + 1) % len(points)]
next_points.append([
_clamp01((1.0 - q) * current[0] + q * following[0]),
_clamp01((1.0 - q) * current[1] + q * following[1]),
])
next_points.append([
_clamp01(q * current[0] + (1.0 - q) * following[0]),
_clamp01(q * current[1] + (1.0 - q) * following[1]),
])
points = next_points
return points
def _simplify_polygon(polygon: list[list[float]], strength: float) -> list[list[float]]:
if len(polygon) < 3:
return polygon
contour = np.array([[[point[0], point[1]]] for point in polygon], dtype=np.float32)
arc_length = cv2.arcLength(contour, True)
epsilon = arc_length * (0.00015 + _smoothing_ratio(strength) * 0.00735)
approx = cv2.approxPolyDP(contour, epsilon, True).reshape(-1, 2)
if len(approx) < 3:
return polygon
return [[_clamp01(float(x)), _clamp01(float(y))] for x, y in approx]
def _smooth_polygon(polygon: list[list[float]], smoothing: dict[str, Any] | None) -> list[list[float]]:
if not smoothing:
return _normalize_polygon(polygon)
strength = float(smoothing.get("strength") or 0.0)
if strength <= 0:
return _normalize_polygon(polygon)
effective_strength = _smoothing_ratio(strength, curve=1.45) * 100.0
if effective_strength >= 85:
iterations = 4
elif effective_strength >= 55:
iterations = 3
elif effective_strength >= 25:
iterations = 2
else:
iterations = 1
corner_cut = 0.03 + _smoothing_ratio(strength, curve=1.35) * 0.22
normalized = _normalize_polygon(polygon)
pre_simplified = _simplify_polygon(normalized, effective_strength * 0.25)
smoothed = _chaikin_smooth_polygon(pre_simplified, iterations, corner_cut)
simplified = _simplify_polygon(smoothed, effective_strength)
if len(simplified) > len(normalized):
for fallback_strength in (25.0, 35.0, 50.0, 70.0, 90.0, 100.0):
simplified = _simplify_polygon(simplified, max(effective_strength, fallback_strength))
if len(simplified) <= len(normalized):
break
return simplified if len(simplified) >= 3 else _normalize_polygon(polygon)
def _bbox_area(bbox: list[float]) -> float:
return max(float(bbox[2]), 0.0) * max(float(bbox[3]), 0.0)
def _bbox_overlap_ratio(a: list[float], b: list[float]) -> float:
ax1, ay1, aw, ah = a
bx1, by1, bw, bh = b
ax2 = ax1 + aw
ay2 = ay1 + ah
bx2 = bx1 + bw
by2 = by1 + bh
overlap_width = max(0.0, min(ax2, bx2) - max(ax1, bx1))
overlap_height = max(0.0, min(ay2, by2) - max(ay1, by1))
overlap_area = overlap_width * overlap_height
smallest_area = min(_bbox_area(a), _bbox_area(b))
return overlap_area / smallest_area if smallest_area > 0 else 0.0
def _stable_json(value: Any) -> str:
return json.dumps(value, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
def _canonicalize_signature_value(value: Any) -> Any:
if isinstance(value, float):
return round(value, 6)
if isinstance(value, list):
return [_canonicalize_signature_value(item) for item in value]
if isinstance(value, dict):
return {key: _canonicalize_signature_value(value[key]) for key in sorted(value)}
return value
def _seed_signature(seed: dict[str, Any]) -> str:
"""Return a stable signature for seed geometry and semantic attrs."""
inherited_signature = seed.get("propagation_seed_signature")
if inherited_signature:
return str(inherited_signature)
signature_payload = {
"polygons": seed.get("polygons") or [],
"holes": seed.get("holes") or [],
"bbox": seed.get("bbox") or [],
"points": seed.get("points") or [],
"labels": seed.get("labels") or [],
"label": seed.get("label"),
"color": seed.get("color"),
"class_metadata": seed.get("class_metadata") or {},
"template_id": seed.get("template_id"),
"smoothing": _normalize_smoothing_options(seed.get("smoothing")),
}
return hashlib.sha256(_stable_json(_canonicalize_signature_value(signature_payload)).encode("utf-8")).hexdigest()
def _seed_key(seed: dict[str, Any]) -> str:
"""Prefer stable persisted ids; fall back to semantic attrs for legacy callers."""
source_instance_id = seed.get("source_instance_id")
if source_instance_id:
return f"instance:{source_instance_id}"
source_annotation_id = seed.get("source_annotation_id")
if source_annotation_id is not None:
return f"annotation:{source_annotation_id}"
source_mask_id = seed.get("source_mask_id")
if source_mask_id:
return f"mask:{source_mask_id}"
class_metadata = seed.get("class_metadata") or {}
class_id = class_metadata.get("id") or class_metadata.get("name")
return _stable_json({
"template_id": seed.get("template_id"),
"class_id": class_id,
"label": seed.get("label"),
"color": seed.get("color"),
})
def _semantic_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool:
"""Best-effort match when a manually edited replacement lacks old lineage ids."""
class_metadata = seed.get("class_metadata") or {}
previous_class = mask_data.get("class") or {}
previous_class_id = previous_class.get("id") or previous_class.get("name")
class_id = class_metadata.get("id") or class_metadata.get("name")
if previous_class_id and class_id and str(previous_class_id) != str(class_id):
return False
return (
mask_data.get("label") == seed.get("label")
and mask_data.get("color") == seed.get("color")
)
def _legacy_seed_matches(mask_data: dict[str, Any], seed: dict[str, Any]) -> bool:
"""Best-effort match for propagation annotations created before seed keys."""
class_metadata = seed.get("class_metadata") or {}
previous_class = mask_data.get("class") or {}
previous_class_id = previous_class.get("id") or previous_class.get("name")
class_id = class_metadata.get("id") or class_metadata.get("name")
return (
mask_data.get("label") == seed.get("label")
and mask_data.get("color") == seed.get("color")
and previous_class_id == class_id
)
def _source_model_matches(mask_data: dict[str, Any], model_id: str) -> bool:
return str(mask_data.get("source") or "") == f"{model_id}_propagation"
def _seed_identity_matches(mask_data: dict[str, Any], seed_key: str, seed: dict[str, Any]) -> bool:
previous_seed_key = mask_data.get("propagation_seed_key")
if previous_seed_key == seed_key:
return True
source_instance_id = seed.get("source_instance_id")
if source_instance_id and (
mask_data.get("source_instance_id") == source_instance_id
or mask_data.get("instance_id") == source_instance_id
):
return True
source_annotation_id = seed.get("source_annotation_id")
if source_annotation_id is not None and str(mask_data.get("source_annotation_id") or "") == str(source_annotation_id):
return True
source_mask_id = seed.get("source_mask_id")
if source_mask_id and mask_data.get("source_mask_id") == source_mask_id:
return True
has_persisted_seed_identity = bool(source_instance_id) or source_annotation_id is not None or bool(source_mask_id)
has_previous_identity = (
bool(previous_seed_key)
or mask_data.get("source_instance_id") is not None
or mask_data.get("instance_id") is not None
or mask_data.get("source_annotation_id") is not None
or bool(mask_data.get("source_mask_id"))
)
if has_persisted_seed_identity or has_previous_identity:
return False
return _legacy_seed_matches(mask_data, seed)
def _seed_identity_markers(seed: dict[str, Any]) -> set[str]:
markers = {f"seed:{_seed_key(seed)}"}
source_instance_id = seed.get("source_instance_id")
if source_instance_id:
markers.add(f"instance:{source_instance_id}")
source_annotation_id = seed.get("source_annotation_id")
if source_annotation_id is not None:
markers.add(f"annotation:{source_annotation_id}")
source_mask_id = seed.get("source_mask_id")
if source_mask_id:
markers.add(f"mask:{source_mask_id}")
return markers
def _mask_identity_markers(mask_data: dict[str, Any]) -> set[str]:
markers: set[str] = set()
previous_seed_key = mask_data.get("propagation_seed_key")
if previous_seed_key:
markers.add(f"seed:{previous_seed_key}")
source_instance_id = mask_data.get("source_instance_id")
if source_instance_id:
markers.add(f"instance:{source_instance_id}")
instance_id = mask_data.get("instance_id")
if instance_id:
markers.add(f"instance:{instance_id}")
source_annotation_id = mask_data.get("source_annotation_id")
if source_annotation_id is not None:
markers.add(f"annotation:{source_annotation_id}")
source_mask_id = mask_data.get("source_mask_id")
if source_mask_id:
markers.add(f"mask:{source_mask_id}")
return markers
def _payload_seed_identity_markers(payload: dict[str, Any]) -> set[str]:
markers: set[str] = set()
for step in payload.get("steps") or []:
seed = step.get("seed") or {}
markers.update(_seed_identity_markers(seed))
return markers
def _is_propagation_annotation(annotation: Annotation, seed_key: str, seed: dict[str, Any]) -> bool:
mask_data = annotation.mask_data or {}
source = str(mask_data.get("source") or "")
if not source.endswith("_propagation"):
return False
return _seed_identity_matches(mask_data, seed_key, seed)
def _direction_matches(mask_data: dict[str, Any], direction: str) -> bool:
previous_direction = mask_data.get("propagation_direction")
return previous_direction in {None, direction}
def _annotation_spatially_matches(annotation: Annotation, polygon: list[list[float]]) -> bool:
"""Use target-frame overlap as a final guard before replacing same-object propagation."""
candidate_bbox = _polygon_bbox(polygon)
for previous_polygon in (annotation.mask_data or {}).get("polygons") or []:
if len(previous_polygon) < 3:
continue
if _bbox_overlap_ratio(_polygon_bbox(previous_polygon), candidate_bbox) >= 0.15:
return True
return False
def _delete_replaced_frame_annotations(
db: Session,
*,
payload: dict[str, Any],
frame_id: int,
seed_key: str,
seed: dict[str, Any],
polygon: list[list[float]],
) -> int:
"""Delete old propagated masks for the same object immediately before writing a new result."""
previous_annotations = (
db.query(Annotation)
.filter(Annotation.project_id == int(payload["project_id"]))
.filter(Annotation.frame_id == frame_id)
.all()
)
deleted_count = 0
current_seed_markers = _seed_identity_markers(seed)
task_seed_markers = _payload_seed_identity_markers(payload)
for annotation in previous_annotations:
mask_data = annotation.mask_data or {}
source = str(mask_data.get("source") or "")
if not source.endswith("_propagation"):
continue
source_instance_id = seed.get("source_instance_id")
mask_instance_ids = {
str(value)
for value in (mask_data.get("source_instance_id"), mask_data.get("instance_id"))
if value
}
if source_instance_id and mask_instance_ids and str(source_instance_id) not in mask_instance_ids:
continue
mask_markers = _mask_identity_markers(mask_data)
# Keep sibling seeds in the same propagation task from deleting each other.
if mask_markers and mask_markers.isdisjoint(current_seed_markers) and not mask_markers.isdisjoint(task_seed_markers):
continue
same_lineage = _seed_identity_matches(mask_data, seed_key, seed)
same_manual_replacement = (
_semantic_seed_matches(mask_data, seed)
and _annotation_spatially_matches(annotation, polygon)
)
if same_lineage or same_manual_replacement:
db.delete(annotation)
deleted_count += 1
if deleted_count:
db.commit()
return deleted_count
def _prepare_seed_propagation(
db: Session,
*,
payload: dict[str, Any],
model_id: str,
seed: dict[str, Any],
direction: str,
target_frame_ids: set[int],
) -> dict[str, Any]:
seed_key = _seed_key(seed)
seed_signature = _seed_signature(seed)
if not target_frame_ids:
return {
"skip": True,
"seed_key": seed_key,
"seed_signature": seed_signature,
"deleted_annotation_count": 0,
}
previous_annotations = (
db.query(Annotation)
.filter(Annotation.project_id == int(payload["project_id"]))
.filter(Annotation.frame_id.in_(target_frame_ids))
.all()
)
matching = [
annotation for annotation in previous_annotations
if _is_propagation_annotation(annotation, seed_key, seed)
and _direction_matches(annotation.mask_data or {}, direction)
]
covered_frame_ids = {int(annotation.frame_id) for annotation in matching}
if matching and all(
(annotation.mask_data or {}).get("propagation_seed_signature") == seed_signature
and _source_model_matches(annotation.mask_data or {}, model_id)
for annotation in matching
) and target_frame_ids.issubset(covered_frame_ids):
return {
"skip": True,
"seed_key": seed_key,
"seed_signature": seed_signature,
"deleted_annotation_count": 0,
}
deleted_count = 0
if matching:
for annotation in matching:
db.delete(annotation)
deleted_count += 1
db.commit()
return {
"skip": False,
"seed_key": seed_key,
"seed_signature": seed_signature,
"deleted_annotation_count": deleted_count,
}
def _frame_window(
frames: list[Frame],
source_position: int,
direction: str,
max_frames: int,
) -> tuple[list[Frame], int]:
count = max(1, min(max_frames, len(frames)))
if direction == "backward":
start = max(0, source_position - count + 1)
return frames[start:source_position + 1], source_position - start
end = min(len(frames), source_position + count)
return frames[source_position:end], 0
def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]:
paths = []
for index, frame in enumerate(frames):
data = download_file(frame.image_url)
# SAM2VideoPredictor sorts frames by converting the filename stem to int.
path = directory / f"{index:06d}.jpg"
path.write_bytes(data)
paths.append(str(path))
return paths
def _save_propagated_annotations(
db: Session,
*,
payload: dict[str, Any],
selected_frames: list[Frame],
source_frame: Frame,
propagated: list[dict[str, Any]],
seed: dict[str, Any],
) -> tuple[list[Annotation], int]:
created: list[Annotation] = []
if payload.get("save_annotations", True) is False:
return created, 0
class_metadata = seed.get("class_metadata")
template_id = seed.get("template_id")
label = seed.get("label") or "Propagated Mask"
color = seed.get("color") or "#06b6d4"
model_id = sam_registry.normalize_model_id(payload.get("model"))
include_source = bool(payload.get("include_source", False))
seed_key = _seed_key(seed)
seed_signature = _seed_signature(seed)
source_annotation_id = seed.get("source_annotation_id")
source_mask_id = seed.get("source_mask_id")
source_instance_id = seed.get("source_instance_id") or seed_key
smoothing = _normalize_smoothing_options(seed.get("smoothing"))
direction = str(payload.get("current_direction") or "")
deleted_count = 0
cleaned_frame_ids: set[int] = set()
for frame_result in propagated:
relative_index = int(frame_result.get("frame_index", -1))
if relative_index < 0 or relative_index >= len(selected_frames):
continue
frame = selected_frames[relative_index]
if not include_source and frame.id == source_frame.id:
continue
result_polygons = frame_result.get("polygons") or []
result_holes = frame_result.get("holes") or []
scores = frame_result.get("scores") or []
prepared_polygons = [
(polygon_index, _smooth_polygon(polygon, smoothing))
for polygon_index, polygon in enumerate(result_polygons)
if len(polygon) >= 3
]
cleanup_polygon = next((polygon for _polygon_index, polygon in prepared_polygons if len(polygon) >= 3), None)
if cleanup_polygon is not None and frame.id not in cleaned_frame_ids:
deleted_count += _delete_replaced_frame_annotations(
db,
payload=payload,
frame_id=int(frame.id),
seed_key=seed_key,
seed=seed,
polygon=cleanup_polygon,
)
cleaned_frame_ids.add(int(frame.id))
polygons_to_save: list[list[list[float]]] = []
holes_to_save: list[list[list[list[float]]]] = []
score_values: list[float] = []
for polygon_index, polygon in prepared_polygons:
if len(polygon) < 3:
continue
polygons_to_save.append(polygon)
hole_group = result_holes[polygon_index] if polygon_index < len(result_holes) and isinstance(result_holes[polygon_index], list) else []
holes_to_save.append(hole_group if isinstance(hole_group, list) else [])
if polygon_index < len(scores):
try:
score_values.append(float(scores[polygon_index]))
except (TypeError, ValueError):
pass
if not polygons_to_save:
continue
annotation = Annotation(
project_id=int(payload["project_id"]),
frame_id=frame.id,
template_id=template_id,
mask_data={
"polygons": polygons_to_save,
**({"holes": holes_to_save, "hasHoles": True} if any(holes_to_save) else {}),
"label": label,
"color": color,
"source": f"{model_id}_propagation",
"propagated_from_frame_id": source_frame.id,
"propagated_from_frame_index": source_frame.frame_index,
"propagation_seed_key": seed_key,
"propagation_seed_signature": seed_signature,
"propagation_direction": direction,
"instance_id": source_instance_id,
"source_instance_id": source_instance_id,
"source_annotation_id": source_annotation_id,
"source_mask_id": source_mask_id,
"score": max(score_values) if score_values else None,
**({"scores": score_values} if len(score_values) > 1 else {}),
**({"geometry_smoothing": smoothing} if smoothing else {}),
**({"class": class_metadata} if class_metadata else {}),
},
points=None,
bbox=_polygons_bbox(polygons_to_save),
)
db.add(annotation)
created.append(annotation)
db.commit()
for annotation in created:
db.refresh(annotation)
return created, deleted_count
def _run_one_step(
db: Session,
*,
payload: dict[str, Any],
frames: list[Frame],
source_frame: Frame,
source_position: int,
step: dict[str, Any],
) -> dict[str, Any]:
direction = str(step.get("direction") or "forward").lower()
if direction not in {"forward", "backward"}:
raise ValueError("direction must be forward or backward")
max_frames = max(1, min(int(step.get("max_frames") or payload.get("max_frames") or 30), 500))
seed = step.get("seed") or {}
if not (seed.get("polygons") or seed.get("bbox") or seed.get("points")):
raise ValueError("Propagation requires seed polygons, bbox, or points")
model_id = sam_registry.normalize_model_id(payload.get("model"))
selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames)
include_source = bool(payload.get("include_source", False))
target_frame_ids = {
int(frame.id)
for frame in selected_frames
if include_source or frame.id != source_frame.id
}
seed_state = _prepare_seed_propagation(
db,
payload=payload,
model_id=model_id,
seed=seed,
direction=direction,
target_frame_ids=target_frame_ids,
)
if seed_state["skip"]:
return {
"model": model_id,
"direction": direction,
"processed_frame_count": 0,
"created_annotation_count": 0,
"deleted_annotation_count": 0,
"skipped_seed_count": 1,
"seed_label": seed.get("label"),
"seed_key": seed_state["seed_key"],
}
with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload['project_id']}_") as tmpdir:
frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir))
propagated = sam_registry.propagate_video(
model_id,
frame_paths,
source_relative_index,
seed,
direction,
len(selected_frames),
)
save_payload = {**payload, "current_direction": direction}
created, write_cleanup_count = _save_propagated_annotations(
db,
payload=save_payload,
selected_frames=selected_frames,
source_frame=source_frame,
propagated=propagated,
seed=seed,
)
return {
"model": model_id,
"direction": direction,
"processed_frame_count": len(selected_frames),
"created_annotation_count": len(created),
"deleted_annotation_count": int(seed_state["deleted_annotation_count"]) + write_cleanup_count,
"skipped_seed_count": 0,
"seed_label": seed.get("label"),
"seed_key": seed_state["seed_key"],
}
def run_propagate_project_task(db: Session, task_id: int) -> dict[str, Any]:
"""Run one queued SAM propagation task and update persisted progress."""
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
if not task:
raise ValueError(f"Task not found: {task_id}")
if task.status == TASK_STATUS_CANCELLED:
return {"task_id": task.id, "status": TASK_STATUS_CANCELLED, "message": task.message or "任务已取消"}
payload = task.payload or {}
project_id = int(payload.get("project_id") or task.project_id or 0)
source_frame_id = int(payload.get("frame_id") or 0)
try:
model_id = sam_registry.normalize_model_id(payload.get("model"))
except ValueError as exc:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
raise
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="项目不存在", error="Project not found", finished=True)
raise ValueError(f"Project not found: {project_id}")
source_frame = db.query(Frame).filter(Frame.id == source_frame_id, Frame.project_id == project_id).first()
if not source_frame:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不存在", error="Frame not found", finished=True)
raise ValueError(f"Frame not found: {source_frame_id}")
frames = db.query(Frame).filter(Frame.project_id == project_id).order_by(Frame.frame_index).all()
source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None)
if source_position is None:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="参考帧不在项目帧序列中", error="Source frame is not in project frame sequence", finished=True)
raise ValueError("Source frame is not in project frame sequence")
steps = payload.get("steps") or []
if not steps:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="传播任务缺少步骤", error="Propagation task has no steps", finished=True)
raise ValueError("Propagation task has no steps")
_ensure_not_cancelled(db, task)
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="自动传播任务已启动", started=True)
step_results: list[dict[str, Any]] = []
created_count = 0
processed_count = 0
deleted_count = 0
skipped_count = 0
total_steps = len(steps)
try:
for index, step in enumerate(steps, start=1):
_ensure_not_cancelled(db, task)
seed_label = (step.get("seed") or {}).get("label") or "mask"
direction_label = "向前传播" if step.get("direction") == "backward" else "向后传播"
progress_before = 5 + int(((index - 1) / total_steps) * 90)
_set_task_state(
db,
task,
progress=progress_before,
message=f"{direction_label} {seed_label} ({index}/{total_steps})",
result={
"project_id": project_id,
"source_frame_id": source_frame_id,
"model": model_id,
"total_steps": total_steps,
"completed_steps": index - 1,
"processed_frame_count": processed_count,
"created_annotation_count": created_count,
"deleted_annotation_count": deleted_count,
"skipped_seed_count": skipped_count,
"steps": step_results,
},
)
result = _run_one_step(
db,
payload=payload,
frames=frames,
source_frame=source_frame,
source_position=source_position,
step=step,
)
step_results.append(result)
created_count += int(result["created_annotation_count"])
processed_count += int(result["processed_frame_count"])
deleted_count += int(result.get("deleted_annotation_count") or 0)
skipped_count += int(result.get("skipped_seed_count") or 0)
_set_task_state(
db,
task,
progress=5 + int((index / total_steps) * 90),
message=f"{direction_label} {seed_label} 完成 ({index}/{total_steps})",
result={
"project_id": project_id,
"source_frame_id": source_frame_id,
"model": model_id,
"total_steps": total_steps,
"completed_steps": index,
"processed_frame_count": processed_count,
"created_annotation_count": created_count,
"deleted_annotation_count": deleted_count,
"skipped_seed_count": skipped_count,
"steps": step_results,
},
)
result = {
"project_id": project_id,
"source_frame_id": source_frame_id,
"model": model_id,
"total_steps": total_steps,
"completed_steps": total_steps,
"processed_frame_count": processed_count,
"created_annotation_count": created_count,
"deleted_annotation_count": deleted_count,
"skipped_seed_count": skipped_count,
"steps": step_results,
}
_set_task_state(
db,
task,
status=TASK_STATUS_SUCCESS,
progress=100,
message="自动传播完成" if created_count > 0 else (
"自动传播完成,未改变的 mask 已跳过" if skipped_count > 0 else "自动传播完成,但没有生成新的 mask"
),
result=result,
finished=True,
)
return result
except PropagationTaskCancelled:
task.status = TASK_STATUS_CANCELLED
task.progress = 100
task.message = task.message or "任务已取消"
task.error = task.error or "Cancelled by user"
task.finished_at = task.finished_at or _now()
db.commit()
db.refresh(task)
publish_task_progress_event(task)
return {"task_id": task.id, "project_id": project_id, "status": TASK_STATUS_CANCELLED, "message": task.message}
except (ModelUnavailableError, NotImplementedError, ValueError) as exc:
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
raise
except Exception as exc: # noqa: BLE001
logger.exception("Propagation task failed: task_id=%s", task.id)
_set_task_state(db, task, status=TASK_STATUS_FAILED, progress=100, message="自动传播失败", error=str(exc), finished=True)
raise

View File

@@ -0,0 +1,690 @@
"""SAM 2 engine wrapper with lazy loading and explicit runtime status."""
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import numpy as np
from config import settings
logger = logging.getLogger(__name__)
DEFAULT_SAM2_MODEL_ID = "sam2.1_hiera_tiny"
@dataclass(frozen=True)
class SAM2Variant:
"""One selectable SAM 2.1 runtime variant."""
id: str
label: str
short_label: str
config: str
legacy_config: str
checkpoint_filename: str
legacy_checkpoint_filename: str
SAM2_VARIANTS: dict[str, SAM2Variant] = {
"sam2.1_hiera_tiny": SAM2Variant(
id="sam2.1_hiera_tiny",
label="SAM 2.1 Tiny",
short_label="tiny",
config="configs/sam2.1/sam2.1_hiera_t.yaml",
legacy_config="configs/sam2/sam2_hiera_t.yaml",
checkpoint_filename="sam2.1_hiera_tiny.pt",
legacy_checkpoint_filename="sam2_hiera_tiny.pt",
),
"sam2.1_hiera_small": SAM2Variant(
id="sam2.1_hiera_small",
label="SAM 2.1 Small",
short_label="small",
config="configs/sam2.1/sam2.1_hiera_s.yaml",
legacy_config="configs/sam2/sam2_hiera_s.yaml",
checkpoint_filename="sam2.1_hiera_small.pt",
legacy_checkpoint_filename="sam2_hiera_small.pt",
),
"sam2.1_hiera_base_plus": SAM2Variant(
id="sam2.1_hiera_base_plus",
label="SAM 2.1 Base+",
short_label="base+",
config="configs/sam2.1/sam2.1_hiera_b+.yaml",
legacy_config="configs/sam2/sam2_hiera_b+.yaml",
checkpoint_filename="sam2.1_hiera_base_plus.pt",
legacy_checkpoint_filename="sam2_hiera_base_plus.pt",
),
"sam2.1_hiera_large": SAM2Variant(
id="sam2.1_hiera_large",
label="SAM 2.1 Large",
short_label="large",
config="configs/sam2.1/sam2.1_hiera_l.yaml",
legacy_config="configs/sam2/sam2_hiera_l.yaml",
checkpoint_filename="sam2.1_hiera_large.pt",
legacy_checkpoint_filename="sam2_hiera_large.pt",
),
}
SAM2_MODEL_ALIASES = {
"sam2": DEFAULT_SAM2_MODEL_ID,
"sam2.1": DEFAULT_SAM2_MODEL_ID,
"sam2_tiny": DEFAULT_SAM2_MODEL_ID,
}
# ---------------------------------------------------------------------------
# Attempt to import PyTorch and SAM 2; fall back to stubs if unavailable.
# ---------------------------------------------------------------------------
try:
import torch
TORCH_AVAILABLE = True
except Exception as exc: # noqa: BLE001
TORCH_AVAILABLE = False
torch = None # type: ignore[assignment]
logger.warning("PyTorch import failed (%s). SAM2 will be unavailable.", exc)
try:
from sam2.build_sam import build_sam2
from sam2.build_sam import build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
SAM2_AVAILABLE = True
logger.info("SAM2 library imported successfully.")
except Exception as exc: # noqa: BLE001
SAM2_AVAILABLE = False
logger.warning("SAM2 import failed (%s). Using stub engine.", exc)
class SAM2Engine:
"""Lazy-loaded SAM 2 inference engine."""
def __init__(self) -> None:
self._predictors: dict[str, Optional[SAM2ImagePredictor]] = {}
self._video_predictors: dict[str, object | None] = {}
self._model_loaded: dict[str, bool] = {}
self._video_model_loaded: dict[str, bool] = {}
self._loaded_device: dict[str, str] = {}
self._last_error: dict[str, str | None] = {}
self._video_last_error: dict[str, str | None] = {}
# -----------------------------------------------------------------------
# Internal helpers
# -----------------------------------------------------------------------
def variant_ids(self) -> list[str]:
return list(SAM2_VARIANTS.keys())
def normalize_model_id(self, model_id: str | None) -> str:
selected = (model_id or settings.sam_default_model or DEFAULT_SAM2_MODEL_ID).lower()
selected = SAM2_MODEL_ALIASES.get(selected, selected)
if selected not in SAM2_VARIANTS:
raise ValueError(f"Unsupported SAM2 model: {model_id}")
return selected
def is_sam2_model(self, model_id: str | None) -> bool:
try:
self.normalize_model_id(model_id)
return True
except ValueError:
return False
def _models_dir(self) -> Path:
configured_path = Path(settings.sam_model_path)
return configured_path.parent if configured_path.parent else Path("models")
def _variant(self, model_id: str | None) -> SAM2Variant:
return SAM2_VARIANTS[self.normalize_model_id(model_id)]
def _checkpoint_config(self, model_id: str | None) -> tuple[str, str]:
variant_id = self.normalize_model_id(model_id)
variant = SAM2_VARIANTS[variant_id]
models_dir = self._models_dir()
candidates: list[tuple[str, str]] = []
configured_path = Path(settings.sam_model_path)
if variant_id == DEFAULT_SAM2_MODEL_ID and configured_path.is_file():
candidates.append((settings.sam_model_config, str(configured_path)))
candidates.extend([
(variant.config, str(models_dir / variant.checkpoint_filename)),
(variant.legacy_config, str(models_dir / variant.legacy_checkpoint_filename)),
])
for config, checkpoint_path in candidates:
if os.path.isfile(checkpoint_path):
return config, checkpoint_path
return candidates[0]
def _load_model(self, model_id: str | None = None) -> None:
"""Load the SAM 2 model and predictor on first use."""
variant_id = self.normalize_model_id(model_id)
if self._model_loaded.get(variant_id):
return
if not TORCH_AVAILABLE:
self._last_error[variant_id] = "PyTorch is not installed."
logger.warning("PyTorch not available; skipping SAM2 model load.")
self._model_loaded[variant_id] = True
return
if not SAM2_AVAILABLE:
self._last_error[variant_id] = "sam2 package is not installed."
logger.warning("SAM2 not available; skipping model load.")
self._model_loaded[variant_id] = True
return
config, checkpoint_path = self._checkpoint_config(variant_id)
if not os.path.isfile(checkpoint_path):
self._last_error[variant_id] = f"SAM2 checkpoint not found: {checkpoint_path}"
logger.error("SAM checkpoint not found at %s", checkpoint_path)
self._model_loaded[variant_id] = True
return
try:
device = self._best_device()
model = build_sam2(
config,
checkpoint_path,
device=device,
)
self._predictors[variant_id] = SAM2ImagePredictor(model)
self._model_loaded[variant_id] = True
self._loaded_device[variant_id] = device
self._last_error[variant_id] = None
logger.info("SAM 2 model %s loaded from %s on %s", variant_id, checkpoint_path, device)
except Exception as exc: # noqa: BLE001
self._last_error[variant_id] = str(exc)
logger.error("Failed to load SAM 2 model %s: %s", variant_id, exc)
self._model_loaded[variant_id] = True # Prevent repeated load attempts
def _load_video_model(self, model_id: str | None = None) -> None:
"""Load the SAM 2 video predictor on first propagation use."""
variant_id = self.normalize_model_id(model_id)
if self._video_model_loaded.get(variant_id):
return
if not TORCH_AVAILABLE:
self._video_last_error[variant_id] = "PyTorch is not installed."
self._video_model_loaded[variant_id] = True
return
if not SAM2_AVAILABLE:
self._video_last_error[variant_id] = "sam2 package is not installed."
self._video_model_loaded[variant_id] = True
return
config, checkpoint_path = self._checkpoint_config(variant_id)
if not os.path.isfile(checkpoint_path):
self._video_last_error[variant_id] = f"SAM2 checkpoint not found: {checkpoint_path}"
self._video_model_loaded[variant_id] = True
return
try:
device = self._best_device()
self._video_predictors[variant_id] = build_sam2_video_predictor(
config,
checkpoint_path,
device=device,
)
self._video_model_loaded[variant_id] = True
self._loaded_device[variant_id] = device
self._video_last_error[variant_id] = None
logger.info("SAM 2 video predictor %s loaded from %s on %s", variant_id, checkpoint_path, device)
except Exception as exc: # noqa: BLE001
self._video_last_error[variant_id] = str(exc)
self._video_model_loaded[variant_id] = True
logger.error("Failed to load SAM 2 video predictor %s: %s", variant_id, exc)
def _best_device(self) -> str:
if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available():
return "cuda"
return "cpu"
def _ensure_ready(self, model_id: str | None = None) -> bool:
"""Ensure the model is loaded; return whether it is usable."""
variant_id = self.normalize_model_id(model_id)
self._load_model(variant_id)
return SAM2_AVAILABLE and self._predictors.get(variant_id) is not None
def _ensure_video_ready(self, model_id: str | None = None) -> bool:
"""Ensure the video predictor is loaded; return whether it is usable."""
variant_id = self.normalize_model_id(model_id)
self._load_video_model(variant_id)
return SAM2_AVAILABLE and self._video_predictors.get(variant_id) is not None
def status(self, model_id: str | None = None) -> dict:
"""Return lightweight, real runtime status without forcing model load."""
variant_id = self.normalize_model_id(model_id)
variant = SAM2_VARIANTS[variant_id]
_, checkpoint_path = self._checkpoint_config(variant_id)
checkpoint_exists = os.path.isfile(checkpoint_path)
using_legacy_checkpoint = Path(checkpoint_path).name == variant.legacy_checkpoint_filename
predictor = self._predictors.get(variant_id)
device = self._loaded_device.get(variant_id) or self._best_device()
available = bool(TORCH_AVAILABLE and SAM2_AVAILABLE and checkpoint_exists)
if predictor is not None:
message = f"{variant.label} model loaded and ready."
elif available:
message = f"{variant.label} dependencies and checkpoint are present; model will load on first inference."
if using_legacy_checkpoint:
message += " Using legacy SAM 2 checkpoint fallback."
else:
missing = []
if not TORCH_AVAILABLE:
missing.append("PyTorch")
if not SAM2_AVAILABLE:
missing.append("sam2 package")
if not checkpoint_exists:
missing.append("checkpoint")
message = f"{variant.label} unavailable: missing {', '.join(missing)}."
last_error = self._last_error.get(variant_id)
if last_error and not predictor:
message = last_error
return {
"id": variant.id,
"label": variant.label,
"available": available,
"loaded": predictor is not None,
"device": device,
"supports": ["point", "box", "interactive", "auto", "propagate"],
"message": message,
"package_available": SAM2_AVAILABLE,
"checkpoint_exists": checkpoint_exists,
"checkpoint_path": checkpoint_path,
"python_ok": True,
"torch_ok": TORCH_AVAILABLE,
"cuda_required": False,
}
# -----------------------------------------------------------------------
# Public API
# -----------------------------------------------------------------------
def predict_points(
self,
model_id: str | None,
image: np.ndarray,
points: list[list[float]],
labels: list[int],
) -> tuple[list[list[list[float]]], list[float]]:
"""Run point-prompt segmentation.
Args:
image: HWC numpy array (uint8).
points: List of [x, y] normalized coordinates (0-1).
labels: 1 for foreground, 0 for background.
Returns:
Tuple of (polygons, scores).
"""
variant_id = self.normalize_model_id(model_id)
if not self._ensure_ready(variant_id):
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
predictor = self._predictors[variant_id]
h, w = image.shape[:2]
pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32)
lbls = np.array(labels, dtype=np.int32)
with torch.inference_mode(): # type: ignore[name-defined]
predictor.set_image(image)
masks, scores, _ = predictor.predict(
point_coords=pts,
point_labels=lbls,
multimask_output=False,
)
polygons = []
for m in masks:
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores.tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 point prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def predict_box(
self,
model_id: str | None,
image: np.ndarray,
box: list[float],
) -> tuple[list[list[list[float]]], list[float]]:
"""Run box-prompt segmentation.
Args:
image: HWC numpy array (uint8).
box: [x1, y1, x2, y2] normalized coordinates.
Returns:
Tuple of (polygons, scores).
"""
variant_id = self.normalize_model_id(model_id)
if not self._ensure_ready(variant_id):
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
predictor = self._predictors[variant_id]
h, w = image.shape[:2]
bbox = np.array(
[box[0] * w, box[1] * h, box[2] * w, box[3] * h],
dtype=np.float32,
)
with torch.inference_mode(): # type: ignore[name-defined]
predictor.set_image(image)
masks, scores, _ = predictor.predict(
box=bbox[None, :],
multimask_output=False,
)
polygons = []
for m in masks:
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores.tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 box prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def predict_interactive(
self,
model_id: str | None,
image: np.ndarray,
box: list[float] | None,
points: list[list[float]],
labels: list[int],
) -> tuple[list[list[list[float]]], list[float]]:
"""Run combined box and point prompt segmentation for refinement."""
variant_id = self.normalize_model_id(model_id)
if not self._ensure_ready(variant_id):
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
predictor = self._predictors[variant_id]
h, w = image.shape[:2]
bbox = None
if box:
bbox = np.array(
[box[0] * w, box[1] * h, box[2] * w, box[3] * h],
dtype=np.float32,
)
pts = None
lbls = None
if points:
pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32)
lbls = np.array(labels, dtype=np.int32)
with torch.inference_mode(): # type: ignore[name-defined]
predictor.set_image(image)
masks, scores, _ = predictor.predict(
point_coords=pts,
point_labels=lbls,
box=bbox,
multimask_output=False,
)
polygons = []
for m in masks:
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores.tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 interactive prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def predict_auto(self, model_id: str | None, image: np.ndarray) -> tuple[list[list[list[float]]], list[float]]:
"""Run automatic mask generation (grid of points).
Args:
image: HWC numpy array (uint8).
Returns:
Tuple of (polygons, scores).
"""
variant_id = self.normalize_model_id(model_id)
if not self._ensure_ready(variant_id):
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
predictor = self._predictors[variant_id]
with torch.inference_mode(): # type: ignore[name-defined]
predictor.set_image(image)
# Generate a uniform 16x16 grid of point prompts
h, w = image.shape[:2]
grid = np.mgrid[0:1:17j, 0:1:17j].reshape(2, -1).T
pts = grid * np.array([w, h])
lbls = np.ones(pts.shape[0], dtype=np.int32)
masks, scores, _ = predictor.predict(
point_coords=pts,
point_labels=lbls,
multimask_output=False,
)
polygons = []
for m in masks[:1]:
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores[:1].tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 auto prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def propagate_video(
self,
model_id: str | None,
frame_paths: list[str],
source_frame_index: int,
seed: dict,
direction: str = "forward",
max_frames: int | None = None,
) -> list[dict]:
"""Propagate one seed mask across a prepared frame directory with SAM 2 video."""
variant_id = self.normalize_model_id(model_id)
if not self._ensure_video_ready(variant_id):
raise RuntimeError(self._video_last_error.get(variant_id) or self.status(variant_id)["message"])
video_predictor = self._video_predictors[variant_id]
if not frame_paths:
return []
if source_frame_index < 0 or source_frame_index >= len(frame_paths):
raise ValueError("source_frame_index is outside the frame sequence.")
import cv2
source_image = cv2.imread(frame_paths[source_frame_index])
if source_image is None:
raise RuntimeError("Failed to decode source frame for SAM 2 propagation.")
height, width = source_image.shape[:2]
seed_mask = self._polygons_to_mask(seed.get("polygons") or [], width, height, seed.get("holes") or [])
if not seed_mask.any():
bbox = seed.get("bbox")
if isinstance(bbox, list) and len(bbox) == 4:
seed_mask = self._bbox_to_mask(bbox, width, height)
if not seed_mask.any():
raise ValueError("SAM 2 propagation requires a non-empty seed polygon or bbox.")
inference_state = video_predictor.init_state(
video_path=os.path.dirname(frame_paths[0]),
offload_video_to_cpu=True,
offload_state_to_cpu=True,
)
video_predictor.add_new_mask(
inference_state,
frame_idx=source_frame_index,
obj_id=1,
mask=seed_mask,
)
results: dict[int, dict] = {}
def collect(reverse: bool) -> None:
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(
inference_state,
start_frame_idx=source_frame_index,
max_frame_num_to_track=max_frames,
reverse=reverse,
):
masks = out_mask_logits
if hasattr(masks, "detach"):
masks = masks.detach().cpu().numpy()
masks = np.asarray(masks)
if masks.ndim == 4:
masks = masks[:, 0]
polygons = []
holes = []
scores = []
for mask in masks:
mask_polygons, mask_holes = self._mask_to_polygon_data(mask > 0)
for polygon_index, polygon in enumerate(mask_polygons):
polygons.append(polygon)
holes.append(mask_holes[polygon_index] if polygon_index < len(mask_holes) else [])
scores.append(1.0)
results[int(out_frame_idx)] = {
"frame_index": int(out_frame_idx),
"polygons": polygons,
"holes": holes,
"scores": scores,
"object_ids": [int(obj_id) for obj_id in list(out_obj_ids)],
}
normalized_direction = direction.lower()
if normalized_direction in {"forward", "both"}:
collect(reverse=False)
if normalized_direction in {"backward", "both"}:
collect(reverse=True)
try:
video_predictor.reset_state(inference_state)
except Exception: # noqa: BLE001
pass
return [results[index] for index in sorted(results)]
# -----------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------
@staticmethod
def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
"""Convert a binary mask to a normalized polygon."""
polygons, _holes = SAM2Engine._mask_to_polygon_data(mask)
return polygons[0] if polygons else []
@staticmethod
def _mask_to_polygon_data(mask: np.ndarray) -> tuple[list[list[list[float]]], list[list[list[list[float]]]]]:
"""Convert a binary mask to normalized outer polygons and aligned hole rings."""
import cv2
if mask.dtype != np.uint8:
mask = (mask > 0).astype(np.uint8)
contours, hierarchy = cv2.findContours(mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
h, w = mask.shape[:2]
if hierarchy is None:
return [], []
def contour_to_polygon(contour: np.ndarray) -> list[list[float]]:
if len(contour) < 3:
return []
return [[float(pt[0][0]) / w, float(pt[0][1]) / h] for pt in contour]
hierarchy_rows = hierarchy[0]
outer_indices = [
index for index, row in enumerate(hierarchy_rows)
if int(row[3]) < 0 and len(contours[index]) >= 3
]
outer_indices.sort(key=lambda index: cv2.contourArea(contours[index]), reverse=True)
polygons: list[list[list[float]]] = []
holes: list[list[list[list[float]]]] = []
for outer_index in outer_indices:
outer = contour_to_polygon(contours[outer_index])
if not outer:
continue
child_index = int(hierarchy_rows[outer_index][2])
hole_group: list[list[list[float]]] = []
while child_index >= 0:
hole = contour_to_polygon(contours[child_index])
if hole:
hole_group.append(hole)
child_index = int(hierarchy_rows[child_index][0])
polygons.append(outer)
holes.append(hole_group)
return polygons, holes
@staticmethod
def _dummy_polygons(w: int, h: int) -> list[list[list[float]]]:
"""Return a dummy rectangle polygon for fallback mode."""
return [
[
[0.25, 0.25],
[0.75, 0.25],
[0.75, 0.75],
[0.25, 0.75],
]
]
@staticmethod
def _polygons_to_mask(
polygons: list[list[list[float]]],
width: int,
height: int,
holes_by_polygon: list[list[list[list[float]]]] | None = None,
) -> np.ndarray:
import cv2
mask = np.zeros((height, width), dtype=np.uint8)
for polygon_index, polygon in enumerate(polygons):
if len(polygon) < 3:
continue
pts = np.array(
[
[
int(round(min(max(float(x), 0.0), 1.0) * max(width - 1, 1))),
int(round(min(max(float(y), 0.0), 1.0) * max(height - 1, 1))),
]
for x, y in polygon
],
dtype=np.int32,
)
cv2.fillPoly(mask, [pts], 1)
holes = holes_by_polygon[polygon_index] if holes_by_polygon and polygon_index < len(holes_by_polygon) else []
for hole in holes:
if len(hole) < 3:
continue
hole_pts = np.array(
[
[
int(round(min(max(float(x), 0.0), 1.0) * max(width - 1, 1))),
int(round(min(max(float(y), 0.0), 1.0) * max(height - 1, 1))),
]
for x, y in hole
],
dtype=np.int32,
)
cv2.fillPoly(mask, [hole_pts], 0)
return mask.astype(bool)
@staticmethod
def _bbox_to_mask(bbox: list[float], width: int, height: int) -> np.ndarray:
x, y, w, h = [min(max(float(value), 0.0), 1.0) for value in bbox]
left = int(round(x * max(width - 1, 1)))
top = int(round(y * max(height - 1, 1)))
right = int(round(min(x + w, 1.0) * max(width - 1, 1)))
bottom = int(round(min(y + h, 1.0) * max(height - 1, 1)))
mask = np.zeros((height, width), dtype=bool)
mask[top:max(bottom + 1, top + 1), left:max(right + 1, left + 1)] = True
return mask
# Singleton instance
sam_engine = SAM2Engine()

View File

@@ -0,0 +1,447 @@
"""SAM 3 engine adapter and runtime status.
The official facebookresearch/sam3 package currently targets Python 3.12+
and CUDA-capable PyTorch. This adapter reports those requirements honestly and
only performs inference when the local runtime can actually import and execute
the package.
"""
from __future__ import annotations
import importlib.util
import json
import logging
import os
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from typing import Any
import numpy as np
from PIL import Image
from config import settings
from services.sam2_engine import SAM2Engine
logger = logging.getLogger(__name__)
try:
import torch
TORCH_AVAILABLE = True
except Exception as exc: # noqa: BLE001
TORCH_AVAILABLE = False
torch = None # type: ignore[assignment]
logger.warning("PyTorch import failed (%s). SAM3 will be unavailable.", exc)
SAM3_PACKAGE_AVAILABLE = importlib.util.find_spec("sam3") is not None
class SAM3Engine:
"""Lazy SAM 3 image inference adapter."""
def __init__(self) -> None:
self._model: Any | None = None
self._processor: Any | None = None
self._model_loaded = False
self._last_error: str | None = None
self._external_status_cache: dict[str, Any] | None = None
self._external_status_checked_at = 0.0
def _python_ok(self) -> bool:
return sys.version_info >= (3, 12)
def _gpu_ok(self) -> bool:
return bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
def _checkpoint_path(self) -> str | None:
path = settings.sam3_checkpoint_path.strip()
return path if path else None
def _checkpoint_exists(self) -> bool:
path = self._checkpoint_path()
return bool(path and os.path.isfile(path))
def _can_load(self) -> bool:
return bool(
SAM3_PACKAGE_AVAILABLE
and TORCH_AVAILABLE
and self._python_ok()
and self._gpu_ok()
and self._checkpoint_exists()
)
def _worker_path(self) -> Path:
return Path(__file__).with_name("sam3_external_worker.py")
def _external_python_exists(self) -> bool:
return bool(settings.sam3_external_enabled and os.path.isfile(settings.sam3_external_python))
def _external_status(self, force: bool = False) -> dict[str, Any]:
now = time.monotonic()
if (
not force
and self._external_status_cache is not None
and now - self._external_status_checked_at < settings.sam3_status_cache_seconds
):
return self._external_status_cache
if not settings.sam3_external_enabled:
status = {
"available": False,
"package_available": False,
"python_ok": False,
"torch_ok": False,
"cuda_available": False,
"device": "unavailable",
"message": "SAM 3 external runtime is disabled.",
}
elif not self._external_python_exists():
status = {
"available": False,
"package_available": False,
"python_ok": False,
"torch_ok": False,
"cuda_available": False,
"device": "unavailable",
"message": f"SAM 3 external Python not found: {settings.sam3_external_python}",
}
else:
try:
env = os.environ.copy()
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
if self._checkpoint_path():
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
completed = subprocess.run(
[settings.sam3_external_python, str(self._worker_path()), "--status"],
capture_output=True,
text=True,
timeout=min(settings.sam3_timeout_seconds, 30),
check=False,
env=env,
)
if completed.returncode != 0:
detail = completed.stderr.strip() or completed.stdout.strip()
status = {
"available": False,
"package_available": False,
"python_ok": False,
"torch_ok": False,
"cuda_available": False,
"device": "unavailable",
"message": f"SAM 3 external status failed: {detail}",
}
else:
status = json.loads(completed.stdout)
except Exception as exc: # noqa: BLE001
status = {
"available": False,
"package_available": False,
"python_ok": False,
"torch_ok": False,
"cuda_available": False,
"device": "unavailable",
"message": f"SAM 3 external status failed: {exc}",
}
self._external_status_cache = status
self._external_status_checked_at = now
return status
def _load_model(self) -> None:
if self._model_loaded:
return
if not self._can_load():
self._last_error = self._status_message()
self._model_loaded = True
return
try:
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.model_builder import build_sam3_image_model
self._model = build_sam3_image_model(
checkpoint_path=self._checkpoint_path(),
load_from_HF=False,
)
self._processor = Sam3Processor(self._model)
self._model_loaded = True
self._last_error = None
logger.info("SAM 3 image model loaded with version setting %s", settings.sam3_model_version)
except Exception as exc: # noqa: BLE001
self._last_error = str(exc)
self._model_loaded = True
logger.error("Failed to load SAM 3 model: %s", exc)
def _ensure_ready(self) -> bool:
self._load_model()
return self._processor is not None
def _status_message(self) -> str:
missing = []
if not SAM3_PACKAGE_AVAILABLE:
missing.append("sam3 package")
if not self._python_ok():
missing.append("Python 3.12+ runtime")
if not TORCH_AVAILABLE:
missing.append("PyTorch")
if not self._gpu_ok():
missing.append("CUDA GPU")
if not self._checkpoint_exists():
missing.append(f"local checkpoint ({settings.sam3_checkpoint_path})")
if missing:
return f"SAM 3 unavailable: missing {', '.join(missing)}."
return "SAM 3 dependencies are present; model will load on first inference."
def status(self) -> dict:
external_status = self._external_status()
available = bool(self._can_load() or external_status.get("available"))
external_ready = bool(external_status.get("available"))
message = self._last_error or self._status_message()
if self._processor is not None:
message = "SAM 3 model loaded and ready."
elif external_ready:
message = "SAM 3 external runtime is ready; local checkpoint will load in the helper process on inference."
elif external_status.get("message") and not self._can_load():
message = str(external_status["message"])
return {
"id": "sam3",
"label": "SAM 3",
"available": available,
"loaded": self._processor is not None,
"device": "cuda" if self._gpu_ok() else str(external_status.get("device", "unavailable")),
"supports": ["semantic", "box", "video_track"],
"message": message,
"package_available": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("package_available")),
"checkpoint_exists": bool(self._checkpoint_exists() or external_status.get("checkpoint_access")),
"checkpoint_path": self._checkpoint_path() or f"official/HuggingFace ({settings.sam3_model_version})",
"python_ok": bool(self._python_ok() or external_status.get("python_ok")),
"torch_ok": bool(TORCH_AVAILABLE or external_status.get("torch_ok")),
"cuda_required": True,
"external_available": external_ready,
"external_python": settings.sam3_external_python if settings.sam3_external_enabled else None,
}
def _xyxy_to_cxcywh(self, box: list[float]) -> list[float]:
if len(box) != 4:
raise ValueError("SAM 3 box prompt requires [x1, y1, x2, y2].")
x1, y1, x2, y2 = [min(max(float(value), 0.0), 1.0) for value in box]
left, right = sorted([x1, x2])
top, bottom = sorted([y1, y2])
width = max(right - left, 1e-6)
height = max(bottom - top, 1e-6)
return [left + width / 2, top + height / 2, width, height]
def _prediction_to_polygons(self, output: Any) -> tuple[list[list[list[float]]], list[float]]:
masks = output.get("masks", [])
scores = output.get("scores", [])
polygons = []
for mask in masks:
if hasattr(mask, "detach"):
mask = mask.detach().cpu().numpy()
if mask.ndim == 3:
mask = mask[0]
poly = SAM2Engine._mask_to_polygon(mask)
if poly:
polygons.append(poly)
if hasattr(scores, "detach"):
scores = scores.detach().cpu().tolist()
elif hasattr(scores, "tolist"):
scores = scores.tolist()
return polygons, list(scores)
def _predict_external(
self,
image: np.ndarray,
prompt_type: str,
*,
text: str = "",
box: list[float] | None = None,
confidence_threshold: float | None = None,
) -> tuple[list[list[list[float]]], list[float]]:
status = self._external_status(force=True)
if not status.get("available"):
raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.")
with tempfile.TemporaryDirectory(prefix="sam3_") as tmpdir:
tmp_path = Path(tmpdir)
image_path = tmp_path / "image.png"
request_path = tmp_path / "request.json"
Image.fromarray(image).save(image_path)
request_path.write_text(
json.dumps(
{
"image_path": str(image_path),
"prompt_type": prompt_type,
"text": text.strip(),
"box": box,
"model_version": settings.sam3_model_version,
"checkpoint_path": self._checkpoint_path(),
"confidence_threshold": (
confidence_threshold
if confidence_threshold is not None
else settings.sam3_confidence_threshold
),
},
ensure_ascii=False,
),
encoding="utf-8",
)
env = os.environ.copy()
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
if self._checkpoint_path():
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
completed = subprocess.run(
[settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)],
capture_output=True,
text=True,
timeout=settings.sam3_timeout_seconds,
check=False,
env=env,
)
if completed.returncode != 0:
detail = completed.stderr.strip() or completed.stdout.strip()
try:
parsed = json.loads(detail)
detail = parsed.get("error", detail)
except Exception: # noqa: BLE001
pass
raise RuntimeError(f"SAM 3 external inference failed: {detail}")
payload = json.loads(completed.stdout)
if payload.get("error"):
raise RuntimeError(str(payload["error"]))
return payload.get("polygons", []), payload.get("scores", [])
def _predict_semantic_external(
self,
image: np.ndarray,
text: str,
confidence_threshold: float | None = None,
) -> tuple[list[list[list[float]]], list[float]]:
return self._predict_external(
image,
"semantic",
text=text,
confidence_threshold=confidence_threshold,
)
def _predict_box_external(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]:
return self._predict_external(image, "box", box=box)
def _propagate_video_external(
self,
frame_paths: list[str],
source_frame_index: int,
seed: dict[str, Any],
direction: str,
max_frames: int | None,
) -> list[dict[str, Any]]:
status = self._external_status(force=True)
if not status.get("available"):
raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.")
if not frame_paths:
return []
with tempfile.TemporaryDirectory(prefix="sam3_video_") as tmpdir:
request_path = Path(tmpdir) / "request.json"
request_path.write_text(
json.dumps(
{
"prompt_type": "video_track",
"frame_dir": str(Path(frame_paths[0]).parent),
"source_frame_index": source_frame_index,
"seed": seed,
"direction": direction,
"max_frames": max_frames,
"model_version": settings.sam3_model_version,
"checkpoint_path": self._checkpoint_path(),
"confidence_threshold": settings.sam3_confidence_threshold,
},
ensure_ascii=False,
),
encoding="utf-8",
)
env = os.environ.copy()
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
if self._checkpoint_path():
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
completed = subprocess.run(
[settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)],
capture_output=True,
text=True,
timeout=settings.sam3_timeout_seconds,
check=False,
env=env,
)
if completed.returncode != 0:
detail = completed.stderr.strip() or completed.stdout.strip()
try:
parsed = json.loads(detail)
detail = parsed.get("error", detail)
except Exception: # noqa: BLE001
pass
raise RuntimeError(f"SAM 3 external video tracking failed: {detail}")
payload = json.loads(completed.stdout)
if payload.get("error"):
raise RuntimeError(str(payload["error"]))
return payload.get("frames", [])
def predict_semantic(
self,
image: np.ndarray,
text: str,
confidence_threshold: float | None = None,
) -> tuple[list[list[list[float]]], list[float]]:
if not text.strip():
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
if not self._can_load() and self._external_status().get("available"):
return self._predict_semantic_external(image, text, confidence_threshold=confidence_threshold)
if not self._ensure_ready():
raise RuntimeError(self.status()["message"])
pil_image = Image.fromarray(image)
with torch.inference_mode(): # type: ignore[union-attr]
state = self._processor.set_image(pil_image)
output = self._processor.set_text_prompt(state=state, prompt=text.strip())
return self._prediction_to_polygons(output)
def predict_points(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for point prompts.")
def predict_box(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]:
if not self._can_load() and self._external_status().get("available"):
return self._predict_box_external(image, box)
if not self._ensure_ready():
raise RuntimeError(self.status()["message"])
pil_image = Image.fromarray(image)
with torch.inference_mode(): # type: ignore[union-attr]
state = self._processor.set_image(pil_image)
output = self._processor.add_geometric_prompt(
state=state,
box=self._xyxy_to_cxcywh(box),
label=True,
)
return self._prediction_to_polygons(output)
def propagate_video(
self,
frame_paths: list[str],
source_frame_index: int,
seed: dict[str, Any],
direction: str = "forward",
max_frames: int | None = None,
) -> list[dict[str, Any]]:
return self._propagate_video_external(frame_paths, source_frame_index, seed, direction, max_frames)
sam3_engine = SAM3Engine()

View File

@@ -0,0 +1,343 @@
"""Standalone SAM 3 helper for the dedicated Python 3.12 runtime.
The main FastAPI backend can keep running in the existing Python 3.11/SAM 2
environment while this helper is executed with a separate conda env that meets
SAM 3's stricter runtime requirements.
"""
from __future__ import annotations
import argparse
import importlib.util
import json
import os
import sys
from pathlib import Path
from typing import Any
import numpy as np
from PIL import Image
def _torch_status() -> tuple[bool, str | None, str | None, str | None]:
try:
import torch
cuda_available = bool(torch.cuda.is_available())
return (
cuda_available,
getattr(torch, "__version__", None),
getattr(torch.version, "cuda", None),
torch.cuda.get_device_name(0) if cuda_available else None,
)
except Exception: # noqa: BLE001
return False, None, None, None
def _compact_error(exc: Exception) -> str:
lines = [line.strip() for line in str(exc).splitlines() if line.strip()]
for line in lines:
if "Access to model" in line or "Cannot access gated repo" in line:
return line
return lines[0] if lines else exc.__class__.__name__
def _checkpoint_access(model_version: str) -> tuple[bool, str | None]:
checkpoint_path = os.environ.get("SAM3_CHECKPOINT_PATH", "").strip()
if checkpoint_path:
path = Path(checkpoint_path)
if path.is_file():
return True, None
return False, f"local checkpoint not found: {checkpoint_path}"
try:
from huggingface_hub import hf_hub_download
repo_id = "facebook/sam3.1" if model_version == "sam3.1" else "facebook/sam3"
hf_hub_download(repo_id=repo_id, filename="config.json")
return True, None
except Exception as exc: # noqa: BLE001
return False, _compact_error(exc)
def runtime_status() -> dict[str, Any]:
model_version = os.environ.get("SAM3_MODEL_VERSION", "sam3")
checkpoint_path = os.environ.get("SAM3_CHECKPOINT_PATH", "").strip() or None
package_error = None
package_available = importlib.util.find_spec("sam3") is not None
if package_available:
try:
import sam3 # noqa: F401
except Exception as exc: # noqa: BLE001
package_available = False
package_error = str(exc)
cuda_available, torch_version, cuda_version, device_name = _torch_status()
python_ok = sys.version_info >= (3, 12)
checkpoint_access = False
checkpoint_error = None
if package_available:
checkpoint_access, checkpoint_error = _checkpoint_access(model_version)
available = bool(package_available and python_ok and cuda_available and checkpoint_access)
missing = []
if not python_ok:
missing.append("Python 3.12+ runtime")
if not package_available:
missing.append(f"sam3 package ({package_error})" if package_error else "sam3 package")
if torch_version is None:
missing.append("PyTorch")
if not cuda_available:
missing.append("CUDA GPU")
if package_available and not checkpoint_access:
missing.append(f"Hugging Face checkpoint access ({checkpoint_error})")
return {
"available": available,
"package_available": package_available,
"checkpoint_access": checkpoint_access,
"checkpoint_path": checkpoint_path or f"official/HuggingFace ({model_version})",
"python_ok": python_ok,
"torch_ok": torch_version is not None,
"torch_version": torch_version,
"cuda_version": cuda_version,
"cuda_available": cuda_available,
"device": "cuda" if cuda_available else "unavailable",
"device_name": device_name,
"message": (
"SAM 3 external runtime is ready."
if available
else f"SAM 3 external runtime unavailable: missing {', '.join(missing)}."
),
}
def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
import cv2
if mask.dtype != np.uint8:
mask = (mask > 0).astype(np.uint8)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
height, width = mask.shape[:2]
largest = []
for contour in contours:
if len(contour) > len(largest):
largest = contour
if len(largest) < 3:
return []
return [[float(point[0][0]) / width, float(point[0][1]) / height] for point in largest]
def _to_numpy(value: Any) -> np.ndarray:
if hasattr(value, "detach"):
value = value.detach()
if hasattr(value, "is_floating_point") and value.is_floating_point():
value = value.float()
value = value.cpu().numpy()
elif hasattr(value, "cpu"):
value = value.cpu()
if hasattr(value, "is_floating_point") and value.is_floating_point():
value = value.float()
value = value.numpy()
return np.asarray(value)
def _xyxy_to_cxcywh(box: list[float]) -> list[float]:
if len(box) != 4:
raise ValueError("SAM 3 box prompt requires [x1, y1, x2, y2].")
x1, y1, x2, y2 = [min(max(float(value), 0.0), 1.0) for value in box]
left, right = sorted([x1, x2])
top, bottom = sorted([y1, y2])
width = max(right - left, 1e-6)
height = max(bottom - top, 1e-6)
return [left + width / 2, top + height / 2, width, height]
def _bbox_from_seed(seed: dict[str, Any]) -> list[float]:
bbox = seed.get("bbox")
if isinstance(bbox, list) and len(bbox) == 4:
return [min(max(float(value), 0.0), 1.0) for value in bbox]
polygons = seed.get("polygons") or []
points = [point for polygon in polygons for point in polygon if len(point) >= 2]
if not points:
raise ValueError("SAM 3 video tracking requires seed bbox or polygons.")
xs = [min(max(float(point[0]), 0.0), 1.0) for point in points]
ys = [min(max(float(point[1]), 0.0), 1.0) for point in points]
left, right = min(xs), max(xs)
top, bottom = min(ys), max(ys)
return [left, top, max(right - left, 1e-6), max(bottom - top, 1e-6)]
def _video_outputs_to_response(outputs: dict[str, Any]) -> dict[str, Any]:
masks = _to_numpy(outputs.get("out_binary_masks", []))
scores = _to_numpy(outputs.get("out_probs", []))
obj_ids = _to_numpy(outputs.get("out_obj_ids", []))
if masks.ndim == 4:
masks = masks[:, 0]
elif masks.ndim == 2:
masks = masks[None, ...]
polygons = []
out_scores = []
out_ids = []
for index, mask in enumerate(masks):
polygon = _mask_to_polygon(mask)
if polygon:
polygons.append(polygon)
out_scores.append(float(scores[index]) if scores.size > index else 1.0)
out_ids.append(int(obj_ids[index]) if obj_ids.size > index else index + 1)
return {"polygons": polygons, "scores": out_scores, "object_ids": out_ids}
def _prediction_to_response(output: dict[str, Any]) -> dict[str, Any]:
masks = _to_numpy(output.get("masks", []))
scores = _to_numpy(output.get("scores", []))
if masks.ndim == 2:
masks = masks[None, :, :]
elif masks.ndim == 4:
masks = masks[:, 0]
elif masks.ndim == 3 and masks.shape[0] == 1:
masks = masks[None, 0]
polygons = []
for mask in masks:
polygon = _mask_to_polygon(mask)
if polygon:
polygons.append(polygon)
return {
"polygons": polygons,
"scores": scores.astype(float).tolist() if scores.size else [],
}
def predict_video(request_path: Path) -> dict[str, Any]:
import torch
from sam3.model_builder import build_sam3_video_predictor
payload = json.loads(request_path.read_text(encoding="utf-8"))
frame_dir = Path(payload["frame_dir"])
source_frame_index = int(payload.get("source_frame_index", 0))
seed = payload.get("seed") or {}
direction = str(payload.get("direction") or "forward").lower()
max_frames = payload.get("max_frames")
max_frames = int(max_frames) if max_frames else None
checkpoint_path = str(payload.get("checkpoint_path") or os.environ.get("SAM3_CHECKPOINT_PATH", "")).strip()
threshold = float(payload.get("confidence_threshold", 0.5))
if direction not in {"forward", "backward", "both"}:
raise ValueError(f"Unsupported propagation direction: {direction}")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
predictor = build_sam3_video_predictor(
checkpoint_path=checkpoint_path or None,
async_loading_frames=False,
)
session_id = None
try:
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
session = predictor.handle_request(
{
"type": "start_session",
"resource_path": str(frame_dir),
"offload_video_to_cpu": True,
"offload_state_to_cpu": True,
}
)
session_id = session["session_id"]
predictor.handle_request(
{
"type": "add_prompt",
"session_id": session_id,
"frame_index": source_frame_index,
"bounding_boxes": [_bbox_from_seed(seed)],
"bounding_box_labels": [1],
"output_prob_thresh": threshold,
"rel_coordinates": True,
}
)
frames = []
for item in predictor.handle_stream_request(
{
"type": "propagate_in_video",
"session_id": session_id,
"propagation_direction": direction,
"start_frame_index": source_frame_index,
"max_frame_num_to_track": max_frames,
"output_prob_thresh": threshold,
}
):
frame_response = _video_outputs_to_response(item.get("outputs") or {})
frame_response["frame_index"] = int(item["frame_index"])
frames.append(frame_response)
finally:
if session_id:
predictor.handle_request({"type": "close_session", "session_id": session_id})
return {"frames": frames}
def predict(request_path: Path) -> dict[str, Any]:
import torch
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.model_builder import build_sam3_image_model
payload = json.loads(request_path.read_text(encoding="utf-8"))
if str(payload.get("prompt_type") or "").strip().lower() == "video_track":
return predict_video(request_path)
image_path = Path(payload["image_path"])
prompt_type = str(payload.get("prompt_type") or "semantic").strip().lower()
text = str(payload.get("text") or "").strip()
threshold = float(payload.get("confidence_threshold", 0.5))
checkpoint_path = str(payload.get("checkpoint_path") or os.environ.get("SAM3_CHECKPOINT_PATH", "")).strip()
if prompt_type == "semantic" and not text:
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
if prompt_type not in {"semantic", "box"}:
raise ValueError(f"Unsupported SAM 3 prompt type: {prompt_type}")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
image = Image.open(image_path).convert("RGB")
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
model = build_sam3_image_model(
checkpoint_path=checkpoint_path or None,
load_from_HF=not bool(checkpoint_path),
)
processor = Sam3Processor(model, confidence_threshold=threshold)
state = processor.set_image(image)
if prompt_type == "box":
output = processor.add_geometric_prompt(
state=state,
box=_xyxy_to_cxcywh(payload.get("box") or []),
label=True,
)
else:
output = processor.set_text_prompt(state=state, prompt=text)
return _prediction_to_response(output)
def main() -> int:
parser = argparse.ArgumentParser(description="SAM 3 external runtime helper")
parser.add_argument("--status", action="store_true")
parser.add_argument("--request", type=Path)
args = parser.parse_args()
try:
if args.status:
print(json.dumps(runtime_status(), ensure_ascii=False))
return 0
if args.request:
print(json.dumps(predict(args.request), ensure_ascii=False))
return 0
parser.error("Use --status or --request")
except Exception as exc: # noqa: BLE001
print(json.dumps({"error": str(exc)}, ensure_ascii=False), file=sys.stderr)
return 1
return 2
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,130 @@
"""Model registry for SAM runtimes and GPU status."""
from __future__ import annotations
from typing import Any
from config import settings
from services.sam2_engine import DEFAULT_SAM2_MODEL_ID, TORCH_AVAILABLE, sam_engine as sam2_engine
# SAM 3 integration is intentionally disabled for the current product flow.
# The source files are kept in the repository so the integration can be
# restored later, but the active registry only exposes SAM 2.
# from services.sam3_engine import sam3_engine
try:
import torch
except Exception: # noqa: BLE001
torch = None # type: ignore[assignment]
class ModelUnavailableError(RuntimeError):
"""Raised when a selected model cannot run in this environment."""
class SAMRegistry:
"""Dispatch predictions to the selected SAM backend."""
def __init__(self) -> None:
self._engines = {
"sam2": sam2_engine,
# "sam3": sam3_engine,
}
def normalize_model_id(self, model_id: str | None) -> str:
selected = (model_id or settings.sam_default_model or DEFAULT_SAM2_MODEL_ID).lower()
if self._engines["sam2"].is_sam2_model(selected):
return self._engines["sam2"].normalize_model_id(selected)
if selected not in self._engines:
raise ValueError(f"Unsupported model: {model_id}")
return selected
def runtime_status(self, selected_model: str | None = None) -> dict[str, Any]:
selected = self.normalize_model_id(selected_model)
return {
"selected_model": selected,
"gpu": self.gpu_status(),
"models": [sam2_engine.status(model_id) for model_id in sam2_engine.variant_ids()],
}
def gpu_status(self) -> dict[str, Any]:
cuda_available = bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
return {
"available": cuda_available,
"device": "cuda" if cuda_available else "cpu",
"name": torch.cuda.get_device_name(0) if cuda_available else None,
"torch_available": bool(TORCH_AVAILABLE),
"torch_version": getattr(torch, "__version__", None) if torch is not None else None,
"cuda_version": getattr(torch.version, "cuda", None) if torch is not None else None,
}
def _engine(self, model_id: str | None) -> Any:
normalized = self.normalize_model_id(model_id)
if self._engines["sam2"].is_sam2_model(normalized):
return self._engines["sam2"]
return self._engines[normalized]
def _ensure_available(self, model_id: str | None) -> Any:
normalized = self.normalize_model_id(model_id)
engine = self._engine(model_id)
status = engine.status(normalized) if engine is sam2_engine else engine.status()
if not status["available"]:
raise ModelUnavailableError(status["message"])
return engine
def predict_points(self, model_id: str | None, image: Any, points: list[list[float]], labels: list[int]):
model = self.normalize_model_id(model_id)
return self._ensure_available(model).predict_points(model, image, points, labels)
def predict_box(self, model_id: str | None, image: Any, box: list[float]):
model = self.normalize_model_id(model_id)
return self._ensure_available(model).predict_box(model, image, box)
def predict_interactive(
self,
model_id: str | None,
image: Any,
box: list[float] | None,
points: list[list[float]],
labels: list[int],
):
model = self.normalize_model_id(model_id)
if not sam2_engine.is_sam2_model(model):
raise NotImplementedError("Interactive box + point refinement is currently supported by SAM 2.")
return self._ensure_available(model).predict_interactive(model, image, box, points, labels)
def predict_auto(self, model_id: str | None, image: Any):
model = self.normalize_model_id(model_id)
return self._ensure_available(model).predict_auto(model, image)
def predict_semantic(
self,
model_id: str | None,
image: Any,
text: str,
confidence_threshold: float | None = None,
):
self.normalize_model_id(model_id)
raise NotImplementedError("Semantic text prompting is disabled; use SAM 2 point or box prompts.")
def propagate_video(
self,
model_id: str | None,
frame_paths: list[str],
source_frame_index: int,
seed: dict[str, Any],
direction: str,
max_frames: int | None,
):
model = self.normalize_model_id(model_id)
return self._ensure_available(model).propagate_video(
model,
frame_paths,
source_frame_index,
seed,
direction=direction,
max_frames=max_frames,
)
sam_registry = SAMRegistry()

15
backend/statuses.py Normal file
View File

@@ -0,0 +1,15 @@
"""Shared status constants used across backend project/task flows."""
PROJECT_STATUS_PENDING = "pending"
PROJECT_STATUS_PARSING = "parsing"
PROJECT_STATUS_READY = "ready"
PROJECT_STATUS_ERROR = "error"
TASK_STATUS_QUEUED = "queued"
TASK_STATUS_RUNNING = "running"
TASK_STATUS_SUCCESS = "success"
TASK_STATUS_FAILED = "failed"
TASK_STATUS_CANCELLED = "cancelled"
TASK_ACTIVE_STATUSES = {TASK_STATUS_QUEUED, TASK_STATUS_RUNNING}
TASK_TERMINAL_STATUSES = {TASK_STATUS_SUCCESS, TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}

36
backend/worker_tasks.py Normal file
View File

@@ -0,0 +1,36 @@
"""Celery task definitions."""
import logging
from celery_app import celery_app
from database import SessionLocal
from services.media_task_runner import run_parse_media_task
from services.propagation_task_runner import run_propagate_project_task
logger = logging.getLogger(__name__)
@celery_app.task(name="media.parse_project")
def parse_project_media(task_id: int) -> dict:
"""Run media parsing for one queued task."""
db = SessionLocal()
try:
return run_parse_media_task(db, task_id)
except Exception as exc: # noqa: BLE001
logger.exception("Parse media task failed: task_id=%s", task_id)
raise exc
finally:
db.close()
@celery_app.task(name="ai.propagate_project")
def propagate_project_masks(task_id: int) -> dict:
"""Run SAM video propagation for one queued task."""
db = SessionLocal()
try:
return run_propagate_project_task(db, task_id)
except Exception as exc: # noqa: BLE001
logger.exception("Propagation task failed: task_id=%s", task_id)
raise exc
finally:
db.close()