Files
Pre_Seg_Server/backend/routers/ai.py
admin b5413066a0 添加Docker自包含部署分支
- 新增 Seg_Server_Docker 自包含部署内容,包含前后端、FastAPI、Celery、PostgreSQL、Redis、MinIO、演示视频和 DICOM 数据。

- 保留 demo 数据以支持恢复演示出厂设置,排除 SAM 2.1 .pt 权重并在 README 中补充下载命令。

- 补充 GPU 部署、backend/worker 镜像复用、frpc/frps + NPM 公网域名反代部署说明。

- 在 .env/.env.example 中用 # XXXX 标注局域网和公网域名部署需要修改的配置项。

- 添加部署分支 .gitignore,忽略本地模型权重、构建产物、缓存和日志。
2026-05-07 19:06:07 +08:00

1229 lines
48 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""AI inference endpoints using selectable SAM runtimes."""
import logging
import math
import tempfile
from pathlib import Path
from typing import Any, List
import cv2
import numpy as np
from fastapi import APIRouter, Depends, File, Form, HTTPException, Response, UploadFile, status
from sqlalchemy import or_
from sqlalchemy.orm import Session
from database import get_db
from minio_client import download_file
from models import Project, Frame, Template, Annotation, ProcessingTask, User
from routers.auth import get_current_user, require_editor
from schemas import (
AiRuntimeStatus,
MaskAnalysisRequest,
MaskAnalysisResponse,
SmoothMaskRequest,
SmoothMaskResponse,
PredictRequest,
PredictResponse,
PropagateRequest,
PropagateResponse,
PropagateTaskRequest,
ProcessingTaskOut,
AnnotationOut,
AnnotationCreate,
AnnotationUpdate,
)
from progress_events import publish_task_progress_event
from statuses import TASK_STATUS_QUEUED
from worker_tasks import propagate_project_masks
from services.sam_registry import ModelUnavailableError, sam_registry
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/ai", tags=["AI"])
GT_MASK_EMPTY_DETAIL = "GT Mask 图片中没有非背景 maskid 区域。"
GT_IMPORT_MAX_CONTOUR_POINTS = 2048
GT_IMPORT_CONTOUR_EPSILON_RATIO = 0.00075
GT_IMPORT_MIN_CONTOUR_EPSILON = 0.35
RESERVED_UNCLASSIFIED_CLASS = {
"id": "reserved-unclassified",
"name": "待分类",
"color": "#000000",
"zIndex": 0,
"maskId": 0,
"category": "系统保留",
}
def _shared_project_or_404(project_id: int, db: Session, current_user: User) -> Project:
_ = current_user
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return project
def _shared_frame_or_404(frame_id: int, db: Session, current_user: User, project_id: int | None = None) -> Frame:
_ = current_user
query = (
db.query(Frame)
.join(Project, Project.id == Frame.project_id)
.filter(Frame.id == frame_id)
)
if project_id is not None:
query = query.filter(Frame.project_id == project_id)
frame = query.first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
return frame
def _visible_template_or_404(template_id: int, db: Session, current_user: User) -> Template:
template = db.query(Template).filter(
Template.id == template_id,
or_(Template.owner_user_id == current_user.id, Template.owner_user_id.is_(None)),
).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
return template
def _normalize_hex_color(value: Any) -> str | None:
if not isinstance(value, str):
return None
text = value.strip().lower()
if not text:
return None
if not text.startswith("#"):
text = f"#{text}"
if len(text) == 4:
text = "#" + "".join(char * 2 for char in text[1:])
if len(text) != 7:
return None
try:
int(text[1:], 16)
except ValueError:
return None
return text
def _rgb_tuple_to_hex(rgb: tuple[int, int, int]) -> str:
values = []
for channel in rgb:
value = int(channel)
if value > 255:
value = int(round(value / 257))
values.append(min(max(value, 0), 255))
return f"#{values[0]:02x}{values[1]:02x}{values[2]:02x}"
def _template_class_maps(template: Template | None) -> tuple[dict[int, dict[str, Any]], dict[str, dict[str, Any]], dict[str, Any]]:
classes = ((template.mapping_rules or {}).get("classes") if template else None) or []
by_maskid: dict[int, dict[str, Any]] = {}
by_color: dict[str, dict[str, Any]] = {}
unclassified = dict(RESERVED_UNCLASSIFIED_CLASS)
for index, item in enumerate(classes):
if not isinstance(item, dict):
continue
maskid_value = item.get("maskId", item.get("maskid", item.get("mask_id")))
try:
maskid = int(maskid_value)
except (TypeError, ValueError):
maskid = index + 1
color = _normalize_hex_color(item.get("color")) or "#22c55e"
class_meta = {
"id": str(item.get("id") or f"maskid-{maskid}"),
"name": str(item.get("name") or f"类别 {maskid}"),
"color": color,
"zIndex": int(item.get("zIndex", item.get("z_index", index * 10))),
"maskId": maskid,
**({"category": item.get("category")} if item.get("category") else {}),
}
if maskid == 0 or class_meta["id"] == RESERVED_UNCLASSIFIED_CLASS["id"] or class_meta["name"] == RESERVED_UNCLASSIFIED_CLASS["name"]:
unclassified = dict(RESERVED_UNCLASSIFIED_CLASS)
continue
if maskid > 0:
by_maskid[maskid] = class_meta
by_color[color] = class_meta
return by_maskid, by_color, unclassified
def _load_frame_image(frame: Frame) -> np.ndarray:
"""Download a frame from MinIO and decode it to an RGB numpy array."""
try:
data = download_file(frame.image_url)
arr = np.frombuffer(data, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise ValueError("OpenCV could not decode image")
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
except Exception as exc: # noqa: BLE001
logger.error("Failed to load frame image: %s", exc)
raise HTTPException(status_code=500, detail="Failed to load frame image") from exc
def _normalized_contour(contour: np.ndarray, width: int, height: int) -> list[list[float]]:
"""Convert a contour to a detailed normalized polygon with a point-count cap."""
arc_length = cv2.arcLength(contour, True)
epsilon = max(GT_IMPORT_MIN_CONTOUR_EPSILON, arc_length * GT_IMPORT_CONTOUR_EPSILON_RATIO)
approx = cv2.approxPolyDP(contour, epsilon, True)
while len(approx) > GT_IMPORT_MAX_CONTOUR_POINTS and epsilon < arc_length * 0.02:
epsilon *= 1.5
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape(-1, 2)
if len(points) < 3:
points = contour.reshape(-1, 2)
if len(points) > GT_IMPORT_MAX_CONTOUR_POINTS:
step = int(math.ceil(len(points) / GT_IMPORT_MAX_CONTOUR_POINTS))
points = points[::step]
return [
[
min(max(float(x) / max(width, 1), 0.0), 1.0),
min(max(float(y) / max(height, 1), 0.0), 1.0),
]
for x, y in points
]
def _contour_bbox(contour: np.ndarray, width: int, height: int) -> list[float]:
x, y, w, h = cv2.boundingRect(contour)
return [
min(max(float(x) / max(width, 1), 0.0), 1.0),
min(max(float(y) / max(height, 1), 0.0), 1.0),
min(max(float(w) / max(width, 1), 0.0), 1.0),
min(max(float(h) / max(height, 1), 0.0), 1.0),
]
def _polygon_bbox(polygon: list[list[float]]) -> list[float]:
xs = [_clamp01(point[0]) for point in polygon]
ys = [_clamp01(point[1]) for point in polygon]
left, right = min(xs), max(xs)
top, bottom = min(ys), max(ys)
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
def _polygons_bbox(polygons: list[list[list[float]]]) -> list[float]:
points = [point for polygon in polygons for point in polygon if len(point) >= 2]
if not points:
return [0.0, 0.0, 0.0, 0.0]
xs = [_clamp01(point[0]) for point in points]
ys = [_clamp01(point[1]) for point in points]
left, right = min(xs), max(xs)
top, bottom = min(ys), max(ys)
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
def _polygon_area(polygon: list[list[float]]) -> float:
if len(polygon) < 3:
return 0.0
total = 0.0
for index, point in enumerate(polygon):
next_point = polygon[(index + 1) % len(polygon)]
total += _clamp01(point[0]) * _clamp01(next_point[1])
total -= _clamp01(next_point[0]) * _clamp01(point[1])
return abs(total) / 2.0
def _normalize_polygon(polygon: list[list[float]]) -> list[list[float]]:
return [[_clamp01(point[0]), _clamp01(point[1])] for point in polygon if len(point) >= 2]
def _normalize_polygons(polygons: list[list[list[float]]]) -> list[list[list[float]]]:
return [polygon for polygon in (_normalize_polygon(polygon) for polygon in polygons) if len(polygon) >= 3]
def _sample_anchor_points(anchors: list[list[float]], limit: int = 64) -> list[list[float]]:
if len(anchors) <= limit:
return anchors
step = max(1, math.ceil(len(anchors) / limit))
return anchors[::step][:limit]
def _analysis_anchor_summary(polygons: list[list[list[float]]]) -> tuple[int, list[list[float]]]:
anchors: list[list[float]] = []
for polygon in polygons:
if not polygon:
continue
anchors.extend([[_clamp01(point[0]), _clamp01(point[1])] for point in polygon])
return len(anchors), _sample_anchor_points(anchors)
def _normalize_smoothing_options(strength: float | int | None, method: str | None = None) -> dict[str, Any]:
clamped_strength = max(0.0, min(float(strength or 0.0), 100.0))
normalized_method = (method or "chaikin").lower()
if normalized_method != "chaikin":
normalized_method = "chaikin"
return {
"strength": round(clamped_strength, 2),
"method": normalized_method,
}
def _smoothing_ratio(strength: float, curve: float = 1.65) -> float:
normalized = max(0.0, min(float(strength or 0.0), 100.0)) / 100.0
return normalized ** curve
def _chaikin_smooth_polygon(polygon: list[list[float]], iterations: int, corner_cut: float = 0.25) -> list[list[float]]:
points = polygon
q = max(0.02, min(float(corner_cut), 0.25))
for _ in range(max(0, iterations)):
if len(points) < 3:
break
next_points: list[list[float]] = []
for index, current in enumerate(points):
following = points[(index + 1) % len(points)]
next_points.append([
_clamp01((1.0 - q) * current[0] + q * following[0]),
_clamp01((1.0 - q) * current[1] + q * following[1]),
])
next_points.append([
_clamp01(q * current[0] + (1.0 - q) * following[0]),
_clamp01(q * current[1] + (1.0 - q) * following[1]),
])
points = next_points
return points
def _simplify_polygon(polygon: list[list[float]], strength: float) -> list[list[float]]:
if len(polygon) < 3 or strength <= 0:
return polygon
contour = np.array([[[point[0], point[1]]] for point in polygon], dtype=np.float32)
arc_length = cv2.arcLength(contour, True)
epsilon = arc_length * (0.00015 + _smoothing_ratio(strength) * 0.00735)
approx = cv2.approxPolyDP(contour, epsilon, True).reshape(-1, 2)
if len(approx) < 3:
return polygon
return [[_clamp01(float(x)), _clamp01(float(y))] for x, y in approx]
def _smooth_polygon(polygon: list[list[float]], smoothing: dict[str, Any]) -> list[list[float]]:
strength = float(smoothing.get("strength") or 0.0)
if strength <= 0:
return _normalize_polygon(polygon)
effective_strength = _smoothing_ratio(strength, curve=1.45) * 100.0
if effective_strength >= 85:
iterations = 4
elif effective_strength >= 55:
iterations = 3
elif effective_strength >= 25:
iterations = 2
else:
iterations = 1
corner_cut = 0.03 + _smoothing_ratio(strength, curve=1.35) * 0.22
normalized = _normalize_polygon(polygon)
pre_simplified = _simplify_polygon(normalized, effective_strength * 0.25)
smoothed = _chaikin_smooth_polygon(pre_simplified, iterations, corner_cut)
simplified = _simplify_polygon(smoothed, effective_strength)
if len(simplified) > len(normalized):
for fallback_strength in (25.0, 35.0, 50.0, 70.0, 90.0, 100.0):
simplified = _simplify_polygon(simplified, max(effective_strength, fallback_strength))
if len(simplified) <= len(normalized):
break
return simplified if len(simplified) >= 3 else _normalize_polygon(polygon)
def _smooth_polygons(polygons: list[list[list[float]]], smoothing: dict[str, Any]) -> list[list[list[float]]]:
return [polygon for polygon in (_smooth_polygon(polygon, smoothing) for polygon in polygons) if len(polygon) >= 3]
def _frame_window(
frames: list[Frame],
source_position: int,
direction: str,
max_frames: int,
) -> tuple[list[Frame], int]:
count = max(1, min(max_frames, len(frames)))
if direction == "backward":
start = max(0, source_position - count + 1)
return frames[start:source_position + 1], source_position - start
if direction == "both":
before = (count - 1) // 2
after = count - 1 - before
start = max(0, source_position - before)
end = min(len(frames), source_position + after + 1)
while end - start < count and start > 0:
start -= 1
while end - start < count and end < len(frames):
end += 1
return frames[start:end], source_position - start
end = min(len(frames), source_position + count)
return frames[source_position:end], 0
def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]:
paths = []
for index, frame in enumerate(frames):
data = download_file(frame.image_url)
path = directory / f"frame_{index:06d}.jpg"
path.write_bytes(data)
paths.append(str(path))
return paths
def _component_seed_point(component_mask: np.ndarray, width: int, height: int) -> list[float]:
"""Reduce a binary component to one positive prompt point using distance transform."""
dist = cv2.distanceTransform(component_mask.astype(np.uint8), cv2.DIST_L2, 5)
_, _, _, max_loc = cv2.minMaxLoc(dist)
x, y = max_loc
return [
min(max(float(x) / max(width, 1), 0.0), 1.0),
min(max(float(y) / max(height, 1), 0.0), 1.0),
]
def _clamp01(value: float) -> float:
return min(max(float(value), 0.0), 1.0)
def _point_in_polygon(point: list[float], polygon: list[list[float]]) -> bool:
"""Return whether a normalized point is inside a normalized polygon."""
if len(polygon) < 3:
return False
x, y = point
inside = False
j = len(polygon) - 1
for i, current in enumerate(polygon):
xi, yi = current
xj, yj = polygon[j]
intersects = ((yi > y) != (yj > y)) and (
x < (xj - xi) * (y - yi) / ((yj - yi) or 1e-9) + xi
)
if intersects:
inside = not inside
j = i
return inside
def _crop_bounds_from_points(points: list[list[float]], margin: float) -> tuple[float, float, float, float]:
xs = [_clamp01(point[0]) for point in points]
ys = [_clamp01(point[1]) for point in points]
x1 = max(0.0, min(xs) - margin)
y1 = max(0.0, min(ys) - margin)
x2 = min(1.0, max(xs) + margin)
y2 = min(1.0, max(ys) + margin)
if x2 - x1 < 0.05:
center = (x1 + x2) / 2
x1 = max(0.0, center - 0.025)
x2 = min(1.0, center + 0.025)
if y2 - y1 < 0.05:
center = (y1 + y2) / 2
y1 = max(0.0, center - 0.025)
y2 = min(1.0, center + 0.025)
return x1, y1, x2, y2
def _crop_image(image: np.ndarray, bounds: tuple[float, float, float, float]) -> np.ndarray:
height, width = image.shape[:2]
x1, y1, x2, y2 = bounds
left = int(round(x1 * width))
top = int(round(y1 * height))
right = max(left + 1, int(round(x2 * width)))
bottom = max(top + 1, int(round(y2 * height)))
return image[top:bottom, left:right]
def _to_crop_point(point: list[float], bounds: tuple[float, float, float, float]) -> list[float]:
x1, y1, x2, y2 = bounds
return [
_clamp01((float(point[0]) - x1) / max(x2 - x1, 1e-9)),
_clamp01((float(point[1]) - y1) / max(y2 - y1, 1e-9)),
]
def _from_crop_polygon(
polygon: list[list[float]],
bounds: tuple[float, float, float, float],
) -> list[list[float]]:
x1, y1, x2, y2 = bounds
return [
[
_clamp01(x1 + float(point[0]) * (x2 - x1)),
_clamp01(y1 + float(point[1]) * (y2 - y1)),
]
for point in polygon
]
def _filter_predictions(
polygons: list[list[list[float]]],
scores: list[float],
options: dict[str, Any],
negative_points: list[list[float]] | None = None,
) -> tuple[list[list[list[float]]], list[float]]:
if not options.get("auto_filter_background"):
return polygons, scores
min_score = float(options.get("min_score", 0.0) or 0.0)
next_polygons: list[list[list[float]]] = []
next_scores: list[float] = []
for index, polygon in enumerate(polygons):
score = scores[index] if index < len(scores) else 0.0
if score < min_score:
continue
if negative_points and any(_point_in_polygon(point, polygon) for point in negative_points):
continue
next_polygons.append(polygon)
next_scores.append(score)
return next_polygons, next_scores
@router.post(
"/predict",
response_model=PredictResponse,
summary="Run SAM inference with a prompt",
)
def predict(
payload: PredictRequest,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> dict:
"""Execute selected SAM segmentation given an image and a prompt.
- **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized
coordinates or `{ "points": [[x, y], ...], "labels": [1, 0, ...] }`.
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
- **interactive**: `prompt_data` is `{ "box": [...], "points": [[x, y]], "labels": [1, 0] }`.
- **semantic**: disabled in the current SAM 2.1 point/box product flow.
"""
frame = _shared_frame_or_404(payload.image_id, db, current_user)
image = _load_frame_image(frame)
prompt_type = payload.prompt_type.lower()
options = payload.options or {}
polygons: List[List[List[float]]] = []
scores: List[float] = []
negative_points: list[list[float]] = []
try:
if prompt_type == "point":
point_payload = payload.prompt_data
if isinstance(point_payload, dict):
points = point_payload.get("points")
labels = point_payload.get("labels")
else:
points = point_payload
labels = None
if not isinstance(points, list) or len(points) == 0:
raise HTTPException(status_code=400, detail="Invalid point prompt data")
if not isinstance(labels, list) or len(labels) != len(points):
labels = [1] * len(points)
negative_points = [
point for point, label in zip(points, labels) if label == 0
]
inference_image = image
inference_points = points
crop_bounds = None
if options.get("crop_to_prompt"):
margin = float(options.get("crop_margin", 0.25) or 0.25)
crop_bounds = _crop_bounds_from_points(points, margin)
inference_image = _crop_image(image, crop_bounds)
inference_points = [_to_crop_point(point, crop_bounds) for point in points]
polygons, scores = sam_registry.predict_points(payload.model, inference_image, inference_points, labels)
if crop_bounds:
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
elif prompt_type == "box":
box = payload.prompt_data
if not isinstance(box, list) or len(box) != 4:
raise HTTPException(status_code=400, detail="Invalid box prompt data")
inference_image = image
inference_box = box
crop_bounds = None
if options.get("crop_to_prompt"):
margin = float(options.get("crop_margin", 0.05) or 0.05)
crop_bounds = _crop_bounds_from_points([[box[0], box[1]], [box[2], box[3]]], margin)
inference_image = _crop_image(image, crop_bounds)
inference_box = [
*_to_crop_point([box[0], box[1]], crop_bounds),
*_to_crop_point([box[2], box[3]], crop_bounds),
]
polygons, scores = sam_registry.predict_box(payload.model, inference_image, inference_box)
if crop_bounds:
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
elif prompt_type == "interactive":
prompt = payload.prompt_data
if not isinstance(prompt, dict):
raise HTTPException(status_code=400, detail="Invalid interactive prompt data")
box = prompt.get("box")
points = prompt.get("points") or []
labels = prompt.get("labels")
if box is not None and (not isinstance(box, list) or len(box) != 4):
raise HTTPException(status_code=400, detail="Invalid interactive box prompt data")
if not isinstance(points, list):
raise HTTPException(status_code=400, detail="Invalid interactive point prompt data")
if not box and len(points) == 0:
raise HTTPException(status_code=400, detail="Interactive prompt requires a box or points")
if not isinstance(labels, list) or len(labels) != len(points):
labels = [1] * len(points)
negative_points = [
point for point, label in zip(points, labels) if label == 0
]
inference_image = image
inference_box = box
inference_points = points
crop_bounds = None
if options.get("crop_to_prompt"):
margin = float(options.get("crop_margin", 0.05) or 0.05)
crop_points = list(points)
if box:
crop_points.extend([[box[0], box[1]], [box[2], box[3]]])
crop_bounds = _crop_bounds_from_points(crop_points, margin)
inference_image = _crop_image(image, crop_bounds)
inference_points = [_to_crop_point(point, crop_bounds) for point in points]
if box:
inference_box = [
*_to_crop_point([box[0], box[1]], crop_bounds),
*_to_crop_point([box[2], box[3]], crop_bounds),
]
polygons, scores = sam_registry.predict_interactive(
payload.model,
inference_image,
inference_box,
inference_points,
labels,
)
if crop_bounds:
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
elif prompt_type == "semantic":
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
min_score = options.get("min_score")
confidence_threshold = None
if min_score is not None:
try:
parsed_min_score = float(min_score)
if parsed_min_score > 0:
confidence_threshold = parsed_min_score
except (TypeError, ValueError):
confidence_threshold = None
polygons, scores = sam_registry.predict_semantic(
payload.model,
image,
text,
confidence_threshold=confidence_threshold,
)
else:
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
except ModelUnavailableError as exc:
raise HTTPException(status_code=503, detail=str(exc)) from exc
except NotImplementedError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
polygons, scores = _filter_predictions(polygons, scores, options, negative_points)
logger.info(
"AI predict completed model=%s prompt_type=%s frame_id=%s polygons=%d",
payload.model or "default",
prompt_type,
payload.image_id,
len(polygons),
)
return {"polygons": polygons, "scores": scores}
@router.get(
"/models/status",
response_model=AiRuntimeStatus,
summary="Get SAM model and GPU runtime status",
)
def model_status(
selected_model: str | None = None,
_current_user: User = Depends(get_current_user),
) -> dict:
"""Return real runtime availability for GPU and the currently enabled SAM model."""
try:
return sam_registry.runtime_status(selected_model)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@router.post(
"/analyze-mask",
response_model=MaskAnalysisResponse,
summary="Analyze mask geometry and prompt anchors",
)
def analyze_mask(
payload: MaskAnalysisRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> dict:
"""Return backend-computed mask properties for the frontend inspector."""
if payload.frame_id is not None:
_shared_frame_or_404(payload.frame_id, db, current_user)
mask_data = payload.mask_data or {}
polygons = mask_data.get("polygons") or []
if not polygons:
raise HTTPException(status_code=400, detail="Mask analysis requires polygons")
valid_polygons = _normalize_polygons(polygons)
if not valid_polygons:
raise HTTPException(status_code=400, detail="Mask analysis requires at least one valid polygon")
area = sum(_polygon_area(polygon) for polygon in valid_polygons)
bbox = payload.bbox or _polygon_bbox(valid_polygons[0])
source = mask_data.get("source")
raw_score = mask_data.get("score")
confidence: float | None = None
confidence_source = "unavailable"
if isinstance(raw_score, (int, float)):
confidence = max(0.0, min(float(raw_score), 1.0))
confidence_source = "model_score"
elif source:
confidence_source = "source_without_score"
else:
confidence_source = "manual_or_imported"
anchor_count, anchors = _analysis_anchor_summary(valid_polygons)
message = "已从后端重新提取几何拓扑锚点" if payload.extract_skeleton else "已读取后端几何属性"
return {
"confidence": confidence,
"confidence_source": confidence_source,
"topology_anchor_count": anchor_count,
"topology_anchors": anchors,
"area": area,
"bbox": bbox,
"source": source,
"message": message,
}
@router.post(
"/smooth-mask",
response_model=SmoothMaskResponse,
summary="Smooth editable mask polygons with backend geometry rules",
)
def smooth_mask(
payload: SmoothMaskRequest,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> dict:
"""Return a smoothed polygon mask without persisting it.
The frontend keeps this as an explicit edit operation: users preview/apply it
to the current mask, then save through the normal annotation endpoint.
"""
if payload.frame_id is not None:
_shared_frame_or_404(payload.frame_id, db, current_user)
polygons = payload.mask_data.get("polygons") or []
valid_polygons = _normalize_polygons(polygons)
if not valid_polygons:
raise HTTPException(status_code=400, detail="Mask smoothing requires at least one valid polygon")
smoothing = _normalize_smoothing_options(payload.strength, payload.method)
smoothed_polygons = _smooth_polygons(valid_polygons, smoothing)
if not smoothed_polygons:
raise HTTPException(status_code=400, detail="Mask smoothing produced no valid polygons")
area = sum(_polygon_area(polygon) for polygon in smoothed_polygons)
bbox = _polygon_bbox(smoothed_polygons[0])
anchor_count, anchors = _analysis_anchor_summary(smoothed_polygons)
return {
"polygons": smoothed_polygons,
"topology_anchor_count": anchor_count,
"topology_anchors": anchors,
"area": area,
"bbox": bbox,
"smoothing": smoothing,
"message": f"已应用边缘平滑强度 {smoothing['strength']:.0f}",
}
@router.post(
"/propagate",
response_model=PropagateResponse,
summary="Propagate one current-frame region across a video frame segment",
)
def propagate(
payload: PropagateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> dict:
"""Track one selected region from the current frame across nearby frames.
SAM 2 uses the official video predictor with the selected mask as the seed.
SAM 3 video tracking is currently disabled in this product flow.
"""
direction = payload.direction.lower()
if direction not in {"forward", "backward", "both"}:
raise HTTPException(status_code=400, detail="direction must be forward, backward, or both")
max_frames = max(1, min(int(payload.max_frames or 30), 500))
_shared_project_or_404(payload.project_id, db, current_user)
source_frame = _shared_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
seed = payload.seed.model_dump(exclude_none=True)
polygons = seed.get("polygons") or []
bbox = seed.get("bbox")
points = seed.get("points") or []
if not polygons and not bbox and not points:
raise HTTPException(status_code=400, detail="Propagation requires seed polygons, bbox, or points")
frames = db.query(Frame).filter(Frame.project_id == payload.project_id).order_by(Frame.frame_index).all()
source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None)
if source_position is None:
raise HTTPException(status_code=404, detail="Source frame is not in project frame sequence")
selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames)
if len(selected_frames) == 0:
raise HTTPException(status_code=400, detail="No frames available for propagation")
try:
with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload.project_id}_") as tmpdir:
frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir))
propagated = sam_registry.propagate_video(
payload.model,
frame_paths,
source_relative_index,
seed,
direction,
len(selected_frames),
)
except ModelUnavailableError as exc:
raise HTTPException(status_code=503, detail=str(exc)) from exc
except NotImplementedError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc: # noqa: BLE001
logger.error("Video propagation failed: %s", exc)
raise HTTPException(status_code=500, detail=f"Video propagation failed: {exc}") from exc
created: list[Annotation] = []
if payload.save_annotations:
class_metadata = seed.get("class_metadata")
template_id = seed.get("template_id")
label = seed.get("label") or "Propagated Mask"
color = seed.get("color") or "#06b6d4"
model_id = sam_registry.normalize_model_id(payload.model)
source_annotation_id = seed.get("source_annotation_id")
source_mask_id = seed.get("source_mask_id")
source_instance_id = (
seed.get("source_instance_id")
or (f"annotation:{source_annotation_id}" if source_annotation_id is not None else None)
or (f"mask:{source_mask_id}" if source_mask_id else None)
)
seed_smoothing = seed.get("smoothing")
smoothing = _normalize_smoothing_options(
seed_smoothing.get("strength"),
seed_smoothing.get("method"),
) if isinstance(seed_smoothing, dict) else None
if smoothing and smoothing["strength"] <= 0:
smoothing = None
for frame_result in propagated:
relative_index = int(frame_result.get("frame_index", -1))
if relative_index < 0 or relative_index >= len(selected_frames):
continue
frame = selected_frames[relative_index]
if not payload.include_source and frame.id == source_frame.id:
continue
result_polygons = frame_result.get("polygons") or []
result_holes = frame_result.get("holes") or []
scores = frame_result.get("scores") or []
polygons_to_save: list[list[list[float]]] = []
holes_to_save: list[list[list[list[float]]]] = []
score_values: list[float] = []
for polygon_index, polygon in enumerate(result_polygons):
if len(polygon) < 3:
continue
polygons_to_save.append(_smooth_polygon(polygon, smoothing) if smoothing else polygon)
hole_group = result_holes[polygon_index] if polygon_index < len(result_holes) and isinstance(result_holes[polygon_index], list) else []
holes_to_save.append(hole_group if isinstance(hole_group, list) else [])
if polygon_index < len(scores):
try:
score_values.append(float(scores[polygon_index]))
except (TypeError, ValueError):
pass
if not polygons_to_save:
continue
annotation = Annotation(
project_id=payload.project_id,
frame_id=frame.id,
template_id=template_id,
mask_data={
"polygons": polygons_to_save,
**({"holes": holes_to_save, "hasHoles": True} if any(holes_to_save) else {}),
"label": label,
"color": color,
"source": f"{model_id}_propagation",
"propagated_from_frame_id": source_frame.id,
"propagated_from_frame_index": source_frame.frame_index,
**({"instance_id": source_instance_id, "source_instance_id": source_instance_id} if source_instance_id else {}),
**({"source_annotation_id": source_annotation_id} if source_annotation_id is not None else {}),
**({"source_mask_id": source_mask_id} if source_mask_id else {}),
"score": max(score_values) if score_values else None,
**({"scores": score_values} if len(score_values) > 1 else {}),
**({"geometry_smoothing": smoothing} if smoothing else {}),
**({"class": class_metadata} if class_metadata else {}),
},
points=None,
bbox=_polygons_bbox(polygons_to_save),
)
db.add(annotation)
created.append(annotation)
db.commit()
for annotation in created:
db.refresh(annotation)
return {
"model": sam_registry.normalize_model_id(payload.model),
"direction": direction,
"source_frame_id": source_frame.id,
"processed_frame_count": len(selected_frames),
"created_annotation_count": len(created),
"annotations": created,
}
@router.post(
"/propagate/task",
status_code=status.HTTP_202_ACCEPTED,
response_model=ProcessingTaskOut,
summary="Queue a background video propagation task",
)
def queue_propagate_task(
payload: PropagateTaskRequest,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> ProcessingTaskOut:
"""Queue multiple seed/direction propagation steps as one background task."""
_shared_project_or_404(payload.project_id, db, current_user)
source_frame = _shared_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
if not payload.steps:
raise HTTPException(status_code=400, detail="Propagation task requires at least one step")
try:
model_id = sam_registry.normalize_model_id(payload.model)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
for step in payload.steps:
direction = step.direction.lower()
if direction not in {"forward", "backward"}:
raise HTTPException(status_code=400, detail="direction must be forward or backward")
seed = step.seed.model_dump(exclude_none=True)
if not (seed.get("polygons") or seed.get("bbox") or seed.get("points")):
raise HTTPException(status_code=400, detail="Propagation requires seed polygons, bbox, or points")
task_payload = payload.model_dump(exclude_none=True)
task_payload["model"] = model_id
task = ProcessingTask(
task_type="propagate_masks",
status=TASK_STATUS_QUEUED,
progress=0,
message="自动传播任务已入队",
project_id=payload.project_id,
payload=task_payload,
)
db.add(task)
db.commit()
db.refresh(task)
publish_task_progress_event(task)
async_result = propagate_project_masks.delay(task.id)
task.celery_task_id = async_result.id
db.commit()
db.refresh(task)
publish_task_progress_event(task)
logger.info("Queued propagation task id=%s project_id=%s celery_id=%s", task.id, payload.project_id, async_result.id)
return task
@router.post(
"/auto",
response_model=PredictResponse,
summary="Run automatic segmentation",
)
def auto_segment(
image_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> dict:
"""Run automatic mask generation on a frame using a grid of point prompts."""
frame = _shared_frame_or_404(image_id, db, current_user)
image = _load_frame_image(frame)
try:
polygons, scores = sam_registry.predict_auto(None, image)
except ModelUnavailableError as exc:
raise HTTPException(status_code=503, detail=str(exc)) from exc
return {"polygons": polygons, "scores": scores}
@router.post(
"/annotate",
response_model=AnnotationOut,
status_code=status.HTTP_201_CREATED,
summary="Save an AI-generated annotation",
)
def save_annotation(
payload: AnnotationCreate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Annotation:
"""Persist an annotation (mask, points, bbox) into the database."""
_shared_project_or_404(payload.project_id, db, current_user)
if payload.frame_id:
_shared_frame_or_404(payload.frame_id, db, current_user, payload.project_id)
if payload.template_id:
_visible_template_or_404(payload.template_id, db, current_user)
annotation = Annotation(**payload.model_dump())
db.add(annotation)
db.commit()
db.refresh(annotation)
logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id)
return annotation
@router.post(
"/import-gt-mask",
response_model=List[AnnotationOut],
status_code=status.HTTP_201_CREATED,
summary="Import a GT mask and reduce components to editable point regions",
)
async def import_gt_mask(
project_id: int = Form(...),
frame_id: int = Form(...),
template_id: int | None = Form(None),
label: str = Form("GT Mask"),
color: str = Form("#22c55e"),
unknown_color_policy: str = Form("undefined"),
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> List[Annotation]:
"""Convert a binary/label mask image into persisted polygon annotations.
Each connected component becomes one annotation. The `points` field stores a
positive seed point at the component's distance-transform center, which gives
the frontend an editable point-region representation instead of a static
bitmap layer.
"""
_shared_project_or_404(project_id, db, current_user)
frame = _shared_frame_or_404(frame_id, db, current_user, project_id)
if unknown_color_policy not in {"discard", "undefined"}:
raise HTTPException(status_code=400, detail="unknown_color_policy must be discard or undefined")
template: Template | None = None
if template_id is not None:
template = _visible_template_or_404(template_id, db, current_user)
data = await file.read()
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
if image is None:
raise HTTPException(status_code=400, detail="Invalid mask image")
invalid_format_detail = (
"GT Mask 图片不符合要求:仅支持 8-bit 灰度图,或 8-bit RGB 三通道完全相同的 maskid 图"
"(背景 0像素值为 1-255 的 maskid"
)
if image.dtype != np.uint8:
raise HTTPException(status_code=400, detail=invalid_format_detail)
if image.ndim == 2:
label_image = image
elif image.ndim == 3 and image.shape[2] >= 3:
channels = image[:, :, :3]
# GT label images are maskid maps: either grayscale or RGB/BGR where
# all three color channels contain the same maskid value [X, X, X].
if not (np.array_equal(channels[:, :, 0], channels[:, :, 1]) and np.array_equal(channels[:, :, 1], channels[:, :, 2])):
raise HTTPException(status_code=400, detail=invalid_format_detail)
label_image = channels[:, :, 0]
else:
raise HTTPException(status_code=400, detail=invalid_format_detail)
width = int(frame.width or image.shape[1])
height = int(frame.height or image.shape[0])
original_height, original_width = int(label_image.shape[0]), int(label_image.shape[1])
resized_to_frame = original_width != width or original_height != height
if resized_to_frame:
label_image = cv2.resize(label_image, (width, height), interpolation=cv2.INTER_NEAREST)
by_maskid, _by_color, unclassified_class = _template_class_maps(template)
has_template_classes = bool(by_maskid)
fallback_color = _normalize_hex_color(color) or "#22c55e"
import_items: list[dict[str, Any]] = []
skipped_unknown = 0
label_values = [int(value) for value in np.unique(label_image) if int(value) > 0]
for label_value in label_values:
class_meta = by_maskid.get(label_value)
is_unknown = has_template_classes and class_meta is None
if is_unknown and unknown_color_policy == "discard":
skipped_unknown += 1
continue
if class_meta:
annotation_label = class_meta["name"]
annotation_color = class_meta["color"]
elif is_unknown:
annotation_label = unclassified_class["name"]
annotation_color = unclassified_class["color"]
class_meta = unclassified_class
else:
annotation_label = f"{label} {label_value}" if len(label_values) > 1 else label
annotation_color = fallback_color
import_items.append({
"token": label_value,
"binary": np.where(label_image == label_value, 255, 0).astype(np.uint8),
"label": annotation_label,
"color": annotation_color,
"class": class_meta,
"unknown": is_unknown,
"metadata": {
"gt_label_value": label_value,
"gt_original_size": {"width": original_width, "height": original_height},
"gt_resized_to_frame": resized_to_frame,
},
})
if not import_items:
if skipped_unknown > 0:
raise HTTPException(status_code=400, detail="No matching GT mask classes found")
raise HTTPException(status_code=400, detail=GT_MASK_EMPTY_DETAIL)
annotations: list[Annotation] = []
for item in import_items:
binary = item["binary"]
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
for contour in contours:
if cv2.contourArea(contour) < 1:
continue
polygon = _normalized_contour(contour, binary.shape[1], binary.shape[0])
if len(polygon) < 3:
continue
component = np.zeros_like(binary, dtype=np.uint8)
cv2.drawContours(component, [contour], -1, 1, thickness=-1)
seed_point = _component_seed_point(component, binary.shape[1], binary.shape[0])
bbox = _contour_bbox(contour, binary.shape[1], binary.shape[0])
mask_data = {
"polygons": [polygon],
"label": item["label"],
"color": item["color"],
"source": "gt_mask",
"image_size": {"width": width, "height": height},
**item["metadata"],
}
if item["class"]:
mask_data["class"] = item["class"]
if item["unknown"]:
mask_data["gt_unknown_class"] = True
annotation = Annotation(
project_id=project_id,
frame_id=frame_id,
template_id=template_id,
mask_data=mask_data,
points=[seed_point],
bbox=bbox,
)
db.add(annotation)
annotations.append(annotation)
if not annotations:
raise HTTPException(status_code=400, detail=GT_MASK_EMPTY_DETAIL)
db.commit()
for annotation in annotations:
db.refresh(annotation)
logger.info("Imported %s GT mask annotations for project_id=%s frame_id=%s", len(annotations), project_id, frame_id)
return annotations
@router.get(
"/annotations",
response_model=List[AnnotationOut],
summary="List saved annotations for a project",
)
def list_annotations(
project_id: int,
frame_id: int | None = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> List[Annotation]:
"""Return persisted annotations for a project, optionally scoped to one frame."""
_shared_project_or_404(project_id, db, current_user)
query = db.query(Annotation).filter(Annotation.project_id == project_id)
if frame_id is not None:
_shared_frame_or_404(frame_id, db, current_user, project_id)
query = query.filter(Annotation.frame_id == frame_id)
return query.order_by(Annotation.id).all()
@router.patch(
"/annotations/{annotation_id}",
response_model=AnnotationOut,
summary="Update a saved annotation",
)
def update_annotation(
annotation_id: int,
payload: AnnotationUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Annotation:
"""Update mutable annotation fields persisted in the database."""
annotation = (
db.query(Annotation)
.join(Project, Project.id == Annotation.project_id)
.filter(Annotation.id == annotation_id)
.first()
)
if not annotation:
raise HTTPException(status_code=404, detail="Annotation not found")
updates = payload.model_dump(exclude_unset=True)
if "template_id" in updates and updates["template_id"] is not None:
_visible_template_or_404(updates["template_id"], db, current_user)
for field, value in updates.items():
setattr(annotation, field, value)
db.commit()
db.refresh(annotation)
logger.info("Updated annotation id=%s", annotation.id)
return annotation
@router.delete(
"/annotations/{annotation_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a saved annotation",
)
def delete_annotation(
annotation_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(require_editor),
) -> Response:
"""Delete an annotation and its derived mask rows through ORM cascade."""
annotation = (
db.query(Annotation)
.join(Project, Project.id == Annotation.project_id)
.filter(Annotation.id == annotation_id)
.first()
)
if not annotation:
raise HTTPException(status_code=404, detail="Annotation not found")
db.delete(annotation)
db.commit()
logger.info("Deleted annotation id=%s", annotation_id)
return Response(status_code=status.HTTP_204_NO_CONTENT)