feat: 建立 SAM2 标注闭环基线

- 打通工作区真实标注闭环:支持手工多边形、矩形、圆形、点区域和线段生成 mask,并可保存、回显、更新和删除后端 annotation。

- 增强 polygon 编辑器:支持顶点拖动、顶点删除、边中点插入、多 polygon 子区域选择编辑,以及区域合并和区域去除。

- 接入 GT mask 导入:后端支持二值/多类别 mask 拆分、contour 转 polygon、distance transform seed point,前端支持导入、回显和 seed point 拖动编辑。

- 完善导出能力:COCO JSON 导出对齐前端,PNG mask ZIP 同时包含单标注 mask、按 zIndex 融合的 semantic_frame 和 semantic_classes.json。

- 打通异步任务管理:新增任务取消、重试、失败详情接口与 Dashboard 控件,worker 支持取消状态检查并通过 Redis/WebSocket 推送 cancelled 事件。

- 对接 Dashboard 后端数据:概览统计、解析队列和实时流转记录从 FastAPI 聚合接口与 WebSocket 更新。

- 增强 AI 推理参数:前端发送 crop_to_prompt、auto_filter_background 和 min_score,后端支持点/框 prompt 局部裁剪推理、结果回映射和负向点/低分过滤。

- 接入 SAM3 基础设施:新增独立 Python 3.12 sam3 环境安装脚本、外部 worker helper、后端桥接和真实 Python/CUDA/包/HF checkpoint access 状态检测。

- 保留 SAM3 授权边界:当前官方 facebook/sam3 gated 权重未授权时状态接口会返回不可用,不伪装成可推理。

- 增强前端状态管理:新增 mask undo/redo 历史栈、AI 模型选择状态、保存状态 dirty/draft/saved 流转和项目状态归一化。

- 更新前端 API 封装:补充 annotation CRUD、GT mask import、mask ZIP export、task cancel/retry/detail、AI runtime status 和 prediction options。

- 更新 UI 控件:ToolsPalette、AISegmentation、VideoWorkspace 和 CanvasArea 接入真实操作、导入导出、撤销重做、任务控制和模型状态。

- 新增 polygon-clipping 依赖,用于前端区域 union/difference 几何运算。

- 完善后端 schemas/status/progress:补充 AI 模型外部状态字段、任务 cancelled 状态和进度事件 payload。

- 补充测试覆盖:新增后端任务控制、SAM3 桥接、GT mask、导出融合、AI options 测试;补充前端 Canvas、Dashboard、VideoWorkspace、ToolsPalette、API 和 store 测试。

- 更新 README、AGENTS 和 doc 文档:冻结当前需求/设计/测试计划,标注真实功能、剩余 Mock、SAM3 授权边界和后续实施顺序。
This commit is contained in:
2026-05-01 15:26:25 +08:00
parent f020ff3b4f
commit 689a9ba283
48 changed files with 3280 additions and 176 deletions

View File

@@ -22,7 +22,12 @@ class Settings(BaseSettings):
sam_default_model: str = "sam2"
sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
sam_model_config: str = "configs/sam2/sam2_hiera_t.yaml"
sam3_model_version: str = "sam3.1"
sam3_model_version: str = "sam3"
sam3_external_enabled: bool = True
sam3_external_python: str = "/home/wkmgc/miniconda3/envs/sam3/bin/python"
sam3_timeout_seconds: int = 300
sam3_status_cache_seconds: int = 30
sam3_confidence_threshold: float = 0.5
# App
app_env: str = "development"

View File

@@ -8,7 +8,7 @@ from datetime import datetime, timezone
from typing import Any
from redis_client import get_redis_client
from statuses import TASK_STATUS_FAILED, TASK_STATUS_SUCCESS
from statuses import TASK_STATUS_CANCELLED, TASK_STATUS_FAILED, TASK_STATUS_SUCCESS
logger = logging.getLogger(__name__)
@@ -22,6 +22,8 @@ def _iso_now() -> str:
def _event_type(task_status: str) -> str:
if task_status == TASK_STATUS_SUCCESS:
return "complete"
if task_status == TASK_STATUS_CANCELLED:
return "cancelled"
if task_status == TASK_STATUS_FAILED:
return "error"
return "progress"

View File

@@ -5,7 +5,7 @@ from typing import Any, List
import cv2
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, Response, status
from fastapi import APIRouter, Depends, File, Form, HTTPException, Response, UploadFile, status
from sqlalchemy.orm import Session
from database import get_db
@@ -39,6 +39,140 @@ def _load_frame_image(frame: Frame) -> np.ndarray:
raise HTTPException(status_code=500, detail="Failed to load frame image") from exc
def _normalized_contour(contour: np.ndarray, width: int, height: int) -> list[list[float]]:
"""Approximate a contour and convert it to normalized polygon coordinates."""
arc_length = cv2.arcLength(contour, True)
epsilon = max(1.0, arc_length * 0.01)
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape(-1, 2)
if len(points) < 3:
points = contour.reshape(-1, 2)
return [
[
min(max(float(x) / max(width, 1), 0.0), 1.0),
min(max(float(y) / max(height, 1), 0.0), 1.0),
]
for x, y in points
]
def _contour_bbox(contour: np.ndarray, width: int, height: int) -> list[float]:
x, y, w, h = cv2.boundingRect(contour)
return [
min(max(float(x) / max(width, 1), 0.0), 1.0),
min(max(float(y) / max(height, 1), 0.0), 1.0),
min(max(float(w) / max(width, 1), 0.0), 1.0),
min(max(float(h) / max(height, 1), 0.0), 1.0),
]
def _component_seed_point(component_mask: np.ndarray, width: int, height: int) -> list[float]:
"""Reduce a binary component to one positive prompt point using distance transform."""
dist = cv2.distanceTransform(component_mask.astype(np.uint8), cv2.DIST_L2, 5)
_, _, _, max_loc = cv2.minMaxLoc(dist)
x, y = max_loc
return [
min(max(float(x) / max(width, 1), 0.0), 1.0),
min(max(float(y) / max(height, 1), 0.0), 1.0),
]
def _clamp01(value: float) -> float:
return min(max(float(value), 0.0), 1.0)
def _point_in_polygon(point: list[float], polygon: list[list[float]]) -> bool:
"""Return whether a normalized point is inside a normalized polygon."""
if len(polygon) < 3:
return False
x, y = point
inside = False
j = len(polygon) - 1
for i, current in enumerate(polygon):
xi, yi = current
xj, yj = polygon[j]
intersects = ((yi > y) != (yj > y)) and (
x < (xj - xi) * (y - yi) / ((yj - yi) or 1e-9) + xi
)
if intersects:
inside = not inside
j = i
return inside
def _crop_bounds_from_points(points: list[list[float]], margin: float) -> tuple[float, float, float, float]:
xs = [_clamp01(point[0]) for point in points]
ys = [_clamp01(point[1]) for point in points]
x1 = max(0.0, min(xs) - margin)
y1 = max(0.0, min(ys) - margin)
x2 = min(1.0, max(xs) + margin)
y2 = min(1.0, max(ys) + margin)
if x2 - x1 < 0.05:
center = (x1 + x2) / 2
x1 = max(0.0, center - 0.025)
x2 = min(1.0, center + 0.025)
if y2 - y1 < 0.05:
center = (y1 + y2) / 2
y1 = max(0.0, center - 0.025)
y2 = min(1.0, center + 0.025)
return x1, y1, x2, y2
def _crop_image(image: np.ndarray, bounds: tuple[float, float, float, float]) -> np.ndarray:
height, width = image.shape[:2]
x1, y1, x2, y2 = bounds
left = int(round(x1 * width))
top = int(round(y1 * height))
right = max(left + 1, int(round(x2 * width)))
bottom = max(top + 1, int(round(y2 * height)))
return image[top:bottom, left:right]
def _to_crop_point(point: list[float], bounds: tuple[float, float, float, float]) -> list[float]:
x1, y1, x2, y2 = bounds
return [
_clamp01((float(point[0]) - x1) / max(x2 - x1, 1e-9)),
_clamp01((float(point[1]) - y1) / max(y2 - y1, 1e-9)),
]
def _from_crop_polygon(
polygon: list[list[float]],
bounds: tuple[float, float, float, float],
) -> list[list[float]]:
x1, y1, x2, y2 = bounds
return [
[
_clamp01(x1 + float(point[0]) * (x2 - x1)),
_clamp01(y1 + float(point[1]) * (y2 - y1)),
]
for point in polygon
]
def _filter_predictions(
polygons: list[list[list[float]]],
scores: list[float],
options: dict[str, Any],
negative_points: list[list[float]] | None = None,
) -> tuple[list[list[list[float]]], list[float]]:
if not options.get("auto_filter_background"):
return polygons, scores
min_score = float(options.get("min_score", 0.0) or 0.0)
next_polygons: list[list[list[float]]] = []
next_scores: list[float] = []
for index, polygon in enumerate(polygons):
score = scores[index] if index < len(scores) else 0.0
if score < min_score:
continue
if negative_points and any(_point_in_polygon(point, polygon) for point in negative_points):
continue
next_polygons.append(polygon)
next_scores.append(score)
return next_polygons, next_scores
@router.post(
"/predict",
response_model=PredictResponse,
@@ -58,9 +192,11 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
image = _load_frame_image(frame)
prompt_type = payload.prompt_type.lower()
options = payload.options or {}
polygons: List[List[List[float]]] = []
scores: List[float] = []
negative_points: list[list[float]] = []
try:
if prompt_type == "point":
@@ -76,13 +212,39 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
raise HTTPException(status_code=400, detail="Invalid point prompt data")
if not isinstance(labels, list) or len(labels) != len(points):
labels = [1] * len(points)
polygons, scores = sam_registry.predict_points(payload.model, image, points, labels)
negative_points = [
point for point, label in zip(points, labels) if label == 0
]
inference_image = image
inference_points = points
crop_bounds = None
if options.get("crop_to_prompt"):
margin = float(options.get("crop_margin", 0.25) or 0.25)
crop_bounds = _crop_bounds_from_points(points, margin)
inference_image = _crop_image(image, crop_bounds)
inference_points = [_to_crop_point(point, crop_bounds) for point in points]
polygons, scores = sam_registry.predict_points(payload.model, inference_image, inference_points, labels)
if crop_bounds:
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
elif prompt_type == "box":
box = payload.prompt_data
if not isinstance(box, list) or len(box) != 4:
raise HTTPException(status_code=400, detail="Invalid box prompt data")
polygons, scores = sam_registry.predict_box(payload.model, image, box)
inference_image = image
inference_box = box
crop_bounds = None
if options.get("crop_to_prompt"):
margin = float(options.get("crop_margin", 0.05) or 0.05)
crop_bounds = _crop_bounds_from_points([[box[0], box[1]], [box[2], box[3]]], margin)
inference_image = _crop_image(image, crop_bounds)
inference_box = [
*_to_crop_point([box[0], box[1]], crop_bounds),
*_to_crop_point([box[2], box[3]], crop_bounds),
]
polygons, scores = sam_registry.predict_box(payload.model, inference_image, inference_box)
if crop_bounds:
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
elif prompt_type == "semantic":
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
@@ -95,8 +257,9 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
except NotImplementedError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
raise HTTPException(status_code=400, detail=str(exc)) from exc
polygons, scores = _filter_predictions(polygons, scores, options, negative_points)
return {"polygons": polygons, "scores": scores}
@@ -161,6 +324,100 @@ def save_annotation(
return annotation
@router.post(
"/import-gt-mask",
response_model=List[AnnotationOut],
status_code=status.HTTP_201_CREATED,
summary="Import a GT mask and reduce components to editable point regions",
)
async def import_gt_mask(
project_id: int = Form(...),
frame_id: int = Form(...),
template_id: int | None = Form(None),
label: str = Form("GT Mask"),
color: str = Form("#22c55e"),
file: UploadFile = File(...),
db: Session = Depends(get_db),
) -> List[Annotation]:
"""Convert a binary/label mask image into persisted polygon annotations.
Each connected component becomes one annotation. The `points` field stores a
positive seed point at the component's distance-transform center, which gives
the frontend an editable point-region representation instead of a static
bitmap layer.
"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
frame = db.query(Frame).filter(Frame.id == frame_id, Frame.project_id == project_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
if template_id is not None:
template = db.query(Template).filter(Template.id == template_id).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
data = await file.read()
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
if image is None:
raise HTTPException(status_code=400, detail="Invalid mask image")
width = int(frame.width or image.shape[1])
height = int(frame.height or image.shape[0])
label_values = [int(value) for value in np.unique(image) if int(value) > 0]
if not label_values:
raise HTTPException(status_code=400, detail="No foreground mask regions found")
has_multiple_labels = len(label_values) > 1
annotations: list[Annotation] = []
for label_value in label_values:
binary = np.where(image == label_value, 255, 0).astype(np.uint8)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
annotation_label = f"{label} {label_value}" if has_multiple_labels else label
for contour in contours:
if cv2.contourArea(contour) < 1:
continue
polygon = _normalized_contour(contour, image.shape[1], image.shape[0])
if len(polygon) < 3:
continue
component = np.zeros_like(binary, dtype=np.uint8)
cv2.drawContours(component, [contour], -1, 1, thickness=-1)
seed_point = _component_seed_point(component, image.shape[1], image.shape[0])
bbox = _contour_bbox(contour, image.shape[1], image.shape[0])
annotation = Annotation(
project_id=project_id,
frame_id=frame_id,
template_id=template_id,
mask_data={
"polygons": [polygon],
"label": annotation_label,
"color": color,
"source": "gt_mask",
"gt_label_value": label_value,
"image_size": {"width": width, "height": height},
},
points=[seed_point],
bbox=bbox,
)
db.add(annotation)
annotations.append(annotation)
if not annotations:
raise HTTPException(status_code=400, detail="No foreground mask regions found")
db.commit()
for annotation in annotations:
db.refresh(annotation)
logger.info("Imported %s GT mask annotations for project_id=%s frame_id=%s", len(annotations), project_id, frame_id)
return annotations
@router.get(
"/annotations",
response_model=List[AnnotationOut],

View File

@@ -14,6 +14,7 @@ from models import Annotation, Frame, ProcessingTask, Project, Template
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
ACTIVE_TASK_STATUSES = {"queued", "running"}
MONITORED_TASK_STATUSES = {"queued", "running", "failed", "cancelled"}
def _system_load_percent() -> int:
@@ -42,7 +43,9 @@ def _task_payload(task: ProcessingTask) -> dict[str, Any]:
"name": task.project.name if task.project else f"任务 {task.id}",
"progress": task.progress,
"status": task.message or task.status,
"raw_status": task.status,
"frame_count": (task.result or {}).get("frames_extracted", 0),
"error": task.error,
"updated_at": _iso_or_none(task.updated_at),
}
@@ -68,7 +71,7 @@ def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
.limit(50)
.all()
)
tasks = [_task_payload(task) for task in recent_tasks if task.status in ACTIVE_TASK_STATUSES]
tasks = [_task_payload(task) for task in recent_tasks if task.status in MONITORED_TASK_STATUSES]
activities: list[dict[str, Any]] = []
for task in recent_tasks[:10]:

View File

@@ -37,6 +37,54 @@ def _mask_from_polygon(
return mask
def _annotation_z_index(annotation: Annotation) -> int:
class_meta = (annotation.mask_data or {}).get("class") or {}
if isinstance(class_meta, dict) and class_meta.get("zIndex") is not None:
try:
return int(class_meta["zIndex"])
except (TypeError, ValueError):
pass
if annotation.template and annotation.template.z_index is not None:
return int(annotation.template.z_index)
return 0
def _annotation_class_key(annotation: Annotation) -> str:
class_meta = (annotation.mask_data or {}).get("class") or {}
if isinstance(class_meta, dict):
if class_meta.get("id"):
return f"class:{class_meta['id']}"
if class_meta.get("name"):
return f"name:{class_meta['name']}"
if annotation.template_id:
return f"template:{annotation.template_id}"
return f"annotation:{annotation.id}"
def _annotation_label(annotation: Annotation) -> str:
mask_data = annotation.mask_data or {}
class_meta = mask_data.get("class") or {}
if isinstance(class_meta, dict) and class_meta.get("name"):
return str(class_meta["name"])
if mask_data.get("label"):
return str(mask_data["label"])
if annotation.template and annotation.template.name:
return str(annotation.template.name)
return f"Annotation {annotation.id}"
def _annotation_color(annotation: Annotation) -> str:
mask_data = annotation.mask_data or {}
class_meta = mask_data.get("class") or {}
if isinstance(class_meta, dict) and class_meta.get("color"):
return str(class_meta["color"])
if mask_data.get("color"):
return str(mask_data["color"])
if annotation.template and annotation.template.color:
return str(annotation.template.color)
return "#ffffff"
@router.get(
"/{project_id}/coco",
summary="Export annotations in COCO format",
@@ -150,19 +198,46 @@ def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResp
summary="Export PNG masks as a ZIP archive",
)
def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
"""Export all annotation masks as individual PNG files inside a ZIP archive."""
"""Export individual masks plus z-index fused semantic masks inside a ZIP."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
import cv2
annotations = (
db.query(Annotation)
.filter(Annotation.project_id == project_id)
.all()
)
frames = (
db.query(Frame)
.filter(Frame.project_id == project_id)
.order_by(Frame.frame_index)
.all()
)
class_values: dict[str, int] = {}
semantic_classes: list[dict[str, Any]] = []
def class_value(annotation: Annotation) -> int:
key = _annotation_class_key(annotation)
if key not in class_values:
value = len(class_values) + 1
class_values[key] = value
semantic_classes.append({
"value": value,
"key": key,
"label": _annotation_label(annotation),
"color": _annotation_color(annotation),
"zIndex": _annotation_z_index(annotation),
"template_id": annotation.template_id,
})
return class_values[key]
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
frame_masks: dict[int, list[tuple[Annotation, np.ndarray]]] = {}
for ann in annotations:
if not ann.mask_data:
continue
@@ -178,11 +253,28 @@ def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingRes
mask = _mask_from_polygon(poly, width, height)
combined = np.maximum(combined, mask)
# Encode PNG
import cv2
_, encoded = cv2.imencode(".png", combined)
fname = f"mask_{ann.id:06d}.png"
zf.writestr(fname, encoded.tobytes())
if ann.frame_id is not None:
frame_masks.setdefault(ann.frame_id, []).append((ann, combined))
for frame in frames:
entries = frame_masks.get(frame.id, [])
if not entries:
continue
width = frame.width or 1920
height = frame.height or 1080
semantic = np.zeros((height, width), dtype=np.uint8)
for ann, mask in sorted(entries, key=lambda item: _annotation_z_index(item[0])):
semantic[mask > 0] = class_value(ann)
_, encoded = cv2.imencode(".png", semantic)
zf.writestr(f"semantic_frame_{frame.frame_index:06d}.png", encoded.tobytes())
zf.writestr(
"semantic_classes.json",
json.dumps({"classes": semantic_classes}, ensure_ascii=False, indent=2).encode("utf-8"),
)
zip_buffer.seek(0)
filename = f"project_{project_id}_masks.zip"

View File

@@ -1,15 +1,45 @@
"""Processing task query endpoints."""
import logging
from datetime import datetime, timezone
from typing import List
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from celery_app import celery_app
from database import get_db
from models import ProcessingTask
from models import ProcessingTask, Project
from progress_events import publish_task_progress_event
from schemas import ProcessingTaskOut
from statuses import (
PROJECT_STATUS_PARSING,
PROJECT_STATUS_PENDING,
PROJECT_STATUS_READY,
TASK_ACTIVE_STATUSES,
TASK_STATUS_CANCELLED,
TASK_STATUS_FAILED,
TASK_STATUS_QUEUED,
)
from worker_tasks import parse_project_media
router = APIRouter(prefix="/api/tasks", tags=["Tasks"])
logger = logging.getLogger(__name__)
def _now() -> datetime:
return datetime.now(timezone.utc)
def _get_task_or_404(task_id: int, db: Session) -> ProcessingTask:
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return task
def _project_status_after_stop(project: Project) -> str:
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
@router.get("", response_model=List[ProcessingTaskOut], summary="List processing tasks")
@@ -31,7 +61,78 @@ def list_tasks(
@router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task")
def get_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
"""Return one background task by id."""
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return _get_task_or_404(task_id, db)
@router.post("/{task_id}/cancel", response_model=ProcessingTaskOut, summary="Cancel processing task")
def cancel_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
"""Cancel a queued/running background task and revoke the Celery job when possible."""
task = _get_task_or_404(task_id, db)
if task.status not in TASK_ACTIVE_STATUSES:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Task is not cancellable in status: {task.status}",
)
if task.celery_task_id:
try:
celery_app.control.revoke(task.celery_task_id, terminate=True, signal="SIGTERM")
except Exception as exc: # noqa: BLE001
logger.warning("Failed to revoke celery task %s: %s", task.celery_task_id, exc)
task.status = TASK_STATUS_CANCELLED
task.progress = 100
task.message = "任务已取消"
task.error = "Cancelled by user"
task.finished_at = _now()
if task.project:
task.project.status = _project_status_after_stop(task.project)
db.commit()
db.refresh(task)
publish_task_progress_event(task)
return task
@router.post("/{task_id}/retry", response_model=ProcessingTaskOut, status_code=status.HTTP_202_ACCEPTED, summary="Retry processing task")
def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
"""Create a fresh queued task from a failed or cancelled task."""
previous = _get_task_or_404(task_id, db)
if previous.status not in {TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Task is not retryable in status: {previous.status}",
)
if previous.project_id is None:
raise HTTPException(status_code=400, detail="Task has no project_id")
project = db.query(Project).filter(Project.id == previous.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if not project.video_path:
raise HTTPException(status_code=400, detail="Project has no media uploaded")
payload = dict(previous.payload or {})
payload.setdefault("source_type", project.source_type or "video")
payload["retry_of"] = previous.id
task = ProcessingTask(
task_type=previous.task_type,
status=TASK_STATUS_QUEUED,
progress=0,
message=f"重试任务已入队(源任务 #{previous.id}",
project_id=project.id,
payload=payload,
)
project.status = PROJECT_STATUS_PARSING
db.add(task)
db.commit()
db.refresh(task)
publish_task_progress_event(task)
async_result = parse_project_media.delay(task.id)
task.celery_task_id = async_result.id
db.commit()
db.refresh(task)
publish_task_progress_event(task)
return task

View File

@@ -180,6 +180,7 @@ class PredictRequest(BaseModel):
prompt_type: str # point / box / semantic
prompt_data: Any
model: Optional[str] = None
options: Optional[dict[str, Any]] = None
class PredictResponse(BaseModel):
@@ -201,6 +202,8 @@ class AiModelStatus(BaseModel):
python_ok: bool = True
torch_ok: bool = True
cuda_required: bool = False
external_available: bool = False
external_python: Optional[str] = None
class GpuStatus(BaseModel):

View File

@@ -20,9 +20,11 @@ from services.frame_parser import (
upload_frames_to_minio,
)
from statuses import (
PROJECT_STATUS_PENDING,
PROJECT_STATUS_ERROR,
PROJECT_STATUS_PARSING,
PROJECT_STATUS_READY,
TASK_STATUS_CANCELLED,
TASK_STATUS_FAILED,
TASK_STATUS_RUNNING,
TASK_STATUS_SUCCESS,
@@ -31,6 +33,10 @@ from statuses import (
logger = logging.getLogger(__name__)
class TaskCancelled(RuntimeError):
"""Raised internally when a persisted task has been cancelled."""
def _now() -> datetime:
return datetime.now(timezone.utc)
@@ -66,12 +72,29 @@ def _set_task_state(
publish_task_progress_event(task)
def _project_status_after_stop(project: Project) -> str:
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
db.refresh(task)
if task.status == TASK_STATUS_CANCELLED:
raise TaskCancelled("Task was cancelled")
def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
"""Parse one project's media and update task progress in the database."""
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
if not task:
raise ValueError(f"Task not found: {task_id}")
if task.status == TASK_STATUS_CANCELLED:
return {
"task_id": task.id,
"status": TASK_STATUS_CANCELLED,
"message": task.message or "任务已取消",
}
if task.project_id is None:
_set_task_state(
db,
@@ -111,6 +134,7 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
db.commit()
raise ValueError("Project has no media uploaded")
_ensure_not_cancelled(db, task)
project.status = PROJECT_STATUS_PARSING
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="后台解析已启动", started=True)
@@ -121,6 +145,7 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
os.makedirs(output_dir, exist_ok=True)
try:
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=15, message="正在下载媒体文件")
if effective_source == "dicom":
dcm_dir = os.path.join(tmp_dir, "dcm")
@@ -129,20 +154,24 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
client = get_minio_client()
objects = list(client.list_objects(BUCKET_NAME, prefix=project.video_path, recursive=True))
for obj in objects:
_ensure_not_cancelled(db, task)
if obj.object_name.lower().endswith(".dcm"):
data = download_file(obj.object_name)
local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name))
with open(local_dcm, "wb") as f:
f.write(data)
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=35, message="正在解析 DICOM 序列")
frame_files = parse_dicom(dcm_dir, output_dir)
else:
_ensure_not_cancelled(db, task)
media_bytes = download_file(project.video_path)
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
with open(local_path, "wb") as f:
f.write(media_bytes)
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=35, message="正在使用 FFmpeg/OpenCV 拆帧")
frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps))
project.original_fps = original_fps
@@ -158,12 +187,15 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
except Exception as exc: # noqa: BLE001
logger.warning("Thumbnail extraction failed: %s", exc)
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=70, message="正在上传帧到对象存储")
object_names = upload_frames_to_minio(frame_files, project.id)
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=85, message="正在写入帧索引")
frames_out = []
for idx, obj_name in enumerate(object_names):
_ensure_not_cancelled(db, task)
local_frame = frame_files[idx]
try:
import cv2
@@ -203,6 +235,23 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
)
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project.id)
return result
except TaskCancelled:
project.status = _project_status_after_stop(project)
task.status = TASK_STATUS_CANCELLED
task.progress = 100
task.message = task.message or "任务已取消"
task.error = task.error or "Cancelled by user"
task.finished_at = task.finished_at or _now()
db.commit()
db.refresh(task)
publish_task_progress_event(task)
logger.info("Parse task cancelled: task_id=%s project_id=%s", task.id, project.id)
return {
"task_id": task.id,
"project_id": project.id,
"status": TASK_STATUS_CANCELLED,
"message": task.message,
}
except Exception as exc: # noqa: BLE001
project.status = PROJECT_STATUS_ERROR
_set_task_state(

View File

@@ -9,8 +9,14 @@ the package.
from __future__ import annotations
import importlib.util
import json
import logging
import os
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from typing import Any
import numpy as np
@@ -41,6 +47,8 @@ class SAM3Engine:
self._processor: Any | None = None
self._model_loaded = False
self._last_error: str | None = None
self._external_status_cache: dict[str, Any] | None = None
self._external_status_checked_at = 0.0
def _python_ok(self) -> bool:
return sys.version_info >= (3, 12)
@@ -51,6 +59,81 @@ class SAM3Engine:
def _can_load(self) -> bool:
return bool(SAM3_PACKAGE_AVAILABLE and TORCH_AVAILABLE and self._python_ok() and self._gpu_ok())
def _worker_path(self) -> Path:
return Path(__file__).with_name("sam3_external_worker.py")
def _external_python_exists(self) -> bool:
return bool(settings.sam3_external_enabled and os.path.isfile(settings.sam3_external_python))
def _external_status(self, force: bool = False) -> dict[str, Any]:
now = time.monotonic()
if (
not force
and self._external_status_cache is not None
and now - self._external_status_checked_at < settings.sam3_status_cache_seconds
):
return self._external_status_cache
if not settings.sam3_external_enabled:
status = {
"available": False,
"package_available": False,
"python_ok": False,
"torch_ok": False,
"cuda_available": False,
"device": "unavailable",
"message": "SAM 3 external runtime is disabled.",
}
elif not self._external_python_exists():
status = {
"available": False,
"package_available": False,
"python_ok": False,
"torch_ok": False,
"cuda_available": False,
"device": "unavailable",
"message": f"SAM 3 external Python not found: {settings.sam3_external_python}",
}
else:
try:
env = os.environ.copy()
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
completed = subprocess.run(
[settings.sam3_external_python, str(self._worker_path()), "--status"],
capture_output=True,
text=True,
timeout=min(settings.sam3_timeout_seconds, 30),
check=False,
env=env,
)
if completed.returncode != 0:
detail = completed.stderr.strip() or completed.stdout.strip()
status = {
"available": False,
"package_available": False,
"python_ok": False,
"torch_ok": False,
"cuda_available": False,
"device": "unavailable",
"message": f"SAM 3 external status failed: {detail}",
}
else:
status = json.loads(completed.stdout)
except Exception as exc: # noqa: BLE001
status = {
"available": False,
"package_available": False,
"python_ok": False,
"torch_ok": False,
"cuda_available": False,
"device": "unavailable",
"message": f"SAM 3 external status failed: {exc}",
}
self._external_status_cache = status
self._external_status_checked_at = now
return status
def _load_model(self) -> None:
if self._model_loaded:
return
@@ -92,26 +175,86 @@ class SAM3Engine:
return "SAM 3 dependencies are present; model will load on first inference."
def status(self) -> dict:
available = self._can_load()
external_status = self._external_status()
available = bool(self._can_load() or external_status.get("available"))
external_ready = bool(external_status.get("available"))
message = self._last_error or self._status_message()
if self._processor is not None:
message = "SAM 3 model loaded and ready."
elif external_ready:
message = "SAM 3 external runtime is ready; model will load in the helper process on inference."
elif external_status.get("message") and not self._can_load():
message = str(external_status["message"])
return {
"id": "sam3",
"label": "SAM 3",
"available": available,
"loaded": self._processor is not None,
"device": "cuda" if self._gpu_ok() else "unavailable",
"device": "cuda" if self._gpu_ok() else str(external_status.get("device", "unavailable")),
"supports": ["semantic"],
"message": "SAM 3 model loaded and ready." if self._processor is not None else (self._last_error or self._status_message()),
"package_available": SAM3_PACKAGE_AVAILABLE,
"checkpoint_exists": SAM3_PACKAGE_AVAILABLE,
"message": message,
"package_available": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("package_available")),
"checkpoint_exists": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("checkpoint_access")),
"checkpoint_path": f"official/HuggingFace ({settings.sam3_model_version})",
"python_ok": self._python_ok(),
"torch_ok": TORCH_AVAILABLE,
"python_ok": bool(self._python_ok() or external_status.get("python_ok")),
"torch_ok": bool(TORCH_AVAILABLE or external_status.get("torch_ok")),
"cuda_required": True,
"external_available": external_ready,
"external_python": settings.sam3_external_python if settings.sam3_external_enabled else None,
}
def _predict_semantic_external(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
status = self._external_status(force=True)
if not status.get("available"):
raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.")
with tempfile.TemporaryDirectory(prefix="sam3_") as tmpdir:
tmp_path = Path(tmpdir)
image_path = tmp_path / "image.png"
request_path = tmp_path / "request.json"
Image.fromarray(image).save(image_path)
request_path.write_text(
json.dumps(
{
"image_path": str(image_path),
"text": text.strip(),
"model_version": settings.sam3_model_version,
"confidence_threshold": settings.sam3_confidence_threshold,
},
ensure_ascii=False,
),
encoding="utf-8",
)
env = os.environ.copy()
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
completed = subprocess.run(
[settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)],
capture_output=True,
text=True,
timeout=settings.sam3_timeout_seconds,
check=False,
env=env,
)
if completed.returncode != 0:
detail = completed.stderr.strip() or completed.stdout.strip()
try:
parsed = json.loads(detail)
detail = parsed.get("error", detail)
except Exception: # noqa: BLE001
pass
raise RuntimeError(f"SAM 3 external inference failed: {detail}")
payload = json.loads(completed.stdout)
if payload.get("error"):
raise RuntimeError(str(payload["error"]))
return payload.get("polygons", []), payload.get("scores", [])
def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
if not text.strip():
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
if not self._can_load() and self._external_status().get("available"):
return self._predict_semantic_external(image, text)
if not self._ensure_ready():
raise RuntimeError(self.status()["message"])

View File

@@ -0,0 +1,190 @@
"""Standalone SAM 3 helper for the dedicated Python 3.12 runtime.
The main FastAPI backend can keep running in the existing Python 3.11/SAM 2
environment while this helper is executed with a separate conda env that meets
SAM 3's stricter runtime requirements.
"""
from __future__ import annotations
import argparse
import importlib.util
import json
import os
import sys
from pathlib import Path
from typing import Any
import numpy as np
from PIL import Image
def _torch_status() -> tuple[bool, str | None, str | None, str | None]:
try:
import torch
cuda_available = bool(torch.cuda.is_available())
return (
cuda_available,
getattr(torch, "__version__", None),
getattr(torch.version, "cuda", None),
torch.cuda.get_device_name(0) if cuda_available else None,
)
except Exception: # noqa: BLE001
return False, None, None, None
def _compact_error(exc: Exception) -> str:
lines = [line.strip() for line in str(exc).splitlines() if line.strip()]
for line in lines:
if "Access to model" in line or "Cannot access gated repo" in line:
return line
return lines[0] if lines else exc.__class__.__name__
def _checkpoint_access(model_version: str) -> tuple[bool, str | None]:
try:
from huggingface_hub import hf_hub_download
repo_id = "facebook/sam3.1" if model_version == "sam3.1" else "facebook/sam3"
hf_hub_download(repo_id=repo_id, filename="config.json")
return True, None
except Exception as exc: # noqa: BLE001
return False, _compact_error(exc)
def runtime_status() -> dict[str, Any]:
model_version = os.environ.get("SAM3_MODEL_VERSION", "sam3")
package_error = None
package_available = importlib.util.find_spec("sam3") is not None
if package_available:
try:
import sam3 # noqa: F401
except Exception as exc: # noqa: BLE001
package_available = False
package_error = str(exc)
cuda_available, torch_version, cuda_version, device_name = _torch_status()
python_ok = sys.version_info >= (3, 12)
checkpoint_access = False
checkpoint_error = None
if package_available:
checkpoint_access, checkpoint_error = _checkpoint_access(model_version)
available = bool(package_available and python_ok and cuda_available and checkpoint_access)
missing = []
if not python_ok:
missing.append("Python 3.12+ runtime")
if not package_available:
missing.append(f"sam3 package ({package_error})" if package_error else "sam3 package")
if torch_version is None:
missing.append("PyTorch")
if not cuda_available:
missing.append("CUDA GPU")
if package_available and not checkpoint_access:
missing.append(f"Hugging Face checkpoint access ({checkpoint_error})")
return {
"available": available,
"package_available": package_available,
"checkpoint_access": checkpoint_access,
"python_ok": python_ok,
"torch_ok": torch_version is not None,
"torch_version": torch_version,
"cuda_version": cuda_version,
"cuda_available": cuda_available,
"device": "cuda" if cuda_available else "unavailable",
"device_name": device_name,
"message": (
"SAM 3 external runtime is ready."
if available
else f"SAM 3 external runtime unavailable: missing {', '.join(missing)}."
),
}
def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
import cv2
if mask.dtype != np.uint8:
mask = (mask > 0).astype(np.uint8)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
height, width = mask.shape[:2]
largest = []
for contour in contours:
if len(contour) > len(largest):
largest = contour
if len(largest) < 3:
return []
return [[float(point[0][0]) / width, float(point[0][1]) / height] for point in largest]
def _to_numpy(value: Any) -> np.ndarray:
if hasattr(value, "detach"):
value = value.detach().cpu().numpy()
elif hasattr(value, "cpu"):
value = value.cpu().numpy()
return np.asarray(value)
def predict(request_path: Path) -> dict[str, Any]:
import torch
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.model_builder import build_sam3_image_model
payload = json.loads(request_path.read_text(encoding="utf-8"))
image_path = Path(payload["image_path"])
text = str(payload["text"]).strip()
threshold = float(payload.get("confidence_threshold", 0.5))
if not text:
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
image = Image.open(image_path).convert("RGB")
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
model = build_sam3_image_model()
processor = Sam3Processor(model, confidence_threshold=threshold)
state = processor.set_image(image)
output = processor.set_text_prompt(state=state, prompt=text)
masks = _to_numpy(output.get("masks", []))
scores = _to_numpy(output.get("scores", []))
if masks.ndim == 4:
masks = masks[:, 0]
elif masks.ndim == 3 and masks.shape[0] == 1:
masks = masks[None, 0]
polygons = []
for mask in masks:
polygon = _mask_to_polygon(mask)
if polygon:
polygons.append(polygon)
return {
"polygons": polygons,
"scores": scores.astype(float).tolist() if scores.size else [],
}
def main() -> int:
parser = argparse.ArgumentParser(description="SAM 3 external runtime helper")
parser.add_argument("--status", action="store_true")
parser.add_argument("--request", type=Path)
args = parser.parse_args()
try:
if args.status:
print(json.dumps(runtime_status(), ensure_ascii=False))
return 0
if args.request:
print(json.dumps(predict(args.request), ensure_ascii=False))
return 0
parser.error("Use --status or --request")
except Exception as exc: # noqa: BLE001
print(json.dumps({"error": str(exc)}, ensure_ascii=False), file=sys.stderr)
return 1
return 2
if __name__ == "__main__":
raise SystemExit(main())

24
backend/setup_sam3_env.sh Executable file
View File

@@ -0,0 +1,24 @@
#!/usr/bin/env bash
# Create the dedicated SAM 3 runtime used by backend/services/sam3_external_worker.py.
# Keep Hugging Face tokens outside this repository, for example:
# export HF_TOKEN=...
# huggingface-cli login --token "$HF_TOKEN"
set -euo pipefail
ENV_NAME="${SAM3_CONDA_ENV:-sam3}"
source /home/wkmgc/miniconda3/etc/profile.d/conda.sh
if ! conda env list | awk '{print $1}' | grep -qx "$ENV_NAME"; then
conda create -y -n "$ENV_NAME" python=3.12
fi
conda activate "$ENV_NAME"
python -m pip install --upgrade pip
python -m pip install "setuptools<81"
python -m pip install torch==2.10.0 torchvision --index-url https://download.pytorch.org/whl/cu128
python -m pip install opencv-python pillow numpy huggingface_hub einops pycocotools psutil
python -m pip install git+https://github.com/facebookresearch/sam3.git
python /home/wkmgc/Desktop/Seg_Server/backend/services/sam3_external_worker.py --status

View File

@@ -9,3 +9,7 @@ TASK_STATUS_QUEUED = "queued"
TASK_STATUS_RUNNING = "running"
TASK_STATUS_SUCCESS = "success"
TASK_STATUS_FAILED = "failed"
TASK_STATUS_CANCELLED = "cancelled"
TASK_ACTIVE_STATUSES = {TASK_STATUS_QUEUED, TASK_STATUS_RUNNING}
TASK_TERMINAL_STATUSES = {TASK_STATUS_SUCCESS, TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}

View File

@@ -1,4 +1,5 @@
import numpy as np
import cv2
def _create_project_and_frame(client):
@@ -46,6 +47,46 @@ def test_predict_accepts_point_object_with_labels(client, monkeypatch):
assert calls["args"] == ([[0.5, 0.5], [0.1, 0.1]], [1, 0])
def test_predict_applies_crop_and_background_filter_options(client, monkeypatch):
_, frame, _ = _create_project_and_frame(client)
calls = {}
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((100, 200, 3), dtype=np.uint8))
def fake_predict_points(model, image, points, labels):
calls["shape"] = image.shape
calls["points"] = points
calls["labels"] = labels
return (
[
[[0.0, 0.0], [0.2, 0.0], [0.2, 0.2]],
[[0.45, 0.45], [0.55, 0.45], [0.55, 0.55]],
],
[0.9, 0.01],
)
monkeypatch.setattr("routers.ai.sam_registry.predict_points", fake_predict_points)
response = client.post("/api/ai/predict", json={
"image_id": frame["id"],
"prompt_type": "point",
"prompt_data": {"points": [[0.5, 0.5], [0.52, 0.52]], "labels": [1, 0]},
"options": {
"crop_to_prompt": True,
"crop_margin": 0.1,
"auto_filter_background": True,
"min_score": 0.05,
},
})
assert response.status_code == 200
assert calls["shape"][0] < 100
assert calls["shape"][1] < 200
assert calls["labels"] == [1, 0]
assert response.json()["scores"] == [0.9]
polygon = response.json()["polygons"][0]
assert all(0.0 <= coord <= 1.0 for point in polygon for coord in point)
def test_predict_box_and_semantic_fallback(client, monkeypatch):
_, frame, _ = _create_project_and_frame(client)
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
@@ -246,3 +287,62 @@ def test_update_and_delete_annotation_validation(client):
f"/api/ai/annotations/{saved['id']}",
json={"template_id": 999},
).status_code == 404
def test_import_gt_mask_creates_annotations_with_seed_points(client):
project, frame, template = _create_project_and_frame(client)
mask = np.zeros((360, 640), dtype=np.uint8)
cv2.rectangle(mask, (100, 80), (260, 220), 255, thickness=-1)
ok, encoded = cv2.imencode(".png", mask)
assert ok
response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(frame["id"]),
"template_id": str(template["id"]),
"label": "Imported GT",
"color": "#22c55e",
},
files={"file": ("mask.png", encoded.tobytes(), "image/png")},
)
assert response.status_code == 201
body = response.json()
assert len(body) == 1
assert body[0]["project_id"] == project["id"]
assert body[0]["frame_id"] == frame["id"]
assert body[0]["template_id"] == template["id"]
assert body[0]["mask_data"]["label"] == "Imported GT"
assert body[0]["mask_data"]["source"] == "gt_mask"
assert body[0]["mask_data"]["gt_label_value"] == 255
assert len(body[0]["mask_data"]["polygons"][0]) >= 3
assert len(body[0]["points"]) == 1
assert 0.0 <= body[0]["points"][0][0] <= 1.0
assert 0.0 <= body[0]["points"][0][1] <= 1.0
def test_import_gt_mask_splits_label_values(client):
project, frame, _ = _create_project_and_frame(client)
mask = np.zeros((360, 640), dtype=np.uint8)
cv2.rectangle(mask, (20, 20), (120, 120), 1, thickness=-1)
cv2.rectangle(mask, (220, 80), (320, 180), 2, thickness=-1)
ok, encoded = cv2.imencode(".png", mask)
assert ok
response = client.post(
"/api/ai/import-gt-mask",
data={
"project_id": str(project["id"]),
"frame_id": str(frame["id"]),
"label": "GT Class",
},
files={"file": ("labels.png", encoded.tobytes(), "image/png")},
)
assert response.status_code == 201
body = sorted(response.json(), key=lambda item: item["mask_data"]["gt_label_value"])
assert [item["mask_data"]["gt_label_value"] for item in body] == [1, 2]
assert [item["mask_data"]["label"] for item in body] == ["GT Class 1", "GT Class 2"]
assert all(len(item["points"]) == 1 for item in body)

View File

@@ -59,7 +59,9 @@ def test_dashboard_overview_uses_persisted_records(client, db_session):
"name": "Pending Project",
"progress": 35,
"status": "正在使用 FFmpeg/OpenCV 拆帧",
"raw_status": "running",
"frame_count": 0,
"error": None,
"updated_at": body["tasks"][0]["updated_at"],
},
]

View File

@@ -1,6 +1,10 @@
import zipfile
import json
from io import BytesIO
import cv2
import numpy as np
def _seed_export_data(client):
project = client.post("/api/projects", json={"name": "Export Project"}).json()
@@ -58,7 +62,55 @@ def test_export_masks_zip(client):
assert response.status_code == 200
assert response.headers["content-type"].startswith("application/zip")
with zipfile.ZipFile(BytesIO(response.content)) as archive:
assert archive.namelist() == [f"mask_{annotation['id']:06d}.png"]
assert archive.namelist() == [
f"mask_{annotation['id']:06d}.png",
"semantic_frame_000000.png",
"semantic_classes.json",
]
def test_export_masks_uses_z_index_for_semantic_fusion(client):
project = client.post("/api/projects", json={"name": "Fusion Project"}).json()
frame = client.post(f"/api/projects/{project['id']}/frames", json={
"project_id": project["id"],
"frame_index": 0,
"image_url": "frames/0.jpg",
"width": 20,
"height": 20,
}).json()
low = client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"mask_data": {
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
"label": "Low",
"color": "#00ff00",
"class": {"id": "low", "name": "Low", "color": "#00ff00", "zIndex": 10},
},
}).json()
high = client.post("/api/ai/annotate", json={
"project_id": project["id"],
"frame_id": frame["id"],
"mask_data": {
"polygons": [[[0.4, 0.4], [0.9, 0.4], [0.9, 0.9], [0.4, 0.9]]],
"label": "High",
"color": "#ff0000",
"class": {"id": "high", "name": "High", "color": "#ff0000", "zIndex": 20},
},
}).json()
response = client.get(f"/api/export/{project['id']}/masks")
assert response.status_code == 200
with zipfile.ZipFile(BytesIO(response.content)) as archive:
assert f"mask_{low['id']:06d}.png" in archive.namelist()
assert f"mask_{high['id']:06d}.png" in archive.namelist()
legend = json.loads(archive.read("semantic_classes.json"))
high_value = next(item["value"] for item in legend["classes"] if item["key"] == "class:high")
semantic_bytes = np.frombuffer(archive.read("semantic_frame_000000.png"), dtype=np.uint8)
semantic = cv2.imdecode(semantic_bytes, cv2.IMREAD_GRAYSCALE)
assert semantic[10, 10] == high_value
def test_export_missing_project_returns_404(client):

View File

@@ -140,3 +140,25 @@ def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp
assert project_detail["status"] == "ready"
frames = client.get(f"/api/projects/{project['id']}/frames").json()
assert "frame_000001.jpg" in frames[0]["image_url"]
def test_parse_task_runner_skips_already_cancelled_task(db_session):
from models import ProcessingTask
from services.media_task_runner import run_parse_media_task
task = ProcessingTask(
task_type="parse_video",
status="cancelled",
progress=100,
message="任务已取消",
project_id=1,
payload={"source_type": "video"},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
result = run_parse_media_task(db_session, task.id)
assert result["status"] == "cancelled"
assert result["message"] == "任务已取消"

View File

@@ -26,6 +26,25 @@ def test_task_progress_payload_uses_dashboard_task_id_and_project_name():
assert payload["status"] == "解析完成"
def test_task_progress_payload_marks_cancelled_tasks():
task = SimpleNamespace(
id=13,
project_id=7,
project=SimpleNamespace(name="demo.mp4"),
status="cancelled",
progress=100,
message="任务已取消",
error="Cancelled by user",
updated_at=None,
)
payload = task_progress_payload(task)
assert payload["type"] == "cancelled"
assert payload["status"] == "任务已取消"
assert payload["error"] == "Cancelled by user"
def test_publish_progress_event_writes_json_to_redis(monkeypatch):
calls = []

View File

@@ -0,0 +1,112 @@
import json
from pathlib import Path
import numpy as np
from services.sam3_engine import SAM3Engine
class _Completed:
def __init__(self, returncode=0, stdout="", stderr=""):
self.returncode = returncode
self.stdout = stdout
self.stderr = stderr
def _external_settings(monkeypatch, python_path: Path):
python_path.write_text("#!/usr/bin/env python\n", encoding="utf-8")
python_path.chmod(0o755)
monkeypatch.setattr("services.sam3_engine.SAM3_PACKAGE_AVAILABLE", False)
monkeypatch.setattr("services.sam3_engine.TORCH_AVAILABLE", False)
monkeypatch.setattr("services.sam3_engine.settings.sam3_external_enabled", True)
monkeypatch.setattr("services.sam3_engine.settings.sam3_external_python", str(python_path))
monkeypatch.setattr("services.sam3_engine.settings.sam3_timeout_seconds", 10)
monkeypatch.setattr("services.sam3_engine.settings.sam3_status_cache_seconds", 30)
monkeypatch.setattr("services.sam3_engine.settings.sam3_confidence_threshold", 0.4)
def test_sam3_status_reports_external_runtime_ready(tmp_path, monkeypatch):
_external_settings(monkeypatch, tmp_path / "python")
def fake_run(args, **_kwargs):
assert "--status" in args
return _Completed(stdout=json.dumps({
"available": True,
"package_available": True,
"python_ok": True,
"torch_ok": True,
"cuda_available": True,
"device": "cuda",
"message": "ready",
}))
monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run)
status = SAM3Engine().status()
assert status["available"] is True
assert status["external_available"] is True
assert status["package_available"] is True
assert status["python_ok"] is True
assert status["message"] == "SAM 3 external runtime is ready; model will load in the helper process on inference."
def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch):
_external_settings(monkeypatch, tmp_path / "python")
calls = []
def fake_run(args, **_kwargs):
calls.append(args)
if "--status" in args:
return _Completed(stdout=json.dumps({
"available": True,
"package_available": True,
"python_ok": True,
"torch_ok": True,
"cuda_available": True,
"device": "cuda",
"message": "ready",
}))
request_path = Path(args[-1])
request = json.loads(request_path.read_text(encoding="utf-8"))
assert request["text"] == "vessel"
assert request["confidence_threshold"] == 0.4
assert Path(request["image_path"]).exists()
return _Completed(stdout=json.dumps({
"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
"scores": [0.91],
}))
monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run)
polygons, scores = SAM3Engine().predict_semantic(np.zeros((8, 8, 3), dtype=np.uint8), " vessel ")
assert polygons == [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]]
assert scores == [0.91]
assert any("--request" in args for args in calls)
def test_sam3_predict_semantic_reports_external_errors(tmp_path, monkeypatch):
_external_settings(monkeypatch, tmp_path / "python")
def fake_run(args, **_kwargs):
if "--status" in args:
return _Completed(stdout=json.dumps({
"available": True,
"package_available": True,
"python_ok": True,
"torch_ok": True,
"cuda_available": True,
"device": "cuda",
"message": "ready",
}))
return _Completed(returncode=1, stderr=json.dumps({"error": "HF access denied"}))
monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run)
try:
SAM3Engine().predict_semantic(np.zeros((8, 8, 3), dtype=np.uint8), "vessel")
except RuntimeError as exc:
assert "HF access denied" in str(exc)
else:
raise AssertionError("Expected SAM 3 external inference failure.")

104
backend/tests/test_tasks.py Normal file
View File

@@ -0,0 +1,104 @@
from models import ProcessingTask
def test_cancel_task_revokes_celery_and_updates_project(client, db_session, monkeypatch):
project = client.post("/api/projects", json={
"name": "Cancelable",
"video_path": "uploads/1/clip.mp4",
"status": "parsing",
}).json()
task = ProcessingTask(
task_type="parse_video",
status="running",
progress=35,
message="正在使用 FFmpeg/OpenCV 拆帧",
project_id=project["id"],
celery_task_id="celery-1",
payload={"source_type": "video"},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
revoked = []
published = []
monkeypatch.setattr(
"routers.tasks.celery_app.control.revoke",
lambda celery_id, terminate, signal: revoked.append((celery_id, terminate, signal)),
)
monkeypatch.setattr("routers.tasks.publish_task_progress_event", lambda event_task: published.append(event_task.status))
response = client.post(f"/api/tasks/{task.id}/cancel")
assert response.status_code == 200
body = response.json()
assert body["status"] == "cancelled"
assert body["progress"] == 100
assert body["message"] == "任务已取消"
assert body["error"] == "Cancelled by user"
assert revoked == [("celery-1", True, "SIGTERM")]
assert published == ["cancelled"]
assert client.get(f"/api/projects/{project['id']}").json()["status"] == "pending"
def test_retry_task_creates_fresh_parse_task(client, db_session, monkeypatch):
project = client.post("/api/projects", json={
"name": "Retryable",
"video_path": "uploads/2/clip.mp4",
"source_type": "video",
"status": "error",
}).json()
task = ProcessingTask(
task_type="parse_video",
status="failed",
progress=100,
message="解析失败",
error="ffmpeg failed",
project_id=project["id"],
payload={"source_type": "video"},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
class FakeAsyncResult:
id = "celery-retry"
queued = []
published = []
monkeypatch.setattr("routers.tasks.parse_project_media.delay", lambda task_id: queued.append(task_id) or FakeAsyncResult())
monkeypatch.setattr("routers.tasks.publish_task_progress_event", lambda event_task: published.append((event_task.id, event_task.status)))
response = client.post(f"/api/tasks/{task.id}/retry")
assert response.status_code == 202
body = response.json()
assert body["id"] != task.id
assert body["status"] == "queued"
assert body["progress"] == 0
assert body["celery_task_id"] == "celery-retry"
assert body["payload"]["retry_of"] == task.id
assert queued == [body["id"]]
assert published[0] == (body["id"], "queued")
assert published[-1] == (body["id"], "queued")
assert client.get(f"/api/projects/{project['id']}").json()["status"] == "parsing"
def test_task_actions_reject_invalid_states(client, db_session):
project = client.post("/api/projects", json={
"name": "Done",
"video_path": "uploads/3/clip.mp4",
}).json()
task = ProcessingTask(
task_type="parse_video",
status="success",
progress=100,
project_id=project["id"],
payload={"source_type": "video"},
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
assert client.post(f"/api/tasks/{task.id}/cancel").status_code == 409
assert client.post(f"/api/tasks/{task.id}/retry").status_code == 409