Files
Pre_Seg_Server/backend/routers/ai.py
admin 29a1a87e52 feat: 完善 SAM2.1 模型选择与标注工作流
- 后端 SAM2 引擎新增 sam2.1_hiera_tiny、sam2.1_hiera_small、sam2.1_hiera_base_plus、sam2.1_hiera_large 四个变体定义,并按变体维护 checkpoint/config、image predictor、video predictor、加载状态、错误信息和真实状态回报。

- 后端 SAM registry 仅暴露当前产品启用的 SAM2.1 变体,保留 sam2 作为 tiny 兼容别名,拒绝 sam3 产品入口,并把 point、box、interactive、auto、propagate 都分发到所选 SAM2.1 变体。

- 后端默认配置和下载脚本切换到 SAM2.1 checkpoint 命名,支持 legacy SAM2 checkpoint fallback,并在状态消息中标出 fallback 使用情况。

- 前端全局 AI 模型状态新增 SAM2.1 tiny/small/base+/large 类型和默认 tiny,API 请求默认携带 sam2.1_hiera_tiny,AI 页面提供模型变体选择和所选模型状态展示。

- AI 智能分割页移除当前产品不使用的 SAM3/文本提示入口,保留正向点、反向点、框选和参数开关;AI 页只展示本页生成的候选 mask,并支持遮罩清晰度调节、候选 mask 上继续加正/反点、清空本页候选、推送到工作区编辑。

- 工作区和 Canvas 补强 SAM2 交互式细化链路:框选后正/反点继续细化同一个候选 mask,反向点请求启用背景过滤,空结果会移除被否定候选;AI 推送到工作区后保留选中态和未保存 draft mask。

- 工作区标注保存闭环补强:未保存 mask 可归档保存,dirty saved mask 可更新,保存后用后端 saved annotation 替换已提交 draft,清空/删除已保存 mask 时同步后端删除。

- Dashboard 任务进度区改为展示 queued、running、success、failed、cancelled 最近任务,处理中统计只计算 queued/running,并保留近期完成记录。

- 时间轴在顶部时间进度条和底部缩略图导航轴之间新增已编辑帧标记带,基于当前项目帧内 masks 标出已有编辑/标注的帧,并支持点击标记跳转。

- 前端测试覆盖 SAM2.1 变体选择、模型状态徽标、AI 页候选隔离、遮罩透明度、候选上追加正/反点、推送工作区保留选择、Canvas 交互式细化、VideoWorkspace 传播/保存、Dashboard 进度和时间轴已编辑帧标记。

- 后端测试覆盖 SAM2.1 变体状态、sam2 alias 兼容、sam3 禁用、semantic 禁用、传播标注保存、Dashboard 最近任务状态和 SAM3 历史测试跳过说明。

- README、AGENTS 和 doc 文档同步当前真实进度,更新 SAM2.1 变体、SAM3 禁用、接口契约、设计冻结、需求冻结、前端元素审计、实施计划、FastAPI docs 说明和测试矩阵。
2026-05-01 23:39:53 +08:00

721 lines
28 KiB
Python

"""AI inference endpoints using selectable SAM runtimes."""
import logging
import tempfile
from pathlib import Path
from typing import Any, List
import cv2
import numpy as np
from fastapi import APIRouter, Depends, File, Form, HTTPException, Response, UploadFile, status
from sqlalchemy.orm import Session
from database import get_db
from minio_client import download_file
from models import Project, Frame, Template, Annotation
from schemas import (
AiRuntimeStatus,
PredictRequest,
PredictResponse,
PropagateRequest,
PropagateResponse,
AnnotationOut,
AnnotationCreate,
AnnotationUpdate,
)
from services.sam_registry import ModelUnavailableError, sam_registry
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/ai", tags=["AI"])
def _load_frame_image(frame: Frame) -> np.ndarray:
"""Download a frame from MinIO and decode it to an RGB numpy array."""
try:
data = download_file(frame.image_url)
arr = np.frombuffer(data, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise ValueError("OpenCV could not decode image")
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
except Exception as exc: # noqa: BLE001
logger.error("Failed to load frame image: %s", exc)
raise HTTPException(status_code=500, detail="Failed to load frame image") from exc
def _normalized_contour(contour: np.ndarray, width: int, height: int) -> list[list[float]]:
"""Approximate a contour and convert it to normalized polygon coordinates."""
arc_length = cv2.arcLength(contour, True)
epsilon = max(1.0, arc_length * 0.01)
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape(-1, 2)
if len(points) < 3:
points = contour.reshape(-1, 2)
return [
[
min(max(float(x) / max(width, 1), 0.0), 1.0),
min(max(float(y) / max(height, 1), 0.0), 1.0),
]
for x, y in points
]
def _contour_bbox(contour: np.ndarray, width: int, height: int) -> list[float]:
x, y, w, h = cv2.boundingRect(contour)
return [
min(max(float(x) / max(width, 1), 0.0), 1.0),
min(max(float(y) / max(height, 1), 0.0), 1.0),
min(max(float(w) / max(width, 1), 0.0), 1.0),
min(max(float(h) / max(height, 1), 0.0), 1.0),
]
def _polygon_bbox(polygon: list[list[float]]) -> list[float]:
xs = [_clamp01(point[0]) for point in polygon]
ys = [_clamp01(point[1]) for point in polygon]
left, right = min(xs), max(xs)
top, bottom = min(ys), max(ys)
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
def _frame_window(
frames: list[Frame],
source_position: int,
direction: str,
max_frames: int,
) -> tuple[list[Frame], int]:
count = max(1, min(max_frames, len(frames)))
if direction == "backward":
start = max(0, source_position - count + 1)
return frames[start:source_position + 1], source_position - start
if direction == "both":
before = (count - 1) // 2
after = count - 1 - before
start = max(0, source_position - before)
end = min(len(frames), source_position + after + 1)
while end - start < count and start > 0:
start -= 1
while end - start < count and end < len(frames):
end += 1
return frames[start:end], source_position - start
end = min(len(frames), source_position + count)
return frames[source_position:end], 0
def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]:
paths = []
for index, frame in enumerate(frames):
data = download_file(frame.image_url)
path = directory / f"frame_{index:06d}.jpg"
path.write_bytes(data)
paths.append(str(path))
return paths
def _component_seed_point(component_mask: np.ndarray, width: int, height: int) -> list[float]:
"""Reduce a binary component to one positive prompt point using distance transform."""
dist = cv2.distanceTransform(component_mask.astype(np.uint8), cv2.DIST_L2, 5)
_, _, _, max_loc = cv2.minMaxLoc(dist)
x, y = max_loc
return [
min(max(float(x) / max(width, 1), 0.0), 1.0),
min(max(float(y) / max(height, 1), 0.0), 1.0),
]
def _clamp01(value: float) -> float:
return min(max(float(value), 0.0), 1.0)
def _point_in_polygon(point: list[float], polygon: list[list[float]]) -> bool:
"""Return whether a normalized point is inside a normalized polygon."""
if len(polygon) < 3:
return False
x, y = point
inside = False
j = len(polygon) - 1
for i, current in enumerate(polygon):
xi, yi = current
xj, yj = polygon[j]
intersects = ((yi > y) != (yj > y)) and (
x < (xj - xi) * (y - yi) / ((yj - yi) or 1e-9) + xi
)
if intersects:
inside = not inside
j = i
return inside
def _crop_bounds_from_points(points: list[list[float]], margin: float) -> tuple[float, float, float, float]:
xs = [_clamp01(point[0]) for point in points]
ys = [_clamp01(point[1]) for point in points]
x1 = max(0.0, min(xs) - margin)
y1 = max(0.0, min(ys) - margin)
x2 = min(1.0, max(xs) + margin)
y2 = min(1.0, max(ys) + margin)
if x2 - x1 < 0.05:
center = (x1 + x2) / 2
x1 = max(0.0, center - 0.025)
x2 = min(1.0, center + 0.025)
if y2 - y1 < 0.05:
center = (y1 + y2) / 2
y1 = max(0.0, center - 0.025)
y2 = min(1.0, center + 0.025)
return x1, y1, x2, y2
def _crop_image(image: np.ndarray, bounds: tuple[float, float, float, float]) -> np.ndarray:
height, width = image.shape[:2]
x1, y1, x2, y2 = bounds
left = int(round(x1 * width))
top = int(round(y1 * height))
right = max(left + 1, int(round(x2 * width)))
bottom = max(top + 1, int(round(y2 * height)))
return image[top:bottom, left:right]
def _to_crop_point(point: list[float], bounds: tuple[float, float, float, float]) -> list[float]:
x1, y1, x2, y2 = bounds
return [
_clamp01((float(point[0]) - x1) / max(x2 - x1, 1e-9)),
_clamp01((float(point[1]) - y1) / max(y2 - y1, 1e-9)),
]
def _from_crop_polygon(
polygon: list[list[float]],
bounds: tuple[float, float, float, float],
) -> list[list[float]]:
x1, y1, x2, y2 = bounds
return [
[
_clamp01(x1 + float(point[0]) * (x2 - x1)),
_clamp01(y1 + float(point[1]) * (y2 - y1)),
]
for point in polygon
]
def _filter_predictions(
polygons: list[list[list[float]]],
scores: list[float],
options: dict[str, Any],
negative_points: list[list[float]] | None = None,
) -> tuple[list[list[list[float]]], list[float]]:
if not options.get("auto_filter_background"):
return polygons, scores
min_score = float(options.get("min_score", 0.0) or 0.0)
next_polygons: list[list[list[float]]] = []
next_scores: list[float] = []
for index, polygon in enumerate(polygons):
score = scores[index] if index < len(scores) else 0.0
if score < min_score:
continue
if negative_points and any(_point_in_polygon(point, polygon) for point in negative_points):
continue
next_polygons.append(polygon)
next_scores.append(score)
return next_polygons, next_scores
@router.post(
"/predict",
response_model=PredictResponse,
summary="Run SAM inference with a prompt",
)
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
"""Execute selected SAM segmentation given an image and a prompt.
- **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized
coordinates or `{ "points": [[x, y], ...], "labels": [1, 0, ...] }`.
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
- **interactive**: `prompt_data` is `{ "box": [...], "points": [[x, y]], "labels": [1, 0] }`.
- **semantic**: disabled in the current SAM 2.1 point/box product flow.
"""
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
image = _load_frame_image(frame)
prompt_type = payload.prompt_type.lower()
options = payload.options or {}
polygons: List[List[List[float]]] = []
scores: List[float] = []
negative_points: list[list[float]] = []
try:
if prompt_type == "point":
point_payload = payload.prompt_data
if isinstance(point_payload, dict):
points = point_payload.get("points")
labels = point_payload.get("labels")
else:
points = point_payload
labels = None
if not isinstance(points, list) or len(points) == 0:
raise HTTPException(status_code=400, detail="Invalid point prompt data")
if not isinstance(labels, list) or len(labels) != len(points):
labels = [1] * len(points)
negative_points = [
point for point, label in zip(points, labels) if label == 0
]
inference_image = image
inference_points = points
crop_bounds = None
if options.get("crop_to_prompt"):
margin = float(options.get("crop_margin", 0.25) or 0.25)
crop_bounds = _crop_bounds_from_points(points, margin)
inference_image = _crop_image(image, crop_bounds)
inference_points = [_to_crop_point(point, crop_bounds) for point in points]
polygons, scores = sam_registry.predict_points(payload.model, inference_image, inference_points, labels)
if crop_bounds:
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
elif prompt_type == "box":
box = payload.prompt_data
if not isinstance(box, list) or len(box) != 4:
raise HTTPException(status_code=400, detail="Invalid box prompt data")
inference_image = image
inference_box = box
crop_bounds = None
if options.get("crop_to_prompt"):
margin = float(options.get("crop_margin", 0.05) or 0.05)
crop_bounds = _crop_bounds_from_points([[box[0], box[1]], [box[2], box[3]]], margin)
inference_image = _crop_image(image, crop_bounds)
inference_box = [
*_to_crop_point([box[0], box[1]], crop_bounds),
*_to_crop_point([box[2], box[3]], crop_bounds),
]
polygons, scores = sam_registry.predict_box(payload.model, inference_image, inference_box)
if crop_bounds:
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
elif prompt_type == "interactive":
prompt = payload.prompt_data
if not isinstance(prompt, dict):
raise HTTPException(status_code=400, detail="Invalid interactive prompt data")
box = prompt.get("box")
points = prompt.get("points") or []
labels = prompt.get("labels")
if box is not None and (not isinstance(box, list) or len(box) != 4):
raise HTTPException(status_code=400, detail="Invalid interactive box prompt data")
if not isinstance(points, list):
raise HTTPException(status_code=400, detail="Invalid interactive point prompt data")
if not box and len(points) == 0:
raise HTTPException(status_code=400, detail="Interactive prompt requires a box or points")
if not isinstance(labels, list) or len(labels) != len(points):
labels = [1] * len(points)
negative_points = [
point for point, label in zip(points, labels) if label == 0
]
inference_image = image
inference_box = box
inference_points = points
crop_bounds = None
if options.get("crop_to_prompt"):
margin = float(options.get("crop_margin", 0.05) or 0.05)
crop_points = list(points)
if box:
crop_points.extend([[box[0], box[1]], [box[2], box[3]]])
crop_bounds = _crop_bounds_from_points(crop_points, margin)
inference_image = _crop_image(image, crop_bounds)
inference_points = [_to_crop_point(point, crop_bounds) for point in points]
if box:
inference_box = [
*_to_crop_point([box[0], box[1]], crop_bounds),
*_to_crop_point([box[2], box[3]], crop_bounds),
]
polygons, scores = sam_registry.predict_interactive(
payload.model,
inference_image,
inference_box,
inference_points,
labels,
)
if crop_bounds:
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
elif prompt_type == "semantic":
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
min_score = options.get("min_score")
confidence_threshold = None
if min_score is not None:
try:
parsed_min_score = float(min_score)
if parsed_min_score > 0:
confidence_threshold = parsed_min_score
except (TypeError, ValueError):
confidence_threshold = None
polygons, scores = sam_registry.predict_semantic(
payload.model,
image,
text,
confidence_threshold=confidence_threshold,
)
else:
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
except ModelUnavailableError as exc:
raise HTTPException(status_code=503, detail=str(exc)) from exc
except NotImplementedError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
polygons, scores = _filter_predictions(polygons, scores, options, negative_points)
logger.info(
"AI predict completed model=%s prompt_type=%s frame_id=%s polygons=%d",
payload.model or "default",
prompt_type,
payload.image_id,
len(polygons),
)
return {"polygons": polygons, "scores": scores}
@router.get(
"/models/status",
response_model=AiRuntimeStatus,
summary="Get SAM model and GPU runtime status",
)
def model_status(selected_model: str | None = None) -> dict:
"""Return real runtime availability for GPU and the currently enabled SAM model."""
try:
return sam_registry.runtime_status(selected_model)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@router.post(
"/propagate",
response_model=PropagateResponse,
summary="Propagate one current-frame region across a video frame segment",
)
def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
"""Track one selected region from the current frame across nearby frames.
SAM 2 uses the official video predictor with the selected mask as the seed.
SAM 3 video tracking is currently disabled in this product flow.
"""
direction = payload.direction.lower()
if direction not in {"forward", "backward", "both"}:
raise HTTPException(status_code=400, detail="direction must be forward, backward, or both")
max_frames = max(1, min(int(payload.max_frames or 30), 500))
project = db.query(Project).filter(Project.id == payload.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
source_frame = db.query(Frame).filter(
Frame.id == payload.frame_id,
Frame.project_id == payload.project_id,
).first()
if not source_frame:
raise HTTPException(status_code=404, detail="Frame not found")
seed = payload.seed.model_dump(exclude_none=True)
polygons = seed.get("polygons") or []
bbox = seed.get("bbox")
points = seed.get("points") or []
if not polygons and not bbox and not points:
raise HTTPException(status_code=400, detail="Propagation requires seed polygons, bbox, or points")
frames = db.query(Frame).filter(Frame.project_id == payload.project_id).order_by(Frame.frame_index).all()
source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None)
if source_position is None:
raise HTTPException(status_code=404, detail="Source frame is not in project frame sequence")
selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames)
if len(selected_frames) == 0:
raise HTTPException(status_code=400, detail="No frames available for propagation")
try:
with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload.project_id}_") as tmpdir:
frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir))
propagated = sam_registry.propagate_video(
payload.model,
frame_paths,
source_relative_index,
seed,
direction,
len(selected_frames),
)
except ModelUnavailableError as exc:
raise HTTPException(status_code=503, detail=str(exc)) from exc
except NotImplementedError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc: # noqa: BLE001
logger.error("Video propagation failed: %s", exc)
raise HTTPException(status_code=500, detail=f"Video propagation failed: {exc}") from exc
created: list[Annotation] = []
if payload.save_annotations:
class_metadata = seed.get("class_metadata")
template_id = seed.get("template_id")
label = seed.get("label") or "Propagated Mask"
color = seed.get("color") or "#06b6d4"
model_id = sam_registry.normalize_model_id(payload.model)
for frame_result in propagated:
relative_index = int(frame_result.get("frame_index", -1))
if relative_index < 0 or relative_index >= len(selected_frames):
continue
frame = selected_frames[relative_index]
if not payload.include_source and frame.id == source_frame.id:
continue
result_polygons = frame_result.get("polygons") or []
scores = frame_result.get("scores") or []
for polygon_index, polygon in enumerate(result_polygons):
if len(polygon) < 3:
continue
annotation = Annotation(
project_id=payload.project_id,
frame_id=frame.id,
template_id=template_id,
mask_data={
"polygons": [polygon],
"label": label,
"color": color,
"source": f"{model_id}_propagation",
"propagated_from_frame_id": source_frame.id,
"propagated_from_frame_index": source_frame.frame_index,
"score": scores[polygon_index] if polygon_index < len(scores) else None,
**({"class": class_metadata} if class_metadata else {}),
},
points=None,
bbox=_polygon_bbox(polygon),
)
db.add(annotation)
created.append(annotation)
db.commit()
for annotation in created:
db.refresh(annotation)
return {
"model": sam_registry.normalize_model_id(payload.model),
"direction": direction,
"source_frame_id": source_frame.id,
"processed_frame_count": len(selected_frames),
"created_annotation_count": len(created),
"annotations": created,
}
@router.post(
"/auto",
response_model=PredictResponse,
summary="Run automatic segmentation",
)
def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
"""Run automatic mask generation on a frame using a grid of point prompts."""
frame = db.query(Frame).filter(Frame.id == image_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
image = _load_frame_image(frame)
try:
polygons, scores = sam_registry.predict_auto(None, image)
except ModelUnavailableError as exc:
raise HTTPException(status_code=503, detail=str(exc)) from exc
return {"polygons": polygons, "scores": scores}
@router.post(
"/annotate",
response_model=AnnotationOut,
status_code=status.HTTP_201_CREATED,
summary="Save an AI-generated annotation",
)
def save_annotation(
payload: AnnotationCreate,
db: Session = Depends(get_db),
) -> Annotation:
"""Persist an annotation (mask, points, bbox) into the database."""
project = db.query(Project).filter(Project.id == payload.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if payload.frame_id:
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
annotation = Annotation(**payload.model_dump())
db.add(annotation)
db.commit()
db.refresh(annotation)
logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id)
return annotation
@router.post(
"/import-gt-mask",
response_model=List[AnnotationOut],
status_code=status.HTTP_201_CREATED,
summary="Import a GT mask and reduce components to editable point regions",
)
async def import_gt_mask(
project_id: int = Form(...),
frame_id: int = Form(...),
template_id: int | None = Form(None),
label: str = Form("GT Mask"),
color: str = Form("#22c55e"),
file: UploadFile = File(...),
db: Session = Depends(get_db),
) -> List[Annotation]:
"""Convert a binary/label mask image into persisted polygon annotations.
Each connected component becomes one annotation. The `points` field stores a
positive seed point at the component's distance-transform center, which gives
the frontend an editable point-region representation instead of a static
bitmap layer.
"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
frame = db.query(Frame).filter(Frame.id == frame_id, Frame.project_id == project_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
if template_id is not None:
template = db.query(Template).filter(Template.id == template_id).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
data = await file.read()
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
if image is None:
raise HTTPException(status_code=400, detail="Invalid mask image")
width = int(frame.width or image.shape[1])
height = int(frame.height or image.shape[0])
label_values = [int(value) for value in np.unique(image) if int(value) > 0]
if not label_values:
raise HTTPException(status_code=400, detail="No foreground mask regions found")
has_multiple_labels = len(label_values) > 1
annotations: list[Annotation] = []
for label_value in label_values:
binary = np.where(image == label_value, 255, 0).astype(np.uint8)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
annotation_label = f"{label} {label_value}" if has_multiple_labels else label
for contour in contours:
if cv2.contourArea(contour) < 1:
continue
polygon = _normalized_contour(contour, image.shape[1], image.shape[0])
if len(polygon) < 3:
continue
component = np.zeros_like(binary, dtype=np.uint8)
cv2.drawContours(component, [contour], -1, 1, thickness=-1)
seed_point = _component_seed_point(component, image.shape[1], image.shape[0])
bbox = _contour_bbox(contour, image.shape[1], image.shape[0])
annotation = Annotation(
project_id=project_id,
frame_id=frame_id,
template_id=template_id,
mask_data={
"polygons": [polygon],
"label": annotation_label,
"color": color,
"source": "gt_mask",
"gt_label_value": label_value,
"image_size": {"width": width, "height": height},
},
points=[seed_point],
bbox=bbox,
)
db.add(annotation)
annotations.append(annotation)
if not annotations:
raise HTTPException(status_code=400, detail="No foreground mask regions found")
db.commit()
for annotation in annotations:
db.refresh(annotation)
logger.info("Imported %s GT mask annotations for project_id=%s frame_id=%s", len(annotations), project_id, frame_id)
return annotations
@router.get(
"/annotations",
response_model=List[AnnotationOut],
summary="List saved annotations for a project",
)
def list_annotations(
project_id: int,
frame_id: int | None = None,
db: Session = Depends(get_db),
) -> List[Annotation]:
"""Return persisted annotations for a project, optionally scoped to one frame."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
query = db.query(Annotation).filter(Annotation.project_id == project_id)
if frame_id is not None:
query = query.filter(Annotation.frame_id == frame_id)
return query.order_by(Annotation.id).all()
@router.patch(
"/annotations/{annotation_id}",
response_model=AnnotationOut,
summary="Update a saved annotation",
)
def update_annotation(
annotation_id: int,
payload: AnnotationUpdate,
db: Session = Depends(get_db),
) -> Annotation:
"""Update mutable annotation fields persisted in the database."""
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
if not annotation:
raise HTTPException(status_code=404, detail="Annotation not found")
updates = payload.model_dump(exclude_unset=True)
if "template_id" in updates and updates["template_id"] is not None:
template = db.query(Template).filter(Template.id == updates["template_id"]).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
for field, value in updates.items():
setattr(annotation, field, value)
db.commit()
db.refresh(annotation)
logger.info("Updated annotation id=%s", annotation.id)
return annotation
@router.delete(
"/annotations/{annotation_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a saved annotation",
)
def delete_annotation(
annotation_id: int,
db: Session = Depends(get_db),
) -> Response:
"""Delete an annotation and its derived mask rows through ORM cascade."""
annotation = db.query(Annotation).filter(Annotation.id == annotation_id).first()
if not annotation:
raise HTTPException(status_code=404, detail="Annotation not found")
db.delete(annotation)
db.commit()
logger.info("Deleted annotation id=%s", annotation_id)
return Response(status_code=status.HTTP_204_NO_CONTENT)