添加Docker自包含部署分支

- 新增 Seg_Server_Docker 自包含部署内容,包含前后端、FastAPI、Celery、PostgreSQL、Redis、MinIO、演示视频和 DICOM 数据。

- 保留 demo 数据以支持恢复演示出厂设置,排除 SAM 2.1 .pt 权重并在 README 中补充下载命令。

- 补充 GPU 部署、backend/worker 镜像复用、frpc/frps + NPM 公网域名反代部署说明。

- 在 .env/.env.example 中用 # XXXX 标注局域网和公网域名部署需要修改的配置项。

- 添加部署分支 .gitignore,忽略本地模型权重、构建产物、缓存和日志。
This commit is contained in:
2026-05-07 19:06:07 +08:00
commit b5413066a0
396 changed files with 32742 additions and 0 deletions

View File

299
backend/routers/admin.py Normal file
View File

@@ -0,0 +1,299 @@
"""Administrator-only user and audit management endpoints."""
import os
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from config import settings
from database import get_db
from models import Annotation, AuditLog, Frame, Mask, ProcessingTask, Project, Template, User
from routers.auth import SUPPORTED_ROLES, ensure_default_admin, hash_password, normalize_user_role, require_admin, write_audit_log
from schemas import (
AdminUserCreate,
AdminUserUpdate,
AuditLogOut,
DemoFactoryResetOut,
DemoFactoryResetRequest,
UserOut,
)
from services.demo_media import (
DEMO_DICOM_PROJECT_NAME,
DEMO_VIDEO_PROJECT_NAME,
create_parsed_dicom_demo_project,
create_parsed_video_demo_project,
demo_dicom_files,
)
from services.default_templates import restore_default_templates
router = APIRouter(prefix="/api/admin", tags=["Admin"])
DEMO_RESET_CONFIRMATION = "RESET_DEMO_FACTORY"
DEMO_PROJECT_NAME = DEMO_DICOM_PROJECT_NAME
def _normalize_role(role: str | None) -> str:
normalized = (role or "annotator").strip().lower()
if normalized not in SUPPORTED_ROLES:
raise HTTPException(status_code=400, detail=f"Unsupported role: {role}")
return normalized
def _assert_non_admin_role(role: str) -> None:
if role == "admin":
raise HTTPException(status_code=400, detail="Only the default admin account can have admin role")
@router.get("/users", response_model=List[UserOut], summary="List users")
def list_users(
db: Session = Depends(get_db),
admin_user: User = Depends(require_admin),
) -> List[User]:
"""Return all users for the administrator console."""
_ = admin_user
users = db.query(User).order_by(User.id).all()
return [normalize_user_role(db, user) for user in users]
@router.post(
"/users",
response_model=UserOut,
status_code=status.HTTP_201_CREATED,
summary="Create user",
)
def create_user(
payload: AdminUserCreate,
db: Session = Depends(get_db),
admin_user: User = Depends(require_admin),
) -> User:
"""Create a user with an initial password and role."""
username = payload.username.strip()
if not username:
raise HTTPException(status_code=400, detail="Username is required")
if len(payload.password) < 6:
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
role = _normalize_role(payload.role)
_assert_non_admin_role(role)
user = User(
username=username,
password_hash=hash_password(payload.password),
role=role,
is_active=1 if payload.is_active else 0,
)
db.add(user)
try:
db.commit()
except IntegrityError as exc:
db.rollback()
raise HTTPException(status_code=409, detail="Username already exists") from exc
db.refresh(user)
write_audit_log(
db,
actor=admin_user,
action="admin.user_created",
target_type="user",
target_id=user.id,
detail={"username": user.username, "role": user.role, "is_active": bool(user.is_active)},
)
return user
@router.patch("/users/{user_id}", response_model=UserOut, summary="Update user")
def update_user(
user_id: int,
payload: AdminUserUpdate,
db: Session = Depends(get_db),
admin_user: User = Depends(require_admin),
) -> User:
"""Update username, password, role or active state."""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
user = normalize_user_role(db, user)
updates = payload.model_dump(exclude_unset=True)
audit_detail: dict = {"before": {"username": user.username, "role": user.role, "is_active": bool(user.is_active)}}
if "username" in updates:
username = (updates["username"] or "").strip()
if not username:
raise HTTPException(status_code=400, detail="Username is required")
if user.role == "admin" and username != settings.default_admin_username:
raise HTTPException(status_code=400, detail="Default admin username cannot be changed")
user.username = username
if "password" in updates:
password = updates["password"] or ""
if len(password) < 6:
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
user.password_hash = hash_password(password)
if "role" in updates:
next_role = _normalize_role(updates["role"])
if user.username == settings.default_admin_username:
if next_role != "admin":
raise HTTPException(status_code=400, detail="Cannot remove the default admin role")
else:
_assert_non_admin_role(next_role)
user.role = next_role
if "is_active" in updates:
if user.id == admin_user.id and not updates["is_active"]:
raise HTTPException(status_code=400, detail="Cannot deactivate yourself")
user.is_active = 1 if updates["is_active"] else 0
try:
db.commit()
except IntegrityError as exc:
db.rollback()
raise HTTPException(status_code=409, detail="Username already exists") from exc
db.refresh(user)
audit_detail["after"] = {"username": user.username, "role": user.role, "is_active": bool(user.is_active)}
audit_detail["password_changed"] = "password" in updates
write_audit_log(
db,
actor=admin_user,
action="admin.user_updated",
target_type="user",
target_id=user.id,
detail=audit_detail,
)
return user
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete user")
def delete_user(
user_id: int,
db: Session = Depends(get_db),
admin_user: User = Depends(require_admin),
) -> None:
"""Delete a user when it is safe to remove the account."""
if user_id == admin_user.id:
raise HTTPException(status_code=400, detail="Cannot delete yourself")
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
user = normalize_user_role(db, user)
if user.role == "admin":
raise HTTPException(status_code=400, detail="Cannot delete the default admin account")
username = user.username
db.delete(user)
db.commit()
write_audit_log(
db,
actor=admin_user,
action="admin.user_deleted",
target_type="user",
target_id=user_id,
detail={"username": username},
)
return None
@router.get("/audit-logs", response_model=List[AuditLogOut], summary="List audit logs")
def list_audit_logs(
limit: int = 100,
db: Session = Depends(get_db),
admin_user: User = Depends(require_admin),
) -> List[AuditLog]:
"""Return recent audit events for administrators."""
_ = admin_user
safe_limit = min(max(int(limit or 100), 1), 500)
return db.query(AuditLog).order_by(AuditLog.created_at.desc(), AuditLog.id.desc()).limit(safe_limit).all()
@router.post(
"/demo-factory-reset",
response_model=DemoFactoryResetOut,
summary="Reset demo data to factory defaults",
)
def reset_demo_factory(
payload: DemoFactoryResetRequest,
db: Session = Depends(get_db),
admin_user: User = Depends(require_admin),
) -> dict:
"""Reset a demo deployment to one admin account, the demo video, and the demo DICOM project."""
if payload.confirmation != DEMO_RESET_CONFIRMATION:
raise HTTPException(status_code=400, detail="Invalid reset confirmation")
if not os.path.exists(settings.demo_video_path):
raise HTTPException(
status_code=409,
detail=f"Demo video not found: {settings.demo_video_path}",
)
if not demo_dicom_files(settings.demo_dicom_dir):
raise HTTPException(
status_code=409,
detail=f"Demo DICOM series not found: {settings.demo_dicom_dir}",
)
requested_by = admin_user.username
preserved_admin = ensure_default_admin(db)
preserved_admin.username = settings.default_admin_username
preserved_admin.password_hash = hash_password(settings.default_admin_password)
preserved_admin.role = "admin"
preserved_admin.is_active = 1
db.flush()
deleted_counts = {
"masks": db.query(Mask).delete(synchronize_session=False),
"annotations": db.query(Annotation).delete(synchronize_session=False),
"frames": db.query(Frame).delete(synchronize_session=False),
"tasks": db.query(ProcessingTask).delete(synchronize_session=False),
"projects": db.query(Project).delete(synchronize_session=False),
"user_templates": db.query(Template).filter(Template.owner_user_id.is_not(None)).delete(synchronize_session=False),
"audit_logs": db.query(AuditLog).delete(synchronize_session=False),
"users": db.query(User).filter(User.id != preserved_admin.id).delete(synchronize_session=False),
}
db.flush()
db.expunge_all()
preserved_admin = db.query(User).filter(User.username == settings.default_admin_username).first()
if not preserved_admin:
raise HTTPException(status_code=500, detail="Default admin was not preserved")
restored_templates = restore_default_templates(db)
video_project = create_parsed_video_demo_project(
db,
owner=preserved_admin,
video_path=settings.demo_video_path,
project_name=DEMO_VIDEO_PROJECT_NAME,
)
dicom_project = create_parsed_dicom_demo_project(
db,
owner=preserved_admin,
dicom_dir=settings.demo_dicom_dir,
project_name=DEMO_PROJECT_NAME,
)
db.refresh(preserved_admin)
db.refresh(video_project)
db.refresh(dicom_project)
video_project.frame_count = len(video_project.frames)
dicom_project.frame_count = len(dicom_project.frames)
projects = [video_project, dicom_project]
write_audit_log(
db,
actor=preserved_admin,
action="admin.demo_factory_reset",
target_type="project",
target_id=dicom_project.id,
detail={
"project_names": [project.name for project in projects],
"video_path": video_project.video_path,
"dicom_path": dicom_project.video_path,
"source_types": [project.source_type for project in projects],
"frame_counts": {project.name: len(project.frames) for project in projects},
"deleted_counts": deleted_counts,
"restored_templates": [template.name for template in restored_templates],
"requested_by": requested_by,
},
)
return {
"admin_user": preserved_admin,
"project": dicom_project,
"projects": projects,
"deleted_counts": deleted_counts,
"message": "演示环境已恢复出厂设置",
}

1228
backend/routers/ai.py Normal file

File diff suppressed because it is too large Load Diff

222
backend/routers/auth.py Normal file
View File

@@ -0,0 +1,222 @@
"""Authentication endpoints and dependencies."""
from datetime import datetime, timedelta, timezone
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import JWTError, jwt
from passlib.context import CryptContext
from pydantic import BaseModel
from sqlalchemy.orm import Session
from config import settings
from database import get_db
from models import AuditLog, User
from schemas import LoginResponse, UserOut
router = APIRouter(prefix="/api/auth", tags=["Auth"])
password_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
bearer_scheme = HTTPBearer(auto_error=False)
SUPPORTED_ROLES = {"admin", "annotator"}
class LoginRequest(BaseModel):
username: str
password: str
def hash_password(password: str) -> str:
"""Hash a plain password for storage."""
return password_context.hash(password)
def verify_password(password: str, password_hash: str) -> bool:
"""Verify a plain password against a stored hash."""
return password_context.verify(password, password_hash)
def create_access_token(user: User, expires_delta: timedelta | None = None) -> str:
"""Create a signed JWT access token for a user."""
expire = datetime.now(timezone.utc) + (
expires_delta or timedelta(minutes=settings.access_token_expire_minutes)
)
payload: dict[str, Any] = {
"sub": str(user.id),
"username": user.username,
"role": user.role,
"exp": expire,
}
return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
def ensure_default_admin(db: Session) -> User:
"""Create and enforce the single default administrator account."""
existing = db.query(User).filter(User.username == settings.default_admin_username).first()
if existing:
changed = False
if existing.role != "admin":
existing.role = "admin"
changed = True
if not existing.is_active:
existing.is_active = 1
changed = True
extra_admins = db.query(User).filter(
User.role == "admin",
User.id != existing.id,
).all()
for user in extra_admins:
user.role = "annotator"
changed = True
if changed:
db.commit()
db.refresh(existing)
return existing
user = User(
username=settings.default_admin_username,
password_hash=hash_password(settings.default_admin_password),
role="admin",
is_active=1,
)
db.add(user)
db.commit()
db.refresh(user)
extra_admins = db.query(User).filter(
User.role == "admin",
User.id != user.id,
).all()
if extra_admins:
for extra_user in extra_admins:
extra_user.role = "annotator"
db.commit()
db.refresh(user)
return user
def normalize_user_role(db: Session, user: User) -> User:
"""Keep legacy accounts within the current two-role policy."""
desired_role = "admin" if user.username == settings.default_admin_username else "annotator"
changed = False
if user.role != desired_role:
user.role = desired_role
changed = True
if user.username == settings.default_admin_username and not user.is_active:
user.is_active = 1
changed = True
if changed:
db.commit()
db.refresh(user)
return user
def get_current_user(
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
db: Session = Depends(get_db),
) -> User:
"""Resolve and validate the current user from the Bearer token."""
if credentials is None or credentials.scheme.lower() != "bearer":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(
credentials.credentials,
settings.jwt_secret_key,
algorithms=[settings.jwt_algorithm],
)
user_id = int(payload.get("sub"))
except (JWTError, TypeError, ValueError) as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
) from exc
user = db.query(User).filter(User.id == user_id).first()
if user:
user = normalize_user_role(db, user)
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Inactive or missing user",
headers={"WWW-Authenticate": "Bearer"},
)
return user
def require_admin(current_user: User = Depends(get_current_user)) -> User:
"""Require the current user to have the administrator role."""
if current_user.role != "admin":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin permission required")
return current_user
def require_editor(current_user: User = Depends(get_current_user)) -> User:
"""Require a user role that can modify segmentation data."""
if current_user.role not in SUPPORTED_ROLES:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Edit permission required")
return current_user
def write_audit_log(
db: Session,
*,
actor: User | None,
action: str,
target_type: str | None = None,
target_id: str | int | None = None,
detail: dict[str, Any] | None = None,
) -> AuditLog:
"""Persist a compact audit event."""
log = AuditLog(
actor_user_id=actor.id if actor else None,
action=action,
target_type=target_type,
target_id=str(target_id) if target_id is not None else None,
detail=detail or {},
)
db.add(log)
db.commit()
db.refresh(log)
return log
@router.post("/login", response_model=LoginResponse)
def login(payload: LoginRequest, db: Session = Depends(get_db)) -> dict:
"""Authenticate a user and return a signed JWT."""
ensure_default_admin(db)
user = db.query(User).filter(User.username == payload.username).first()
if user:
user = normalize_user_role(db, user)
if not user or not user.is_active or not verify_password(payload.password, user.password_hash):
write_audit_log(
db,
actor=None,
action="auth.login_failed",
target_type="user",
target_id=payload.username,
detail={"username": payload.username},
)
raise HTTPException(status_code=401, detail="Invalid credentials")
write_audit_log(
db,
actor=user,
action="auth.login_success",
target_type="user",
target_id=user.id,
detail={"username": user.username},
)
return {
"token": create_access_token(user),
"token_type": "bearer",
"username": user.username,
"user": user,
}
@router.get("/me", response_model=UserOut)
def read_current_user(current_user: User = Depends(get_current_user)) -> User:
"""Return the authenticated user profile."""
return current_user

View File

@@ -0,0 +1,164 @@
"""Dashboard overview endpoints."""
import os
from datetime import datetime, timezone
from typing import Any
from fastapi import APIRouter, Depends
from sqlalchemy import func, or_
from sqlalchemy.orm import Session
from database import get_db
from models import Annotation, Frame, ProcessingTask, Project, Template, User
from routers.auth import get_current_user
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
ACTIVE_TASK_STATUSES = {"queued", "running"}
MONITORED_TASK_STATUSES = {"queued", "running", "success", "failed", "cancelled"}
def _system_load_percent() -> int:
"""Return a real host load estimate without adding a psutil dependency."""
try:
load_1m = os.getloadavg()[0]
cpu_count = os.cpu_count() or 1
return min(100, max(0, round((load_1m / cpu_count) * 100)))
except (AttributeError, OSError):
return 0
def _iso_or_none(value: datetime | None) -> str | None:
if value is None:
return None
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
return value.isoformat()
def _task_payload(task: ProcessingTask) -> dict[str, Any]:
result = task.result or {}
return {
"id": f"task-{task.id}",
"task_id": task.id,
"project_id": task.project_id or 0,
"name": task.project.name if task.project else f"任务 {task.id}",
"progress": task.progress,
"status": task.message or task.status,
"raw_status": task.status,
"frame_count": result.get("frames_extracted", result.get("processed_frame_count", 0)),
"error": task.error,
"updated_at": _iso_or_none(task.updated_at),
}
@router.get("/overview", summary="Get dashboard overview")
def get_dashboard_overview(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> dict[str, Any]:
"""Return live dashboard data derived from persisted backend records."""
shared_project_ids_query = db.query(Project.id)
project_count = db.query(func.count(Project.id)).scalar() or 0
frame_count = db.query(func.count(Frame.id)).filter(Frame.project_id.in_(shared_project_ids_query)).scalar() or 0
annotation_count = (
db.query(func.count(Annotation.id))
.filter(Annotation.project_id.in_(shared_project_ids_query))
.scalar()
or 0
)
template_count = (
db.query(func.count(Template.id))
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
.scalar()
or 0
)
active_task_count = (
db.query(func.count(ProcessingTask.id))
.outerjoin(Project, Project.id == ProcessingTask.project_id)
.filter(ProcessingTask.status.in_(ACTIVE_TASK_STATUSES))
.scalar()
or 0
)
projects = (
db.query(Project)
.order_by(Project.updated_at.desc())
.all()
)
recent_tasks = (
db.query(ProcessingTask)
.outerjoin(Project, Project.id == ProcessingTask.project_id)
.order_by(ProcessingTask.created_at.desc())
.limit(50)
.all()
)
tasks = [_task_payload(task) for task in recent_tasks if task.status in MONITORED_TASK_STATUSES]
activities: list[dict[str, Any]] = []
for task in recent_tasks[:10]:
project_name = task.project.name if task.project else f"项目 {task.project_id}"
activities.append({
"id": f"task-{task.id}",
"kind": "task",
"time": _iso_or_none(task.updated_at),
"message": task.message or f"任务状态: {task.status}",
"project": project_name,
})
for project in projects[:10]:
activities.append({
"id": f"project-{project.id}",
"kind": "project",
"time": _iso_or_none(project.updated_at),
"message": f"项目状态: {project.status}",
"project": project.name,
})
recent_annotations = (
db.query(Annotation)
.filter(Annotation.project_id.in_(shared_project_ids_query))
.order_by(Annotation.updated_at.desc())
.limit(10)
.all()
)
for annotation in recent_annotations:
project_name = annotation.project.name if annotation.project else f"项目 {annotation.project_id}"
activities.append({
"id": f"annotation-{annotation.id}",
"kind": "annotation",
"time": _iso_or_none(annotation.updated_at),
"message": f"标注已更新 #{annotation.id}",
"project": project_name,
})
recent_templates = (
db.query(Template)
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
.order_by(Template.created_at.desc())
.limit(10)
.all()
)
for template in recent_templates:
activities.append({
"id": f"template-{template.id}",
"kind": "template",
"time": _iso_or_none(template.created_at),
"message": f"模板可用: {template.name}",
"project": "系统",
})
activities.sort(key=lambda item: item["time"] or "", reverse=True)
return {
"summary": {
"project_count": project_count,
"parsing_task_count": active_task_count,
"annotation_count": annotation_count,
"frame_count": frame_count,
"template_count": template_count,
"system_load_percent": _system_load_percent(),
},
"tasks": tasks,
"activity": activities[:10],
}

764
backend/routers/export.py Normal file
View File

@@ -0,0 +1,764 @@
"""Annotation export endpoints (COCO, PNG masks)."""
import io
import json
import logging
import os
import re
import zipfile
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List
from urllib.parse import quote
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from database import get_db
from minio_client import download_file
from models import Project, Annotation, Frame, Template, User
from routers.auth import get_current_user
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/export", tags=["Export"])
def _mask_from_polygon(
polygon: List[List[float]],
width: int,
height: int,
) -> np.ndarray:
"""Render a normalized polygon to a binary mask."""
import cv2
pts = np.array(
[[int(p[0] * width), int(p[1] * height)] for p in polygon],
dtype=np.int32,
)
mask = np.zeros((height, width), dtype=np.uint8)
cv2.fillPoly(mask, [pts], 255)
return mask
def _annotation_z_index(annotation: Annotation) -> int:
class_meta = (annotation.mask_data or {}).get("class") or {}
if isinstance(class_meta, dict) and class_meta.get("zIndex") is not None:
try:
return int(class_meta["zIndex"])
except (TypeError, ValueError):
pass
if annotation.template and annotation.template.z_index is not None:
return int(annotation.template.z_index)
return 0
def _annotation_mask_id(annotation: Annotation) -> int | None:
class_meta = (annotation.mask_data or {}).get("class") or {}
if isinstance(class_meta, dict):
for key in ("maskId", "maskid", "mask_id"):
if class_meta.get(key) is None:
continue
try:
value = int(class_meta[key])
except (TypeError, ValueError):
continue
if value >= 0:
return value
return None
def _annotation_category_name(annotation: Annotation) -> str:
class_meta = (annotation.mask_data or {}).get("class") or {}
if isinstance(class_meta, dict) and class_meta.get("category"):
return str(class_meta["category"])
if annotation.template and annotation.template.name:
return str(annotation.template.name)
return ""
def _annotation_class_key(annotation: Annotation) -> str:
class_meta = (annotation.mask_data or {}).get("class") or {}
if isinstance(class_meta, dict):
if class_meta.get("id"):
return f"class:{class_meta['id']}"
if class_meta.get("name"):
return f"name:{class_meta['name']}"
if annotation.template_id:
return f"template:{annotation.template_id}"
return f"annotation:{annotation.id}"
def _annotation_label(annotation: Annotation) -> str:
mask_data = annotation.mask_data or {}
class_meta = mask_data.get("class") or {}
if isinstance(class_meta, dict) and class_meta.get("name"):
return str(class_meta["name"])
if mask_data.get("label"):
return str(mask_data["label"])
if annotation.template and annotation.template.name:
return str(annotation.template.name)
return f"Annotation {annotation.id}"
def _annotation_color(annotation: Annotation) -> str:
mask_data = annotation.mask_data or {}
class_meta = mask_data.get("class") or {}
if isinstance(class_meta, dict) and class_meta.get("color"):
return str(class_meta["color"])
if mask_data.get("color"):
return str(mask_data["color"])
if annotation.template and annotation.template.color:
return str(annotation.template.color)
return "#ffffff"
def _hex_to_rgb(color: str) -> list[int]:
value = str(color or "").strip()
if value.startswith("#"):
value = value[1:]
if len(value) == 3:
value = "".join(part * 2 for part in value)
if len(value) != 6:
return [255, 255, 255]
try:
return [int(value[i:i + 2], 16) for i in (0, 2, 4)]
except ValueError:
return [255, 255, 255]
def _safe_filename_part(value: Any, fallback: str = "unknown") -> str:
text = str(value or "").strip()
if not text:
text = fallback
text = re.sub(r"[\\/:*?\"<>|\s]+", "_", text)
text = re.sub(r"_+", "_", text).strip("._")
return text or fallback
def _project_video_name(project: Project) -> str:
if project.video_path:
stem = Path(project.video_path).name
if "." in stem:
stem = ".".join(stem.split(".")[:-1])
if stem:
return _safe_filename_part(stem, f"project_{project.id}")
return _safe_filename_part(project.name, f"project_{project.id}")
def _project_export_name(project: Project) -> str:
return _safe_filename_part(project.name, f"project_{project.id}")
def _frame_timestamp_ms(frame: Frame, project: Project) -> float:
if frame.timestamp_ms is not None:
return float(frame.timestamp_ms)
fps = project.parse_fps or project.original_fps or 30.0
return float(frame.frame_index) * 1000.0 / max(float(fps), 1.0)
def _project_frame_number(frame: Frame) -> int:
return int(frame.frame_index) + 1
def _format_timestamp_ms(value: float) -> str:
total_ms = max(0, int(round(float(value))))
hours = total_ms // 3_600_000
minutes = (total_ms % 3_600_000) // 60_000
seconds = (total_ms % 60_000) // 1_000
milliseconds = total_ms % 1_000
return f"{hours}h{minutes:02d}m{seconds:02d}s{milliseconds:03d}ms"
def _frame_export_stem(project: Project, frame: Frame) -> str:
return "_".join([
_project_video_name(project),
_format_timestamp_ms(_frame_timestamp_ms(frame, project)),
f"frame{_project_frame_number(frame):06d}",
])
def _segmentation_results_filename(project: Project, frames: list[Frame]) -> str:
if not frames:
return f"{_project_export_name(project)}_seg_T_0h00m00s000ms-0h00m00s000ms_P_0-0.zip"
first_frame = frames[0]
last_frame = frames[-1]
return (
f"{_project_export_name(project)}"
f"_seg_T_{_format_timestamp_ms(_frame_timestamp_ms(first_frame, project))}"
f"-{_format_timestamp_ms(_frame_timestamp_ms(last_frame, project))}"
f"_P_{_project_frame_number(first_frame)}-{_project_frame_number(last_frame)}.zip"
)
def _download_content_disposition(filename: str) -> str:
ascii_fallback = filename.encode("ascii", "ignore").decode("ascii") or "segmentation_results.zip"
ascii_fallback = _safe_filename_part(ascii_fallback, "segmentation_results.zip")
if not ascii_fallback.endswith(".zip") and filename.endswith(".zip"):
ascii_fallback = f"{ascii_fallback}.zip"
return f"attachment; filename=\"{ascii_fallback}\"; filename*=UTF-8''{quote(filename)}"
def _frame_image_extension(frame: Frame) -> str:
suffix = Path(frame.image_url or "").suffix.lower()
return suffix if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} else ".jpg"
def _project_or_404(project_id: int, db: Session, current_user: User) -> Project:
_ = current_user
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return project
def _project_frames(project_id: int, db: Session) -> list[Frame]:
return (
db.query(Frame)
.filter(Frame.project_id == project_id)
.order_by(Frame.frame_index)
.all()
)
def _filter_frames(
frames: list[Frame],
*,
scope: str = "all",
start_frame: int | None = None,
end_frame: int | None = None,
frame_id: int | None = None,
) -> list[Frame]:
if scope == "current":
if frame_id is None:
raise HTTPException(status_code=400, detail="frame_id is required for current-frame export")
selected = [frame for frame in frames if frame.id == frame_id]
if not selected:
raise HTTPException(status_code=404, detail="Frame not found")
return selected
if scope == "range":
if start_frame is None or end_frame is None:
raise HTTPException(status_code=400, detail="start_frame and end_frame are required for range export")
start = max(1, min(int(start_frame), int(end_frame)))
end = max(1, max(int(start_frame), int(end_frame)))
return frames[start - 1:end]
return frames
def _filtered_annotations(project_id: int, frame_ids: set[int], db: Session) -> list[Annotation]:
if not frame_ids:
return []
return (
db.query(Annotation)
.filter(Annotation.project_id == project_id)
.filter(Annotation.frame_id.in_(frame_ids))
.all()
)
def _build_coco(project: Project, frames: list[Frame], annotations: list[Annotation], templates: list[Template]) -> dict[str, Any]:
images = []
for frame in frames:
images.append({
"id": frame.id,
"file_name": frame.image_url,
"width": frame.width or 1920,
"height": frame.height or 1080,
"frame_index": frame.frame_index,
})
categories = []
template_id_to_cat_id: Dict[int, int] = {}
for cat_idx, tmpl in enumerate(templates, start=1):
categories.append({
"id": cat_idx,
"name": tmpl.name,
"color": tmpl.color,
})
template_id_to_cat_id[tmpl.id] = cat_idx
coco_annotations = []
ann_id = 1
selected_frame_ids = {frame.id for frame in frames}
for ann in annotations:
if ann.frame_id not in selected_frame_ids or not ann.mask_data:
continue
polygons = ann.mask_data.get("polygons", [])
if not polygons:
continue
first_poly = polygons[0]
xs = [p[0] for p in first_poly]
ys = [p[1] for p in first_poly]
width = ann.frame.width if ann.frame else 1920
height = ann.frame.height if ann.frame else 1080
bbox = [
min(xs) * width,
min(ys) * height,
(max(xs) - min(xs)) * width,
(max(ys) - min(ys)) * height,
]
area = bbox[2] * bbox[3]
segmentation = []
for poly in polygons:
flat = []
for p in poly:
flat.append(p[0] * width)
flat.append(p[1] * height)
segmentation.append(flat)
coco_annotations.append({
"id": ann_id,
"image_id": ann.frame_id,
"category_id": template_id_to_cat_id.get(ann.template_id, 0),
"segmentation": segmentation,
"area": area,
"bbox": bbox,
"iscrowd": 0,
})
ann_id += 1
return {
"info": {
"description": f"Annotations for {project.name}",
"version": "1.0",
"year": datetime.now().year,
"date_created": datetime.now().isoformat(),
},
"images": images,
"annotations": coco_annotations,
"categories": categories,
}
def _class_mapping_entry(annotation: Annotation) -> dict[str, Any]:
return {
"key": _annotation_class_key(annotation),
"className": _annotation_label(annotation),
"chineseName": _annotation_label(annotation),
"categoryName": _annotation_category_name(annotation),
"color": _annotation_color(annotation),
"internalPriority": _annotation_z_index(annotation),
"maskidHint": _annotation_mask_id(annotation),
"template_id": annotation.template_id,
}
def _build_gt_class_mapping(annotations: list[Annotation]) -> tuple[dict[str, int], list[dict[str, Any]]]:
entries_by_key: dict[str, dict[str, Any]] = {}
for annotation in annotations:
if not annotation.mask_data or not annotation.mask_data.get("polygons"):
continue
entry = _class_mapping_entry(annotation)
entries_by_key.setdefault(entry["key"], entry)
ordered = sorted(
entries_by_key.values(),
key=lambda item: (
item["maskidHint"] if isinstance(item.get("maskidHint"), int) and item["maskidHint"] >= 0 else 10_000_000,
str(item["className"]),
str(item["key"]),
),
)
key_to_value: dict[str, int] = {}
classes: list[dict[str, Any]] = []
used_maskids: set[int] = set()
next_maskid = 1
def next_available_maskid() -> int:
nonlocal next_maskid
while next_maskid in used_maskids:
next_maskid += 1
if next_maskid > 255:
raise HTTPException(status_code=400, detail="GT_label 仅支持 8-bit maskid类别值必须在 1-255 之间")
value = next_maskid
used_maskids.add(value)
next_maskid += 1
return value
for entry in ordered:
hinted_maskid = entry.get("maskidHint")
if isinstance(hinted_maskid, int) and hinted_maskid > 255:
raise HTTPException(status_code=400, detail="GT_label 仅支持 8-bit maskid类别值必须在 1-255 之间")
if isinstance(hinted_maskid, int) and hinted_maskid == 0:
maskid = 0
used_maskids.add(maskid)
elif isinstance(hinted_maskid, int) and 0 < hinted_maskid <= 255 and hinted_maskid not in used_maskids:
maskid = hinted_maskid
used_maskids.add(maskid)
else:
maskid = next_available_maskid()
key_to_value[entry["key"]] = maskid
classes.append({
"gt_pixel_value": maskid,
"maskid": maskid,
"chineseName": entry["chineseName"],
"className": entry["className"],
"categoryName": entry["categoryName"],
"rgb": _hex_to_rgb(entry["color"]),
"color": entry["color"],
"key": entry["key"],
"template_id": entry["template_id"],
})
return key_to_value, classes
def _parse_result_outputs(mask_type: str, outputs: str | None) -> set[str]:
allowed = {"separate", "gt_label", "pro_label", "mix_label"}
if outputs:
parsed = {item.strip() for item in outputs.split(",") if item.strip()}
invalid = parsed - allowed
if invalid:
raise HTTPException(status_code=400, detail=f"Invalid outputs: {', '.join(sorted(invalid))}")
return parsed or allowed
if mask_type == "separate":
return {"separate"}
if mask_type == "gt_label":
return {"gt_label"}
if mask_type == "pro_label":
return {"pro_label"}
if mask_type == "mix_label":
return {"mix_label"}
return allowed
def _write_original_frames(
zf: zipfile.ZipFile,
project: Project,
frames: list[Frame],
) -> dict[int, bytes]:
image_bytes_by_frame: dict[int, bytes] = {}
for frame in frames:
image_bytes = download_file(frame.image_url)
image_bytes_by_frame[frame.id] = image_bytes
zf.writestr(
f"原始图片/{_frame_export_stem(project, frame)}{_frame_image_extension(frame)}",
image_bytes,
)
return image_bytes_by_frame
def _decode_original_image(image_bytes: bytes | None, width: int, height: int) -> np.ndarray:
import cv2
if image_bytes:
decoded = cv2.imdecode(np.frombuffer(image_bytes, dtype=np.uint8), cv2.IMREAD_COLOR)
if decoded is not None:
if decoded.shape[1] != width or decoded.shape[0] != height:
decoded = cv2.resize(decoded, (width, height), interpolation=cv2.INTER_AREA)
return decoded
return np.zeros((height, width, 3), dtype=np.uint8)
def _write_result_mask_outputs(
zf: zipfile.ZipFile,
project: Project,
frames: list[Frame],
annotations: list[Annotation],
*,
outputs: set[str],
class_values: dict[str, int],
class_mapping: list[dict[str, Any]],
original_images: dict[int, bytes],
mix_opacity: float,
) -> None:
import cv2
include_individual = "separate" in outputs
include_semantic = "gt_label" in outputs
include_pro_label = "pro_label" in outputs
include_mix_label = "mix_label" in outputs
class_rgb_by_key = {
item["key"]: item.get("rgb") or _hex_to_rgb(item.get("color", "#ffffff"))
for item in class_mapping
}
annotations_by_frame: dict[int, list[Annotation]] = {}
selected_frame_ids = {frame.id for frame in frames}
for annotation in annotations:
if annotation.frame_id not in selected_frame_ids or not annotation.mask_data:
continue
if not annotation.mask_data.get("polygons"):
continue
annotations_by_frame.setdefault(annotation.frame_id, []).append(annotation)
for frame in frames:
frame_annotations = annotations_by_frame.get(frame.id, [])
if not frame_annotations:
continue
width = frame.width or 1920
height = frame.height or 1080
frame_stem = _frame_export_stem(project, frame)
if include_individual:
class_masks: dict[str, np.ndarray] = {}
class_meta: dict[str, dict[str, Any]] = {}
for annotation in frame_annotations:
key = _annotation_class_key(annotation)
combined = class_masks.setdefault(key, np.zeros((height, width), dtype=np.uint8))
for poly in (annotation.mask_data or {}).get("polygons", []):
combined[:] = np.maximum(combined, _mask_from_polygon(poly, width, height))
class_meta.setdefault(key, _class_mapping_entry(annotation))
folder = f"分开Mask分割结果/{frame_stem}_分别导出"
for key, mask in sorted(class_masks.items(), key=lambda item: int(class_meta[item[0]]["internalPriority"])):
meta = class_meta[key]
maskid = class_values.get(key)
if maskid is None:
continue
_, encoded = cv2.imencode(".png", mask)
class_name = _safe_filename_part(meta["className"], "class")
zf.writestr(
f"{folder}/{frame_stem}_{class_name}_maskid{maskid}.png",
encoded.tobytes(),
)
needs_fused_output = include_semantic or include_pro_label or include_mix_label
semantic = np.zeros((height, width), dtype=np.uint8) if needs_fused_output else None
pro_label = np.zeros((height, width, 3), dtype=np.uint8) if (include_pro_label or include_mix_label) else None
if needs_fused_output:
for annotation in sorted(frame_annotations, key=_annotation_z_index):
key = _annotation_class_key(annotation)
value = class_values.get(key)
if value is None:
continue
combined = np.zeros((height, width), dtype=np.uint8)
for poly in (annotation.mask_data or {}).get("polygons", []):
combined = np.maximum(combined, _mask_from_polygon(poly, width, height))
if semantic is not None:
semantic[combined > 0] = value
if pro_label is not None:
rgb = class_rgb_by_key.get(key, [255, 255, 255])
bgr = np.array([rgb[2], rgb[1], rgb[0]], dtype=np.uint8)
pro_label[combined > 0] = bgr
if include_semantic and semantic is not None:
_, encoded = cv2.imencode(".png", semantic)
zf.writestr(f"GT_label图/{frame_stem}.png", encoded.tobytes())
if include_pro_label and pro_label is not None:
_, encoded = cv2.imencode(".png", pro_label)
zf.writestr(f"Pro_label彩色分割结果/{frame_stem}.png", encoded.tobytes())
if include_mix_label and pro_label is not None:
original = _decode_original_image(original_images.get(frame.id), width, height)
mask_pixels = np.any(pro_label > 0, axis=2)
mixed = original.copy()
opacity = min(max(float(mix_opacity), 0.0), 1.0)
mixed[mask_pixels] = (
original[mask_pixels].astype(np.float32) * (1.0 - opacity)
+ pro_label[mask_pixels].astype(np.float32) * opacity
).clip(0, 255).astype(np.uint8)
_, encoded = cv2.imencode(".png", mixed)
zf.writestr(f"Mix_label重叠覆盖彩色分割结果/{frame_stem}.png", encoded.tobytes())
def _write_mask_pngs(
zf: zipfile.ZipFile,
frames: list[Frame],
annotations: list[Annotation],
*,
mask_type: str,
individual_prefix: str = "",
semantic_prefix: str = "",
semantic_file_stem: str = "semantic_frame",
semantic_dtype: Any = np.uint8,
) -> list[dict[str, Any]]:
import cv2
class_values: dict[str, int] = {}
semantic_classes: list[dict[str, Any]] = []
def class_value(annotation: Annotation) -> int:
key = _annotation_class_key(annotation)
if key not in class_values:
value = len(class_values) + 1
class_values[key] = value
semantic_classes.append({
"value": value,
"key": key,
"label": _annotation_label(annotation),
"color": _annotation_color(annotation),
"zIndex": _annotation_z_index(annotation),
"template_id": annotation.template_id,
})
return class_values[key]
include_individual = mask_type in {"separate", "both"}
include_semantic = mask_type in {"gt_label", "both"}
frame_masks: dict[int, list[tuple[Annotation, np.ndarray]]] = {}
selected_frame_ids = {frame.id for frame in frames}
for ann in annotations:
if ann.frame_id not in selected_frame_ids or not ann.mask_data:
continue
polygons = ann.mask_data.get("polygons", [])
if not polygons:
continue
width = ann.frame.width if ann.frame else 1920
height = ann.frame.height if ann.frame else 1080
combined = np.zeros((height, width), dtype=np.uint8)
for poly in polygons:
mask = _mask_from_polygon(poly, width, height)
combined = np.maximum(combined, mask)
if include_individual:
_, encoded = cv2.imencode(".png", combined)
zf.writestr(f"{individual_prefix}mask_{ann.id:06d}.png", encoded.tobytes())
if include_semantic and ann.frame_id is not None:
frame_masks.setdefault(ann.frame_id, []).append((ann, combined))
if include_semantic:
for frame in frames:
entries = frame_masks.get(frame.id, [])
if not entries:
continue
width = frame.width or 1920
height = frame.height or 1080
semantic = np.zeros((height, width), dtype=semantic_dtype)
for ann, mask in sorted(entries, key=lambda item: _annotation_z_index(item[0])):
semantic[mask > 0] = class_value(ann)
_, encoded = cv2.imencode(".png", semantic)
zf.writestr(f"{semantic_prefix}{semantic_file_stem}_{frame.frame_index:06d}.png", encoded.tobytes())
if include_semantic:
zf.writestr(
f"{semantic_prefix}semantic_classes.json",
json.dumps({"classes": semantic_classes}, ensure_ascii=False, indent=2).encode("utf-8"),
)
return semantic_classes
@router.get(
"/{project_id}/coco",
summary="Export annotations in COCO format",
)
def export_coco(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> StreamingResponse:
"""Export all annotations for a project as a COCO-format JSON file."""
project = _project_or_404(project_id, db, current_user)
frames = _project_frames(project_id, db)
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
templates = db.query(Template).all()
coco = _build_coco(project, frames, annotations, templates)
data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8")
filename = f"project_{project_id}_coco.json"
return StreamingResponse(
io.BytesIO(data),
media_type="application/json",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
@router.get(
"/{project_id}/masks",
summary="Export PNG masks as a ZIP archive",
)
def export_masks(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> StreamingResponse:
"""Export individual masks plus z-index fused semantic masks inside a ZIP."""
_project_or_404(project_id, db, current_user)
frames = _project_frames(project_id, db)
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
_write_mask_pngs(
zf,
frames,
annotations,
mask_type="both",
semantic_prefix="",
individual_prefix="",
)
zip_buffer.seek(0)
filename = f"project_{project_id}_masks.zip"
return StreamingResponse(
zip_buffer,
media_type="application/zip",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
@router.get(
"/{project_id}/results",
summary="Export segmentation results as a ZIP archive",
)
def export_results(
project_id: int,
scope: str = Query("all", pattern="^(all|range|current)$"),
mask_type: str = Query("both", pattern="^(separate|gt_label|pro_label|mix_label|both|all)$"),
outputs: str | None = Query(None),
mix_opacity: float = Query(0.3, ge=0.0, le=1.0),
start_frame: int | None = Query(None, ge=1),
end_frame: int | None = Query(None, ge=1),
frame_id: int | None = Query(None, ge=1),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> StreamingResponse:
"""Export JSON annotations plus selected PNG mask outputs inside one ZIP.
`scope=all` exports the whole video. `scope=range` uses 1-based frame
numbers from the sorted project frame sequence. `scope=current` uses the
concrete backend `frame_id`.
"""
project = _project_or_404(project_id, db, current_user)
frames = _filter_frames(
_project_frames(project_id, db),
scope=scope,
start_frame=start_frame,
end_frame=end_frame,
frame_id=frame_id,
)
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
templates = db.query(Template).all()
coco = _build_coco(project, frames, annotations, templates)
class_values, class_mapping = _build_gt_class_mapping(annotations)
selected_outputs = _parse_result_outputs(mask_type, outputs)
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
zf.writestr(
"annotations_coco.json",
json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8"),
)
zf.writestr(
"maskid_GT像素值_类别映射.json",
json.dumps({"classes": class_mapping}, ensure_ascii=False, indent=2).encode("utf-8"),
)
original_images = _write_original_frames(zf, project, frames)
_write_result_mask_outputs(
zf,
project,
frames,
annotations,
outputs=selected_outputs,
class_values=class_values,
class_mapping=class_mapping,
original_images=original_images,
mix_opacity=mix_opacity,
)
zip_buffer.seek(0)
filename = _segmentation_results_filename(project, frames)
return StreamingResponse(
zip_buffer,
media_type="application/zip",
headers={"Content-Disposition": _download_content_disposition(filename)},
)

234
backend/routers/media.py Normal file
View File

@@ -0,0 +1,234 @@
"""Media upload and parsing endpoints."""
import logging
import re
from pathlib import Path
from typing import List, Optional
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile, status
from sqlalchemy.orm import Session
from database import get_db
from minio_client import upload_file, get_presigned_url
from models import ProcessingTask, Project, User
from progress_events import publish_task_progress_event
from routers.auth import require_editor
from schemas import ProcessingTaskOut
from statuses import PROJECT_STATUS_PARSING, PROJECT_STATUS_PENDING, TASK_STATUS_QUEUED
from worker_tasks import parse_project_media
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/media", tags=["Media"])
ALLOWED_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".png", ".jpg", ".jpeg", ".dcm"}
def natural_filename_key(filename: str) -> tuple[object, ...]:
return tuple(
int(part) if part.isdigit() else part.casefold()
for part in re.split(r"(\d+)", Path(filename).name)
)
def _get_ext(filename: str) -> str:
return Path(filename).suffix.lower()
@router.post(
"/upload",
status_code=status.HTTP_201_CREATED,
summary="Upload a media file",
)
async def upload_media(
file: UploadFile = File(...),
project_id: Optional[int] = Form(None),
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> dict:
"""Accept a video, image, or DICOM file and store it in MinIO.
If project_id is provided, the video_path of the project is updated.
Returns the presigned URL of the uploaded object.
"""
if not file.filename:
raise HTTPException(status_code=400, detail="Missing filename")
ext = _get_ext(file.filename)
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {ext}",
)
data = await file.read()
object_name = f"uploads/{project_id or 'general'}/{file.filename}"
try:
upload_file(object_name, data, content_type=file.content_type or "application/octet-stream", length=len(data))
except Exception as exc: # noqa: BLE001
logger.error("Upload failed: %s", exc)
raise HTTPException(status_code=500, detail="Upload to storage failed") from exc
file_url = get_presigned_url(object_name, expires=3600)
if project_id:
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
project.video_path = object_name
db.commit()
logger.info("Linked upload to project_id=%s", project_id)
else:
# Auto-create a project named after the file
project = Project(
name=file.filename,
description="Auto-created from upload",
status=PROJECT_STATUS_PENDING,
video_path=object_name,
source_type="video",
owner_user_id=current_user.id,
)
db.add(project)
db.commit()
db.refresh(project)
project_id = project.id
object_name = f"uploads/{project_id}/{file.filename}"
# Re-upload with corrected path
upload_file(object_name, data, content_type=file.content_type or "application/octet-stream", length=len(data))
project.video_path = object_name
db.commit()
logger.info("Auto-created project id=%s for upload %s", project_id, file.filename)
logger.info("Upload complete: %s (size=%d bytes). Async parsing queued.", object_name, len(data))
return {
"object_name": object_name,
"file_url": file_url,
"size": len(data),
"project_id": project_id,
"message": "Upload successful. Parsing job queued.",
}
@router.post(
"/upload/dicom",
status_code=status.HTTP_201_CREATED,
summary="Upload multiple DICOM files",
)
async def upload_dicom_batch(
files: List[UploadFile] = File(...),
project_id: Optional[int] = Form(None),
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> dict:
"""Upload multiple .dcm files for a DICOM series.
If project_id is provided, files are added to the existing project.
Otherwise a new DICOM project is created.
"""
if not files:
raise HTTPException(status_code=400, detail="No files uploaded")
sorted_files = sorted(
[file for file in files if file.filename and file.filename.lower().endswith(".dcm")],
key=lambda file: natural_filename_key(file.filename or ""),
)
if not sorted_files:
raise HTTPException(status_code=400, detail="No valid DICOM files uploaded")
uploaded = []
if project_id:
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
else:
# Create new DICOM project
first_name = sorted_files[0].filename or "DICOM_Series"
project = Project(
name=first_name,
description=f"DICOM series with {len(sorted_files)} files",
status=PROJECT_STATUS_PENDING,
source_type="dicom",
owner_user_id=current_user.id,
)
db.add(project)
db.commit()
db.refresh(project)
project_id = project.id
logger.info("Auto-created DICOM project id=%s", project_id)
for file in sorted_files:
data = await file.read()
object_name = f"uploads/{project_id}/dicom/{file.filename}"
try:
upload_file(object_name, data, content_type="application/dicom", length=len(data))
uploaded.append(object_name)
except Exception as exc: # noqa: BLE001
logger.error("Failed to upload DICOM %s: %s", file.filename, exc)
project.video_path = f"uploads/{project_id}/dicom"
db.commit()
return {
"project_id": project_id,
"uploaded_count": len(uploaded),
"message": f"Uploaded {len(uploaded)} DICOM files. Parsing job queued.",
}
@router.post(
"/parse",
status_code=status.HTTP_202_ACCEPTED,
response_model=ProcessingTaskOut,
summary="Trigger frame extraction",
)
def parse_media(
project_id: int,
source_type: Optional[str] = None,
parse_fps: Optional[float] = Query(None, gt=0, le=120),
max_frames: Optional[int] = Query(None, gt=0),
target_width: int = Query(640, ge=64, le=4096),
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> ProcessingTask:
"""Create a background task for media frame extraction.
The Celery worker performs the heavy FFmpeg/OpenCV/pydicom work and
updates the persisted task record as it progresses.
"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if not project.video_path:
raise HTTPException(status_code=400, detail="Project has no media uploaded")
effective_source = source_type or project.source_type or "video"
effective_parse_fps = parse_fps or project.parse_fps or 30.0
task = ProcessingTask(
task_type=f"parse_{effective_source}",
status=TASK_STATUS_QUEUED,
progress=0,
message="解析任务已入队",
project_id=project_id,
payload={
"source_type": effective_source,
"parse_fps": effective_parse_fps,
"max_frames": max_frames,
"target_width": target_width,
},
)
project.parse_fps = effective_parse_fps
project.status = PROJECT_STATUS_PARSING
db.add(task)
db.commit()
db.refresh(task)
publish_task_progress_event(task)
async_result = parse_project_media.delay(task.id)
task.celery_task_id = async_result.id
db.commit()
db.refresh(task)
logger.info("Queued parse task id=%s project_id=%s celery_id=%s", task.id, project_id, async_result.id)
return task

310
backend/routers/projects.py Normal file
View File

@@ -0,0 +1,310 @@
"""Project and Frame CRUD endpoints."""
import logging
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from database import get_db
from models import Annotation, Mask, Project, Frame, User
from routers.auth import get_current_user, require_editor
from schemas import ProjectCopyRequest, ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut
from minio_client import get_presigned_url
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/projects", tags=["Projects"])
def _next_project_copy_name(db: Session, source_name: str) -> str:
base_name = f"{source_name} 副本"
existing_names = {
row[0]
for row in db.query(Project.name)
.filter(Project.name.like(f"{base_name}%"))
.all()
}
if base_name not in existing_names:
return base_name
suffix = 2
while f"{base_name} {suffix}" in existing_names:
suffix += 1
return f"{base_name} {suffix}"
def _prepare_project_response(project: Project) -> Project:
project.frame_count = len(project.frames)
if project.thumbnail_url:
project.thumbnail_url = get_presigned_url(project.thumbnail_url, expires=3600)
return project
# ---------------------------------------------------------------------------
# Projects
# ---------------------------------------------------------------------------
@router.post(
"",
response_model=ProjectOut,
status_code=status.HTTP_201_CREATED,
summary="Create a new project",
)
def create_project(
payload: ProjectCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Project:
"""Create a new segmentation project."""
project = Project(**payload.model_dump(), owner_user_id=current_user.id)
db.add(project)
db.commit()
db.refresh(project)
logger.info("Created project id=%s name=%s", project.id, project.name)
return project
@router.get(
"",
response_model=List[ProjectOut],
summary="List all projects",
)
def list_projects(
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> List[Project]:
"""Retrieve a paginated list of projects."""
projects = (
db.query(Project)
.offset(skip)
.limit(limit)
.all()
)
for p in projects:
_prepare_project_response(p)
return projects
@router.get(
"/{project_id}",
response_model=ProjectOut,
summary="Get a single project",
)
def get_project(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Project:
"""Retrieve a project by its ID."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return _prepare_project_response(project)
@router.post(
"/{project_id}/copy",
response_model=ProjectOut,
status_code=status.HTTP_201_CREATED,
summary="Copy a project",
)
def copy_project(
project_id: int,
payload: ProjectCopyRequest,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Project:
"""Copy a project. Reset copies media/frame sequence; full also copies annotations and mask metadata."""
source = db.query(Project).filter(Project.id == project_id).first()
if not source:
raise HTTPException(status_code=404, detail="Project not found")
next_name = (payload.name or "").strip() if payload.name is not None else ""
if not next_name:
next_name = _next_project_copy_name(db, source.name)
copied = Project(
name=next_name,
description=source.description,
video_path=source.video_path,
thumbnail_url=source.thumbnail_url,
status=source.status,
source_type=source.source_type,
original_fps=source.original_fps,
parse_fps=source.parse_fps,
owner_user_id=current_user.id,
)
db.add(copied)
db.flush()
frame_id_map: dict[int, int] = {}
for frame in sorted(source.frames, key=lambda item: item.frame_index):
copied_frame = Frame(
project_id=copied.id,
frame_index=frame.frame_index,
image_url=frame.image_url,
width=frame.width,
height=frame.height,
timestamp_ms=frame.timestamp_ms,
source_frame_number=frame.source_frame_number,
)
db.add(copied_frame)
db.flush()
frame_id_map[frame.id] = copied_frame.id
if payload.mode == "full":
for annotation in sorted(source.annotations, key=lambda item: item.id):
copied_annotation = Annotation(
project_id=copied.id,
frame_id=frame_id_map.get(annotation.frame_id) if annotation.frame_id is not None else None,
template_id=annotation.template_id,
mask_data=annotation.mask_data,
points=annotation.points,
bbox=annotation.bbox,
)
db.add(copied_annotation)
db.flush()
for mask in annotation.masks:
db.add(Mask(
annotation_id=copied_annotation.id,
mask_url=mask.mask_url,
format=mask.format,
))
db.commit()
db.refresh(copied)
logger.info("Copied project id=%s to id=%s mode=%s", project_id, copied.id, payload.mode)
return _prepare_project_response(copied)
@router.patch(
"/{project_id}",
response_model=ProjectOut,
summary="Update a project",
)
def update_project(
project_id: int,
payload: ProjectUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Project:
"""Update project fields partially."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
for key, value in payload.model_dump(exclude_unset=True).items():
if key == "name":
value = (value or "").strip()
if not value:
raise HTTPException(status_code=400, detail="Project name is required")
setattr(project, key, value)
db.commit()
db.refresh(project)
logger.info("Updated project id=%s", project_id)
return project
@router.delete(
"/{project_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a project",
)
def delete_project(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> None:
"""Delete a project and all related frames and annotations."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
db.delete(project)
db.commit()
logger.info("Deleted project id=%s", project_id)
# ---------------------------------------------------------------------------
# Frames
# ---------------------------------------------------------------------------
@router.post(
"/{project_id}/frames",
response_model=FrameOut,
status_code=status.HTTP_201_CREATED,
summary="Add a frame to a project",
)
def create_frame(
project_id: int,
payload: FrameCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Frame:
"""Register a new frame under a project."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
frame = Frame(project_id=project_id, **payload.model_dump(exclude={"project_id"}))
db.add(frame)
db.commit()
db.refresh(frame)
return frame
@router.get(
"/{project_id}/frames",
response_model=List[FrameOut],
summary="List frames for a project",
)
def list_frames(
project_id: int,
skip: int = 0,
limit: Optional[int] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> List[Frame]:
"""Retrieve all frames belonging to a project."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
query = (
db.query(Frame)
.filter(Frame.project_id == project_id)
.order_by(Frame.frame_index)
.offset(skip)
)
if limit is not None:
query = query.limit(limit)
frames = query.all()
for frame in frames:
frame.image_url = get_presigned_url(frame.image_url, expires=3600)
return frames
@router.get(
"/{project_id}/frames/{frame_id}",
response_model=FrameOut,
summary="Get a single frame",
)
def get_frame(
project_id: int,
frame_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Frame:
"""Retrieve a specific frame by ID."""
frame = (
db.query(Frame)
.join(Project, Project.id == Frame.project_id)
.filter(
Frame.project_id == project_id,
Frame.id == frame_id,
)
.first()
)
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
return frame

161
backend/routers/tasks.py Normal file
View File

@@ -0,0 +1,161 @@
"""Processing task query endpoints."""
import logging
from datetime import datetime, timezone
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from celery_app import celery_app
from database import get_db
from models import ProcessingTask, Project, User
from progress_events import publish_task_progress_event
from routers.auth import get_current_user, require_editor
from schemas import ProcessingTaskOut
from statuses import (
PROJECT_STATUS_PARSING,
PROJECT_STATUS_PENDING,
PROJECT_STATUS_READY,
TASK_ACTIVE_STATUSES,
TASK_STATUS_CANCELLED,
TASK_STATUS_FAILED,
TASK_STATUS_QUEUED,
)
from worker_tasks import parse_project_media, propagate_project_masks
router = APIRouter(prefix="/api/tasks", tags=["Tasks"])
logger = logging.getLogger(__name__)
def _now() -> datetime:
return datetime.now(timezone.utc)
def _get_task_or_404(task_id: int, db: Session, current_user: User) -> ProcessingTask:
_ = current_user
task = (
db.query(ProcessingTask)
.outerjoin(Project, Project.id == ProcessingTask.project_id)
.filter(ProcessingTask.id == task_id)
.first()
)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task
def _project_status_after_stop(project: Project) -> str:
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
@router.get("", response_model=List[ProcessingTaskOut], summary="List processing tasks")
def list_tasks(
project_id: int | None = None,
status: str | None = None,
limit: int = 50,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> List[ProcessingTask]:
"""Return recent background processing tasks."""
_ = current_user
query = db.query(ProcessingTask).outerjoin(Project, Project.id == ProcessingTask.project_id)
if project_id is not None:
query = query.filter(ProcessingTask.project_id == project_id)
if status is not None:
query = query.filter(ProcessingTask.status == status)
return query.order_by(ProcessingTask.created_at.desc()).limit(limit).all()
@router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task")
def get_task(
task_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> ProcessingTask:
"""Return one background task by id."""
return _get_task_or_404(task_id, db, current_user)
@router.post("/{task_id}/cancel", response_model=ProcessingTaskOut, summary="Cancel processing task")
def cancel_task(
task_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> ProcessingTask:
"""Cancel a queued/running background task and revoke the Celery job when possible."""
task = _get_task_or_404(task_id, db, current_user)
if task.status not in TASK_ACTIVE_STATUSES:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Task is not cancellable in status: {task.status}",
)
if task.celery_task_id:
try:
celery_app.control.revoke(task.celery_task_id, terminate=True, signal="SIGTERM")
except Exception as exc: # noqa: BLE001
logger.warning("Failed to revoke celery task %s: %s", task.celery_task_id, exc)
task.status = TASK_STATUS_CANCELLED
task.progress = 100
task.message = "任务已取消"
task.error = "Cancelled by user"
task.finished_at = _now()
if task.project:
task.project.status = _project_status_after_stop(task.project)
db.commit()
db.refresh(task)
publish_task_progress_event(task)
return task
@router.post("/{task_id}/retry", response_model=ProcessingTaskOut, status_code=status.HTTP_202_ACCEPTED, summary="Retry processing task")
def retry_task(
task_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> ProcessingTask:
"""Create a fresh queued task from a failed or cancelled task."""
previous = _get_task_or_404(task_id, db, current_user)
if previous.status not in {TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Task is not retryable in status: {previous.status}",
)
if previous.project_id is None:
raise HTTPException(status_code=400, detail="Task has no project_id")
project = db.query(Project).filter(Project.id == previous.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
is_propagation_task = previous.task_type == "propagate_masks"
if not is_propagation_task and not project.video_path:
raise HTTPException(status_code=400, detail="Project has no media uploaded")
payload = dict(previous.payload or {})
payload.setdefault("source_type", project.source_type or "video")
payload["retry_of"] = previous.id
task = ProcessingTask(
task_type=previous.task_type,
status=TASK_STATUS_QUEUED,
progress=0,
message=f"重试任务已入队(源任务 #{previous.id}",
project_id=project.id,
payload=payload,
)
if not is_propagation_task:
project.status = PROJECT_STATUS_PARSING
db.add(task)
db.commit()
db.refresh(task)
publish_task_progress_event(task)
async_result = propagate_project_masks.delay(task.id) if is_propagation_task else parse_project_media.delay(task.id)
task.celery_task_id = async_result.id
db.commit()
db.refresh(task)
publish_task_progress_event(task)
return task

View File

@@ -0,0 +1,183 @@
"""Template (Ontology) CRUD endpoints."""
import logging
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import or_
from sqlalchemy.orm import Session
from database import get_db
from models import Template, User
from routers.auth import get_current_user, require_editor
from schemas import TemplateCreate, TemplateOut, TemplateUpdate
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/templates", tags=["Templates"])
RESERVED_UNCLASSIFIED_CLASS = {
"id": "reserved-unclassified",
"name": "待分类",
"color": "#000000",
"zIndex": 0,
"maskId": 0,
"category": "系统保留",
}
def _is_reserved_class(item: dict) -> bool:
return (
item.get("id") == RESERVED_UNCLASSIFIED_CLASS["id"]
or item.get("name") == RESERVED_UNCLASSIFIED_CLASS["name"]
or item.get("maskId") == 0
)
def _normalize_template_classes(classes: list[dict] | None) -> list[dict]:
normalized = [item for item in (classes or []) if not _is_reserved_class(item)]
return [*normalized, dict(RESERVED_UNCLASSIFIED_CLASS)]
def _pack_mapping_rules(data: dict) -> dict:
"""Pack classes/rules into mapping_rules for DB storage."""
mapping = data.get("mapping_rules") or {}
if "classes" in data and data["classes"] is not None:
mapping["classes"] = _normalize_template_classes(data.pop("classes"))
if "rules" in data and data["rules"] is not None:
mapping["rules"] = data.pop("rules")
if "classes" in mapping:
mapping["classes"] = _normalize_template_classes(mapping.get("classes"))
data["mapping_rules"] = mapping
return data
def _unpack_template(template: Template) -> Template:
"""Unpack mapping_rules into classes/rules for response."""
mapping = template.mapping_rules or {}
# Set as attributes so Pydantic from_attributes can pick them up
template.classes = _normalize_template_classes(mapping.get("classes", []))
template.rules = mapping.get("rules", [])
return template
@router.post(
"",
response_model=TemplateOut,
status_code=status.HTTP_201_CREATED,
summary="Create a new template",
)
def create_template(
payload: TemplateCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Template:
"""Create a new ontology template / segmentation class."""
data = payload.model_dump()
data = _pack_mapping_rules(data)
template = Template(**data, owner_user_id=current_user.id)
db.add(template)
db.commit()
db.refresh(template)
_unpack_template(template)
logger.info("Created template id=%s name=%s", template.id, template.name)
return template
@router.get(
"",
response_model=List[TemplateOut],
summary="List all templates",
)
def list_templates(
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> List[Template]:
"""Retrieve all ontology templates."""
templates = (
db.query(Template)
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
.offset(skip)
.limit(limit)
.all()
)
for t in templates:
_unpack_template(t)
return templates
@router.get(
"/{template_id}",
response_model=TemplateOut,
summary="Get a single template",
)
def get_template(
template_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Template:
"""Retrieve a template by its ID."""
template = db.query(Template).filter(
Template.id == template_id,
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
_unpack_template(template)
return template
@router.patch(
"/{template_id}",
response_model=TemplateOut,
summary="Update a template",
)
def update_template(
template_id: int,
payload: TemplateUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Template:
"""Update template fields partially."""
template = db.query(Template).filter(
Template.id == template_id,
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
data = payload.model_dump(exclude_unset=True)
if "classes" in data or "rules" in data:
data = _pack_mapping_rules(data)
for key, value in data.items():
setattr(template, key, value)
db.commit()
db.refresh(template)
_unpack_template(template)
logger.info("Updated template id=%s", template_id)
return template
@router.delete(
"/{template_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a template",
)
def delete_template(
template_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> None:
"""Delete a template. Associated annotations will have template_id set to NULL."""
template = db.query(Template).filter(
Template.id == template_id,
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
db.delete(template)
db.commit()
logger.info("Deleted template id=%s", template_id)