feat: 完善分割工作区导入导出与管理流程
- 新增基于 JWT 当前用户的登录恢复、角色权限、用户管理、审计日志和演示出厂重置后台接口与前端管理页。 - 重串 GT_label 导出和 GT Mask 导入逻辑:导出保留类别真实 maskid,导入仅接受灰度或 RGB 等通道 maskid 图,支持未知 maskid 策略、尺寸最近邻拉伸和导入预览。 - 统一分割结果导出体验:默认当前帧,按项目抽帧顺序和 XhXXmXXsXXXms 时间戳命名 ZIP 与图片,补齐 GT/Pro/Mix/分开 Mask 输出和映射 JSON。 - 调整工作区左侧工具栏:移除创建点/线段入口,新增画笔、橡皮擦及尺寸控制,并按绘制、布尔、导入/AI 工具分组分隔。 - 扩展 Canvas 编辑能力:画笔按语义分类绘制并可自动并入连通选中 mask,橡皮擦对选中区域扣除,优化布尔操作、选区、撤销重做和保存状态联动。 - 优化自动传播时间轴显示:同一蓝色系按传播新旧递进变暗,老传播记录达到阈值后统一旧记录色,并维护范围选择与清空后的历史显示。 - 将 AI 智能分割入口替换为更明确的 AI 元素图标,并同步侧栏、工作区和 AI 页面入口表现。 - 完善模板分类、maskid 工具函数、分类树联动、遮罩透明度、边缘平滑和传播链同步相关前端状态。 - 扩展后端项目、媒体、任务、Dashboard、模板和传播 runner 的用户隔离、任务控制、进度事件与兼容处理。 - 补充前后端测试,覆盖用户管理、GT_label 往返导入导出、GT Mask 校验和预览、画笔/橡皮擦、时间轴传播历史、导出范围、WebSocket 与 API 封装。 - 更新 AGENTS、README 和 doc 文档,记录当前接口契约、实现状态、测试计划、安装说明和 maskid/GT_label 规则。
This commit is contained in:
@@ -33,6 +33,12 @@ class Settings(BaseSettings):
|
||||
# App
|
||||
app_env: str = "development"
|
||||
cors_origins: list[str] = ["http://localhost:3000", "http://192.168.3.11:3000"]
|
||||
jwt_secret_key: str = "seg-server-dev-secret-change-me"
|
||||
jwt_algorithm: str = "HS256"
|
||||
access_token_expire_minutes: int = 60 * 24
|
||||
default_admin_username: str = "admin"
|
||||
default_admin_password: str = "123456"
|
||||
demo_video_path: str = "/home/wkmgc/Desktop/Seg_Server/Data_MyVideo_1.mp4"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
@@ -20,7 +20,7 @@ from progress_events import PROGRESS_CHANNEL
|
||||
from redis_client import get_redis_client, ping as redis_ping
|
||||
from statuses import PROJECT_STATUS_PENDING, PROJECT_STATUS_READY
|
||||
|
||||
from routers import projects, templates, media, ai, export, auth, dashboard, tasks
|
||||
from routers import projects, templates, media, ai, export, auth, dashboard, tasks, admin
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@@ -28,7 +28,7 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_VIDEO_PATH = "/home/wkmgc/Desktop/Seg_Server/Data_MyVideo_1.mp4"
|
||||
DEFAULT_VIDEO_PATH = settings.demo_video_path
|
||||
|
||||
|
||||
def _ensure_runtime_schema_columns() -> None:
|
||||
@@ -36,25 +36,56 @@ def _ensure_runtime_schema_columns() -> None:
|
||||
try:
|
||||
inspector = inspect(engine)
|
||||
frame_columns = {column["name"] for column in inspector.get_columns("frames")}
|
||||
project_columns = {column["name"] for column in inspector.get_columns("projects")}
|
||||
template_columns = {column["name"] for column in inspector.get_columns("templates")}
|
||||
with engine.begin() as connection:
|
||||
if "timestamp_ms" not in frame_columns:
|
||||
connection.execute(text("ALTER TABLE frames ADD COLUMN timestamp_ms FLOAT"))
|
||||
if "source_frame_number" not in frame_columns:
|
||||
connection.execute(text("ALTER TABLE frames ADD COLUMN source_frame_number INTEGER"))
|
||||
if "owner_user_id" not in project_columns:
|
||||
connection.execute(text("ALTER TABLE projects ADD COLUMN owner_user_id INTEGER"))
|
||||
if "owner_user_id" not in template_columns:
|
||||
connection.execute(text("ALTER TABLE templates ADD COLUMN owner_user_id INTEGER"))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Runtime schema column check failed: %s", exc)
|
||||
|
||||
|
||||
def _seed_default_admin_and_ownership_sync() -> None:
|
||||
"""Ensure the default admin exists and owns legacy unassigned projects."""
|
||||
from models import Project
|
||||
from routers.auth import ensure_default_admin
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
admin = ensure_default_admin(db)
|
||||
db.query(Project).filter(Project.owner_user_id.is_(None)).update(
|
||||
{"owner_user_id": admin.id},
|
||||
synchronize_session=False,
|
||||
)
|
||||
db.commit()
|
||||
logger.info("Default admin ready; legacy projects assigned to user id=%s", admin.id)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to seed default admin or ownership: %s", exc)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _seed_default_project_sync() -> None:
|
||||
"""Synchronously seed the default video project on first startup."""
|
||||
import cv2
|
||||
from models import Project, Frame
|
||||
from routers.auth import ensure_default_admin
|
||||
from services.frame_parser import parse_video, upload_frames_to_minio, extract_thumbnail
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
admin = ensure_default_admin(db)
|
||||
existing = db.query(Project).filter(Project.name == "Data_MyVideo_1").first()
|
||||
if existing is not None:
|
||||
if existing.owner_user_id is None:
|
||||
existing.owner_user_id = admin.id
|
||||
db.commit()
|
||||
return
|
||||
|
||||
if not os.path.exists(DEFAULT_VIDEO_PATH):
|
||||
@@ -67,6 +98,7 @@ def _seed_default_project_sync() -> None:
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="video",
|
||||
parse_fps=30.0,
|
||||
owner_user_id=admin.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
@@ -196,6 +228,7 @@ async def lifespan(app: FastAPI):
|
||||
try:
|
||||
Base.metadata.create_all(bind=engine)
|
||||
_ensure_runtime_schema_columns()
|
||||
_seed_default_admin_and_ownership_sync()
|
||||
logger.info("Database tables initialized.")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Database initialization failed: %s", exc)
|
||||
@@ -265,6 +298,7 @@ app.include_router(ai.router)
|
||||
app.include_router(export.router)
|
||||
app.include_router(dashboard.router)
|
||||
app.include_router(tasks.router)
|
||||
app.include_router(admin.router)
|
||||
|
||||
|
||||
@app.get("/health", tags=["Health"])
|
||||
|
||||
@@ -17,6 +17,25 @@ from database import Base
|
||||
from statuses import PROJECT_STATUS_PENDING
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""Application user used for authentication and data ownership."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
username = Column(String(150), unique=True, index=True, nullable=False)
|
||||
password_hash = Column(String(255), nullable=False)
|
||||
role = Column(String(50), default="admin", nullable=False)
|
||||
is_active = Column(Integer, default=1, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
projects = relationship("Project", back_populates="owner")
|
||||
templates = relationship("Template", back_populates="owner")
|
||||
|
||||
|
||||
class Project(Base):
|
||||
"""Project model representing a segmentation project."""
|
||||
|
||||
@@ -31,11 +50,13 @@ class Project(Base):
|
||||
source_type = Column(String(20), default="video", nullable=False) # video | dicom
|
||||
original_fps = Column(Float, nullable=True)
|
||||
parse_fps = Column(Float, default=30.0, nullable=False)
|
||||
owner_user_id = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
owner = relationship("User", back_populates="projects")
|
||||
frames = relationship("Frame", back_populates="project", cascade="all, delete-orphan")
|
||||
annotations = relationship(
|
||||
"Annotation", back_populates="project", cascade="all, delete-orphan"
|
||||
@@ -77,8 +98,10 @@ class Template(Base):
|
||||
color = Column(String(50), nullable=False)
|
||||
z_index = Column(Integer, default=0, nullable=False)
|
||||
mapping_rules = Column(JSON, nullable=True)
|
||||
owner_user_id = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
owner = relationship("User", back_populates="templates")
|
||||
annotations = relationship(
|
||||
"Annotation", back_populates="template", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -129,6 +152,22 @@ class Mask(Base):
|
||||
annotation = relationship("Annotation", back_populates="masks")
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
"""Audit trail for security and administrative actions."""
|
||||
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
actor_user_id = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True)
|
||||
action = Column(String(120), nullable=False)
|
||||
target_type = Column(String(80), nullable=True)
|
||||
target_id = Column(String(120), nullable=True)
|
||||
detail = Column(JSON, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
actor = relationship("User")
|
||||
|
||||
|
||||
class ProcessingTask(Base):
|
||||
"""Background task state persisted for dashboard and polling."""
|
||||
|
||||
|
||||
270
backend/routers/admin.py
Normal file
270
backend/routers/admin.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""Administrator-only user and audit management endpoints."""
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config import settings
|
||||
from database import get_db
|
||||
from minio_client import upload_file
|
||||
from models import Annotation, AuditLog, Frame, Mask, ProcessingTask, Project, Template, User
|
||||
from routers.auth import ensure_default_admin, hash_password, require_admin, write_audit_log
|
||||
from schemas import (
|
||||
AdminUserCreate,
|
||||
AdminUserUpdate,
|
||||
AuditLogOut,
|
||||
DemoFactoryResetOut,
|
||||
DemoFactoryResetRequest,
|
||||
UserOut,
|
||||
)
|
||||
from statuses import PROJECT_STATUS_PENDING
|
||||
|
||||
router = APIRouter(prefix="/api/admin", tags=["Admin"])
|
||||
|
||||
VALID_ROLES = {"admin", "annotator", "viewer"}
|
||||
DEMO_RESET_CONFIRMATION = "RESET_DEMO_FACTORY"
|
||||
DEMO_PROJECT_NAME = "Data_MyVideo_1"
|
||||
|
||||
|
||||
def _normalize_role(role: str | None) -> str:
|
||||
normalized = (role or "annotator").strip().lower()
|
||||
if normalized not in VALID_ROLES:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported role: {role}")
|
||||
return normalized
|
||||
|
||||
|
||||
@router.get("/users", response_model=List[UserOut], summary="List users")
|
||||
def list_users(
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> List[User]:
|
||||
"""Return all users for the administrator console."""
|
||||
_ = admin_user
|
||||
return db.query(User).order_by(User.id).all()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/users",
|
||||
response_model=UserOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create user",
|
||||
)
|
||||
def create_user(
|
||||
payload: AdminUserCreate,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> User:
|
||||
"""Create a user with an initial password and role."""
|
||||
username = payload.username.strip()
|
||||
if not username:
|
||||
raise HTTPException(status_code=400, detail="Username is required")
|
||||
if len(payload.password) < 6:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
|
||||
user = User(
|
||||
username=username,
|
||||
password_hash=hash_password(payload.password),
|
||||
role=_normalize_role(payload.role),
|
||||
is_active=1 if payload.is_active else 0,
|
||||
)
|
||||
db.add(user)
|
||||
try:
|
||||
db.commit()
|
||||
except IntegrityError as exc:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=409, detail="Username already exists") from exc
|
||||
db.refresh(user)
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=admin_user,
|
||||
action="admin.user_created",
|
||||
target_type="user",
|
||||
target_id=user.id,
|
||||
detail={"username": user.username, "role": user.role, "is_active": bool(user.is_active)},
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@router.patch("/users/{user_id}", response_model=UserOut, summary="Update user")
|
||||
def update_user(
|
||||
user_id: int,
|
||||
payload: AdminUserUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> User:
|
||||
"""Update username, password, role or active state."""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
audit_detail: dict = {"before": {"username": user.username, "role": user.role, "is_active": bool(user.is_active)}}
|
||||
if "username" in updates:
|
||||
username = (updates["username"] or "").strip()
|
||||
if not username:
|
||||
raise HTTPException(status_code=400, detail="Username is required")
|
||||
user.username = username
|
||||
if "password" in updates:
|
||||
password = updates["password"] or ""
|
||||
if len(password) < 6:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
|
||||
user.password_hash = hash_password(password)
|
||||
if "role" in updates:
|
||||
next_role = _normalize_role(updates["role"])
|
||||
if user.id == admin_user.id and next_role != "admin":
|
||||
raise HTTPException(status_code=400, detail="Cannot remove your own admin role")
|
||||
user.role = next_role
|
||||
if "is_active" in updates:
|
||||
if user.id == admin_user.id and not updates["is_active"]:
|
||||
raise HTTPException(status_code=400, detail="Cannot deactivate yourself")
|
||||
user.is_active = 1 if updates["is_active"] else 0
|
||||
|
||||
try:
|
||||
db.commit()
|
||||
except IntegrityError as exc:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=409, detail="Username already exists") from exc
|
||||
db.refresh(user)
|
||||
audit_detail["after"] = {"username": user.username, "role": user.role, "is_active": bool(user.is_active)}
|
||||
audit_detail["password_changed"] = "password" in updates
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=admin_user,
|
||||
action="admin.user_updated",
|
||||
target_type="user",
|
||||
target_id=user.id,
|
||||
detail=audit_detail,
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete user")
|
||||
def delete_user(
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> None:
|
||||
"""Delete a user when it is safe to remove the account."""
|
||||
if user_id == admin_user.id:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete yourself")
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
owned_project_count = db.query(Project).filter(Project.owner_user_id == user_id).count()
|
||||
if owned_project_count:
|
||||
raise HTTPException(status_code=409, detail="User owns projects; deactivate or migrate projects first")
|
||||
username = user.username
|
||||
db.delete(user)
|
||||
db.commit()
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=admin_user,
|
||||
action="admin.user_deleted",
|
||||
target_type="user",
|
||||
target_id=user_id,
|
||||
detail={"username": username},
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/audit-logs", response_model=List[AuditLogOut], summary="List audit logs")
|
||||
def list_audit_logs(
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> List[AuditLog]:
|
||||
"""Return recent audit events for administrators."""
|
||||
_ = admin_user
|
||||
safe_limit = min(max(int(limit or 100), 1), 500)
|
||||
return db.query(AuditLog).order_by(AuditLog.created_at.desc(), AuditLog.id.desc()).limit(safe_limit).all()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/demo-factory-reset",
|
||||
response_model=DemoFactoryResetOut,
|
||||
summary="Reset demo data to factory defaults",
|
||||
)
|
||||
def reset_demo_factory(
|
||||
payload: DemoFactoryResetRequest,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> dict:
|
||||
"""Reset a demo deployment to one admin account and one unparsed demo video project."""
|
||||
if payload.confirmation != DEMO_RESET_CONFIRMATION:
|
||||
raise HTTPException(status_code=400, detail="Invalid reset confirmation")
|
||||
|
||||
if not os.path.exists(settings.demo_video_path):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Demo video not found: {settings.demo_video_path}",
|
||||
)
|
||||
|
||||
requested_by = admin_user.username
|
||||
preserved_admin = ensure_default_admin(db)
|
||||
preserved_admin.username = settings.default_admin_username
|
||||
preserved_admin.password_hash = hash_password(settings.default_admin_password)
|
||||
preserved_admin.role = "admin"
|
||||
preserved_admin.is_active = 1
|
||||
db.flush()
|
||||
|
||||
deleted_counts = {
|
||||
"masks": db.query(Mask).delete(synchronize_session=False),
|
||||
"annotations": db.query(Annotation).delete(synchronize_session=False),
|
||||
"frames": db.query(Frame).delete(synchronize_session=False),
|
||||
"tasks": db.query(ProcessingTask).delete(synchronize_session=False),
|
||||
"projects": db.query(Project).delete(synchronize_session=False),
|
||||
"user_templates": db.query(Template).filter(Template.owner_user_id.is_not(None)).delete(synchronize_session=False),
|
||||
"audit_logs": db.query(AuditLog).delete(synchronize_session=False),
|
||||
"users": db.query(User).filter(User.id != preserved_admin.id).delete(synchronize_session=False),
|
||||
}
|
||||
db.flush()
|
||||
db.expunge_all()
|
||||
|
||||
preserved_admin = db.query(User).filter(User.username == settings.default_admin_username).first()
|
||||
if not preserved_admin:
|
||||
raise HTTPException(status_code=500, detail="Default admin was not preserved")
|
||||
|
||||
project = Project(
|
||||
name=DEMO_PROJECT_NAME,
|
||||
description="默认演示视频,尚未生成帧",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="video",
|
||||
parse_fps=30.0,
|
||||
owner_user_id=preserved_admin.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.flush()
|
||||
|
||||
with open(settings.demo_video_path, "rb") as file_obj:
|
||||
data = file_obj.read()
|
||||
object_name = f"uploads/{project.id}/{os.path.basename(settings.demo_video_path)}"
|
||||
upload_file(object_name, data, content_type="video/mp4", length=len(data))
|
||||
project.video_path = object_name
|
||||
project.thumbnail_url = None
|
||||
project.original_fps = None
|
||||
db.commit()
|
||||
db.refresh(preserved_admin)
|
||||
db.refresh(project)
|
||||
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=preserved_admin,
|
||||
action="admin.demo_factory_reset",
|
||||
target_type="project",
|
||||
target_id=project.id,
|
||||
detail={
|
||||
"project_name": project.name,
|
||||
"video_path": project.video_path,
|
||||
"deleted_counts": deleted_counts,
|
||||
"requested_by": requested_by,
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"admin_user": preserved_admin,
|
||||
"project": project,
|
||||
"deleted_counts": deleted_counts,
|
||||
"message": "演示环境已恢复出厂设置",
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
"""AI inference endpoints using selectable SAM runtimes."""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
@@ -8,11 +9,13 @@ from typing import Any, List
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Response, UploadFile, status
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import download_file
|
||||
from models import Project, Frame, Template, Annotation, ProcessingTask
|
||||
from models import Project, Frame, Template, Annotation, ProcessingTask, User
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import (
|
||||
AiRuntimeStatus,
|
||||
MaskAnalysisRequest,
|
||||
@@ -38,6 +41,102 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/ai", tags=["AI"])
|
||||
|
||||
|
||||
def _owned_project_or_404(project_id: int, db: Session, current_user: User) -> Project:
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
return project
|
||||
|
||||
|
||||
def _owned_frame_or_404(frame_id: int, db: Session, current_user: User, project_id: int | None = None) -> Frame:
|
||||
query = (
|
||||
db.query(Frame)
|
||||
.join(Project, Project.id == Frame.project_id)
|
||||
.filter(Frame.id == frame_id, Project.owner_user_id == current_user.id)
|
||||
)
|
||||
if project_id is not None:
|
||||
query = query.filter(Frame.project_id == project_id)
|
||||
frame = query.first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
return frame
|
||||
|
||||
|
||||
def _visible_template_or_404(template_id: int, db: Session, current_user: User) -> Template:
|
||||
template = db.query(Template).filter(
|
||||
Template.id == template_id,
|
||||
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
return template
|
||||
|
||||
|
||||
def _normalize_hex_color(value: Any) -> str | None:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
text = value.strip().lower()
|
||||
if not text:
|
||||
return None
|
||||
if not text.startswith("#"):
|
||||
text = f"#{text}"
|
||||
if len(text) == 4:
|
||||
text = "#" + "".join(char * 2 for char in text[1:])
|
||||
if len(text) != 7:
|
||||
return None
|
||||
try:
|
||||
int(text[1:], 16)
|
||||
except ValueError:
|
||||
return None
|
||||
return text
|
||||
|
||||
|
||||
def _rgb_tuple_to_hex(rgb: tuple[int, int, int]) -> str:
|
||||
values = []
|
||||
for channel in rgb:
|
||||
value = int(channel)
|
||||
if value > 255:
|
||||
value = int(round(value / 257))
|
||||
values.append(min(max(value, 0), 255))
|
||||
return f"#{values[0]:02x}{values[1]:02x}{values[2]:02x}"
|
||||
|
||||
|
||||
def _template_class_maps(template: Template | None) -> tuple[dict[int, dict[str, Any]], dict[str, dict[str, Any]]]:
|
||||
classes = ((template.mapping_rules or {}).get("classes") if template else None) or []
|
||||
by_maskid: dict[int, dict[str, Any]] = {}
|
||||
by_color: dict[str, dict[str, Any]] = {}
|
||||
for index, item in enumerate(classes):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
maskid_value = item.get("maskId", item.get("maskid", item.get("mask_id")))
|
||||
try:
|
||||
maskid = int(maskid_value)
|
||||
except (TypeError, ValueError):
|
||||
maskid = index + 1
|
||||
color = _normalize_hex_color(item.get("color")) or "#22c55e"
|
||||
class_meta = {
|
||||
"id": str(item.get("id") or f"maskid-{maskid}"),
|
||||
"name": str(item.get("name") or f"类别 {maskid}"),
|
||||
"color": color,
|
||||
"zIndex": int(item.get("zIndex", item.get("z_index", index * 10))),
|
||||
"maskId": maskid,
|
||||
**({"category": item.get("category")} if item.get("category") else {}),
|
||||
}
|
||||
if maskid > 0:
|
||||
by_maskid[maskid] = class_meta
|
||||
by_color[color] = class_meta
|
||||
return by_maskid, by_color
|
||||
|
||||
|
||||
def _gt_unknown_label(token: int | str) -> str:
|
||||
if isinstance(token, int):
|
||||
return f"未定义类别 {token}"
|
||||
return f"未定义颜色 {token}"
|
||||
|
||||
|
||||
def _load_frame_image(frame: Frame) -> np.ndarray:
|
||||
"""Download a frame from MinIO and decode it to an RGB numpy array."""
|
||||
try:
|
||||
@@ -106,16 +205,20 @@ def _normalize_polygons(polygons: list[list[list[float]]]) -> list[list[list[flo
|
||||
return [polygon for polygon in (_normalize_polygon(polygon) for polygon in polygons) if len(polygon) >= 3]
|
||||
|
||||
|
||||
def _analysis_anchors(polygons: list[list[list[float]]], points: list[list[float]] | None) -> list[list[float]]:
|
||||
if points:
|
||||
return [[_clamp01(point[0]), _clamp01(point[1])] for point in points if len(point) >= 2]
|
||||
def _sample_anchor_points(anchors: list[list[float]], limit: int = 64) -> list[list[float]]:
|
||||
if len(anchors) <= limit:
|
||||
return anchors
|
||||
step = max(1, math.ceil(len(anchors) / limit))
|
||||
return anchors[::step][:limit]
|
||||
|
||||
|
||||
def _analysis_anchor_summary(polygons: list[list[list[float]]]) -> tuple[int, list[list[float]]]:
|
||||
anchors: list[list[float]] = []
|
||||
for polygon in polygons:
|
||||
if not polygon:
|
||||
continue
|
||||
step = max(1, len(polygon) // 12)
|
||||
anchors.extend([[_clamp01(point[0]), _clamp01(point[1])] for point in polygon[::step]])
|
||||
return anchors[:32]
|
||||
anchors.extend([[_clamp01(point[0]), _clamp01(point[1])] for point in polygon])
|
||||
return len(anchors), _sample_anchor_points(anchors)
|
||||
|
||||
|
||||
def _normalize_smoothing_options(strength: float | int | None, method: str | None = None) -> dict[str, Any]:
|
||||
@@ -129,8 +232,14 @@ def _normalize_smoothing_options(strength: float | int | None, method: str | Non
|
||||
}
|
||||
|
||||
|
||||
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int) -> list[list[float]]:
|
||||
def _smoothing_ratio(strength: float, curve: float = 1.65) -> float:
|
||||
normalized = max(0.0, min(float(strength or 0.0), 100.0)) / 100.0
|
||||
return normalized ** curve
|
||||
|
||||
|
||||
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int, corner_cut: float = 0.25) -> list[list[float]]:
|
||||
points = polygon
|
||||
q = max(0.02, min(float(corner_cut), 0.25))
|
||||
for _ in range(max(0, iterations)):
|
||||
if len(points) < 3:
|
||||
break
|
||||
@@ -138,12 +247,12 @@ def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int) -> list
|
||||
for index, current in enumerate(points):
|
||||
following = points[(index + 1) % len(points)]
|
||||
next_points.append([
|
||||
_clamp01(0.75 * current[0] + 0.25 * following[0]),
|
||||
_clamp01(0.75 * current[1] + 0.25 * following[1]),
|
||||
_clamp01((1.0 - q) * current[0] + q * following[0]),
|
||||
_clamp01((1.0 - q) * current[1] + q * following[1]),
|
||||
])
|
||||
next_points.append([
|
||||
_clamp01(0.25 * current[0] + 0.75 * following[0]),
|
||||
_clamp01(0.25 * current[1] + 0.75 * following[1]),
|
||||
_clamp01(q * current[0] + (1.0 - q) * following[0]),
|
||||
_clamp01(q * current[1] + (1.0 - q) * following[1]),
|
||||
])
|
||||
points = next_points
|
||||
return points
|
||||
@@ -154,7 +263,7 @@ def _simplify_polygon(polygon: list[list[float]], strength: float) -> list[list[
|
||||
return polygon
|
||||
contour = np.array([[[point[0], point[1]]] for point in polygon], dtype=np.float32)
|
||||
arc_length = cv2.arcLength(contour, True)
|
||||
epsilon = arc_length * (0.001 + (strength / 100.0) * 0.006)
|
||||
epsilon = arc_length * (0.00015 + _smoothing_ratio(strength) * 0.00735)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True).reshape(-1, 2)
|
||||
if len(approx) < 3:
|
||||
return polygon
|
||||
@@ -165,9 +274,25 @@ def _smooth_polygon(polygon: list[list[float]], smoothing: dict[str, Any]) -> li
|
||||
strength = float(smoothing.get("strength") or 0.0)
|
||||
if strength <= 0:
|
||||
return _normalize_polygon(polygon)
|
||||
iterations = max(1, min(3, int(strength // 35) + 1))
|
||||
smoothed = _chaikin_smooth_polygon(_normalize_polygon(polygon), iterations)
|
||||
simplified = _simplify_polygon(smoothed, strength)
|
||||
effective_strength = _smoothing_ratio(strength, curve=1.45) * 100.0
|
||||
if effective_strength >= 85:
|
||||
iterations = 4
|
||||
elif effective_strength >= 55:
|
||||
iterations = 3
|
||||
elif effective_strength >= 25:
|
||||
iterations = 2
|
||||
else:
|
||||
iterations = 1
|
||||
corner_cut = 0.03 + _smoothing_ratio(strength, curve=1.35) * 0.22
|
||||
normalized = _normalize_polygon(polygon)
|
||||
pre_simplified = _simplify_polygon(normalized, effective_strength * 0.25)
|
||||
smoothed = _chaikin_smooth_polygon(pre_simplified, iterations, corner_cut)
|
||||
simplified = _simplify_polygon(smoothed, effective_strength)
|
||||
if len(simplified) > len(normalized):
|
||||
for fallback_strength in (25.0, 35.0, 50.0, 70.0, 90.0, 100.0):
|
||||
simplified = _simplify_polygon(simplified, max(effective_strength, fallback_strength))
|
||||
if len(simplified) <= len(normalized):
|
||||
break
|
||||
return simplified if len(simplified) >= 3 else _normalize_polygon(polygon)
|
||||
|
||||
|
||||
@@ -321,7 +446,11 @@ def _filter_predictions(
|
||||
response_model=PredictResponse,
|
||||
summary="Run SAM inference with a prompt",
|
||||
)
|
||||
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
def predict(
|
||||
payload: PredictRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Execute selected SAM segmentation given an image and a prompt.
|
||||
|
||||
- **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized
|
||||
@@ -330,9 +459,7 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
- **interactive**: `prompt_data` is `{ "box": [...], "points": [[x, y]], "labels": [1, 0] }`.
|
||||
- **semantic**: disabled in the current SAM 2.1 point/box product flow.
|
||||
"""
|
||||
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
frame = _owned_frame_or_404(payload.image_id, db, current_user)
|
||||
|
||||
image = _load_frame_image(frame)
|
||||
prompt_type = payload.prompt_type.lower()
|
||||
@@ -478,7 +605,10 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
response_model=AiRuntimeStatus,
|
||||
summary="Get SAM model and GPU runtime status",
|
||||
)
|
||||
def model_status(selected_model: str | None = None) -> dict:
|
||||
def model_status(
|
||||
selected_model: str | None = None,
|
||||
_current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Return real runtime availability for GPU and the currently enabled SAM model."""
|
||||
try:
|
||||
return sam_registry.runtime_status(selected_model)
|
||||
@@ -491,12 +621,14 @@ def model_status(selected_model: str | None = None) -> dict:
|
||||
response_model=MaskAnalysisResponse,
|
||||
summary="Analyze mask geometry and prompt anchors",
|
||||
)
|
||||
def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) -> dict:
|
||||
def analyze_mask(
|
||||
payload: MaskAnalysisRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Return backend-computed mask properties for the frontend inspector."""
|
||||
if payload.frame_id is not None:
|
||||
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
_owned_frame_or_404(payload.frame_id, db, current_user)
|
||||
|
||||
mask_data = payload.mask_data or {}
|
||||
polygons = mask_data.get("polygons") or []
|
||||
@@ -521,13 +653,13 @@ def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) ->
|
||||
else:
|
||||
confidence_source = "manual_or_imported"
|
||||
|
||||
anchors = _analysis_anchors(valid_polygons, payload.points)
|
||||
anchor_count, anchors = _analysis_anchor_summary(valid_polygons)
|
||||
message = "已从后端重新提取几何拓扑锚点" if payload.extract_skeleton else "已读取后端几何属性"
|
||||
|
||||
return {
|
||||
"confidence": confidence,
|
||||
"confidence_source": confidence_source,
|
||||
"topology_anchor_count": len(anchors),
|
||||
"topology_anchor_count": anchor_count,
|
||||
"topology_anchors": anchors,
|
||||
"area": area,
|
||||
"bbox": bbox,
|
||||
@@ -541,16 +673,18 @@ def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) ->
|
||||
response_model=SmoothMaskResponse,
|
||||
summary="Smooth editable mask polygons with backend geometry rules",
|
||||
)
|
||||
def smooth_mask(payload: SmoothMaskRequest, db: Session = Depends(get_db)) -> dict:
|
||||
def smooth_mask(
|
||||
payload: SmoothMaskRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Return a smoothed polygon mask without persisting it.
|
||||
|
||||
The frontend keeps this as an explicit edit operation: users preview/apply it
|
||||
to the current mask, then save through the normal annotation endpoint.
|
||||
"""
|
||||
if payload.frame_id is not None:
|
||||
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
_owned_frame_or_404(payload.frame_id, db, current_user)
|
||||
|
||||
polygons = payload.mask_data.get("polygons") or []
|
||||
valid_polygons = _normalize_polygons(polygons)
|
||||
@@ -564,10 +698,10 @@ def smooth_mask(payload: SmoothMaskRequest, db: Session = Depends(get_db)) -> di
|
||||
|
||||
area = sum(_polygon_area(polygon) for polygon in smoothed_polygons)
|
||||
bbox = _polygon_bbox(smoothed_polygons[0])
|
||||
anchors = _analysis_anchors(smoothed_polygons, payload.points)
|
||||
anchor_count, anchors = _analysis_anchor_summary(smoothed_polygons)
|
||||
return {
|
||||
"polygons": smoothed_polygons,
|
||||
"topology_anchor_count": len(anchors),
|
||||
"topology_anchor_count": anchor_count,
|
||||
"topology_anchors": anchors,
|
||||
"area": area,
|
||||
"bbox": bbox,
|
||||
@@ -581,7 +715,11 @@ def smooth_mask(payload: SmoothMaskRequest, db: Session = Depends(get_db)) -> di
|
||||
response_model=PropagateResponse,
|
||||
summary="Propagate one current-frame region across a video frame segment",
|
||||
)
|
||||
def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
|
||||
def propagate(
|
||||
payload: PropagateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Track one selected region from the current frame across nearby frames.
|
||||
|
||||
SAM 2 uses the official video predictor with the selected mask as the seed.
|
||||
@@ -592,16 +730,8 @@ def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
|
||||
raise HTTPException(status_code=400, detail="direction must be forward, backward, or both")
|
||||
max_frames = max(1, min(int(payload.max_frames or 30), 500))
|
||||
|
||||
project = db.query(Project).filter(Project.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
source_frame = db.query(Frame).filter(
|
||||
Frame.id == payload.frame_id,
|
||||
Frame.project_id == payload.project_id,
|
||||
).first()
|
||||
if not source_frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
_owned_project_or_404(payload.project_id, db, current_user)
|
||||
source_frame = _owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
|
||||
|
||||
seed = payload.seed.model_dump(exclude_none=True)
|
||||
polygons = seed.get("polygons") or []
|
||||
@@ -709,18 +839,14 @@ def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
|
||||
response_model=ProcessingTaskOut,
|
||||
summary="Queue a background video propagation task",
|
||||
)
|
||||
def queue_propagate_task(payload: PropagateTaskRequest, db: Session = Depends(get_db)) -> ProcessingTaskOut:
|
||||
def queue_propagate_task(
|
||||
payload: PropagateTaskRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> ProcessingTaskOut:
|
||||
"""Queue multiple seed/direction propagation steps as one background task."""
|
||||
project = db.query(Project).filter(Project.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
source_frame = db.query(Frame).filter(
|
||||
Frame.id == payload.frame_id,
|
||||
Frame.project_id == payload.project_id,
|
||||
).first()
|
||||
if not source_frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
_owned_project_or_404(payload.project_id, db, current_user)
|
||||
source_frame = _owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
|
||||
|
||||
if not payload.steps:
|
||||
raise HTTPException(status_code=400, detail="Propagation task requires at least one step")
|
||||
@@ -768,11 +894,13 @@ def queue_propagate_task(payload: PropagateTaskRequest, db: Session = Depends(ge
|
||||
response_model=PredictResponse,
|
||||
summary="Run automatic segmentation",
|
||||
)
|
||||
def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
|
||||
def auto_segment(
|
||||
image_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Run automatic mask generation on a frame using a grid of point prompts."""
|
||||
frame = db.query(Frame).filter(Frame.id == image_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
frame = _owned_frame_or_404(image_id, db, current_user)
|
||||
|
||||
image = _load_frame_image(frame)
|
||||
try:
|
||||
@@ -792,16 +920,15 @@ def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
|
||||
def save_annotation(
|
||||
payload: AnnotationCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Annotation:
|
||||
"""Persist an annotation (mask, points, bbox) into the database."""
|
||||
project = db.query(Project).filter(Project.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
_owned_project_or_404(payload.project_id, db, current_user)
|
||||
|
||||
if payload.frame_id:
|
||||
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
_owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
|
||||
if payload.template_id:
|
||||
_visible_template_or_404(payload.template_id, db, current_user)
|
||||
|
||||
annotation = Annotation(**payload.model_dump())
|
||||
db.add(annotation)
|
||||
@@ -823,8 +950,10 @@ async def import_gt_mask(
|
||||
template_id: int | None = Form(None),
|
||||
label: str = Form("GT Mask"),
|
||||
color: str = Form("#22c55e"),
|
||||
unknown_color_policy: str = Form("undefined"),
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> List[Annotation]:
|
||||
"""Convert a binary/label mask image into persisted polygon annotations.
|
||||
|
||||
@@ -833,62 +962,122 @@ async def import_gt_mask(
|
||||
the frontend an editable point-region representation instead of a static
|
||||
bitmap layer.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
_owned_project_or_404(project_id, db, current_user)
|
||||
frame = _owned_frame_or_404(frame_id, db, current_user, project_id)
|
||||
|
||||
frame = db.query(Frame).filter(Frame.id == frame_id, Frame.project_id == project_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
if unknown_color_policy not in {"discard", "undefined"}:
|
||||
raise HTTPException(status_code=400, detail="unknown_color_policy must be discard or undefined")
|
||||
|
||||
template: Template | None = None
|
||||
if template_id is not None:
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
template = _visible_template_or_404(template_id, db, current_user)
|
||||
|
||||
data = await file.read()
|
||||
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
|
||||
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
if image is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid mask image")
|
||||
|
||||
if image.ndim == 2:
|
||||
label_image = image
|
||||
elif image.ndim == 3 and image.shape[2] >= 3:
|
||||
channels = image[:, :, :3]
|
||||
# GT label images are maskid maps: either grayscale or RGB/BGR where
|
||||
# all three color channels contain the same maskid value [X, X, X].
|
||||
if not (np.array_equal(channels[:, :, 0], channels[:, :, 1]) and np.array_equal(channels[:, :, 1], channels[:, :, 2])):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="GT Mask 图片不符合要求:请上传灰度图,或 RGB 三通道完全相同的 maskid 图(背景 0,像素值为 maskid)。",
|
||||
)
|
||||
label_image = channels[:, :, 0]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="GT Mask 图片不符合要求:请上传灰度图,或 RGB 三通道完全相同的 maskid 图(背景 0,像素值为 maskid)。",
|
||||
)
|
||||
|
||||
width = int(frame.width or image.shape[1])
|
||||
height = int(frame.height or image.shape[0])
|
||||
label_values = [int(value) for value in np.unique(image) if int(value) > 0]
|
||||
if not label_values:
|
||||
original_height, original_width = int(label_image.shape[0]), int(label_image.shape[1])
|
||||
resized_to_frame = original_width != width or original_height != height
|
||||
if resized_to_frame:
|
||||
label_image = cv2.resize(label_image, (width, height), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
by_maskid, _by_color = _template_class_maps(template)
|
||||
has_template_classes = bool(by_maskid)
|
||||
fallback_color = _normalize_hex_color(color) or "#22c55e"
|
||||
|
||||
import_items: list[dict[str, Any]] = []
|
||||
skipped_unknown = 0
|
||||
label_values = [int(value) for value in np.unique(label_image) if int(value) > 0]
|
||||
for label_value in label_values:
|
||||
class_meta = by_maskid.get(label_value)
|
||||
is_unknown = has_template_classes and class_meta is None
|
||||
if is_unknown and unknown_color_policy == "discard":
|
||||
skipped_unknown += 1
|
||||
continue
|
||||
if class_meta:
|
||||
annotation_label = class_meta["name"]
|
||||
annotation_color = class_meta["color"]
|
||||
elif is_unknown:
|
||||
annotation_label = _gt_unknown_label(label_value)
|
||||
annotation_color = fallback_color
|
||||
else:
|
||||
annotation_label = f"{label} {label_value}" if len(label_values) > 1 else label
|
||||
annotation_color = fallback_color
|
||||
import_items.append({
|
||||
"token": label_value,
|
||||
"binary": np.where(label_image == label_value, 255, 0).astype(np.uint8),
|
||||
"label": annotation_label,
|
||||
"color": annotation_color,
|
||||
"class": class_meta,
|
||||
"unknown": is_unknown,
|
||||
"metadata": {
|
||||
"gt_label_value": label_value,
|
||||
"gt_original_size": {"width": original_width, "height": original_height},
|
||||
"gt_resized_to_frame": resized_to_frame,
|
||||
},
|
||||
})
|
||||
|
||||
if not import_items:
|
||||
if skipped_unknown > 0:
|
||||
raise HTTPException(status_code=400, detail="No matching GT mask classes found")
|
||||
raise HTTPException(status_code=400, detail="No foreground mask regions found")
|
||||
has_multiple_labels = len(label_values) > 1
|
||||
|
||||
annotations: list[Annotation] = []
|
||||
for label_value in label_values:
|
||||
binary = np.where(image == label_value, 255, 0).astype(np.uint8)
|
||||
for item in import_items:
|
||||
binary = item["binary"]
|
||||
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
annotation_label = f"{label} {label_value}" if has_multiple_labels else label
|
||||
|
||||
for contour in contours:
|
||||
if cv2.contourArea(contour) < 1:
|
||||
continue
|
||||
|
||||
polygon = _normalized_contour(contour, image.shape[1], image.shape[0])
|
||||
polygon = _normalized_contour(contour, binary.shape[1], binary.shape[0])
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
|
||||
component = np.zeros_like(binary, dtype=np.uint8)
|
||||
cv2.drawContours(component, [contour], -1, 1, thickness=-1)
|
||||
seed_point = _component_seed_point(component, image.shape[1], image.shape[0])
|
||||
bbox = _contour_bbox(contour, image.shape[1], image.shape[0])
|
||||
seed_point = _component_seed_point(component, binary.shape[1], binary.shape[0])
|
||||
bbox = _contour_bbox(contour, binary.shape[1], binary.shape[0])
|
||||
mask_data = {
|
||||
"polygons": [polygon],
|
||||
"label": item["label"],
|
||||
"color": item["color"],
|
||||
"source": "gt_mask",
|
||||
"image_size": {"width": width, "height": height},
|
||||
**item["metadata"],
|
||||
}
|
||||
if item["class"]:
|
||||
mask_data["class"] = item["class"]
|
||||
if item["unknown"]:
|
||||
mask_data["gt_unknown_class"] = True
|
||||
|
||||
annotation = Annotation(
|
||||
project_id=project_id,
|
||||
frame_id=frame_id,
|
||||
template_id=template_id,
|
||||
mask_data={
|
||||
"polygons": [polygon],
|
||||
"label": annotation_label,
|
||||
"color": color,
|
||||
"source": "gt_mask",
|
||||
"gt_label_value": label_value,
|
||||
"image_size": {"width": width, "height": height},
|
||||
},
|
||||
mask_data=mask_data,
|
||||
points=[seed_point],
|
||||
bbox=bbox,
|
||||
)
|
||||
@@ -914,14 +1103,14 @@ def list_annotations(
|
||||
project_id: int,
|
||||
frame_id: int | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[Annotation]:
|
||||
"""Return persisted annotations for a project, optionally scoped to one frame."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
_owned_project_or_404(project_id, db, current_user)
|
||||
|
||||
query = db.query(Annotation).filter(Annotation.project_id == project_id)
|
||||
if frame_id is not None:
|
||||
_owned_frame_or_404(frame_id, db, current_user, project_id)
|
||||
query = query.filter(Annotation.frame_id == frame_id)
|
||||
return query.order_by(Annotation.id).all()
|
||||
|
||||
@@ -935,17 +1124,21 @@ def update_annotation(
|
||||
annotation_id: int,
|
||||
payload: AnnotationUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Annotation:
|
||||
"""Update mutable annotation fields persisted in the database."""
|
||||
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
|
||||
annotation = (
|
||||
db.query(Annotation)
|
||||
.join(Project, Project.id == Annotation.project_id)
|
||||
.filter(Annotation.id == annotation_id, Project.owner_user_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
if not annotation:
|
||||
raise HTTPException(status_code=404, detail="Annotation not found")
|
||||
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
if "template_id" in updates and updates["template_id"] is not None:
|
||||
template = db.query(Template).filter(Template.id == updates["template_id"]).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
_visible_template_or_404(updates["template_id"], db, current_user)
|
||||
|
||||
for field, value in updates.items():
|
||||
setattr(annotation, field, value)
|
||||
@@ -964,9 +1157,15 @@ def update_annotation(
|
||||
def delete_annotation(
|
||||
annotation_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Response:
|
||||
"""Delete an annotation and its derived mask rows through ORM cascade."""
|
||||
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
|
||||
annotation = (
|
||||
db.query(Annotation)
|
||||
.join(Project, Project.id == Annotation.project_id)
|
||||
.filter(Annotation.id == annotation_id, Project.owner_user_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
if not annotation:
|
||||
raise HTTPException(status_code=404, detail="Annotation not found")
|
||||
|
||||
|
||||
@@ -1,9 +1,23 @@
|
||||
"""Authentication endpoints."""
|
||||
"""Authentication endpoints and dependencies."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config import settings
|
||||
from database import get_db
|
||||
from models import AuditLog, User
|
||||
from schemas import LoginResponse, UserOut
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["Auth"])
|
||||
password_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
@@ -11,14 +25,151 @@ class LoginRequest(BaseModel):
|
||||
password: str
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
token: str
|
||||
username: str
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a plain password for storage."""
|
||||
return password_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(password: str, password_hash: str) -> bool:
|
||||
"""Verify a plain password against a stored hash."""
|
||||
return password_context.verify(password, password_hash)
|
||||
|
||||
|
||||
def create_access_token(user: User, expires_delta: timedelta | None = None) -> str:
|
||||
"""Create a signed JWT access token for a user."""
|
||||
expire = datetime.now(timezone.utc) + (
|
||||
expires_delta or timedelta(minutes=settings.access_token_expire_minutes)
|
||||
)
|
||||
payload: dict[str, Any] = {
|
||||
"sub": str(user.id),
|
||||
"username": user.username,
|
||||
"role": user.role,
|
||||
"exp": expire,
|
||||
}
|
||||
return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
|
||||
|
||||
|
||||
def ensure_default_admin(db: Session) -> User:
|
||||
"""Create the default development admin if the user table is empty."""
|
||||
existing = db.query(User).filter(User.username == settings.default_admin_username).first()
|
||||
if existing:
|
||||
return existing
|
||||
user = User(
|
||||
username=settings.default_admin_username,
|
||||
password_hash=hash_password(settings.default_admin_password),
|
||||
role="admin",
|
||||
is_active=1,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
||||
db: Session = Depends(get_db),
|
||||
) -> User:
|
||||
"""Resolve and validate the current user from the Bearer token."""
|
||||
if credentials is None or credentials.scheme.lower() != "bearer":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
credentials.credentials,
|
||||
settings.jwt_secret_key,
|
||||
algorithms=[settings.jwt_algorithm],
|
||||
)
|
||||
user_id = int(payload.get("sub"))
|
||||
except (JWTError, TypeError, ValueError) as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from exc
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Inactive or missing user",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
def require_admin(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""Require the current user to have the administrator role."""
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin permission required")
|
||||
return current_user
|
||||
|
||||
|
||||
def require_editor(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""Require a user role that can modify segmentation data."""
|
||||
if current_user.role not in {"admin", "annotator"}:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Edit permission required")
|
||||
return current_user
|
||||
|
||||
|
||||
def write_audit_log(
|
||||
db: Session,
|
||||
*,
|
||||
actor: User | None,
|
||||
action: str,
|
||||
target_type: str | None = None,
|
||||
target_id: str | int | None = None,
|
||||
detail: dict[str, Any] | None = None,
|
||||
) -> AuditLog:
|
||||
"""Persist a compact audit event."""
|
||||
log = AuditLog(
|
||||
actor_user_id=actor.id if actor else None,
|
||||
action=action,
|
||||
target_type=target_type,
|
||||
target_id=str(target_id) if target_id is not None else None,
|
||||
detail=detail or {},
|
||||
)
|
||||
db.add(log)
|
||||
db.commit()
|
||||
db.refresh(log)
|
||||
return log
|
||||
|
||||
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
def login(payload: LoginRequest) -> dict:
|
||||
"""Simple login for development."""
|
||||
if payload.username == "admin" and payload.password == "123456":
|
||||
return {"token": "fake-jwt-token-for-admin", "username": payload.username}
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
def login(payload: LoginRequest, db: Session = Depends(get_db)) -> dict:
|
||||
"""Authenticate a user and return a signed JWT."""
|
||||
ensure_default_admin(db)
|
||||
user = db.query(User).filter(User.username == payload.username).first()
|
||||
if not user or not user.is_active or not verify_password(payload.password, user.password_hash):
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=None,
|
||||
action="auth.login_failed",
|
||||
target_type="user",
|
||||
target_id=payload.username,
|
||||
detail={"username": payload.username},
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=user,
|
||||
action="auth.login_success",
|
||||
target_type="user",
|
||||
target_id=user.id,
|
||||
detail={"username": user.username},
|
||||
)
|
||||
return {
|
||||
"token": create_access_token(user),
|
||||
"token_type": "bearer",
|
||||
"username": user.username,
|
||||
"user": user,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserOut)
|
||||
def read_current_user(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""Return the authenticated user profile."""
|
||||
return current_user
|
||||
|
||||
@@ -5,11 +5,12 @@ from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Annotation, Frame, ProcessingTask, Project, Template
|
||||
from models import Annotation, Frame, ProcessingTask, Project, Template, User
|
||||
from routers.auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
|
||||
|
||||
@@ -52,22 +53,45 @@ def _task_payload(task: ProcessingTask) -> dict[str, Any]:
|
||||
|
||||
|
||||
@router.get("/overview", summary="Get dashboard overview")
|
||||
def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
|
||||
def get_dashboard_overview(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict[str, Any]:
|
||||
"""Return live dashboard data derived from persisted backend records."""
|
||||
project_count = db.query(func.count(Project.id)).scalar() or 0
|
||||
frame_count = db.query(func.count(Frame.id)).scalar() or 0
|
||||
annotation_count = db.query(func.count(Annotation.id)).scalar() or 0
|
||||
template_count = db.query(func.count(Template.id)).scalar() or 0
|
||||
owned_project_ids_query = db.query(Project.id).filter(Project.owner_user_id == current_user.id)
|
||||
project_count = db.query(func.count(Project.id)).filter(Project.owner_user_id == current_user.id).scalar() or 0
|
||||
frame_count = db.query(func.count(Frame.id)).filter(Frame.project_id.in_(owned_project_ids_query)).scalar() or 0
|
||||
annotation_count = (
|
||||
db.query(func.count(Annotation.id))
|
||||
.filter(Annotation.project_id.in_(owned_project_ids_query))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
template_count = (
|
||||
db.query(func.count(Template.id))
|
||||
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
active_task_count = (
|
||||
db.query(func.count(ProcessingTask.id))
|
||||
.outerjoin(Project, Project.id == ProcessingTask.project_id)
|
||||
.filter((ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id))
|
||||
.filter(ProcessingTask.status.in_(ACTIVE_TASK_STATUSES))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
projects = db.query(Project).order_by(Project.updated_at.desc()).all()
|
||||
projects = (
|
||||
db.query(Project)
|
||||
.filter(Project.owner_user_id == current_user.id)
|
||||
.order_by(Project.updated_at.desc())
|
||||
.all()
|
||||
)
|
||||
recent_tasks = (
|
||||
db.query(ProcessingTask)
|
||||
.outerjoin(Project, Project.id == ProcessingTask.project_id)
|
||||
.filter((ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id))
|
||||
.order_by(ProcessingTask.created_at.desc())
|
||||
.limit(50)
|
||||
.all()
|
||||
@@ -96,6 +120,7 @@ def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
|
||||
|
||||
recent_annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id.in_(owned_project_ids_query))
|
||||
.order_by(Annotation.updated_at.desc())
|
||||
.limit(10)
|
||||
.all()
|
||||
@@ -112,6 +137,7 @@ def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
|
||||
|
||||
recent_templates = (
|
||||
db.query(Template)
|
||||
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
|
||||
.order_by(Template.created_at.desc())
|
||||
.limit(10)
|
||||
.all()
|
||||
|
||||
@@ -4,17 +4,22 @@ import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from urllib.parse import quote
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Project, Annotation, Frame, Template
|
||||
from minio_client import download_file
|
||||
from models import Project, Annotation, Frame, Template, User
|
||||
from routers.auth import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/export", tags=["Export"])
|
||||
@@ -49,6 +54,30 @@ def _annotation_z_index(annotation: Annotation) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
def _annotation_mask_id(annotation: Annotation) -> int | None:
|
||||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||
if isinstance(class_meta, dict):
|
||||
for key in ("maskId", "maskid", "mask_id"):
|
||||
if class_meta.get(key) is None:
|
||||
continue
|
||||
try:
|
||||
value = int(class_meta[key])
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if value > 0:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _annotation_category_name(annotation: Annotation) -> str:
|
||||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||
if isinstance(class_meta, dict) and class_meta.get("category"):
|
||||
return str(class_meta["category"])
|
||||
if annotation.template and annotation.template.name:
|
||||
return str(annotation.template.name)
|
||||
return ""
|
||||
|
||||
|
||||
def _annotation_class_key(annotation: Annotation) -> str:
|
||||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||
if isinstance(class_meta, dict):
|
||||
@@ -85,38 +114,162 @@ def _annotation_color(annotation: Annotation) -> str:
|
||||
return "#ffffff"
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/coco",
|
||||
summary="Export annotations in COCO format",
|
||||
)
|
||||
def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
|
||||
"""Export all annotations for a project as a COCO-format JSON file."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
def _hex_to_rgb(color: str) -> list[int]:
|
||||
value = str(color or "").strip()
|
||||
if value.startswith("#"):
|
||||
value = value[1:]
|
||||
if len(value) == 3:
|
||||
value = "".join(part * 2 for part in value)
|
||||
if len(value) != 6:
|
||||
return [255, 255, 255]
|
||||
try:
|
||||
return [int(value[i:i + 2], 16) for i in (0, 2, 4)]
|
||||
except ValueError:
|
||||
return [255, 255, 255]
|
||||
|
||||
|
||||
def _safe_filename_part(value: Any, fallback: str = "unknown") -> str:
|
||||
text = str(value or "").strip()
|
||||
if not text:
|
||||
text = fallback
|
||||
text = re.sub(r"[\\/:*?\"<>|\s]+", "_", text)
|
||||
text = re.sub(r"_+", "_", text).strip("._")
|
||||
return text or fallback
|
||||
|
||||
|
||||
def _project_video_name(project: Project) -> str:
|
||||
if project.video_path:
|
||||
stem = Path(project.video_path).name
|
||||
if "." in stem:
|
||||
stem = ".".join(stem.split(".")[:-1])
|
||||
if stem:
|
||||
return _safe_filename_part(stem, f"project_{project.id}")
|
||||
return _safe_filename_part(project.name, f"project_{project.id}")
|
||||
|
||||
|
||||
def _project_export_name(project: Project) -> str:
|
||||
return _safe_filename_part(project.name, f"project_{project.id}")
|
||||
|
||||
|
||||
def _frame_timestamp_ms(frame: Frame, project: Project) -> float:
|
||||
if frame.timestamp_ms is not None:
|
||||
return float(frame.timestamp_ms)
|
||||
fps = project.parse_fps or project.original_fps or 30.0
|
||||
return float(frame.frame_index) * 1000.0 / max(float(fps), 1.0)
|
||||
|
||||
|
||||
def _project_frame_number(frame: Frame) -> int:
|
||||
return int(frame.frame_index) + 1
|
||||
|
||||
|
||||
def _format_timestamp_ms(value: float) -> str:
|
||||
total_ms = max(0, int(round(float(value))))
|
||||
hours = total_ms // 3_600_000
|
||||
minutes = (total_ms % 3_600_000) // 60_000
|
||||
seconds = (total_ms % 60_000) // 1_000
|
||||
milliseconds = total_ms % 1_000
|
||||
return f"{hours}h{minutes:02d}m{seconds:02d}s{milliseconds:03d}ms"
|
||||
|
||||
|
||||
def _frame_export_stem(project: Project, frame: Frame) -> str:
|
||||
return "_".join([
|
||||
_project_video_name(project),
|
||||
_format_timestamp_ms(_frame_timestamp_ms(frame, project)),
|
||||
f"frame{_project_frame_number(frame):06d}",
|
||||
])
|
||||
|
||||
|
||||
def _segmentation_results_filename(project: Project, frames: list[Frame]) -> str:
|
||||
if not frames:
|
||||
return f"{_project_export_name(project)}_seg_T_0h00m00s000ms-0h00m00s000ms_P_0-0.zip"
|
||||
first_frame = frames[0]
|
||||
last_frame = frames[-1]
|
||||
return (
|
||||
f"{_project_export_name(project)}"
|
||||
f"_seg_T_{_format_timestamp_ms(_frame_timestamp_ms(first_frame, project))}"
|
||||
f"-{_format_timestamp_ms(_frame_timestamp_ms(last_frame, project))}"
|
||||
f"_P_{_project_frame_number(first_frame)}-{_project_frame_number(last_frame)}.zip"
|
||||
)
|
||||
|
||||
|
||||
def _download_content_disposition(filename: str) -> str:
|
||||
ascii_fallback = filename.encode("ascii", "ignore").decode("ascii") or "segmentation_results.zip"
|
||||
ascii_fallback = _safe_filename_part(ascii_fallback, "segmentation_results.zip")
|
||||
if not ascii_fallback.endswith(".zip") and filename.endswith(".zip"):
|
||||
ascii_fallback = f"{ascii_fallback}.zip"
|
||||
return f"attachment; filename=\"{ascii_fallback}\"; filename*=UTF-8''{quote(filename)}"
|
||||
|
||||
|
||||
def _frame_image_extension(frame: Frame) -> str:
|
||||
suffix = Path(frame.image_url or "").suffix.lower()
|
||||
return suffix if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} else ".jpg"
|
||||
|
||||
|
||||
def _project_or_404(project_id: int, db: Session, current_user: User) -> Project:
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
return project
|
||||
|
||||
annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == project_id)
|
||||
.all()
|
||||
)
|
||||
frames = (
|
||||
|
||||
def _project_frames(project_id: int, db: Session) -> list[Frame]:
|
||||
return (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id)
|
||||
.order_by(Frame.frame_index)
|
||||
.all()
|
||||
)
|
||||
templates = db.query(Template).all()
|
||||
|
||||
# Build COCO structure
|
||||
|
||||
def _filter_frames(
|
||||
frames: list[Frame],
|
||||
*,
|
||||
scope: str = "all",
|
||||
start_frame: int | None = None,
|
||||
end_frame: int | None = None,
|
||||
frame_id: int | None = None,
|
||||
) -> list[Frame]:
|
||||
if scope == "current":
|
||||
if frame_id is None:
|
||||
raise HTTPException(status_code=400, detail="frame_id is required for current-frame export")
|
||||
selected = [frame for frame in frames if frame.id == frame_id]
|
||||
if not selected:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
return selected
|
||||
|
||||
if scope == "range":
|
||||
if start_frame is None or end_frame is None:
|
||||
raise HTTPException(status_code=400, detail="start_frame and end_frame are required for range export")
|
||||
start = max(1, min(int(start_frame), int(end_frame)))
|
||||
end = max(1, max(int(start_frame), int(end_frame)))
|
||||
return frames[start - 1:end]
|
||||
|
||||
return frames
|
||||
|
||||
|
||||
def _filtered_annotations(project_id: int, frame_ids: set[int], db: Session) -> list[Annotation]:
|
||||
if not frame_ids:
|
||||
return []
|
||||
return (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == project_id)
|
||||
.filter(Annotation.frame_id.in_(frame_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def _build_coco(project: Project, frames: list[Frame], annotations: list[Annotation], templates: list[Template]) -> dict[str, Any]:
|
||||
images = []
|
||||
for idx, frame in enumerate(frames):
|
||||
for frame in frames:
|
||||
images.append({
|
||||
"id": frame.id,
|
||||
"file_name": frame.image_url,
|
||||
"width": frame.width or 1920,
|
||||
"height": frame.height or 1080,
|
||||
"frame_index": idx,
|
||||
"frame_index": frame.frame_index,
|
||||
})
|
||||
|
||||
categories = []
|
||||
@@ -131,14 +284,14 @@ def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResp
|
||||
|
||||
coco_annotations = []
|
||||
ann_id = 1
|
||||
selected_frame_ids = {frame.id for frame in frames}
|
||||
for ann in annotations:
|
||||
if not ann.mask_data:
|
||||
if ann.frame_id not in selected_frame_ids or not ann.mask_data:
|
||||
continue
|
||||
polygons = ann.mask_data.get("polygons", [])
|
||||
if not polygons:
|
||||
continue
|
||||
|
||||
# Use first polygon for bbox / area approximation
|
||||
first_poly = polygons[0]
|
||||
xs = [p[0] for p in first_poly]
|
||||
ys = [p[1] for p in first_poly]
|
||||
@@ -171,7 +324,7 @@ def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResp
|
||||
})
|
||||
ann_id += 1
|
||||
|
||||
coco = {
|
||||
return {
|
||||
"info": {
|
||||
"description": f"Annotations for {project.name}",
|
||||
"version": "1.0",
|
||||
@@ -183,39 +336,235 @@ def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResp
|
||||
"categories": categories,
|
||||
}
|
||||
|
||||
data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8")
|
||||
filename = f"project_{project_id}_coco.json"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(data),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
def _class_mapping_entry(annotation: Annotation) -> dict[str, Any]:
|
||||
return {
|
||||
"key": _annotation_class_key(annotation),
|
||||
"className": _annotation_label(annotation),
|
||||
"chineseName": _annotation_label(annotation),
|
||||
"categoryName": _annotation_category_name(annotation),
|
||||
"color": _annotation_color(annotation),
|
||||
"internalPriority": _annotation_z_index(annotation),
|
||||
"maskidHint": _annotation_mask_id(annotation),
|
||||
"template_id": annotation.template_id,
|
||||
}
|
||||
|
||||
|
||||
def _build_gt_class_mapping(annotations: list[Annotation]) -> tuple[dict[str, int], list[dict[str, Any]]]:
|
||||
entries_by_key: dict[str, dict[str, Any]] = {}
|
||||
for annotation in annotations:
|
||||
if not annotation.mask_data or not annotation.mask_data.get("polygons"):
|
||||
continue
|
||||
entry = _class_mapping_entry(annotation)
|
||||
entries_by_key.setdefault(entry["key"], entry)
|
||||
|
||||
ordered = sorted(
|
||||
entries_by_key.values(),
|
||||
key=lambda item: (
|
||||
item["maskidHint"] if isinstance(item.get("maskidHint"), int) and item["maskidHint"] > 0 else 10_000_000,
|
||||
str(item["className"]),
|
||||
str(item["key"]),
|
||||
),
|
||||
)
|
||||
key_to_value: dict[str, int] = {}
|
||||
classes: list[dict[str, Any]] = []
|
||||
used_maskids: set[int] = set()
|
||||
next_maskid = 1
|
||||
|
||||
def next_available_maskid() -> int:
|
||||
nonlocal next_maskid
|
||||
while next_maskid in used_maskids:
|
||||
next_maskid += 1
|
||||
value = next_maskid
|
||||
used_maskids.add(value)
|
||||
next_maskid += 1
|
||||
return value
|
||||
|
||||
for entry in ordered:
|
||||
hinted_maskid = entry.get("maskidHint")
|
||||
if isinstance(hinted_maskid, int) and hinted_maskid > 0 and hinted_maskid not in used_maskids:
|
||||
maskid = hinted_maskid
|
||||
used_maskids.add(maskid)
|
||||
else:
|
||||
maskid = next_available_maskid()
|
||||
key_to_value[entry["key"]] = maskid
|
||||
classes.append({
|
||||
"gt_pixel_value": maskid,
|
||||
"maskid": maskid,
|
||||
"chineseName": entry["chineseName"],
|
||||
"className": entry["className"],
|
||||
"categoryName": entry["categoryName"],
|
||||
"rgb": _hex_to_rgb(entry["color"]),
|
||||
"color": entry["color"],
|
||||
"key": entry["key"],
|
||||
"template_id": entry["template_id"],
|
||||
})
|
||||
return key_to_value, classes
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/masks",
|
||||
summary="Export PNG masks as a ZIP archive",
|
||||
)
|
||||
def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
|
||||
"""Export individual masks plus z-index fused semantic masks inside a ZIP."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
def _parse_result_outputs(mask_type: str, outputs: str | None) -> set[str]:
|
||||
allowed = {"separate", "gt_label", "pro_label", "mix_label"}
|
||||
if outputs:
|
||||
parsed = {item.strip() for item in outputs.split(",") if item.strip()}
|
||||
invalid = parsed - allowed
|
||||
if invalid:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid outputs: {', '.join(sorted(invalid))}")
|
||||
return parsed or allowed
|
||||
|
||||
if mask_type == "separate":
|
||||
return {"separate"}
|
||||
if mask_type == "gt_label":
|
||||
return {"gt_label"}
|
||||
if mask_type == "pro_label":
|
||||
return {"pro_label"}
|
||||
if mask_type == "mix_label":
|
||||
return {"mix_label"}
|
||||
return allowed
|
||||
|
||||
|
||||
def _write_original_frames(
|
||||
zf: zipfile.ZipFile,
|
||||
project: Project,
|
||||
frames: list[Frame],
|
||||
) -> dict[int, bytes]:
|
||||
image_bytes_by_frame: dict[int, bytes] = {}
|
||||
for frame in frames:
|
||||
image_bytes = download_file(frame.image_url)
|
||||
image_bytes_by_frame[frame.id] = image_bytes
|
||||
zf.writestr(
|
||||
f"原始图片/{_frame_export_stem(project, frame)}{_frame_image_extension(frame)}",
|
||||
image_bytes,
|
||||
)
|
||||
return image_bytes_by_frame
|
||||
|
||||
|
||||
def _decode_original_image(image_bytes: bytes | None, width: int, height: int) -> np.ndarray:
|
||||
import cv2
|
||||
|
||||
annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == project_id)
|
||||
.all()
|
||||
)
|
||||
frames = (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id)
|
||||
.order_by(Frame.frame_index)
|
||||
.all()
|
||||
)
|
||||
if image_bytes:
|
||||
decoded = cv2.imdecode(np.frombuffer(image_bytes, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
if decoded is not None:
|
||||
if decoded.shape[1] != width or decoded.shape[0] != height:
|
||||
decoded = cv2.resize(decoded, (width, height), interpolation=cv2.INTER_AREA)
|
||||
return decoded
|
||||
return np.zeros((height, width, 3), dtype=np.uint8)
|
||||
|
||||
|
||||
def _write_result_mask_outputs(
|
||||
zf: zipfile.ZipFile,
|
||||
project: Project,
|
||||
frames: list[Frame],
|
||||
annotations: list[Annotation],
|
||||
*,
|
||||
outputs: set[str],
|
||||
class_values: dict[str, int],
|
||||
class_mapping: list[dict[str, Any]],
|
||||
original_images: dict[int, bytes],
|
||||
mix_opacity: float,
|
||||
) -> None:
|
||||
import cv2
|
||||
|
||||
include_individual = "separate" in outputs
|
||||
include_semantic = "gt_label" in outputs
|
||||
include_pro_label = "pro_label" in outputs
|
||||
include_mix_label = "mix_label" in outputs
|
||||
class_rgb_by_key = {
|
||||
item["key"]: item.get("rgb") or _hex_to_rgb(item.get("color", "#ffffff"))
|
||||
for item in class_mapping
|
||||
}
|
||||
annotations_by_frame: dict[int, list[Annotation]] = {}
|
||||
selected_frame_ids = {frame.id for frame in frames}
|
||||
for annotation in annotations:
|
||||
if annotation.frame_id not in selected_frame_ids or not annotation.mask_data:
|
||||
continue
|
||||
if not annotation.mask_data.get("polygons"):
|
||||
continue
|
||||
annotations_by_frame.setdefault(annotation.frame_id, []).append(annotation)
|
||||
|
||||
for frame in frames:
|
||||
frame_annotations = annotations_by_frame.get(frame.id, [])
|
||||
if not frame_annotations:
|
||||
continue
|
||||
width = frame.width or 1920
|
||||
height = frame.height or 1080
|
||||
frame_stem = _frame_export_stem(project, frame)
|
||||
|
||||
if include_individual:
|
||||
class_masks: dict[str, np.ndarray] = {}
|
||||
class_meta: dict[str, dict[str, Any]] = {}
|
||||
for annotation in frame_annotations:
|
||||
key = _annotation_class_key(annotation)
|
||||
combined = class_masks.setdefault(key, np.zeros((height, width), dtype=np.uint8))
|
||||
for poly in (annotation.mask_data or {}).get("polygons", []):
|
||||
combined[:] = np.maximum(combined, _mask_from_polygon(poly, width, height))
|
||||
class_meta.setdefault(key, _class_mapping_entry(annotation))
|
||||
|
||||
folder = f"分开Mask分割结果/{frame_stem}_分别导出"
|
||||
for key, mask in sorted(class_masks.items(), key=lambda item: int(class_meta[item[0]]["internalPriority"])):
|
||||
meta = class_meta[key]
|
||||
maskid = class_values.get(key)
|
||||
if maskid is None:
|
||||
continue
|
||||
_, encoded = cv2.imencode(".png", mask)
|
||||
class_name = _safe_filename_part(meta["className"], "class")
|
||||
zf.writestr(
|
||||
f"{folder}/{frame_stem}_{class_name}_maskid{maskid}.png",
|
||||
encoded.tobytes(),
|
||||
)
|
||||
|
||||
needs_fused_output = include_semantic or include_pro_label or include_mix_label
|
||||
semantic = np.zeros((height, width), dtype=np.uint16) if needs_fused_output else None
|
||||
pro_label = np.zeros((height, width, 3), dtype=np.uint8) if (include_pro_label or include_mix_label) else None
|
||||
|
||||
if needs_fused_output:
|
||||
for annotation in sorted(frame_annotations, key=_annotation_z_index):
|
||||
key = _annotation_class_key(annotation)
|
||||
value = class_values.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
combined = np.zeros((height, width), dtype=np.uint8)
|
||||
for poly in (annotation.mask_data or {}).get("polygons", []):
|
||||
combined = np.maximum(combined, _mask_from_polygon(poly, width, height))
|
||||
if semantic is not None:
|
||||
semantic[combined > 0] = value
|
||||
if pro_label is not None:
|
||||
rgb = class_rgb_by_key.get(key, [255, 255, 255])
|
||||
bgr = np.array([rgb[2], rgb[1], rgb[0]], dtype=np.uint8)
|
||||
pro_label[combined > 0] = bgr
|
||||
|
||||
if include_semantic and semantic is not None:
|
||||
_, encoded = cv2.imencode(".png", semantic)
|
||||
zf.writestr(f"GT_label图/{frame_stem}.png", encoded.tobytes())
|
||||
|
||||
if include_pro_label and pro_label is not None:
|
||||
_, encoded = cv2.imencode(".png", pro_label)
|
||||
zf.writestr(f"Pro_label彩色分割结果/{frame_stem}.png", encoded.tobytes())
|
||||
|
||||
if include_mix_label and pro_label is not None:
|
||||
original = _decode_original_image(original_images.get(frame.id), width, height)
|
||||
mask_pixels = np.any(pro_label > 0, axis=2)
|
||||
mixed = original.copy()
|
||||
opacity = min(max(float(mix_opacity), 0.0), 1.0)
|
||||
mixed[mask_pixels] = (
|
||||
original[mask_pixels].astype(np.float32) * (1.0 - opacity)
|
||||
+ pro_label[mask_pixels].astype(np.float32) * opacity
|
||||
).clip(0, 255).astype(np.uint8)
|
||||
_, encoded = cv2.imencode(".png", mixed)
|
||||
zf.writestr(f"Mix_label重叠覆盖彩色分割结果/{frame_stem}.png", encoded.tobytes())
|
||||
|
||||
|
||||
def _write_mask_pngs(
|
||||
zf: zipfile.ZipFile,
|
||||
frames: list[Frame],
|
||||
annotations: list[Annotation],
|
||||
*,
|
||||
mask_type: str,
|
||||
individual_prefix: str = "",
|
||||
semantic_prefix: str = "",
|
||||
semantic_file_stem: str = "semantic_frame",
|
||||
semantic_dtype: Any = np.uint8,
|
||||
) -> list[dict[str, Any]]:
|
||||
import cv2
|
||||
|
||||
class_values: dict[str, int] = {}
|
||||
semantic_classes: list[dict[str, Any]] = []
|
||||
@@ -235,46 +584,102 @@ def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingRes
|
||||
})
|
||||
return class_values[key]
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
frame_masks: dict[int, list[tuple[Annotation, np.ndarray]]] = {}
|
||||
for ann in annotations:
|
||||
if not ann.mask_data:
|
||||
continue
|
||||
polygons = ann.mask_data.get("polygons", [])
|
||||
if not polygons:
|
||||
continue
|
||||
include_individual = mask_type in {"separate", "both"}
|
||||
include_semantic = mask_type in {"gt_label", "both"}
|
||||
frame_masks: dict[int, list[tuple[Annotation, np.ndarray]]] = {}
|
||||
selected_frame_ids = {frame.id for frame in frames}
|
||||
|
||||
width = ann.frame.width if ann.frame else 1920
|
||||
height = ann.frame.height if ann.frame else 1080
|
||||
combined = np.zeros((height, width), dtype=np.uint8)
|
||||
for ann in annotations:
|
||||
if ann.frame_id not in selected_frame_ids or not ann.mask_data:
|
||||
continue
|
||||
polygons = ann.mask_data.get("polygons", [])
|
||||
if not polygons:
|
||||
continue
|
||||
|
||||
for poly in polygons:
|
||||
mask = _mask_from_polygon(poly, width, height)
|
||||
combined = np.maximum(combined, mask)
|
||||
width = ann.frame.width if ann.frame else 1920
|
||||
height = ann.frame.height if ann.frame else 1080
|
||||
combined = np.zeros((height, width), dtype=np.uint8)
|
||||
for poly in polygons:
|
||||
mask = _mask_from_polygon(poly, width, height)
|
||||
combined = np.maximum(combined, mask)
|
||||
|
||||
if include_individual:
|
||||
_, encoded = cv2.imencode(".png", combined)
|
||||
fname = f"mask_{ann.id:06d}.png"
|
||||
zf.writestr(fname, encoded.tobytes())
|
||||
if ann.frame_id is not None:
|
||||
frame_masks.setdefault(ann.frame_id, []).append((ann, combined))
|
||||
zf.writestr(f"{individual_prefix}mask_{ann.id:06d}.png", encoded.tobytes())
|
||||
if include_semantic and ann.frame_id is not None:
|
||||
frame_masks.setdefault(ann.frame_id, []).append((ann, combined))
|
||||
|
||||
if include_semantic:
|
||||
for frame in frames:
|
||||
entries = frame_masks.get(frame.id, [])
|
||||
if not entries:
|
||||
continue
|
||||
width = frame.width or 1920
|
||||
height = frame.height or 1080
|
||||
semantic = np.zeros((height, width), dtype=np.uint8)
|
||||
semantic = np.zeros((height, width), dtype=semantic_dtype)
|
||||
for ann, mask in sorted(entries, key=lambda item: _annotation_z_index(item[0])):
|
||||
semantic[mask > 0] = class_value(ann)
|
||||
_, encoded = cv2.imencode(".png", semantic)
|
||||
zf.writestr(f"semantic_frame_{frame.frame_index:06d}.png", encoded.tobytes())
|
||||
zf.writestr(f"{semantic_prefix}{semantic_file_stem}_{frame.frame_index:06d}.png", encoded.tobytes())
|
||||
|
||||
if include_semantic:
|
||||
zf.writestr(
|
||||
"semantic_classes.json",
|
||||
f"{semantic_prefix}semantic_classes.json",
|
||||
json.dumps({"classes": semantic_classes}, ensure_ascii=False, indent=2).encode("utf-8"),
|
||||
)
|
||||
return semantic_classes
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/coco",
|
||||
summary="Export annotations in COCO format",
|
||||
)
|
||||
def export_coco(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
"""Export all annotations for a project as a COCO-format JSON file."""
|
||||
project = _project_or_404(project_id, db, current_user)
|
||||
frames = _project_frames(project_id, db)
|
||||
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
|
||||
templates = db.query(Template).all()
|
||||
coco = _build_coco(project, frames, annotations, templates)
|
||||
|
||||
data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8")
|
||||
filename = f"project_{project_id}_coco.json"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(data),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/masks",
|
||||
summary="Export PNG masks as a ZIP archive",
|
||||
)
|
||||
def export_masks(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
"""Export individual masks plus z-index fused semantic masks inside a ZIP."""
|
||||
_project_or_404(project_id, db, current_user)
|
||||
frames = _project_frames(project_id, db)
|
||||
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
_write_mask_pngs(
|
||||
zf,
|
||||
frames,
|
||||
annotations,
|
||||
mask_type="both",
|
||||
semantic_prefix="",
|
||||
individual_prefix="",
|
||||
)
|
||||
|
||||
zip_buffer.seek(0)
|
||||
filename = f"project_{project_id}_masks.zip"
|
||||
@@ -284,3 +689,71 @@ def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingRes
|
||||
media_type="application/zip",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/results",
|
||||
summary="Export segmentation results as a ZIP archive",
|
||||
)
|
||||
def export_results(
|
||||
project_id: int,
|
||||
scope: str = Query("all", pattern="^(all|range|current)$"),
|
||||
mask_type: str = Query("both", pattern="^(separate|gt_label|pro_label|mix_label|both|all)$"),
|
||||
outputs: str | None = Query(None),
|
||||
mix_opacity: float = Query(0.3, ge=0.0, le=1.0),
|
||||
start_frame: int | None = Query(None, ge=1),
|
||||
end_frame: int | None = Query(None, ge=1),
|
||||
frame_id: int | None = Query(None, ge=1),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
"""Export JSON annotations plus selected PNG mask outputs inside one ZIP.
|
||||
|
||||
`scope=all` exports the whole video. `scope=range` uses 1-based frame
|
||||
numbers from the sorted project frame sequence. `scope=current` uses the
|
||||
concrete backend `frame_id`.
|
||||
"""
|
||||
project = _project_or_404(project_id, db, current_user)
|
||||
frames = _filter_frames(
|
||||
_project_frames(project_id, db),
|
||||
scope=scope,
|
||||
start_frame=start_frame,
|
||||
end_frame=end_frame,
|
||||
frame_id=frame_id,
|
||||
)
|
||||
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
|
||||
templates = db.query(Template).all()
|
||||
coco = _build_coco(project, frames, annotations, templates)
|
||||
class_values, class_mapping = _build_gt_class_mapping(annotations)
|
||||
selected_outputs = _parse_result_outputs(mask_type, outputs)
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
zf.writestr(
|
||||
"annotations_coco.json",
|
||||
json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8"),
|
||||
)
|
||||
zf.writestr(
|
||||
"maskid_GT像素值_类别映射.json",
|
||||
json.dumps({"classes": class_mapping}, ensure_ascii=False, indent=2).encode("utf-8"),
|
||||
)
|
||||
original_images = _write_original_frames(zf, project, frames)
|
||||
_write_result_mask_outputs(
|
||||
zf,
|
||||
project,
|
||||
frames,
|
||||
annotations,
|
||||
outputs=selected_outputs,
|
||||
class_values=class_values,
|
||||
class_mapping=class_mapping,
|
||||
original_images=original_images,
|
||||
mix_opacity=mix_opacity,
|
||||
)
|
||||
|
||||
zip_buffer.seek(0)
|
||||
filename = _segmentation_results_filename(project, frames)
|
||||
return StreamingResponse(
|
||||
zip_buffer,
|
||||
media_type="application/zip",
|
||||
headers={"Content-Disposition": _download_content_disposition(filename)},
|
||||
)
|
||||
|
||||
@@ -9,8 +9,9 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import upload_file, get_presigned_url
|
||||
from models import ProcessingTask, Project
|
||||
from models import ProcessingTask, Project, User
|
||||
from progress_events import publish_task_progress_event
|
||||
from routers.auth import require_editor
|
||||
from schemas import ProcessingTaskOut
|
||||
from statuses import PROJECT_STATUS_PARSING, PROJECT_STATUS_PENDING, TASK_STATUS_QUEUED
|
||||
from worker_tasks import parse_project_media
|
||||
@@ -34,6 +35,7 @@ async def upload_media(
|
||||
file: UploadFile = File(...),
|
||||
project_id: Optional[int] = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Accept a video, image, or DICOM file and store it in MinIO.
|
||||
|
||||
@@ -62,13 +64,15 @@ async def upload_media(
|
||||
file_url = get_presigned_url(object_name, expires=3600)
|
||||
|
||||
if project_id:
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if project:
|
||||
project.video_path = object_name
|
||||
db.commit()
|
||||
logger.info("Linked upload to project_id=%s", project_id)
|
||||
else:
|
||||
logger.warning("Project id=%s not found for upload linkage", project_id)
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
project.video_path = object_name
|
||||
db.commit()
|
||||
logger.info("Linked upload to project_id=%s", project_id)
|
||||
else:
|
||||
# Auto-create a project named after the file
|
||||
project = Project(
|
||||
@@ -77,6 +81,7 @@ async def upload_media(
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
video_path=object_name,
|
||||
source_type="video",
|
||||
owner_user_id=current_user.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
@@ -109,6 +114,7 @@ async def upload_dicom_batch(
|
||||
files: List[UploadFile] = File(...),
|
||||
project_id: Optional[int] = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Upload multiple .dcm files for a DICOM series.
|
||||
|
||||
@@ -121,7 +127,10 @@ async def upload_dicom_batch(
|
||||
uploaded = []
|
||||
|
||||
if project_id:
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
else:
|
||||
@@ -132,6 +141,7 @@ async def upload_dicom_batch(
|
||||
description=f"DICOM series with {len(files)} files",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="dicom",
|
||||
owner_user_id=current_user.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
@@ -173,13 +183,17 @@ def parse_media(
|
||||
max_frames: Optional[int] = Query(None, gt=0),
|
||||
target_width: int = Query(640, ge=64, le=4096),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> ProcessingTask:
|
||||
"""Create a background task for media frame extraction.
|
||||
|
||||
The Celery worker performs the heavy FFmpeg/OpenCV/pydicom work and
|
||||
updates the persisted task record as it progresses.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
|
||||
@@ -7,7 +7,8 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Project, Frame
|
||||
from models import Project, Frame, User
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut
|
||||
from minio_client import get_presigned_url
|
||||
|
||||
@@ -24,9 +25,13 @@ router = APIRouter(prefix="/api/projects", tags=["Projects"])
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a new project",
|
||||
)
|
||||
def create_project(payload: ProjectCreate, db: Session = Depends(get_db)) -> Project:
|
||||
def create_project(
|
||||
payload: ProjectCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Project:
|
||||
"""Create a new segmentation project."""
|
||||
project = Project(**payload.model_dump())
|
||||
project = Project(**payload.model_dump(), owner_user_id=current_user.id)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
@@ -39,9 +44,20 @@ def create_project(payload: ProjectCreate, db: Session = Depends(get_db)) -> Pro
|
||||
response_model=List[ProjectOut],
|
||||
summary="List all projects",
|
||||
)
|
||||
def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)) -> List[Project]:
|
||||
def list_projects(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[Project]:
|
||||
"""Retrieve a paginated list of projects."""
|
||||
projects = db.query(Project).offset(skip).limit(limit).all()
|
||||
projects = (
|
||||
db.query(Project)
|
||||
.filter(Project.owner_user_id == current_user.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
for p in projects:
|
||||
p.frame_count = len(p.frames)
|
||||
if p.thumbnail_url:
|
||||
@@ -54,9 +70,16 @@ def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)
|
||||
response_model=ProjectOut,
|
||||
summary="Get a single project",
|
||||
)
|
||||
def get_project(project_id: int, db: Session = Depends(get_db)) -> Project:
|
||||
def get_project(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Project:
|
||||
"""Retrieve a project by its ID."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
project.frame_count = len(project.frames)
|
||||
@@ -74,9 +97,13 @@ def update_project(
|
||||
project_id: int,
|
||||
payload: ProjectUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Project:
|
||||
"""Update project fields partially."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
@@ -94,9 +121,16 @@ def update_project(
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a project",
|
||||
)
|
||||
def delete_project(project_id: int, db: Session = Depends(get_db)) -> None:
|
||||
def delete_project(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> None:
|
||||
"""Delete a project and all related frames and annotations."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
@@ -118,9 +152,13 @@ def create_frame(
|
||||
project_id: int,
|
||||
payload: FrameCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Frame:
|
||||
"""Register a new frame under a project."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
@@ -141,9 +179,13 @@ def list_frames(
|
||||
skip: int = 0,
|
||||
limit: int = 1000,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[Frame]:
|
||||
"""Retrieve all frames belonging to a project."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
@@ -165,11 +207,21 @@ def list_frames(
|
||||
response_model=FrameOut,
|
||||
summary="Get a single frame",
|
||||
)
|
||||
def get_frame(project_id: int, frame_id: int, db: Session = Depends(get_db)) -> Frame:
|
||||
def get_frame(
|
||||
project_id: int,
|
||||
frame_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Frame:
|
||||
"""Retrieve a specific frame by ID."""
|
||||
frame = (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id, Frame.id == frame_id)
|
||||
.join(Project, Project.id == Frame.project_id)
|
||||
.filter(
|
||||
Frame.project_id == project_id,
|
||||
Frame.id == frame_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not frame:
|
||||
|
||||
@@ -9,8 +9,9 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from celery_app import celery_app
|
||||
from database import get_db
|
||||
from models import ProcessingTask, Project
|
||||
from models import ProcessingTask, Project, User
|
||||
from progress_events import publish_task_progress_event
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import ProcessingTaskOut
|
||||
from statuses import (
|
||||
PROJECT_STATUS_PARSING,
|
||||
@@ -31,8 +32,16 @@ def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _get_task_or_404(task_id: int, db: Session) -> ProcessingTask:
|
||||
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
|
||||
def _get_task_or_404(task_id: int, db: Session, current_user: User) -> ProcessingTask:
|
||||
task = (
|
||||
db.query(ProcessingTask)
|
||||
.outerjoin(Project, Project.id == ProcessingTask.project_id)
|
||||
.filter(
|
||||
ProcessingTask.id == task_id,
|
||||
(ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return task
|
||||
@@ -48,9 +57,12 @@ def list_tasks(
|
||||
status: str | None = None,
|
||||
limit: int = 50,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[ProcessingTask]:
|
||||
"""Return recent background processing tasks."""
|
||||
query = db.query(ProcessingTask)
|
||||
query = db.query(ProcessingTask).outerjoin(Project, Project.id == ProcessingTask.project_id).filter(
|
||||
(ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id)
|
||||
)
|
||||
if project_id is not None:
|
||||
query = query.filter(ProcessingTask.project_id == project_id)
|
||||
if status is not None:
|
||||
@@ -59,15 +71,23 @@ def list_tasks(
|
||||
|
||||
|
||||
@router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task")
|
||||
def get_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
def get_task(
|
||||
task_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> ProcessingTask:
|
||||
"""Return one background task by id."""
|
||||
return _get_task_or_404(task_id, db)
|
||||
return _get_task_or_404(task_id, db, current_user)
|
||||
|
||||
|
||||
@router.post("/{task_id}/cancel", response_model=ProcessingTaskOut, summary="Cancel processing task")
|
||||
def cancel_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
def cancel_task(
|
||||
task_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> ProcessingTask:
|
||||
"""Cancel a queued/running background task and revoke the Celery job when possible."""
|
||||
task = _get_task_or_404(task_id, db)
|
||||
task = _get_task_or_404(task_id, db, current_user)
|
||||
if task.status not in TASK_ACTIVE_STATUSES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
@@ -95,9 +115,13 @@ def cancel_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
|
||||
|
||||
@router.post("/{task_id}/retry", response_model=ProcessingTaskOut, status_code=status.HTTP_202_ACCEPTED, summary="Retry processing task")
|
||||
def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
def retry_task(
|
||||
task_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> ProcessingTask:
|
||||
"""Create a fresh queued task from a failed or cancelled task."""
|
||||
previous = _get_task_or_404(task_id, db)
|
||||
previous = _get_task_or_404(task_id, db, current_user)
|
||||
if previous.status not in {TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
@@ -106,7 +130,10 @@ def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
if previous.project_id is None:
|
||||
raise HTTPException(status_code=400, detail="Task has no project_id")
|
||||
|
||||
project = db.query(Project).filter(Project.id == previous.project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == previous.project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
is_propagation_task = previous.task_type == "propagate_masks"
|
||||
|
||||
@@ -4,10 +4,12 @@ import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Template
|
||||
from models import Template, User
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import TemplateCreate, TemplateOut, TemplateUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -40,11 +42,15 @@ def _unpack_template(template: Template) -> Template:
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a new template",
|
||||
)
|
||||
def create_template(payload: TemplateCreate, db: Session = Depends(get_db)) -> Template:
|
||||
def create_template(
|
||||
payload: TemplateCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Template:
|
||||
"""Create a new ontology template / segmentation class."""
|
||||
data = payload.model_dump()
|
||||
data = _pack_mapping_rules(data)
|
||||
template = Template(**data)
|
||||
template = Template(**data, owner_user_id=current_user.id)
|
||||
db.add(template)
|
||||
db.commit()
|
||||
db.refresh(template)
|
||||
@@ -62,9 +68,16 @@ def list_templates(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[Template]:
|
||||
"""Retrieve all ontology templates."""
|
||||
templates = db.query(Template).offset(skip).limit(limit).all()
|
||||
templates = (
|
||||
db.query(Template)
|
||||
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
for t in templates:
|
||||
_unpack_template(t)
|
||||
return templates
|
||||
@@ -75,9 +88,16 @@ def list_templates(
|
||||
response_model=TemplateOut,
|
||||
summary="Get a single template",
|
||||
)
|
||||
def get_template(template_id: int, db: Session = Depends(get_db)) -> Template:
|
||||
def get_template(
|
||||
template_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Template:
|
||||
"""Retrieve a template by its ID."""
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
template = db.query(Template).filter(
|
||||
Template.id == template_id,
|
||||
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
_unpack_template(template)
|
||||
@@ -93,9 +113,13 @@ def update_template(
|
||||
template_id: int,
|
||||
payload: TemplateUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Template:
|
||||
"""Update template fields partially."""
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
template = db.query(Template).filter(
|
||||
Template.id == template_id,
|
||||
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
@@ -118,9 +142,16 @@ def update_template(
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a template",
|
||||
)
|
||||
def delete_template(template_id: int, db: Session = Depends(get_db)) -> None:
|
||||
def delete_template(
|
||||
template_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> None:
|
||||
"""Delete a template. Associated annotations will have template_id set to NULL."""
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
template = db.query(Template).filter(
|
||||
Template.id == template_id,
|
||||
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
|
||||
@@ -5,6 +5,55 @@ from typing import Optional, Any
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth / user schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class UserOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
username: str
|
||||
role: str
|
||||
is_active: int
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
token: str
|
||||
token_type: str = "bearer"
|
||||
username: str
|
||||
user: UserOut
|
||||
|
||||
|
||||
class AdminUserCreate(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
role: str = "annotator"
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class AdminUserUpdate(BaseModel):
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class AuditLogOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
actor_user_id: Optional[int] = None
|
||||
action: str
|
||||
target_type: Optional[str] = None
|
||||
target_id: Optional[str] = None
|
||||
detail: Optional[dict[str, Any]] = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class DemoFactoryResetRequest(BaseModel):
|
||||
confirmation: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Project schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -38,11 +87,19 @@ class ProjectOut(ProjectBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
owner_user_id: Optional[int] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
frame_count: int = 0
|
||||
|
||||
|
||||
class DemoFactoryResetOut(BaseModel):
|
||||
admin_user: UserOut
|
||||
project: ProjectOut
|
||||
deleted_counts: dict[str, int]
|
||||
message: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frame schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -98,6 +155,7 @@ class TemplateOut(TemplateBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
owner_user_id: Optional[int] = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
|
||||
@@ -102,8 +102,14 @@ def _normalize_smoothing_options(value: Any) -> dict[str, Any] | None:
|
||||
return {"strength": round(strength, 2), "method": method}
|
||||
|
||||
|
||||
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int) -> list[list[float]]:
|
||||
def _smoothing_ratio(strength: float, curve: float = 1.65) -> float:
|
||||
normalized = max(0.0, min(float(strength or 0.0), 100.0)) / 100.0
|
||||
return normalized ** curve
|
||||
|
||||
|
||||
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int, corner_cut: float = 0.25) -> list[list[float]]:
|
||||
points = _normalize_polygon(polygon)
|
||||
q = max(0.02, min(float(corner_cut), 0.25))
|
||||
for _ in range(max(0, iterations)):
|
||||
if len(points) < 3:
|
||||
break
|
||||
@@ -111,12 +117,12 @@ def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int) -> list
|
||||
for index, current in enumerate(points):
|
||||
following = points[(index + 1) % len(points)]
|
||||
next_points.append([
|
||||
_clamp01(0.75 * current[0] + 0.25 * following[0]),
|
||||
_clamp01(0.75 * current[1] + 0.25 * following[1]),
|
||||
_clamp01((1.0 - q) * current[0] + q * following[0]),
|
||||
_clamp01((1.0 - q) * current[1] + q * following[1]),
|
||||
])
|
||||
next_points.append([
|
||||
_clamp01(0.25 * current[0] + 0.75 * following[0]),
|
||||
_clamp01(0.25 * current[1] + 0.75 * following[1]),
|
||||
_clamp01(q * current[0] + (1.0 - q) * following[0]),
|
||||
_clamp01(q * current[1] + (1.0 - q) * following[1]),
|
||||
])
|
||||
points = next_points
|
||||
return points
|
||||
@@ -127,7 +133,7 @@ def _simplify_polygon(polygon: list[list[float]], strength: float) -> list[list[
|
||||
return polygon
|
||||
contour = np.array([[[point[0], point[1]]] for point in polygon], dtype=np.float32)
|
||||
arc_length = cv2.arcLength(contour, True)
|
||||
epsilon = arc_length * (0.001 + (strength / 100.0) * 0.006)
|
||||
epsilon = arc_length * (0.00015 + _smoothing_ratio(strength) * 0.00735)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True).reshape(-1, 2)
|
||||
if len(approx) < 3:
|
||||
return polygon
|
||||
@@ -140,9 +146,25 @@ def _smooth_polygon(polygon: list[list[float]], smoothing: dict[str, Any] | None
|
||||
strength = float(smoothing.get("strength") or 0.0)
|
||||
if strength <= 0:
|
||||
return _normalize_polygon(polygon)
|
||||
iterations = max(1, min(3, int(strength // 35) + 1))
|
||||
smoothed = _chaikin_smooth_polygon(polygon, iterations)
|
||||
simplified = _simplify_polygon(smoothed, strength)
|
||||
effective_strength = _smoothing_ratio(strength, curve=1.45) * 100.0
|
||||
if effective_strength >= 85:
|
||||
iterations = 4
|
||||
elif effective_strength >= 55:
|
||||
iterations = 3
|
||||
elif effective_strength >= 25:
|
||||
iterations = 2
|
||||
else:
|
||||
iterations = 1
|
||||
corner_cut = 0.03 + _smoothing_ratio(strength, curve=1.35) * 0.22
|
||||
normalized = _normalize_polygon(polygon)
|
||||
pre_simplified = _simplify_polygon(normalized, effective_strength * 0.25)
|
||||
smoothed = _chaikin_smooth_polygon(pre_simplified, iterations, corner_cut)
|
||||
simplified = _simplify_polygon(smoothed, effective_strength)
|
||||
if len(simplified) > len(normalized):
|
||||
for fallback_strength in (25.0, 35.0, 50.0, 70.0, 90.0, 100.0):
|
||||
simplified = _simplify_polygon(simplified, max(effective_strength, fallback_strength))
|
||||
if len(simplified) <= len(normalized):
|
||||
break
|
||||
return simplified if len(simplified) >= 3 else _normalize_polygon(polygon)
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ if str(BACKEND_DIR) not in sys.path:
|
||||
|
||||
from database import Base, get_db # noqa: E402
|
||||
from main import websocket_progress # noqa: E402
|
||||
from routers import ai, auth, dashboard, export, media, projects, tasks, templates # noqa: E402
|
||||
from routers import admin, ai, auth, dashboard, export, media, projects, tasks, templates # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -32,6 +32,7 @@ def db_session() -> Iterator[Session]:
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
session = TestingSessionLocal()
|
||||
auth.ensure_default_admin(session)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
@@ -56,6 +57,7 @@ def app(db_session: Session) -> FastAPI:
|
||||
test_app.include_router(export.router)
|
||||
test_app.include_router(dashboard.router)
|
||||
test_app.include_router(tasks.router)
|
||||
test_app.include_router(admin.router)
|
||||
|
||||
@test_app.get("/health")
|
||||
def health_check() -> dict[str, str]:
|
||||
@@ -67,6 +69,10 @@ def app(db_session: Session) -> FastAPI:
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(app: FastAPI) -> Iterator[TestClient]:
|
||||
def client(app: FastAPI, db_session: Session) -> Iterator[TestClient]:
|
||||
with TestClient(app) as test_client:
|
||||
admin = auth.ensure_default_admin(db_session)
|
||||
test_client.headers.update({
|
||||
"Authorization": f"Bearer {auth.create_access_token(admin)}"
|
||||
})
|
||||
yield test_client
|
||||
|
||||
158
backend/tests/test_admin.py
Normal file
158
backend/tests/test_admin.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from models import Annotation, AuditLog, Frame, Mask, ProcessingTask, Project, Template, User
|
||||
from routers.auth import create_access_token, hash_password
|
||||
from statuses import PROJECT_STATUS_PENDING
|
||||
|
||||
|
||||
def test_admin_user_management_and_audit_logs(client, db_session):
|
||||
created = client.post("/api/admin/users", json={
|
||||
"username": "doctor",
|
||||
"password": "secret123",
|
||||
"role": "annotator",
|
||||
"is_active": True,
|
||||
})
|
||||
assert created.status_code == 201
|
||||
user_id = created.json()["id"]
|
||||
|
||||
updated = client.patch(f"/api/admin/users/{user_id}", json={
|
||||
"role": "viewer",
|
||||
"password": "newsecret",
|
||||
"is_active": False,
|
||||
})
|
||||
assert updated.status_code == 200
|
||||
assert updated.json()["role"] == "viewer"
|
||||
assert updated.json()["is_active"] == 0
|
||||
|
||||
users = client.get("/api/admin/users")
|
||||
assert users.status_code == 200
|
||||
assert any(user["username"] == "doctor" for user in users.json())
|
||||
|
||||
deleted = client.delete(f"/api/admin/users/{user_id}")
|
||||
assert deleted.status_code == 204
|
||||
|
||||
logs = client.get("/api/admin/audit-logs")
|
||||
assert logs.status_code == 200
|
||||
actions = [log["action"] for log in logs.json()]
|
||||
assert "admin.user_created" in actions
|
||||
assert "admin.user_updated" in actions
|
||||
assert "admin.user_deleted" in actions
|
||||
|
||||
|
||||
def test_admin_routes_require_admin_role(client, db_session):
|
||||
user = User(username="viewer", password_hash=hash_password("secret123"), role="viewer", is_active=1)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
original_auth = client.headers["Authorization"]
|
||||
client.headers.update({"Authorization": f"Bearer {create_access_token(user)}"})
|
||||
try:
|
||||
response = client.get("/api/admin/users")
|
||||
assert response.status_code == 403
|
||||
finally:
|
||||
client.headers.update({"Authorization": original_auth})
|
||||
|
||||
|
||||
def test_viewer_role_is_read_only_for_business_mutations(client, db_session):
|
||||
project = client.post("/api/projects", json={"name": "Readonly Check"}).json()
|
||||
user = User(username="readonly", password_hash=hash_password("secret123"), role="viewer", is_active=1)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
original_auth = client.headers["Authorization"]
|
||||
client.headers.update({"Authorization": f"Bearer {create_access_token(user)}"})
|
||||
try:
|
||||
assert client.get("/api/projects").status_code == 200
|
||||
assert client.post("/api/projects", json={"name": "Nope"}).status_code == 403
|
||||
assert client.patch(f"/api/projects/{project['id']}", json={"name": "Nope"}).status_code == 403
|
||||
assert client.post("/api/ai/annotate", json={"project_id": project["id"]}).status_code == 403
|
||||
finally:
|
||||
client.headers.update({"Authorization": original_auth})
|
||||
|
||||
|
||||
def test_admin_cannot_delete_self_or_user_with_projects(client, db_session):
|
||||
me = client.get("/api/auth/me").json()
|
||||
assert client.delete(f"/api/admin/users/{me['id']}").status_code == 400
|
||||
|
||||
user = User(username="owner", password_hash=hash_password("secret123"), role="annotator", is_active=1)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
db_session.add(Project(name="Owned", owner_user_id=user.id))
|
||||
db_session.commit()
|
||||
|
||||
response = client.delete(f"/api/admin/users/{user.id}")
|
||||
assert response.status_code == 409
|
||||
|
||||
|
||||
def test_demo_factory_reset_leaves_admin_and_unparsed_demo_video(client, db_session, monkeypatch, tmp_path):
|
||||
video_path = tmp_path / "Data_MyVideo_1.mp4"
|
||||
video_path.write_bytes(b"demo-video")
|
||||
monkeypatch.setattr("routers.admin.settings.demo_video_path", str(video_path))
|
||||
uploaded = []
|
||||
monkeypatch.setattr("routers.admin.upload_file", lambda object_name, data, content_type, length: uploaded.append({
|
||||
"object_name": object_name,
|
||||
"data": data,
|
||||
"content_type": content_type,
|
||||
"length": length,
|
||||
}))
|
||||
|
||||
extra_user = User(username="doctor", password_hash=hash_password("secret123"), role="annotator", is_active=1)
|
||||
db_session.add(extra_user)
|
||||
db_session.commit()
|
||||
db_session.refresh(extra_user)
|
||||
old_project = Project(name="Old", owner_user_id=extra_user.id, video_path="uploads/old.mp4")
|
||||
db_session.add(old_project)
|
||||
db_session.commit()
|
||||
db_session.refresh(old_project)
|
||||
frame = Frame(project_id=old_project.id, frame_index=0, image_url="frames/old.jpg")
|
||||
db_session.add(frame)
|
||||
task = ProcessingTask(task_type="parse_video", project_id=old_project.id)
|
||||
private_template = Template(
|
||||
name="Private",
|
||||
description="private",
|
||||
color="#fff",
|
||||
z_index=1,
|
||||
owner_user_id=extra_user.id,
|
||||
)
|
||||
db_session.add_all([task, private_template])
|
||||
db_session.commit()
|
||||
db_session.refresh(frame)
|
||||
annotation = Annotation(project_id=old_project.id, frame_id=frame.id, mask_data={"label": "old"})
|
||||
db_session.add(annotation)
|
||||
db_session.commit()
|
||||
db_session.refresh(annotation)
|
||||
db_session.add(Mask(annotation_id=annotation.id, mask_url="masks/old.png"))
|
||||
db_session.add(AuditLog(actor_user_id=extra_user.id, action="old.audit"))
|
||||
db_session.commit()
|
||||
|
||||
response = client.post("/api/admin/demo-factory-reset", json={"confirmation": "RESET_DEMO_FACTORY"})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["message"] == "演示环境已恢复出厂设置"
|
||||
assert data["admin_user"]["username"] == "admin"
|
||||
assert data["project"]["name"] == "Data_MyVideo_1"
|
||||
assert data["project"]["status"] == PROJECT_STATUS_PENDING
|
||||
assert data["project"]["frame_count"] == 0
|
||||
assert data["project"]["video_path"] == f"uploads/{data['project']['id']}/Data_MyVideo_1.mp4"
|
||||
assert uploaded == [{
|
||||
"object_name": data["project"]["video_path"],
|
||||
"data": b"demo-video",
|
||||
"content_type": "video/mp4",
|
||||
"length": len(b"demo-video"),
|
||||
}]
|
||||
|
||||
assert [user.username for user in db_session.query(User).all()] == ["admin"]
|
||||
assert db_session.query(Project).count() == 1
|
||||
assert db_session.query(Frame).count() == 0
|
||||
assert db_session.query(Annotation).count() == 0
|
||||
assert db_session.query(Mask).count() == 0
|
||||
assert db_session.query(ProcessingTask).count() == 0
|
||||
assert db_session.query(Template).filter(Template.owner_user_id.is_not(None)).count() == 0
|
||||
assert db_session.query(AuditLog).count() == 1
|
||||
assert db_session.query(AuditLog).first().action == "admin.demo_factory_reset"
|
||||
|
||||
|
||||
def test_demo_factory_reset_requires_exact_confirmation(client):
|
||||
response = client.post("/api/admin/demo-factory-reset", json={"confirmation": "reset"})
|
||||
|
||||
assert response.status_code == 400
|
||||
@@ -223,6 +223,88 @@ def test_analyze_mask_returns_backend_geometry_properties(client):
|
||||
assert body["message"] == "已从后端重新提取几何拓扑锚点"
|
||||
|
||||
|
||||
def test_analyze_mask_reports_actual_polygon_anchor_count(client):
|
||||
_, frame, _ = _create_project_and_frame(client)
|
||||
polygon = [[0.1 + index * 0.005, 0.1 + (0.01 if index % 2 else 0)] for index in range(80)]
|
||||
|
||||
response = client.post("/api/ai/analyze-mask", json={
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [polygon],
|
||||
"label": "AI Mask",
|
||||
"color": "#06b6d4",
|
||||
},
|
||||
"points": [[0.2, 0.2]],
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["topology_anchor_count"] == len(polygon)
|
||||
assert len(body["topology_anchors"]) <= 64
|
||||
|
||||
|
||||
def test_smooth_mask_simplifies_noisy_ai_polygon(client):
|
||||
_, frame, _ = _create_project_and_frame(client)
|
||||
polygon = []
|
||||
for index in range(20):
|
||||
polygon.append([0.1 + index * 0.02, 0.1 + (0.01 if index % 2 else 0)])
|
||||
for index in range(20):
|
||||
polygon.append([0.5 + (0.01 if index % 2 else 0), 0.1 + index * 0.02])
|
||||
for index in range(20):
|
||||
polygon.append([0.5 - index * 0.02, 0.5 + (0.01 if index % 2 else 0)])
|
||||
for index in range(20):
|
||||
polygon.append([0.1 + (0.01 if index % 2 else 0), 0.5 - index * 0.02])
|
||||
|
||||
response = client.post("/api/ai/smooth-mask", json={
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [polygon],
|
||||
"label": "AI Mask",
|
||||
"color": "#06b6d4",
|
||||
},
|
||||
"strength": 80,
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["topology_anchor_count"] == len(body["polygons"][0])
|
||||
assert len(body["polygons"][0]) < len(polygon)
|
||||
|
||||
|
||||
def test_smooth_mask_uses_eased_strength_curve(client):
|
||||
_, frame, _ = _create_project_and_frame(client)
|
||||
polygon = []
|
||||
for index in range(20):
|
||||
polygon.append([0.1 + index * 0.02, 0.1 + (0.01 if index % 2 else 0)])
|
||||
for index in range(20):
|
||||
polygon.append([0.5 + (0.01 if index % 2 else 0), 0.1 + index * 0.02])
|
||||
for index in range(20):
|
||||
polygon.append([0.5 - index * 0.02, 0.5 + (0.01 if index % 2 else 0)])
|
||||
for index in range(20):
|
||||
polygon.append([0.1 + (0.01 if index % 2 else 0), 0.5 - index * 0.02])
|
||||
|
||||
def smoothed_count(strength: int) -> int:
|
||||
response = client.post("/api/ai/smooth-mask", json={
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [polygon],
|
||||
"label": "AI Mask",
|
||||
"color": "#06b6d4",
|
||||
},
|
||||
"strength": strength,
|
||||
})
|
||||
assert response.status_code == 200
|
||||
return len(response.json()["polygons"][0])
|
||||
|
||||
low_count = smoothed_count(20)
|
||||
mid_count = smoothed_count(70)
|
||||
high_count = smoothed_count(95)
|
||||
|
||||
assert low_count <= len(polygon)
|
||||
assert mid_count < low_count
|
||||
assert high_count < mid_count
|
||||
|
||||
|
||||
def test_smooth_mask_returns_backend_smoothed_geometry(client):
|
||||
_, frame, _ = _create_project_and_frame(client)
|
||||
|
||||
@@ -311,6 +393,7 @@ def test_propagate_saves_tracked_annotations(client, monkeypatch):
|
||||
"color": "#ff0000",
|
||||
"class_metadata": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
|
||||
"template_id": None,
|
||||
"smoothing": {"strength": 45, "method": "chaikin"},
|
||||
},
|
||||
})
|
||||
|
||||
@@ -327,6 +410,9 @@ def test_propagate_saves_tracked_annotations(client, monkeypatch):
|
||||
assert saved["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation"
|
||||
assert saved["mask_data"]["class"]["name"] == "胆囊"
|
||||
assert saved["mask_data"]["score"] == 0.8
|
||||
assert saved["mask_data"]["geometry_smoothing"] == {"strength": 45.0, "method": "chaikin"}
|
||||
assert saved["mask_data"]["polygons"][0] != [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]
|
||||
assert len(saved["mask_data"]["polygons"][0]) > 3
|
||||
|
||||
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
||||
assert len(listing.json()) == 1
|
||||
@@ -490,8 +576,10 @@ def test_propagation_task_runner_saves_annotations_and_progress(client, db_sessi
|
||||
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
||||
assert listing.json()[0]["frame_id"] == frames[1]["id"]
|
||||
assert listing.json()[0]["mask_data"]["source"] == "sam2.1_hiera_tiny_propagation"
|
||||
stored_polygon = listing.json()[0]["mask_data"]["polygons"][0]
|
||||
assert listing.json()[0]["mask_data"]["geometry_smoothing"] == {"strength": 40.0, "method": "chaikin"}
|
||||
assert len(listing.json()[0]["mask_data"]["polygons"][0]) > 3
|
||||
assert stored_polygon != [[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]
|
||||
assert len(stored_polygon) > 3
|
||||
|
||||
|
||||
def test_propagation_task_runner_skips_unchanged_seed_and_replaces_changed_seed(client, db_session, monkeypatch):
|
||||
@@ -1084,3 +1172,156 @@ def test_import_gt_mask_splits_label_values(client):
|
||||
assert [item["mask_data"]["gt_label_value"] for item in body] == [1, 2]
|
||||
assert [item["mask_data"]["label"] for item in body] == ["GT Class 1", "GT Class 2"]
|
||||
assert all(len(item["points"]) == 1 for item in body)
|
||||
|
||||
|
||||
def test_import_gt_mask_preserves_low_value_gtlabel_png(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "GTLabel Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [
|
||||
{"id": "tumor", "name": "肿瘤", "color": "#ff0000", "zIndex": 10, "maskId": 1},
|
||||
],
|
||||
"rules": [],
|
||||
}).json()
|
||||
mask = np.zeros((360, 640), dtype=np.uint16)
|
||||
cv2.rectangle(mask, (40, 40), (140, 140), 1, thickness=-1)
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
|
||||
response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"template_id": str(template["id"]),
|
||||
"unknown_color_policy": "discard",
|
||||
},
|
||||
files={"file": ("GT_label.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
body = response.json()
|
||||
assert len(body) == 1
|
||||
assert body[0]["mask_data"]["gt_label_value"] == 1
|
||||
assert body[0]["mask_data"]["class"]["name"] == "肿瘤"
|
||||
assert body[0]["mask_data"]["class"]["maskId"] == 1
|
||||
|
||||
|
||||
def test_import_gt_mask_rejects_rgb_color_masks(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Color Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [
|
||||
{"id": "known", "name": "已知类别", "color": "#ff0000", "zIndex": 10, "maskId": 1},
|
||||
],
|
||||
"rules": [],
|
||||
}).json()
|
||||
mask = np.zeros((80, 120, 3), dtype=np.uint8)
|
||||
mask[10:40, 10:40] = [0, 0, 255] # BGR red -> #ff0000
|
||||
mask[40:70, 70:110] = [0, 255, 0] # BGR green -> unknown #00ff00
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
|
||||
response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"template_id": str(template["id"]),
|
||||
"unknown_color_policy": "discard",
|
||||
},
|
||||
files={"file": ("color-mask.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "RGB 三通道完全相同" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_import_gt_mask_reads_uint16_gt_label_and_maps_maskid_class(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Label Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [{"id": "tumor", "name": "肿瘤", "color": "#ff0000", "zIndex": 10, "maskId": 1}],
|
||||
"rules": [],
|
||||
}).json()
|
||||
mask = np.zeros((360, 640), dtype=np.uint16)
|
||||
cv2.rectangle(mask, (20, 20), (120, 120), 1, thickness=-1)
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
|
||||
response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"template_id": str(template["id"]),
|
||||
"unknown_color_policy": "discard",
|
||||
},
|
||||
files={"file": ("gt_label.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
body = response.json()
|
||||
assert len(body) == 1
|
||||
assert body[0]["mask_data"]["gt_label_value"] == 1
|
||||
assert body[0]["mask_data"]["label"] == "肿瘤"
|
||||
assert body[0]["mask_data"]["class"]["maskId"] == 1
|
||||
assert body[0]["mask_data"]["class"]["color"] == "#ff0000"
|
||||
|
||||
|
||||
def test_import_gt_mask_handles_unknown_maskid_policy_and_resizes_to_frame(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Color Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [{"id": "known", "name": "已定义", "color": "#ff0000", "zIndex": 10, "maskId": 1}],
|
||||
"rules": [],
|
||||
}).json()
|
||||
mask = np.zeros((90, 160, 3), dtype=np.uint8)
|
||||
cv2.rectangle(mask, (5, 5), (40, 40), (1, 1, 1), thickness=-1)
|
||||
cv2.rectangle(mask, (80, 5), (120, 40), (2, 2, 2), thickness=-1)
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
|
||||
discard_response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"template_id": str(template["id"]),
|
||||
"unknown_color_policy": "discard",
|
||||
},
|
||||
files={"file": ("colors.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert discard_response.status_code == 201
|
||||
assert [item["mask_data"]["label"] for item in discard_response.json()] == ["已定义"]
|
||||
assert discard_response.json()[0]["mask_data"]["gt_original_size"] == {"width": 160, "height": 90}
|
||||
assert discard_response.json()[0]["mask_data"]["gt_resized_to_frame"] is True
|
||||
assert discard_response.json()[0]["mask_data"]["image_size"] == {"width": 640, "height": 360}
|
||||
|
||||
undefined_response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"template_id": str(template["id"]),
|
||||
"unknown_color_policy": "undefined",
|
||||
},
|
||||
files={"file": ("colors.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert undefined_response.status_code == 201
|
||||
labels = {item["mask_data"]["label"] for item in undefined_response.json()}
|
||||
assert labels == {"已定义", "未定义类别 2"}
|
||||
unknown = next(item for item in undefined_response.json() if item["mask_data"]["label"].startswith("未定义"))
|
||||
assert unknown["mask_data"]["gt_unknown_class"] is True
|
||||
assert unknown["mask_data"]["gt_label_value"] == 2
|
||||
assert unknown["mask_data"]["gt_resized_to_frame"] is True
|
||||
|
||||
@@ -2,10 +2,11 @@ def test_login_success(client):
|
||||
response = client.post("/api/auth/login", json={"username": "admin", "password": "123456"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"token": "fake-jwt-token-for-admin",
|
||||
"username": "admin",
|
||||
}
|
||||
body = response.json()
|
||||
assert body["token"]
|
||||
assert body["token_type"] == "bearer"
|
||||
assert body["username"] == "admin"
|
||||
assert body["user"]["username"] == "admin"
|
||||
|
||||
|
||||
def test_login_rejects_invalid_credentials(client):
|
||||
@@ -13,3 +14,19 @@ def test_login_rejects_invalid_credentials(client):
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Invalid credentials"
|
||||
|
||||
|
||||
def test_me_returns_current_user(client):
|
||||
response = client.get("/api/auth/me")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["username"] == "admin"
|
||||
|
||||
|
||||
def test_business_routes_require_auth(app):
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
with TestClient(app) as unauthenticated:
|
||||
response = unauthenticated.get("/api/projects")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
@@ -1,19 +1,31 @@
|
||||
import zipfile
|
||||
import json
|
||||
from io import BytesIO
|
||||
from urllib.parse import unquote
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _fake_image_bytes(width=100, height=50, color=(255, 255, 255)):
|
||||
image = np.full((height, width, 3), color, dtype=np.uint8)
|
||||
_, encoded = cv2.imencode(".jpg", image)
|
||||
return encoded.tobytes()
|
||||
|
||||
|
||||
def _seed_export_data(client):
|
||||
project = client.post("/api/projects", json={"name": "Export Project"}).json()
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Export Project",
|
||||
"video_path": "uploads/1/clip.mp4",
|
||||
}).json()
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/0.jpg",
|
||||
"width": 100,
|
||||
"height": 50,
|
||||
"timestamp_ms": 1250.0,
|
||||
"source_frame_number": 37,
|
||||
}).json()
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Category",
|
||||
@@ -113,6 +125,328 @@ def test_export_masks_uses_z_index_for_semantic_fusion(client):
|
||||
assert semantic[10, 10] == high_value
|
||||
|
||||
|
||||
def test_export_results_zip_contains_coco_original_images_and_selected_mask_outputs(client, monkeypatch):
|
||||
project, _, _, annotation = _seed_export_data(client)
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes())
|
||||
|
||||
response = client.get(f"/api/export/{project['id']}/results?scope=all&mask_type=both")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("application/zip")
|
||||
with zipfile.ZipFile(BytesIO(response.content)) as archive:
|
||||
names = archive.namelist()
|
||||
frame_stem = "clip_0h00m01s250ms_frame000001"
|
||||
assert "annotations_coco.json" in names
|
||||
assert "maskid_GT像素值_类别映射.json" in names
|
||||
assert f"原始图片/{frame_stem}.jpg" in names
|
||||
assert f"分开Mask分割结果/{frame_stem}_分别导出/{frame_stem}_Category_maskid1.png" in names
|
||||
assert f"GT_label图/{frame_stem}.png" in names
|
||||
assert f"Pro_label彩色分割结果/{frame_stem}.png" in names
|
||||
assert f"Mix_label重叠覆盖彩色分割结果/{frame_stem}.png" in names
|
||||
coco = json.loads(archive.read("annotations_coco.json"))
|
||||
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
|
||||
label_bytes = np.frombuffer(archive.read(f"GT_label图/{frame_stem}.png"), dtype=np.uint8)
|
||||
gt_label = cv2.imdecode(label_bytes, cv2.IMREAD_UNCHANGED)
|
||||
pro_label = cv2.imdecode(
|
||||
np.frombuffer(archive.read(f"Pro_label彩色分割结果/{frame_stem}.png"), dtype=np.uint8),
|
||||
cv2.IMREAD_COLOR,
|
||||
)
|
||||
mix_label = cv2.imdecode(
|
||||
np.frombuffer(archive.read(f"Mix_label重叠覆盖彩色分割结果/{frame_stem}.png"), dtype=np.uint8),
|
||||
cv2.IMREAD_COLOR,
|
||||
)
|
||||
|
||||
assert coco["images"][0]["frame_index"] == 0
|
||||
assert coco["annotations"][0]["image_id"] == annotation["frame_id"]
|
||||
assert mapping["classes"] == [{
|
||||
"gt_pixel_value": 1,
|
||||
"maskid": 1,
|
||||
"chineseName": "Category",
|
||||
"className": "Category",
|
||||
"categoryName": "Category",
|
||||
"rgb": [6, 182, 212],
|
||||
"color": "#06b6d4",
|
||||
"key": f"template:{annotation['template_id']}",
|
||||
"template_id": annotation["template_id"],
|
||||
}]
|
||||
assert gt_label[0, 0] == 0
|
||||
assert gt_label[20, 50] == 1
|
||||
assert pro_label[20, 50].tolist() == [212, 182, 6]
|
||||
assert pro_label[0, 0].tolist() == [0, 0, 0]
|
||||
assert mix_label[20, 50].tolist() != [255, 255, 255]
|
||||
|
||||
|
||||
def test_export_results_uses_internal_layer_order_for_gt_pro_and_mix_outputs(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Layered Export Project",
|
||||
"video_path": "uploads/2/layered.mp4",
|
||||
}).json()
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/layered.jpg",
|
||||
"width": 20,
|
||||
"height": 20,
|
||||
"timestamp_ms": 0,
|
||||
"source_frame_number": 0,
|
||||
}).json()
|
||||
client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
|
||||
"label": "Low",
|
||||
"color": "#00ff00",
|
||||
"class": {"id": "low", "name": "Low", "color": "#00ff00", "zIndex": 10},
|
||||
},
|
||||
})
|
||||
client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.4, 0.4], [0.9, 0.4], [0.9, 0.9], [0.4, 0.9]]],
|
||||
"label": "High",
|
||||
"color": "#ff0000",
|
||||
"class": {"id": "high", "name": "High", "color": "#ff0000", "zIndex": 20},
|
||||
},
|
||||
})
|
||||
|
||||
response = client.get(
|
||||
f"/api/export/{project['id']}/results?scope=all&outputs=gt_label,pro_label,mix_label&mix_opacity=0.5",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
with zipfile.ZipFile(BytesIO(response.content)) as archive:
|
||||
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
|
||||
high_value = next(item["maskid"] for item in mapping["classes"] if item["key"] == "class:high")
|
||||
stem = "layered_0h00m00s000ms_frame000001"
|
||||
gt_label = cv2.imdecode(
|
||||
np.frombuffer(archive.read(f"GT_label图/{stem}.png"), dtype=np.uint8),
|
||||
cv2.IMREAD_UNCHANGED,
|
||||
)
|
||||
pro_label = cv2.imdecode(
|
||||
np.frombuffer(archive.read(f"Pro_label彩色分割结果/{stem}.png"), dtype=np.uint8),
|
||||
cv2.IMREAD_COLOR,
|
||||
)
|
||||
mix_label = cv2.imdecode(
|
||||
np.frombuffer(archive.read(f"Mix_label重叠覆盖彩色分割结果/{stem}.png"), dtype=np.uint8),
|
||||
cv2.IMREAD_COLOR,
|
||||
)
|
||||
|
||||
assert gt_label[10, 10] == high_value
|
||||
assert pro_label[10, 10].tolist() == [0, 0, 255]
|
||||
assert mix_label[10, 10].tolist() == [127, 127, 255]
|
||||
|
||||
|
||||
def test_export_results_supports_range_and_current_scope(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Scoped Export Project",
|
||||
"video_path": "uploads/9/scope.mp4",
|
||||
"parse_fps": 2,
|
||||
}).json()
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Scoped Category",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [],
|
||||
"rules": [],
|
||||
}).json()
|
||||
frames = []
|
||||
annotations = []
|
||||
for idx in range(3):
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": idx,
|
||||
"image_url": f"frames/{idx}.jpg",
|
||||
"width": 20,
|
||||
"height": 20,
|
||||
"timestamp_ms": idx * 500.0,
|
||||
"source_frame_number": idx * 10,
|
||||
}).json()
|
||||
frames.append(frame)
|
||||
annotations.append(client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"template_id": template["id"],
|
||||
"mask_data": {"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]]},
|
||||
}).json())
|
||||
|
||||
range_response = client.get(
|
||||
f"/api/export/{project['id']}/results?scope=range&start_frame=2&end_frame=3&mask_type=gt_label",
|
||||
)
|
||||
current_response = client.get(
|
||||
f"/api/export/{project['id']}/results?scope=current&frame_id={frames[1]['id']}&mask_type=separate",
|
||||
)
|
||||
|
||||
assert range_response.status_code == 200
|
||||
assert "Scoped_Export_Project_seg_T_0h00m00s500ms-0h00m01s000ms_P_2-3.zip" in unquote(
|
||||
range_response.headers["content-disposition"],
|
||||
)
|
||||
with zipfile.ZipFile(BytesIO(range_response.content)) as archive:
|
||||
names = archive.namelist()
|
||||
coco = json.loads(archive.read("annotations_coco.json"))
|
||||
assert "原始图片/scope_0h00m00s500ms_frame000002.jpg" in names
|
||||
assert "原始图片/scope_0h00m01s000ms_frame000003.jpg" in names
|
||||
assert "原始图片/scope_0h00m00s000ms_frame000001.jpg" not in names
|
||||
assert "GT_label图/scope_0h00m00s500ms_frame000002.png" in names
|
||||
assert "GT_label图/scope_0h00m01s000ms_frame000003.png" in names
|
||||
assert "GT_label图/scope_0h00m00s000ms_frame000001.png" not in names
|
||||
assert not any(name.startswith("分开Mask分割结果/") for name in names)
|
||||
assert not any(name.startswith("Pro_label彩色分割结果/") for name in names)
|
||||
assert not any(name.startswith("Mix_label重叠覆盖彩色分割结果/") for name in names)
|
||||
assert [image["frame_index"] for image in coco["images"]] == [1, 2]
|
||||
|
||||
assert current_response.status_code == 200
|
||||
with zipfile.ZipFile(BytesIO(current_response.content)) as archive:
|
||||
names = archive.namelist()
|
||||
coco = json.loads(archive.read("annotations_coco.json"))
|
||||
current_stem = "scope_0h00m00s500ms_frame000002"
|
||||
assert f"原始图片/{current_stem}.jpg" in names
|
||||
assert f"分开Mask分割结果/{current_stem}_分别导出/{current_stem}_Scoped_Category_maskid1.png" in names
|
||||
assert f"分开Mask分割结果/scope_0h00m00s000ms_frame000001_分别导出/scope_0h00m00s000ms_frame000001_Scoped_Category_maskid1.png" not in names
|
||||
assert not any(name.startswith("GT_label图/") for name in names)
|
||||
assert not any(name.startswith("Pro_label彩色分割结果/") for name in names)
|
||||
assert not any(name.startswith("Mix_label重叠覆盖彩色分割结果/") for name in names)
|
||||
assert [image["id"] for image in coco["images"]] == [frames[1]["id"]]
|
||||
|
||||
|
||||
def test_export_results_preserves_template_maskid_consistently_across_frames(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "MaskId Export Project",
|
||||
"video_path": "uploads/8/maskid-demo.mp4",
|
||||
"parse_fps": 1,
|
||||
}).json()
|
||||
frames = []
|
||||
for idx in range(2):
|
||||
frames.append(client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": idx,
|
||||
"image_url": f"frames/{idx}.jpg",
|
||||
"width": 20,
|
||||
"height": 20,
|
||||
"timestamp_ms": idx * 1000.0,
|
||||
"source_frame_number": idx,
|
||||
}).json())
|
||||
client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frames[-1]["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
|
||||
"label": "Tumor",
|
||||
"color": "#ff0000",
|
||||
"class": {"id": "tumor", "name": "Tumor", "color": "#ff0000", "maskId": 7, "zIndex": 30},
|
||||
},
|
||||
})
|
||||
|
||||
response = client.get(f"/api/export/{project['id']}/results?scope=all&mask_type=both")
|
||||
|
||||
assert response.status_code == 200
|
||||
with zipfile.ZipFile(BytesIO(response.content)) as archive:
|
||||
names = archive.namelist()
|
||||
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
|
||||
first_stem = "maskid-demo_0h00m00s000ms_frame000001"
|
||||
second_stem = "maskid-demo_0h00m01s000ms_frame000002"
|
||||
assert f"分开Mask分割结果/{first_stem}_分别导出/{first_stem}_Tumor_maskid7.png" in names
|
||||
assert f"分开Mask分割结果/{second_stem}_分别导出/{second_stem}_Tumor_maskid7.png" in names
|
||||
first_label = cv2.imdecode(np.frombuffer(archive.read(f"GT_label图/{first_stem}.png"), dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
second_label = cv2.imdecode(np.frombuffer(archive.read(f"GT_label图/{second_stem}.png"), dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
|
||||
assert mapping["classes"] == [{
|
||||
"gt_pixel_value": 7,
|
||||
"maskid": 7,
|
||||
"chineseName": "Tumor",
|
||||
"className": "Tumor",
|
||||
"categoryName": "",
|
||||
"rgb": [255, 0, 0],
|
||||
"color": "#ff0000",
|
||||
"key": "class:tumor",
|
||||
"template_id": None,
|
||||
}]
|
||||
assert first_label[5, 5] == 7
|
||||
assert second_label[5, 5] == 7
|
||||
|
||||
|
||||
def test_exported_gtlabel_round_trips_through_gt_mask_import_with_template_maskid(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.export.download_file", lambda object_name: _fake_image_bytes(20, 20))
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "GT Roundtrip Project",
|
||||
"video_path": "uploads/8/roundtrip.mp4",
|
||||
}).json()
|
||||
template = client.post("/api/templates", json={
|
||||
"name": "Roundtrip Template",
|
||||
"color": "#06b6d4",
|
||||
"z_index": 0,
|
||||
"classes": [
|
||||
{"id": "tumor", "name": "Tumor", "color": "#ff0000", "zIndex": 30, "maskId": 7},
|
||||
],
|
||||
"rules": [],
|
||||
}).json()
|
||||
source_frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/source.jpg",
|
||||
"width": 20,
|
||||
"height": 20,
|
||||
"timestamp_ms": 0,
|
||||
}).json()
|
||||
target_frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 1,
|
||||
"image_url": "frames/target.jpg",
|
||||
"width": 20,
|
||||
"height": 20,
|
||||
"timestamp_ms": 1000,
|
||||
}).json()
|
||||
client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": source_frame["id"],
|
||||
"template_id": template["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
|
||||
"label": "Tumor",
|
||||
"color": "#ff0000",
|
||||
"class": {"id": "tumor", "name": "Tumor", "color": "#ff0000", "maskId": 7, "zIndex": 30},
|
||||
},
|
||||
})
|
||||
|
||||
export_response = client.get(
|
||||
f"/api/export/{project['id']}/results?scope=current&frame_id={source_frame['id']}&outputs=gt_label",
|
||||
)
|
||||
|
||||
assert export_response.status_code == 200
|
||||
with zipfile.ZipFile(BytesIO(export_response.content)) as archive:
|
||||
stem = "roundtrip_0h00m00s000ms_frame000001"
|
||||
exported_gt_label = archive.read(f"GT_label图/{stem}.png")
|
||||
gt_label = cv2.imdecode(np.frombuffer(exported_gt_label, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
mapping = json.loads(archive.read("maskid_GT像素值_类别映射.json"))
|
||||
|
||||
assert gt_label[5, 5] == 7
|
||||
assert mapping["classes"][0]["maskid"] == 7
|
||||
|
||||
import_response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(target_frame["id"]),
|
||||
"template_id": str(template["id"]),
|
||||
"unknown_color_policy": "discard",
|
||||
},
|
||||
files={"file": ("exported_gt_label.png", exported_gt_label, "image/png")},
|
||||
)
|
||||
|
||||
assert import_response.status_code == 201
|
||||
imported = import_response.json()
|
||||
assert len(imported) == 1
|
||||
assert imported[0]["frame_id"] == target_frame["id"]
|
||||
assert imported[0]["mask_data"]["gt_label_value"] == 7
|
||||
assert imported[0]["mask_data"]["label"] == "Tumor"
|
||||
assert imported[0]["mask_data"]["class"]["maskId"] == 7
|
||||
|
||||
|
||||
def test_export_missing_project_returns_404(client):
|
||||
assert client.get("/api/export/999/coco").status_code == 404
|
||||
assert client.get("/api/export/999/masks").status_code == 404
|
||||
assert client.get("/api/export/999/results").status_code == 404
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from models import Annotation, Frame, Mask, ProcessingTask, Project
|
||||
from models import Annotation, Frame, Mask, ProcessingTask, Project, User
|
||||
from routers.auth import create_access_token, hash_password
|
||||
|
||||
|
||||
def test_project_crud_and_frames(client, monkeypatch):
|
||||
@@ -93,3 +94,33 @@ def test_project_and_frame_404s(client):
|
||||
}).status_code == 404
|
||||
assert client.get("/api/projects/999/frames").status_code == 404
|
||||
assert client.get("/api/projects/999/frames/1").status_code == 404
|
||||
|
||||
|
||||
def test_projects_are_scoped_to_authenticated_owner(client, db_session):
|
||||
owner_project = client.post("/api/projects", json={"name": "Owner Project"}).json()
|
||||
other_user = User(
|
||||
username="other",
|
||||
password_hash=hash_password("pass"),
|
||||
role="annotator",
|
||||
is_active=1,
|
||||
)
|
||||
db_session.add(other_user)
|
||||
db_session.commit()
|
||||
db_session.refresh(other_user)
|
||||
other_project = Project(name="Other Project", owner_user_id=other_user.id)
|
||||
db_session.add(other_project)
|
||||
db_session.commit()
|
||||
db_session.refresh(other_project)
|
||||
|
||||
listing = client.get("/api/projects")
|
||||
assert [project["id"] for project in listing.json()] == [owner_project["id"]]
|
||||
assert client.get(f"/api/projects/{other_project.id}").status_code == 404
|
||||
|
||||
original_auth = client.headers["Authorization"]
|
||||
client.headers.update({"Authorization": f"Bearer {create_access_token(other_user)}"})
|
||||
try:
|
||||
other_listing = client.get("/api/projects")
|
||||
assert [project["id"] for project in other_listing.json()] == [other_project.id]
|
||||
assert client.get(f"/api/projects/{owner_project['id']}").status_code == 404
|
||||
finally:
|
||||
client.headers.update({"Authorization": original_auth})
|
||||
|
||||
Reference in New Issue
Block a user