feat: 完善分割工作区导入导出与管理流程

- 新增基于 JWT 当前用户的登录恢复、角色权限、用户管理、审计日志和演示出厂重置后台接口与前端管理页。

- 重串 GT_label 导出和 GT Mask 导入逻辑:导出保留类别真实 maskid,导入仅接受灰度或 RGB 等通道 maskid 图,支持未知 maskid 策略、尺寸最近邻拉伸和导入预览。

- 统一分割结果导出体验:默认当前帧,按项目抽帧顺序和 XhXXmXXsXXXms 时间戳命名 ZIP 与图片,补齐 GT/Pro/Mix/分开 Mask 输出和映射 JSON。

- 调整工作区左侧工具栏:移除创建点/线段入口,新增画笔、橡皮擦及尺寸控制,并按绘制、布尔、导入/AI 工具分组分隔。

- 扩展 Canvas 编辑能力:画笔按语义分类绘制并可自动并入连通选中 mask,橡皮擦对选中区域扣除,优化布尔操作、选区、撤销重做和保存状态联动。

- 优化自动传播时间轴显示:同一蓝色系按传播新旧递进变暗,老传播记录达到阈值后统一旧记录色,并维护范围选择与清空后的历史显示。

- 将 AI 智能分割入口替换为更明确的 AI 元素图标,并同步侧栏、工作区和 AI 页面入口表现。

- 完善模板分类、maskid 工具函数、分类树联动、遮罩透明度、边缘平滑和传播链同步相关前端状态。

- 扩展后端项目、媒体、任务、Dashboard、模板和传播 runner 的用户隔离、任务控制、进度事件与兼容处理。

- 补充前后端测试,覆盖用户管理、GT_label 往返导入导出、GT Mask 校验和预览、画笔/橡皮擦、时间轴传播历史、导出范围、WebSocket 与 API 封装。

- 更新 AGENTS、README 和 doc 文档,记录当前接口契约、实现状态、测试计划、安装说明和 maskid/GT_label 规则。
This commit is contained in:
2026-05-03 03:52:32 +08:00
parent 4c1d3dba73
commit afcddfaeb9
62 changed files with 6572 additions and 849 deletions

View File

@@ -1,6 +1,7 @@
"""AI inference endpoints using selectable SAM runtimes."""
import logging
import math
import tempfile
from pathlib import Path
from typing import Any, List
@@ -8,11 +9,13 @@ from typing import Any, List
import cv2
import numpy as np
from fastapi import APIRouter, Depends, File, Form, HTTPException, Response, UploadFile, status
from sqlalchemy import or_
from sqlalchemy.orm import Session
from database import get_db
from minio_client import download_file
from models import Project, Frame, Template, Annotation, ProcessingTask
from models import Project, Frame, Template, Annotation, ProcessingTask, User
from routers.auth import get_current_user, require_editor
from schemas import (
AiRuntimeStatus,
MaskAnalysisRequest,
@@ -38,6 +41,102 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/ai", tags=["AI"])
def _owned_project_or_404(project_id: int, db: Session, current_user: User) -> Project:
project = db.query(Project).filter(
Project.id == project_id,
Project.owner_user_id == current_user.id,
).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return project
def _owned_frame_or_404(frame_id: int, db: Session, current_user: User, project_id: int | None = None) -> Frame:
query = (
db.query(Frame)
.join(Project, Project.id == Frame.project_id)
.filter(Frame.id == frame_id, Project.owner_user_id == current_user.id)
)
if project_id is not None:
query = query.filter(Frame.project_id == project_id)
frame = query.first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
return frame
def _visible_template_or_404(template_id: int, db: Session, current_user: User) -> Template:
template = db.query(Template).filter(
Template.id == template_id,
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
return template
def _normalize_hex_color(value: Any) -> str | None:
if not isinstance(value, str):
return None
text = value.strip().lower()
if not text:
return None
if not text.startswith("#"):
text = f"#{text}"
if len(text) == 4:
text = "#" + "".join(char * 2 for char in text[1:])
if len(text) != 7:
return None
try:
int(text[1:], 16)
except ValueError:
return None
return text
def _rgb_tuple_to_hex(rgb: tuple[int, int, int]) -> str:
values = []
for channel in rgb:
value = int(channel)
if value > 255:
value = int(round(value / 257))
values.append(min(max(value, 0), 255))
return f"#{values[0]:02x}{values[1]:02x}{values[2]:02x}"
def _template_class_maps(template: Template | None) -> tuple[dict[int, dict[str, Any]], dict[str, dict[str, Any]]]:
classes = ((template.mapping_rules or {}).get("classes") if template else None) or []
by_maskid: dict[int, dict[str, Any]] = {}
by_color: dict[str, dict[str, Any]] = {}
for index, item in enumerate(classes):
if not isinstance(item, dict):
continue
maskid_value = item.get("maskId", item.get("maskid", item.get("mask_id")))
try:
maskid = int(maskid_value)
except (TypeError, ValueError):
maskid = index + 1
color = _normalize_hex_color(item.get("color")) or "#22c55e"
class_meta = {
"id": str(item.get("id") or f"maskid-{maskid}"),
"name": str(item.get("name") or f"类别 {maskid}"),
"color": color,
"zIndex": int(item.get("zIndex", item.get("z_index", index * 10))),
"maskId": maskid,
**({"category": item.get("category")} if item.get("category") else {}),
}
if maskid > 0:
by_maskid[maskid] = class_meta
by_color[color] = class_meta
return by_maskid, by_color
def _gt_unknown_label(token: int | str) -> str:
if isinstance(token, int):
return f"未定义类别 {token}"
return f"未定义颜色 {token}"
def _load_frame_image(frame: Frame) -> np.ndarray:
"""Download a frame from MinIO and decode it to an RGB numpy array."""
try:
@@ -106,16 +205,20 @@ def _normalize_polygons(polygons: list[list[list[float]]]) -> list[list[list[flo
return [polygon for polygon in (_normalize_polygon(polygon) for polygon in polygons) if len(polygon) >= 3]
def _analysis_anchors(polygons: list[list[list[float]]], points: list[list[float]] | None) -> list[list[float]]:
if points:
return [[_clamp01(point[0]), _clamp01(point[1])] for point in points if len(point) >= 2]
def _sample_anchor_points(anchors: list[list[float]], limit: int = 64) -> list[list[float]]:
if len(anchors) <= limit:
return anchors
step = max(1, math.ceil(len(anchors) / limit))
return anchors[::step][:limit]
def _analysis_anchor_summary(polygons: list[list[list[float]]]) -> tuple[int, list[list[float]]]:
anchors: list[list[float]] = []
for polygon in polygons:
if not polygon:
continue
step = max(1, len(polygon) // 12)
anchors.extend([[_clamp01(point[0]), _clamp01(point[1])] for point in polygon[::step]])
return anchors[:32]
anchors.extend([[_clamp01(point[0]), _clamp01(point[1])] for point in polygon])
return len(anchors), _sample_anchor_points(anchors)
def _normalize_smoothing_options(strength: float | int | None, method: str | None = None) -> dict[str, Any]:
@@ -129,8 +232,14 @@ def _normalize_smoothing_options(strength: float | int | None, method: str | Non
}
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int) -> list[list[float]]:
def _smoothing_ratio(strength: float, curve: float = 1.65) -> float:
normalized = max(0.0, min(float(strength or 0.0), 100.0)) / 100.0
return normalized ** curve
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int, corner_cut: float = 0.25) -> list[list[float]]:
points = polygon
q = max(0.02, min(float(corner_cut), 0.25))
for _ in range(max(0, iterations)):
if len(points) < 3:
break
@@ -138,12 +247,12 @@ def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int) -> list
for index, current in enumerate(points):
following = points[(index + 1) % len(points)]
next_points.append([
_clamp01(0.75 * current[0] + 0.25 * following[0]),
_clamp01(0.75 * current[1] + 0.25 * following[1]),
_clamp01((1.0 - q) * current[0] + q * following[0]),
_clamp01((1.0 - q) * current[1] + q * following[1]),
])
next_points.append([
_clamp01(0.25 * current[0] + 0.75 * following[0]),
_clamp01(0.25 * current[1] + 0.75 * following[1]),
_clamp01(q * current[0] + (1.0 - q) * following[0]),
_clamp01(q * current[1] + (1.0 - q) * following[1]),
])
points = next_points
return points
@@ -154,7 +263,7 @@ def _simplify_polygon(polygon: list[list[float]], strength: float) -> list[list[
return polygon
contour = np.array([[[point[0], point[1]]] for point in polygon], dtype=np.float32)
arc_length = cv2.arcLength(contour, True)
epsilon = arc_length * (0.001 + (strength / 100.0) * 0.006)
epsilon = arc_length * (0.00015 + _smoothing_ratio(strength) * 0.00735)
approx = cv2.approxPolyDP(contour, epsilon, True).reshape(-1, 2)
if len(approx) < 3:
return polygon
@@ -165,9 +274,25 @@ def _smooth_polygon(polygon: list[list[float]], smoothing: dict[str, Any]) -> li
strength = float(smoothing.get("strength") or 0.0)
if strength <= 0:
return _normalize_polygon(polygon)
iterations = max(1, min(3, int(strength // 35) + 1))
smoothed = _chaikin_smooth_polygon(_normalize_polygon(polygon), iterations)
simplified = _simplify_polygon(smoothed, strength)
effective_strength = _smoothing_ratio(strength, curve=1.45) * 100.0
if effective_strength >= 85:
iterations = 4
elif effective_strength >= 55:
iterations = 3
elif effective_strength >= 25:
iterations = 2
else:
iterations = 1
corner_cut = 0.03 + _smoothing_ratio(strength, curve=1.35) * 0.22
normalized = _normalize_polygon(polygon)
pre_simplified = _simplify_polygon(normalized, effective_strength * 0.25)
smoothed = _chaikin_smooth_polygon(pre_simplified, iterations, corner_cut)
simplified = _simplify_polygon(smoothed, effective_strength)
if len(simplified) > len(normalized):
for fallback_strength in (25.0, 35.0, 50.0, 70.0, 90.0, 100.0):
simplified = _simplify_polygon(simplified, max(effective_strength, fallback_strength))
if len(simplified) <= len(normalized):
break
return simplified if len(simplified) >= 3 else _normalize_polygon(polygon)
@@ -321,7 +446,11 @@ def _filter_predictions(
response_model=PredictResponse,
summary="Run SAM inference with a prompt",
)
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
def predict(
payload: PredictRequest,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> dict:
"""Execute selected SAM segmentation given an image and a prompt.
- **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized
@@ -330,9 +459,7 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
- **interactive**: `prompt_data` is `{ "box": [...], "points": [[x, y]], "labels": [1, 0] }`.
- **semantic**: disabled in the current SAM 2.1 point/box product flow.
"""
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
frame = _owned_frame_or_404(payload.image_id, db, current_user)
image = _load_frame_image(frame)
prompt_type = payload.prompt_type.lower()
@@ -478,7 +605,10 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
response_model=AiRuntimeStatus,
summary="Get SAM model and GPU runtime status",
)
def model_status(selected_model: str | None = None) -> dict:
def model_status(
selected_model: str | None = None,
_current_user: User = Depends(get_current_user),
) -> dict:
"""Return real runtime availability for GPU and the currently enabled SAM model."""
try:
return sam_registry.runtime_status(selected_model)
@@ -491,12 +621,14 @@ def model_status(selected_model: str | None = None) -> dict:
response_model=MaskAnalysisResponse,
summary="Analyze mask geometry and prompt anchors",
)
def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) -> dict:
def analyze_mask(
payload: MaskAnalysisRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> dict:
"""Return backend-computed mask properties for the frontend inspector."""
if payload.frame_id is not None:
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
_owned_frame_or_404(payload.frame_id, db, current_user)
mask_data = payload.mask_data or {}
polygons = mask_data.get("polygons") or []
@@ -521,13 +653,13 @@ def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) ->
else:
confidence_source = "manual_or_imported"
anchors = _analysis_anchors(valid_polygons, payload.points)
anchor_count, anchors = _analysis_anchor_summary(valid_polygons)
message = "已从后端重新提取几何拓扑锚点" if payload.extract_skeleton else "已读取后端几何属性"
return {
"confidence": confidence,
"confidence_source": confidence_source,
"topology_anchor_count": len(anchors),
"topology_anchor_count": anchor_count,
"topology_anchors": anchors,
"area": area,
"bbox": bbox,
@@ -541,16 +673,18 @@ def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) ->
response_model=SmoothMaskResponse,
summary="Smooth editable mask polygons with backend geometry rules",
)
def smooth_mask(payload: SmoothMaskRequest, db: Session = Depends(get_db)) -> dict:
def smooth_mask(
payload: SmoothMaskRequest,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> dict:
"""Return a smoothed polygon mask without persisting it.
The frontend keeps this as an explicit edit operation: users preview/apply it
to the current mask, then save through the normal annotation endpoint.
"""
if payload.frame_id is not None:
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
_owned_frame_or_404(payload.frame_id, db, current_user)
polygons = payload.mask_data.get("polygons") or []
valid_polygons = _normalize_polygons(polygons)
@@ -564,10 +698,10 @@ def smooth_mask(payload: SmoothMaskRequest, db: Session = Depends(get_db)) -> di
area = sum(_polygon_area(polygon) for polygon in smoothed_polygons)
bbox = _polygon_bbox(smoothed_polygons[0])
anchors = _analysis_anchors(smoothed_polygons, payload.points)
anchor_count, anchors = _analysis_anchor_summary(smoothed_polygons)
return {
"polygons": smoothed_polygons,
"topology_anchor_count": len(anchors),
"topology_anchor_count": anchor_count,
"topology_anchors": anchors,
"area": area,
"bbox": bbox,
@@ -581,7 +715,11 @@ def smooth_mask(payload: SmoothMaskRequest, db: Session = Depends(get_db)) -> di
response_model=PropagateResponse,
summary="Propagate one current-frame region across a video frame segment",
)
def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
def propagate(
payload: PropagateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> dict:
"""Track one selected region from the current frame across nearby frames.
SAM 2 uses the official video predictor with the selected mask as the seed.
@@ -592,16 +730,8 @@ def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
raise HTTPException(status_code=400, detail="direction must be forward, backward, or both")
max_frames = max(1, min(int(payload.max_frames or 30), 500))
project = db.query(Project).filter(Project.id == payload.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
source_frame = db.query(Frame).filter(
Frame.id == payload.frame_id,
Frame.project_id == payload.project_id,
).first()
if not source_frame:
raise HTTPException(status_code=404, detail="Frame not found")
_owned_project_or_404(payload.project_id, db, current_user)
source_frame = _owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
seed = payload.seed.model_dump(exclude_none=True)
polygons = seed.get("polygons") or []
@@ -709,18 +839,14 @@ def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
response_model=ProcessingTaskOut,
summary="Queue a background video propagation task",
)
def queue_propagate_task(payload: PropagateTaskRequest, db: Session = Depends(get_db)) -> ProcessingTaskOut:
def queue_propagate_task(
payload: PropagateTaskRequest,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> ProcessingTaskOut:
"""Queue multiple seed/direction propagation steps as one background task."""
project = db.query(Project).filter(Project.id == payload.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
source_frame = db.query(Frame).filter(
Frame.id == payload.frame_id,
Frame.project_id == payload.project_id,
).first()
if not source_frame:
raise HTTPException(status_code=404, detail="Frame not found")
_owned_project_or_404(payload.project_id, db, current_user)
source_frame = _owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
if not payload.steps:
raise HTTPException(status_code=400, detail="Propagation task requires at least one step")
@@ -768,11 +894,13 @@ def queue_propagate_task(payload: PropagateTaskRequest, db: Session = Depends(ge
response_model=PredictResponse,
summary="Run automatic segmentation",
)
def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
def auto_segment(
image_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> dict:
"""Run automatic mask generation on a frame using a grid of point prompts."""
frame = db.query(Frame).filter(Frame.id == image_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
frame = _owned_frame_or_404(image_id, db, current_user)
image = _load_frame_image(frame)
try:
@@ -792,16 +920,15 @@ def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
def save_annotation(
payload: AnnotationCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Annotation:
"""Persist an annotation (mask, points, bbox) into the database."""
project = db.query(Project).filter(Project.id == payload.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
_owned_project_or_404(payload.project_id, db, current_user)
if payload.frame_id:
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
_owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
if payload.template_id:
_visible_template_or_404(payload.template_id, db, current_user)
annotation = Annotation(**payload.model_dump())
db.add(annotation)
@@ -823,8 +950,10 @@ async def import_gt_mask(
template_id: int | None = Form(None),
label: str = Form("GT Mask"),
color: str = Form("#22c55e"),
unknown_color_policy: str = Form("undefined"),
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> List[Annotation]:
"""Convert a binary/label mask image into persisted polygon annotations.
@@ -833,62 +962,122 @@ async def import_gt_mask(
the frontend an editable point-region representation instead of a static
bitmap layer.
"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
_owned_project_or_404(project_id, db, current_user)
frame = _owned_frame_or_404(frame_id, db, current_user, project_id)
frame = db.query(Frame).filter(Frame.id == frame_id, Frame.project_id == project_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
if unknown_color_policy not in {"discard", "undefined"}:
raise HTTPException(status_code=400, detail="unknown_color_policy must be discard or undefined")
template: Template | None = None
if template_id is not None:
template = db.query(Template).filter(Template.id == template_id).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
template = _visible_template_or_404(template_id, db, current_user)
data = await file.read()
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
if image is None:
raise HTTPException(status_code=400, detail="Invalid mask image")
if image.ndim == 2:
label_image = image
elif image.ndim == 3 and image.shape[2] >= 3:
channels = image[:, :, :3]
# GT label images are maskid maps: either grayscale or RGB/BGR where
# all three color channels contain the same maskid value [X, X, X].
if not (np.array_equal(channels[:, :, 0], channels[:, :, 1]) and np.array_equal(channels[:, :, 1], channels[:, :, 2])):
raise HTTPException(
status_code=400,
detail="GT Mask 图片不符合要求:请上传灰度图,或 RGB 三通道完全相同的 maskid 图(背景 0像素值为 maskid",
)
label_image = channels[:, :, 0]
else:
raise HTTPException(
status_code=400,
detail="GT Mask 图片不符合要求:请上传灰度图,或 RGB 三通道完全相同的 maskid 图(背景 0像素值为 maskid",
)
width = int(frame.width or image.shape[1])
height = int(frame.height or image.shape[0])
label_values = [int(value) for value in np.unique(image) if int(value) > 0]
if not label_values:
original_height, original_width = int(label_image.shape[0]), int(label_image.shape[1])
resized_to_frame = original_width != width or original_height != height
if resized_to_frame:
label_image = cv2.resize(label_image, (width, height), interpolation=cv2.INTER_NEAREST)
by_maskid, _by_color = _template_class_maps(template)
has_template_classes = bool(by_maskid)
fallback_color = _normalize_hex_color(color) or "#22c55e"
import_items: list[dict[str, Any]] = []
skipped_unknown = 0
label_values = [int(value) for value in np.unique(label_image) if int(value) > 0]
for label_value in label_values:
class_meta = by_maskid.get(label_value)
is_unknown = has_template_classes and class_meta is None
if is_unknown and unknown_color_policy == "discard":
skipped_unknown += 1
continue
if class_meta:
annotation_label = class_meta["name"]
annotation_color = class_meta["color"]
elif is_unknown:
annotation_label = _gt_unknown_label(label_value)
annotation_color = fallback_color
else:
annotation_label = f"{label} {label_value}" if len(label_values) > 1 else label
annotation_color = fallback_color
import_items.append({
"token": label_value,
"binary": np.where(label_image == label_value, 255, 0).astype(np.uint8),
"label": annotation_label,
"color": annotation_color,
"class": class_meta,
"unknown": is_unknown,
"metadata": {
"gt_label_value": label_value,
"gt_original_size": {"width": original_width, "height": original_height},
"gt_resized_to_frame": resized_to_frame,
},
})
if not import_items:
if skipped_unknown > 0:
raise HTTPException(status_code=400, detail="No matching GT mask classes found")
raise HTTPException(status_code=400, detail="No foreground mask regions found")
has_multiple_labels = len(label_values) > 1
annotations: list[Annotation] = []
for label_value in label_values:
binary = np.where(image == label_value, 255, 0).astype(np.uint8)
for item in import_items:
binary = item["binary"]
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
annotation_label = f"{label} {label_value}" if has_multiple_labels else label
for contour in contours:
if cv2.contourArea(contour) < 1:
continue
polygon = _normalized_contour(contour, image.shape[1], image.shape[0])
polygon = _normalized_contour(contour, binary.shape[1], binary.shape[0])
if len(polygon) < 3:
continue
component = np.zeros_like(binary, dtype=np.uint8)
cv2.drawContours(component, [contour], -1, 1, thickness=-1)
seed_point = _component_seed_point(component, image.shape[1], image.shape[0])
bbox = _contour_bbox(contour, image.shape[1], image.shape[0])
seed_point = _component_seed_point(component, binary.shape[1], binary.shape[0])
bbox = _contour_bbox(contour, binary.shape[1], binary.shape[0])
mask_data = {
"polygons": [polygon],
"label": item["label"],
"color": item["color"],
"source": "gt_mask",
"image_size": {"width": width, "height": height},
**item["metadata"],
}
if item["class"]:
mask_data["class"] = item["class"]
if item["unknown"]:
mask_data["gt_unknown_class"] = True
annotation = Annotation(
project_id=project_id,
frame_id=frame_id,
template_id=template_id,
mask_data={
"polygons": [polygon],
"label": annotation_label,
"color": color,
"source": "gt_mask",
"gt_label_value": label_value,
"image_size": {"width": width, "height": height},
},
mask_data=mask_data,
points=[seed_point],
bbox=bbox,
)
@@ -914,14 +1103,14 @@ def list_annotations(
project_id: int,
frame_id: int | None = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> List[Annotation]:
"""Return persisted annotations for a project, optionally scoped to one frame."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
_owned_project_or_404(project_id, db, current_user)
query = db.query(Annotation).filter(Annotation.project_id == project_id)
if frame_id is not None:
_owned_frame_or_404(frame_id, db, current_user, project_id)
query = query.filter(Annotation.frame_id == frame_id)
return query.order_by(Annotation.id).all()
@@ -935,17 +1124,21 @@ def update_annotation(
annotation_id: int,
payload: AnnotationUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Annotation:
"""Update mutable annotation fields persisted in the database."""
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
annotation = (
db.query(Annotation)
.join(Project, Project.id == Annotation.project_id)
.filter(Annotation.id == annotation_id, Project.owner_user_id == current_user.id)
.first()
)
if not annotation:
raise HTTPException(status_code=404, detail="Annotation not found")
updates = payload.model_dump(exclude_unset=True)
if "template_id" in updates and updates["template_id"] is not None:
template = db.query(Template).filter(Template.id == updates["template_id"]).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
_visible_template_or_404(updates["template_id"], db, current_user)
for field, value in updates.items():
setattr(annotation, field, value)
@@ -964,9 +1157,15 @@ def update_annotation(
def delete_annotation(
annotation_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Response:
"""Delete an annotation and its derived mask rows through ORM cascade."""
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
annotation = (
db.query(Annotation)
.join(Project, Project.id == Annotation.project_id)
.filter(Annotation.id == annotation_id, Project.owner_user_id == current_user.id)
.first()
)
if not annotation:
raise HTTPException(status_code=404, detail="Annotation not found")