feat: 完善分割工作区导入导出与管理流程

- 新增基于 JWT 当前用户的登录恢复、角色权限、用户管理、审计日志和演示出厂重置后台接口与前端管理页。

- 重串 GT_label 导出和 GT Mask 导入逻辑:导出保留类别真实 maskid,导入仅接受灰度或 RGB 等通道 maskid 图,支持未知 maskid 策略、尺寸最近邻拉伸和导入预览。

- 统一分割结果导出体验:默认当前帧,按项目抽帧顺序和 XhXXmXXsXXXms 时间戳命名 ZIP 与图片,补齐 GT/Pro/Mix/分开 Mask 输出和映射 JSON。

- 调整工作区左侧工具栏:移除创建点/线段入口,新增画笔、橡皮擦及尺寸控制,并按绘制、布尔、导入/AI 工具分组分隔。

- 扩展 Canvas 编辑能力:画笔按语义分类绘制并可自动并入连通选中 mask,橡皮擦对选中区域扣除,优化布尔操作、选区、撤销重做和保存状态联动。

- 优化自动传播时间轴显示:同一蓝色系按传播新旧递进变暗,老传播记录达到阈值后统一旧记录色,并维护范围选择与清空后的历史显示。

- 将 AI 智能分割入口替换为更明确的 AI 元素图标,并同步侧栏、工作区和 AI 页面入口表现。

- 完善模板分类、maskid 工具函数、分类树联动、遮罩透明度、边缘平滑和传播链同步相关前端状态。

- 扩展后端项目、媒体、任务、Dashboard、模板和传播 runner 的用户隔离、任务控制、进度事件与兼容处理。

- 补充前后端测试,覆盖用户管理、GT_label 往返导入导出、GT Mask 校验和预览、画笔/橡皮擦、时间轴传播历史、导出范围、WebSocket 与 API 封装。

- 更新 AGENTS、README 和 doc 文档,记录当前接口契约、实现状态、测试计划、安装说明和 maskid/GT_label 规则。
This commit is contained in:
2026-05-03 03:52:32 +08:00
parent 4c1d3dba73
commit afcddfaeb9
62 changed files with 6572 additions and 849 deletions

View File

@@ -33,6 +33,12 @@ class Settings(BaseSettings):
# App
app_env: str = "development"
cors_origins: list[str] = ["http://localhost:3000", "http://192.168.3.11:3000"]
jwt_secret_key: str = "seg-server-dev-secret-change-me"
jwt_algorithm: str = "HS256"
access_token_expire_minutes: int = 60 * 24
default_admin_username: str = "admin"
default_admin_password: str = "123456"
demo_video_path: str = "/home/wkmgc/Desktop/Seg_Server/Data_MyVideo_1.mp4"
class Config:
env_file = ".env"

View File

@@ -20,7 +20,7 @@ from progress_events import PROGRESS_CHANNEL
from redis_client import get_redis_client, ping as redis_ping
from statuses import PROJECT_STATUS_PENDING, PROJECT_STATUS_READY
from routers import projects, templates, media, ai, export, auth, dashboard, tasks
from routers import projects, templates, media, ai, export, auth, dashboard, tasks, admin
logging.basicConfig(
level=logging.INFO,
@@ -28,7 +28,7 @@ logging.basicConfig(
)
logger = logging.getLogger(__name__)
DEFAULT_VIDEO_PATH = "/home/wkmgc/Desktop/Seg_Server/Data_MyVideo_1.mp4"
DEFAULT_VIDEO_PATH = settings.demo_video_path
def _ensure_runtime_schema_columns() -> None:
@@ -36,25 +36,56 @@ def _ensure_runtime_schema_columns() -> None:
try:
inspector = inspect(engine)
frame_columns = {column["name"] for column in inspector.get_columns("frames")}
project_columns = {column["name"] for column in inspector.get_columns("projects")}
template_columns = {column["name"] for column in inspector.get_columns("templates")}
with engine.begin() as connection:
if "timestamp_ms" not in frame_columns:
connection.execute(text("ALTER TABLE frames ADD COLUMN timestamp_ms FLOAT"))
if "source_frame_number" not in frame_columns:
connection.execute(text("ALTER TABLE frames ADD COLUMN source_frame_number INTEGER"))
if "owner_user_id" not in project_columns:
connection.execute(text("ALTER TABLE projects ADD COLUMN owner_user_id INTEGER"))
if "owner_user_id" not in template_columns:
connection.execute(text("ALTER TABLE templates ADD COLUMN owner_user_id INTEGER"))
except Exception as exc: # noqa: BLE001
logger.warning("Runtime schema column check failed: %s", exc)
def _seed_default_admin_and_ownership_sync() -> None:
"""Ensure the default admin exists and owns legacy unassigned projects."""
from models import Project
from routers.auth import ensure_default_admin
db = SessionLocal()
try:
admin = ensure_default_admin(db)
db.query(Project).filter(Project.owner_user_id.is_(None)).update(
{"owner_user_id": admin.id},
synchronize_session=False,
)
db.commit()
logger.info("Default admin ready; legacy projects assigned to user id=%s", admin.id)
except Exception as exc: # noqa: BLE001
logger.error("Failed to seed default admin or ownership: %s", exc)
finally:
db.close()
def _seed_default_project_sync() -> None:
"""Synchronously seed the default video project on first startup."""
import cv2
from models import Project, Frame
from routers.auth import ensure_default_admin
from services.frame_parser import parse_video, upload_frames_to_minio, extract_thumbnail
db = SessionLocal()
try:
admin = ensure_default_admin(db)
existing = db.query(Project).filter(Project.name == "Data_MyVideo_1").first()
if existing is not None:
if existing.owner_user_id is None:
existing.owner_user_id = admin.id
db.commit()
return
if not os.path.exists(DEFAULT_VIDEO_PATH):
@@ -67,6 +98,7 @@ def _seed_default_project_sync() -> None:
status=PROJECT_STATUS_PENDING,
source_type="video",
parse_fps=30.0,
owner_user_id=admin.id,
)
db.add(project)
db.commit()
@@ -196,6 +228,7 @@ async def lifespan(app: FastAPI):
try:
Base.metadata.create_all(bind=engine)
_ensure_runtime_schema_columns()
_seed_default_admin_and_ownership_sync()
logger.info("Database tables initialized.")
except Exception as exc: # noqa: BLE001
logger.error("Database initialization failed: %s", exc)
@@ -265,6 +298,7 @@ app.include_router(ai.router)
app.include_router(export.router)
app.include_router(dashboard.router)
app.include_router(tasks.router)
app.include_router(admin.router)
@app.get("/health", tags=["Health"])

View File

@@ -17,6 +17,25 @@ from database import Base
from statuses import PROJECT_STATUS_PENDING
class User(Base):
"""Application user used for authentication and data ownership."""
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
username = Column(String(150), unique=True, index=True, nullable=False)
password_hash = Column(String(255), nullable=False)
role = Column(String(50), default="admin", nullable=False)
is_active = Column(Integer, default=1, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
projects = relationship("Project", back_populates="owner")
templates = relationship("Template", back_populates="owner")
class Project(Base):
"""Project model representing a segmentation project."""
@@ -31,11 +50,13 @@ class Project(Base):
source_type = Column(String(20), default="video", nullable=False) # video | dicom
original_fps = Column(Float, nullable=True)
parse_fps = Column(Float, default=30.0, nullable=False)
owner_user_id = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
owner = relationship("User", back_populates="projects")
frames = relationship("Frame", back_populates="project", cascade="all, delete-orphan")
annotations = relationship(
"Annotation", back_populates="project", cascade="all, delete-orphan"
@@ -77,8 +98,10 @@ class Template(Base):
color = Column(String(50), nullable=False)
z_index = Column(Integer, default=0, nullable=False)
mapping_rules = Column(JSON, nullable=True)
owner_user_id = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
owner = relationship("User", back_populates="templates")
annotations = relationship(
"Annotation", back_populates="template", cascade="all, delete-orphan"
)
@@ -129,6 +152,22 @@ class Mask(Base):
annotation = relationship("Annotation", back_populates="masks")
class AuditLog(Base):
"""Audit trail for security and administrative actions."""
__tablename__ = "audit_logs"
id = Column(Integer, primary_key=True, index=True)
actor_user_id = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
action = Column(String(120), nullable=False)
target_type = Column(String(80), nullable=True)
target_id = Column(String(120), nullable=True)
detail = Column(JSON, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
actor = relationship("User")
class ProcessingTask(Base):
"""Background task state persisted for dashboard and polling."""

270
backend/routers/admin.py Normal file
View File

@@ -0,0 +1,270 @@
"""Administrator-only user and audit management endpoints."""
import os
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from config import settings
from database import get_db
from minio_client import upload_file
from models import Annotation, AuditLog, Frame, Mask, ProcessingTask, Project, Template, User
from routers.auth import ensure_default_admin, hash_password, require_admin, write_audit_log
from schemas import (
AdminUserCreate,
AdminUserUpdate,
AuditLogOut,
DemoFactoryResetOut,
DemoFactoryResetRequest,
UserOut,
)
from statuses import PROJECT_STATUS_PENDING
router = APIRouter(prefix="/api/admin", tags=["Admin"])
VALID_ROLES = {"admin", "annotator", "viewer"}
DEMO_RESET_CONFIRMATION = "RESET_DEMO_FACTORY"
DEMO_PROJECT_NAME = "Data_MyVideo_1"
def _normalize_role(role: str | None) -> str:
normalized = (role or "annotator").strip().lower()
if normalized not in VALID_ROLES:
raise HTTPException(status_code=400, detail=f"Unsupported role: {role}")
return normalized
@router.get("/users", response_model=List[UserOut], summary="List users")
def list_users(
db: Session = Depends(get_db),
admin_user: User = Depends(require_admin),
) -> List[User]:
"""Return all users for the administrator console."""
_ = admin_user
return db.query(User).order_by(User.id).all()
@router.post(
"/users",
response_model=UserOut,
status_code=status.HTTP_201_CREATED,
summary="Create user",
)
def create_user(
payload: AdminUserCreate,
db: Session = Depends(get_db),
admin_user: User = Depends(require_admin),
) -> User:
"""Create a user with an initial password and role."""
username = payload.username.strip()
if not username:
raise HTTPException(status_code=400, detail="Username is required")
if len(payload.password) < 6:
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
user = User(
username=username,
password_hash=hash_password(payload.password),
role=_normalize_role(payload.role),
is_active=1 if payload.is_active else 0,
)
db.add(user)
try:
db.commit()
except IntegrityError as exc:
db.rollback()
raise HTTPException(status_code=409, detail="Username already exists") from exc
db.refresh(user)
write_audit_log(
db,
actor=admin_user,
action="admin.user_created",
target_type="user",
target_id=user.id,
detail={"username": user.username, "role": user.role, "is_active": bool(user.is_active)},
)
return user
@router.patch("/users/{user_id}", response_model=UserOut, summary="Update user")
def update_user(
user_id: int,
payload: AdminUserUpdate,
db: Session = Depends(get_db),
admin_user: User = Depends(require_admin),
) -> User:
"""Update username, password, role or active state."""
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
updates = payload.model_dump(exclude_unset=True)
audit_detail: dict = {"before": {"username": user.username, "role": user.role, "is_active": bool(user.is_active)}}
if "username" in updates:
username = (updates["username"] or "").strip()
if not username:
raise HTTPException(status_code=400, detail="Username is required")
user.username = username
if "password" in updates:
password = updates["password"] or ""
if len(password) < 6:
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
user.password_hash = hash_password(password)
if "role" in updates:
next_role = _normalize_role(updates["role"])
if user.id == admin_user.id and next_role != "admin":
raise HTTPException(status_code=400, detail="Cannot remove your own admin role")
user.role = next_role
if "is_active" in updates:
if user.id == admin_user.id and not updates["is_active"]:
raise HTTPException(status_code=400, detail="Cannot deactivate yourself")
user.is_active = 1 if updates["is_active"] else 0
try:
db.commit()
except IntegrityError as exc:
db.rollback()
raise HTTPException(status_code=409, detail="Username already exists") from exc
db.refresh(user)
audit_detail["after"] = {"username": user.username, "role": user.role, "is_active": bool(user.is_active)}
audit_detail["password_changed"] = "password" in updates
write_audit_log(
db,
actor=admin_user,
action="admin.user_updated",
target_type="user",
target_id=user.id,
detail=audit_detail,
)
return user
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete user")
def delete_user(
user_id: int,
db: Session = Depends(get_db),
admin_user: User = Depends(require_admin),
) -> None:
"""Delete a user when it is safe to remove the account."""
if user_id == admin_user.id:
raise HTTPException(status_code=400, detail="Cannot delete yourself")
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
owned_project_count = db.query(Project).filter(Project.owner_user_id == user_id).count()
if owned_project_count:
raise HTTPException(status_code=409, detail="User owns projects; deactivate or migrate projects first")
username = user.username
db.delete(user)
db.commit()
write_audit_log(
db,
actor=admin_user,
action="admin.user_deleted",
target_type="user",
target_id=user_id,
detail={"username": username},
)
return None
@router.get("/audit-logs", response_model=List[AuditLogOut], summary="List audit logs")
def list_audit_logs(
limit: int = 100,
db: Session = Depends(get_db),
admin_user: User = Depends(require_admin),
) -> List[AuditLog]:
"""Return recent audit events for administrators."""
_ = admin_user
safe_limit = min(max(int(limit or 100), 1), 500)
return db.query(AuditLog).order_by(AuditLog.created_at.desc(), AuditLog.id.desc()).limit(safe_limit).all()
@router.post(
"/demo-factory-reset",
response_model=DemoFactoryResetOut,
summary="Reset demo data to factory defaults",
)
def reset_demo_factory(
payload: DemoFactoryResetRequest,
db: Session = Depends(get_db),
admin_user: User = Depends(require_admin),
) -> dict:
"""Reset a demo deployment to one admin account and one unparsed demo video project."""
if payload.confirmation != DEMO_RESET_CONFIRMATION:
raise HTTPException(status_code=400, detail="Invalid reset confirmation")
if not os.path.exists(settings.demo_video_path):
raise HTTPException(
status_code=409,
detail=f"Demo video not found: {settings.demo_video_path}",
)
requested_by = admin_user.username
preserved_admin = ensure_default_admin(db)
preserved_admin.username = settings.default_admin_username
preserved_admin.password_hash = hash_password(settings.default_admin_password)
preserved_admin.role = "admin"
preserved_admin.is_active = 1
db.flush()
deleted_counts = {
"masks": db.query(Mask).delete(synchronize_session=False),
"annotations": db.query(Annotation).delete(synchronize_session=False),
"frames": db.query(Frame).delete(synchronize_session=False),
"tasks": db.query(ProcessingTask).delete(synchronize_session=False),
"projects": db.query(Project).delete(synchronize_session=False),
"user_templates": db.query(Template).filter(Template.owner_user_id.is_not(None)).delete(synchronize_session=False),
"audit_logs": db.query(AuditLog).delete(synchronize_session=False),
"users": db.query(User).filter(User.id != preserved_admin.id).delete(synchronize_session=False),
}
db.flush()
db.expunge_all()
preserved_admin = db.query(User).filter(User.username == settings.default_admin_username).first()
if not preserved_admin:
raise HTTPException(status_code=500, detail="Default admin was not preserved")
project = Project(
name=DEMO_PROJECT_NAME,
description="默认演示视频,尚未生成帧",
status=PROJECT_STATUS_PENDING,
source_type="video",
parse_fps=30.0,
owner_user_id=preserved_admin.id,
)
db.add(project)
db.flush()
with open(settings.demo_video_path, "rb") as file_obj:
data = file_obj.read()
object_name = f"uploads/{project.id}/{os.path.basename(settings.demo_video_path)}"
upload_file(object_name, data, content_type="video/mp4", length=len(data))
project.video_path = object_name
project.thumbnail_url = None
project.original_fps = None
db.commit()
db.refresh(preserved_admin)
db.refresh(project)
write_audit_log(
db,
actor=preserved_admin,
action="admin.demo_factory_reset",
target_type="project",
target_id=project.id,
detail={
"project_name": project.name,
"video_path": project.video_path,
"deleted_counts": deleted_counts,
"requested_by": requested_by,
},
)
return {
"admin_user": preserved_admin,
"project": project,
"deleted_counts": deleted_counts,
"message": "演示环境已恢复出厂设置",
}

View File

@@ -1,6 +1,7 @@
"""AI inference endpoints using selectable SAM runtimes."""
import logging
import math
import tempfile
from pathlib import Path
from typing import Any, List
@@ -8,11 +9,13 @@ from typing import Any, List
import cv2
import numpy as np
from fastapi import APIRouter, Depends, File, Form, HTTPException, Response, UploadFile, status
from sqlalchemy import or_
from sqlalchemy.orm import Session
from database import get_db
from minio_client import download_file
from models import Project, Frame, Template, Annotation, ProcessingTask
from models import Project, Frame, Template, Annotation, ProcessingTask, User
from routers.auth import get_current_user, require_editor
from schemas import (
AiRuntimeStatus,
MaskAnalysisRequest,
@@ -38,6 +41,102 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/ai", tags=["AI"])
def _owned_project_or_404(project_id: int, db: Session, current_user: User) -> Project:
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return project
def _owned_frame_or_404(frame_id: int, db: Session, current_user: User, project_id: int | None = None) -> Frame:
query = (
db.query(Frame)
.join(Project, Project.id == Frame.project_id)
.filter(Frame.id == frame_id, Project.owner_user_id == current_user.id)
)
if project_id is not None:
query = query.filter(Frame.project_id == project_id)
frame = query.first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
return frame
def _visible_template_or_404(template_id: int, db: Session, current_user: User) -> Template:
template = db.query(Template).filter(
Template.id == template_id,
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
return template
def _normalize_hex_color(value: Any) -> str | None:
if not isinstance(value, str):
return None
text = value.strip().lower()
if not text:
return None
if not text.startswith("#"):
text = f"#{text}"
if len(text) == 4:
text = "#" + "".join(char * 2 for char in text[1:])
if len(text) != 7:
return None
try:
int(text[1:], 16)
except ValueError:
return None
return text
def _rgb_tuple_to_hex(rgb: tuple[int, int, int]) -> str:
values = []
for channel in rgb:
value = int(channel)
if value > 255:
value = int(round(value / 257))
values.append(min(max(value, 0), 255))
return f"#{values[0]:02x}{values[1]:02x}{values[2]:02x}"
def _template_class_maps(template: Template | None) -> tuple[dict[int, dict[str, Any]], dict[str, dict[str, Any]]]:
classes = ((template.mapping_rules or {}).get("classes") if template else None) or []
by_maskid: dict[int, dict[str, Any]] = {}
by_color: dict[str, dict[str, Any]] = {}
for index, item in enumerate(classes):
if not isinstance(item, dict):
continue
maskid_value = item.get("maskId", item.get("maskid", item.get("mask_id")))
try:
maskid = int(maskid_value)
except (TypeError, ValueError):
maskid = index + 1
color = _normalize_hex_color(item.get("color")) or "#22c55e"
class_meta = {
"id": str(item.get("id") or f"maskid-{maskid}"),
"name": str(item.get("name") or f"类别 {maskid}"),
"color": color,
"zIndex": int(item.get("zIndex", item.get("z_index", index * 10))),
"maskId": maskid,
**({"category": item.get("category")} if item.get("category") else {}),
}
if maskid > 0:
by_maskid[maskid] = class_meta
by_color[color] = class_meta
return by_maskid, by_color
def _gt_unknown_label(token: int | str) -> str:
if isinstance(token, int):
return f"未定义类别 {token}"
return f"未定义颜色 {token}"
def _load_frame_image(frame: Frame) -> np.ndarray:
"""Download a frame from MinIO and decode it to an RGB numpy array."""
try:
@@ -106,16 +205,20 @@ def _normalize_polygons(polygons: list[list[list[float]]]) -> list[list[list[flo
return [polygon for polygon in (_normalize_polygon(polygon) for polygon in polygons) if len(polygon) >= 3]
def _analysis_anchors(polygons: list[list[list[float]]], points: list[list[float]] | None) -> list[list[float]]:
if points:
return [[_clamp01(point[0]), _clamp01(point[1])] for point in points if len(point) >= 2]
def _sample_anchor_points(anchors: list[list[float]], limit: int = 64) -> list[list[float]]:
if len(anchors) <= limit:
return anchors
step = max(1, math.ceil(len(anchors) / limit))
return anchors[::step][:limit]
def _analysis_anchor_summary(polygons: list[list[list[float]]]) -> tuple[int, list[list[float]]]:
anchors: list[list[float]] = []
for polygon in polygons:
if not polygon:
continue
step = max(1, len(polygon) // 12)
anchors.extend([[_clamp01(point[0]), _clamp01(point[1])] for point in polygon[::step]])
return anchors[:32]
anchors.extend([[_clamp01(point[0]), _clamp01(point[1])] for point in polygon])
return len(anchors), _sample_anchor_points(anchors)
def _normalize_smoothing_options(strength: float | int | None, method: str | None = None) -> dict[str, Any]:
@@ -129,8 +232,14 @@ def _normalize_smoothing_options(strength: float | int | None, method: str | Non
}
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int) -> list[list[float]]:
def _smoothing_ratio(strength: float, curve: float = 1.65) -> float:
normalized = max(0.0, min(float(strength or 0.0), 100.0)) / 100.0
return normalized ** curve
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int, corner_cut: float = 0.25) -> list[list[float]]:
points = polygon
q = max(0.02, min(float(corner_cut), 0.25))
for _ in range(max(0, iterations)):
if len(points) < 3:
break
@@ -138,12 +247,12 @@ def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int) -> list
for index, current in enumerate(points):
following = points[(index + 1) % len(points)]
next_points.append([
_clamp01(0.75 * current[0] + 0.25 * following[0]),
_clamp01(0.75 * current[1] + 0.25 * following[1]),
_clamp01((1.0 - q) * current[0] + q * following[0]),
_clamp01((1.0 - q) * current[1] + q * following[1]),
])
next_points.append([
_clamp01(0.25 * current[0] + 0.75 * following[0]),
_clamp01(0.25 * current[1] + 0.75 * following[1]),
_clamp01(q * current[0] + (1.0 - q) * following[0]),
_clamp01(q * current[1] + (1.0 - q) * following[1]),
])
points = next_points
return points
@@ -154,7 +263,7 @@ def _simplify_polygon(polygon: list[list[float]], strength: float) -> list[list[
return polygon
contour = np.array([[[point[0], point[1]]] for point in polygon], dtype=np.float32)
arc_length = cv2.arcLength(contour, True)
epsilon = arc_length * (0.001 + (strength / 100.0) * 0.006)
epsilon = arc_length * (0.00015 + _smoothing_ratio(strength) * 0.00735)
approx = cv2.approxPolyDP(contour, epsilon, True).reshape(-1, 2)
if len(approx) < 3:
return polygon
@@ -165,9 +274,25 @@ def _smooth_polygon(polygon: list[list[float]], smoothing: dict[str, Any]) -> li
strength = float(smoothing.get("strength") or 0.0)
if strength <= 0:
return _normalize_polygon(polygon)
iterations = max(1, min(3, int(strength // 35) + 1))
smoothed = _chaikin_smooth_polygon(_normalize_polygon(polygon), iterations)
simplified = _simplify_polygon(smoothed, strength)
effective_strength = _smoothing_ratio(strength, curve=1.45) * 100.0
if effective_strength >= 85:
iterations = 4
elif effective_strength >= 55:
iterations = 3
elif effective_strength >= 25:
iterations = 2
else:
iterations = 1
corner_cut = 0.03 + _smoothing_ratio(strength, curve=1.35) * 0.22
normalized = _normalize_polygon(polygon)
pre_simplified = _simplify_polygon(normalized, effective_strength * 0.25)
smoothed = _chaikin_smooth_polygon(pre_simplified, iterations, corner_cut)
simplified = _simplify_polygon(smoothed, effective_strength)
if len(simplified) > len(normalized):
for fallback_strength in (25.0, 35.0, 50.0, 70.0, 90.0, 100.0):
simplified = _simplify_polygon(simplified, max(effective_strength, fallback_strength))
if len(simplified) <= len(normalized):
break
return simplified if len(simplified) >= 3 else _normalize_polygon(polygon)
@@ -321,7 +446,11 @@ def _filter_predictions(
response_model=PredictResponse,
summary="Run SAM inference with a prompt",
)
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
def predict(
payload: PredictRequest,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> dict:
"""Execute selected SAM segmentation given an image and a prompt.
- **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized
@@ -330,9 +459,7 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
- **interactive**: `prompt_data` is `{ "box": [...], "points": [[x, y]], "labels": [1, 0] }`.
- **semantic**: disabled in the current SAM 2.1 point/box product flow.
"""
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
frame = _owned_frame_or_404(payload.image_id, db, current_user)
image = _load_frame_image(frame)
prompt_type = payload.prompt_type.lower()
@@ -478,7 +605,10 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
response_model=AiRuntimeStatus,
summary="Get SAM model and GPU runtime status",
)
def model_status(selected_model: str | None = None) -> dict:
def model_status(
selected_model: str | None = None,
_current_user: User = Depends(get_current_user),
) -> dict:
"""Return real runtime availability for GPU and the currently enabled SAM model."""
try:
return sam_registry.runtime_status(selected_model)
@@ -491,12 +621,14 @@ def model_status(selected_model: str | None = None) -> dict:
response_model=MaskAnalysisResponse,
summary="Analyze mask geometry and prompt anchors",
)
def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) -> dict:
def analyze_mask(
payload: MaskAnalysisRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> dict:
"""Return backend-computed mask properties for the frontend inspector."""
if payload.frame_id is not None:
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
_owned_frame_or_404(payload.frame_id, db, current_user)
mask_data = payload.mask_data or {}
polygons = mask_data.get("polygons") or []
@@ -521,13 +653,13 @@ def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) ->
else:
confidence_source = "manual_or_imported"
anchors = _analysis_anchors(valid_polygons, payload.points)
anchor_count, anchors = _analysis_anchor_summary(valid_polygons)
message = "已从后端重新提取几何拓扑锚点" if payload.extract_skeleton else "已读取后端几何属性"
return {
"confidence": confidence,
"confidence_source": confidence_source,
"topology_anchor_count": len(anchors),
"topology_anchor_count": anchor_count,
"topology_anchors": anchors,
"area": area,
"bbox": bbox,
@@ -541,16 +673,18 @@ def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) ->
response_model=SmoothMaskResponse,
summary="Smooth editable mask polygons with backend geometry rules",
)
def smooth_mask(payload: SmoothMaskRequest, db: Session = Depends(get_db)) -> dict:
def smooth_mask(
payload: SmoothMaskRequest,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> dict:
"""Return a smoothed polygon mask without persisting it.
The frontend keeps this as an explicit edit operation: users preview/apply it
to the current mask, then save through the normal annotation endpoint.
"""
if payload.frame_id is not None:
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
_owned_frame_or_404(payload.frame_id, db, current_user)
polygons = payload.mask_data.get("polygons") or []
valid_polygons = _normalize_polygons(polygons)
@@ -564,10 +698,10 @@ def smooth_mask(payload: SmoothMaskRequest, db: Session = Depends(get_db)) -> di
area = sum(_polygon_area(polygon) for polygon in smoothed_polygons)
bbox = _polygon_bbox(smoothed_polygons[0])
anchors = _analysis_anchors(smoothed_polygons, payload.points)
anchor_count, anchors = _analysis_anchor_summary(smoothed_polygons)
return {
"polygons": smoothed_polygons,
"topology_anchor_count": len(anchors),
"topology_anchor_count": anchor_count,
"topology_anchors": anchors,
"area": area,
"bbox": bbox,
@@ -581,7 +715,11 @@ def smooth_mask(payload: SmoothMaskRequest, db: Session = Depends(get_db)) -> di
response_model=PropagateResponse,
summary="Propagate one current-frame region across a video frame segment",
)
def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
def propagate(
payload: PropagateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> dict:
"""Track one selected region from the current frame across nearby frames.
SAM 2 uses the official video predictor with the selected mask as the seed.
@@ -592,16 +730,8 @@ def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
raise HTTPException(status_code=400, detail="direction must be forward, backward, or both")
max_frames = max(1, min(int(payload.max_frames or 30), 500))
project = db.query(Project).filter(Project.id == payload.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
source_frame = db.query(Frame).filter(
Frame.id == payload.frame_id,
Frame.project_id == payload.project_id,
).first()
if not source_frame:
raise HTTPException(status_code=404, detail="Frame not found")
_owned_project_or_404(payload.project_id, db, current_user)
source_frame = _owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
seed = payload.seed.model_dump(exclude_none=True)
polygons = seed.get("polygons") or []
@@ -709,18 +839,14 @@ def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
response_model=ProcessingTaskOut,
summary="Queue a background video propagation task",
)
def queue_propagate_task(payload: PropagateTaskRequest, db: Session = Depends(get_db)) -> ProcessingTaskOut:
def queue_propagate_task(
payload: PropagateTaskRequest,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> ProcessingTaskOut:
"""Queue multiple seed/direction propagation steps as one background task."""
project = db.query(Project).filter(Project.id == payload.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
source_frame = db.query(Frame).filter(
Frame.id == payload.frame_id,
Frame.project_id == payload.project_id,
).first()
if not source_frame:
raise HTTPException(status_code=404, detail="Frame not found")
_owned_project_or_404(payload.project_id, db, current_user)
source_frame = _owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
if not payload.steps:
raise HTTPException(status_code=400, detail="Propagation task requires at least one step")
@@ -768,11 +894,13 @@ def queue_propagate_task(payload: PropagateTaskRequest, db: Session = Depends(ge
response_model=PredictResponse,
summary="Run automatic segmentation",
)
def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
def auto_segment(
image_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> dict:
"""Run automatic mask generation on a frame using a grid of point prompts."""
frame = db.query(Frame).filter(Frame.id == image_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
frame = _owned_frame_or_404(image_id, db, current_user)
image = _load_frame_image(frame)
try:
@@ -792,16 +920,15 @@ def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
def save_annotation(
payload: AnnotationCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Annotation:
"""Persist an annotation (mask, points, bbox) into the database."""
project = db.query(Project).filter(Project.id == payload.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
_owned_project_or_404(payload.project_id, db, current_user)
if payload.frame_id:
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
_owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
if payload.template_id:
_visible_template_or_404(payload.template_id, db, current_user)
annotation = Annotation(**payload.model_dump())
db.add(annotation)
@@ -823,8 +950,10 @@ async def import_gt_mask(
template_id: int | None = Form(None),
label: str = Form("GT Mask"),
color: str = Form("#22c55e"),
unknown_color_policy: str = Form("undefined"),
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> List[Annotation]:
"""Convert a binary/label mask image into persisted polygon annotations.
@@ -833,62 +962,122 @@ async def import_gt_mask(
the frontend an editable point-region representation instead of a static
bitmap layer.
"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
_owned_project_or_404(project_id, db, current_user)
frame = _owned_frame_or_404(frame_id, db, current_user, project_id)
frame = db.query(Frame).filter(Frame.id == frame_id, Frame.project_id == project_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
if unknown_color_policy not in {"discard", "undefined"}:
raise HTTPException(status_code=400, detail="unknown_color_policy must be discard or undefined")
template: Template | None = None
if template_id is not None:
template = db.query(Template).filter(Template.id == template_id).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
template = _visible_template_or_404(template_id, db, current_user)
data = await file.read()
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
if image is None:
raise HTTPException(status_code=400, detail="Invalid mask image")
if image.ndim == 2:
label_image = image
elif image.ndim == 3 and image.shape[2] >= 3:
channels = image[:, :, :3]
# GT label images are maskid maps: either grayscale or RGB/BGR where
# all three color channels contain the same maskid value [X, X, X].
if not (np.array_equal(channels[:, :, 0], channels[:, :, 1]) and np.array_equal(channels[:, :, 1], channels[:, :, 2])):
raise HTTPException(
status_code=400,
detail="GT Mask 图片不符合要求:请上传灰度图,或 RGB 三通道完全相同的 maskid 图(背景 0像素值为 maskid",
)
label_image = channels[:, :, 0]
else:
raise HTTPException(
status_code=400,
detail="GT Mask 图片不符合要求:请上传灰度图,或 RGB 三通道完全相同的 maskid 图(背景 0像素值为 maskid",
)
width = int(frame.width or image.shape[1])
height = int(frame.height or image.shape[0])
label_values = [int(value) for value in np.unique(image) if int(value) > 0]
if not label_values:
original_height, original_width = int(label_image.shape[0]), int(label_image.shape[1])
resized_to_frame = original_width != width or original_height != height
if resized_to_frame:
label_image = cv2.resize(label_image, (width, height), interpolation=cv2.INTER_NEAREST)
by_maskid, _by_color = _template_class_maps(template)
has_template_classes = bool(by_maskid)
fallback_color = _normalize_hex_color(color) or "#22c55e"
import_items: list[dict[str, Any]] = []
skipped_unknown = 0
label_values = [int(value) for value in np.unique(label_image) if int(value) > 0]
for label_value in label_values:
class_meta = by_maskid.get(label_value)
is_unknown = has_template_classes and class_meta is None
if is_unknown and unknown_color_policy == "discard":
skipped_unknown += 1
continue
if class_meta:
annotation_label = class_meta["name"]
annotation_color = class_meta["color"]
elif is_unknown:
annotation_label = _gt_unknown_label(label_value)
annotation_color = fallback_color
else:
annotation_label = f"{label} {label_value}" if len(label_values) > 1 else label
annotation_color = fallback_color
import_items.append({
"token": label_value,
"binary": np.where(label_image == label_value, 255, 0).astype(np.uint8),
"label": annotation_label,
"color": annotation_color,
"class": class_meta,
"unknown": is_unknown,
"metadata": {
"gt_label_value": label_value,
"gt_original_size": {"width": original_width, "height": original_height},
"gt_resized_to_frame": resized_to_frame,
},
})
if not import_items:
if skipped_unknown > 0:
raise HTTPException(status_code=400, detail="No matching GT mask classes found")
raise HTTPException(status_code=400, detail="No foreground mask regions found")
has_multiple_labels = len(label_values) > 1
annotations: list[Annotation] = []
for label_value in label_values:
binary = np.where(image == label_value, 255, 0).astype(np.uint8)
for item in import_items:
binary = item["binary"]
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
annotation_label = f"{label} {label_value}" if has_multiple_labels else label
for contour in contours:
if cv2.contourArea(contour) < 1:
continue
polygon = _normalized_contour(contour, image.shape[1], image.shape[0])
polygon = _normalized_contour(contour, binary.shape[1], binary.shape[0])
if len(polygon) < 3:
continue
component = np.zeros_like(binary, dtype=np.uint8)
cv2.drawContours(component, [contour], -1, 1, thickness=-1)
seed_point = _component_seed_point(component, image.shape[1], image.shape[0])
bbox = _contour_bbox(contour, image.shape[1], image.shape[0])
seed_point = _component_seed_point(component, binary.shape[1], binary.shape[0])
bbox = _contour_bbox(contour, binary.shape[1], binary.shape[0])
mask_data = {
"polygons": [polygon],
"label": item["label"],
"color": item["color"],
"source": "gt_mask",
"image_size": {"width": width, "height": height},
**item["metadata"],
}
if item["class"]:
mask_data["class"] = item["class"]
if item["unknown"]:
mask_data["gt_unknown_class"] = True
annotation = Annotation(
project_id=project_id,
frame_id=frame_id,
template_id=template_id,
mask_data={
"polygons": [polygon],
"label": annotation_label,
"color": color,
"source": "gt_mask",
"gt_label_value": label_value,
"image_size": {"width": width, "height": height},
},
mask_data=mask_data,
points=[seed_point],
bbox=bbox,
)
@@ -914,14 +1103,14 @@ def list_annotations(
project_id: int,
frame_id: int | None = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> List[Annotation]:
"""Return persisted annotations for a project, optionally scoped to one frame."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
_owned_project_or_404(project_id, db, current_user)
query = db.query(Annotation).filter(Annotation.project_id == project_id)
if frame_id is not None:
_owned_frame_or_404(frame_id, db, current_user, project_id)
query = query.filter(Annotation.frame_id == frame_id)
return query.order_by(Annotation.id).all()
@@ -935,17 +1124,21 @@ def update_annotation(
annotation_id: int,
payload: AnnotationUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Annotation:
"""Update mutable annotation fields persisted in the database."""
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
annotation = (
db.query(Annotation)
.join(Project, Project.id == Annotation.project_id)
.filter(Annotation.id == annotation_id, Project.owner_user_id == current_user.id)
.first()
)
if not annotation:
raise HTTPException(status_code=404, detail="Annotation not found")
updates = payload.model_dump(exclude_unset=True)
if "template_id" in updates and updates["template_id"] is not None:
template = db.query(Template).filter(Template.id == updates["template_id"]).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
_visible_template_or_404(updates["template_id"], db, current_user)
for field, value in updates.items():
setattr(annotation, field, value)
@@ -964,9 +1157,15 @@ def update_annotation(
def delete_annotation(
annotation_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Response:
"""Delete an annotation and its derived mask rows through ORM cascade."""
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
annotation = (
db.query(Annotation)
.join(Project, Project.id == Annotation.project_id)
.filter(Annotation.id == annotation_id, Project.owner_user_id == current_user.id)
.first()
)
if not annotation:
raise HTTPException(status_code=404, detail="Annotation not found")

View File

@@ -1,9 +1,23 @@
"""Authentication endpoints."""
"""Authentication endpoints and dependencies."""
from fastapi import APIRouter, HTTPException
from datetime import datetime, timedelta, timezone
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import JWTError, jwt
from passlib.context import CryptContext
from pydantic import BaseModel
from sqlalchemy.orm import Session
from config import settings
from database import get_db
from models import AuditLog, User
from schemas import LoginResponse, UserOut
router = APIRouter(prefix="/api/auth", tags=["Auth"])
password_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
bearer_scheme = HTTPBearer(auto_error=False)
class LoginRequest(BaseModel):
@@ -11,14 +25,151 @@ class LoginRequest(BaseModel):
password: str
class LoginResponse(BaseModel):
token: str
username: str
def hash_password(password: str) -> str:
"""Hash a plain password for storage."""
return password_context.hash(password)
def verify_password(password: str, password_hash: str) -> bool:
"""Verify a plain password against a stored hash."""
return password_context.verify(password, password_hash)
def create_access_token(user: User, expires_delta: timedelta | None = None) -> str:
"""Create a signed JWT access token for a user."""
expire = datetime.now(timezone.utc) + (
expires_delta or timedelta(minutes=settings.access_token_expire_minutes)
)
payload: dict[str, Any] = {
"sub": str(user.id),
"username": user.username,
"role": user.role,
"exp": expire,
}
return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
def ensure_default_admin(db: Session) -> User:
"""Create the default development admin if the user table is empty."""
existing = db.query(User).filter(User.username == settings.default_admin_username).first()
if existing:
return existing
user = User(
username=settings.default_admin_username,
password_hash=hash_password(settings.default_admin_password),
role="admin",
is_active=1,
)
db.add(user)
db.commit()
db.refresh(user)
return user
def get_current_user(
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
db: Session = Depends(get_db),
) -> User:
"""Resolve and validate the current user from the Bearer token."""
if credentials is None or credentials.scheme.lower() != "bearer":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(
credentials.credentials,
settings.jwt_secret_key,
algorithms=[settings.jwt_algorithm],
)
user_id = int(payload.get("sub"))
except (JWTError, TypeError, ValueError) as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
) from exc
user = db.query(User).filter(User.id == user_id).first()
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Inactive or missing user",
headers={"WWW-Authenticate": "Bearer"},
)
return user
def require_admin(current_user: User = Depends(get_current_user)) -> User:
"""Require the current user to have the administrator role."""
if current_user.role != "admin":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin permission required")
return current_user
def require_editor(current_user: User = Depends(get_current_user)) -> User:
"""Require a user role that can modify segmentation data."""
if current_user.role not in {"admin", "annotator"}:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Edit permission required")
return current_user
def write_audit_log(
db: Session,
*,
actor: User | None,
action: str,
target_type: str | None = None,
target_id: str | int | None = None,
detail: dict[str, Any] | None = None,
) -> AuditLog:
"""Persist a compact audit event."""
log = AuditLog(
actor_user_id=actor.id if actor else None,
action=action,
target_type=target_type,
target_id=str(target_id) if target_id is not None else None,
detail=detail or {},
)
db.add(log)
db.commit()
db.refresh(log)
return log
@router.post("/login", response_model=LoginResponse)
def login(payload: LoginRequest) -> dict:
"""Simple login for development."""
if payload.username == "admin" and payload.password == "123456":
return {"token": "fake-jwt-token-for-admin", "username": payload.username}
raise HTTPException(status_code=401, detail="Invalid credentials")
def login(payload: LoginRequest, db: Session = Depends(get_db)) -> dict:
"""Authenticate a user and return a signed JWT."""
ensure_default_admin(db)
user = db.query(User).filter(User.username == payload.username).first()
if not user or not user.is_active or not verify_password(payload.password, user.password_hash):
write_audit_log(
db,
actor=None,
action="auth.login_failed",
target_type="user",
target_id=payload.username,
detail={"username": payload.username},
)
raise HTTPException(status_code=401, detail="Invalid credentials")
write_audit_log(
db,
actor=user,
action="auth.login_success",
target_type="user",
target_id=user.id,
detail={"username": user.username},
)
return {
"token": create_access_token(user),
"token_type": "bearer",
"username": user.username,
"user": user,
}
@router.get("/me", response_model=UserOut)
def read_current_user(current_user: User = Depends(get_current_user)) -> User:
"""Return the authenticated user profile."""
return current_user

View File

@@ -5,11 +5,12 @@ from datetime import datetime, timezone
from typing import Any
from fastapi import APIRouter, Depends
from sqlalchemy import func
from sqlalchemy import func, or_
from sqlalchemy.orm import Session
from database import get_db
from models import Annotation, Frame, ProcessingTask, Project, Template
from models import Annotation, Frame, ProcessingTask, Project, Template, User
from routers.auth import get_current_user
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
@@ -52,22 +53,45 @@ def _task_payload(task: ProcessingTask) -> dict[str, Any]:
@router.get("/overview", summary="Get dashboard overview")
def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
def get_dashboard_overview(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> dict[str, Any]:
"""Return live dashboard data derived from persisted backend records."""
project_count = db.query(func.count(Project.id)).scalar() or 0
frame_count = db.query(func.count(Frame.id)).scalar() or 0
annotation_count = db.query(func.count(Annotation.id)).scalar() or 0
template_count = db.query(func.count(Template.id)).scalar() or 0
owned_project_ids_query = db.query(Project.id).filter(Project.owner_user_id == current_user.id)
project_count = db.query(func.count(Project.id)).filter(Project.owner_user_id == current_user.id).scalar() or 0
frame_count = db.query(func.count(Frame.id)).filter(Frame.project_id.in_(owned_project_ids_query)).scalar() or 0
annotation_count = (
db.query(func.count(Annotation.id))
.filter(Annotation.project_id.in_(owned_project_ids_query))
.scalar()
or 0
)
template_count = (
db.query(func.count(Template.id))
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
.scalar()
or 0
)
active_task_count = (
db.query(func.count(ProcessingTask.id))
.outerjoin(Project, Project.id == ProcessingTask.project_id)
.filter((ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id))
.filter(ProcessingTask.status.in_(ACTIVE_TASK_STATUSES))
.scalar()
or 0
)
projects = db.query(Project).order_by(Project.updated_at.desc()).all()
projects = (
db.query(Project)
.filter(Project.owner_user_id == current_user.id)
.order_by(Project.updated_at.desc())
.all()
)
recent_tasks = (
db.query(ProcessingTask)
.outerjoin(Project, Project.id == ProcessingTask.project_id)
.filter((ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id))
.order_by(ProcessingTask.created_at.desc())
.limit(50)
.all()
@@ -96,6 +120,7 @@ def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
recent_annotations = (
db.query(Annotation)
.filter(Annotation.project_id.in_(owned_project_ids_query))
.order_by(Annotation.updated_at.desc())
.limit(10)
.all()
@@ -112,6 +137,7 @@ def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
recent_templates = (
db.query(Template)
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
.order_by(Template.created_at.desc())
.limit(10)
.all()

View File

@@ -4,17 +4,22 @@ import io
import json
import logging
import os
import re
import zipfile
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List
from urllib.parse import quote
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from database import get_db
from models import Project, Annotation, Frame, Template
from minio_client import download_file
from models import Project, Annotation, Frame, Template, User
from routers.auth import get_current_user
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/export", tags=["Export"])
@@ -49,6 +54,30 @@ def _annotation_z_index(annotation: Annotation) -> int:
return 0
def _annotation_mask_id(annotation: Annotation) -> int | None:
class_meta = (annotation.mask_data or {}).get("class") or {}
if isinstance(class_meta, dict):
for key in ("maskId", "maskid", "mask_id"):
if class_meta.get(key) is None:
continue
try:
value = int(class_meta[key])
except (TypeError, ValueError):
continue
if value > 0:
return value
return None
def _annotation_category_name(annotation: Annotation) -> str:
class_meta = (annotation.mask_data or {}).get("class") or {}
if isinstance(class_meta, dict) and class_meta.get("category"):
return str(class_meta["category"])
if annotation.template and annotation.template.name:
return str(annotation.template.name)
return ""
def _annotation_class_key(annotation: Annotation) -> str:
class_meta = (annotation.mask_data or {}).get("class") or {}
if isinstance(class_meta, dict):
@@ -85,38 +114,162 @@ def _annotation_color(annotation: Annotation) -> str:
return "#ffffff"
@router.get(
"/{project_id}/coco",
summary="Export annotations in COCO format",
)
def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
"""Export all annotations for a project as a COCO-format JSON file."""
project = db.query(Project).filter(Project.id == project_id).first()
def _hex_to_rgb(color: str) -> list[int]:
value = str(color or "").strip()
if value.startswith("#"):
value = value[1:]
if len(value) == 3:
value = "".join(part * 2 for part in value)
if len(value) != 6:
return [255, 255, 255]
try:
return [int(value[i:i + 2], 16) for i in (0, 2, 4)]
except ValueError:
return [255, 255, 255]
def _safe_filename_part(value: Any, fallback: str = "unknown") -> str:
text = str(value or "").strip()
if not text:
text = fallback
text = re.sub(r"[\\/:*?\"<>|\s]+", "_", text)
text = re.sub(r"_+", "_", text).strip("._")
return text or fallback
def _project_video_name(project: Project) -> str:
if project.video_path:
stem = Path(project.video_path).name
if "." in stem:
stem = ".".join(stem.split(".")[:-1])
if stem:
return _safe_filename_part(stem, f"project_{project.id}")
return _safe_filename_part(project.name, f"project_{project.id}")
def _project_export_name(project: Project) -> str:
return _safe_filename_part(project.name, f"project_{project.id}")
def _frame_timestamp_ms(frame: Frame, project: Project) -> float:
if frame.timestamp_ms is not None:
return float(frame.timestamp_ms)
fps = project.parse_fps or project.original_fps or 30.0
return float(frame.frame_index) * 1000.0 / max(float(fps), 1.0)
def _project_frame_number(frame: Frame) -> int:
return int(frame.frame_index) + 1
def _format_timestamp_ms(value: float) -> str:
total_ms = max(0, int(round(float(value))))
hours = total_ms // 3_600_000
minutes = (total_ms % 3_600_000) // 60_000
seconds = (total_ms % 60_000) // 1_000
milliseconds = total_ms % 1_000
return f"{hours}h{minutes:02d}m{seconds:02d}s{milliseconds:03d}ms"
def _frame_export_stem(project: Project, frame: Frame) -> str:
return "_".join([
_project_video_name(project),
_format_timestamp_ms(_frame_timestamp_ms(frame, project)),
f"frame{_project_frame_number(frame):06d}",
])
def _segmentation_results_filename(project: Project, frames: list[Frame]) -> str:
if not frames:
return f"{_project_export_name(project)}_seg_T_0h00m00s000ms-0h00m00s000ms_P_0-0.zip"
first_frame = frames[0]
last_frame = frames[-1]
return (
f"{_project_export_name(project)}"
f"_seg_T_{_format_timestamp_ms(_frame_timestamp_ms(first_frame, project))}"
f"-{_format_timestamp_ms(_frame_timestamp_ms(last_frame, project))}"
f"_P_{_project_frame_number(first_frame)}-{_project_frame_number(last_frame)}.zip"
)
def _download_content_disposition(filename: str) -> str:
ascii_fallback = filename.encode("ascii", "ignore").decode("ascii") or "segmentation_results.zip"
ascii_fallback = _safe_filename_part(ascii_fallback, "segmentation_results.zip")
if not ascii_fallback.endswith(".zip") and filename.endswith(".zip"):
ascii_fallback = f"{ascii_fallback}.zip"
return f"attachment; filename=\"{ascii_fallback}\"; filename*=UTF-8''{quote(filename)}"
def _frame_image_extension(frame: Frame) -> str:
suffix = Path(frame.image_url or "").suffix.lower()
return suffix if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} else ".jpg"
def _project_or_404(project_id: int, db: Session, current_user: User) -> Project:
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return project
annotations = (
db.query(Annotation)
.filter(Annotation.project_id == project_id)
.all()
)
frames = (
def _project_frames(project_id: int, db: Session) -> list[Frame]:
return (
db.query(Frame)
.filter(Frame.project_id == project_id)
.order_by(Frame.frame_index)
.all()
)
templates = db.query(Template).all()
# Build COCO structure
def _filter_frames(
frames: list[Frame],
*,
scope: str = "all",
start_frame: int | None = None,
end_frame: int | None = None,
frame_id: int | None = None,
) -> list[Frame]:
if scope == "current":
if frame_id is None:
raise HTTPException(status_code=400, detail="frame_id is required for current-frame export")
selected = [frame for frame in frames if frame.id == frame_id]
if not selected:
raise HTTPException(status_code=404, detail="Frame not found")
return selected
if scope == "range":
if start_frame is None or end_frame is None:
raise HTTPException(status_code=400, detail="start_frame and end_frame are required for range export")
start = max(1, min(int(start_frame), int(end_frame)))
end = max(1, max(int(start_frame), int(end_frame)))
return frames[start - 1:end]
return frames
def _filtered_annotations(project_id: int, frame_ids: set[int], db: Session) -> list[Annotation]:
if not frame_ids:
return []
return (
db.query(Annotation)
.filter(Annotation.project_id == project_id)
.filter(Annotation.frame_id.in_(frame_ids))
.all()
)
def _build_coco(project: Project, frames: list[Frame], annotations: list[Annotation], templates: list[Template]) -> dict[str, Any]:
images = []
for idx, frame in enumerate(frames):
for frame in frames:
images.append({
"id": frame.id,
"file_name": frame.image_url,
"width": frame.width or 1920,
"height": frame.height or 1080,
"frame_index": idx,
"frame_index": frame.frame_index,
})
categories = []
@@ -131,14 +284,14 @@ def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResp
coco_annotations = []
ann_id = 1
selected_frame_ids = {frame.id for frame in frames}
for ann in annotations:
if not ann.mask_data:
if ann.frame_id not in selected_frame_ids or not ann.mask_data:
continue
polygons = ann.mask_data.get("polygons", [])
if not polygons:
continue
# Use first polygon for bbox / area approximation
first_poly = polygons[0]
xs = [p[0] for p in first_poly]
ys = [p[1] for p in first_poly]
@@ -171,7 +324,7 @@ def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResp
})
ann_id += 1
coco = {
return {
"info": {
"description": f"Annotations for {project.name}",
"version": "1.0",
@@ -183,39 +336,235 @@ def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResp
"categories": categories,
}
data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8")
filename = f"project_{project_id}_coco.json"
return StreamingResponse(
io.BytesIO(data),
media_type="application/json",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
def _class_mapping_entry(annotation: Annotation) -> dict[str, Any]:
return {
"key": _annotation_class_key(annotation),
"className": _annotation_label(annotation),
"chineseName": _annotation_label(annotation),
"categoryName": _annotation_category_name(annotation),
"color": _annotation_color(annotation),
"internalPriority": _annotation_z_index(annotation),
"maskidHint": _annotation_mask_id(annotation),
"template_id": annotation.template_id,
}
def _build_gt_class_mapping(annotations: list[Annotation]) -> tuple[dict[str, int], list[dict[str, Any]]]:
entries_by_key: dict[str, dict[str, Any]] = {}
for annotation in annotations:
if not annotation.mask_data or not annotation.mask_data.get("polygons"):
continue
entry = _class_mapping_entry(annotation)
entries_by_key.setdefault(entry["key"], entry)
ordered = sorted(
entries_by_key.values(),
key=lambda item: (
item["maskidHint"] if isinstance(item.get("maskidHint"), int) and item["maskidHint"] > 0 else 10_000_000,
str(item["className"]),
str(item["key"]),
),
)
key_to_value: dict[str, int] = {}
classes: list[dict[str, Any]] = []
used_maskids: set[int] = set()
next_maskid = 1
def next_available_maskid() -> int:
nonlocal next_maskid
while next_maskid in used_maskids:
next_maskid += 1
value = next_maskid
used_maskids.add(value)
next_maskid += 1
return value
for entry in ordered:
hinted_maskid = entry.get("maskidHint")
if isinstance(hinted_maskid, int) and hinted_maskid > 0 and hinted_maskid not in used_maskids:
maskid = hinted_maskid
used_maskids.add(maskid)
else:
maskid = next_available_maskid()
key_to_value[entry["key"]] = maskid
classes.append({
"gt_pixel_value": maskid,
"maskid": maskid,
"chineseName": entry["chineseName"],
"className": entry["className"],
"categoryName": entry["categoryName"],
"rgb": _hex_to_rgb(entry["color"]),
"color": entry["color"],
"key": entry["key"],
"template_id": entry["template_id"],
})
return key_to_value, classes
@router.get(
"/{project_id}/masks",
summary="Export PNG masks as a ZIP archive",
)
def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
"""Export individual masks plus z-index fused semantic masks inside a ZIP."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
def _parse_result_outputs(mask_type: str, outputs: str | None) -> set[str]:
allowed = {"separate", "gt_label", "pro_label", "mix_label"}
if outputs:
parsed = {item.strip() for item in outputs.split(",") if item.strip()}
invalid = parsed - allowed
if invalid:
raise HTTPException(status_code=400, detail=f"Invalid outputs: {', '.join(sorted(invalid))}")
return parsed or allowed
if mask_type == "separate":
return {"separate"}
if mask_type == "gt_label":
return {"gt_label"}
if mask_type == "pro_label":
return {"pro_label"}
if mask_type == "mix_label":
return {"mix_label"}
return allowed
def _write_original_frames(
zf: zipfile.ZipFile,
project: Project,
frames: list[Frame],
) -> dict[int, bytes]:
image_bytes_by_frame: dict[int, bytes] = {}
for frame in frames:
image_bytes = download_file(frame.image_url)
image_bytes_by_frame[frame.id] = image_bytes
zf.writestr(
f"原始图片/{_frame_export_stem(project, frame)}{_frame_image_extension(frame)}",
image_bytes,
)
return image_bytes_by_frame
def _decode_original_image(image_bytes: bytes | None, width: int, height: int) -> np.ndarray:
import cv2
annotations = (
db.query(Annotation)
.filter(Annotation.project_id == project_id)
.all()
)
frames = (
db.query(Frame)
.filter(Frame.project_id == project_id)
.order_by(Frame.frame_index)
.all()
)
if image_bytes:
decoded = cv2.imdecode(np.frombuffer(image_bytes, dtype=np.uint8), cv2.IMREAD_COLOR)
if decoded is not None:
if decoded.shape[1] != width or decoded.shape[0] != height:
decoded = cv2.resize(decoded, (width, height), interpolation=cv2.INTER_AREA)
return decoded
return np.zeros((height, width, 3), dtype=np.uint8)
def _write_result_mask_outputs(
zf: zipfile.ZipFile,
project: Project,
frames: list[Frame],
annotations: list[Annotation],
*,
outputs: set[str],
class_values: dict[str, int],
class_mapping: list[dict[str, Any]],
original_images: dict[int, bytes],
mix_opacity: float,
) -> None:
import cv2
include_individual = "separate" in outputs
include_semantic = "gt_label" in outputs
include_pro_label = "pro_label" in outputs
include_mix_label = "mix_label" in outputs
class_rgb_by_key = {
item["key"]: item.get("rgb") or _hex_to_rgb(item.get("color", "#ffffff"))
for item in class_mapping
}
annotations_by_frame: dict[int, list[Annotation]] = {}
selected_frame_ids = {frame.id for frame in frames}
for annotation in annotations:
if annotation.frame_id not in selected_frame_ids or not annotation.mask_data:
continue
if not annotation.mask_data.get("polygons"):
continue
annotations_by_frame.setdefault(annotation.frame_id, []).append(annotation)
for frame in frames:
frame_annotations = annotations_by_frame.get(frame.id, [])
if not frame_annotations:
continue
width = frame.width or 1920
height = frame.height or 1080
frame_stem = _frame_export_stem(project, frame)
if include_individual:
class_masks: dict[str, np.ndarray] = {}
class_meta: dict[str, dict[str, Any]] = {}
for annotation in frame_annotations:
key = _annotation_class_key(annotation)
combined = class_masks.setdefault(key, np.zeros((height, width), dtype=np.uint8))
for poly in (annotation.mask_data or {}).get("polygons", []):
combined[:] = np.maximum(combined, _mask_from_polygon(poly, width, height))
class_meta.setdefault(key, _class_mapping_entry(annotation))
folder = f"分开Mask分割结果/{frame_stem}_分别导出"
for key, mask in sorted(class_masks.items(), key=lambda item: int(class_meta[item[0]]["internalPriority"])):
meta = class_meta[key]
maskid = class_values.get(key)
if maskid is None:
continue
_, encoded = cv2.imencode(".png", mask)
class_name = _safe_filename_part(meta["className"], "class")
zf.writestr(
f"{folder}/{frame_stem}_{class_name}_maskid{maskid}.png",
encoded.tobytes(),
)
needs_fused_output = include_semantic or include_pro_label or include_mix_label
semantic = np.zeros((height, width), dtype=np.uint16) if needs_fused_output else None
pro_label = np.zeros((height, width, 3), dtype=np.uint8) if (include_pro_label or include_mix_label) else None
if needs_fused_output:
for annotation in sorted(frame_annotations, key=_annotation_z_index):
key = _annotation_class_key(annotation)
value = class_values.get(key)
if value is None:
continue
combined = np.zeros((height, width), dtype=np.uint8)
for poly in (annotation.mask_data or {}).get("polygons", []):
combined = np.maximum(combined, _mask_from_polygon(poly, width, height))
if semantic is not None:
semantic[combined > 0] = value
if pro_label is not None:
rgb = class_rgb_by_key.get(key, [255, 255, 255])
bgr = np.array([rgb[2], rgb[1], rgb[0]], dtype=np.uint8)
pro_label[combined > 0] = bgr
if include_semantic and semantic is not None:
_, encoded = cv2.imencode(".png", semantic)
zf.writestr(f"GT_label图/{frame_stem}.png", encoded.tobytes())
if include_pro_label and pro_label is not None:
_, encoded = cv2.imencode(".png", pro_label)
zf.writestr(f"Pro_label彩色分割结果/{frame_stem}.png", encoded.tobytes())
if include_mix_label and pro_label is not None:
original = _decode_original_image(original_images.get(frame.id), width, height)
mask_pixels = np.any(pro_label > 0, axis=2)
mixed = original.copy()
opacity = min(max(float(mix_opacity), 0.0), 1.0)
mixed[mask_pixels] = (
original[mask_pixels].astype(np.float32) * (1.0 - opacity)
+ pro_label[mask_pixels].astype(np.float32) * opacity
).clip(0, 255).astype(np.uint8)
_, encoded = cv2.imencode(".png", mixed)
zf.writestr(f"Mix_label重叠覆盖彩色分割结果/{frame_stem}.png", encoded.tobytes())
def _write_mask_pngs(
zf: zipfile.ZipFile,
frames: list[Frame],
annotations: list[Annotation],
*,
mask_type: str,
individual_prefix: str = "",
semantic_prefix: str = "",
semantic_file_stem: str = "semantic_frame",
semantic_dtype: Any = np.uint8,
) -> list[dict[str, Any]]:
import cv2
class_values: dict[str, int] = {}
semantic_classes: list[dict[str, Any]] = []
@@ -235,46 +584,102 @@ def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingRes
})
return class_values[key]
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
frame_masks: dict[int, list[tuple[Annotation, np.ndarray]]] = {}
for ann in annotations:
if not ann.mask_data:
continue
polygons = ann.mask_data.get("polygons", [])
if not polygons:
continue
include_individual = mask_type in {"separate", "both"}
include_semantic = mask_type in {"gt_label", "both"}
frame_masks: dict[int, list[tuple[Annotation, np.ndarray]]] = {}
selected_frame_ids = {frame.id for frame in frames}
width = ann.frame.width if ann.frame else 1920
height = ann.frame.height if ann.frame else 1080
combined = np.zeros((height, width), dtype=np.uint8)
for ann in annotations:
if ann.frame_id not in selected_frame_ids or not ann.mask_data:
continue
polygons = ann.mask_data.get("polygons", [])
if not polygons:
continue
for poly in polygons:
mask = _mask_from_polygon(poly, width, height)
combined = np.maximum(combined, mask)
width = ann.frame.width if ann.frame else 1920
height = ann.frame.height if ann.frame else 1080
combined = np.zeros((height, width), dtype=np.uint8)
for poly in polygons:
mask = _mask_from_polygon(poly, width, height)
combined = np.maximum(combined, mask)
if include_individual:
_, encoded = cv2.imencode(".png", combined)
fname = f"mask_{ann.id:06d}.png"
zf.writestr(fname, encoded.tobytes())
if ann.frame_id is not None:
frame_masks.setdefault(ann.frame_id, []).append((ann, combined))
zf.writestr(f"{individual_prefix}mask_{ann.id:06d}.png", encoded.tobytes())
if include_semantic and ann.frame_id is not None:
frame_masks.setdefault(ann.frame_id, []).append((ann, combined))
if include_semantic:
for frame in frames:
entries = frame_masks.get(frame.id, [])
if not entries:
continue
width = frame.width or 1920
height = frame.height or 1080
semantic = np.zeros((height, width), dtype=np.uint8)
semantic = np.zeros((height, width), dtype=semantic_dtype)
for ann, mask in sorted(entries, key=lambda item: _annotation_z_index(item[0])):
semantic[mask > 0] = class_value(ann)
_, encoded = cv2.imencode(".png", semantic)
zf.writestr(f"semantic_frame_{frame.frame_index:06d}.png", encoded.tobytes())
zf.writestr(f"{semantic_prefix}{semantic_file_stem}_{frame.frame_index:06d}.png", encoded.tobytes())
if include_semantic:
zf.writestr(
"semantic_classes.json",
f"{semantic_prefix}semantic_classes.json",
json.dumps({"classes": semantic_classes}, ensure_ascii=False, indent=2).encode("utf-8"),
)
return semantic_classes
@router.get(
"/{project_id}/coco",
summary="Export annotations in COCO format",
)
def export_coco(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> StreamingResponse:
"""Export all annotations for a project as a COCO-format JSON file."""
project = _project_or_404(project_id, db, current_user)
frames = _project_frames(project_id, db)
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
templates = db.query(Template).all()
coco = _build_coco(project, frames, annotations, templates)
data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8")
filename = f"project_{project_id}_coco.json"
return StreamingResponse(
io.BytesIO(data),
media_type="application/json",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
@router.get(
"/{project_id}/masks",
summary="Export PNG masks as a ZIP archive",
)
def export_masks(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> StreamingResponse:
"""Export individual masks plus z-index fused semantic masks inside a ZIP."""
_project_or_404(project_id, db, current_user)
frames = _project_frames(project_id, db)
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
_write_mask_pngs(
zf,
frames,
annotations,
mask_type="both",
semantic_prefix="",
individual_prefix="",
)
zip_buffer.seek(0)
filename = f"project_{project_id}_masks.zip"
@@ -284,3 +689,71 @@ def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingRes
media_type="application/zip",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
@router.get(
"/{project_id}/results",
summary="Export segmentation results as a ZIP archive",
)
def export_results(
project_id: int,
scope: str = Query("all", pattern="^(all|range|current)$"),
mask_type: str = Query("both", pattern="^(separate|gt_label|pro_label|mix_label|both|all)$"),
outputs: str | None = Query(None),
mix_opacity: float = Query(0.3, ge=0.0, le=1.0),
start_frame: int | None = Query(None, ge=1),
end_frame: int | None = Query(None, ge=1),
frame_id: int | None = Query(None, ge=1),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> StreamingResponse:
"""Export JSON annotations plus selected PNG mask outputs inside one ZIP.
`scope=all` exports the whole video. `scope=range` uses 1-based frame
numbers from the sorted project frame sequence. `scope=current` uses the
concrete backend `frame_id`.
"""
project = _project_or_404(project_id, db, current_user)
frames = _filter_frames(
_project_frames(project_id, db),
scope=scope,
start_frame=start_frame,
end_frame=end_frame,
frame_id=frame_id,
)
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
templates = db.query(Template).all()
coco = _build_coco(project, frames, annotations, templates)
class_values, class_mapping = _build_gt_class_mapping(annotations)
selected_outputs = _parse_result_outputs(mask_type, outputs)
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
zf.writestr(
"annotations_coco.json",
json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8"),
)
zf.writestr(
"maskid_GT像素值_类别映射.json",
json.dumps({"classes": class_mapping}, ensure_ascii=False, indent=2).encode("utf-8"),
)
original_images = _write_original_frames(zf, project, frames)
_write_result_mask_outputs(
zf,
project,
frames,
annotations,
outputs=selected_outputs,
class_values=class_values,
class_mapping=class_mapping,
original_images=original_images,
mix_opacity=mix_opacity,
)
zip_buffer.seek(0)
filename = _segmentation_results_filename(project, frames)
return StreamingResponse(
zip_buffer,
media_type="application/zip",
headers={"Content-Disposition": _download_content_disposition(filename)},
)

View File

@@ -9,8 +9,9 @@ from sqlalchemy.orm import Session
from database import get_db
from minio_client import upload_file, get_presigned_url
from models import ProcessingTask, Project
from models import ProcessingTask, Project, User
from progress_events import publish_task_progress_event
from routers.auth import require_editor
from schemas import ProcessingTaskOut
from statuses import PROJECT_STATUS_PARSING, PROJECT_STATUS_PENDING, TASK_STATUS_QUEUED
from worker_tasks import parse_project_media
@@ -34,6 +35,7 @@ async def upload_media(
file: UploadFile = File(...),
project_id: Optional[int] = Form(None),
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> dict:
"""Accept a video, image, or DICOM file and store it in MinIO.
@@ -62,13 +64,15 @@ async def upload_media(
file_url = get_presigned_url(object_name, expires=3600)
if project_id:
project = db.query(Project).filter(Project.id == project_id).first()
if project:
project.video_path = object_name
db.commit()
logger.info("Linked upload to project_id=%s", project_id)
else:
logger.warning("Project id=%s not found for upload linkage", project_id)
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
project.video_path = object_name
db.commit()
logger.info("Linked upload to project_id=%s", project_id)
else:
# Auto-create a project named after the file
project = Project(
@@ -77,6 +81,7 @@ async def upload_media(
status=PROJECT_STATUS_PENDING,
video_path=object_name,
source_type="video",
owner_user_id=current_user.id,
)
db.add(project)
db.commit()
@@ -109,6 +114,7 @@ async def upload_dicom_batch(
files: List[UploadFile] = File(...),
project_id: Optional[int] = Form(None),
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> dict:
"""Upload multiple .dcm files for a DICOM series.
@@ -121,7 +127,10 @@ async def upload_dicom_batch(
uploaded = []
if project_id:
project = db.query(Project).filter(Project.id == project_id).first()
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
else:
@@ -132,6 +141,7 @@ async def upload_dicom_batch(
description=f"DICOM series with {len(files)} files",
status=PROJECT_STATUS_PENDING,
source_type="dicom",
owner_user_id=current_user.id,
)
db.add(project)
db.commit()
@@ -173,13 +183,17 @@ def parse_media(
max_frames: Optional[int] = Query(None, gt=0),
target_width: int = Query(640, ge=64, le=4096),
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> ProcessingTask:
"""Create a background task for media frame extraction.
The Celery worker performs the heavy FFmpeg/OpenCV/pydicom work and
updates the persisted task record as it progresses.
"""
project = db.query(Project).filter(Project.id == project_id).first()
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")

View File

@@ -7,7 +7,8 @@ from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from database import get_db
from models import Project, Frame
from models import Project, Frame, User
from routers.auth import get_current_user, require_editor
from schemas import ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut
from minio_client import get_presigned_url
@@ -24,9 +25,13 @@ router = APIRouter(prefix="/api/projects", tags=["Projects"])
status_code=status.HTTP_201_CREATED,
summary="Create a new project",
)
def create_project(payload: ProjectCreate, db: Session = Depends(get_db)) -> Project:
def create_project(
payload: ProjectCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Project:
"""Create a new segmentation project."""
project = Project(**payload.model_dump())
project = Project(**payload.model_dump(), owner_user_id=current_user.id)
db.add(project)
db.commit()
db.refresh(project)
@@ -39,9 +44,20 @@ def create_project(payload: ProjectCreate, db: Session = Depends(get_db)) -> Pro
response_model=List[ProjectOut],
summary="List all projects",
)
def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)) -> List[Project]:
def list_projects(
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> List[Project]:
"""Retrieve a paginated list of projects."""
projects = db.query(Project).offset(skip).limit(limit).all()
projects = (
db.query(Project)
.filter(Project.owner_user_id == current_user.id)
.offset(skip)
.limit(limit)
.all()
)
for p in projects:
p.frame_count = len(p.frames)
if p.thumbnail_url:
@@ -54,9 +70,16 @@ def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)
response_model=ProjectOut,
summary="Get a single project",
)
def get_project(project_id: int, db: Session = Depends(get_db)) -> Project:
def get_project(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Project:
"""Retrieve a project by its ID."""
project = db.query(Project).filter(Project.id == project_id).first()
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
project.frame_count = len(project.frames)
@@ -74,9 +97,13 @@ def update_project(
project_id: int,
payload: ProjectUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Project:
"""Update project fields partially."""
project = db.query(Project).filter(Project.id == project_id).first()
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
@@ -94,9 +121,16 @@ def update_project(
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a project",
)
def delete_project(project_id: int, db: Session = Depends(get_db)) -> None:
def delete_project(
project_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> None:
"""Delete a project and all related frames and annotations."""
project = db.query(Project).filter(Project.id == project_id).first()
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
@@ -118,9 +152,13 @@ def create_frame(
project_id: int,
payload: FrameCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Frame:
"""Register a new frame under a project."""
project = db.query(Project).filter(Project.id == project_id).first()
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
@@ -141,9 +179,13 @@ def list_frames(
skip: int = 0,
limit: int = 1000,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> List[Frame]:
"""Retrieve all frames belonging to a project."""
project = db.query(Project).filter(Project.id == project_id).first()
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
@@ -165,11 +207,21 @@ def list_frames(
response_model=FrameOut,
summary="Get a single frame",
)
def get_frame(project_id: int, frame_id: int, db: Session = Depends(get_db)) -> Frame:
def get_frame(
project_id: int,
frame_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Frame:
"""Retrieve a specific frame by ID."""
frame = (
db.query(Frame)
.filter(Frame.project_id == project_id, Frame.id == frame_id)
.join(Project, Project.id == Frame.project_id)
.filter(
Frame.project_id == project_id,
Frame.id == frame_id,
Project.owner_user_id == current_user.id,
)
.first()
)
if not frame:

View File

@@ -9,8 +9,9 @@ from sqlalchemy.orm import Session
from celery_app import celery_app
from database import get_db
from models import ProcessingTask, Project
from models import ProcessingTask, Project, User
from progress_events import publish_task_progress_event
from routers.auth import get_current_user, require_editor
from schemas import ProcessingTaskOut
from statuses import (
PROJECT_STATUS_PARSING,
@@ -31,8 +32,16 @@ def _now() -> datetime:
return datetime.now(timezone.utc)
def _get_task_or_404(task_id: int, db: Session) -> ProcessingTask:
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
def _get_task_or_404(task_id: int, db: Session, current_user: User) -> ProcessingTask:
task = (
db.query(ProcessingTask)
.outerjoin(Project, Project.id == ProcessingTask.project_id)
.filter(
ProcessingTask.id == task_id,
(ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id),
)
.first()
)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task
@@ -48,9 +57,12 @@ def list_tasks(
status: str | None = None,
limit: int = 50,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> List[ProcessingTask]:
"""Return recent background processing tasks."""
query = db.query(ProcessingTask)
query = db.query(ProcessingTask).outerjoin(Project, Project.id == ProcessingTask.project_id).filter(
(ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id)
)
if project_id is not None:
query = query.filter(ProcessingTask.project_id == project_id)
if status is not None:
@@ -59,15 +71,23 @@ def list_tasks(
@router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task")
def get_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
def get_task(
task_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> ProcessingTask:
"""Return one background task by id."""
return _get_task_or_404(task_id, db)
return _get_task_or_404(task_id, db, current_user)
@router.post("/{task_id}/cancel", response_model=ProcessingTaskOut, summary="Cancel processing task")
def cancel_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
def cancel_task(
task_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> ProcessingTask:
"""Cancel a queued/running background task and revoke the Celery job when possible."""
task = _get_task_or_404(task_id, db)
task = _get_task_or_404(task_id, db, current_user)
if task.status not in TASK_ACTIVE_STATUSES:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
@@ -95,9 +115,13 @@ def cancel_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
@router.post("/{task_id}/retry", response_model=ProcessingTaskOut, status_code=status.HTTP_202_ACCEPTED, summary="Retry processing task")
def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
def retry_task(
task_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> ProcessingTask:
"""Create a fresh queued task from a failed or cancelled task."""
previous = _get_task_or_404(task_id, db)
previous = _get_task_or_404(task_id, db, current_user)
if previous.status not in {TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
@@ -106,7 +130,10 @@ def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
if previous.project_id is None:
raise HTTPException(status_code=400, detail="Task has no project_id")
project = db.query(Project).filter(Project.id == previous.project_id).first()
project = db.query(Project).filter(
Project.id == previous.project_id,
Project.owner_user_id == current_user.id,
).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
is_propagation_task = previous.task_type == "propagate_masks"

View File

@@ -4,10 +4,12 @@ import logging
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import or_
from sqlalchemy.orm import Session
from database import get_db
from models import Template
from models import Template, User
from routers.auth import get_current_user, require_editor
from schemas import TemplateCreate, TemplateOut, TemplateUpdate
logger = logging.getLogger(__name__)
@@ -40,11 +42,15 @@ def _unpack_template(template: Template) -> Template:
status_code=status.HTTP_201_CREATED,
summary="Create a new template",
)
def create_template(payload: TemplateCreate, db: Session = Depends(get_db)) -> Template:
def create_template(
payload: TemplateCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Template:
"""Create a new ontology template / segmentation class."""
data = payload.model_dump()
data = _pack_mapping_rules(data)
template = Template(**data)
template = Template(**data, owner_user_id=current_user.id)
db.add(template)
db.commit()
db.refresh(template)
@@ -62,9 +68,16 @@ def list_templates(
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> List[Template]:
"""Retrieve all ontology templates."""
templates = db.query(Template).offset(skip).limit(limit).all()
templates = (
db.query(Template)
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
.offset(skip)
.limit(limit)
.all()
)
for t in templates:
_unpack_template(t)
return templates
@@ -75,9 +88,16 @@ def list_templates(
response_model=TemplateOut,
summary="Get a single template",
)
def get_template(template_id: int, db: Session = Depends(get_db)) -> Template:
def get_template(
template_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> Template:
"""Retrieve a template by its ID."""
template = db.query(Template).filter(Template.id == template_id).first()
template = db.query(Template).filter(
Template.id == template_id,
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
_unpack_template(template)
@@ -93,9 +113,13 @@ def update_template(
template_id: int,
payload: TemplateUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Template:
"""Update template fields partially."""
template = db.query(Template).filter(Template.id == template_id).first()
template = db.query(Template).filter(
Template.id == template_id,
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
@@ -118,9 +142,16 @@ def update_template(
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a template",
)
def delete_template(template_id: int, db: Session = Depends(get_db)) -> None:
def delete_template(
template_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> None:
"""Delete a template. Associated annotations will have template_id set to NULL."""
template = db.query(Template).filter(Template.id == template_id).first()
template = db.query(Template).filter(
Template.id == template_id,
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")

View File

@@ -5,6 +5,55 @@ from typing import Optional, Any
from pydantic import BaseModel, ConfigDict
# ---------------------------------------------------------------------------
# Auth / user schemas
# ---------------------------------------------------------------------------
class UserOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
username: str
role: str
is_active: int
class LoginResponse(BaseModel):
token: str
token_type: str = "bearer"
username: str
user: UserOut
class AdminUserCreate(BaseModel):
username: str
password: str
role: str = "annotator"
is_active: bool = True
class AdminUserUpdate(BaseModel):
username: Optional[str] = None
password: Optional[str] = None
role: Optional[str] = None
is_active: Optional[bool] = None
class AuditLogOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
actor_user_id: Optional[int] = None
action: str
target_type: Optional[str] = None
target_id: Optional[str] = None
detail: Optional[dict[str, Any]] = None
created_at: datetime
class DemoFactoryResetRequest(BaseModel):
confirmation: str
# ---------------------------------------------------------------------------
# Project schemas
# ---------------------------------------------------------------------------
@@ -38,11 +87,19 @@ class ProjectOut(ProjectBase):
model_config = ConfigDict(from_attributes=True)
id: int
owner_user_id: Optional[int] = None
created_at: datetime
updated_at: datetime
frame_count: int = 0
class DemoFactoryResetOut(BaseModel):
admin_user: UserOut
project: ProjectOut
deleted_counts: dict[str, int]
message: str
# ---------------------------------------------------------------------------
# Frame schemas
# ---------------------------------------------------------------------------
@@ -98,6 +155,7 @@ class TemplateOut(TemplateBase):
model_config = ConfigDict(from_attributes=True)
id: int
owner_user_id: Optional[int] = None
created_at: datetime

View File

@@ -102,8 +102,14 @@ def _normalize_smoothing_options(value: Any) -> dict[str, Any] | None:
return {"strength": round(strength, 2), "method": method}
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int) -> list[list[float]]:
def _smoothing_ratio(strength: float, curve: float = 1.65) -> float:
normalized = max(0.0, min(float(strength or 0.0), 100.0)) / 100.0
return normalized ** curve
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int, corner_cut: float = 0.25) -> list[list[float]]:
points = _normalize_polygon(polygon)
q = max(0.02, min(float(corner_cut), 0.25))
for _ in range(max(0, iterations)):
if len(points) < 3:
break
@@ -111,12 +117,12 @@ def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int) -> list
for index, current in enumerate(points):
following = points[(index + 1) % len(points)]
next_points.append([
_clamp01(0.75 * current[0] + 0.25 * following[0]),
_clamp01(0.75 * current[1] + 0.25 * following[1]),
_clamp01((1.0 - q) * current[0] + q * following[0]),
_clamp01((1.0 - q) * current[1] + q * following[1]),
])
next_points.append([
_clamp01(0.25 * current[0] + 0.75 * following[0]),
_clamp01(0.25 * current[1] + 0.75 * following[1]),
_clamp01(q * current[0] + (1.0 - q) * following[0]),
_clamp01(q * current[1] + (1.0 - q) * following[1]),
])
points = next_points
return points
@@ -127,7 +133,7 @@ def _simplify_polygon(polygon: list[list[float]], strength: float) -> list[list[
return polygon
contour = np.array([[[point[0], point[1]]] for point in polygon], dtype=np.float32)
arc_length = cv2.arcLength(contour, True)
epsilon = arc_length * (0.001 + (strength / 100.0) * 0.006)
epsilon = arc_length * (0.00015 + _smoothing_ratio(strength) * 0.00735)
approx = cv2.approxPolyDP(contour, epsilon, True).reshape(-1, 2)
if len(approx) < 3:
return polygon
@@ -140,9 +146,25 @@ def _smooth_polygon(polygon: list[list[float]], smoothing: dict[str, Any] | None
strength = float(smoothing.get("strength") or 0.0)
if strength <= 0:
return _normalize_polygon(polygon)
iterations = max(1, min(3, int(strength // 35) + 1))
smoothed = _chaikin_smooth_polygon(polygon, iterations)
simplified = _simplify_polygon(smoothed, strength)
effective_strength = _smoothing_ratio(strength, curve=1.45) * 100.0
if effective_strength >= 85:
iterations = 4
elif effective_strength >= 55:
iterations = 3
elif effective_strength >= 25:
iterations = 2
else:
iterations = 1
corner_cut = 0.03 + _smoothing_ratio(strength, curve=1.35) * 0.22
normalized = _normalize_polygon(polygon)
pre_simplified = _simplify_polygon(normalized, effective_strength * 0.25)
smoothed = _chaikin_smooth_polygon(pre_simplified, iterations, corner_cut)
simplified = _simplify_polygon(smoothed, effective_strength)
if len(simplified) > len(normalized):
for fallback_strength in (25.0, 35.0, 50.0, 70.0, 90.0, 100.0):
simplified = _simplify_polygon(simplified, max(effective_strength, fallback_strength))
if len(simplified) <= len(normalized):
break
return simplified if len(simplified) >= 3 else _normalize_polygon(polygon)

View File

@@ -19,7 +19,7 @@ if str(BACKEND_DIR) not in sys.path:
from database import Base, get_db # noqa: E402
from main import websocket_progress # noqa: E402
from routers import ai, auth, dashboard, export, media, projects, tasks, templates # noqa: E402
from routers import admin, ai, auth, dashboard, export, media, projects, tasks, templates # noqa: E402
@pytest.fixture()
@@ -32,6 +32,7 @@ def db_session() -> Iterator[Session]:
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
session = TestingSessionLocal()
auth.ensure_default_admin(session)
try:
yield session
finally:
@@ -56,6 +57,7 @@ def app(db_session: Session) -> FastAPI:
test_app.include_router(export.router)
test_app.include_router(dashboard.router)
test_app.include_router(tasks.router)
test_app.include_router(admin.router)
@test_app.get("/health")
def health_check() -> dict[str, str]:
@@ -67,6 +69,10 @@ def app(db_session: Session) -> FastAPI:
@pytest.fixture()
def client(app: FastAPI) -> Iterator[TestClient]:
def client(app: FastAPI, db_session: Session) -> Iterator[TestClient]:
with TestClient(app) as test_client:
admin = auth.ensure_default_admin(db_session)
test_client.headers.update({
"Authorization": f"Bearer {auth.create_access_token(admin)}"
})
yield test_client

158
backend/tests/test_admin.py Normal file
View File

@@ -0,0 +1,158 @@
from models import Annotation, AuditLog, Frame, Mask, ProcessingTask, Project, Template, User
from routers.auth import create_access_token, hash_password
from statuses import PROJECT_STATUS_PENDING
def test_admin_user_management_and_audit_logs(client, db_session):
created = client.post("/api/admin/users", json={
"username": "doctor",
"password": "secret123",
"role": "annotator",
"is_active": True,
})
assert created.status_code == 201
user_id = created.json()["id"]
updated = client.patch(f"/api/admin/users/{user_id}", json={
"role": "viewer",
"password": "newsecret",
"is_active": False,
})
assert updated.status_code == 200
assert updated.json()["role"] == "viewer"
assert updated.json()["is_active"] == 0
users = client.get("/api/admin/users")
assert users.status_code == 200
assert any(user["username"] == "doctor" for user in users.json())
deleted = client.delete(f"/api/admin/users/{user_id}")
assert deleted.status_code == 204
logs = client.get("/api/admin/audit-logs")
assert logs.status_code == 200
actions = [log["action"] for log in logs.json()]
assert "admin.user_created" in actions
assert "admin.user_updated" in actions
assert "admin.user_deleted" in actions
def test_admin_routes_require_admin_role(client, db_session):
user = User(username="viewer", password_hash=hash_password("secret123"), role="viewer", is_active=1)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
original_auth = client.headers["Authorization"]
client.headers.update({"Authorization": f"Bearer {create_access_token(user)}"})
try:
response = client.get("/api/admin/users")
assert response.status_code == 403
finally:
client.headers.update({"Authorization": original_auth})
def test_viewer_role_is_read_only_for_business_mutations(client, db_session):
project = client.post("/api/projects", json={"name": "Readonly Check"}).json()
user = User(username="readonly", password_hash=hash_password("secret123"), role="viewer", is_active=1)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
original_auth = client.headers["Authorization"]
client.headers.update({"Authorization": f"Bearer {create_access_token(user)}"})
try:
assert client.get("/api/projects").status_code == 200
assert client.post("/api/projects", json={"name": "Nope"}).status_code == 403
assert client.patch(f"/api/projects/{project['id']}", json={"name": "Nope"}).status_code == 403
assert client.post("/api/ai/annotate", json={"project_id": project["id"]}).status_code == 403
finally:
client.headers.update({"Authorization": original_auth})
def test_admin_cannot_delete_self_or_user_with_projects(client, db_session):
me = client.get("/api/auth/me").json()
assert client.delete(f"/api/admin/users/{me['id']}").status_code == 400
user = User(username="owner", password_hash=hash_password("secret123"), role="annotator", is_active=1)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
db_session.add(Project(name="Owned", owner_user_id=user.id))
db_session.commit()
response = client.delete(f"/api/admin/users/{user.id}")
assert response.status_code == 409
def test_demo_factory_reset_leaves_admin_and_unparsed_demo_video(client, db_session, monkeypatch, tmp_path):
video_path = tmp_path / "Data_MyVideo_1.mp4"
video_path.write_bytes(b"demo-video")
monkeypatch.setattr("routers.admin.settings.demo_video_path", str(video_path))
uploaded = []
monkeypatch.setattr("routers.admin.upload_file", lambda object_name, data, content_type, length: uploaded.append({
"object_name": object_name,
"data": data,
"content_type": content_type,
"length": length,
}))
extra_user = User(username="doctor", password_hash=hash_password("secret123"), role="annotator", is_active=1)
db_session.add(extra_user)
db_session.commit()
db_session.refresh(extra_user)
old_project = Project(name="Old", owner_user_id=extra_user.id, video_path="uploads/old.mp4")
db_session.add(old_project)
db_session.commit()
db_session.refresh(old_project)
frame = Frame(project_id=old_project.id, frame_index=0, image_url="frames/old.jpg")
db_session.add(frame)
task = ProcessingTask(task_type="parse_video", project_id=old_project.id)
private_template = Template(
name="Private",
description="private",
color="#fff",
z_index=1,
owner_user_id=extra_user.id,
)
db_session.add_all([task, private_template])
db_session.commit()
db_session.refresh(frame)
annotation = Annotation(project_id=old_project.id, frame_id=frame.id, mask_data={"label": "old"})
db_session.add(annotation)
db_session.commit()
db_session.refresh(annotation)
db_session.add(Mask(annotation_id=annotation.id, mask_url="masks/old.png"))
db_session.add(AuditLog(actor_user_id=extra_user.id, action="old.audit"))
db_session.commit()
response = client.post("/api/admin/demo-factory-reset", json={"confirmation": "RESET_DEMO_FACTORY"})
assert response.status_code == 200
data = response.json()
assert data["message"] == "演示环境已恢复出厂设置"
assert data["admin_user"]["username"] == "admin"
assert data["project"]["name"] == "Data_MyVideo_1"
assert data["project"]["status"] == PROJECT_STATUS_PENDING
assert data["project"]["frame_count"] == 0
assert data["project"]["video_path"] == f"uploads/{data['project']['id']}/Data_MyVideo_1.mp4"
assert uploaded == [{
"object_name": data["project"]["video_path"],
"data": b"demo-video",
"content_type": "video/mp4",
"length": len(b"demo-video"),
}]
assert [user.username for user in db_session.query(User).all()] == ["admin"]
assert db_session.query(Project).count() == 1
assert db_session.query(Frame).count() == 0
assert db_session.query(Annotation).count() == 0
assert db_session.query(Mask).count() == 0
assert db_session.query(ProcessingTask).count() == 0
assert db_session.query(Template).filter(Template.owner_user_id.is_not(None)).count() == 0
assert db_session.query(AuditLog).count() == 1
assert db_session.query(AuditLog).first().action == "admin.demo_factory_reset"
def test_demo_factory_reset_requires_exact_confirmation(client):
response = client.post("/api/admin/demo-factory-reset", json={"confirmation": "reset"})
assert response.status_code == 400

View File

@@ -223,6 +223,88 @@ def test_analyze_mask_returns_backend_geometry_properties(client):
assert body["message"] == "已从后端重新提取几何拓扑锚点"
def test_analyze_mask_reports_actual_polygon_anchor_count(client):
_, frame, _ = _create_project_and_frame(client)
polygon = [[0.1 + index * 0.005, 0.1 + (0.01 if index % 2 else 0)] for index in range(80)]
response = client.post("/api/ai/analyze-mask", json={
"frame_id": frame["id"],
"mask_data": {
"polygons": [polygon],
"label": "AI Mask",
"color": "#06b6d4",
},
"points": [[0.2, 0.2]],
})
assert response.status_code == 200
body = response.json()
assert body["topology_anchor_count"] == len(polygon)
assert len(body["topology_anchors"]) <= 64
def test_smooth_mask_simplifies_noisy_ai_polygon(client):
_, frame, _ = _create_project_and_frame(client)
polygon = []
for index in range(20):
polygon.append([0.1 + index * 0.02, 0.1 + (0.01 if index % 2 else 0)])
for index in range(20):
polygon.append([0.5 + (0.01 if index % 2 else 0), 0.1 + index * 0.02])
for index in range(20):
polygon.append([0.5 - index * 0.02, 0.5 + (0.01 if index % 2 else 0)])
for index in range(20):
polygon.append([0.1 + (0.01 if index % 2 else 0), 0.5 - index * 0.02])
response = client.post("/api/ai/smooth-mask", json={
"frame_id": frame["id"],
"mask_data": {
"polygons": [polygon],
"label": "AI Mask",
"color": "#06b6d4",
},
"strength": 80,
})
assert response.status_code == 200
body = response.json()
assert body["topology_anchor_count"] == len(body["polygons"][0])
assert len(body["polygons"][0]) < len(polygon)
def test_smooth_mask_uses_eased_strength_curve(client):
_, frame, _ = _create_project_and_frame(client)
polygon = []
for index in range(20):
polygon.append([0.1 + index * 0.02, 0.1 + (0.01 if index % 2 else 0)])
for index in range(20):
polygon.append([0.5 + (0.01 if index % 2 else 0), 0.1 + index * 0.02])
for index in range(20):
polygon.append([0.5 - index * 0.02, 0.5 + (0.01 if index % 2 else 0)])
for index in range(20):
polygon.append([0.1 + (0.01 if index % 2 else 0), 0.5 - index * 0.02])
def smoothed_count(strength: int) -> int:
response = client.post("/api/ai/smooth-mask", json={
"frame_id": frame["id"],
"mask_data": {
"polygons": [polygon],
"label": "AI Mask",
"color": "#06b6d4",
},
"strength": strength,
})
assert response.status_code == 200
return len(response.json()["polygons"][0])
low_count = smoothed_count(20)
mid_count = smoothed_count(70)
high_count = smoothed_count(95)
assert low_count <= len(polygon)
assert mid_count < low_count
assert high_count < mid_count
def test_smooth_mask_returns_backend_smoothed_geometry(client):
_, frame, _ = _create_project_and_frame(client)
@@ -311,6 +393,7 @@ def test_propagate_saves_tracked_annotations(client, monkeypatch):
"color": "#ff0000",
"class_metadata": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
"template_id": None,
"smoothing": {"strength": 45, "method": "chaikin"},
},
})
@@ -327,6 +410,9 @@ def test_propagate_saves_tracked_annotations(client, monkeypatch):
assert saved["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation"
assert saved["mask_data"]["class"]["name"] == "胆囊"
assert saved["mask_data"]["score"] == 0.8
assert saved["mask_data"]["geometry_smoothing"] == {"strength": 45.0, "method": "chaikin"}
assert saved["mask_data"]["polygons"][0] != [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]
assert len(saved["mask_data"]["polygons"][0]) > 3
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
assert len(listing.json()) == 1
@@ -490,8 +576,10 @@ def test_propagation_task_runner_saves_annotations_and_progress(client, db_sessi
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
assert listing.json()[0]["frame_id"] == frames[1]["id"]
assert listing.json()[0]["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation"
stored_polygon = listing.json()[0]["mask_data"]["polygons"][0]
assert listing.json()[0]["mask_data"]["geometry_smoothing"] == {"strength": 40.0, "method": "chaikin"}
assert len(listing.json()[0]["mask_data"]["polygons"][0]) > 3
assert stored_polygon != [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]
assert len(stored_polygon) > 3
def test_propagation_task_runner_skips_unchanged_seed_and_replaces_changed_seed(client, db_session, monkeypatch):
@@ -1084,3 +1172,156 @@ def test_import_gt_mask_splits_label_values(client):
assert [item["mask_data"]["gt_label_value"] for item in body] == [1, 2]
assert [item["mask_data"]["label"] for item in body] == ["GT Class 1", "GT Class 2"]
assert all(len(item["points"]) == 1 for item in body)
def test_import_gt_mask_preserves_low_value_gtlabel_png(client):
project, frame, _ = _create_project_and_frame(client)
template = client.post("/api/templates", json={
"name": "GTLabel Template",
"color": "#06b6d4",
"z_index": 0,
"classes": [
{"id": "tumor", "name": "肿瘤", "color": "#ff0000", "zIndex": 10, "maskId": 1},
],
"rules": [],
}).json()
mask = np.zeros((360, 640), dtype=np.uint16)
cv2.rectangle(mask, (40, 40), (140, 140), 1, thickness=-1)
ok, encoded = cv2.imencode(".png", mask)
assert ok
response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(frame["id"]),
"template_id": str(template["id"]),
"unknown_color_policy": "discard",
},
files={"file": ("GT_label.png", encoded.tobytes(), "image/png")},
)
assert response.status_code == 201
body = response.json()
assert len(body) == 1
assert body[0]["mask_data"]["gt_label_value"] == 1
assert body[0]["mask_data"]["class"]["name"] == "肿瘤"
assert body[0]["mask_data"]["class"]["maskId"] == 1
def test_import_gt_mask_rejects_rgb_color_masks(client):
project, frame, _ = _create_project_and_frame(client)
template = client.post("/api/templates", json={
"name": "Color Template",
"color": "#06b6d4",
"z_index": 0,
"classes": [
{"id": "known", "name": "已知类别", "color": "#ff0000", "zIndex": 10, "maskId": 1},
],
"rules": [],
}).json()
mask = np.zeros((80, 120, 3), dtype=np.uint8)
mask[10:40, 10:40] = [0, 0, 255] # BGR red -> #ff0000
mask[40:70, 70:110] = [0, 255, 0] # BGR green -> unknown #00ff00
ok, encoded = cv2.imencode(".png", mask)
assert ok
response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(frame["id"]),
"template_id": str(template["id"]),
"unknown_color_policy": "discard",
},
files={"file": ("color-mask.png", encoded.tobytes(), "image/png")},
)
assert response.status_code == 400
assert "RGB 三通道完全相同" in response.json()["detail"]
def test_import_gt_mask_reads_uint16_gt_label_and_maps_maskid_class(client):
project, frame, _ = _create_project_and_frame(client)
template = client.post("/api/templates", json={
"name": "Label Template",
"color": "#06b6d4",
"z_index": 0,
"classes": [{"id": "tumor", "name": "肿瘤", "color": "#ff0000", "zIndex": 10, "maskId": 1}],
"rules": [],
}).json()
mask = np.zeros((360, 640), dtype=np.uint16)
cv2.rectangle(mask, (20, 20), (120, 120), 1, thickness=-1)
ok, encoded = cv2.imencode(".png", mask)
assert ok
response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(frame["id"]),
"template_id": str(template["id"]),
"unknown_color_policy": "discard",
},
files={"file": ("gt_label.png", encoded.tobytes(), "image/png")},
)
assert response.status_code == 201
body = response.json()
assert len(body) == 1
assert body[0]["mask_data"]["gt_label_value"] == 1
assert body[0]["mask_data"]["label"] == "肿瘤"
assert body[0]["mask_data"]["class"]["maskId"] == 1
assert body[0]["mask_data"]["class"]["color"] == "#ff0000"
def test_import_gt_mask_handles_unknown_maskid_policy_and_resizes_to_frame(client):
project, frame, _ = _create_project_and_frame(client)
template = client.post("/api/templates", json={
"name": "Color Template",
"color": "#06b6d4",
"z_index": 0,
"classes": [{"id": "known", "name": "已定义", "color": "#ff0000", "zIndex": 10, "maskId": 1}],
"rules": [],
}).json()
mask = np.zeros((90, 160, 3), dtype=np.uint8)
cv2.rectangle(mask, (5, 5), (40, 40), (1, 1, 1), thickness=-1)
cv2.rectangle(mask, (80, 5), (120, 40), (2, 2, 2), thickness=-1)
ok, encoded = cv2.imencode(".png", mask)
assert ok
discard_response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(frame["id"]),
"template_id": str(template["id"]),
"unknown_color_policy": "discard",
},
files={"file": ("colors.png", encoded.tobytes(), "image/png")},
)
assert discard_response.status_code == 201
assert [item["mask_data"]["label"] for item in discard_response.json()] == ["已定义"]
assert discard_response.json()[0]["mask_data"]["gt_original_size"] == {"width": 160, "height": 90}
assert discard_response.json()[0]["mask_data"]["gt_resized_to_frame"] is True
assert discard_response.json()[0]["mask_data"]["image_size"] == {"width": 640, "height": 360}
undefined_response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(frame["id"]),
"template_id": str(template["id"]),
"unknown_color_policy": "undefined",
},
files={"file": ("colors.png", encoded.tobytes(), "image/png")},
)
assert undefined_response.status_code == 201
labels = {item["mask_data"]["label"] for item in undefined_response.json()}
assert labels == {"已定义", "未定义类别 2"}
unknown = next(item for item in undefined_response.json() if item["mask_data"]["label"].startswith("未定义"))
assert unknown["mask_data"]["gt_unknown_class"] is True
assert unknown["mask_data"]["gt_label_value"] == 2
assert unknown["mask_data"]["gt_resized_to_frame"] is True

View File

@@ -2,10 +2,11 @@ def test_login_success(client):
response = client.post("/api/auth/login", json={"username": "admin", "password": "123456"})
assert response.status_code == 200
assert response.json() == {
"token": "fake-jwt-token-for-admin",
"username": "admin",
}
body = response.json()
assert body["token"]
assert body["token_type"] == "bearer"
assert body["username"] == "admin"
assert body["user"]["username"] == "admin"
def test_login_rejects_invalid_credentials(client):
@@ -13,3 +14,19 @@ def test_login_rejects_invalid_credentials(client):
assert response.status_code == 401
assert response.json()["detail"] == "Invalid credentials"
def test_me_returns_current_user(client):
response = client.get("/api/auth/me")
assert response.status_code == 200
assert response.json()["username"] == "admin"
def test_business_routes_require_auth(app):
from fastapi.testclient import TestClient
with TestClient(app) as unauthenticated:
response = unauthenticated.get("/api/projects")
assert response.status_code == 401

View File

@@ -1,19 +1,31 @@
import zipfile
import json
from io import BytesIO
from urllib.parse import unquote
import cv2
import numpy as np
def _fake_image_bytes(width=100, height=50, color=(255, 255, 255)):
image = np.full((height, width, 3), color, dtype=np.uint8)
_, encoded = cv2.imencode(".jpg", image)
return encoded.tobytes()
def _seed_export_data(client):
project = client.post("/api/projects", json={"name": "Export Project"}).json()
project = client.post("/api/projects", json={
"name": "Export Project",
"video_path": "uploads/1/clip.mp4",
}).json()
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 0,
"image_url": "frames/0.jpg",
"width": 100,
"height": 50,
"timestamp_ms": 1250.0,
"source_frame_number": 37,
}).json()
template = client.post("/api/templates", json={
"name": "Category",
@@ -113,6 +125,328 @@ def test_export_masks_uses_z_index_for_semantic_fusion(client):
assert semantic[10, 10] == high_value
def test_export_results_zip_contains_coco_original_images_and_selected_mask_outputs(client, monkeypatch):
project, _, _, annotation = _seed_export_data(client)
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes())
response = client.get(f"/api/export/{project['id']}/results?scope=all&mask_type=both")
assert response.status_code == 200
assert response.headers["content-type"].startswith("application/zip")
with zipfile.ZipFile(BytesIO(response.content)) as archive:
names = archive.namelist()
frame_stem = "clip_0h00m01s250ms_frame000001"
assert "annotations_coco.json" in names
assert "maskid_GT像素值_类别映射.json" in names
assert f"原始图片/{frame_stem}.jpg" in names
assert f"分开Mask分割结果/{frame_stem}_分别导出/{frame_stem}_Category_maskid1.png" in names
assert f"GT_label图/{frame_stem}.png" in names
assert f"Pro_label彩色分割结果/{frame_stem}.png" in names
assert f"Mix_label重叠覆盖彩色分割结果/{frame_stem}.png" in names
coco = json.loads(archive.read("annotations_coco.json"))
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
label_bytes = np.frombuffer(archive.read(f"GT_label图/{frame_stem}.png"), dtype=np.uint8)
gt_label = cv2.imdecode(label_bytes, cv2.IMREAD_UNCHANGED)
pro_label = cv2.imdecode(
np.frombuffer(archive.read(f"Pro_label彩色分割结果/{frame_stem}.png"), dtype=np.uint8),
cv2.IMREAD_COLOR,
)
mix_label = cv2.imdecode(
np.frombuffer(archive.read(f"Mix_label重叠覆盖彩色分割结果/{frame_stem}.png"), dtype=np.uint8),
cv2.IMREAD_COLOR,
)
assert coco["images"][0]["frame_index"] == 0
assert coco["annotations"][0]["image_id"] == annotation["frame_id"]
assert mapping["classes"] == [{
"gt_pixel_value": 1,
"maskid": 1,
"chineseName": "Category",
"className": "Category",
"categoryName": "Category",
"rgb": [6, 182, 212],
"color": "#06b6d4",
"key": f"template:{annotation['template_id']}",
"template_id": annotation["template_id"],
}]
assert gt_label[0, 0] == 0
assert gt_label[20, 50] == 1
assert pro_label[20, 50].tolist() == [212, 182, 6]
assert pro_label[0, 0].tolist() == [0, 0, 0]
assert mix_label[20, 50].tolist() != [255, 255, 255]
def test_export_results_uses_internal_layer_order_for_gt_pro_and_mix_outputs(client, monkeypatch):
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
project = client.post("/api/projects", json={
"name": "Layered Export Project",
"video_path": "uploads/2/layered.mp4",
}).json()
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 0,
"image_url": "frames/layered.jpg",
"width": 20,
"height": 20,
"timestamp_ms": 0,
"source_frame_number": 0,
}).json()
client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"mask_data": {
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
"label": "Low",
"color": "#00ff00",
"class": {"id": "low", "name": "Low", "color": "#00ff00", "zIndex": 10},
},
})
client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"mask_data": {
"polygons": [[[0.4, 0.4], [0.9, 0.4], [0.9, 0.9], [0.4, 0.9]]],
"label": "High",
"color": "#ff0000",
"class": {"id": "high", "name": "High", "color": "#ff0000", "zIndex": 20},
},
})
response = client.get(
f"/api/export/{project['id']}/results?scope=all&outputs=gt_label,pro_label,mix_label&mix_opacity=0.5",
)
assert response.status_code == 200
with zipfile.ZipFile(BytesIO(response.content)) as archive:
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
high_value = next(item["maskid"] for item in mapping["classes"] if item["key"] == "class:high")
stem = "layered_0h00m00s000ms_frame000001"
gt_label = cv2.imdecode(
np.frombuffer(archive.read(f"GT_label图/{stem}.png"), dtype=np.uint8),
cv2.IMREAD_UNCHANGED,
)
pro_label = cv2.imdecode(
np.frombuffer(archive.read(f"Pro_label彩色分割结果/{stem}.png"), dtype=np.uint8),
cv2.IMREAD_COLOR,
)
mix_label = cv2.imdecode(
np.frombuffer(archive.read(f"Mix_label重叠覆盖彩色分割结果/{stem}.png"), dtype=np.uint8),
cv2.IMREAD_COLOR,
)
assert gt_label[10, 10] == high_value
assert pro_label[10, 10].tolist() == [0, 0, 255]
assert mix_label[10, 10].tolist() == [127, 127, 255]
def test_export_results_supports_range_and_current_scope(client, monkeypatch):
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
project = client.post("/api/projects", json={
"name": "Scoped Export Project",
"video_path": "uploads/9/scope.mp4",
"parse_fps": 2,
}).json()
template = client.post("/api/templates", json={
"name": "Scoped Category",
"color": "#06b6d4",
"z_index": 0,
"classes": [],
"rules": [],
}).json()
frames = []
annotations = []
for idx in range(3):
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": idx,
"image_url": f"frames/{idx}.jpg",
"width": 20,
"height": 20,
"timestamp_ms": idx * 500.0,
"source_frame_number": idx * 10,
}).json()
frames.append(frame)
annotations.append(client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"template_id": template["id"],
"mask_data": {"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]]},
}).json())
range_response = client.get(
f"/api/export/{project['id']}/results?scope=range&start_frame=2&end_frame=3&mask_type=gt_label",
)
current_response = client.get(
f"/api/export/{project['id']}/results?scope=current&frame_id={frames[1]['id']}&mask_type=separate",
)
assert range_response.status_code == 200
assert "Scoped_Export_Project_seg_T_0h00m00s500ms-0h00m01s000ms_P_2-3.zip" in unquote(
range_response.headers["content-disposition"],
)
with zipfile.ZipFile(BytesIO(range_response.content)) as archive:
names = archive.namelist()
coco = json.loads(archive.read("annotations_coco.json"))
assert "原始图片/scope_0h00m00s500ms_frame000002.jpg" in names
assert "原始图片/scope_0h00m01s000ms_frame000003.jpg" in names
assert "原始图片/scope_0h00m00s000ms_frame000001.jpg" not in names
assert "GT_label图/scope_0h00m00s500ms_frame000002.png" in names
assert "GT_label图/scope_0h00m01s000ms_frame000003.png" in names
assert "GT_label图/scope_0h00m00s000ms_frame000001.png" not in names
assert not any(name.startswith("分开Mask分割结果/") for name in names)
assert not any(name.startswith("Pro_label彩色分割结果/") for name in names)
assert not any(name.startswith("Mix_label重叠覆盖彩色分割结果/") for name in names)
assert [image["frame_index"] for image in coco["images"]] == [1, 2]
assert current_response.status_code == 200
with zipfile.ZipFile(BytesIO(current_response.content)) as archive:
names = archive.namelist()
coco = json.loads(archive.read("annotations_coco.json"))
current_stem = "scope_0h00m00s500ms_frame000002"
assert f"原始图片/{current_stem}.jpg" in names
assert f"分开Mask分割结果/{current_stem}_分别导出/{current_stem}_Scoped_Category_maskid1.png" in names
assert f"分开Mask分割结果/scope_0h00m00s000ms_frame000001_分别导出/scope_0h00m00s000ms_frame000001_Scoped_Category_maskid1.png" not in names
assert not any(name.startswith("GT_label图/") for name in names)
assert not any(name.startswith("Pro_label彩色分割结果/") for name in names)
assert not any(name.startswith("Mix_label重叠覆盖彩色分割结果/") for name in names)
assert [image["id"] for image in coco["images"]] == [frames[1]["id"]]
def test_export_results_preserves_template_maskid_consistently_across_frames(client, monkeypatch):
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
project = client.post("/api/projects", json={
"name": "MaskId Export Project",
"video_path": "uploads/8/maskid-demo.mp4",
"parse_fps": 1,
}).json()
frames = []
for idx in range(2):
frames.append(client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": idx,
"image_url": f"frames/{idx}.jpg",
"width": 20,
"height": 20,
"timestamp_ms": idx * 1000.0,
"source_frame_number": idx,
}).json())
client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frames[-1]["id"],
"mask_data": {
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
"label": "Tumor",
"color": "#ff0000",
"class": {"id": "tumor", "name": "Tumor", "color": "#ff0000", "maskId": 7, "zIndex": 30},
},
})
response = client.get(f"/api/export/{project['id']}/results?scope=all&mask_type=both")
assert response.status_code == 200
with zipfile.ZipFile(BytesIO(response.content)) as archive:
names = archive.namelist()
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
first_stem = "maskid-demo_0h00m00s000ms_frame000001"
second_stem = "maskid-demo_0h00m01s000ms_frame000002"
assert f"分开Mask分割结果/{first_stem}_分别导出/{first_stem}_Tumor_maskid7.png" in names
assert f"分开Mask分割结果/{second_stem}_分别导出/{second_stem}_Tumor_maskid7.png" in names
first_label = cv2.imdecode(np.frombuffer(archive.read(f"GT_label图/{first_stem}.png"), dtype=np.uint8), cv2.IMREAD_UNCHANGED)
second_label = cv2.imdecode(np.frombuffer(archive.read(f"GT_label图/{second_stem}.png"), dtype=np.uint8), cv2.IMREAD_UNCHANGED)
assert mapping["classes"] == [{
"gt_pixel_value": 7,
"maskid": 7,
"chineseName": "Tumor",
"className": "Tumor",
"categoryName": "",
"rgb": [255, 0, 0],
"color": "#ff0000",
"key": "class:tumor",
"template_id": None,
}]
assert first_label[5, 5] == 7
assert second_label[5, 5] == 7
def test_exported_gtlabel_round_trips_through_gt_mask_import_with_template_maskid(client, monkeypatch):
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
project = client.post("/api/projects", json={
"name": "GT Roundtrip Project",
"video_path": "uploads/8/roundtrip.mp4",
}).json()
template = client.post("/api/templates", json={
"name": "Roundtrip Template",
"color": "#06b6d4",
"z_index": 0,
"classes": [
{"id": "tumor", "name": "Tumor", "color": "#ff0000", "zIndex": 30, "maskId": 7},
],
"rules": [],
}).json()
source_frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 0,
"image_url": "frames/source.jpg",
"width": 20,
"height": 20,
"timestamp_ms": 0,
}).json()
target_frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 1,
"image_url": "frames/target.jpg",
"width": 20,
"height": 20,
"timestamp_ms": 1000,
}).json()
client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": source_frame["id"],
"template_id": template["id"],
"mask_data": {
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
"label": "Tumor",
"color": "#ff0000",
"class": {"id": "tumor", "name": "Tumor", "color": "#ff0000", "maskId": 7, "zIndex": 30},
},
})
export_response = client.get(
f"/api/export/{project['id']}/results?scope=current&frame_id={source_frame['id']}&outputs=gt_label",
)
assert export_response.status_code == 200
with zipfile.ZipFile(BytesIO(export_response.content)) as archive:
stem = "roundtrip_0h00m00s000ms_frame000001"
exported_gt_label = archive.read(f"GT_label图/{stem}.png")
gt_label = cv2.imdecode(np.frombuffer(exported_gt_label, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
assert gt_label[5, 5] == 7
assert mapping["classes"][0]["maskid"] == 7
import_response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(target_frame["id"]),
"template_id": str(template["id"]),
"unknown_color_policy": "discard",
},
files={"file": ("exported_gt_label.png", exported_gt_label, "image/png")},
)
assert import_response.status_code == 201
imported = import_response.json()
assert len(imported) == 1
assert imported[0]["frame_id"] == target_frame["id"]
assert imported[0]["mask_data"]["gt_label_value"] == 7
assert imported[0]["mask_data"]["label"] == "Tumor"
assert imported[0]["mask_data"]["class"]["maskId"] == 7
def test_export_missing_project_returns_404(client):
assert client.get("/api/export/999/coco").status_code == 404
assert client.get("/api/export/999/masks").status_code == 404
assert client.get("/api/export/999/results").status_code == 404

View File

@@ -1,4 +1,5 @@
from models import Annotation, Frame, Mask, ProcessingTask, Project
from models import Annotation, Frame, Mask, ProcessingTask, Project, User
from routers.auth import create_access_token, hash_password
def test_project_crud_and_frames(client, monkeypatch):
@@ -93,3 +94,33 @@ def test_project_and_frame_404s(client):
}).status_code == 404
assert client.get("/api/projects/999/frames").status_code == 404
assert client.get("/api/projects/999/frames/1").status_code == 404
def test_projects_are_scoped_to_authenticated_owner(client, db_session):
owner_project = client.post("/api/projects", json={"name": "Owner Project"}).json()
other_user = User(
username="other",
password_hash=hash_password("pass"),
role="annotator",
is_active=1,
)
db_session.add(other_user)
db_session.commit()
db_session.refresh(other_user)
other_project = Project(name="Other Project", owner_user_id=other_user.id)
db_session.add(other_project)
db_session.commit()
db_session.refresh(other_project)
listing = client.get("/api/projects")
assert [project["id"] for project in listing.json()] == [owner_project["id"]]
assert client.get(f"/api/projects/{other_project.id}").status_code == 404
original_auth = client.headers["Authorization"]
client.headers.update({"Authorization": f"Bearer {create_access_token(other_user)}"})
try:
other_listing = client.get("/api/projects")
assert [project["id"] for project in other_listing.json()] == [other_project.id]
assert client.get(f"/api/projects/{owner_project['id']}").status_code == 404
finally:
client.headers.update({"Authorization": original_auth})