feat: 完善分割工作区导入导出与管理流程
- 新增基于 JWT 当前用户的登录恢复、角色权限、用户管理、审计日志和演示出厂重置后台接口与前端管理页。 - 重串 GT_label 导出和 GT Mask 导入逻辑:导出保留类别真实 maskid,导入仅接受灰度或 RGB 等通道 maskid 图,支持未知 maskid 策略、尺寸最近邻拉伸和导入预览。 - 统一分割结果导出体验:默认当前帧,按项目抽帧顺序和 XhXXmXXsXXXms 时间戳命名 ZIP 与图片,补齐 GT/Pro/Mix/分开 Mask 输出和映射 JSON。 - 调整工作区左侧工具栏:移除创建点/线段入口,新增画笔、橡皮擦及尺寸控制,并按绘制、布尔、导入/AI 工具分组分隔。 - 扩展 Canvas 编辑能力:画笔按语义分类绘制并可自动并入连通选中 mask,橡皮擦对选中区域扣除,优化布尔操作、选区、撤销重做和保存状态联动。 - 优化自动传播时间轴显示:同一蓝色系按传播新旧递进变暗,老传播记录达到阈值后统一旧记录色,并维护范围选择与清空后的历史显示。 - 将 AI 智能分割入口替换为更明确的 AI 元素图标,并同步侧栏、工作区和 AI 页面入口表现。 - 完善模板分类、maskid 工具函数、分类树联动、遮罩透明度、边缘平滑和传播链同步相关前端状态。 - 扩展后端项目、媒体、任务、Dashboard、模板和传播 runner 的用户隔离、任务控制、进度事件与兼容处理。 - 补充前后端测试,覆盖用户管理、GT_label 往返导入导出、GT Mask 校验和预览、画笔/橡皮擦、时间轴传播历史、导出范围、WebSocket 与 API 封装。 - 更新 AGENTS、README 和 doc 文档,记录当前接口契约、实现状态、测试计划、安装说明和 maskid/GT_label 规则。
This commit is contained in:
270
backend/routers/admin.py
Normal file
270
backend/routers/admin.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""Administrator-only user and audit management endpoints."""
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config import settings
|
||||
from database import get_db
|
||||
from minio_client import upload_file
|
||||
from models import Annotation, AuditLog, Frame, Mask, ProcessingTask, Project, Template, User
|
||||
from routers.auth import ensure_default_admin, hash_password, require_admin, write_audit_log
|
||||
from schemas import (
|
||||
AdminUserCreate,
|
||||
AdminUserUpdate,
|
||||
AuditLogOut,
|
||||
DemoFactoryResetOut,
|
||||
DemoFactoryResetRequest,
|
||||
UserOut,
|
||||
)
|
||||
from statuses import PROJECT_STATUS_PENDING
|
||||
|
||||
router = APIRouter(prefix="/api/admin", tags=["Admin"])
|
||||
|
||||
VALID_ROLES = {"admin", "annotator", "viewer"}
|
||||
DEMO_RESET_CONFIRMATION = "RESET_DEMO_FACTORY"
|
||||
DEMO_PROJECT_NAME = "Data_MyVideo_1"
|
||||
|
||||
|
||||
def _normalize_role(role: str | None) -> str:
|
||||
normalized = (role or "annotator").strip().lower()
|
||||
if normalized not in VALID_ROLES:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported role: {role}")
|
||||
return normalized
|
||||
|
||||
|
||||
@router.get("/users", response_model=List[UserOut], summary="List users")
|
||||
def list_users(
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> List[User]:
|
||||
"""Return all users for the administrator console."""
|
||||
_ = admin_user
|
||||
return db.query(User).order_by(User.id).all()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/users",
|
||||
response_model=UserOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create user",
|
||||
)
|
||||
def create_user(
|
||||
payload: AdminUserCreate,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> User:
|
||||
"""Create a user with an initial password and role."""
|
||||
username = payload.username.strip()
|
||||
if not username:
|
||||
raise HTTPException(status_code=400, detail="Username is required")
|
||||
if len(payload.password) < 6:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
|
||||
user = User(
|
||||
username=username,
|
||||
password_hash=hash_password(payload.password),
|
||||
role=_normalize_role(payload.role),
|
||||
is_active=1 if payload.is_active else 0,
|
||||
)
|
||||
db.add(user)
|
||||
try:
|
||||
db.commit()
|
||||
except IntegrityError as exc:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=409, detail="Username already exists") from exc
|
||||
db.refresh(user)
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=admin_user,
|
||||
action="admin.user_created",
|
||||
target_type="user",
|
||||
target_id=user.id,
|
||||
detail={"username": user.username, "role": user.role, "is_active": bool(user.is_active)},
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@router.patch("/users/{user_id}", response_model=UserOut, summary="Update user")
|
||||
def update_user(
|
||||
user_id: int,
|
||||
payload: AdminUserUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> User:
|
||||
"""Update username, password, role or active state."""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
audit_detail: dict = {"before": {"username": user.username, "role": user.role, "is_active": bool(user.is_active)}}
|
||||
if "username" in updates:
|
||||
username = (updates["username"] or "").strip()
|
||||
if not username:
|
||||
raise HTTPException(status_code=400, detail="Username is required")
|
||||
user.username = username
|
||||
if "password" in updates:
|
||||
password = updates["password"] or ""
|
||||
if len(password) < 6:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
|
||||
user.password_hash = hash_password(password)
|
||||
if "role" in updates:
|
||||
next_role = _normalize_role(updates["role"])
|
||||
if user.id == admin_user.id and next_role != "admin":
|
||||
raise HTTPException(status_code=400, detail="Cannot remove your own admin role")
|
||||
user.role = next_role
|
||||
if "is_active" in updates:
|
||||
if user.id == admin_user.id and not updates["is_active"]:
|
||||
raise HTTPException(status_code=400, detail="Cannot deactivate yourself")
|
||||
user.is_active = 1 if updates["is_active"] else 0
|
||||
|
||||
try:
|
||||
db.commit()
|
||||
except IntegrityError as exc:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=409, detail="Username already exists") from exc
|
||||
db.refresh(user)
|
||||
audit_detail["after"] = {"username": user.username, "role": user.role, "is_active": bool(user.is_active)}
|
||||
audit_detail["password_changed"] = "password" in updates
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=admin_user,
|
||||
action="admin.user_updated",
|
||||
target_type="user",
|
||||
target_id=user.id,
|
||||
detail=audit_detail,
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete user")
|
||||
def delete_user(
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> None:
|
||||
"""Delete a user when it is safe to remove the account."""
|
||||
if user_id == admin_user.id:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete yourself")
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
owned_project_count = db.query(Project).filter(Project.owner_user_id == user_id).count()
|
||||
if owned_project_count:
|
||||
raise HTTPException(status_code=409, detail="User owns projects; deactivate or migrate projects first")
|
||||
username = user.username
|
||||
db.delete(user)
|
||||
db.commit()
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=admin_user,
|
||||
action="admin.user_deleted",
|
||||
target_type="user",
|
||||
target_id=user_id,
|
||||
detail={"username": username},
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/audit-logs", response_model=List[AuditLogOut], summary="List audit logs")
|
||||
def list_audit_logs(
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> List[AuditLog]:
|
||||
"""Return recent audit events for administrators."""
|
||||
_ = admin_user
|
||||
safe_limit = min(max(int(limit or 100), 1), 500)
|
||||
return db.query(AuditLog).order_by(AuditLog.created_at.desc(), AuditLog.id.desc()).limit(safe_limit).all()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/demo-factory-reset",
|
||||
response_model=DemoFactoryResetOut,
|
||||
summary="Reset demo data to factory defaults",
|
||||
)
|
||||
def reset_demo_factory(
|
||||
payload: DemoFactoryResetRequest,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> dict:
|
||||
"""Reset a demo deployment to one admin account and one unparsed demo video project."""
|
||||
if payload.confirmation != DEMO_RESET_CONFIRMATION:
|
||||
raise HTTPException(status_code=400, detail="Invalid reset confirmation")
|
||||
|
||||
if not os.path.exists(settings.demo_video_path):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Demo video not found: {settings.demo_video_path}",
|
||||
)
|
||||
|
||||
requested_by = admin_user.username
|
||||
preserved_admin = ensure_default_admin(db)
|
||||
preserved_admin.username = settings.default_admin_username
|
||||
preserved_admin.password_hash = hash_password(settings.default_admin_password)
|
||||
preserved_admin.role = "admin"
|
||||
preserved_admin.is_active = 1
|
||||
db.flush()
|
||||
|
||||
deleted_counts = {
|
||||
"masks": db.query(Mask).delete(synchronize_session=False),
|
||||
"annotations": db.query(Annotation).delete(synchronize_session=False),
|
||||
"frames": db.query(Frame).delete(synchronize_session=False),
|
||||
"tasks": db.query(ProcessingTask).delete(synchronize_session=False),
|
||||
"projects": db.query(Project).delete(synchronize_session=False),
|
||||
"user_templates": db.query(Template).filter(Template.owner_user_id.is_not(None)).delete(synchronize_session=False),
|
||||
"audit_logs": db.query(AuditLog).delete(synchronize_session=False),
|
||||
"users": db.query(User).filter(User.id != preserved_admin.id).delete(synchronize_session=False),
|
||||
}
|
||||
db.flush()
|
||||
db.expunge_all()
|
||||
|
||||
preserved_admin = db.query(User).filter(User.username == settings.default_admin_username).first()
|
||||
if not preserved_admin:
|
||||
raise HTTPException(status_code=500, detail="Default admin was not preserved")
|
||||
|
||||
project = Project(
|
||||
name=DEMO_PROJECT_NAME,
|
||||
description="默认演示视频,尚未生成帧",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="video",
|
||||
parse_fps=30.0,
|
||||
owner_user_id=preserved_admin.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.flush()
|
||||
|
||||
with open(settings.demo_video_path, "rb") as file_obj:
|
||||
data = file_obj.read()
|
||||
object_name = f"uploads/{project.id}/{os.path.basename(settings.demo_video_path)}"
|
||||
upload_file(object_name, data, content_type="video/mp4", length=len(data))
|
||||
project.video_path = object_name
|
||||
project.thumbnail_url = None
|
||||
project.original_fps = None
|
||||
db.commit()
|
||||
db.refresh(preserved_admin)
|
||||
db.refresh(project)
|
||||
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=preserved_admin,
|
||||
action="admin.demo_factory_reset",
|
||||
target_type="project",
|
||||
target_id=project.id,
|
||||
detail={
|
||||
"project_name": project.name,
|
||||
"video_path": project.video_path,
|
||||
"deleted_counts": deleted_counts,
|
||||
"requested_by": requested_by,
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"admin_user": preserved_admin,
|
||||
"project": project,
|
||||
"deleted_counts": deleted_counts,
|
||||
"message": "演示环境已恢复出厂设置",
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
"""AI inference endpoints using selectable SAM runtimes."""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
@@ -8,11 +9,13 @@ from typing import Any, List
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Response, UploadFile, status
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import download_file
|
||||
from models import Project, Frame, Template, Annotation, ProcessingTask
|
||||
from models import Project, Frame, Template, Annotation, ProcessingTask, User
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import (
|
||||
AiRuntimeStatus,
|
||||
MaskAnalysisRequest,
|
||||
@@ -38,6 +41,102 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/ai", tags=["AI"])
|
||||
|
||||
|
||||
def _owned_project_or_404(project_id: int, db: Session, current_user: User) -> Project:
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
return project
|
||||
|
||||
|
||||
def _owned_frame_or_404(frame_id: int, db: Session, current_user: User, project_id: int | None = None) -> Frame:
|
||||
query = (
|
||||
db.query(Frame)
|
||||
.join(Project, Project.id == Frame.project_id)
|
||||
.filter(Frame.id == frame_id, Project.owner_user_id == current_user.id)
|
||||
)
|
||||
if project_id is not None:
|
||||
query = query.filter(Frame.project_id == project_id)
|
||||
frame = query.first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
return frame
|
||||
|
||||
|
||||
def _visible_template_or_404(template_id: int, db: Session, current_user: User) -> Template:
|
||||
template = db.query(Template).filter(
|
||||
Template.id == template_id,
|
||||
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
return template
|
||||
|
||||
|
||||
def _normalize_hex_color(value: Any) -> str | None:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
text = value.strip().lower()
|
||||
if not text:
|
||||
return None
|
||||
if not text.startswith("#"):
|
||||
text = f"#{text}"
|
||||
if len(text) == 4:
|
||||
text = "#" + "".join(char * 2 for char in text[1:])
|
||||
if len(text) != 7:
|
||||
return None
|
||||
try:
|
||||
int(text[1:], 16)
|
||||
except ValueError:
|
||||
return None
|
||||
return text
|
||||
|
||||
|
||||
def _rgb_tuple_to_hex(rgb: tuple[int, int, int]) -> str:
|
||||
values = []
|
||||
for channel in rgb:
|
||||
value = int(channel)
|
||||
if value > 255:
|
||||
value = int(round(value / 257))
|
||||
values.append(min(max(value, 0), 255))
|
||||
return f"#{values[0]:02x}{values[1]:02x}{values[2]:02x}"
|
||||
|
||||
|
||||
def _template_class_maps(template: Template | None) -> tuple[dict[int, dict[str, Any]], dict[str, dict[str, Any]]]:
|
||||
classes = ((template.mapping_rules or {}).get("classes") if template else None) or []
|
||||
by_maskid: dict[int, dict[str, Any]] = {}
|
||||
by_color: dict[str, dict[str, Any]] = {}
|
||||
for index, item in enumerate(classes):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
maskid_value = item.get("maskId", item.get("maskid", item.get("mask_id")))
|
||||
try:
|
||||
maskid = int(maskid_value)
|
||||
except (TypeError, ValueError):
|
||||
maskid = index + 1
|
||||
color = _normalize_hex_color(item.get("color")) or "#22c55e"
|
||||
class_meta = {
|
||||
"id": str(item.get("id") or f"maskid-{maskid}"),
|
||||
"name": str(item.get("name") or f"类别 {maskid}"),
|
||||
"color": color,
|
||||
"zIndex": int(item.get("zIndex", item.get("z_index", index * 10))),
|
||||
"maskId": maskid,
|
||||
**({"category": item.get("category")} if item.get("category") else {}),
|
||||
}
|
||||
if maskid > 0:
|
||||
by_maskid[maskid] = class_meta
|
||||
by_color[color] = class_meta
|
||||
return by_maskid, by_color
|
||||
|
||||
|
||||
def _gt_unknown_label(token: int | str) -> str:
|
||||
if isinstance(token, int):
|
||||
return f"未定义类别 {token}"
|
||||
return f"未定义颜色 {token}"
|
||||
|
||||
|
||||
def _load_frame_image(frame: Frame) -> np.ndarray:
|
||||
"""Download a frame from MinIO and decode it to an RGB numpy array."""
|
||||
try:
|
||||
@@ -106,16 +205,20 @@ def _normalize_polygons(polygons: list[list[list[float]]]) -> list[list[list[flo
|
||||
return [polygon for polygon in (_normalize_polygon(polygon) for polygon in polygons) if len(polygon) >= 3]
|
||||
|
||||
|
||||
def _analysis_anchors(polygons: list[list[list[float]]], points: list[list[float]] | None) -> list[list[float]]:
|
||||
if points:
|
||||
return [[_clamp01(point[0]), _clamp01(point[1])] for point in points if len(point) >= 2]
|
||||
def _sample_anchor_points(anchors: list[list[float]], limit: int = 64) -> list[list[float]]:
|
||||
if len(anchors) <= limit:
|
||||
return anchors
|
||||
step = max(1, math.ceil(len(anchors) / limit))
|
||||
return anchors[::step][:limit]
|
||||
|
||||
|
||||
def _analysis_anchor_summary(polygons: list[list[list[float]]]) -> tuple[int, list[list[float]]]:
|
||||
anchors: list[list[float]] = []
|
||||
for polygon in polygons:
|
||||
if not polygon:
|
||||
continue
|
||||
step = max(1, len(polygon) // 12)
|
||||
anchors.extend([[_clamp01(point[0]), _clamp01(point[1])] for point in polygon[::step]])
|
||||
return anchors[:32]
|
||||
anchors.extend([[_clamp01(point[0]), _clamp01(point[1])] for point in polygon])
|
||||
return len(anchors), _sample_anchor_points(anchors)
|
||||
|
||||
|
||||
def _normalize_smoothing_options(strength: float | int | None, method: str | None = None) -> dict[str, Any]:
|
||||
@@ -129,8 +232,14 @@ def _normalize_smoothing_options(strength: float | int | None, method: str | Non
|
||||
}
|
||||
|
||||
|
||||
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int) -> list[list[float]]:
|
||||
def _smoothing_ratio(strength: float, curve: float = 1.65) -> float:
|
||||
normalized = max(0.0, min(float(strength or 0.0), 100.0)) / 100.0
|
||||
return normalized ** curve
|
||||
|
||||
|
||||
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int, corner_cut: float = 0.25) -> list[list[float]]:
|
||||
points = polygon
|
||||
q = max(0.02, min(float(corner_cut), 0.25))
|
||||
for _ in range(max(0, iterations)):
|
||||
if len(points) < 3:
|
||||
break
|
||||
@@ -138,12 +247,12 @@ def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int) -> list
|
||||
for index, current in enumerate(points):
|
||||
following = points[(index + 1) % len(points)]
|
||||
next_points.append([
|
||||
_clamp01(0.75 * current[0] + 0.25 * following[0]),
|
||||
_clamp01(0.75 * current[1] + 0.25 * following[1]),
|
||||
_clamp01((1.0 - q) * current[0] + q * following[0]),
|
||||
_clamp01((1.0 - q) * current[1] + q * following[1]),
|
||||
])
|
||||
next_points.append([
|
||||
_clamp01(0.25 * current[0] + 0.75 * following[0]),
|
||||
_clamp01(0.25 * current[1] + 0.75 * following[1]),
|
||||
_clamp01(q * current[0] + (1.0 - q) * following[0]),
|
||||
_clamp01(q * current[1] + (1.0 - q) * following[1]),
|
||||
])
|
||||
points = next_points
|
||||
return points
|
||||
@@ -154,7 +263,7 @@ def _simplify_polygon(polygon: list[list[float]], strength: float) -> list[list[
|
||||
return polygon
|
||||
contour = np.array([[[point[0], point[1]]] for point in polygon], dtype=np.float32)
|
||||
arc_length = cv2.arcLength(contour, True)
|
||||
epsilon = arc_length * (0.001 + (strength / 100.0) * 0.006)
|
||||
epsilon = arc_length * (0.00015 + _smoothing_ratio(strength) * 0.00735)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True).reshape(-1, 2)
|
||||
if len(approx) < 3:
|
||||
return polygon
|
||||
@@ -165,9 +274,25 @@ def _smooth_polygon(polygon: list[list[float]], smoothing: dict[str, Any]) -> li
|
||||
strength = float(smoothing.get("strength") or 0.0)
|
||||
if strength <= 0:
|
||||
return _normalize_polygon(polygon)
|
||||
iterations = max(1, min(3, int(strength // 35) + 1))
|
||||
smoothed = _chaikin_smooth_polygon(_normalize_polygon(polygon), iterations)
|
||||
simplified = _simplify_polygon(smoothed, strength)
|
||||
effective_strength = _smoothing_ratio(strength, curve=1.45) * 100.0
|
||||
if effective_strength >= 85:
|
||||
iterations = 4
|
||||
elif effective_strength >= 55:
|
||||
iterations = 3
|
||||
elif effective_strength >= 25:
|
||||
iterations = 2
|
||||
else:
|
||||
iterations = 1
|
||||
corner_cut = 0.03 + _smoothing_ratio(strength, curve=1.35) * 0.22
|
||||
normalized = _normalize_polygon(polygon)
|
||||
pre_simplified = _simplify_polygon(normalized, effective_strength * 0.25)
|
||||
smoothed = _chaikin_smooth_polygon(pre_simplified, iterations, corner_cut)
|
||||
simplified = _simplify_polygon(smoothed, effective_strength)
|
||||
if len(simplified) > len(normalized):
|
||||
for fallback_strength in (25.0, 35.0, 50.0, 70.0, 90.0, 100.0):
|
||||
simplified = _simplify_polygon(simplified, max(effective_strength, fallback_strength))
|
||||
if len(simplified) <= len(normalized):
|
||||
break
|
||||
return simplified if len(simplified) >= 3 else _normalize_polygon(polygon)
|
||||
|
||||
|
||||
@@ -321,7 +446,11 @@ def _filter_predictions(
|
||||
response_model=PredictResponse,
|
||||
summary="Run SAM inference with a prompt",
|
||||
)
|
||||
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
def predict(
|
||||
payload: PredictRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Execute selected SAM segmentation given an image and a prompt.
|
||||
|
||||
- **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized
|
||||
@@ -330,9 +459,7 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
- **interactive**: `prompt_data` is `{ "box": [...], "points": [[x, y]], "labels": [1, 0] }`.
|
||||
- **semantic**: disabled in the current SAM 2.1 point/box product flow.
|
||||
"""
|
||||
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
frame = _owned_frame_or_404(payload.image_id, db, current_user)
|
||||
|
||||
image = _load_frame_image(frame)
|
||||
prompt_type = payload.prompt_type.lower()
|
||||
@@ -478,7 +605,10 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
response_model=AiRuntimeStatus,
|
||||
summary="Get SAM model and GPU runtime status",
|
||||
)
|
||||
def model_status(selected_model: str | None = None) -> dict:
|
||||
def model_status(
|
||||
selected_model: str | None = None,
|
||||
_current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Return real runtime availability for GPU and the currently enabled SAM model."""
|
||||
try:
|
||||
return sam_registry.runtime_status(selected_model)
|
||||
@@ -491,12 +621,14 @@ def model_status(selected_model: str | None = None) -> dict:
|
||||
response_model=MaskAnalysisResponse,
|
||||
summary="Analyze mask geometry and prompt anchors",
|
||||
)
|
||||
def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) -> dict:
|
||||
def analyze_mask(
|
||||
payload: MaskAnalysisRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Return backend-computed mask properties for the frontend inspector."""
|
||||
if payload.frame_id is not None:
|
||||
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
_owned_frame_or_404(payload.frame_id, db, current_user)
|
||||
|
||||
mask_data = payload.mask_data or {}
|
||||
polygons = mask_data.get("polygons") or []
|
||||
@@ -521,13 +653,13 @@ def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) ->
|
||||
else:
|
||||
confidence_source = "manual_or_imported"
|
||||
|
||||
anchors = _analysis_anchors(valid_polygons, payload.points)
|
||||
anchor_count, anchors = _analysis_anchor_summary(valid_polygons)
|
||||
message = "已从后端重新提取几何拓扑锚点" if payload.extract_skeleton else "已读取后端几何属性"
|
||||
|
||||
return {
|
||||
"confidence": confidence,
|
||||
"confidence_source": confidence_source,
|
||||
"topology_anchor_count": len(anchors),
|
||||
"topology_anchor_count": anchor_count,
|
||||
"topology_anchors": anchors,
|
||||
"area": area,
|
||||
"bbox": bbox,
|
||||
@@ -541,16 +673,18 @@ def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) ->
|
||||
response_model=SmoothMaskResponse,
|
||||
summary="Smooth editable mask polygons with backend geometry rules",
|
||||
)
|
||||
def smooth_mask(payload: SmoothMaskRequest, db: Session = Depends(get_db)) -> dict:
|
||||
def smooth_mask(
|
||||
payload: SmoothMaskRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Return a smoothed polygon mask without persisting it.
|
||||
|
||||
The frontend keeps this as an explicit edit operation: users preview/apply it
|
||||
to the current mask, then save through the normal annotation endpoint.
|
||||
"""
|
||||
if payload.frame_id is not None:
|
||||
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
_owned_frame_or_404(payload.frame_id, db, current_user)
|
||||
|
||||
polygons = payload.mask_data.get("polygons") or []
|
||||
valid_polygons = _normalize_polygons(polygons)
|
||||
@@ -564,10 +698,10 @@ def smooth_mask(payload: SmoothMaskRequest, db: Session = Depends(get_db)) -> di
|
||||
|
||||
area = sum(_polygon_area(polygon) for polygon in smoothed_polygons)
|
||||
bbox = _polygon_bbox(smoothed_polygons[0])
|
||||
anchors = _analysis_anchors(smoothed_polygons, payload.points)
|
||||
anchor_count, anchors = _analysis_anchor_summary(smoothed_polygons)
|
||||
return {
|
||||
"polygons": smoothed_polygons,
|
||||
"topology_anchor_count": len(anchors),
|
||||
"topology_anchor_count": anchor_count,
|
||||
"topology_anchors": anchors,
|
||||
"area": area,
|
||||
"bbox": bbox,
|
||||
@@ -581,7 +715,11 @@ def smooth_mask(payload: SmoothMaskRequest, db: Session = Depends(get_db)) -> di
|
||||
response_model=PropagateResponse,
|
||||
summary="Propagate one current-frame region across a video frame segment",
|
||||
)
|
||||
def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
|
||||
def propagate(
|
||||
payload: PropagateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Track one selected region from the current frame across nearby frames.
|
||||
|
||||
SAM 2 uses the official video predictor with the selected mask as the seed.
|
||||
@@ -592,16 +730,8 @@ def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
|
||||
raise HTTPException(status_code=400, detail="direction must be forward, backward, or both")
|
||||
max_frames = max(1, min(int(payload.max_frames or 30), 500))
|
||||
|
||||
project = db.query(Project).filter(Project.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
source_frame = db.query(Frame).filter(
|
||||
Frame.id == payload.frame_id,
|
||||
Frame.project_id == payload.project_id,
|
||||
).first()
|
||||
if not source_frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
_owned_project_or_404(payload.project_id, db, current_user)
|
||||
source_frame = _owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
|
||||
|
||||
seed = payload.seed.model_dump(exclude_none=True)
|
||||
polygons = seed.get("polygons") or []
|
||||
@@ -709,18 +839,14 @@ def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
|
||||
response_model=ProcessingTaskOut,
|
||||
summary="Queue a background video propagation task",
|
||||
)
|
||||
def queue_propagate_task(payload: PropagateTaskRequest, db: Session = Depends(get_db)) -> ProcessingTaskOut:
|
||||
def queue_propagate_task(
|
||||
payload: PropagateTaskRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> ProcessingTaskOut:
|
||||
"""Queue multiple seed/direction propagation steps as one background task."""
|
||||
project = db.query(Project).filter(Project.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
source_frame = db.query(Frame).filter(
|
||||
Frame.id == payload.frame_id,
|
||||
Frame.project_id == payload.project_id,
|
||||
).first()
|
||||
if not source_frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
_owned_project_or_404(payload.project_id, db, current_user)
|
||||
source_frame = _owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
|
||||
|
||||
if not payload.steps:
|
||||
raise HTTPException(status_code=400, detail="Propagation task requires at least one step")
|
||||
@@ -768,11 +894,13 @@ def queue_propagate_task(payload: PropagateTaskRequest, db: Session = Depends(ge
|
||||
response_model=PredictResponse,
|
||||
summary="Run automatic segmentation",
|
||||
)
|
||||
def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
|
||||
def auto_segment(
|
||||
image_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Run automatic mask generation on a frame using a grid of point prompts."""
|
||||
frame = db.query(Frame).filter(Frame.id == image_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
frame = _owned_frame_or_404(image_id, db, current_user)
|
||||
|
||||
image = _load_frame_image(frame)
|
||||
try:
|
||||
@@ -792,16 +920,15 @@ def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
|
||||
def save_annotation(
|
||||
payload: AnnotationCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Annotation:
|
||||
"""Persist an annotation (mask, points, bbox) into the database."""
|
||||
project = db.query(Project).filter(Project.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
_owned_project_or_404(payload.project_id, db, current_user)
|
||||
|
||||
if payload.frame_id:
|
||||
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
_owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
|
||||
if payload.template_id:
|
||||
_visible_template_or_404(payload.template_id, db, current_user)
|
||||
|
||||
annotation = Annotation(**payload.model_dump())
|
||||
db.add(annotation)
|
||||
@@ -823,8 +950,10 @@ async def import_gt_mask(
|
||||
template_id: int | None = Form(None),
|
||||
label: str = Form("GT Mask"),
|
||||
color: str = Form("#22c55e"),
|
||||
unknown_color_policy: str = Form("undefined"),
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> List[Annotation]:
|
||||
"""Convert a binary/label mask image into persisted polygon annotations.
|
||||
|
||||
@@ -833,62 +962,122 @@ async def import_gt_mask(
|
||||
the frontend an editable point-region representation instead of a static
|
||||
bitmap layer.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
_owned_project_or_404(project_id, db, current_user)
|
||||
frame = _owned_frame_or_404(frame_id, db, current_user, project_id)
|
||||
|
||||
frame = db.query(Frame).filter(Frame.id == frame_id, Frame.project_id == project_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
if unknown_color_policy not in {"discard", "undefined"}:
|
||||
raise HTTPException(status_code=400, detail="unknown_color_policy must be discard or undefined")
|
||||
|
||||
template: Template | None = None
|
||||
if template_id is not None:
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
template = _visible_template_or_404(template_id, db, current_user)
|
||||
|
||||
data = await file.read()
|
||||
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
|
||||
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
if image is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid mask image")
|
||||
|
||||
if image.ndim == 2:
|
||||
label_image = image
|
||||
elif image.ndim == 3 and image.shape[2] >= 3:
|
||||
channels = image[:, :, :3]
|
||||
# GT label images are maskid maps: either grayscale or RGB/BGR where
|
||||
# all three color channels contain the same maskid value [X, X, X].
|
||||
if not (np.array_equal(channels[:, :, 0], channels[:, :, 1]) and np.array_equal(channels[:, :, 1], channels[:, :, 2])):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="GT Mask 图片不符合要求:请上传灰度图,或 RGB 三通道完全相同的 maskid 图(背景 0,像素值为 maskid)。",
|
||||
)
|
||||
label_image = channels[:, :, 0]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="GT Mask 图片不符合要求:请上传灰度图,或 RGB 三通道完全相同的 maskid 图(背景 0,像素值为 maskid)。",
|
||||
)
|
||||
|
||||
width = int(frame.width or image.shape[1])
|
||||
height = int(frame.height or image.shape[0])
|
||||
label_values = [int(value) for value in np.unique(image) if int(value) > 0]
|
||||
if not label_values:
|
||||
original_height, original_width = int(label_image.shape[0]), int(label_image.shape[1])
|
||||
resized_to_frame = original_width != width or original_height != height
|
||||
if resized_to_frame:
|
||||
label_image = cv2.resize(label_image, (width, height), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
by_maskid, _by_color = _template_class_maps(template)
|
||||
has_template_classes = bool(by_maskid)
|
||||
fallback_color = _normalize_hex_color(color) or "#22c55e"
|
||||
|
||||
import_items: list[dict[str, Any]] = []
|
||||
skipped_unknown = 0
|
||||
label_values = [int(value) for value in np.unique(label_image) if int(value) > 0]
|
||||
for label_value in label_values:
|
||||
class_meta = by_maskid.get(label_value)
|
||||
is_unknown = has_template_classes and class_meta is None
|
||||
if is_unknown and unknown_color_policy == "discard":
|
||||
skipped_unknown += 1
|
||||
continue
|
||||
if class_meta:
|
||||
annotation_label = class_meta["name"]
|
||||
annotation_color = class_meta["color"]
|
||||
elif is_unknown:
|
||||
annotation_label = _gt_unknown_label(label_value)
|
||||
annotation_color = fallback_color
|
||||
else:
|
||||
annotation_label = f"{label} {label_value}" if len(label_values) > 1 else label
|
||||
annotation_color = fallback_color
|
||||
import_items.append({
|
||||
"token": label_value,
|
||||
"binary": np.where(label_image == label_value, 255, 0).astype(np.uint8),
|
||||
"label": annotation_label,
|
||||
"color": annotation_color,
|
||||
"class": class_meta,
|
||||
"unknown": is_unknown,
|
||||
"metadata": {
|
||||
"gt_label_value": label_value,
|
||||
"gt_original_size": {"width": original_width, "height": original_height},
|
||||
"gt_resized_to_frame": resized_to_frame,
|
||||
},
|
||||
})
|
||||
|
||||
if not import_items:
|
||||
if skipped_unknown > 0:
|
||||
raise HTTPException(status_code=400, detail="No matching GT mask classes found")
|
||||
raise HTTPException(status_code=400, detail="No foreground mask regions found")
|
||||
has_multiple_labels = len(label_values) > 1
|
||||
|
||||
annotations: list[Annotation] = []
|
||||
for label_value in label_values:
|
||||
binary = np.where(image == label_value, 255, 0).astype(np.uint8)
|
||||
for item in import_items:
|
||||
binary = item["binary"]
|
||||
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
annotation_label = f"{label} {label_value}" if has_multiple_labels else label
|
||||
|
||||
for contour in contours:
|
||||
if cv2.contourArea(contour) < 1:
|
||||
continue
|
||||
|
||||
polygon = _normalized_contour(contour, image.shape[1], image.shape[0])
|
||||
polygon = _normalized_contour(contour, binary.shape[1], binary.shape[0])
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
|
||||
component = np.zeros_like(binary, dtype=np.uint8)
|
||||
cv2.drawContours(component, [contour], -1, 1, thickness=-1)
|
||||
seed_point = _component_seed_point(component, image.shape[1], image.shape[0])
|
||||
bbox = _contour_bbox(contour, image.shape[1], image.shape[0])
|
||||
seed_point = _component_seed_point(component, binary.shape[1], binary.shape[0])
|
||||
bbox = _contour_bbox(contour, binary.shape[1], binary.shape[0])
|
||||
mask_data = {
|
||||
"polygons": [polygon],
|
||||
"label": item["label"],
|
||||
"color": item["color"],
|
||||
"source": "gt_mask",
|
||||
"image_size": {"width": width, "height": height},
|
||||
**item["metadata"],
|
||||
}
|
||||
if item["class"]:
|
||||
mask_data["class"] = item["class"]
|
||||
if item["unknown"]:
|
||||
mask_data["gt_unknown_class"] = True
|
||||
|
||||
annotation = Annotation(
|
||||
project_id=project_id,
|
||||
frame_id=frame_id,
|
||||
template_id=template_id,
|
||||
mask_data={
|
||||
"polygons": [polygon],
|
||||
"label": annotation_label,
|
||||
"color": color,
|
||||
"source": "gt_mask",
|
||||
"gt_label_value": label_value,
|
||||
"image_size": {"width": width, "height": height},
|
||||
},
|
||||
mask_data=mask_data,
|
||||
points=[seed_point],
|
||||
bbox=bbox,
|
||||
)
|
||||
@@ -914,14 +1103,14 @@ def list_annotations(
|
||||
project_id: int,
|
||||
frame_id: int | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[Annotation]:
|
||||
"""Return persisted annotations for a project, optionally scoped to one frame."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
_owned_project_or_404(project_id, db, current_user)
|
||||
|
||||
query = db.query(Annotation).filter(Annotation.project_id == project_id)
|
||||
if frame_id is not None:
|
||||
_owned_frame_or_404(frame_id, db, current_user, project_id)
|
||||
query = query.filter(Annotation.frame_id == frame_id)
|
||||
return query.order_by(Annotation.id).all()
|
||||
|
||||
@@ -935,17 +1124,21 @@ def update_annotation(
|
||||
annotation_id: int,
|
||||
payload: AnnotationUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Annotation:
|
||||
"""Update mutable annotation fields persisted in the database."""
|
||||
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
|
||||
annotation = (
|
||||
db.query(Annotation)
|
||||
.join(Project, Project.id == Annotation.project_id)
|
||||
.filter(Annotation.id == annotation_id, Project.owner_user_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
if not annotation:
|
||||
raise HTTPException(status_code=404, detail="Annotation not found")
|
||||
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
if "template_id" in updates and updates["template_id"] is not None:
|
||||
template = db.query(Template).filter(Template.id == updates["template_id"]).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
_visible_template_or_404(updates["template_id"], db, current_user)
|
||||
|
||||
for field, value in updates.items():
|
||||
setattr(annotation, field, value)
|
||||
@@ -964,9 +1157,15 @@ def update_annotation(
|
||||
def delete_annotation(
|
||||
annotation_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Response:
|
||||
"""Delete an annotation and its derived mask rows through ORM cascade."""
|
||||
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
|
||||
annotation = (
|
||||
db.query(Annotation)
|
||||
.join(Project, Project.id == Annotation.project_id)
|
||||
.filter(Annotation.id == annotation_id, Project.owner_user_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
if not annotation:
|
||||
raise HTTPException(status_code=404, detail="Annotation not found")
|
||||
|
||||
|
||||
@@ -1,9 +1,23 @@
|
||||
"""Authentication endpoints."""
|
||||
"""Authentication endpoints and dependencies."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config import settings
|
||||
from database import get_db
|
||||
from models import AuditLog, User
|
||||
from schemas import LoginResponse, UserOut
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["Auth"])
|
||||
password_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
@@ -11,14 +25,151 @@ class LoginRequest(BaseModel):
|
||||
password: str
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
token: str
|
||||
username: str
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a plain password for storage."""
|
||||
return password_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(password: str, password_hash: str) -> bool:
|
||||
"""Verify a plain password against a stored hash."""
|
||||
return password_context.verify(password, password_hash)
|
||||
|
||||
|
||||
def create_access_token(user: User, expires_delta: timedelta | None = None) -> str:
|
||||
"""Create a signed JWT access token for a user."""
|
||||
expire = datetime.now(timezone.utc) + (
|
||||
expires_delta or timedelta(minutes=settings.access_token_expire_minutes)
|
||||
)
|
||||
payload: dict[str, Any] = {
|
||||
"sub": str(user.id),
|
||||
"username": user.username,
|
||||
"role": user.role,
|
||||
"exp": expire,
|
||||
}
|
||||
return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
|
||||
|
||||
|
||||
def ensure_default_admin(db: Session) -> User:
|
||||
"""Create the default development admin if the user table is empty."""
|
||||
existing = db.query(User).filter(User.username == settings.default_admin_username).first()
|
||||
if existing:
|
||||
return existing
|
||||
user = User(
|
||||
username=settings.default_admin_username,
|
||||
password_hash=hash_password(settings.default_admin_password),
|
||||
role="admin",
|
||||
is_active=1,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
||||
db: Session = Depends(get_db),
|
||||
) -> User:
|
||||
"""Resolve and validate the current user from the Bearer token."""
|
||||
if credentials is None or credentials.scheme.lower() != "bearer":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
credentials.credentials,
|
||||
settings.jwt_secret_key,
|
||||
algorithms=[settings.jwt_algorithm],
|
||||
)
|
||||
user_id = int(payload.get("sub"))
|
||||
except (JWTError, TypeError, ValueError) as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from exc
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Inactive or missing user",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
def require_admin(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""Require the current user to have the administrator role."""
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin permission required")
|
||||
return current_user
|
||||
|
||||
|
||||
def require_editor(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""Require a user role that can modify segmentation data."""
|
||||
if current_user.role not in {"admin", "annotator"}:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Edit permission required")
|
||||
return current_user
|
||||
|
||||
|
||||
def write_audit_log(
|
||||
db: Session,
|
||||
*,
|
||||
actor: User | None,
|
||||
action: str,
|
||||
target_type: str | None = None,
|
||||
target_id: str | int | None = None,
|
||||
detail: dict[str, Any] | None = None,
|
||||
) -> AuditLog:
|
||||
"""Persist a compact audit event."""
|
||||
log = AuditLog(
|
||||
actor_user_id=actor.id if actor else None,
|
||||
action=action,
|
||||
target_type=target_type,
|
||||
target_id=str(target_id) if target_id is not None else None,
|
||||
detail=detail or {},
|
||||
)
|
||||
db.add(log)
|
||||
db.commit()
|
||||
db.refresh(log)
|
||||
return log
|
||||
|
||||
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
def login(payload: LoginRequest) -> dict:
|
||||
"""Simple login for development."""
|
||||
if payload.username == "admin" and payload.password == "123456":
|
||||
return {"token": "fake-jwt-token-for-admin", "username": payload.username}
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
def login(payload: LoginRequest, db: Session = Depends(get_db)) -> dict:
|
||||
"""Authenticate a user and return a signed JWT."""
|
||||
ensure_default_admin(db)
|
||||
user = db.query(User).filter(User.username == payload.username).first()
|
||||
if not user or not user.is_active or not verify_password(payload.password, user.password_hash):
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=None,
|
||||
action="auth.login_failed",
|
||||
target_type="user",
|
||||
target_id=payload.username,
|
||||
detail={"username": payload.username},
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=user,
|
||||
action="auth.login_success",
|
||||
target_type="user",
|
||||
target_id=user.id,
|
||||
detail={"username": user.username},
|
||||
)
|
||||
return {
|
||||
"token": create_access_token(user),
|
||||
"token_type": "bearer",
|
||||
"username": user.username,
|
||||
"user": user,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserOut)
|
||||
def read_current_user(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""Return the authenticated user profile."""
|
||||
return current_user
|
||||
|
||||
@@ -5,11 +5,12 @@ from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Annotation, Frame, ProcessingTask, Project, Template
|
||||
from models import Annotation, Frame, ProcessingTask, Project, Template, User
|
||||
from routers.auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
|
||||
|
||||
@@ -52,22 +53,45 @@ def _task_payload(task: ProcessingTask) -> dict[str, Any]:
|
||||
|
||||
|
||||
@router.get("/overview", summary="Get dashboard overview")
|
||||
def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
|
||||
def get_dashboard_overview(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict[str, Any]:
|
||||
"""Return live dashboard data derived from persisted backend records."""
|
||||
project_count = db.query(func.count(Project.id)).scalar() or 0
|
||||
frame_count = db.query(func.count(Frame.id)).scalar() or 0
|
||||
annotation_count = db.query(func.count(Annotation.id)).scalar() or 0
|
||||
template_count = db.query(func.count(Template.id)).scalar() or 0
|
||||
owned_project_ids_query = db.query(Project.id).filter(Project.owner_user_id == current_user.id)
|
||||
project_count = db.query(func.count(Project.id)).filter(Project.owner_user_id == current_user.id).scalar() or 0
|
||||
frame_count = db.query(func.count(Frame.id)).filter(Frame.project_id.in_(owned_project_ids_query)).scalar() or 0
|
||||
annotation_count = (
|
||||
db.query(func.count(Annotation.id))
|
||||
.filter(Annotation.project_id.in_(owned_project_ids_query))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
template_count = (
|
||||
db.query(func.count(Template.id))
|
||||
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
active_task_count = (
|
||||
db.query(func.count(ProcessingTask.id))
|
||||
.outerjoin(Project, Project.id == ProcessingTask.project_id)
|
||||
.filter((ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id))
|
||||
.filter(ProcessingTask.status.in_(ACTIVE_TASK_STATUSES))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
projects = db.query(Project).order_by(Project.updated_at.desc()).all()
|
||||
projects = (
|
||||
db.query(Project)
|
||||
.filter(Project.owner_user_id == current_user.id)
|
||||
.order_by(Project.updated_at.desc())
|
||||
.all()
|
||||
)
|
||||
recent_tasks = (
|
||||
db.query(ProcessingTask)
|
||||
.outerjoin(Project, Project.id == ProcessingTask.project_id)
|
||||
.filter((ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id))
|
||||
.order_by(ProcessingTask.created_at.desc())
|
||||
.limit(50)
|
||||
.all()
|
||||
@@ -96,6 +120,7 @@ def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
|
||||
|
||||
recent_annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id.in_(owned_project_ids_query))
|
||||
.order_by(Annotation.updated_at.desc())
|
||||
.limit(10)
|
||||
.all()
|
||||
@@ -112,6 +137,7 @@ def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
|
||||
|
||||
recent_templates = (
|
||||
db.query(Template)
|
||||
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
|
||||
.order_by(Template.created_at.desc())
|
||||
.limit(10)
|
||||
.all()
|
||||
|
||||
@@ -4,17 +4,22 @@ import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from urllib.parse import quote
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Project, Annotation, Frame, Template
|
||||
from minio_client import download_file
|
||||
from models import Project, Annotation, Frame, Template, User
|
||||
from routers.auth import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/export", tags=["Export"])
|
||||
@@ -49,6 +54,30 @@ def _annotation_z_index(annotation: Annotation) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
def _annotation_mask_id(annotation: Annotation) -> int | None:
|
||||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||
if isinstance(class_meta, dict):
|
||||
for key in ("maskId", "maskid", "mask_id"):
|
||||
if class_meta.get(key) is None:
|
||||
continue
|
||||
try:
|
||||
value = int(class_meta[key])
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if value > 0:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _annotation_category_name(annotation: Annotation) -> str:
|
||||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||
if isinstance(class_meta, dict) and class_meta.get("category"):
|
||||
return str(class_meta["category"])
|
||||
if annotation.template and annotation.template.name:
|
||||
return str(annotation.template.name)
|
||||
return ""
|
||||
|
||||
|
||||
def _annotation_class_key(annotation: Annotation) -> str:
|
||||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||
if isinstance(class_meta, dict):
|
||||
@@ -85,38 +114,162 @@ def _annotation_color(annotation: Annotation) -> str:
|
||||
return "#ffffff"
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/coco",
|
||||
summary="Export annotations in COCO format",
|
||||
)
|
||||
def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
|
||||
"""Export all annotations for a project as a COCO-format JSON file."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
def _hex_to_rgb(color: str) -> list[int]:
|
||||
value = str(color or "").strip()
|
||||
if value.startswith("#"):
|
||||
value = value[1:]
|
||||
if len(value) == 3:
|
||||
value = "".join(part * 2 for part in value)
|
||||
if len(value) != 6:
|
||||
return [255, 255, 255]
|
||||
try:
|
||||
return [int(value[i:i + 2], 16) for i in (0, 2, 4)]
|
||||
except ValueError:
|
||||
return [255, 255, 255]
|
||||
|
||||
|
||||
def _safe_filename_part(value: Any, fallback: str = "unknown") -> str:
|
||||
text = str(value or "").strip()
|
||||
if not text:
|
||||
text = fallback
|
||||
text = re.sub(r"[\\/:*?\"<>|\s]+", "_", text)
|
||||
text = re.sub(r"_+", "_", text).strip("._")
|
||||
return text or fallback
|
||||
|
||||
|
||||
def _project_video_name(project: Project) -> str:
|
||||
if project.video_path:
|
||||
stem = Path(project.video_path).name
|
||||
if "." in stem:
|
||||
stem = ".".join(stem.split(".")[:-1])
|
||||
if stem:
|
||||
return _safe_filename_part(stem, f"project_{project.id}")
|
||||
return _safe_filename_part(project.name, f"project_{project.id}")
|
||||
|
||||
|
||||
def _project_export_name(project: Project) -> str:
|
||||
return _safe_filename_part(project.name, f"project_{project.id}")
|
||||
|
||||
|
||||
def _frame_timestamp_ms(frame: Frame, project: Project) -> float:
|
||||
if frame.timestamp_ms is not None:
|
||||
return float(frame.timestamp_ms)
|
||||
fps = project.parse_fps or project.original_fps or 30.0
|
||||
return float(frame.frame_index) * 1000.0 / max(float(fps), 1.0)
|
||||
|
||||
|
||||
def _project_frame_number(frame: Frame) -> int:
|
||||
return int(frame.frame_index) + 1
|
||||
|
||||
|
||||
def _format_timestamp_ms(value: float) -> str:
|
||||
total_ms = max(0, int(round(float(value))))
|
||||
hours = total_ms // 3_600_000
|
||||
minutes = (total_ms % 3_600_000) // 60_000
|
||||
seconds = (total_ms % 60_000) // 1_000
|
||||
milliseconds = total_ms % 1_000
|
||||
return f"{hours}h{minutes:02d}m{seconds:02d}s{milliseconds:03d}ms"
|
||||
|
||||
|
||||
def _frame_export_stem(project: Project, frame: Frame) -> str:
|
||||
return "_".join([
|
||||
_project_video_name(project),
|
||||
_format_timestamp_ms(_frame_timestamp_ms(frame, project)),
|
||||
f"frame{_project_frame_number(frame):06d}",
|
||||
])
|
||||
|
||||
|
||||
def _segmentation_results_filename(project: Project, frames: list[Frame]) -> str:
|
||||
if not frames:
|
||||
return f"{_project_export_name(project)}_seg_T_0h00m00s000ms-0h00m00s000ms_P_0-0.zip"
|
||||
first_frame = frames[0]
|
||||
last_frame = frames[-1]
|
||||
return (
|
||||
f"{_project_export_name(project)}"
|
||||
f"_seg_T_{_format_timestamp_ms(_frame_timestamp_ms(first_frame, project))}"
|
||||
f"-{_format_timestamp_ms(_frame_timestamp_ms(last_frame, project))}"
|
||||
f"_P_{_project_frame_number(first_frame)}-{_project_frame_number(last_frame)}.zip"
|
||||
)
|
||||
|
||||
|
||||
def _download_content_disposition(filename: str) -> str:
|
||||
ascii_fallback = filename.encode("ascii", "ignore").decode("ascii") or "segmentation_results.zip"
|
||||
ascii_fallback = _safe_filename_part(ascii_fallback, "segmentation_results.zip")
|
||||
if not ascii_fallback.endswith(".zip") and filename.endswith(".zip"):
|
||||
ascii_fallback = f"{ascii_fallback}.zip"
|
||||
return f"attachment; filename=\"{ascii_fallback}\"; filename*=UTF-8''{quote(filename)}"
|
||||
|
||||
|
||||
def _frame_image_extension(frame: Frame) -> str:
|
||||
suffix = Path(frame.image_url or "").suffix.lower()
|
||||
return suffix if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} else ".jpg"
|
||||
|
||||
|
||||
def _project_or_404(project_id: int, db: Session, current_user: User) -> Project:
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
return project
|
||||
|
||||
annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == project_id)
|
||||
.all()
|
||||
)
|
||||
frames = (
|
||||
|
||||
def _project_frames(project_id: int, db: Session) -> list[Frame]:
|
||||
return (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id)
|
||||
.order_by(Frame.frame_index)
|
||||
.all()
|
||||
)
|
||||
templates = db.query(Template).all()
|
||||
|
||||
# Build COCO structure
|
||||
|
||||
def _filter_frames(
|
||||
frames: list[Frame],
|
||||
*,
|
||||
scope: str = "all",
|
||||
start_frame: int | None = None,
|
||||
end_frame: int | None = None,
|
||||
frame_id: int | None = None,
|
||||
) -> list[Frame]:
|
||||
if scope == "current":
|
||||
if frame_id is None:
|
||||
raise HTTPException(status_code=400, detail="frame_id is required for current-frame export")
|
||||
selected = [frame for frame in frames if frame.id == frame_id]
|
||||
if not selected:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
return selected
|
||||
|
||||
if scope == "range":
|
||||
if start_frame is None or end_frame is None:
|
||||
raise HTTPException(status_code=400, detail="start_frame and end_frame are required for range export")
|
||||
start = max(1, min(int(start_frame), int(end_frame)))
|
||||
end = max(1, max(int(start_frame), int(end_frame)))
|
||||
return frames[start - 1:end]
|
||||
|
||||
return frames
|
||||
|
||||
|
||||
def _filtered_annotations(project_id: int, frame_ids: set[int], db: Session) -> list[Annotation]:
|
||||
if not frame_ids:
|
||||
return []
|
||||
return (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == project_id)
|
||||
.filter(Annotation.frame_id.in_(frame_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def _build_coco(project: Project, frames: list[Frame], annotations: list[Annotation], templates: list[Template]) -> dict[str, Any]:
|
||||
images = []
|
||||
for idx, frame in enumerate(frames):
|
||||
for frame in frames:
|
||||
images.append({
|
||||
"id": frame.id,
|
||||
"file_name": frame.image_url,
|
||||
"width": frame.width or 1920,
|
||||
"height": frame.height or 1080,
|
||||
"frame_index": idx,
|
||||
"frame_index": frame.frame_index,
|
||||
})
|
||||
|
||||
categories = []
|
||||
@@ -131,14 +284,14 @@ def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResp
|
||||
|
||||
coco_annotations = []
|
||||
ann_id = 1
|
||||
selected_frame_ids = {frame.id for frame in frames}
|
||||
for ann in annotations:
|
||||
if not ann.mask_data:
|
||||
if ann.frame_id not in selected_frame_ids or not ann.mask_data:
|
||||
continue
|
||||
polygons = ann.mask_data.get("polygons", [])
|
||||
if not polygons:
|
||||
continue
|
||||
|
||||
# Use first polygon for bbox / area approximation
|
||||
first_poly = polygons[0]
|
||||
xs = [p[0] for p in first_poly]
|
||||
ys = [p[1] for p in first_poly]
|
||||
@@ -171,7 +324,7 @@ def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResp
|
||||
})
|
||||
ann_id += 1
|
||||
|
||||
coco = {
|
||||
return {
|
||||
"info": {
|
||||
"description": f"Annotations for {project.name}",
|
||||
"version": "1.0",
|
||||
@@ -183,39 +336,235 @@ def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResp
|
||||
"categories": categories,
|
||||
}
|
||||
|
||||
data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8")
|
||||
filename = f"project_{project_id}_coco.json"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(data),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
def _class_mapping_entry(annotation: Annotation) -> dict[str, Any]:
|
||||
return {
|
||||
"key": _annotation_class_key(annotation),
|
||||
"className": _annotation_label(annotation),
|
||||
"chineseName": _annotation_label(annotation),
|
||||
"categoryName": _annotation_category_name(annotation),
|
||||
"color": _annotation_color(annotation),
|
||||
"internalPriority": _annotation_z_index(annotation),
|
||||
"maskidHint": _annotation_mask_id(annotation),
|
||||
"template_id": annotation.template_id,
|
||||
}
|
||||
|
||||
|
||||
def _build_gt_class_mapping(annotations: list[Annotation]) -> tuple[dict[str, int], list[dict[str, Any]]]:
|
||||
entries_by_key: dict[str, dict[str, Any]] = {}
|
||||
for annotation in annotations:
|
||||
if not annotation.mask_data or not annotation.mask_data.get("polygons"):
|
||||
continue
|
||||
entry = _class_mapping_entry(annotation)
|
||||
entries_by_key.setdefault(entry["key"], entry)
|
||||
|
||||
ordered = sorted(
|
||||
entries_by_key.values(),
|
||||
key=lambda item: (
|
||||
item["maskidHint"] if isinstance(item.get("maskidHint"), int) and item["maskidHint"] > 0 else 10_000_000,
|
||||
str(item["className"]),
|
||||
str(item["key"]),
|
||||
),
|
||||
)
|
||||
key_to_value: dict[str, int] = {}
|
||||
classes: list[dict[str, Any]] = []
|
||||
used_maskids: set[int] = set()
|
||||
next_maskid = 1
|
||||
|
||||
def next_available_maskid() -> int:
|
||||
nonlocal next_maskid
|
||||
while next_maskid in used_maskids:
|
||||
next_maskid += 1
|
||||
value = next_maskid
|
||||
used_maskids.add(value)
|
||||
next_maskid += 1
|
||||
return value
|
||||
|
||||
for entry in ordered:
|
||||
hinted_maskid = entry.get("maskidHint")
|
||||
if isinstance(hinted_maskid, int) and hinted_maskid > 0 and hinted_maskid not in used_maskids:
|
||||
maskid = hinted_maskid
|
||||
used_maskids.add(maskid)
|
||||
else:
|
||||
maskid = next_available_maskid()
|
||||
key_to_value[entry["key"]] = maskid
|
||||
classes.append({
|
||||
"gt_pixel_value": maskid,
|
||||
"maskid": maskid,
|
||||
"chineseName": entry["chineseName"],
|
||||
"className": entry["className"],
|
||||
"categoryName": entry["categoryName"],
|
||||
"rgb": _hex_to_rgb(entry["color"]),
|
||||
"color": entry["color"],
|
||||
"key": entry["key"],
|
||||
"template_id": entry["template_id"],
|
||||
})
|
||||
return key_to_value, classes
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/masks",
|
||||
summary="Export PNG masks as a ZIP archive",
|
||||
)
|
||||
def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
|
||||
"""Export individual masks plus z-index fused semantic masks inside a ZIP."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
def _parse_result_outputs(mask_type: str, outputs: str | None) -> set[str]:
|
||||
allowed = {"separate", "gt_label", "pro_label", "mix_label"}
|
||||
if outputs:
|
||||
parsed = {item.strip() for item in outputs.split(",") if item.strip()}
|
||||
invalid = parsed - allowed
|
||||
if invalid:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid outputs: {', '.join(sorted(invalid))}")
|
||||
return parsed or allowed
|
||||
|
||||
if mask_type == "separate":
|
||||
return {"separate"}
|
||||
if mask_type == "gt_label":
|
||||
return {"gt_label"}
|
||||
if mask_type == "pro_label":
|
||||
return {"pro_label"}
|
||||
if mask_type == "mix_label":
|
||||
return {"mix_label"}
|
||||
return allowed
|
||||
|
||||
|
||||
def _write_original_frames(
|
||||
zf: zipfile.ZipFile,
|
||||
project: Project,
|
||||
frames: list[Frame],
|
||||
) -> dict[int, bytes]:
|
||||
image_bytes_by_frame: dict[int, bytes] = {}
|
||||
for frame in frames:
|
||||
image_bytes = download_file(frame.image_url)
|
||||
image_bytes_by_frame[frame.id] = image_bytes
|
||||
zf.writestr(
|
||||
f"原始图片/{_frame_export_stem(project, frame)}{_frame_image_extension(frame)}",
|
||||
image_bytes,
|
||||
)
|
||||
return image_bytes_by_frame
|
||||
|
||||
|
||||
def _decode_original_image(image_bytes: bytes | None, width: int, height: int) -> np.ndarray:
|
||||
import cv2
|
||||
|
||||
annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == project_id)
|
||||
.all()
|
||||
)
|
||||
frames = (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id)
|
||||
.order_by(Frame.frame_index)
|
||||
.all()
|
||||
)
|
||||
if image_bytes:
|
||||
decoded = cv2.imdecode(np.frombuffer(image_bytes, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
if decoded is not None:
|
||||
if decoded.shape[1] != width or decoded.shape[0] != height:
|
||||
decoded = cv2.resize(decoded, (width, height), interpolation=cv2.INTER_AREA)
|
||||
return decoded
|
||||
return np.zeros((height, width, 3), dtype=np.uint8)
|
||||
|
||||
|
||||
def _write_result_mask_outputs(
|
||||
zf: zipfile.ZipFile,
|
||||
project: Project,
|
||||
frames: list[Frame],
|
||||
annotations: list[Annotation],
|
||||
*,
|
||||
outputs: set[str],
|
||||
class_values: dict[str, int],
|
||||
class_mapping: list[dict[str, Any]],
|
||||
original_images: dict[int, bytes],
|
||||
mix_opacity: float,
|
||||
) -> None:
|
||||
import cv2
|
||||
|
||||
include_individual = "separate" in outputs
|
||||
include_semantic = "gt_label" in outputs
|
||||
include_pro_label = "pro_label" in outputs
|
||||
include_mix_label = "mix_label" in outputs
|
||||
class_rgb_by_key = {
|
||||
item["key"]: item.get("rgb") or _hex_to_rgb(item.get("color", "#ffffff"))
|
||||
for item in class_mapping
|
||||
}
|
||||
annotations_by_frame: dict[int, list[Annotation]] = {}
|
||||
selected_frame_ids = {frame.id for frame in frames}
|
||||
for annotation in annotations:
|
||||
if annotation.frame_id not in selected_frame_ids or not annotation.mask_data:
|
||||
continue
|
||||
if not annotation.mask_data.get("polygons"):
|
||||
continue
|
||||
annotations_by_frame.setdefault(annotation.frame_id, []).append(annotation)
|
||||
|
||||
for frame in frames:
|
||||
frame_annotations = annotations_by_frame.get(frame.id, [])
|
||||
if not frame_annotations:
|
||||
continue
|
||||
width = frame.width or 1920
|
||||
height = frame.height or 1080
|
||||
frame_stem = _frame_export_stem(project, frame)
|
||||
|
||||
if include_individual:
|
||||
class_masks: dict[str, np.ndarray] = {}
|
||||
class_meta: dict[str, dict[str, Any]] = {}
|
||||
for annotation in frame_annotations:
|
||||
key = _annotation_class_key(annotation)
|
||||
combined = class_masks.setdefault(key, np.zeros((height, width), dtype=np.uint8))
|
||||
for poly in (annotation.mask_data or {}).get("polygons", []):
|
||||
combined[:] = np.maximum(combined, _mask_from_polygon(poly, width, height))
|
||||
class_meta.setdefault(key, _class_mapping_entry(annotation))
|
||||
|
||||
folder = f"分开Mask分割结果/{frame_stem}_分别导出"
|
||||
for key, mask in sorted(class_masks.items(), key=lambda item: int(class_meta[item[0]]["internalPriority"])):
|
||||
meta = class_meta[key]
|
||||
maskid = class_values.get(key)
|
||||
if maskid is None:
|
||||
continue
|
||||
_, encoded = cv2.imencode(".png", mask)
|
||||
class_name = _safe_filename_part(meta["className"], "class")
|
||||
zf.writestr(
|
||||
f"{folder}/{frame_stem}_{class_name}_maskid{maskid}.png",
|
||||
encoded.tobytes(),
|
||||
)
|
||||
|
||||
needs_fused_output = include_semantic or include_pro_label or include_mix_label
|
||||
semantic = np.zeros((height, width), dtype=np.uint16) if needs_fused_output else None
|
||||
pro_label = np.zeros((height, width, 3), dtype=np.uint8) if (include_pro_label or include_mix_label) else None
|
||||
|
||||
if needs_fused_output:
|
||||
for annotation in sorted(frame_annotations, key=_annotation_z_index):
|
||||
key = _annotation_class_key(annotation)
|
||||
value = class_values.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
combined = np.zeros((height, width), dtype=np.uint8)
|
||||
for poly in (annotation.mask_data or {}).get("polygons", []):
|
||||
combined = np.maximum(combined, _mask_from_polygon(poly, width, height))
|
||||
if semantic is not None:
|
||||
semantic[combined > 0] = value
|
||||
if pro_label is not None:
|
||||
rgb = class_rgb_by_key.get(key, [255, 255, 255])
|
||||
bgr = np.array([rgb[2], rgb[1], rgb[0]], dtype=np.uint8)
|
||||
pro_label[combined > 0] = bgr
|
||||
|
||||
if include_semantic and semantic is not None:
|
||||
_, encoded = cv2.imencode(".png", semantic)
|
||||
zf.writestr(f"GT_label图/{frame_stem}.png", encoded.tobytes())
|
||||
|
||||
if include_pro_label and pro_label is not None:
|
||||
_, encoded = cv2.imencode(".png", pro_label)
|
||||
zf.writestr(f"Pro_label彩色分割结果/{frame_stem}.png", encoded.tobytes())
|
||||
|
||||
if include_mix_label and pro_label is not None:
|
||||
original = _decode_original_image(original_images.get(frame.id), width, height)
|
||||
mask_pixels = np.any(pro_label > 0, axis=2)
|
||||
mixed = original.copy()
|
||||
opacity = min(max(float(mix_opacity), 0.0), 1.0)
|
||||
mixed[mask_pixels] = (
|
||||
original[mask_pixels].astype(np.float32) * (1.0 - opacity)
|
||||
+ pro_label[mask_pixels].astype(np.float32) * opacity
|
||||
).clip(0, 255).astype(np.uint8)
|
||||
_, encoded = cv2.imencode(".png", mixed)
|
||||
zf.writestr(f"Mix_label重叠覆盖彩色分割结果/{frame_stem}.png", encoded.tobytes())
|
||||
|
||||
|
||||
def _write_mask_pngs(
|
||||
zf: zipfile.ZipFile,
|
||||
frames: list[Frame],
|
||||
annotations: list[Annotation],
|
||||
*,
|
||||
mask_type: str,
|
||||
individual_prefix: str = "",
|
||||
semantic_prefix: str = "",
|
||||
semantic_file_stem: str = "semantic_frame",
|
||||
semantic_dtype: Any = np.uint8,
|
||||
) -> list[dict[str, Any]]:
|
||||
import cv2
|
||||
|
||||
class_values: dict[str, int] = {}
|
||||
semantic_classes: list[dict[str, Any]] = []
|
||||
@@ -235,46 +584,102 @@ def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingRes
|
||||
})
|
||||
return class_values[key]
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
frame_masks: dict[int, list[tuple[Annotation, np.ndarray]]] = {}
|
||||
for ann in annotations:
|
||||
if not ann.mask_data:
|
||||
continue
|
||||
polygons = ann.mask_data.get("polygons", [])
|
||||
if not polygons:
|
||||
continue
|
||||
include_individual = mask_type in {"separate", "both"}
|
||||
include_semantic = mask_type in {"gt_label", "both"}
|
||||
frame_masks: dict[int, list[tuple[Annotation, np.ndarray]]] = {}
|
||||
selected_frame_ids = {frame.id for frame in frames}
|
||||
|
||||
width = ann.frame.width if ann.frame else 1920
|
||||
height = ann.frame.height if ann.frame else 1080
|
||||
combined = np.zeros((height, width), dtype=np.uint8)
|
||||
for ann in annotations:
|
||||
if ann.frame_id not in selected_frame_ids or not ann.mask_data:
|
||||
continue
|
||||
polygons = ann.mask_data.get("polygons", [])
|
||||
if not polygons:
|
||||
continue
|
||||
|
||||
for poly in polygons:
|
||||
mask = _mask_from_polygon(poly, width, height)
|
||||
combined = np.maximum(combined, mask)
|
||||
width = ann.frame.width if ann.frame else 1920
|
||||
height = ann.frame.height if ann.frame else 1080
|
||||
combined = np.zeros((height, width), dtype=np.uint8)
|
||||
for poly in polygons:
|
||||
mask = _mask_from_polygon(poly, width, height)
|
||||
combined = np.maximum(combined, mask)
|
||||
|
||||
if include_individual:
|
||||
_, encoded = cv2.imencode(".png", combined)
|
||||
fname = f"mask_{ann.id:06d}.png"
|
||||
zf.writestr(fname, encoded.tobytes())
|
||||
if ann.frame_id is not None:
|
||||
frame_masks.setdefault(ann.frame_id, []).append((ann, combined))
|
||||
zf.writestr(f"{individual_prefix}mask_{ann.id:06d}.png", encoded.tobytes())
|
||||
if include_semantic and ann.frame_id is not None:
|
||||
frame_masks.setdefault(ann.frame_id, []).append((ann, combined))
|
||||
|
||||
if include_semantic:
|
||||
for frame in frames:
|
||||
entries = frame_masks.get(frame.id, [])
|
||||
if not entries:
|
||||
continue
|
||||
width = frame.width or 1920
|
||||
height = frame.height or 1080
|
||||
semantic = np.zeros((height, width), dtype=np.uint8)
|
||||
semantic = np.zeros((height, width), dtype=semantic_dtype)
|
||||
for ann, mask in sorted(entries, key=lambda item: _annotation_z_index(item[0])):
|
||||
semantic[mask > 0] = class_value(ann)
|
||||
_, encoded = cv2.imencode(".png", semantic)
|
||||
zf.writestr(f"semantic_frame_{frame.frame_index:06d}.png", encoded.tobytes())
|
||||
zf.writestr(f"{semantic_prefix}{semantic_file_stem}_{frame.frame_index:06d}.png", encoded.tobytes())
|
||||
|
||||
if include_semantic:
|
||||
zf.writestr(
|
||||
"semantic_classes.json",
|
||||
f"{semantic_prefix}semantic_classes.json",
|
||||
json.dumps({"classes": semantic_classes}, ensure_ascii=False, indent=2).encode("utf-8"),
|
||||
)
|
||||
return semantic_classes
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/coco",
|
||||
summary="Export annotations in COCO format",
|
||||
)
|
||||
def export_coco(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
"""Export all annotations for a project as a COCO-format JSON file."""
|
||||
project = _project_or_404(project_id, db, current_user)
|
||||
frames = _project_frames(project_id, db)
|
||||
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
|
||||
templates = db.query(Template).all()
|
||||
coco = _build_coco(project, frames, annotations, templates)
|
||||
|
||||
data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8")
|
||||
filename = f"project_{project_id}_coco.json"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(data),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/masks",
|
||||
summary="Export PNG masks as a ZIP archive",
|
||||
)
|
||||
def export_masks(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
"""Export individual masks plus z-index fused semantic masks inside a ZIP."""
|
||||
_project_or_404(project_id, db, current_user)
|
||||
frames = _project_frames(project_id, db)
|
||||
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
_write_mask_pngs(
|
||||
zf,
|
||||
frames,
|
||||
annotations,
|
||||
mask_type="both",
|
||||
semantic_prefix="",
|
||||
individual_prefix="",
|
||||
)
|
||||
|
||||
zip_buffer.seek(0)
|
||||
filename = f"project_{project_id}_masks.zip"
|
||||
@@ -284,3 +689,71 @@ def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingRes
|
||||
media_type="application/zip",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/results",
|
||||
summary="Export segmentation results as a ZIP archive",
|
||||
)
|
||||
def export_results(
|
||||
project_id: int,
|
||||
scope: str = Query("all", pattern="^(all|range|current)$"),
|
||||
mask_type: str = Query("both", pattern="^(separate|gt_label|pro_label|mix_label|both|all)$"),
|
||||
outputs: str | None = Query(None),
|
||||
mix_opacity: float = Query(0.3, ge=0.0, le=1.0),
|
||||
start_frame: int | None = Query(None, ge=1),
|
||||
end_frame: int | None = Query(None, ge=1),
|
||||
frame_id: int | None = Query(None, ge=1),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
"""Export JSON annotations plus selected PNG mask outputs inside one ZIP.
|
||||
|
||||
`scope=all` exports the whole video. `scope=range` uses 1-based frame
|
||||
numbers from the sorted project frame sequence. `scope=current` uses the
|
||||
concrete backend `frame_id`.
|
||||
"""
|
||||
project = _project_or_404(project_id, db, current_user)
|
||||
frames = _filter_frames(
|
||||
_project_frames(project_id, db),
|
||||
scope=scope,
|
||||
start_frame=start_frame,
|
||||
end_frame=end_frame,
|
||||
frame_id=frame_id,
|
||||
)
|
||||
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
|
||||
templates = db.query(Template).all()
|
||||
coco = _build_coco(project, frames, annotations, templates)
|
||||
class_values, class_mapping = _build_gt_class_mapping(annotations)
|
||||
selected_outputs = _parse_result_outputs(mask_type, outputs)
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
zf.writestr(
|
||||
"annotations_coco.json",
|
||||
json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8"),
|
||||
)
|
||||
zf.writestr(
|
||||
"maskid_GT像素值_类别映射.json",
|
||||
json.dumps({"classes": class_mapping}, ensure_ascii=False, indent=2).encode("utf-8"),
|
||||
)
|
||||
original_images = _write_original_frames(zf, project, frames)
|
||||
_write_result_mask_outputs(
|
||||
zf,
|
||||
project,
|
||||
frames,
|
||||
annotations,
|
||||
outputs=selected_outputs,
|
||||
class_values=class_values,
|
||||
class_mapping=class_mapping,
|
||||
original_images=original_images,
|
||||
mix_opacity=mix_opacity,
|
||||
)
|
||||
|
||||
zip_buffer.seek(0)
|
||||
filename = _segmentation_results_filename(project, frames)
|
||||
return StreamingResponse(
|
||||
zip_buffer,
|
||||
media_type="application/zip",
|
||||
headers={"Content-Disposition": _download_content_disposition(filename)},
|
||||
)
|
||||
|
||||
@@ -9,8 +9,9 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import upload_file, get_presigned_url
|
||||
from models import ProcessingTask, Project
|
||||
from models import ProcessingTask, Project, User
|
||||
from progress_events import publish_task_progress_event
|
||||
from routers.auth import require_editor
|
||||
from schemas import ProcessingTaskOut
|
||||
from statuses import PROJECT_STATUS_PARSING, PROJECT_STATUS_PENDING, TASK_STATUS_QUEUED
|
||||
from worker_tasks import parse_project_media
|
||||
@@ -34,6 +35,7 @@ async def upload_media(
|
||||
file: UploadFile = File(...),
|
||||
project_id: Optional[int] = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Accept a video, image, or DICOM file and store it in MinIO.
|
||||
|
||||
@@ -62,13 +64,15 @@ async def upload_media(
|
||||
file_url = get_presigned_url(object_name, expires=3600)
|
||||
|
||||
if project_id:
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if project:
|
||||
project.video_path = object_name
|
||||
db.commit()
|
||||
logger.info("Linked upload to project_id=%s", project_id)
|
||||
else:
|
||||
logger.warning("Project id=%s not found for upload linkage", project_id)
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
project.video_path = object_name
|
||||
db.commit()
|
||||
logger.info("Linked upload to project_id=%s", project_id)
|
||||
else:
|
||||
# Auto-create a project named after the file
|
||||
project = Project(
|
||||
@@ -77,6 +81,7 @@ async def upload_media(
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
video_path=object_name,
|
||||
source_type="video",
|
||||
owner_user_id=current_user.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
@@ -109,6 +114,7 @@ async def upload_dicom_batch(
|
||||
files: List[UploadFile] = File(...),
|
||||
project_id: Optional[int] = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Upload multiple .dcm files for a DICOM series.
|
||||
|
||||
@@ -121,7 +127,10 @@ async def upload_dicom_batch(
|
||||
uploaded = []
|
||||
|
||||
if project_id:
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
else:
|
||||
@@ -132,6 +141,7 @@ async def upload_dicom_batch(
|
||||
description=f"DICOM series with {len(files)} files",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="dicom",
|
||||
owner_user_id=current_user.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
@@ -173,13 +183,17 @@ def parse_media(
|
||||
max_frames: Optional[int] = Query(None, gt=0),
|
||||
target_width: int = Query(640, ge=64, le=4096),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> ProcessingTask:
|
||||
"""Create a background task for media frame extraction.
|
||||
|
||||
The Celery worker performs the heavy FFmpeg/OpenCV/pydicom work and
|
||||
updates the persisted task record as it progresses.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
|
||||
@@ -7,7 +7,8 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Project, Frame
|
||||
from models import Project, Frame, User
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut
|
||||
from minio_client import get_presigned_url
|
||||
|
||||
@@ -24,9 +25,13 @@ router = APIRouter(prefix="/api/projects", tags=["Projects"])
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a new project",
|
||||
)
|
||||
def create_project(payload: ProjectCreate, db: Session = Depends(get_db)) -> Project:
|
||||
def create_project(
|
||||
payload: ProjectCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Project:
|
||||
"""Create a new segmentation project."""
|
||||
project = Project(**payload.model_dump())
|
||||
project = Project(**payload.model_dump(), owner_user_id=current_user.id)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
@@ -39,9 +44,20 @@ def create_project(payload: ProjectCreate, db: Session = Depends(get_db)) -> Pro
|
||||
response_model=List[ProjectOut],
|
||||
summary="List all projects",
|
||||
)
|
||||
def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)) -> List[Project]:
|
||||
def list_projects(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[Project]:
|
||||
"""Retrieve a paginated list of projects."""
|
||||
projects = db.query(Project).offset(skip).limit(limit).all()
|
||||
projects = (
|
||||
db.query(Project)
|
||||
.filter(Project.owner_user_id == current_user.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
for p in projects:
|
||||
p.frame_count = len(p.frames)
|
||||
if p.thumbnail_url:
|
||||
@@ -54,9 +70,16 @@ def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)
|
||||
response_model=ProjectOut,
|
||||
summary="Get a single project",
|
||||
)
|
||||
def get_project(project_id: int, db: Session = Depends(get_db)) -> Project:
|
||||
def get_project(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Project:
|
||||
"""Retrieve a project by its ID."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
project.frame_count = len(project.frames)
|
||||
@@ -74,9 +97,13 @@ def update_project(
|
||||
project_id: int,
|
||||
payload: ProjectUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Project:
|
||||
"""Update project fields partially."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
@@ -94,9 +121,16 @@ def update_project(
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a project",
|
||||
)
|
||||
def delete_project(project_id: int, db: Session = Depends(get_db)) -> None:
|
||||
def delete_project(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> None:
|
||||
"""Delete a project and all related frames and annotations."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
@@ -118,9 +152,13 @@ def create_frame(
|
||||
project_id: int,
|
||||
payload: FrameCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Frame:
|
||||
"""Register a new frame under a project."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
@@ -141,9 +179,13 @@ def list_frames(
|
||||
skip: int = 0,
|
||||
limit: int = 1000,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[Frame]:
|
||||
"""Retrieve all frames belonging to a project."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
@@ -165,11 +207,21 @@ def list_frames(
|
||||
response_model=FrameOut,
|
||||
summary="Get a single frame",
|
||||
)
|
||||
def get_frame(project_id: int, frame_id: int, db: Session = Depends(get_db)) -> Frame:
|
||||
def get_frame(
|
||||
project_id: int,
|
||||
frame_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Frame:
|
||||
"""Retrieve a specific frame by ID."""
|
||||
frame = (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id, Frame.id == frame_id)
|
||||
.join(Project, Project.id == Frame.project_id)
|
||||
.filter(
|
||||
Frame.project_id == project_id,
|
||||
Frame.id == frame_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not frame:
|
||||
|
||||
@@ -9,8 +9,9 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from celery_app import celery_app
|
||||
from database import get_db
|
||||
from models import ProcessingTask, Project
|
||||
from models import ProcessingTask, Project, User
|
||||
from progress_events import publish_task_progress_event
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import ProcessingTaskOut
|
||||
from statuses import (
|
||||
PROJECT_STATUS_PARSING,
|
||||
@@ -31,8 +32,16 @@ def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _get_task_or_404(task_id: int, db: Session) -> ProcessingTask:
|
||||
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
|
||||
def _get_task_or_404(task_id: int, db: Session, current_user: User) -> ProcessingTask:
|
||||
task = (
|
||||
db.query(ProcessingTask)
|
||||
.outerjoin(Project, Project.id == ProcessingTask.project_id)
|
||||
.filter(
|
||||
ProcessingTask.id == task_id,
|
||||
(ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return task
|
||||
@@ -48,9 +57,12 @@ def list_tasks(
|
||||
status: str | None = None,
|
||||
limit: int = 50,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[ProcessingTask]:
|
||||
"""Return recent background processing tasks."""
|
||||
query = db.query(ProcessingTask)
|
||||
query = db.query(ProcessingTask).outerjoin(Project, Project.id == ProcessingTask.project_id).filter(
|
||||
(ProcessingTask.project_id.is_(None)) | (Project.owner_user_id == current_user.id)
|
||||
)
|
||||
if project_id is not None:
|
||||
query = query.filter(ProcessingTask.project_id == project_id)
|
||||
if status is not None:
|
||||
@@ -59,15 +71,23 @@ def list_tasks(
|
||||
|
||||
|
||||
@router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task")
|
||||
def get_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
def get_task(
|
||||
task_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> ProcessingTask:
|
||||
"""Return one background task by id."""
|
||||
return _get_task_or_404(task_id, db)
|
||||
return _get_task_or_404(task_id, db, current_user)
|
||||
|
||||
|
||||
@router.post("/{task_id}/cancel", response_model=ProcessingTaskOut, summary="Cancel processing task")
|
||||
def cancel_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
def cancel_task(
|
||||
task_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> ProcessingTask:
|
||||
"""Cancel a queued/running background task and revoke the Celery job when possible."""
|
||||
task = _get_task_or_404(task_id, db)
|
||||
task = _get_task_or_404(task_id, db, current_user)
|
||||
if task.status not in TASK_ACTIVE_STATUSES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
@@ -95,9 +115,13 @@ def cancel_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
|
||||
|
||||
@router.post("/{task_id}/retry", response_model=ProcessingTaskOut, status_code=status.HTTP_202_ACCEPTED, summary="Retry processing task")
|
||||
def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
def retry_task(
|
||||
task_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> ProcessingTask:
|
||||
"""Create a fresh queued task from a failed or cancelled task."""
|
||||
previous = _get_task_or_404(task_id, db)
|
||||
previous = _get_task_or_404(task_id, db, current_user)
|
||||
if previous.status not in {TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
@@ -106,7 +130,10 @@ def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
if previous.project_id is None:
|
||||
raise HTTPException(status_code=400, detail="Task has no project_id")
|
||||
|
||||
project = db.query(Project).filter(Project.id == previous.project_id).first()
|
||||
project = db.query(Project).filter(
|
||||
Project.id == previous.project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
is_propagation_task = previous.task_type == "propagate_masks"
|
||||
|
||||
@@ -4,10 +4,12 @@ import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Template
|
||||
from models import Template, User
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import TemplateCreate, TemplateOut, TemplateUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -40,11 +42,15 @@ def _unpack_template(template: Template) -> Template:
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a new template",
|
||||
)
|
||||
def create_template(payload: TemplateCreate, db: Session = Depends(get_db)) -> Template:
|
||||
def create_template(
|
||||
payload: TemplateCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Template:
|
||||
"""Create a new ontology template / segmentation class."""
|
||||
data = payload.model_dump()
|
||||
data = _pack_mapping_rules(data)
|
||||
template = Template(**data)
|
||||
template = Template(**data, owner_user_id=current_user.id)
|
||||
db.add(template)
|
||||
db.commit()
|
||||
db.refresh(template)
|
||||
@@ -62,9 +68,16 @@ def list_templates(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[Template]:
|
||||
"""Retrieve all ontology templates."""
|
||||
templates = db.query(Template).offset(skip).limit(limit).all()
|
||||
templates = (
|
||||
db.query(Template)
|
||||
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
for t in templates:
|
||||
_unpack_template(t)
|
||||
return templates
|
||||
@@ -75,9 +88,16 @@ def list_templates(
|
||||
response_model=TemplateOut,
|
||||
summary="Get a single template",
|
||||
)
|
||||
def get_template(template_id: int, db: Session = Depends(get_db)) -> Template:
|
||||
def get_template(
|
||||
template_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Template:
|
||||
"""Retrieve a template by its ID."""
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
template = db.query(Template).filter(
|
||||
Template.id == template_id,
|
||||
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
_unpack_template(template)
|
||||
@@ -93,9 +113,13 @@ def update_template(
|
||||
template_id: int,
|
||||
payload: TemplateUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Template:
|
||||
"""Update template fields partially."""
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
template = db.query(Template).filter(
|
||||
Template.id == template_id,
|
||||
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
@@ -118,9 +142,16 @@ def update_template(
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a template",
|
||||
)
|
||||
def delete_template(template_id: int, db: Session = Depends(get_db)) -> None:
|
||||
def delete_template(
|
||||
template_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> None:
|
||||
"""Delete a template. Associated annotations will have template_id set to NULL."""
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
template = db.query(Template).filter(
|
||||
Template.id == template_id,
|
||||
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user