添加Docker自包含部署分支
- 新增 Seg_Server_Docker 自包含部署内容,包含前后端、FastAPI、Celery、PostgreSQL、Redis、MinIO、演示视频和 DICOM 数据。 - 保留 demo 数据以支持恢复演示出厂设置,排除 SAM 2.1 .pt 权重并在 README 中补充下载命令。 - 补充 GPU 部署、backend/worker 镜像复用、frpc/frps + NPM 公网域名反代部署说明。 - 在 .env/.env.example 中用 # XXXX 标注局域网和公网域名部署需要修改的配置项。 - 添加部署分支 .gitignore,忽略本地模型权重、构建产物、缓存和日志。
This commit is contained in:
0
backend/routers/__init__.py
Normal file
0
backend/routers/__init__.py
Normal file
299
backend/routers/admin.py
Normal file
299
backend/routers/admin.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""Administrator-only user and audit management endpoints."""
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config import settings
|
||||
from database import get_db
|
||||
from models import Annotation, AuditLog, Frame, Mask, ProcessingTask, Project, Template, User
|
||||
from routers.auth import SUPPORTED_ROLES, ensure_default_admin, hash_password, normalize_user_role, require_admin, write_audit_log
|
||||
from schemas import (
|
||||
AdminUserCreate,
|
||||
AdminUserUpdate,
|
||||
AuditLogOut,
|
||||
DemoFactoryResetOut,
|
||||
DemoFactoryResetRequest,
|
||||
UserOut,
|
||||
)
|
||||
from services.demo_media import (
|
||||
DEMO_DICOM_PROJECT_NAME,
|
||||
DEMO_VIDEO_PROJECT_NAME,
|
||||
create_parsed_dicom_demo_project,
|
||||
create_parsed_video_demo_project,
|
||||
demo_dicom_files,
|
||||
)
|
||||
from services.default_templates import restore_default_templates
|
||||
|
||||
router = APIRouter(prefix="/api/admin", tags=["Admin"])
|
||||
|
||||
DEMO_RESET_CONFIRMATION = "RESET_DEMO_FACTORY"
|
||||
DEMO_PROJECT_NAME = DEMO_DICOM_PROJECT_NAME
|
||||
|
||||
|
||||
def _normalize_role(role: str | None) -> str:
|
||||
normalized = (role or "annotator").strip().lower()
|
||||
if normalized not in SUPPORTED_ROLES:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported role: {role}")
|
||||
return normalized
|
||||
|
||||
|
||||
def _assert_non_admin_role(role: str) -> None:
|
||||
if role == "admin":
|
||||
raise HTTPException(status_code=400, detail="Only the default admin account can have admin role")
|
||||
|
||||
|
||||
@router.get("/users", response_model=List[UserOut], summary="List users")
|
||||
def list_users(
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> List[User]:
|
||||
"""Return all users for the administrator console."""
|
||||
_ = admin_user
|
||||
users = db.query(User).order_by(User.id).all()
|
||||
return [normalize_user_role(db, user) for user in users]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/users",
|
||||
response_model=UserOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create user",
|
||||
)
|
||||
def create_user(
|
||||
payload: AdminUserCreate,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> User:
|
||||
"""Create a user with an initial password and role."""
|
||||
username = payload.username.strip()
|
||||
if not username:
|
||||
raise HTTPException(status_code=400, detail="Username is required")
|
||||
if len(payload.password) < 6:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
|
||||
role = _normalize_role(payload.role)
|
||||
_assert_non_admin_role(role)
|
||||
user = User(
|
||||
username=username,
|
||||
password_hash=hash_password(payload.password),
|
||||
role=role,
|
||||
is_active=1 if payload.is_active else 0,
|
||||
)
|
||||
db.add(user)
|
||||
try:
|
||||
db.commit()
|
||||
except IntegrityError as exc:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=409, detail="Username already exists") from exc
|
||||
db.refresh(user)
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=admin_user,
|
||||
action="admin.user_created",
|
||||
target_type="user",
|
||||
target_id=user.id,
|
||||
detail={"username": user.username, "role": user.role, "is_active": bool(user.is_active)},
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@router.patch("/users/{user_id}", response_model=UserOut, summary="Update user")
|
||||
def update_user(
|
||||
user_id: int,
|
||||
payload: AdminUserUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> User:
|
||||
"""Update username, password, role or active state."""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
user = normalize_user_role(db, user)
|
||||
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
audit_detail: dict = {"before": {"username": user.username, "role": user.role, "is_active": bool(user.is_active)}}
|
||||
if "username" in updates:
|
||||
username = (updates["username"] or "").strip()
|
||||
if not username:
|
||||
raise HTTPException(status_code=400, detail="Username is required")
|
||||
if user.role == "admin" and username != settings.default_admin_username:
|
||||
raise HTTPException(status_code=400, detail="Default admin username cannot be changed")
|
||||
user.username = username
|
||||
if "password" in updates:
|
||||
password = updates["password"] or ""
|
||||
if len(password) < 6:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
|
||||
user.password_hash = hash_password(password)
|
||||
if "role" in updates:
|
||||
next_role = _normalize_role(updates["role"])
|
||||
if user.username == settings.default_admin_username:
|
||||
if next_role != "admin":
|
||||
raise HTTPException(status_code=400, detail="Cannot remove the default admin role")
|
||||
else:
|
||||
_assert_non_admin_role(next_role)
|
||||
user.role = next_role
|
||||
if "is_active" in updates:
|
||||
if user.id == admin_user.id and not updates["is_active"]:
|
||||
raise HTTPException(status_code=400, detail="Cannot deactivate yourself")
|
||||
user.is_active = 1 if updates["is_active"] else 0
|
||||
|
||||
try:
|
||||
db.commit()
|
||||
except IntegrityError as exc:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=409, detail="Username already exists") from exc
|
||||
db.refresh(user)
|
||||
audit_detail["after"] = {"username": user.username, "role": user.role, "is_active": bool(user.is_active)}
|
||||
audit_detail["password_changed"] = "password" in updates
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=admin_user,
|
||||
action="admin.user_updated",
|
||||
target_type="user",
|
||||
target_id=user.id,
|
||||
detail=audit_detail,
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete user")
|
||||
def delete_user(
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> None:
|
||||
"""Delete a user when it is safe to remove the account."""
|
||||
if user_id == admin_user.id:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete yourself")
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
user = normalize_user_role(db, user)
|
||||
if user.role == "admin":
|
||||
raise HTTPException(status_code=400, detail="Cannot delete the default admin account")
|
||||
username = user.username
|
||||
db.delete(user)
|
||||
db.commit()
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=admin_user,
|
||||
action="admin.user_deleted",
|
||||
target_type="user",
|
||||
target_id=user_id,
|
||||
detail={"username": username},
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/audit-logs", response_model=List[AuditLogOut], summary="List audit logs")
|
||||
def list_audit_logs(
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> List[AuditLog]:
|
||||
"""Return recent audit events for administrators."""
|
||||
_ = admin_user
|
||||
safe_limit = min(max(int(limit or 100), 1), 500)
|
||||
return db.query(AuditLog).order_by(AuditLog.created_at.desc(), AuditLog.id.desc()).limit(safe_limit).all()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/demo-factory-reset",
|
||||
response_model=DemoFactoryResetOut,
|
||||
summary="Reset demo data to factory defaults",
|
||||
)
|
||||
def reset_demo_factory(
|
||||
payload: DemoFactoryResetRequest,
|
||||
db: Session = Depends(get_db),
|
||||
admin_user: User = Depends(require_admin),
|
||||
) -> dict:
|
||||
"""Reset a demo deployment to one admin account, the demo video, and the demo DICOM project."""
|
||||
if payload.confirmation != DEMO_RESET_CONFIRMATION:
|
||||
raise HTTPException(status_code=400, detail="Invalid reset confirmation")
|
||||
|
||||
if not os.path.exists(settings.demo_video_path):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Demo video not found: {settings.demo_video_path}",
|
||||
)
|
||||
if not demo_dicom_files(settings.demo_dicom_dir):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Demo DICOM series not found: {settings.demo_dicom_dir}",
|
||||
)
|
||||
|
||||
requested_by = admin_user.username
|
||||
preserved_admin = ensure_default_admin(db)
|
||||
preserved_admin.username = settings.default_admin_username
|
||||
preserved_admin.password_hash = hash_password(settings.default_admin_password)
|
||||
preserved_admin.role = "admin"
|
||||
preserved_admin.is_active = 1
|
||||
db.flush()
|
||||
|
||||
deleted_counts = {
|
||||
"masks": db.query(Mask).delete(synchronize_session=False),
|
||||
"annotations": db.query(Annotation).delete(synchronize_session=False),
|
||||
"frames": db.query(Frame).delete(synchronize_session=False),
|
||||
"tasks": db.query(ProcessingTask).delete(synchronize_session=False),
|
||||
"projects": db.query(Project).delete(synchronize_session=False),
|
||||
"user_templates": db.query(Template).filter(Template.owner_user_id.is_not(None)).delete(synchronize_session=False),
|
||||
"audit_logs": db.query(AuditLog).delete(synchronize_session=False),
|
||||
"users": db.query(User).filter(User.id != preserved_admin.id).delete(synchronize_session=False),
|
||||
}
|
||||
db.flush()
|
||||
db.expunge_all()
|
||||
|
||||
preserved_admin = db.query(User).filter(User.username == settings.default_admin_username).first()
|
||||
if not preserved_admin:
|
||||
raise HTTPException(status_code=500, detail="Default admin was not preserved")
|
||||
|
||||
restored_templates = restore_default_templates(db)
|
||||
|
||||
video_project = create_parsed_video_demo_project(
|
||||
db,
|
||||
owner=preserved_admin,
|
||||
video_path=settings.demo_video_path,
|
||||
project_name=DEMO_VIDEO_PROJECT_NAME,
|
||||
)
|
||||
|
||||
dicom_project = create_parsed_dicom_demo_project(
|
||||
db,
|
||||
owner=preserved_admin,
|
||||
dicom_dir=settings.demo_dicom_dir,
|
||||
project_name=DEMO_PROJECT_NAME,
|
||||
)
|
||||
db.refresh(preserved_admin)
|
||||
db.refresh(video_project)
|
||||
db.refresh(dicom_project)
|
||||
video_project.frame_count = len(video_project.frames)
|
||||
dicom_project.frame_count = len(dicom_project.frames)
|
||||
projects = [video_project, dicom_project]
|
||||
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=preserved_admin,
|
||||
action="admin.demo_factory_reset",
|
||||
target_type="project",
|
||||
target_id=dicom_project.id,
|
||||
detail={
|
||||
"project_names": [project.name for project in projects],
|
||||
"video_path": video_project.video_path,
|
||||
"dicom_path": dicom_project.video_path,
|
||||
"source_types": [project.source_type for project in projects],
|
||||
"frame_counts": {project.name: len(project.frames) for project in projects},
|
||||
"deleted_counts": deleted_counts,
|
||||
"restored_templates": [template.name for template in restored_templates],
|
||||
"requested_by": requested_by,
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"admin_user": preserved_admin,
|
||||
"project": dicom_project,
|
||||
"projects": projects,
|
||||
"deleted_counts": deleted_counts,
|
||||
"message": "演示环境已恢复出厂设置",
|
||||
}
|
||||
1228
backend/routers/ai.py
Normal file
1228
backend/routers/ai.py
Normal file
File diff suppressed because it is too large
Load Diff
222
backend/routers/auth.py
Normal file
222
backend/routers/auth.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Authentication endpoints and dependencies."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config import settings
|
||||
from database import get_db
|
||||
from models import AuditLog, User
|
||||
from schemas import LoginResponse, UserOut
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["Auth"])
|
||||
password_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
SUPPORTED_ROLES = {"admin", "annotator"}
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a plain password for storage."""
|
||||
return password_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(password: str, password_hash: str) -> bool:
|
||||
"""Verify a plain password against a stored hash."""
|
||||
return password_context.verify(password, password_hash)
|
||||
|
||||
|
||||
def create_access_token(user: User, expires_delta: timedelta | None = None) -> str:
|
||||
"""Create a signed JWT access token for a user."""
|
||||
expire = datetime.now(timezone.utc) + (
|
||||
expires_delta or timedelta(minutes=settings.access_token_expire_minutes)
|
||||
)
|
||||
payload: dict[str, Any] = {
|
||||
"sub": str(user.id),
|
||||
"username": user.username,
|
||||
"role": user.role,
|
||||
"exp": expire,
|
||||
}
|
||||
return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
|
||||
|
||||
|
||||
def ensure_default_admin(db: Session) -> User:
|
||||
"""Create and enforce the single default administrator account."""
|
||||
existing = db.query(User).filter(User.username == settings.default_admin_username).first()
|
||||
if existing:
|
||||
changed = False
|
||||
if existing.role != "admin":
|
||||
existing.role = "admin"
|
||||
changed = True
|
||||
if not existing.is_active:
|
||||
existing.is_active = 1
|
||||
changed = True
|
||||
extra_admins = db.query(User).filter(
|
||||
User.role == "admin",
|
||||
User.id != existing.id,
|
||||
).all()
|
||||
for user in extra_admins:
|
||||
user.role = "annotator"
|
||||
changed = True
|
||||
if changed:
|
||||
db.commit()
|
||||
db.refresh(existing)
|
||||
return existing
|
||||
user = User(
|
||||
username=settings.default_admin_username,
|
||||
password_hash=hash_password(settings.default_admin_password),
|
||||
role="admin",
|
||||
is_active=1,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
extra_admins = db.query(User).filter(
|
||||
User.role == "admin",
|
||||
User.id != user.id,
|
||||
).all()
|
||||
if extra_admins:
|
||||
for extra_user in extra_admins:
|
||||
extra_user.role = "annotator"
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
def normalize_user_role(db: Session, user: User) -> User:
|
||||
"""Keep legacy accounts within the current two-role policy."""
|
||||
desired_role = "admin" if user.username == settings.default_admin_username else "annotator"
|
||||
changed = False
|
||||
if user.role != desired_role:
|
||||
user.role = desired_role
|
||||
changed = True
|
||||
if user.username == settings.default_admin_username and not user.is_active:
|
||||
user.is_active = 1
|
||||
changed = True
|
||||
if changed:
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
||||
db: Session = Depends(get_db),
|
||||
) -> User:
|
||||
"""Resolve and validate the current user from the Bearer token."""
|
||||
if credentials is None or credentials.scheme.lower() != "bearer":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
credentials.credentials,
|
||||
settings.jwt_secret_key,
|
||||
algorithms=[settings.jwt_algorithm],
|
||||
)
|
||||
user_id = int(payload.get("sub"))
|
||||
except (JWTError, TypeError, ValueError) as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from exc
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user:
|
||||
user = normalize_user_role(db, user)
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Inactive or missing user",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
def require_admin(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""Require the current user to have the administrator role."""
|
||||
if current_user.role != "admin":
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin permission required")
|
||||
return current_user
|
||||
|
||||
|
||||
def require_editor(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""Require a user role that can modify segmentation data."""
|
||||
if current_user.role not in SUPPORTED_ROLES:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Edit permission required")
|
||||
return current_user
|
||||
|
||||
|
||||
def write_audit_log(
|
||||
db: Session,
|
||||
*,
|
||||
actor: User | None,
|
||||
action: str,
|
||||
target_type: str | None = None,
|
||||
target_id: str | int | None = None,
|
||||
detail: dict[str, Any] | None = None,
|
||||
) -> AuditLog:
|
||||
"""Persist a compact audit event."""
|
||||
log = AuditLog(
|
||||
actor_user_id=actor.id if actor else None,
|
||||
action=action,
|
||||
target_type=target_type,
|
||||
target_id=str(target_id) if target_id is not None else None,
|
||||
detail=detail or {},
|
||||
)
|
||||
db.add(log)
|
||||
db.commit()
|
||||
db.refresh(log)
|
||||
return log
|
||||
|
||||
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
def login(payload: LoginRequest, db: Session = Depends(get_db)) -> dict:
|
||||
"""Authenticate a user and return a signed JWT."""
|
||||
ensure_default_admin(db)
|
||||
user = db.query(User).filter(User.username == payload.username).first()
|
||||
if user:
|
||||
user = normalize_user_role(db, user)
|
||||
if not user or not user.is_active or not verify_password(payload.password, user.password_hash):
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=None,
|
||||
action="auth.login_failed",
|
||||
target_type="user",
|
||||
target_id=payload.username,
|
||||
detail={"username": payload.username},
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
write_audit_log(
|
||||
db,
|
||||
actor=user,
|
||||
action="auth.login_success",
|
||||
target_type="user",
|
||||
target_id=user.id,
|
||||
detail={"username": user.username},
|
||||
)
|
||||
return {
|
||||
"token": create_access_token(user),
|
||||
"token_type": "bearer",
|
||||
"username": user.username,
|
||||
"user": user,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserOut)
|
||||
def read_current_user(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""Return the authenticated user profile."""
|
||||
return current_user
|
||||
164
backend/routers/dashboard.py
Normal file
164
backend/routers/dashboard.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Dashboard overview endpoints."""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Annotation, Frame, ProcessingTask, Project, Template, User
|
||||
from routers.auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
|
||||
|
||||
ACTIVE_TASK_STATUSES = {"queued", "running"}
|
||||
MONITORED_TASK_STATUSES = {"queued", "running", "success", "failed", "cancelled"}
|
||||
|
||||
|
||||
def _system_load_percent() -> int:
|
||||
"""Return a real host load estimate without adding a psutil dependency."""
|
||||
try:
|
||||
load_1m = os.getloadavg()[0]
|
||||
cpu_count = os.cpu_count() or 1
|
||||
return min(100, max(0, round((load_1m / cpu_count) * 100)))
|
||||
except (AttributeError, OSError):
|
||||
return 0
|
||||
|
||||
|
||||
def _iso_or_none(value: datetime | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
return value.isoformat()
|
||||
|
||||
|
||||
def _task_payload(task: ProcessingTask) -> dict[str, Any]:
|
||||
result = task.result or {}
|
||||
return {
|
||||
"id": f"task-{task.id}",
|
||||
"task_id": task.id,
|
||||
"project_id": task.project_id or 0,
|
||||
"name": task.project.name if task.project else f"任务 {task.id}",
|
||||
"progress": task.progress,
|
||||
"status": task.message or task.status,
|
||||
"raw_status": task.status,
|
||||
"frame_count": result.get("frames_extracted", result.get("processed_frame_count", 0)),
|
||||
"error": task.error,
|
||||
"updated_at": _iso_or_none(task.updated_at),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/overview", summary="Get dashboard overview")
|
||||
def get_dashboard_overview(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict[str, Any]:
|
||||
"""Return live dashboard data derived from persisted backend records."""
|
||||
shared_project_ids_query = db.query(Project.id)
|
||||
project_count = db.query(func.count(Project.id)).scalar() or 0
|
||||
frame_count = db.query(func.count(Frame.id)).filter(Frame.project_id.in_(shared_project_ids_query)).scalar() or 0
|
||||
annotation_count = (
|
||||
db.query(func.count(Annotation.id))
|
||||
.filter(Annotation.project_id.in_(shared_project_ids_query))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
template_count = (
|
||||
db.query(func.count(Template.id))
|
||||
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
active_task_count = (
|
||||
db.query(func.count(ProcessingTask.id))
|
||||
.outerjoin(Project, Project.id == ProcessingTask.project_id)
|
||||
.filter(ProcessingTask.status.in_(ACTIVE_TASK_STATUSES))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
projects = (
|
||||
db.query(Project)
|
||||
.order_by(Project.updated_at.desc())
|
||||
.all()
|
||||
)
|
||||
recent_tasks = (
|
||||
db.query(ProcessingTask)
|
||||
.outerjoin(Project, Project.id == ProcessingTask.project_id)
|
||||
.order_by(ProcessingTask.created_at.desc())
|
||||
.limit(50)
|
||||
.all()
|
||||
)
|
||||
tasks = [_task_payload(task) for task in recent_tasks if task.status in MONITORED_TASK_STATUSES]
|
||||
|
||||
activities: list[dict[str, Any]] = []
|
||||
for task in recent_tasks[:10]:
|
||||
project_name = task.project.name if task.project else f"项目 {task.project_id}"
|
||||
activities.append({
|
||||
"id": f"task-{task.id}",
|
||||
"kind": "task",
|
||||
"time": _iso_or_none(task.updated_at),
|
||||
"message": task.message or f"任务状态: {task.status}",
|
||||
"project": project_name,
|
||||
})
|
||||
|
||||
for project in projects[:10]:
|
||||
activities.append({
|
||||
"id": f"project-{project.id}",
|
||||
"kind": "project",
|
||||
"time": _iso_or_none(project.updated_at),
|
||||
"message": f"项目状态: {project.status}",
|
||||
"project": project.name,
|
||||
})
|
||||
|
||||
recent_annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id.in_(shared_project_ids_query))
|
||||
.order_by(Annotation.updated_at.desc())
|
||||
.limit(10)
|
||||
.all()
|
||||
)
|
||||
for annotation in recent_annotations:
|
||||
project_name = annotation.project.name if annotation.project else f"项目 {annotation.project_id}"
|
||||
activities.append({
|
||||
"id": f"annotation-{annotation.id}",
|
||||
"kind": "annotation",
|
||||
"time": _iso_or_none(annotation.updated_at),
|
||||
"message": f"标注已更新 #{annotation.id}",
|
||||
"project": project_name,
|
||||
})
|
||||
|
||||
recent_templates = (
|
||||
db.query(Template)
|
||||
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
|
||||
.order_by(Template.created_at.desc())
|
||||
.limit(10)
|
||||
.all()
|
||||
)
|
||||
for template in recent_templates:
|
||||
activities.append({
|
||||
"id": f"template-{template.id}",
|
||||
"kind": "template",
|
||||
"time": _iso_or_none(template.created_at),
|
||||
"message": f"模板可用: {template.name}",
|
||||
"project": "系统",
|
||||
})
|
||||
|
||||
activities.sort(key=lambda item: item["time"] or "", reverse=True)
|
||||
|
||||
return {
|
||||
"summary": {
|
||||
"project_count": project_count,
|
||||
"parsing_task_count": active_task_count,
|
||||
"annotation_count": annotation_count,
|
||||
"frame_count": frame_count,
|
||||
"template_count": template_count,
|
||||
"system_load_percent": _system_load_percent(),
|
||||
},
|
||||
"tasks": tasks,
|
||||
"activity": activities[:10],
|
||||
}
|
||||
764
backend/routers/export.py
Normal file
764
backend/routers/export.py
Normal file
@@ -0,0 +1,764 @@
|
||||
"""Annotation export endpoints (COCO, PNG masks)."""
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from urllib.parse import quote
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import download_file
|
||||
from models import Project, Annotation, Frame, Template, User
|
||||
from routers.auth import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/export", tags=["Export"])
|
||||
|
||||
|
||||
def _mask_from_polygon(
|
||||
polygon: List[List[float]],
|
||||
width: int,
|
||||
height: int,
|
||||
) -> np.ndarray:
|
||||
"""Render a normalized polygon to a binary mask."""
|
||||
import cv2
|
||||
|
||||
pts = np.array(
|
||||
[[int(p[0] * width), int(p[1] * height)] for p in polygon],
|
||||
dtype=np.int32,
|
||||
)
|
||||
mask = np.zeros((height, width), dtype=np.uint8)
|
||||
cv2.fillPoly(mask, [pts], 255)
|
||||
return mask
|
||||
|
||||
|
||||
def _annotation_z_index(annotation: Annotation) -> int:
|
||||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||
if isinstance(class_meta, dict) and class_meta.get("zIndex") is not None:
|
||||
try:
|
||||
return int(class_meta["zIndex"])
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
if annotation.template and annotation.template.z_index is not None:
|
||||
return int(annotation.template.z_index)
|
||||
return 0
|
||||
|
||||
|
||||
def _annotation_mask_id(annotation: Annotation) -> int | None:
|
||||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||
if isinstance(class_meta, dict):
|
||||
for key in ("maskId", "maskid", "mask_id"):
|
||||
if class_meta.get(key) is None:
|
||||
continue
|
||||
try:
|
||||
value = int(class_meta[key])
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if value >= 0:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _annotation_category_name(annotation: Annotation) -> str:
|
||||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||
if isinstance(class_meta, dict) and class_meta.get("category"):
|
||||
return str(class_meta["category"])
|
||||
if annotation.template and annotation.template.name:
|
||||
return str(annotation.template.name)
|
||||
return ""
|
||||
|
||||
|
||||
def _annotation_class_key(annotation: Annotation) -> str:
|
||||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||
if isinstance(class_meta, dict):
|
||||
if class_meta.get("id"):
|
||||
return f"class:{class_meta['id']}"
|
||||
if class_meta.get("name"):
|
||||
return f"name:{class_meta['name']}"
|
||||
if annotation.template_id:
|
||||
return f"template:{annotation.template_id}"
|
||||
return f"annotation:{annotation.id}"
|
||||
|
||||
|
||||
def _annotation_label(annotation: Annotation) -> str:
|
||||
mask_data = annotation.mask_data or {}
|
||||
class_meta = mask_data.get("class") or {}
|
||||
if isinstance(class_meta, dict) and class_meta.get("name"):
|
||||
return str(class_meta["name"])
|
||||
if mask_data.get("label"):
|
||||
return str(mask_data["label"])
|
||||
if annotation.template and annotation.template.name:
|
||||
return str(annotation.template.name)
|
||||
return f"Annotation {annotation.id}"
|
||||
|
||||
|
||||
def _annotation_color(annotation: Annotation) -> str:
|
||||
mask_data = annotation.mask_data or {}
|
||||
class_meta = mask_data.get("class") or {}
|
||||
if isinstance(class_meta, dict) and class_meta.get("color"):
|
||||
return str(class_meta["color"])
|
||||
if mask_data.get("color"):
|
||||
return str(mask_data["color"])
|
||||
if annotation.template and annotation.template.color:
|
||||
return str(annotation.template.color)
|
||||
return "#ffffff"
|
||||
|
||||
|
||||
def _hex_to_rgb(color: str) -> list[int]:
|
||||
value = str(color or "").strip()
|
||||
if value.startswith("#"):
|
||||
value = value[1:]
|
||||
if len(value) == 3:
|
||||
value = "".join(part * 2 for part in value)
|
||||
if len(value) != 6:
|
||||
return [255, 255, 255]
|
||||
try:
|
||||
return [int(value[i:i + 2], 16) for i in (0, 2, 4)]
|
||||
except ValueError:
|
||||
return [255, 255, 255]
|
||||
|
||||
|
||||
def _safe_filename_part(value: Any, fallback: str = "unknown") -> str:
|
||||
text = str(value or "").strip()
|
||||
if not text:
|
||||
text = fallback
|
||||
text = re.sub(r"[\\/:*?\"<>|\s]+", "_", text)
|
||||
text = re.sub(r"_+", "_", text).strip("._")
|
||||
return text or fallback
|
||||
|
||||
|
||||
def _project_video_name(project: Project) -> str:
|
||||
if project.video_path:
|
||||
stem = Path(project.video_path).name
|
||||
if "." in stem:
|
||||
stem = ".".join(stem.split(".")[:-1])
|
||||
if stem:
|
||||
return _safe_filename_part(stem, f"project_{project.id}")
|
||||
return _safe_filename_part(project.name, f"project_{project.id}")
|
||||
|
||||
|
||||
def _project_export_name(project: Project) -> str:
|
||||
return _safe_filename_part(project.name, f"project_{project.id}")
|
||||
|
||||
|
||||
def _frame_timestamp_ms(frame: Frame, project: Project) -> float:
|
||||
if frame.timestamp_ms is not None:
|
||||
return float(frame.timestamp_ms)
|
||||
fps = project.parse_fps or project.original_fps or 30.0
|
||||
return float(frame.frame_index) * 1000.0 / max(float(fps), 1.0)
|
||||
|
||||
|
||||
def _project_frame_number(frame: Frame) -> int:
|
||||
return int(frame.frame_index) + 1
|
||||
|
||||
|
||||
def _format_timestamp_ms(value: float) -> str:
|
||||
total_ms = max(0, int(round(float(value))))
|
||||
hours = total_ms // 3_600_000
|
||||
minutes = (total_ms % 3_600_000) // 60_000
|
||||
seconds = (total_ms % 60_000) // 1_000
|
||||
milliseconds = total_ms % 1_000
|
||||
return f"{hours}h{minutes:02d}m{seconds:02d}s{milliseconds:03d}ms"
|
||||
|
||||
|
||||
def _frame_export_stem(project: Project, frame: Frame) -> str:
|
||||
return "_".join([
|
||||
_project_video_name(project),
|
||||
_format_timestamp_ms(_frame_timestamp_ms(frame, project)),
|
||||
f"frame{_project_frame_number(frame):06d}",
|
||||
])
|
||||
|
||||
|
||||
def _segmentation_results_filename(project: Project, frames: list[Frame]) -> str:
|
||||
if not frames:
|
||||
return f"{_project_export_name(project)}_seg_T_0h00m00s000ms-0h00m00s000ms_P_0-0.zip"
|
||||
first_frame = frames[0]
|
||||
last_frame = frames[-1]
|
||||
return (
|
||||
f"{_project_export_name(project)}"
|
||||
f"_seg_T_{_format_timestamp_ms(_frame_timestamp_ms(first_frame, project))}"
|
||||
f"-{_format_timestamp_ms(_frame_timestamp_ms(last_frame, project))}"
|
||||
f"_P_{_project_frame_number(first_frame)}-{_project_frame_number(last_frame)}.zip"
|
||||
)
|
||||
|
||||
|
||||
def _download_content_disposition(filename: str) -> str:
|
||||
ascii_fallback = filename.encode("ascii", "ignore").decode("ascii") or "segmentation_results.zip"
|
||||
ascii_fallback = _safe_filename_part(ascii_fallback, "segmentation_results.zip")
|
||||
if not ascii_fallback.endswith(".zip") and filename.endswith(".zip"):
|
||||
ascii_fallback = f"{ascii_fallback}.zip"
|
||||
return f"attachment; filename=\"{ascii_fallback}\"; filename*=UTF-8''{quote(filename)}"
|
||||
|
||||
|
||||
def _frame_image_extension(frame: Frame) -> str:
|
||||
suffix = Path(frame.image_url or "").suffix.lower()
|
||||
return suffix if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} else ".jpg"
|
||||
|
||||
|
||||
def _project_or_404(project_id: int, db: Session, current_user: User) -> Project:
|
||||
_ = current_user
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
return project
|
||||
|
||||
|
||||
def _project_frames(project_id: int, db: Session) -> list[Frame]:
|
||||
return (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id)
|
||||
.order_by(Frame.frame_index)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def _filter_frames(
|
||||
frames: list[Frame],
|
||||
*,
|
||||
scope: str = "all",
|
||||
start_frame: int | None = None,
|
||||
end_frame: int | None = None,
|
||||
frame_id: int | None = None,
|
||||
) -> list[Frame]:
|
||||
if scope == "current":
|
||||
if frame_id is None:
|
||||
raise HTTPException(status_code=400, detail="frame_id is required for current-frame export")
|
||||
selected = [frame for frame in frames if frame.id == frame_id]
|
||||
if not selected:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
return selected
|
||||
|
||||
if scope == "range":
|
||||
if start_frame is None or end_frame is None:
|
||||
raise HTTPException(status_code=400, detail="start_frame and end_frame are required for range export")
|
||||
start = max(1, min(int(start_frame), int(end_frame)))
|
||||
end = max(1, max(int(start_frame), int(end_frame)))
|
||||
return frames[start - 1:end]
|
||||
|
||||
return frames
|
||||
|
||||
|
||||
def _filtered_annotations(project_id: int, frame_ids: set[int], db: Session) -> list[Annotation]:
|
||||
if not frame_ids:
|
||||
return []
|
||||
return (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == project_id)
|
||||
.filter(Annotation.frame_id.in_(frame_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def _build_coco(project: Project, frames: list[Frame], annotations: list[Annotation], templates: list[Template]) -> dict[str, Any]:
|
||||
images = []
|
||||
for frame in frames:
|
||||
images.append({
|
||||
"id": frame.id,
|
||||
"file_name": frame.image_url,
|
||||
"width": frame.width or 1920,
|
||||
"height": frame.height or 1080,
|
||||
"frame_index": frame.frame_index,
|
||||
})
|
||||
|
||||
categories = []
|
||||
template_id_to_cat_id: Dict[int, int] = {}
|
||||
for cat_idx, tmpl in enumerate(templates, start=1):
|
||||
categories.append({
|
||||
"id": cat_idx,
|
||||
"name": tmpl.name,
|
||||
"color": tmpl.color,
|
||||
})
|
||||
template_id_to_cat_id[tmpl.id] = cat_idx
|
||||
|
||||
coco_annotations = []
|
||||
ann_id = 1
|
||||
selected_frame_ids = {frame.id for frame in frames}
|
||||
for ann in annotations:
|
||||
if ann.frame_id not in selected_frame_ids or not ann.mask_data:
|
||||
continue
|
||||
polygons = ann.mask_data.get("polygons", [])
|
||||
if not polygons:
|
||||
continue
|
||||
|
||||
first_poly = polygons[0]
|
||||
xs = [p[0] for p in first_poly]
|
||||
ys = [p[1] for p in first_poly]
|
||||
width = ann.frame.width if ann.frame else 1920
|
||||
height = ann.frame.height if ann.frame else 1080
|
||||
bbox = [
|
||||
min(xs) * width,
|
||||
min(ys) * height,
|
||||
(max(xs) - min(xs)) * width,
|
||||
(max(ys) - min(ys)) * height,
|
||||
]
|
||||
area = bbox[2] * bbox[3]
|
||||
|
||||
segmentation = []
|
||||
for poly in polygons:
|
||||
flat = []
|
||||
for p in poly:
|
||||
flat.append(p[0] * width)
|
||||
flat.append(p[1] * height)
|
||||
segmentation.append(flat)
|
||||
|
||||
coco_annotations.append({
|
||||
"id": ann_id,
|
||||
"image_id": ann.frame_id,
|
||||
"category_id": template_id_to_cat_id.get(ann.template_id, 0),
|
||||
"segmentation": segmentation,
|
||||
"area": area,
|
||||
"bbox": bbox,
|
||||
"iscrowd": 0,
|
||||
})
|
||||
ann_id += 1
|
||||
|
||||
return {
|
||||
"info": {
|
||||
"description": f"Annotations for {project.name}",
|
||||
"version": "1.0",
|
||||
"year": datetime.now().year,
|
||||
"date_created": datetime.now().isoformat(),
|
||||
},
|
||||
"images": images,
|
||||
"annotations": coco_annotations,
|
||||
"categories": categories,
|
||||
}
|
||||
|
||||
|
||||
def _class_mapping_entry(annotation: Annotation) -> dict[str, Any]:
|
||||
return {
|
||||
"key": _annotation_class_key(annotation),
|
||||
"className": _annotation_label(annotation),
|
||||
"chineseName": _annotation_label(annotation),
|
||||
"categoryName": _annotation_category_name(annotation),
|
||||
"color": _annotation_color(annotation),
|
||||
"internalPriority": _annotation_z_index(annotation),
|
||||
"maskidHint": _annotation_mask_id(annotation),
|
||||
"template_id": annotation.template_id,
|
||||
}
|
||||
|
||||
|
||||
def _build_gt_class_mapping(annotations: list[Annotation]) -> tuple[dict[str, int], list[dict[str, Any]]]:
|
||||
entries_by_key: dict[str, dict[str, Any]] = {}
|
||||
for annotation in annotations:
|
||||
if not annotation.mask_data or not annotation.mask_data.get("polygons"):
|
||||
continue
|
||||
entry = _class_mapping_entry(annotation)
|
||||
entries_by_key.setdefault(entry["key"], entry)
|
||||
|
||||
ordered = sorted(
|
||||
entries_by_key.values(),
|
||||
key=lambda item: (
|
||||
item["maskidHint"] if isinstance(item.get("maskidHint"), int) and item["maskidHint"] >= 0 else 10_000_000,
|
||||
str(item["className"]),
|
||||
str(item["key"]),
|
||||
),
|
||||
)
|
||||
key_to_value: dict[str, int] = {}
|
||||
classes: list[dict[str, Any]] = []
|
||||
used_maskids: set[int] = set()
|
||||
next_maskid = 1
|
||||
|
||||
def next_available_maskid() -> int:
|
||||
nonlocal next_maskid
|
||||
while next_maskid in used_maskids:
|
||||
next_maskid += 1
|
||||
if next_maskid > 255:
|
||||
raise HTTPException(status_code=400, detail="GT_label 仅支持 8-bit maskid,类别值必须在 1-255 之间")
|
||||
value = next_maskid
|
||||
used_maskids.add(value)
|
||||
next_maskid += 1
|
||||
return value
|
||||
|
||||
for entry in ordered:
|
||||
hinted_maskid = entry.get("maskidHint")
|
||||
if isinstance(hinted_maskid, int) and hinted_maskid > 255:
|
||||
raise HTTPException(status_code=400, detail="GT_label 仅支持 8-bit maskid,类别值必须在 1-255 之间")
|
||||
if isinstance(hinted_maskid, int) and hinted_maskid == 0:
|
||||
maskid = 0
|
||||
used_maskids.add(maskid)
|
||||
elif isinstance(hinted_maskid, int) and 0 < hinted_maskid <= 255 and hinted_maskid not in used_maskids:
|
||||
maskid = hinted_maskid
|
||||
used_maskids.add(maskid)
|
||||
else:
|
||||
maskid = next_available_maskid()
|
||||
key_to_value[entry["key"]] = maskid
|
||||
classes.append({
|
||||
"gt_pixel_value": maskid,
|
||||
"maskid": maskid,
|
||||
"chineseName": entry["chineseName"],
|
||||
"className": entry["className"],
|
||||
"categoryName": entry["categoryName"],
|
||||
"rgb": _hex_to_rgb(entry["color"]),
|
||||
"color": entry["color"],
|
||||
"key": entry["key"],
|
||||
"template_id": entry["template_id"],
|
||||
})
|
||||
return key_to_value, classes
|
||||
|
||||
|
||||
def _parse_result_outputs(mask_type: str, outputs: str | None) -> set[str]:
|
||||
allowed = {"separate", "gt_label", "pro_label", "mix_label"}
|
||||
if outputs:
|
||||
parsed = {item.strip() for item in outputs.split(",") if item.strip()}
|
||||
invalid = parsed - allowed
|
||||
if invalid:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid outputs: {', '.join(sorted(invalid))}")
|
||||
return parsed or allowed
|
||||
|
||||
if mask_type == "separate":
|
||||
return {"separate"}
|
||||
if mask_type == "gt_label":
|
||||
return {"gt_label"}
|
||||
if mask_type == "pro_label":
|
||||
return {"pro_label"}
|
||||
if mask_type == "mix_label":
|
||||
return {"mix_label"}
|
||||
return allowed
|
||||
|
||||
|
||||
def _write_original_frames(
|
||||
zf: zipfile.ZipFile,
|
||||
project: Project,
|
||||
frames: list[Frame],
|
||||
) -> dict[int, bytes]:
|
||||
image_bytes_by_frame: dict[int, bytes] = {}
|
||||
for frame in frames:
|
||||
image_bytes = download_file(frame.image_url)
|
||||
image_bytes_by_frame[frame.id] = image_bytes
|
||||
zf.writestr(
|
||||
f"原始图片/{_frame_export_stem(project, frame)}{_frame_image_extension(frame)}",
|
||||
image_bytes,
|
||||
)
|
||||
return image_bytes_by_frame
|
||||
|
||||
|
||||
def _decode_original_image(image_bytes: bytes | None, width: int, height: int) -> np.ndarray:
|
||||
import cv2
|
||||
|
||||
if image_bytes:
|
||||
decoded = cv2.imdecode(np.frombuffer(image_bytes, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
if decoded is not None:
|
||||
if decoded.shape[1] != width or decoded.shape[0] != height:
|
||||
decoded = cv2.resize(decoded, (width, height), interpolation=cv2.INTER_AREA)
|
||||
return decoded
|
||||
return np.zeros((height, width, 3), dtype=np.uint8)
|
||||
|
||||
|
||||
def _write_result_mask_outputs(
|
||||
zf: zipfile.ZipFile,
|
||||
project: Project,
|
||||
frames: list[Frame],
|
||||
annotations: list[Annotation],
|
||||
*,
|
||||
outputs: set[str],
|
||||
class_values: dict[str, int],
|
||||
class_mapping: list[dict[str, Any]],
|
||||
original_images: dict[int, bytes],
|
||||
mix_opacity: float,
|
||||
) -> None:
|
||||
import cv2
|
||||
|
||||
include_individual = "separate" in outputs
|
||||
include_semantic = "gt_label" in outputs
|
||||
include_pro_label = "pro_label" in outputs
|
||||
include_mix_label = "mix_label" in outputs
|
||||
class_rgb_by_key = {
|
||||
item["key"]: item.get("rgb") or _hex_to_rgb(item.get("color", "#ffffff"))
|
||||
for item in class_mapping
|
||||
}
|
||||
annotations_by_frame: dict[int, list[Annotation]] = {}
|
||||
selected_frame_ids = {frame.id for frame in frames}
|
||||
for annotation in annotations:
|
||||
if annotation.frame_id not in selected_frame_ids or not annotation.mask_data:
|
||||
continue
|
||||
if not annotation.mask_data.get("polygons"):
|
||||
continue
|
||||
annotations_by_frame.setdefault(annotation.frame_id, []).append(annotation)
|
||||
|
||||
for frame in frames:
|
||||
frame_annotations = annotations_by_frame.get(frame.id, [])
|
||||
if not frame_annotations:
|
||||
continue
|
||||
width = frame.width or 1920
|
||||
height = frame.height or 1080
|
||||
frame_stem = _frame_export_stem(project, frame)
|
||||
|
||||
if include_individual:
|
||||
class_masks: dict[str, np.ndarray] = {}
|
||||
class_meta: dict[str, dict[str, Any]] = {}
|
||||
for annotation in frame_annotations:
|
||||
key = _annotation_class_key(annotation)
|
||||
combined = class_masks.setdefault(key, np.zeros((height, width), dtype=np.uint8))
|
||||
for poly in (annotation.mask_data or {}).get("polygons", []):
|
||||
combined[:] = np.maximum(combined, _mask_from_polygon(poly, width, height))
|
||||
class_meta.setdefault(key, _class_mapping_entry(annotation))
|
||||
|
||||
folder = f"分开Mask分割结果/{frame_stem}_分别导出"
|
||||
for key, mask in sorted(class_masks.items(), key=lambda item: int(class_meta[item[0]]["internalPriority"])):
|
||||
meta = class_meta[key]
|
||||
maskid = class_values.get(key)
|
||||
if maskid is None:
|
||||
continue
|
||||
_, encoded = cv2.imencode(".png", mask)
|
||||
class_name = _safe_filename_part(meta["className"], "class")
|
||||
zf.writestr(
|
||||
f"{folder}/{frame_stem}_{class_name}_maskid{maskid}.png",
|
||||
encoded.tobytes(),
|
||||
)
|
||||
|
||||
needs_fused_output = include_semantic or include_pro_label or include_mix_label
|
||||
semantic = np.zeros((height, width), dtype=np.uint8) if needs_fused_output else None
|
||||
pro_label = np.zeros((height, width, 3), dtype=np.uint8) if (include_pro_label or include_mix_label) else None
|
||||
|
||||
if needs_fused_output:
|
||||
for annotation in sorted(frame_annotations, key=_annotation_z_index):
|
||||
key = _annotation_class_key(annotation)
|
||||
value = class_values.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
combined = np.zeros((height, width), dtype=np.uint8)
|
||||
for poly in (annotation.mask_data or {}).get("polygons", []):
|
||||
combined = np.maximum(combined, _mask_from_polygon(poly, width, height))
|
||||
if semantic is not None:
|
||||
semantic[combined > 0] = value
|
||||
if pro_label is not None:
|
||||
rgb = class_rgb_by_key.get(key, [255, 255, 255])
|
||||
bgr = np.array([rgb[2], rgb[1], rgb[0]], dtype=np.uint8)
|
||||
pro_label[combined > 0] = bgr
|
||||
|
||||
if include_semantic and semantic is not None:
|
||||
_, encoded = cv2.imencode(".png", semantic)
|
||||
zf.writestr(f"GT_label图/{frame_stem}.png", encoded.tobytes())
|
||||
|
||||
if include_pro_label and pro_label is not None:
|
||||
_, encoded = cv2.imencode(".png", pro_label)
|
||||
zf.writestr(f"Pro_label彩色分割结果/{frame_stem}.png", encoded.tobytes())
|
||||
|
||||
if include_mix_label and pro_label is not None:
|
||||
original = _decode_original_image(original_images.get(frame.id), width, height)
|
||||
mask_pixels = np.any(pro_label > 0, axis=2)
|
||||
mixed = original.copy()
|
||||
opacity = min(max(float(mix_opacity), 0.0), 1.0)
|
||||
mixed[mask_pixels] = (
|
||||
original[mask_pixels].astype(np.float32) * (1.0 - opacity)
|
||||
+ pro_label[mask_pixels].astype(np.float32) * opacity
|
||||
).clip(0, 255).astype(np.uint8)
|
||||
_, encoded = cv2.imencode(".png", mixed)
|
||||
zf.writestr(f"Mix_label重叠覆盖彩色分割结果/{frame_stem}.png", encoded.tobytes())
|
||||
|
||||
|
||||
def _write_mask_pngs(
|
||||
zf: zipfile.ZipFile,
|
||||
frames: list[Frame],
|
||||
annotations: list[Annotation],
|
||||
*,
|
||||
mask_type: str,
|
||||
individual_prefix: str = "",
|
||||
semantic_prefix: str = "",
|
||||
semantic_file_stem: str = "semantic_frame",
|
||||
semantic_dtype: Any = np.uint8,
|
||||
) -> list[dict[str, Any]]:
|
||||
import cv2
|
||||
|
||||
class_values: dict[str, int] = {}
|
||||
semantic_classes: list[dict[str, Any]] = []
|
||||
|
||||
def class_value(annotation: Annotation) -> int:
|
||||
key = _annotation_class_key(annotation)
|
||||
if key not in class_values:
|
||||
value = len(class_values) + 1
|
||||
class_values[key] = value
|
||||
semantic_classes.append({
|
||||
"value": value,
|
||||
"key": key,
|
||||
"label": _annotation_label(annotation),
|
||||
"color": _annotation_color(annotation),
|
||||
"zIndex": _annotation_z_index(annotation),
|
||||
"template_id": annotation.template_id,
|
||||
})
|
||||
return class_values[key]
|
||||
|
||||
include_individual = mask_type in {"separate", "both"}
|
||||
include_semantic = mask_type in {"gt_label", "both"}
|
||||
frame_masks: dict[int, list[tuple[Annotation, np.ndarray]]] = {}
|
||||
selected_frame_ids = {frame.id for frame in frames}
|
||||
|
||||
for ann in annotations:
|
||||
if ann.frame_id not in selected_frame_ids or not ann.mask_data:
|
||||
continue
|
||||
polygons = ann.mask_data.get("polygons", [])
|
||||
if not polygons:
|
||||
continue
|
||||
|
||||
width = ann.frame.width if ann.frame else 1920
|
||||
height = ann.frame.height if ann.frame else 1080
|
||||
combined = np.zeros((height, width), dtype=np.uint8)
|
||||
for poly in polygons:
|
||||
mask = _mask_from_polygon(poly, width, height)
|
||||
combined = np.maximum(combined, mask)
|
||||
|
||||
if include_individual:
|
||||
_, encoded = cv2.imencode(".png", combined)
|
||||
zf.writestr(f"{individual_prefix}mask_{ann.id:06d}.png", encoded.tobytes())
|
||||
if include_semantic and ann.frame_id is not None:
|
||||
frame_masks.setdefault(ann.frame_id, []).append((ann, combined))
|
||||
|
||||
if include_semantic:
|
||||
for frame in frames:
|
||||
entries = frame_masks.get(frame.id, [])
|
||||
if not entries:
|
||||
continue
|
||||
width = frame.width or 1920
|
||||
height = frame.height or 1080
|
||||
semantic = np.zeros((height, width), dtype=semantic_dtype)
|
||||
for ann, mask in sorted(entries, key=lambda item: _annotation_z_index(item[0])):
|
||||
semantic[mask > 0] = class_value(ann)
|
||||
_, encoded = cv2.imencode(".png", semantic)
|
||||
zf.writestr(f"{semantic_prefix}{semantic_file_stem}_{frame.frame_index:06d}.png", encoded.tobytes())
|
||||
|
||||
if include_semantic:
|
||||
zf.writestr(
|
||||
f"{semantic_prefix}semantic_classes.json",
|
||||
json.dumps({"classes": semantic_classes}, ensure_ascii=False, indent=2).encode("utf-8"),
|
||||
)
|
||||
return semantic_classes
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/coco",
|
||||
summary="Export annotations in COCO format",
|
||||
)
|
||||
def export_coco(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
"""Export all annotations for a project as a COCO-format JSON file."""
|
||||
project = _project_or_404(project_id, db, current_user)
|
||||
frames = _project_frames(project_id, db)
|
||||
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
|
||||
templates = db.query(Template).all()
|
||||
coco = _build_coco(project, frames, annotations, templates)
|
||||
|
||||
data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8")
|
||||
filename = f"project_{project_id}_coco.json"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(data),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/masks",
|
||||
summary="Export PNG masks as a ZIP archive",
|
||||
)
|
||||
def export_masks(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
"""Export individual masks plus z-index fused semantic masks inside a ZIP."""
|
||||
_project_or_404(project_id, db, current_user)
|
||||
frames = _project_frames(project_id, db)
|
||||
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
_write_mask_pngs(
|
||||
zf,
|
||||
frames,
|
||||
annotations,
|
||||
mask_type="both",
|
||||
semantic_prefix="",
|
||||
individual_prefix="",
|
||||
)
|
||||
|
||||
zip_buffer.seek(0)
|
||||
filename = f"project_{project_id}_masks.zip"
|
||||
|
||||
return StreamingResponse(
|
||||
zip_buffer,
|
||||
media_type="application/zip",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/results",
|
||||
summary="Export segmentation results as a ZIP archive",
|
||||
)
|
||||
def export_results(
|
||||
project_id: int,
|
||||
scope: str = Query("all", pattern="^(all|range|current)$"),
|
||||
mask_type: str = Query("both", pattern="^(separate|gt_label|pro_label|mix_label|both|all)$"),
|
||||
outputs: str | None = Query(None),
|
||||
mix_opacity: float = Query(0.3, ge=0.0, le=1.0),
|
||||
start_frame: int | None = Query(None, ge=1),
|
||||
end_frame: int | None = Query(None, ge=1),
|
||||
frame_id: int | None = Query(None, ge=1),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
"""Export JSON annotations plus selected PNG mask outputs inside one ZIP.
|
||||
|
||||
`scope=all` exports the whole video. `scope=range` uses 1-based frame
|
||||
numbers from the sorted project frame sequence. `scope=current` uses the
|
||||
concrete backend `frame_id`.
|
||||
"""
|
||||
project = _project_or_404(project_id, db, current_user)
|
||||
frames = _filter_frames(
|
||||
_project_frames(project_id, db),
|
||||
scope=scope,
|
||||
start_frame=start_frame,
|
||||
end_frame=end_frame,
|
||||
frame_id=frame_id,
|
||||
)
|
||||
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
|
||||
templates = db.query(Template).all()
|
||||
coco = _build_coco(project, frames, annotations, templates)
|
||||
class_values, class_mapping = _build_gt_class_mapping(annotations)
|
||||
selected_outputs = _parse_result_outputs(mask_type, outputs)
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
zf.writestr(
|
||||
"annotations_coco.json",
|
||||
json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8"),
|
||||
)
|
||||
zf.writestr(
|
||||
"maskid_GT像素值_类别映射.json",
|
||||
json.dumps({"classes": class_mapping}, ensure_ascii=False, indent=2).encode("utf-8"),
|
||||
)
|
||||
original_images = _write_original_frames(zf, project, frames)
|
||||
_write_result_mask_outputs(
|
||||
zf,
|
||||
project,
|
||||
frames,
|
||||
annotations,
|
||||
outputs=selected_outputs,
|
||||
class_values=class_values,
|
||||
class_mapping=class_mapping,
|
||||
original_images=original_images,
|
||||
mix_opacity=mix_opacity,
|
||||
)
|
||||
|
||||
zip_buffer.seek(0)
|
||||
filename = _segmentation_results_filename(project, frames)
|
||||
return StreamingResponse(
|
||||
zip_buffer,
|
||||
media_type="application/zip",
|
||||
headers={"Content-Disposition": _download_content_disposition(filename)},
|
||||
)
|
||||
234
backend/routers/media.py
Normal file
234
backend/routers/media.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""Media upload and parsing endpoints."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import upload_file, get_presigned_url
|
||||
from models import ProcessingTask, Project, User
|
||||
from progress_events import publish_task_progress_event
|
||||
from routers.auth import require_editor
|
||||
from schemas import ProcessingTaskOut
|
||||
from statuses import PROJECT_STATUS_PARSING, PROJECT_STATUS_PENDING, TASK_STATUS_QUEUED
|
||||
from worker_tasks import parse_project_media
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/media", tags=["Media"])
|
||||
|
||||
ALLOWED_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".png", ".jpg", ".jpeg", ".dcm"}
|
||||
|
||||
|
||||
def natural_filename_key(filename: str) -> tuple[object, ...]:
|
||||
return tuple(
|
||||
int(part) if part.isdigit() else part.casefold()
|
||||
for part in re.split(r"(\d+)", Path(filename).name)
|
||||
)
|
||||
|
||||
|
||||
def _get_ext(filename: str) -> str:
|
||||
return Path(filename).suffix.lower()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/upload",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Upload a media file",
|
||||
)
|
||||
async def upload_media(
|
||||
file: UploadFile = File(...),
|
||||
project_id: Optional[int] = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Accept a video, image, or DICOM file and store it in MinIO.
|
||||
|
||||
If project_id is provided, the video_path of the project is updated.
|
||||
Returns the presigned URL of the uploaded object.
|
||||
"""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="Missing filename")
|
||||
|
||||
ext = _get_ext(file.filename)
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported file type: {ext}",
|
||||
)
|
||||
|
||||
data = await file.read()
|
||||
object_name = f"uploads/{project_id or 'general'}/{file.filename}"
|
||||
|
||||
try:
|
||||
upload_file(object_name, data, content_type=file.content_type or "application/octet-stream", length=len(data))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Upload failed: %s", exc)
|
||||
raise HTTPException(status_code=500, detail="Upload to storage failed") from exc
|
||||
|
||||
file_url = get_presigned_url(object_name, expires=3600)
|
||||
|
||||
if project_id:
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
project.video_path = object_name
|
||||
db.commit()
|
||||
logger.info("Linked upload to project_id=%s", project_id)
|
||||
else:
|
||||
# Auto-create a project named after the file
|
||||
project = Project(
|
||||
name=file.filename,
|
||||
description="Auto-created from upload",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
video_path=object_name,
|
||||
source_type="video",
|
||||
owner_user_id=current_user.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
project_id = project.id
|
||||
object_name = f"uploads/{project_id}/{file.filename}"
|
||||
# Re-upload with corrected path
|
||||
upload_file(object_name, data, content_type=file.content_type or "application/octet-stream", length=len(data))
|
||||
project.video_path = object_name
|
||||
db.commit()
|
||||
logger.info("Auto-created project id=%s for upload %s", project_id, file.filename)
|
||||
|
||||
logger.info("Upload complete: %s (size=%d bytes). Async parsing queued.", object_name, len(data))
|
||||
|
||||
return {
|
||||
"object_name": object_name,
|
||||
"file_url": file_url,
|
||||
"size": len(data),
|
||||
"project_id": project_id,
|
||||
"message": "Upload successful. Parsing job queued.",
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/upload/dicom",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Upload multiple DICOM files",
|
||||
)
|
||||
async def upload_dicom_batch(
|
||||
files: List[UploadFile] = File(...),
|
||||
project_id: Optional[int] = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Upload multiple .dcm files for a DICOM series.
|
||||
|
||||
If project_id is provided, files are added to the existing project.
|
||||
Otherwise a new DICOM project is created.
|
||||
"""
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files uploaded")
|
||||
|
||||
sorted_files = sorted(
|
||||
[file for file in files if file.filename and file.filename.lower().endswith(".dcm")],
|
||||
key=lambda file: natural_filename_key(file.filename or ""),
|
||||
)
|
||||
if not sorted_files:
|
||||
raise HTTPException(status_code=400, detail="No valid DICOM files uploaded")
|
||||
uploaded = []
|
||||
|
||||
if project_id:
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
else:
|
||||
# Create new DICOM project
|
||||
first_name = sorted_files[0].filename or "DICOM_Series"
|
||||
project = Project(
|
||||
name=first_name,
|
||||
description=f"DICOM series with {len(sorted_files)} files",
|
||||
status=PROJECT_STATUS_PENDING,
|
||||
source_type="dicom",
|
||||
owner_user_id=current_user.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
project_id = project.id
|
||||
logger.info("Auto-created DICOM project id=%s", project_id)
|
||||
|
||||
for file in sorted_files:
|
||||
data = await file.read()
|
||||
object_name = f"uploads/{project_id}/dicom/{file.filename}"
|
||||
try:
|
||||
upload_file(object_name, data, content_type="application/dicom", length=len(data))
|
||||
uploaded.append(object_name)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to upload DICOM %s: %s", file.filename, exc)
|
||||
|
||||
project.video_path = f"uploads/{project_id}/dicom"
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"project_id": project_id,
|
||||
"uploaded_count": len(uploaded),
|
||||
"message": f"Uploaded {len(uploaded)} DICOM files. Parsing job queued.",
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/parse",
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
response_model=ProcessingTaskOut,
|
||||
summary="Trigger frame extraction",
|
||||
)
|
||||
def parse_media(
|
||||
project_id: int,
|
||||
source_type: Optional[str] = None,
|
||||
parse_fps: Optional[float] = Query(None, gt=0, le=120),
|
||||
max_frames: Optional[int] = Query(None, gt=0),
|
||||
target_width: int = Query(640, ge=64, le=4096),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> ProcessingTask:
|
||||
"""Create a background task for media frame extraction.
|
||||
|
||||
The Celery worker performs the heavy FFmpeg/OpenCV/pydicom work and
|
||||
updates the persisted task record as it progresses.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
if not project.video_path:
|
||||
raise HTTPException(status_code=400, detail="Project has no media uploaded")
|
||||
|
||||
effective_source = source_type or project.source_type or "video"
|
||||
effective_parse_fps = parse_fps or project.parse_fps or 30.0
|
||||
task = ProcessingTask(
|
||||
task_type=f"parse_{effective_source}",
|
||||
status=TASK_STATUS_QUEUED,
|
||||
progress=0,
|
||||
message="解析任务已入队",
|
||||
project_id=project_id,
|
||||
payload={
|
||||
"source_type": effective_source,
|
||||
"parse_fps": effective_parse_fps,
|
||||
"max_frames": max_frames,
|
||||
"target_width": target_width,
|
||||
},
|
||||
)
|
||||
project.parse_fps = effective_parse_fps
|
||||
project.status = PROJECT_STATUS_PARSING
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
|
||||
async_result = parse_project_media.delay(task.id)
|
||||
task.celery_task_id = async_result.id
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
|
||||
logger.info("Queued parse task id=%s project_id=%s celery_id=%s", task.id, project_id, async_result.id)
|
||||
return task
|
||||
310
backend/routers/projects.py
Normal file
310
backend/routers/projects.py
Normal file
@@ -0,0 +1,310 @@
|
||||
"""Project and Frame CRUD endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Annotation, Mask, Project, Frame, User
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import ProjectCopyRequest, ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut
|
||||
from minio_client import get_presigned_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/projects", tags=["Projects"])
|
||||
|
||||
|
||||
def _next_project_copy_name(db: Session, source_name: str) -> str:
|
||||
base_name = f"{source_name} 副本"
|
||||
existing_names = {
|
||||
row[0]
|
||||
for row in db.query(Project.name)
|
||||
.filter(Project.name.like(f"{base_name}%"))
|
||||
.all()
|
||||
}
|
||||
if base_name not in existing_names:
|
||||
return base_name
|
||||
suffix = 2
|
||||
while f"{base_name} {suffix}" in existing_names:
|
||||
suffix += 1
|
||||
return f"{base_name} {suffix}"
|
||||
|
||||
|
||||
def _prepare_project_response(project: Project) -> Project:
|
||||
project.frame_count = len(project.frames)
|
||||
if project.thumbnail_url:
|
||||
project.thumbnail_url = get_presigned_url(project.thumbnail_url, expires=3600)
|
||||
return project
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Projects
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post(
|
||||
"",
|
||||
response_model=ProjectOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a new project",
|
||||
)
|
||||
def create_project(
|
||||
payload: ProjectCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Project:
|
||||
"""Create a new segmentation project."""
|
||||
project = Project(**payload.model_dump(), owner_user_id=current_user.id)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
logger.info("Created project id=%s name=%s", project.id, project.name)
|
||||
return project
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=List[ProjectOut],
|
||||
summary="List all projects",
|
||||
)
|
||||
def list_projects(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[Project]:
|
||||
"""Retrieve a paginated list of projects."""
|
||||
projects = (
|
||||
db.query(Project)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
for p in projects:
|
||||
_prepare_project_response(p)
|
||||
return projects
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}",
|
||||
response_model=ProjectOut,
|
||||
summary="Get a single project",
|
||||
)
|
||||
def get_project(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Project:
|
||||
"""Retrieve a project by its ID."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
return _prepare_project_response(project)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{project_id}/copy",
|
||||
response_model=ProjectOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Copy a project",
|
||||
)
|
||||
def copy_project(
|
||||
project_id: int,
|
||||
payload: ProjectCopyRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Project:
|
||||
"""Copy a project. Reset copies media/frame sequence; full also copies annotations and mask metadata."""
|
||||
source = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
next_name = (payload.name or "").strip() if payload.name is not None else ""
|
||||
if not next_name:
|
||||
next_name = _next_project_copy_name(db, source.name)
|
||||
|
||||
copied = Project(
|
||||
name=next_name,
|
||||
description=source.description,
|
||||
video_path=source.video_path,
|
||||
thumbnail_url=source.thumbnail_url,
|
||||
status=source.status,
|
||||
source_type=source.source_type,
|
||||
original_fps=source.original_fps,
|
||||
parse_fps=source.parse_fps,
|
||||
owner_user_id=current_user.id,
|
||||
)
|
||||
db.add(copied)
|
||||
db.flush()
|
||||
|
||||
frame_id_map: dict[int, int] = {}
|
||||
for frame in sorted(source.frames, key=lambda item: item.frame_index):
|
||||
copied_frame = Frame(
|
||||
project_id=copied.id,
|
||||
frame_index=frame.frame_index,
|
||||
image_url=frame.image_url,
|
||||
width=frame.width,
|
||||
height=frame.height,
|
||||
timestamp_ms=frame.timestamp_ms,
|
||||
source_frame_number=frame.source_frame_number,
|
||||
)
|
||||
db.add(copied_frame)
|
||||
db.flush()
|
||||
frame_id_map[frame.id] = copied_frame.id
|
||||
|
||||
if payload.mode == "full":
|
||||
for annotation in sorted(source.annotations, key=lambda item: item.id):
|
||||
copied_annotation = Annotation(
|
||||
project_id=copied.id,
|
||||
frame_id=frame_id_map.get(annotation.frame_id) if annotation.frame_id is not None else None,
|
||||
template_id=annotation.template_id,
|
||||
mask_data=annotation.mask_data,
|
||||
points=annotation.points,
|
||||
bbox=annotation.bbox,
|
||||
)
|
||||
db.add(copied_annotation)
|
||||
db.flush()
|
||||
for mask in annotation.masks:
|
||||
db.add(Mask(
|
||||
annotation_id=copied_annotation.id,
|
||||
mask_url=mask.mask_url,
|
||||
format=mask.format,
|
||||
))
|
||||
|
||||
db.commit()
|
||||
db.refresh(copied)
|
||||
logger.info("Copied project id=%s to id=%s mode=%s", project_id, copied.id, payload.mode)
|
||||
return _prepare_project_response(copied)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{project_id}",
|
||||
response_model=ProjectOut,
|
||||
summary="Update a project",
|
||||
)
|
||||
def update_project(
|
||||
project_id: int,
|
||||
payload: ProjectUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Project:
|
||||
"""Update project fields partially."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
for key, value in payload.model_dump(exclude_unset=True).items():
|
||||
if key == "name":
|
||||
value = (value or "").strip()
|
||||
if not value:
|
||||
raise HTTPException(status_code=400, detail="Project name is required")
|
||||
setattr(project, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
logger.info("Updated project id=%s", project_id)
|
||||
return project
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{project_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a project",
|
||||
)
|
||||
def delete_project(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> None:
|
||||
"""Delete a project and all related frames and annotations."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
db.delete(project)
|
||||
db.commit()
|
||||
logger.info("Deleted project id=%s", project_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frames
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post(
|
||||
"/{project_id}/frames",
|
||||
response_model=FrameOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Add a frame to a project",
|
||||
)
|
||||
def create_frame(
|
||||
project_id: int,
|
||||
payload: FrameCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Frame:
|
||||
"""Register a new frame under a project."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
frame = Frame(project_id=project_id, **payload.model_dump(exclude={"project_id"}))
|
||||
db.add(frame)
|
||||
db.commit()
|
||||
db.refresh(frame)
|
||||
return frame
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/frames",
|
||||
response_model=List[FrameOut],
|
||||
summary="List frames for a project",
|
||||
)
|
||||
def list_frames(
|
||||
project_id: int,
|
||||
skip: int = 0,
|
||||
limit: Optional[int] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[Frame]:
|
||||
"""Retrieve all frames belonging to a project."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
query = (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id)
|
||||
.order_by(Frame.frame_index)
|
||||
.offset(skip)
|
||||
)
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
frames = query.all()
|
||||
for frame in frames:
|
||||
frame.image_url = get_presigned_url(frame.image_url, expires=3600)
|
||||
return frames
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/frames/{frame_id}",
|
||||
response_model=FrameOut,
|
||||
summary="Get a single frame",
|
||||
)
|
||||
def get_frame(
|
||||
project_id: int,
|
||||
frame_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Frame:
|
||||
"""Retrieve a specific frame by ID."""
|
||||
frame = (
|
||||
db.query(Frame)
|
||||
.join(Project, Project.id == Frame.project_id)
|
||||
.filter(
|
||||
Frame.project_id == project_id,
|
||||
Frame.id == frame_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
return frame
|
||||
161
backend/routers/tasks.py
Normal file
161
backend/routers/tasks.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Processing task query endpoints."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from celery_app import celery_app
|
||||
from database import get_db
|
||||
from models import ProcessingTask, Project, User
|
||||
from progress_events import publish_task_progress_event
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import ProcessingTaskOut
|
||||
from statuses import (
|
||||
PROJECT_STATUS_PARSING,
|
||||
PROJECT_STATUS_PENDING,
|
||||
PROJECT_STATUS_READY,
|
||||
TASK_ACTIVE_STATUSES,
|
||||
TASK_STATUS_CANCELLED,
|
||||
TASK_STATUS_FAILED,
|
||||
TASK_STATUS_QUEUED,
|
||||
)
|
||||
from worker_tasks import parse_project_media, propagate_project_masks
|
||||
|
||||
router = APIRouter(prefix="/api/tasks", tags=["Tasks"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _get_task_or_404(task_id: int, db: Session, current_user: User) -> ProcessingTask:
|
||||
_ = current_user
|
||||
task = (
|
||||
db.query(ProcessingTask)
|
||||
.outerjoin(Project, Project.id == ProcessingTask.project_id)
|
||||
.filter(ProcessingTask.id == task_id)
|
||||
.first()
|
||||
)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return task
|
||||
|
||||
|
||||
def _project_status_after_stop(project: Project) -> str:
|
||||
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
|
||||
|
||||
|
||||
@router.get("", response_model=List[ProcessingTaskOut], summary="List processing tasks")
|
||||
def list_tasks(
|
||||
project_id: int | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 50,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[ProcessingTask]:
|
||||
"""Return recent background processing tasks."""
|
||||
_ = current_user
|
||||
query = db.query(ProcessingTask).outerjoin(Project, Project.id == ProcessingTask.project_id)
|
||||
if project_id is not None:
|
||||
query = query.filter(ProcessingTask.project_id == project_id)
|
||||
if status is not None:
|
||||
query = query.filter(ProcessingTask.status == status)
|
||||
return query.order_by(ProcessingTask.created_at.desc()).limit(limit).all()
|
||||
|
||||
|
||||
@router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task")
|
||||
def get_task(
|
||||
task_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> ProcessingTask:
|
||||
"""Return one background task by id."""
|
||||
return _get_task_or_404(task_id, db, current_user)
|
||||
|
||||
|
||||
@router.post("/{task_id}/cancel", response_model=ProcessingTaskOut, summary="Cancel processing task")
|
||||
def cancel_task(
|
||||
task_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> ProcessingTask:
|
||||
"""Cancel a queued/running background task and revoke the Celery job when possible."""
|
||||
task = _get_task_or_404(task_id, db, current_user)
|
||||
if task.status not in TASK_ACTIVE_STATUSES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Task is not cancellable in status: {task.status}",
|
||||
)
|
||||
|
||||
if task.celery_task_id:
|
||||
try:
|
||||
celery_app.control.revoke(task.celery_task_id, terminate=True, signal="SIGTERM")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Failed to revoke celery task %s: %s", task.celery_task_id, exc)
|
||||
|
||||
task.status = TASK_STATUS_CANCELLED
|
||||
task.progress = 100
|
||||
task.message = "任务已取消"
|
||||
task.error = "Cancelled by user"
|
||||
task.finished_at = _now()
|
||||
if task.project:
|
||||
task.project.status = _project_status_after_stop(task.project)
|
||||
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
return task
|
||||
|
||||
|
||||
@router.post("/{task_id}/retry", response_model=ProcessingTaskOut, status_code=status.HTTP_202_ACCEPTED, summary="Retry processing task")
|
||||
def retry_task(
|
||||
task_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> ProcessingTask:
|
||||
"""Create a fresh queued task from a failed or cancelled task."""
|
||||
previous = _get_task_or_404(task_id, db, current_user)
|
||||
if previous.status not in {TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Task is not retryable in status: {previous.status}",
|
||||
)
|
||||
if previous.project_id is None:
|
||||
raise HTTPException(status_code=400, detail="Task has no project_id")
|
||||
|
||||
project = db.query(Project).filter(Project.id == previous.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
is_propagation_task = previous.task_type == "propagate_masks"
|
||||
if not is_propagation_task and not project.video_path:
|
||||
raise HTTPException(status_code=400, detail="Project has no media uploaded")
|
||||
|
||||
payload = dict(previous.payload or {})
|
||||
payload.setdefault("source_type", project.source_type or "video")
|
||||
payload["retry_of"] = previous.id
|
||||
|
||||
task = ProcessingTask(
|
||||
task_type=previous.task_type,
|
||||
status=TASK_STATUS_QUEUED,
|
||||
progress=0,
|
||||
message=f"重试任务已入队(源任务 #{previous.id})",
|
||||
project_id=project.id,
|
||||
payload=payload,
|
||||
)
|
||||
if not is_propagation_task:
|
||||
project.status = PROJECT_STATUS_PARSING
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
|
||||
async_result = propagate_project_masks.delay(task.id) if is_propagation_task else parse_project_media.delay(task.id)
|
||||
task.celery_task_id = async_result.id
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
return task
|
||||
183
backend/routers/templates.py
Normal file
183
backend/routers/templates.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Template (Ontology) CRUD endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Template, User
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import TemplateCreate, TemplateOut, TemplateUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/templates", tags=["Templates"])
|
||||
RESERVED_UNCLASSIFIED_CLASS = {
|
||||
"id": "reserved-unclassified",
|
||||
"name": "待分类",
|
||||
"color": "#000000",
|
||||
"zIndex": 0,
|
||||
"maskId": 0,
|
||||
"category": "系统保留",
|
||||
}
|
||||
|
||||
|
||||
def _is_reserved_class(item: dict) -> bool:
|
||||
return (
|
||||
item.get("id") == RESERVED_UNCLASSIFIED_CLASS["id"]
|
||||
or item.get("name") == RESERVED_UNCLASSIFIED_CLASS["name"]
|
||||
or item.get("maskId") == 0
|
||||
)
|
||||
|
||||
|
||||
def _normalize_template_classes(classes: list[dict] | None) -> list[dict]:
|
||||
normalized = [item for item in (classes or []) if not _is_reserved_class(item)]
|
||||
return [*normalized, dict(RESERVED_UNCLASSIFIED_CLASS)]
|
||||
|
||||
|
||||
def _pack_mapping_rules(data: dict) -> dict:
|
||||
"""Pack classes/rules into mapping_rules for DB storage."""
|
||||
mapping = data.get("mapping_rules") or {}
|
||||
if "classes" in data and data["classes"] is not None:
|
||||
mapping["classes"] = _normalize_template_classes(data.pop("classes"))
|
||||
if "rules" in data and data["rules"] is not None:
|
||||
mapping["rules"] = data.pop("rules")
|
||||
if "classes" in mapping:
|
||||
mapping["classes"] = _normalize_template_classes(mapping.get("classes"))
|
||||
data["mapping_rules"] = mapping
|
||||
return data
|
||||
|
||||
|
||||
def _unpack_template(template: Template) -> Template:
|
||||
"""Unpack mapping_rules into classes/rules for response."""
|
||||
mapping = template.mapping_rules or {}
|
||||
# Set as attributes so Pydantic from_attributes can pick them up
|
||||
template.classes = _normalize_template_classes(mapping.get("classes", []))
|
||||
template.rules = mapping.get("rules", [])
|
||||
return template
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=TemplateOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a new template",
|
||||
)
|
||||
def create_template(
|
||||
payload: TemplateCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Template:
|
||||
"""Create a new ontology template / segmentation class."""
|
||||
data = payload.model_dump()
|
||||
data = _pack_mapping_rules(data)
|
||||
template = Template(**data, owner_user_id=current_user.id)
|
||||
db.add(template)
|
||||
db.commit()
|
||||
db.refresh(template)
|
||||
_unpack_template(template)
|
||||
logger.info("Created template id=%s name=%s", template.id, template.name)
|
||||
return template
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=List[TemplateOut],
|
||||
summary="List all templates",
|
||||
)
|
||||
def list_templates(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[Template]:
|
||||
"""Retrieve all ontology templates."""
|
||||
templates = (
|
||||
db.query(Template)
|
||||
.filter(or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)))
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
for t in templates:
|
||||
_unpack_template(t)
|
||||
return templates
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{template_id}",
|
||||
response_model=TemplateOut,
|
||||
summary="Get a single template",
|
||||
)
|
||||
def get_template(
|
||||
template_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Template:
|
||||
"""Retrieve a template by its ID."""
|
||||
template = db.query(Template).filter(
|
||||
Template.id == template_id,
|
||||
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
_unpack_template(template)
|
||||
return template
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{template_id}",
|
||||
response_model=TemplateOut,
|
||||
summary="Update a template",
|
||||
)
|
||||
def update_template(
|
||||
template_id: int,
|
||||
payload: TemplateUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Template:
|
||||
"""Update template fields partially."""
|
||||
template = db.query(Template).filter(
|
||||
Template.id == template_id,
|
||||
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
data = payload.model_dump(exclude_unset=True)
|
||||
if "classes" in data or "rules" in data:
|
||||
data = _pack_mapping_rules(data)
|
||||
|
||||
for key, value in data.items():
|
||||
setattr(template, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(template)
|
||||
_unpack_template(template)
|
||||
logger.info("Updated template id=%s", template_id)
|
||||
return template
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{template_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a template",
|
||||
)
|
||||
def delete_template(
|
||||
template_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> None:
|
||||
"""Delete a template. Associated annotations will have template_id set to NULL."""
|
||||
template = db.query(Template).filter(
|
||||
Template.id == template_id,
|
||||
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
db.delete(template)
|
||||
db.commit()
|
||||
logger.info("Deleted template id=%s", template_id)
|
||||
Reference in New Issue
Block a user