feat: 完善分割工作区导入导出与管理流程
- 新增基于 JWT 当前用户的登录恢复、角色权限、用户管理、审计日志和演示出厂重置后台接口与前端管理页。 - 重串 GT_label 导出和 GT Mask 导入逻辑:导出保留类别真实 maskid,导入仅接受灰度或 RGB 等通道 maskid 图,支持未知 maskid 策略、尺寸最近邻拉伸和导入预览。 - 统一分割结果导出体验:默认当前帧,按项目抽帧顺序和 XhXXmXXsXXXms 时间戳命名 ZIP 与图片,补齐 GT/Pro/Mix/分开 Mask 输出和映射 JSON。 - 调整工作区左侧工具栏:移除创建点/线段入口,新增画笔、橡皮擦及尺寸控制,并按绘制、布尔、导入/AI 工具分组分隔。 - 扩展 Canvas 编辑能力:画笔按语义分类绘制并可自动并入连通选中 mask,橡皮擦对选中区域扣除,优化布尔操作、选区、撤销重做和保存状态联动。 - 优化自动传播时间轴显示:同一蓝色系按传播新旧递进变暗,老传播记录达到阈值后统一旧记录色,并维护范围选择与清空后的历史显示。 - 将 AI 智能分割入口替换为更明确的 AI 元素图标,并同步侧栏、工作区和 AI 页面入口表现。 - 完善模板分类、maskid 工具函数、分类树联动、遮罩透明度、边缘平滑和传播链同步相关前端状态。 - 扩展后端项目、媒体、任务、Dashboard、模板和传播 runner 的用户隔离、任务控制、进度事件与兼容处理。 - 补充前后端测试,覆盖用户管理、GT_label 往返导入导出、GT Mask 校验和预览、画笔/橡皮擦、时间轴传播历史、导出范围、WebSocket 与 API 封装。 - 更新 AGENTS、README 和 doc 文档,记录当前接口契约、实现状态、测试计划、安装说明和 maskid/GT_label 规则。
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""AI inference endpoints using selectable SAM runtimes."""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
@@ -8,11 +9,13 @@ from typing import Any, List
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Response, UploadFile, status
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import download_file
|
||||
from models import Project, Frame, Template, Annotation, ProcessingTask
|
||||
from models import Project, Frame, Template, Annotation, ProcessingTask, User
|
||||
from routers.auth import get_current_user, require_editor
|
||||
from schemas import (
|
||||
AiRuntimeStatus,
|
||||
MaskAnalysisRequest,
|
||||
@@ -38,6 +41,102 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/ai", tags=["AI"])
|
||||
|
||||
|
||||
def _owned_project_or_404(project_id: int, db: Session, current_user: User) -> Project:
|
||||
project = db.query(Project).filter(
|
||||
Project.id == project_id,
|
||||
Project.owner_user_id == current_user.id,
|
||||
).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
return project
|
||||
|
||||
|
||||
def _owned_frame_or_404(frame_id: int, db: Session, current_user: User, project_id: int | None = None) -> Frame:
|
||||
query = (
|
||||
db.query(Frame)
|
||||
.join(Project, Project.id == Frame.project_id)
|
||||
.filter(Frame.id == frame_id, Project.owner_user_id == current_user.id)
|
||||
)
|
||||
if project_id is not None:
|
||||
query = query.filter(Frame.project_id == project_id)
|
||||
frame = query.first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
return frame
|
||||
|
||||
|
||||
def _visible_template_or_404(template_id: int, db: Session, current_user: User) -> Template:
|
||||
template = db.query(Template).filter(
|
||||
Template.id == template_id,
|
||||
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
return template
|
||||
|
||||
|
||||
def _normalize_hex_color(value: Any) -> str | None:
|
||||
if not isinstance(value, str):
|
||||
return None
|
||||
text = value.strip().lower()
|
||||
if not text:
|
||||
return None
|
||||
if not text.startswith("#"):
|
||||
text = f"#{text}"
|
||||
if len(text) == 4:
|
||||
text = "#" + "".join(char * 2 for char in text[1:])
|
||||
if len(text) != 7:
|
||||
return None
|
||||
try:
|
||||
int(text[1:], 16)
|
||||
except ValueError:
|
||||
return None
|
||||
return text
|
||||
|
||||
|
||||
def _rgb_tuple_to_hex(rgb: tuple[int, int, int]) -> str:
|
||||
values = []
|
||||
for channel in rgb:
|
||||
value = int(channel)
|
||||
if value > 255:
|
||||
value = int(round(value / 257))
|
||||
values.append(min(max(value, 0), 255))
|
||||
return f"#{values[0]:02x}{values[1]:02x}{values[2]:02x}"
|
||||
|
||||
|
||||
def _template_class_maps(template: Template | None) -> tuple[dict[int, dict[str, Any]], dict[str, dict[str, Any]]]:
|
||||
classes = ((template.mapping_rules or {}).get("classes") if template else None) or []
|
||||
by_maskid: dict[int, dict[str, Any]] = {}
|
||||
by_color: dict[str, dict[str, Any]] = {}
|
||||
for index, item in enumerate(classes):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
maskid_value = item.get("maskId", item.get("maskid", item.get("mask_id")))
|
||||
try:
|
||||
maskid = int(maskid_value)
|
||||
except (TypeError, ValueError):
|
||||
maskid = index + 1
|
||||
color = _normalize_hex_color(item.get("color")) or "#22c55e"
|
||||
class_meta = {
|
||||
"id": str(item.get("id") or f"maskid-{maskid}"),
|
||||
"name": str(item.get("name") or f"类别 {maskid}"),
|
||||
"color": color,
|
||||
"zIndex": int(item.get("zIndex", item.get("z_index", index * 10))),
|
||||
"maskId": maskid,
|
||||
**({"category": item.get("category")} if item.get("category") else {}),
|
||||
}
|
||||
if maskid > 0:
|
||||
by_maskid[maskid] = class_meta
|
||||
by_color[color] = class_meta
|
||||
return by_maskid, by_color
|
||||
|
||||
|
||||
def _gt_unknown_label(token: int | str) -> str:
|
||||
if isinstance(token, int):
|
||||
return f"未定义类别 {token}"
|
||||
return f"未定义颜色 {token}"
|
||||
|
||||
|
||||
def _load_frame_image(frame: Frame) -> np.ndarray:
|
||||
"""Download a frame from MinIO and decode it to an RGB numpy array."""
|
||||
try:
|
||||
@@ -106,16 +205,20 @@ def _normalize_polygons(polygons: list[list[list[float]]]) -> list[list[list[flo
|
||||
return [polygon for polygon in (_normalize_polygon(polygon) for polygon in polygons) if len(polygon) >= 3]
|
||||
|
||||
|
||||
def _analysis_anchors(polygons: list[list[list[float]]], points: list[list[float]] | None) -> list[list[float]]:
|
||||
if points:
|
||||
return [[_clamp01(point[0]), _clamp01(point[1])] for point in points if len(point) >= 2]
|
||||
def _sample_anchor_points(anchors: list[list[float]], limit: int = 64) -> list[list[float]]:
|
||||
if len(anchors) <= limit:
|
||||
return anchors
|
||||
step = max(1, math.ceil(len(anchors) / limit))
|
||||
return anchors[::step][:limit]
|
||||
|
||||
|
||||
def _analysis_anchor_summary(polygons: list[list[list[float]]]) -> tuple[int, list[list[float]]]:
|
||||
anchors: list[list[float]] = []
|
||||
for polygon in polygons:
|
||||
if not polygon:
|
||||
continue
|
||||
step = max(1, len(polygon) // 12)
|
||||
anchors.extend([[_clamp01(point[0]), _clamp01(point[1])] for point in polygon[::step]])
|
||||
return anchors[:32]
|
||||
anchors.extend([[_clamp01(point[0]), _clamp01(point[1])] for point in polygon])
|
||||
return len(anchors), _sample_anchor_points(anchors)
|
||||
|
||||
|
||||
def _normalize_smoothing_options(strength: float | int | None, method: str | None = None) -> dict[str, Any]:
|
||||
@@ -129,8 +232,14 @@ def _normalize_smoothing_options(strength: float | int | None, method: str | Non
|
||||
}
|
||||
|
||||
|
||||
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int) -> list[list[float]]:
|
||||
def _smoothing_ratio(strength: float, curve: float = 1.65) -> float:
|
||||
normalized = max(0.0, min(float(strength or 0.0), 100.0)) / 100.0
|
||||
return normalized ** curve
|
||||
|
||||
|
||||
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int, corner_cut: float = 0.25) -> list[list[float]]:
|
||||
points = polygon
|
||||
q = max(0.02, min(float(corner_cut), 0.25))
|
||||
for _ in range(max(0, iterations)):
|
||||
if len(points) < 3:
|
||||
break
|
||||
@@ -138,12 +247,12 @@ def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int) -> list
|
||||
for index, current in enumerate(points):
|
||||
following = points[(index + 1) % len(points)]
|
||||
next_points.append([
|
||||
_clamp01(0.75 * current[0] + 0.25 * following[0]),
|
||||
_clamp01(0.75 * current[1] + 0.25 * following[1]),
|
||||
_clamp01((1.0 - q) * current[0] + q * following[0]),
|
||||
_clamp01((1.0 - q) * current[1] + q * following[1]),
|
||||
])
|
||||
next_points.append([
|
||||
_clamp01(0.25 * current[0] + 0.75 * following[0]),
|
||||
_clamp01(0.25 * current[1] + 0.75 * following[1]),
|
||||
_clamp01(q * current[0] + (1.0 - q) * following[0]),
|
||||
_clamp01(q * current[1] + (1.0 - q) * following[1]),
|
||||
])
|
||||
points = next_points
|
||||
return points
|
||||
@@ -154,7 +263,7 @@ def _simplify_polygon(polygon: list[list[float]], strength: float) -> list[list[
|
||||
return polygon
|
||||
contour = np.array([[[point[0], point[1]]] for point in polygon], dtype=np.float32)
|
||||
arc_length = cv2.arcLength(contour, True)
|
||||
epsilon = arc_length * (0.001 + (strength / 100.0) * 0.006)
|
||||
epsilon = arc_length * (0.00015 + _smoothing_ratio(strength) * 0.00735)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True).reshape(-1, 2)
|
||||
if len(approx) < 3:
|
||||
return polygon
|
||||
@@ -165,9 +274,25 @@ def _smooth_polygon(polygon: list[list[float]], smoothing: dict[str, Any]) -> li
|
||||
strength = float(smoothing.get("strength") or 0.0)
|
||||
if strength <= 0:
|
||||
return _normalize_polygon(polygon)
|
||||
iterations = max(1, min(3, int(strength // 35) + 1))
|
||||
smoothed = _chaikin_smooth_polygon(_normalize_polygon(polygon), iterations)
|
||||
simplified = _simplify_polygon(smoothed, strength)
|
||||
effective_strength = _smoothing_ratio(strength, curve=1.45) * 100.0
|
||||
if effective_strength >= 85:
|
||||
iterations = 4
|
||||
elif effective_strength >= 55:
|
||||
iterations = 3
|
||||
elif effective_strength >= 25:
|
||||
iterations = 2
|
||||
else:
|
||||
iterations = 1
|
||||
corner_cut = 0.03 + _smoothing_ratio(strength, curve=1.35) * 0.22
|
||||
normalized = _normalize_polygon(polygon)
|
||||
pre_simplified = _simplify_polygon(normalized, effective_strength * 0.25)
|
||||
smoothed = _chaikin_smooth_polygon(pre_simplified, iterations, corner_cut)
|
||||
simplified = _simplify_polygon(smoothed, effective_strength)
|
||||
if len(simplified) > len(normalized):
|
||||
for fallback_strength in (25.0, 35.0, 50.0, 70.0, 90.0, 100.0):
|
||||
simplified = _simplify_polygon(simplified, max(effective_strength, fallback_strength))
|
||||
if len(simplified) <= len(normalized):
|
||||
break
|
||||
return simplified if len(simplified) >= 3 else _normalize_polygon(polygon)
|
||||
|
||||
|
||||
@@ -321,7 +446,11 @@ def _filter_predictions(
|
||||
response_model=PredictResponse,
|
||||
summary="Run SAM inference with a prompt",
|
||||
)
|
||||
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
def predict(
|
||||
payload: PredictRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Execute selected SAM segmentation given an image and a prompt.
|
||||
|
||||
- **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized
|
||||
@@ -330,9 +459,7 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
- **interactive**: `prompt_data` is `{ "box": [...], "points": [[x, y]], "labels": [1, 0] }`.
|
||||
- **semantic**: disabled in the current SAM 2.1 point/box product flow.
|
||||
"""
|
||||
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
frame = _owned_frame_or_404(payload.image_id, db, current_user)
|
||||
|
||||
image = _load_frame_image(frame)
|
||||
prompt_type = payload.prompt_type.lower()
|
||||
@@ -478,7 +605,10 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
response_model=AiRuntimeStatus,
|
||||
summary="Get SAM model and GPU runtime status",
|
||||
)
|
||||
def model_status(selected_model: str | None = None) -> dict:
|
||||
def model_status(
|
||||
selected_model: str | None = None,
|
||||
_current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Return real runtime availability for GPU and the currently enabled SAM model."""
|
||||
try:
|
||||
return sam_registry.runtime_status(selected_model)
|
||||
@@ -491,12 +621,14 @@ def model_status(selected_model: str | None = None) -> dict:
|
||||
response_model=MaskAnalysisResponse,
|
||||
summary="Analyze mask geometry and prompt anchors",
|
||||
)
|
||||
def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) -> dict:
|
||||
def analyze_mask(
|
||||
payload: MaskAnalysisRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Return backend-computed mask properties for the frontend inspector."""
|
||||
if payload.frame_id is not None:
|
||||
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
_owned_frame_or_404(payload.frame_id, db, current_user)
|
||||
|
||||
mask_data = payload.mask_data or {}
|
||||
polygons = mask_data.get("polygons") or []
|
||||
@@ -521,13 +653,13 @@ def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) ->
|
||||
else:
|
||||
confidence_source = "manual_or_imported"
|
||||
|
||||
anchors = _analysis_anchors(valid_polygons, payload.points)
|
||||
anchor_count, anchors = _analysis_anchor_summary(valid_polygons)
|
||||
message = "已从后端重新提取几何拓扑锚点" if payload.extract_skeleton else "已读取后端几何属性"
|
||||
|
||||
return {
|
||||
"confidence": confidence,
|
||||
"confidence_source": confidence_source,
|
||||
"topology_anchor_count": len(anchors),
|
||||
"topology_anchor_count": anchor_count,
|
||||
"topology_anchors": anchors,
|
||||
"area": area,
|
||||
"bbox": bbox,
|
||||
@@ -541,16 +673,18 @@ def analyze_mask(payload: MaskAnalysisRequest, db: Session = Depends(get_db)) ->
|
||||
response_model=SmoothMaskResponse,
|
||||
summary="Smooth editable mask polygons with backend geometry rules",
|
||||
)
|
||||
def smooth_mask(payload: SmoothMaskRequest, db: Session = Depends(get_db)) -> dict:
|
||||
def smooth_mask(
|
||||
payload: SmoothMaskRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Return a smoothed polygon mask without persisting it.
|
||||
|
||||
The frontend keeps this as an explicit edit operation: users preview/apply it
|
||||
to the current mask, then save through the normal annotation endpoint.
|
||||
"""
|
||||
if payload.frame_id is not None:
|
||||
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
_owned_frame_or_404(payload.frame_id, db, current_user)
|
||||
|
||||
polygons = payload.mask_data.get("polygons") or []
|
||||
valid_polygons = _normalize_polygons(polygons)
|
||||
@@ -564,10 +698,10 @@ def smooth_mask(payload: SmoothMaskRequest, db: Session = Depends(get_db)) -> di
|
||||
|
||||
area = sum(_polygon_area(polygon) for polygon in smoothed_polygons)
|
||||
bbox = _polygon_bbox(smoothed_polygons[0])
|
||||
anchors = _analysis_anchors(smoothed_polygons, payload.points)
|
||||
anchor_count, anchors = _analysis_anchor_summary(smoothed_polygons)
|
||||
return {
|
||||
"polygons": smoothed_polygons,
|
||||
"topology_anchor_count": len(anchors),
|
||||
"topology_anchor_count": anchor_count,
|
||||
"topology_anchors": anchors,
|
||||
"area": area,
|
||||
"bbox": bbox,
|
||||
@@ -581,7 +715,11 @@ def smooth_mask(payload: SmoothMaskRequest, db: Session = Depends(get_db)) -> di
|
||||
response_model=PropagateResponse,
|
||||
summary="Propagate one current-frame region across a video frame segment",
|
||||
)
|
||||
def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
|
||||
def propagate(
|
||||
payload: PropagateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Track one selected region from the current frame across nearby frames.
|
||||
|
||||
SAM 2 uses the official video predictor with the selected mask as the seed.
|
||||
@@ -592,16 +730,8 @@ def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
|
||||
raise HTTPException(status_code=400, detail="direction must be forward, backward, or both")
|
||||
max_frames = max(1, min(int(payload.max_frames or 30), 500))
|
||||
|
||||
project = db.query(Project).filter(Project.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
source_frame = db.query(Frame).filter(
|
||||
Frame.id == payload.frame_id,
|
||||
Frame.project_id == payload.project_id,
|
||||
).first()
|
||||
if not source_frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
_owned_project_or_404(payload.project_id, db, current_user)
|
||||
source_frame = _owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
|
||||
|
||||
seed = payload.seed.model_dump(exclude_none=True)
|
||||
polygons = seed.get("polygons") or []
|
||||
@@ -709,18 +839,14 @@ def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
|
||||
response_model=ProcessingTaskOut,
|
||||
summary="Queue a background video propagation task",
|
||||
)
|
||||
def queue_propagate_task(payload: PropagateTaskRequest, db: Session = Depends(get_db)) -> ProcessingTaskOut:
|
||||
def queue_propagate_task(
|
||||
payload: PropagateTaskRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> ProcessingTaskOut:
|
||||
"""Queue multiple seed/direction propagation steps as one background task."""
|
||||
project = db.query(Project).filter(Project.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
source_frame = db.query(Frame).filter(
|
||||
Frame.id == payload.frame_id,
|
||||
Frame.project_id == payload.project_id,
|
||||
).first()
|
||||
if not source_frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
_owned_project_or_404(payload.project_id, db, current_user)
|
||||
source_frame = _owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
|
||||
|
||||
if not payload.steps:
|
||||
raise HTTPException(status_code=400, detail="Propagation task requires at least one step")
|
||||
@@ -768,11 +894,13 @@ def queue_propagate_task(payload: PropagateTaskRequest, db: Session = Depends(ge
|
||||
response_model=PredictResponse,
|
||||
summary="Run automatic segmentation",
|
||||
)
|
||||
def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
|
||||
def auto_segment(
|
||||
image_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> dict:
|
||||
"""Run automatic mask generation on a frame using a grid of point prompts."""
|
||||
frame = db.query(Frame).filter(Frame.id == image_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
frame = _owned_frame_or_404(image_id, db, current_user)
|
||||
|
||||
image = _load_frame_image(frame)
|
||||
try:
|
||||
@@ -792,16 +920,15 @@ def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
|
||||
def save_annotation(
|
||||
payload: AnnotationCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Annotation:
|
||||
"""Persist an annotation (mask, points, bbox) into the database."""
|
||||
project = db.query(Project).filter(Project.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
_owned_project_or_404(payload.project_id, db, current_user)
|
||||
|
||||
if payload.frame_id:
|
||||
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
_owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
|
||||
if payload.template_id:
|
||||
_visible_template_or_404(payload.template_id, db, current_user)
|
||||
|
||||
annotation = Annotation(**payload.model_dump())
|
||||
db.add(annotation)
|
||||
@@ -823,8 +950,10 @@ async def import_gt_mask(
|
||||
template_id: int | None = Form(None),
|
||||
label: str = Form("GT Mask"),
|
||||
color: str = Form("#22c55e"),
|
||||
unknown_color_policy: str = Form("undefined"),
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> List[Annotation]:
|
||||
"""Convert a binary/label mask image into persisted polygon annotations.
|
||||
|
||||
@@ -833,62 +962,122 @@ async def import_gt_mask(
|
||||
the frontend an editable point-region representation instead of a static
|
||||
bitmap layer.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
_owned_project_or_404(project_id, db, current_user)
|
||||
frame = _owned_frame_or_404(frame_id, db, current_user, project_id)
|
||||
|
||||
frame = db.query(Frame).filter(Frame.id == frame_id, Frame.project_id == project_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
if unknown_color_policy not in {"discard", "undefined"}:
|
||||
raise HTTPException(status_code=400, detail="unknown_color_policy must be discard or undefined")
|
||||
|
||||
template: Template | None = None
|
||||
if template_id is not None:
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
template = _visible_template_or_404(template_id, db, current_user)
|
||||
|
||||
data = await file.read()
|
||||
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
|
||||
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
if image is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid mask image")
|
||||
|
||||
if image.ndim == 2:
|
||||
label_image = image
|
||||
elif image.ndim == 3 and image.shape[2] >= 3:
|
||||
channels = image[:, :, :3]
|
||||
# GT label images are maskid maps: either grayscale or RGB/BGR where
|
||||
# all three color channels contain the same maskid value [X, X, X].
|
||||
if not (np.array_equal(channels[:, :, 0], channels[:, :, 1]) and np.array_equal(channels[:, :, 1], channels[:, :, 2])):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="GT Mask 图片不符合要求:请上传灰度图,或 RGB 三通道完全相同的 maskid 图(背景 0,像素值为 maskid)。",
|
||||
)
|
||||
label_image = channels[:, :, 0]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="GT Mask 图片不符合要求:请上传灰度图,或 RGB 三通道完全相同的 maskid 图(背景 0,像素值为 maskid)。",
|
||||
)
|
||||
|
||||
width = int(frame.width or image.shape[1])
|
||||
height = int(frame.height or image.shape[0])
|
||||
label_values = [int(value) for value in np.unique(image) if int(value) > 0]
|
||||
if not label_values:
|
||||
original_height, original_width = int(label_image.shape[0]), int(label_image.shape[1])
|
||||
resized_to_frame = original_width != width or original_height != height
|
||||
if resized_to_frame:
|
||||
label_image = cv2.resize(label_image, (width, height), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
by_maskid, _by_color = _template_class_maps(template)
|
||||
has_template_classes = bool(by_maskid)
|
||||
fallback_color = _normalize_hex_color(color) or "#22c55e"
|
||||
|
||||
import_items: list[dict[str, Any]] = []
|
||||
skipped_unknown = 0
|
||||
label_values = [int(value) for value in np.unique(label_image) if int(value) > 0]
|
||||
for label_value in label_values:
|
||||
class_meta = by_maskid.get(label_value)
|
||||
is_unknown = has_template_classes and class_meta is None
|
||||
if is_unknown and unknown_color_policy == "discard":
|
||||
skipped_unknown += 1
|
||||
continue
|
||||
if class_meta:
|
||||
annotation_label = class_meta["name"]
|
||||
annotation_color = class_meta["color"]
|
||||
elif is_unknown:
|
||||
annotation_label = _gt_unknown_label(label_value)
|
||||
annotation_color = fallback_color
|
||||
else:
|
||||
annotation_label = f"{label} {label_value}" if len(label_values) > 1 else label
|
||||
annotation_color = fallback_color
|
||||
import_items.append({
|
||||
"token": label_value,
|
||||
"binary": np.where(label_image == label_value, 255, 0).astype(np.uint8),
|
||||
"label": annotation_label,
|
||||
"color": annotation_color,
|
||||
"class": class_meta,
|
||||
"unknown": is_unknown,
|
||||
"metadata": {
|
||||
"gt_label_value": label_value,
|
||||
"gt_original_size": {"width": original_width, "height": original_height},
|
||||
"gt_resized_to_frame": resized_to_frame,
|
||||
},
|
||||
})
|
||||
|
||||
if not import_items:
|
||||
if skipped_unknown > 0:
|
||||
raise HTTPException(status_code=400, detail="No matching GT mask classes found")
|
||||
raise HTTPException(status_code=400, detail="No foreground mask regions found")
|
||||
has_multiple_labels = len(label_values) > 1
|
||||
|
||||
annotations: list[Annotation] = []
|
||||
for label_value in label_values:
|
||||
binary = np.where(image == label_value, 255, 0).astype(np.uint8)
|
||||
for item in import_items:
|
||||
binary = item["binary"]
|
||||
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
annotation_label = f"{label} {label_value}" if has_multiple_labels else label
|
||||
|
||||
for contour in contours:
|
||||
if cv2.contourArea(contour) < 1:
|
||||
continue
|
||||
|
||||
polygon = _normalized_contour(contour, image.shape[1], image.shape[0])
|
||||
polygon = _normalized_contour(contour, binary.shape[1], binary.shape[0])
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
|
||||
component = np.zeros_like(binary, dtype=np.uint8)
|
||||
cv2.drawContours(component, [contour], -1, 1, thickness=-1)
|
||||
seed_point = _component_seed_point(component, image.shape[1], image.shape[0])
|
||||
bbox = _contour_bbox(contour, image.shape[1], image.shape[0])
|
||||
seed_point = _component_seed_point(component, binary.shape[1], binary.shape[0])
|
||||
bbox = _contour_bbox(contour, binary.shape[1], binary.shape[0])
|
||||
mask_data = {
|
||||
"polygons": [polygon],
|
||||
"label": item["label"],
|
||||
"color": item["color"],
|
||||
"source": "gt_mask",
|
||||
"image_size": {"width": width, "height": height},
|
||||
**item["metadata"],
|
||||
}
|
||||
if item["class"]:
|
||||
mask_data["class"] = item["class"]
|
||||
if item["unknown"]:
|
||||
mask_data["gt_unknown_class"] = True
|
||||
|
||||
annotation = Annotation(
|
||||
project_id=project_id,
|
||||
frame_id=frame_id,
|
||||
template_id=template_id,
|
||||
mask_data={
|
||||
"polygons": [polygon],
|
||||
"label": annotation_label,
|
||||
"color": color,
|
||||
"source": "gt_mask",
|
||||
"gt_label_value": label_value,
|
||||
"image_size": {"width": width, "height": height},
|
||||
},
|
||||
mask_data=mask_data,
|
||||
points=[seed_point],
|
||||
bbox=bbox,
|
||||
)
|
||||
@@ -914,14 +1103,14 @@ def list_annotations(
|
||||
project_id: int,
|
||||
frame_id: int | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> List[Annotation]:
|
||||
"""Return persisted annotations for a project, optionally scoped to one frame."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
_owned_project_or_404(project_id, db, current_user)
|
||||
|
||||
query = db.query(Annotation).filter(Annotation.project_id == project_id)
|
||||
if frame_id is not None:
|
||||
_owned_frame_or_404(frame_id, db, current_user, project_id)
|
||||
query = query.filter(Annotation.frame_id == frame_id)
|
||||
return query.order_by(Annotation.id).all()
|
||||
|
||||
@@ -935,17 +1124,21 @@ def update_annotation(
|
||||
annotation_id: int,
|
||||
payload: AnnotationUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Annotation:
|
||||
"""Update mutable annotation fields persisted in the database."""
|
||||
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
|
||||
annotation = (
|
||||
db.query(Annotation)
|
||||
.join(Project, Project.id == Annotation.project_id)
|
||||
.filter(Annotation.id == annotation_id, Project.owner_user_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
if not annotation:
|
||||
raise HTTPException(status_code=404, detail="Annotation not found")
|
||||
|
||||
updates = payload.model_dump(exclude_unset=True)
|
||||
if "template_id" in updates and updates["template_id"] is not None:
|
||||
template = db.query(Template).filter(Template.id == updates["template_id"]).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
_visible_template_or_404(updates["template_id"], db, current_user)
|
||||
|
||||
for field, value in updates.items():
|
||||
setattr(annotation, field, value)
|
||||
@@ -964,9 +1157,15 @@ def update_annotation(
|
||||
def delete_annotation(
|
||||
annotation_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(require_editor),
|
||||
) -> Response:
|
||||
"""Delete an annotation and its derived mask rows through ORM cascade."""
|
||||
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
|
||||
annotation = (
|
||||
db.query(Annotation)
|
||||
.join(Project, Project.id == Annotation.project_id)
|
||||
.filter(Annotation.id == annotation_id, Project.owner_user_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
if not annotation:
|
||||
raise HTTPException(status_code=404, detail="Annotation not found")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user