- 增强 DICOM/视频项目导入与演示数据:DICOM 按文件名自然顺序处理,导入后展示上传与解析任务进度,恢复演示出厂设置保留演示视频和演示 DICOM 项目,并补充 demo media seed 逻辑。 - 完善项目管理:项目支持重命名、删除、复制,删除使用站内确认弹窗,复制支持新项目重置和全内容复制,DICOM 项目不显示生成帧入口。 - 完善 GT Mask 与导出链路:只支持 8-bit maskid 图导入,非法/全背景图明确拒绝,尺寸自动适配,高精度 polygon 回显;统一导出默认当前帧,GT_label 使用 uint8 和真实 maskid,待分类 maskid 0 与背景一致。 - 完善分割工作区交互:新增画笔和橡皮擦并支持尺寸控制,移除创建点/线段入口,工具栏按类别分隔,AI 智能分割使用明确 AI 图标,取消黄色 seed point,清空/删除传播 mask 后同步清理空帧时间轴状态。 - 完善传播与时间轴:自动传播使用 SAM 2.1 权重任务,参考帧无遮罩时提示,传播历史按同一蓝色系递进变暗,删除/清空传播链时保留人工或独立 AI 标注来源。 - 完善模板库:新增头颈部 CT 分割默认模板,所有模板保留 maskid 0 待分类,支持鼠标复制模板、拖拽层级、JSON 批量导入预览、删除 label 和站内删除确认。 - 完善用户与高风险确认:用户改密码、删除用户、恢复演示出厂设置和清空人工/AI 标注帧均改为站内确认交互,避免浏览器原生 prompt/confirm。 - 补充前后端测试与文档:更新项目、模板、GT 导入、导出、传播、DICOM、用户管理等测试,并同步 README、AGENTS 和 doc 下实现/契约/测试计划文档。
1187 lines
45 KiB
Python
1187 lines
45 KiB
Python
"""AI inference endpoints using selectable SAM runtimes."""
|
||
|
||
import logging
|
||
import math
|
||
import tempfile
|
||
from pathlib import Path
|
||
from typing import Any, List
|
||
|
||
import cv2
|
||
import numpy as np
|
||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Response, UploadFile, status
|
||
from sqlalchemy import or_
|
||
from sqlalchemy.orm import Session
|
||
|
||
from database import get_db
|
||
from minio_client import download_file
|
||
from models import Project, Frame, Template, Annotation, ProcessingTask, User
|
||
from routers.auth import get_current_user, require_editor
|
||
from schemas import (
|
||
AiRuntimeStatus,
|
||
MaskAnalysisRequest,
|
||
MaskAnalysisResponse,
|
||
SmoothMaskRequest,
|
||
SmoothMaskResponse,
|
||
PredictRequest,
|
||
PredictResponse,
|
||
PropagateRequest,
|
||
PropagateResponse,
|
||
PropagateTaskRequest,
|
||
ProcessingTaskOut,
|
||
AnnotationOut,
|
||
AnnotationCreate,
|
||
AnnotationUpdate,
|
||
)
|
||
from progress_events import publish_task_progress_event
|
||
from statuses import TASK_STATUS_QUEUED
|
||
from worker_tasks import propagate_project_masks
|
||
from services.sam_registry import ModelUnavailableError, sam_registry
|
||
|
||
logger = logging.getLogger(__name__)
|
||
router = APIRouter(prefix="/api/ai", tags=["AI"])
|
||
GT_MASK_EMPTY_DETAIL = "GT Mask 图片中没有非背景 maskid 区域。"
|
||
GT_IMPORT_MAX_CONTOUR_POINTS = 2048
|
||
GT_IMPORT_CONTOUR_EPSILON_RATIO = 0.00075
|
||
GT_IMPORT_MIN_CONTOUR_EPSILON = 0.35
|
||
|
||
|
||
def _owned_project_or_404(project_id: int, db: Session, current_user: User) -> Project:
|
||
project = db.query(Project).filter(
|
||
Project.id == project_id,
|
||
Project.owner_user_id == current_user.id,
|
||
).first()
|
||
if not project:
|
||
raise HTTPException(status_code=404, detail="Project not found")
|
||
return project
|
||
|
||
|
||
def _owned_frame_or_404(frame_id: int, db: Session, current_user: User, project_id: int | None = None) -> Frame:
|
||
query = (
|
||
db.query(Frame)
|
||
.join(Project, Project.id == Frame.project_id)
|
||
.filter(Frame.id == frame_id, Project.owner_user_id == current_user.id)
|
||
)
|
||
if project_id is not None:
|
||
query = query.filter(Frame.project_id == project_id)
|
||
frame = query.first()
|
||
if not frame:
|
||
raise HTTPException(status_code=404, detail="Frame not found")
|
||
return frame
|
||
|
||
|
||
def _visible_template_or_404(template_id: int, db: Session, current_user: User) -> Template:
|
||
template = db.query(Template).filter(
|
||
Template.id == template_id,
|
||
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
|
||
).first()
|
||
if not template:
|
||
raise HTTPException(status_code=404, detail="Template not found")
|
||
return template
|
||
|
||
|
||
def _normalize_hex_color(value: Any) -> str | None:
|
||
if not isinstance(value, str):
|
||
return None
|
||
text = value.strip().lower()
|
||
if not text:
|
||
return None
|
||
if not text.startswith("#"):
|
||
text = f"#{text}"
|
||
if len(text) == 4:
|
||
text = "#" + "".join(char * 2 for char in text[1:])
|
||
if len(text) != 7:
|
||
return None
|
||
try:
|
||
int(text[1:], 16)
|
||
except ValueError:
|
||
return None
|
||
return text
|
||
|
||
|
||
def _rgb_tuple_to_hex(rgb: tuple[int, int, int]) -> str:
|
||
values = []
|
||
for channel in rgb:
|
||
value = int(channel)
|
||
if value > 255:
|
||
value = int(round(value / 257))
|
||
values.append(min(max(value, 0), 255))
|
||
return f"#{values[0]:02x}{values[1]:02x}{values[2]:02x}"
|
||
|
||
|
||
def _template_class_maps(template: Template | None) -> tuple[dict[int, dict[str, Any]], dict[str, dict[str, Any]]]:
|
||
classes = ((template.mapping_rules or {}).get("classes") if template else None) or []
|
||
by_maskid: dict[int, dict[str, Any]] = {}
|
||
by_color: dict[str, dict[str, Any]] = {}
|
||
for index, item in enumerate(classes):
|
||
if not isinstance(item, dict):
|
||
continue
|
||
maskid_value = item.get("maskId", item.get("maskid", item.get("mask_id")))
|
||
try:
|
||
maskid = int(maskid_value)
|
||
except (TypeError, ValueError):
|
||
maskid = index + 1
|
||
color = _normalize_hex_color(item.get("color")) or "#22c55e"
|
||
class_meta = {
|
||
"id": str(item.get("id") or f"maskid-{maskid}"),
|
||
"name": str(item.get("name") or f"类别 {maskid}"),
|
||
"color": color,
|
||
"zIndex": int(item.get("zIndex", item.get("z_index", index * 10))),
|
||
"maskId": maskid,
|
||
**({"category": item.get("category")} if item.get("category") else {}),
|
||
}
|
||
if maskid > 0:
|
||
by_maskid[maskid] = class_meta
|
||
by_color[color] = class_meta
|
||
return by_maskid, by_color
|
||
|
||
|
||
def _gt_unknown_label(token: int | str) -> str:
|
||
if isinstance(token, int):
|
||
return f"未定义类别 {token}"
|
||
return f"未定义颜色 {token}"
|
||
|
||
|
||
def _load_frame_image(frame: Frame) -> np.ndarray:
|
||
"""Download a frame from MinIO and decode it to an RGB numpy array."""
|
||
try:
|
||
data = download_file(frame.image_url)
|
||
arr = np.frombuffer(data, dtype=np.uint8)
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
if img is None:
|
||
raise ValueError("OpenCV could not decode image")
|
||
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.error("Failed to load frame image: %s", exc)
|
||
raise HTTPException(status_code=500, detail="Failed to load frame image") from exc
|
||
|
||
|
||
def _normalized_contour(contour: np.ndarray, width: int, height: int) -> list[list[float]]:
|
||
"""Convert a contour to a detailed normalized polygon with a point-count cap."""
|
||
arc_length = cv2.arcLength(contour, True)
|
||
epsilon = max(GT_IMPORT_MIN_CONTOUR_EPSILON, arc_length * GT_IMPORT_CONTOUR_EPSILON_RATIO)
|
||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||
while len(approx) > GT_IMPORT_MAX_CONTOUR_POINTS and epsilon < arc_length * 0.02:
|
||
epsilon *= 1.5
|
||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||
points = approx.reshape(-1, 2)
|
||
if len(points) < 3:
|
||
points = contour.reshape(-1, 2)
|
||
if len(points) > GT_IMPORT_MAX_CONTOUR_POINTS:
|
||
step = int(math.ceil(len(points) / GT_IMPORT_MAX_CONTOUR_POINTS))
|
||
points = points[::step]
|
||
return [
|
||
[
|
||
min(max(float(x) / max(width, 1), 0.0), 1.0),
|
||
min(max(float(y) / max(height, 1), 0.0), 1.0),
|
||
]
|
||
for x, y in points
|
||
]
|
||
|
||
|
||
def _contour_bbox(contour: np.ndarray, width: int, height: int) -> list[float]:
|
||
x, y, w, h = cv2.boundingRect(contour)
|
||
return [
|
||
min(max(float(x) / max(width, 1), 0.0), 1.0),
|
||
min(max(float(y) / max(height, 1), 0.0), 1.0),
|
||
min(max(float(w) / max(width, 1), 0.0), 1.0),
|
||
min(max(float(h) / max(height, 1), 0.0), 1.0),
|
||
]
|
||
|
||
|
||
def _polygon_bbox(polygon: list[list[float]]) -> list[float]:
|
||
xs = [_clamp01(point[0]) for point in polygon]
|
||
ys = [_clamp01(point[1]) for point in polygon]
|
||
left, right = min(xs), max(xs)
|
||
top, bottom = min(ys), max(ys)
|
||
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
|
||
|
||
|
||
def _polygon_area(polygon: list[list[float]]) -> float:
|
||
if len(polygon) < 3:
|
||
return 0.0
|
||
total = 0.0
|
||
for index, point in enumerate(polygon):
|
||
next_point = polygon[(index + 1) % len(polygon)]
|
||
total += _clamp01(point[0]) * _clamp01(next_point[1])
|
||
total -= _clamp01(next_point[0]) * _clamp01(point[1])
|
||
return abs(total) / 2.0
|
||
|
||
|
||
def _normalize_polygon(polygon: list[list[float]]) -> list[list[float]]:
|
||
return [[_clamp01(point[0]), _clamp01(point[1])] for point in polygon if len(point) >= 2]
|
||
|
||
|
||
def _normalize_polygons(polygons: list[list[list[float]]]) -> list[list[list[float]]]:
|
||
return [polygon for polygon in (_normalize_polygon(polygon) for polygon in polygons) if len(polygon) >= 3]
|
||
|
||
|
||
def _sample_anchor_points(anchors: list[list[float]], limit: int = 64) -> list[list[float]]:
|
||
if len(anchors) <= limit:
|
||
return anchors
|
||
step = max(1, math.ceil(len(anchors) / limit))
|
||
return anchors[::step][:limit]
|
||
|
||
|
||
def _analysis_anchor_summary(polygons: list[list[list[float]]]) -> tuple[int, list[list[float]]]:
|
||
anchors: list[list[float]] = []
|
||
for polygon in polygons:
|
||
if not polygon:
|
||
continue
|
||
anchors.extend([[_clamp01(point[0]), _clamp01(point[1])] for point in polygon])
|
||
return len(anchors), _sample_anchor_points(anchors)
|
||
|
||
|
||
def _normalize_smoothing_options(strength: float | int | None, method: str | None = None) -> dict[str, Any]:
|
||
clamped_strength = max(0.0, min(float(strength or 0.0), 100.0))
|
||
normalized_method = (method or "chaikin").lower()
|
||
if normalized_method != "chaikin":
|
||
normalized_method = "chaikin"
|
||
return {
|
||
"strength": round(clamped_strength, 2),
|
||
"method": normalized_method,
|
||
}
|
||
|
||
|
||
def _smoothing_ratio(strength: float, curve: float = 1.65) -> float:
|
||
normalized = max(0.0, min(float(strength or 0.0), 100.0)) / 100.0
|
||
return normalized ** curve
|
||
|
||
|
||
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int, corner_cut: float = 0.25) -> list[list[float]]:
|
||
points = polygon
|
||
q = max(0.02, min(float(corner_cut), 0.25))
|
||
for _ in range(max(0, iterations)):
|
||
if len(points) < 3:
|
||
break
|
||
next_points: list[list[float]] = []
|
||
for index, current in enumerate(points):
|
||
following = points[(index + 1) % len(points)]
|
||
next_points.append([
|
||
_clamp01((1.0 - q) * current[0] + q * following[0]),
|
||
_clamp01((1.0 - q) * current[1] + q * following[1]),
|
||
])
|
||
next_points.append([
|
||
_clamp01(q * current[0] + (1.0 - q) * following[0]),
|
||
_clamp01(q * current[1] + (1.0 - q) * following[1]),
|
||
])
|
||
points = next_points
|
||
return points
|
||
|
||
|
||
def _simplify_polygon(polygon: list[list[float]], strength: float) -> list[list[float]]:
|
||
if len(polygon) < 3 or strength <= 0:
|
||
return polygon
|
||
contour = np.array([[[point[0], point[1]]] for point in polygon], dtype=np.float32)
|
||
arc_length = cv2.arcLength(contour, True)
|
||
epsilon = arc_length * (0.00015 + _smoothing_ratio(strength) * 0.00735)
|
||
approx = cv2.approxPolyDP(contour, epsilon, True).reshape(-1, 2)
|
||
if len(approx) < 3:
|
||
return polygon
|
||
return [[_clamp01(float(x)), _clamp01(float(y))] for x, y in approx]
|
||
|
||
|
||
def _smooth_polygon(polygon: list[list[float]], smoothing: dict[str, Any]) -> list[list[float]]:
|
||
strength = float(smoothing.get("strength") or 0.0)
|
||
if strength <= 0:
|
||
return _normalize_polygon(polygon)
|
||
effective_strength = _smoothing_ratio(strength, curve=1.45) * 100.0
|
||
if effective_strength >= 85:
|
||
iterations = 4
|
||
elif effective_strength >= 55:
|
||
iterations = 3
|
||
elif effective_strength >= 25:
|
||
iterations = 2
|
||
else:
|
||
iterations = 1
|
||
corner_cut = 0.03 + _smoothing_ratio(strength, curve=1.35) * 0.22
|
||
normalized = _normalize_polygon(polygon)
|
||
pre_simplified = _simplify_polygon(normalized, effective_strength * 0.25)
|
||
smoothed = _chaikin_smooth_polygon(pre_simplified, iterations, corner_cut)
|
||
simplified = _simplify_polygon(smoothed, effective_strength)
|
||
if len(simplified) > len(normalized):
|
||
for fallback_strength in (25.0, 35.0, 50.0, 70.0, 90.0, 100.0):
|
||
simplified = _simplify_polygon(simplified, max(effective_strength, fallback_strength))
|
||
if len(simplified) <= len(normalized):
|
||
break
|
||
return simplified if len(simplified) >= 3 else _normalize_polygon(polygon)
|
||
|
||
|
||
def _smooth_polygons(polygons: list[list[list[float]]], smoothing: dict[str, Any]) -> list[list[list[float]]]:
|
||
return [polygon for polygon in (_smooth_polygon(polygon, smoothing) for polygon in polygons) if len(polygon) >= 3]
|
||
|
||
|
||
def _frame_window(
|
||
frames: list[Frame],
|
||
source_position: int,
|
||
direction: str,
|
||
max_frames: int,
|
||
) -> tuple[list[Frame], int]:
|
||
count = max(1, min(max_frames, len(frames)))
|
||
if direction == "backward":
|
||
start = max(0, source_position - count + 1)
|
||
return frames[start:source_position + 1], source_position - start
|
||
if direction == "both":
|
||
before = (count - 1) // 2
|
||
after = count - 1 - before
|
||
start = max(0, source_position - before)
|
||
end = min(len(frames), source_position + after + 1)
|
||
while end - start < count and start > 0:
|
||
start -= 1
|
||
while end - start < count and end < len(frames):
|
||
end += 1
|
||
return frames[start:end], source_position - start
|
||
end = min(len(frames), source_position + count)
|
||
return frames[source_position:end], 0
|
||
|
||
|
||
def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]:
|
||
paths = []
|
||
for index, frame in enumerate(frames):
|
||
data = download_file(frame.image_url)
|
||
path = directory / f"frame_{index:06d}.jpg"
|
||
path.write_bytes(data)
|
||
paths.append(str(path))
|
||
return paths
|
||
|
||
|
||
def _component_seed_point(component_mask: np.ndarray, width: int, height: int) -> list[float]:
|
||
"""Reduce a binary component to one positive prompt point using distance transform."""
|
||
dist = cv2.distanceTransform(component_mask.astype(np.uint8), cv2.DIST_L2, 5)
|
||
_, _, _, max_loc = cv2.minMaxLoc(dist)
|
||
x, y = max_loc
|
||
return [
|
||
min(max(float(x) / max(width, 1), 0.0), 1.0),
|
||
min(max(float(y) / max(height, 1), 0.0), 1.0),
|
||
]
|
||
|
||
|
||
def _clamp01(value: float) -> float:
|
||
return min(max(float(value), 0.0), 1.0)
|
||
|
||
|
||
def _point_in_polygon(point: list[float], polygon: list[list[float]]) -> bool:
|
||
"""Return whether a normalized point is inside a normalized polygon."""
|
||
if len(polygon) < 3:
|
||
return False
|
||
x, y = point
|
||
inside = False
|
||
j = len(polygon) - 1
|
||
for i, current in enumerate(polygon):
|
||
xi, yi = current
|
||
xj, yj = polygon[j]
|
||
intersects = ((yi > y) != (yj > y)) and (
|
||
x < (xj - xi) * (y - yi) / ((yj - yi) or 1e-9) + xi
|
||
)
|
||
if intersects:
|
||
inside = not inside
|
||
j = i
|
||
return inside
|
||
|
||
|
||
def _crop_bounds_from_points(points: list[list[float]], margin: float) -> tuple[float, float, float, float]:
|
||
xs = [_clamp01(point[0]) for point in points]
|
||
ys = [_clamp01(point[1]) for point in points]
|
||
x1 = max(0.0, min(xs) - margin)
|
||
y1 = max(0.0, min(ys) - margin)
|
||
x2 = min(1.0, max(xs) + margin)
|
||
y2 = min(1.0, max(ys) + margin)
|
||
if x2 - x1 < 0.05:
|
||
center = (x1 + x2) / 2
|
||
x1 = max(0.0, center - 0.025)
|
||
x2 = min(1.0, center + 0.025)
|
||
if y2 - y1 < 0.05:
|
||
center = (y1 + y2) / 2
|
||
y1 = max(0.0, center - 0.025)
|
||
y2 = min(1.0, center + 0.025)
|
||
return x1, y1, x2, y2
|
||
|
||
|
||
def _crop_image(image: np.ndarray, bounds: tuple[float, float, float, float]) -> np.ndarray:
|
||
height, width = image.shape[:2]
|
||
x1, y1, x2, y2 = bounds
|
||
left = int(round(x1 * width))
|
||
top = int(round(y1 * height))
|
||
right = max(left + 1, int(round(x2 * width)))
|
||
bottom = max(top + 1, int(round(y2 * height)))
|
||
return image[top:bottom, left:right]
|
||
|
||
|
||
def _to_crop_point(point: list[float], bounds: tuple[float, float, float, float]) -> list[float]:
|
||
x1, y1, x2, y2 = bounds
|
||
return [
|
||
_clamp01((float(point[0]) - x1) / max(x2 - x1, 1e-9)),
|
||
_clamp01((float(point[1]) - y1) / max(y2 - y1, 1e-9)),
|
||
]
|
||
|
||
|
||
def _from_crop_polygon(
|
||
polygon: list[list[float]],
|
||
bounds: tuple[float, float, float, float],
|
||
) -> list[list[float]]:
|
||
x1, y1, x2, y2 = bounds
|
||
return [
|
||
[
|
||
_clamp01(x1 + float(point[0]) * (x2 - x1)),
|
||
_clamp01(y1 + float(point[1]) * (y2 - y1)),
|
||
]
|
||
for point in polygon
|
||
]
|
||
|
||
|
||
def _filter_predictions(
|
||
polygons: list[list[list[float]]],
|
||
scores: list[float],
|
||
options: dict[str, Any],
|
||
negative_points: list[list[float]] | None = None,
|
||
) -> tuple[list[list[list[float]]], list[float]]:
|
||
if not options.get("auto_filter_background"):
|
||
return polygons, scores
|
||
|
||
min_score = float(options.get("min_score", 0.0) or 0.0)
|
||
next_polygons: list[list[list[float]]] = []
|
||
next_scores: list[float] = []
|
||
for index, polygon in enumerate(polygons):
|
||
score = scores[index] if index < len(scores) else 0.0
|
||
if score < min_score:
|
||
continue
|
||
if negative_points and any(_point_in_polygon(point, polygon) for point in negative_points):
|
||
continue
|
||
next_polygons.append(polygon)
|
||
next_scores.append(score)
|
||
return next_polygons, next_scores
|
||
|
||
|
||
@router.post(
|
||
"/predict",
|
||
response_model=PredictResponse,
|
||
summary="Run SAM inference with a prompt",
|
||
)
|
||
def predict(
|
||
payload: PredictRequest,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(require_editor),
|
||
) -> dict:
|
||
"""Execute selected SAM segmentation given an image and a prompt.
|
||
|
||
- **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized
|
||
coordinates or `{ "points": [[x, y], ...], "labels": [1, 0, ...] }`.
|
||
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
|
||
- **interactive**: `prompt_data` is `{ "box": [...], "points": [[x, y]], "labels": [1, 0] }`.
|
||
- **semantic**: disabled in the current SAM 2.1 point/box product flow.
|
||
"""
|
||
frame = _owned_frame_or_404(payload.image_id, db, current_user)
|
||
|
||
image = _load_frame_image(frame)
|
||
prompt_type = payload.prompt_type.lower()
|
||
options = payload.options or {}
|
||
|
||
polygons: List[List[List[float]]] = []
|
||
scores: List[float] = []
|
||
negative_points: list[list[float]] = []
|
||
|
||
try:
|
||
if prompt_type == "point":
|
||
point_payload = payload.prompt_data
|
||
if isinstance(point_payload, dict):
|
||
points = point_payload.get("points")
|
||
labels = point_payload.get("labels")
|
||
else:
|
||
points = point_payload
|
||
labels = None
|
||
|
||
if not isinstance(points, list) or len(points) == 0:
|
||
raise HTTPException(status_code=400, detail="Invalid point prompt data")
|
||
if not isinstance(labels, list) or len(labels) != len(points):
|
||
labels = [1] * len(points)
|
||
negative_points = [
|
||
point for point, label in zip(points, labels) if label == 0
|
||
]
|
||
inference_image = image
|
||
inference_points = points
|
||
crop_bounds = None
|
||
if options.get("crop_to_prompt"):
|
||
margin = float(options.get("crop_margin", 0.25) or 0.25)
|
||
crop_bounds = _crop_bounds_from_points(points, margin)
|
||
inference_image = _crop_image(image, crop_bounds)
|
||
inference_points = [_to_crop_point(point, crop_bounds) for point in points]
|
||
polygons, scores = sam_registry.predict_points(payload.model, inference_image, inference_points, labels)
|
||
if crop_bounds:
|
||
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
|
||
|
||
elif prompt_type == "box":
|
||
box = payload.prompt_data
|
||
if not isinstance(box, list) or len(box) != 4:
|
||
raise HTTPException(status_code=400, detail="Invalid box prompt data")
|
||
inference_image = image
|
||
inference_box = box
|
||
crop_bounds = None
|
||
if options.get("crop_to_prompt"):
|
||
margin = float(options.get("crop_margin", 0.05) or 0.05)
|
||
crop_bounds = _crop_bounds_from_points([[box[0], box[1]], [box[2], box[3]]], margin)
|
||
inference_image = _crop_image(image, crop_bounds)
|
||
inference_box = [
|
||
*_to_crop_point([box[0], box[1]], crop_bounds),
|
||
*_to_crop_point([box[2], box[3]], crop_bounds),
|
||
]
|
||
polygons, scores = sam_registry.predict_box(payload.model, inference_image, inference_box)
|
||
if crop_bounds:
|
||
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
|
||
|
||
elif prompt_type == "interactive":
|
||
prompt = payload.prompt_data
|
||
if not isinstance(prompt, dict):
|
||
raise HTTPException(status_code=400, detail="Invalid interactive prompt data")
|
||
box = prompt.get("box")
|
||
points = prompt.get("points") or []
|
||
labels = prompt.get("labels")
|
||
if box is not None and (not isinstance(box, list) or len(box) != 4):
|
||
raise HTTPException(status_code=400, detail="Invalid interactive box prompt data")
|
||
if not isinstance(points, list):
|
||
raise HTTPException(status_code=400, detail="Invalid interactive point prompt data")
|
||
if not box and len(points) == 0:
|
||
raise HTTPException(status_code=400, detail="Interactive prompt requires a box or points")
|
||
if not isinstance(labels, list) or len(labels) != len(points):
|
||
labels = [1] * len(points)
|
||
negative_points = [
|
||
point for point, label in zip(points, labels) if label == 0
|
||
]
|
||
inference_image = image
|
||
inference_box = box
|
||
inference_points = points
|
||
crop_bounds = None
|
||
if options.get("crop_to_prompt"):
|
||
margin = float(options.get("crop_margin", 0.05) or 0.05)
|
||
crop_points = list(points)
|
||
if box:
|
||
crop_points.extend([[box[0], box[1]], [box[2], box[3]]])
|
||
crop_bounds = _crop_bounds_from_points(crop_points, margin)
|
||
inference_image = _crop_image(image, crop_bounds)
|
||
inference_points = [_to_crop_point(point, crop_bounds) for point in points]
|
||
if box:
|
||
inference_box = [
|
||
*_to_crop_point([box[0], box[1]], crop_bounds),
|
||
*_to_crop_point([box[2], box[3]], crop_bounds),
|
||
]
|
||
polygons, scores = sam_registry.predict_interactive(
|
||
payload.model,
|
||
inference_image,
|
||
inference_box,
|
||
inference_points,
|
||
labels,
|
||
)
|
||
if crop_bounds:
|
||
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
|
||
|
||
elif prompt_type == "semantic":
|
||
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
|
||
min_score = options.get("min_score")
|
||
confidence_threshold = None
|
||
if min_score is not None:
|
||
try:
|
||
parsed_min_score = float(min_score)
|
||
if parsed_min_score > 0:
|
||
confidence_threshold = parsed_min_score
|
||
except (TypeError, ValueError):
|
||
confidence_threshold = None
|
||
polygons, scores = sam_registry.predict_semantic(
|
||
payload.model,
|
||
image,
|
||
text,
|
||
confidence_threshold=confidence_threshold,
|
||
)
|
||
|
||
else:
|
||
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
|
||
except ModelUnavailableError as exc:
|
||
raise HTTPException(status_code=503, detail=str(exc)) from exc
|
||
except NotImplementedError as exc:
|
||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||
except ValueError as exc:
|
||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||
|
||
polygons, scores = _filter_predictions(polygons, scores, options, negative_points)
|
||
logger.info(
|
||
"AI predict completed model=%s prompt_type=%s frame_id=%s polygons=%d",
|
||
payload.model or "default",
|
||
prompt_type,
|
||
payload.image_id,
|
||
len(polygons),
|
||
)
|
||
return {"polygons": polygons, "scores": scores}
|
||
|
||
|
||
@router.get(
|
||
"/models/status",
|
||
response_model=AiRuntimeStatus,
|
||
summary="Get SAM model and GPU runtime status",
|
||
)
|
||
def model_status(
|
||
selected_model: str | None = None,
|
||
_current_user: User = Depends(get_current_user),
|
||
) -> dict:
|
||
"""Return real runtime availability for GPU and the currently enabled SAM model."""
|
||
try:
|
||
return sam_registry.runtime_status(selected_model)
|
||
except ValueError as exc:
|
||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||
|
||
|
||
@router.post(
|
||
"/analyze-mask",
|
||
response_model=MaskAnalysisResponse,
|
||
summary="Analyze mask geometry and prompt anchors",
|
||
)
|
||
def analyze_mask(
|
||
payload: MaskAnalysisRequest,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
) -> dict:
|
||
"""Return backend-computed mask properties for the frontend inspector."""
|
||
if payload.frame_id is not None:
|
||
_owned_frame_or_404(payload.frame_id, db, current_user)
|
||
|
||
mask_data = payload.mask_data or {}
|
||
polygons = mask_data.get("polygons") or []
|
||
if not polygons:
|
||
raise HTTPException(status_code=400, detail="Mask analysis requires polygons")
|
||
|
||
valid_polygons = _normalize_polygons(polygons)
|
||
if not valid_polygons:
|
||
raise HTTPException(status_code=400, detail="Mask analysis requires at least one valid polygon")
|
||
|
||
area = sum(_polygon_area(polygon) for polygon in valid_polygons)
|
||
bbox = payload.bbox or _polygon_bbox(valid_polygons[0])
|
||
source = mask_data.get("source")
|
||
raw_score = mask_data.get("score")
|
||
confidence: float | None = None
|
||
confidence_source = "unavailable"
|
||
if isinstance(raw_score, (int, float)):
|
||
confidence = max(0.0, min(float(raw_score), 1.0))
|
||
confidence_source = "model_score"
|
||
elif source:
|
||
confidence_source = "source_without_score"
|
||
else:
|
||
confidence_source = "manual_or_imported"
|
||
|
||
anchor_count, anchors = _analysis_anchor_summary(valid_polygons)
|
||
message = "已从后端重新提取几何拓扑锚点" if payload.extract_skeleton else "已读取后端几何属性"
|
||
|
||
return {
|
||
"confidence": confidence,
|
||
"confidence_source": confidence_source,
|
||
"topology_anchor_count": anchor_count,
|
||
"topology_anchors": anchors,
|
||
"area": area,
|
||
"bbox": bbox,
|
||
"source": source,
|
||
"message": message,
|
||
}
|
||
|
||
|
||
@router.post(
|
||
"/smooth-mask",
|
||
response_model=SmoothMaskResponse,
|
||
summary="Smooth editable mask polygons with backend geometry rules",
|
||
)
|
||
def smooth_mask(
|
||
payload: SmoothMaskRequest,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(require_editor),
|
||
) -> dict:
|
||
"""Return a smoothed polygon mask without persisting it.
|
||
|
||
The frontend keeps this as an explicit edit operation: users preview/apply it
|
||
to the current mask, then save through the normal annotation endpoint.
|
||
"""
|
||
if payload.frame_id is not None:
|
||
_owned_frame_or_404(payload.frame_id, db, current_user)
|
||
|
||
polygons = payload.mask_data.get("polygons") or []
|
||
valid_polygons = _normalize_polygons(polygons)
|
||
if not valid_polygons:
|
||
raise HTTPException(status_code=400, detail="Mask smoothing requires at least one valid polygon")
|
||
|
||
smoothing = _normalize_smoothing_options(payload.strength, payload.method)
|
||
smoothed_polygons = _smooth_polygons(valid_polygons, smoothing)
|
||
if not smoothed_polygons:
|
||
raise HTTPException(status_code=400, detail="Mask smoothing produced no valid polygons")
|
||
|
||
area = sum(_polygon_area(polygon) for polygon in smoothed_polygons)
|
||
bbox = _polygon_bbox(smoothed_polygons[0])
|
||
anchor_count, anchors = _analysis_anchor_summary(smoothed_polygons)
|
||
return {
|
||
"polygons": smoothed_polygons,
|
||
"topology_anchor_count": anchor_count,
|
||
"topology_anchors": anchors,
|
||
"area": area,
|
||
"bbox": bbox,
|
||
"smoothing": smoothing,
|
||
"message": f"已应用边缘平滑强度 {smoothing['strength']:.0f}",
|
||
}
|
||
|
||
|
||
@router.post(
|
||
"/propagate",
|
||
response_model=PropagateResponse,
|
||
summary="Propagate one current-frame region across a video frame segment",
|
||
)
|
||
def propagate(
|
||
payload: PropagateRequest,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(require_editor),
|
||
) -> dict:
|
||
"""Track one selected region from the current frame across nearby frames.
|
||
|
||
SAM 2 uses the official video predictor with the selected mask as the seed.
|
||
SAM 3 video tracking is currently disabled in this product flow.
|
||
"""
|
||
direction = payload.direction.lower()
|
||
if direction not in {"forward", "backward", "both"}:
|
||
raise HTTPException(status_code=400, detail="direction must be forward, backward, or both")
|
||
max_frames = max(1, min(int(payload.max_frames or 30), 500))
|
||
|
||
_owned_project_or_404(payload.project_id, db, current_user)
|
||
source_frame = _owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
|
||
|
||
seed = payload.seed.model_dump(exclude_none=True)
|
||
polygons = seed.get("polygons") or []
|
||
bbox = seed.get("bbox")
|
||
points = seed.get("points") or []
|
||
if not polygons and not bbox and not points:
|
||
raise HTTPException(status_code=400, detail="Propagation requires seed polygons, bbox, or points")
|
||
|
||
frames = db.query(Frame).filter(Frame.project_id == payload.project_id).order_by(Frame.frame_index).all()
|
||
source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None)
|
||
if source_position is None:
|
||
raise HTTPException(status_code=404, detail="Source frame is not in project frame sequence")
|
||
|
||
selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames)
|
||
if len(selected_frames) == 0:
|
||
raise HTTPException(status_code=400, detail="No frames available for propagation")
|
||
|
||
try:
|
||
with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload.project_id}_") as tmpdir:
|
||
frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir))
|
||
propagated = sam_registry.propagate_video(
|
||
payload.model,
|
||
frame_paths,
|
||
source_relative_index,
|
||
seed,
|
||
direction,
|
||
len(selected_frames),
|
||
)
|
||
except ModelUnavailableError as exc:
|
||
raise HTTPException(status_code=503, detail=str(exc)) from exc
|
||
except NotImplementedError as exc:
|
||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||
except ValueError as exc:
|
||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.error("Video propagation failed: %s", exc)
|
||
raise HTTPException(status_code=500, detail=f"Video propagation failed: {exc}") from exc
|
||
|
||
created: list[Annotation] = []
|
||
if payload.save_annotations:
|
||
class_metadata = seed.get("class_metadata")
|
||
template_id = seed.get("template_id")
|
||
label = seed.get("label") or "Propagated Mask"
|
||
color = seed.get("color") or "#06b6d4"
|
||
model_id = sam_registry.normalize_model_id(payload.model)
|
||
seed_smoothing = seed.get("smoothing")
|
||
smoothing = _normalize_smoothing_options(
|
||
seed_smoothing.get("strength"),
|
||
seed_smoothing.get("method"),
|
||
) if isinstance(seed_smoothing, dict) else None
|
||
if smoothing and smoothing["strength"] <= 0:
|
||
smoothing = None
|
||
|
||
for frame_result in propagated:
|
||
relative_index = int(frame_result.get("frame_index", -1))
|
||
if relative_index < 0 or relative_index >= len(selected_frames):
|
||
continue
|
||
frame = selected_frames[relative_index]
|
||
if not payload.include_source and frame.id == source_frame.id:
|
||
continue
|
||
result_polygons = frame_result.get("polygons") or []
|
||
scores = frame_result.get("scores") or []
|
||
for polygon_index, polygon in enumerate(result_polygons):
|
||
if len(polygon) < 3:
|
||
continue
|
||
polygon_to_save = _smooth_polygon(polygon, smoothing) if smoothing else polygon
|
||
annotation = Annotation(
|
||
project_id=payload.project_id,
|
||
frame_id=frame.id,
|
||
template_id=template_id,
|
||
mask_data={
|
||
"polygons": [polygon_to_save],
|
||
"label": label,
|
||
"color": color,
|
||
"source": f"{model_id}_propagation",
|
||
"propagated_from_frame_id": source_frame.id,
|
||
"propagated_from_frame_index": source_frame.frame_index,
|
||
"score": scores[polygon_index] if polygon_index < len(scores) else None,
|
||
**({"geometry_smoothing": smoothing} if smoothing else {}),
|
||
**({"class": class_metadata} if class_metadata else {}),
|
||
},
|
||
points=None,
|
||
bbox=_polygon_bbox(polygon_to_save),
|
||
)
|
||
db.add(annotation)
|
||
created.append(annotation)
|
||
|
||
db.commit()
|
||
for annotation in created:
|
||
db.refresh(annotation)
|
||
|
||
return {
|
||
"model": sam_registry.normalize_model_id(payload.model),
|
||
"direction": direction,
|
||
"source_frame_id": source_frame.id,
|
||
"processed_frame_count": len(selected_frames),
|
||
"created_annotation_count": len(created),
|
||
"annotations": created,
|
||
}
|
||
|
||
|
||
@router.post(
|
||
"/propagate/task",
|
||
status_code=status.HTTP_202_ACCEPTED,
|
||
response_model=ProcessingTaskOut,
|
||
summary="Queue a background video propagation task",
|
||
)
|
||
def queue_propagate_task(
|
||
payload: PropagateTaskRequest,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(require_editor),
|
||
) -> ProcessingTaskOut:
|
||
"""Queue multiple seed/direction propagation steps as one background task."""
|
||
_owned_project_or_404(payload.project_id, db, current_user)
|
||
source_frame = _owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
|
||
|
||
if not payload.steps:
|
||
raise HTTPException(status_code=400, detail="Propagation task requires at least one step")
|
||
|
||
try:
|
||
model_id = sam_registry.normalize_model_id(payload.model)
|
||
except ValueError as exc:
|
||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||
|
||
for step in payload.steps:
|
||
direction = step.direction.lower()
|
||
if direction not in {"forward", "backward"}:
|
||
raise HTTPException(status_code=400, detail="direction must be forward or backward")
|
||
seed = step.seed.model_dump(exclude_none=True)
|
||
if not (seed.get("polygons") or seed.get("bbox") or seed.get("points")):
|
||
raise HTTPException(status_code=400, detail="Propagation requires seed polygons, bbox, or points")
|
||
|
||
task_payload = payload.model_dump(exclude_none=True)
|
||
task_payload["model"] = model_id
|
||
task = ProcessingTask(
|
||
task_type="propagate_masks",
|
||
status=TASK_STATUS_QUEUED,
|
||
progress=0,
|
||
message="自动传播任务已入队",
|
||
project_id=payload.project_id,
|
||
payload=task_payload,
|
||
)
|
||
db.add(task)
|
||
db.commit()
|
||
db.refresh(task)
|
||
publish_task_progress_event(task)
|
||
|
||
async_result = propagate_project_masks.delay(task.id)
|
||
task.celery_task_id = async_result.id
|
||
db.commit()
|
||
db.refresh(task)
|
||
publish_task_progress_event(task)
|
||
|
||
logger.info("Queued propagation task id=%s project_id=%s celery_id=%s", task.id, payload.project_id, async_result.id)
|
||
return task
|
||
|
||
|
||
@router.post(
|
||
"/auto",
|
||
response_model=PredictResponse,
|
||
summary="Run automatic segmentation",
|
||
)
|
||
def auto_segment(
|
||
image_id: int,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(require_editor),
|
||
) -> dict:
|
||
"""Run automatic mask generation on a frame using a grid of point prompts."""
|
||
frame = _owned_frame_or_404(image_id, db, current_user)
|
||
|
||
image = _load_frame_image(frame)
|
||
try:
|
||
polygons, scores = sam_registry.predict_auto(None, image)
|
||
except ModelUnavailableError as exc:
|
||
raise HTTPException(status_code=503, detail=str(exc)) from exc
|
||
|
||
return {"polygons": polygons, "scores": scores}
|
||
|
||
|
||
@router.post(
|
||
"/annotate",
|
||
response_model=AnnotationOut,
|
||
status_code=status.HTTP_201_CREATED,
|
||
summary="Save an AI-generated annotation",
|
||
)
|
||
def save_annotation(
|
||
payload: AnnotationCreate,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(require_editor),
|
||
) -> Annotation:
|
||
"""Persist an annotation (mask, points, bbox) into the database."""
|
||
_owned_project_or_404(payload.project_id, db, current_user)
|
||
|
||
if payload.frame_id:
|
||
_owned_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
|
||
if payload.template_id:
|
||
_visible_template_or_404(payload.template_id, db, current_user)
|
||
|
||
annotation = Annotation(**payload.model_dump())
|
||
db.add(annotation)
|
||
db.commit()
|
||
db.refresh(annotation)
|
||
logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id)
|
||
return annotation
|
||
|
||
|
||
@router.post(
|
||
"/import-gt-mask",
|
||
response_model=List[AnnotationOut],
|
||
status_code=status.HTTP_201_CREATED,
|
||
summary="Import a GT mask and reduce components to editable point regions",
|
||
)
|
||
async def import_gt_mask(
|
||
project_id: int = Form(...),
|
||
frame_id: int = Form(...),
|
||
template_id: int | None = Form(None),
|
||
label: str = Form("GT Mask"),
|
||
color: str = Form("#22c55e"),
|
||
unknown_color_policy: str = Form("undefined"),
|
||
file: UploadFile = File(...),
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(require_editor),
|
||
) -> List[Annotation]:
|
||
"""Convert a binary/label mask image into persisted polygon annotations.
|
||
|
||
Each connected component becomes one annotation. The `points` field stores a
|
||
positive seed point at the component's distance-transform center, which gives
|
||
the frontend an editable point-region representation instead of a static
|
||
bitmap layer.
|
||
"""
|
||
_owned_project_or_404(project_id, db, current_user)
|
||
frame = _owned_frame_or_404(frame_id, db, current_user, project_id)
|
||
|
||
if unknown_color_policy not in {"discard", "undefined"}:
|
||
raise HTTPException(status_code=400, detail="unknown_color_policy must be discard or undefined")
|
||
|
||
template: Template | None = None
|
||
if template_id is not None:
|
||
template = _visible_template_or_404(template_id, db, current_user)
|
||
|
||
data = await file.read()
|
||
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||
if image is None:
|
||
raise HTTPException(status_code=400, detail="Invalid mask image")
|
||
|
||
invalid_format_detail = (
|
||
"GT Mask 图片不符合要求:仅支持 8-bit 灰度图,或 8-bit RGB 三通道完全相同的 maskid 图"
|
||
"(背景 0,像素值为 1-255 的 maskid)。"
|
||
)
|
||
if image.dtype != np.uint8:
|
||
raise HTTPException(status_code=400, detail=invalid_format_detail)
|
||
|
||
if image.ndim == 2:
|
||
label_image = image
|
||
elif image.ndim == 3 and image.shape[2] >= 3:
|
||
channels = image[:, :, :3]
|
||
# GT label images are maskid maps: either grayscale or RGB/BGR where
|
||
# all three color channels contain the same maskid value [X, X, X].
|
||
if not (np.array_equal(channels[:, :, 0], channels[:, :, 1]) and np.array_equal(channels[:, :, 1], channels[:, :, 2])):
|
||
raise HTTPException(status_code=400, detail=invalid_format_detail)
|
||
label_image = channels[:, :, 0]
|
||
else:
|
||
raise HTTPException(status_code=400, detail=invalid_format_detail)
|
||
|
||
width = int(frame.width or image.shape[1])
|
||
height = int(frame.height or image.shape[0])
|
||
original_height, original_width = int(label_image.shape[0]), int(label_image.shape[1])
|
||
resized_to_frame = original_width != width or original_height != height
|
||
if resized_to_frame:
|
||
label_image = cv2.resize(label_image, (width, height), interpolation=cv2.INTER_NEAREST)
|
||
|
||
by_maskid, _by_color = _template_class_maps(template)
|
||
has_template_classes = bool(by_maskid)
|
||
fallback_color = _normalize_hex_color(color) or "#22c55e"
|
||
|
||
import_items: list[dict[str, Any]] = []
|
||
skipped_unknown = 0
|
||
label_values = [int(value) for value in np.unique(label_image) if int(value) > 0]
|
||
for label_value in label_values:
|
||
class_meta = by_maskid.get(label_value)
|
||
is_unknown = has_template_classes and class_meta is None
|
||
if is_unknown and unknown_color_policy == "discard":
|
||
skipped_unknown += 1
|
||
continue
|
||
if class_meta:
|
||
annotation_label = class_meta["name"]
|
||
annotation_color = class_meta["color"]
|
||
elif is_unknown:
|
||
annotation_label = _gt_unknown_label(label_value)
|
||
annotation_color = fallback_color
|
||
else:
|
||
annotation_label = f"{label} {label_value}" if len(label_values) > 1 else label
|
||
annotation_color = fallback_color
|
||
import_items.append({
|
||
"token": label_value,
|
||
"binary": np.where(label_image == label_value, 255, 0).astype(np.uint8),
|
||
"label": annotation_label,
|
||
"color": annotation_color,
|
||
"class": class_meta,
|
||
"unknown": is_unknown,
|
||
"metadata": {
|
||
"gt_label_value": label_value,
|
||
"gt_original_size": {"width": original_width, "height": original_height},
|
||
"gt_resized_to_frame": resized_to_frame,
|
||
},
|
||
})
|
||
|
||
if not import_items:
|
||
if skipped_unknown > 0:
|
||
raise HTTPException(status_code=400, detail="No matching GT mask classes found")
|
||
raise HTTPException(status_code=400, detail=GT_MASK_EMPTY_DETAIL)
|
||
|
||
annotations: list[Annotation] = []
|
||
for item in import_items:
|
||
binary = item["binary"]
|
||
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
||
|
||
for contour in contours:
|
||
if cv2.contourArea(contour) < 1:
|
||
continue
|
||
|
||
polygon = _normalized_contour(contour, binary.shape[1], binary.shape[0])
|
||
if len(polygon) < 3:
|
||
continue
|
||
|
||
component = np.zeros_like(binary, dtype=np.uint8)
|
||
cv2.drawContours(component, [contour], -1, 1, thickness=-1)
|
||
seed_point = _component_seed_point(component, binary.shape[1], binary.shape[0])
|
||
bbox = _contour_bbox(contour, binary.shape[1], binary.shape[0])
|
||
mask_data = {
|
||
"polygons": [polygon],
|
||
"label": item["label"],
|
||
"color": item["color"],
|
||
"source": "gt_mask",
|
||
"image_size": {"width": width, "height": height},
|
||
**item["metadata"],
|
||
}
|
||
if item["class"]:
|
||
mask_data["class"] = item["class"]
|
||
if item["unknown"]:
|
||
mask_data["gt_unknown_class"] = True
|
||
|
||
annotation = Annotation(
|
||
project_id=project_id,
|
||
frame_id=frame_id,
|
||
template_id=template_id,
|
||
mask_data=mask_data,
|
||
points=[seed_point],
|
||
bbox=bbox,
|
||
)
|
||
db.add(annotation)
|
||
annotations.append(annotation)
|
||
|
||
if not annotations:
|
||
raise HTTPException(status_code=400, detail=GT_MASK_EMPTY_DETAIL)
|
||
|
||
db.commit()
|
||
for annotation in annotations:
|
||
db.refresh(annotation)
|
||
logger.info("Imported %s GT mask annotations for project_id=%s frame_id=%s", len(annotations), project_id, frame_id)
|
||
return annotations
|
||
|
||
|
||
@router.get(
|
||
"/annotations",
|
||
response_model=List[AnnotationOut],
|
||
summary="List saved annotations for a project",
|
||
)
|
||
def list_annotations(
|
||
project_id: int,
|
||
frame_id: int | None = None,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
) -> List[Annotation]:
|
||
"""Return persisted annotations for a project, optionally scoped to one frame."""
|
||
_owned_project_or_404(project_id, db, current_user)
|
||
|
||
query = db.query(Annotation).filter(Annotation.project_id == project_id)
|
||
if frame_id is not None:
|
||
_owned_frame_or_404(frame_id, db, current_user, project_id)
|
||
query = query.filter(Annotation.frame_id == frame_id)
|
||
return query.order_by(Annotation.id).all()
|
||
|
||
|
||
@router.patch(
|
||
"/annotations/{annotation_id}",
|
||
response_model=AnnotationOut,
|
||
summary="Update a saved annotation",
|
||
)
|
||
def update_annotation(
|
||
annotation_id: int,
|
||
payload: AnnotationUpdate,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(require_editor),
|
||
) -> Annotation:
|
||
"""Update mutable annotation fields persisted in the database."""
|
||
annotation = (
|
||
db.query(Annotation)
|
||
.join(Project, Project.id == Annotation.project_id)
|
||
.filter(Annotation.id == annotation_id, Project.owner_user_id == current_user.id)
|
||
.first()
|
||
)
|
||
if not annotation:
|
||
raise HTTPException(status_code=404, detail="Annotation not found")
|
||
|
||
updates = payload.model_dump(exclude_unset=True)
|
||
if "template_id" in updates and updates["template_id"] is not None:
|
||
_visible_template_or_404(updates["template_id"], db, current_user)
|
||
|
||
for field, value in updates.items():
|
||
setattr(annotation, field, value)
|
||
|
||
db.commit()
|
||
db.refresh(annotation)
|
||
logger.info("Updated annotation id=%s", annotation.id)
|
||
return annotation
|
||
|
||
|
||
@router.delete(
|
||
"/annotations/{annotation_id}",
|
||
status_code=status.HTTP_204_NO_CONTENT,
|
||
summary="Delete a saved annotation",
|
||
)
|
||
def delete_annotation(
|
||
annotation_id: int,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(require_editor),
|
||
) -> Response:
|
||
"""Delete an annotation and its derived mask rows through ORM cascade."""
|
||
annotation = (
|
||
db.query(Annotation)
|
||
.join(Project, Project.id == Annotation.project_id)
|
||
.filter(Annotation.id == annotation_id, Project.owner_user_id == current_user.id)
|
||
.first()
|
||
)
|
||
if not annotation:
|
||
raise HTTPException(status_code=404, detail="Annotation not found")
|
||
|
||
db.delete(annotation)
|
||
db.commit()
|
||
logger.info("Deleted annotation id=%s", annotation_id)
|
||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|