feat: 建立 SAM2 标注闭环基线
- 打通工作区真实标注闭环:支持手工多边形、矩形、圆形、点区域和线段生成 mask,并可保存、回显、更新和删除后端 annotation。 - 增强 polygon 编辑器:支持顶点拖动、顶点删除、边中点插入、多 polygon 子区域选择编辑,以及区域合并和区域去除。 - 接入 GT mask 导入:后端支持二值/多类别 mask 拆分、contour 转 polygon、distance transform seed point,前端支持导入、回显和 seed point 拖动编辑。 - 完善导出能力:COCO JSON 导出对齐前端,PNG mask ZIP 同时包含单标注 mask、按 zIndex 融合的 semantic_frame 和 semantic_classes.json。 - 打通异步任务管理:新增任务取消、重试、失败详情接口与 Dashboard 控件,worker 支持取消状态检查并通过 Redis/WebSocket 推送 cancelled 事件。 - 对接 Dashboard 后端数据:概览统计、解析队列和实时流转记录从 FastAPI 聚合接口与 WebSocket 更新。 - 增强 AI 推理参数:前端发送 crop_to_prompt、auto_filter_background 和 min_score,后端支持点/框 prompt 局部裁剪推理、结果回映射和负向点/低分过滤。 - 接入 SAM3 基础设施:新增独立 Python 3.12 sam3 环境安装脚本、外部 worker helper、后端桥接和真实 Python/CUDA/包/HF checkpoint access 状态检测。 - 保留 SAM3 授权边界:当前官方 facebook/sam3 gated 权重未授权时状态接口会返回不可用,不伪装成可推理。 - 增强前端状态管理:新增 mask undo/redo 历史栈、AI 模型选择状态、保存状态 dirty/draft/saved 流转和项目状态归一化。 - 更新前端 API 封装:补充 annotation CRUD、GT mask import、mask ZIP export、task cancel/retry/detail、AI runtime status 和 prediction options。 - 更新 UI 控件:ToolsPalette、AISegmentation、VideoWorkspace 和 CanvasArea 接入真实操作、导入导出、撤销重做、任务控制和模型状态。 - 新增 polygon-clipping 依赖,用于前端区域 union/difference 几何运算。 - 完善后端 schemas/status/progress:补充 AI 模型外部状态字段、任务 cancelled 状态和进度事件 payload。 - 补充测试覆盖:新增后端任务控制、SAM3 桥接、GT mask、导出融合、AI options 测试;补充前端 Canvas、Dashboard、VideoWorkspace、ToolsPalette、API 和 store 测试。 - 更新 README、AGENTS 和 doc 文档:冻结当前需求/设计/测试计划,标注真实功能、剩余 Mock、SAM3 授权边界和后续实施顺序。
This commit is contained in:
@@ -22,7 +22,12 @@ class Settings(BaseSettings):
|
||||
sam_default_model: str = "sam2"
|
||||
sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
|
||||
sam_model_config: str = "configs/sam2/sam2_hiera_t.yaml"
|
||||
sam3_model_version: str = "sam3.1"
|
||||
sam3_model_version: str = "sam3"
|
||||
sam3_external_enabled: bool = True
|
||||
sam3_external_python: str = "/home/wkmgc/miniconda3/envs/sam3/bin/python"
|
||||
sam3_timeout_seconds: int = 300
|
||||
sam3_status_cache_seconds: int = 30
|
||||
sam3_confidence_threshold: float = 0.5
|
||||
|
||||
# App
|
||||
app_env: str = "development"
|
||||
|
||||
@@ -8,7 +8,7 @@ from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from redis_client import get_redis_client
|
||||
from statuses import TASK_STATUS_FAILED, TASK_STATUS_SUCCESS
|
||||
from statuses import TASK_STATUS_CANCELLED, TASK_STATUS_FAILED, TASK_STATUS_SUCCESS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -22,6 +22,8 @@ def _iso_now() -> str:
|
||||
def _event_type(task_status: str) -> str:
|
||||
if task_status == TASK_STATUS_SUCCESS:
|
||||
return "complete"
|
||||
if task_status == TASK_STATUS_CANCELLED:
|
||||
return "cancelled"
|
||||
if task_status == TASK_STATUS_FAILED:
|
||||
return "error"
|
||||
return "progress"
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Response, UploadFile, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
@@ -39,6 +39,140 @@ def _load_frame_image(frame: Frame) -> np.ndarray:
|
||||
raise HTTPException(status_code=500, detail="Failed to load frame image") from exc
|
||||
|
||||
|
||||
def _normalized_contour(contour: np.ndarray, width: int, height: int) -> list[list[float]]:
|
||||
"""Approximate a contour and convert it to normalized polygon coordinates."""
|
||||
arc_length = cv2.arcLength(contour, True)
|
||||
epsilon = max(1.0, arc_length * 0.01)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
points = approx.reshape(-1, 2)
|
||||
if len(points) < 3:
|
||||
points = contour.reshape(-1, 2)
|
||||
return [
|
||||
[
|
||||
min(max(float(x) / max(width, 1), 0.0), 1.0),
|
||||
min(max(float(y) / max(height, 1), 0.0), 1.0),
|
||||
]
|
||||
for x, y in points
|
||||
]
|
||||
|
||||
|
||||
def _contour_bbox(contour: np.ndarray, width: int, height: int) -> list[float]:
|
||||
x, y, w, h = cv2.boundingRect(contour)
|
||||
return [
|
||||
min(max(float(x) / max(width, 1), 0.0), 1.0),
|
||||
min(max(float(y) / max(height, 1), 0.0), 1.0),
|
||||
min(max(float(w) / max(width, 1), 0.0), 1.0),
|
||||
min(max(float(h) / max(height, 1), 0.0), 1.0),
|
||||
]
|
||||
|
||||
|
||||
def _component_seed_point(component_mask: np.ndarray, width: int, height: int) -> list[float]:
|
||||
"""Reduce a binary component to one positive prompt point using distance transform."""
|
||||
dist = cv2.distanceTransform(component_mask.astype(np.uint8), cv2.DIST_L2, 5)
|
||||
_, _, _, max_loc = cv2.minMaxLoc(dist)
|
||||
x, y = max_loc
|
||||
return [
|
||||
min(max(float(x) / max(width, 1), 0.0), 1.0),
|
||||
min(max(float(y) / max(height, 1), 0.0), 1.0),
|
||||
]
|
||||
|
||||
|
||||
def _clamp01(value: float) -> float:
|
||||
return min(max(float(value), 0.0), 1.0)
|
||||
|
||||
|
||||
def _point_in_polygon(point: list[float], polygon: list[list[float]]) -> bool:
|
||||
"""Return whether a normalized point is inside a normalized polygon."""
|
||||
if len(polygon) < 3:
|
||||
return False
|
||||
x, y = point
|
||||
inside = False
|
||||
j = len(polygon) - 1
|
||||
for i, current in enumerate(polygon):
|
||||
xi, yi = current
|
||||
xj, yj = polygon[j]
|
||||
intersects = ((yi > y) != (yj > y)) and (
|
||||
x < (xj - xi) * (y - yi) / ((yj - yi) or 1e-9) + xi
|
||||
)
|
||||
if intersects:
|
||||
inside = not inside
|
||||
j = i
|
||||
return inside
|
||||
|
||||
|
||||
def _crop_bounds_from_points(points: list[list[float]], margin: float) -> tuple[float, float, float, float]:
|
||||
xs = [_clamp01(point[0]) for point in points]
|
||||
ys = [_clamp01(point[1]) for point in points]
|
||||
x1 = max(0.0, min(xs) - margin)
|
||||
y1 = max(0.0, min(ys) - margin)
|
||||
x2 = min(1.0, max(xs) + margin)
|
||||
y2 = min(1.0, max(ys) + margin)
|
||||
if x2 - x1 < 0.05:
|
||||
center = (x1 + x2) / 2
|
||||
x1 = max(0.0, center - 0.025)
|
||||
x2 = min(1.0, center + 0.025)
|
||||
if y2 - y1 < 0.05:
|
||||
center = (y1 + y2) / 2
|
||||
y1 = max(0.0, center - 0.025)
|
||||
y2 = min(1.0, center + 0.025)
|
||||
return x1, y1, x2, y2
|
||||
|
||||
|
||||
def _crop_image(image: np.ndarray, bounds: tuple[float, float, float, float]) -> np.ndarray:
|
||||
height, width = image.shape[:2]
|
||||
x1, y1, x2, y2 = bounds
|
||||
left = int(round(x1 * width))
|
||||
top = int(round(y1 * height))
|
||||
right = max(left + 1, int(round(x2 * width)))
|
||||
bottom = max(top + 1, int(round(y2 * height)))
|
||||
return image[top:bottom, left:right]
|
||||
|
||||
|
||||
def _to_crop_point(point: list[float], bounds: tuple[float, float, float, float]) -> list[float]:
|
||||
x1, y1, x2, y2 = bounds
|
||||
return [
|
||||
_clamp01((float(point[0]) - x1) / max(x2 - x1, 1e-9)),
|
||||
_clamp01((float(point[1]) - y1) / max(y2 - y1, 1e-9)),
|
||||
]
|
||||
|
||||
|
||||
def _from_crop_polygon(
|
||||
polygon: list[list[float]],
|
||||
bounds: tuple[float, float, float, float],
|
||||
) -> list[list[float]]:
|
||||
x1, y1, x2, y2 = bounds
|
||||
return [
|
||||
[
|
||||
_clamp01(x1 + float(point[0]) * (x2 - x1)),
|
||||
_clamp01(y1 + float(point[1]) * (y2 - y1)),
|
||||
]
|
||||
for point in polygon
|
||||
]
|
||||
|
||||
|
||||
def _filter_predictions(
|
||||
polygons: list[list[list[float]]],
|
||||
scores: list[float],
|
||||
options: dict[str, Any],
|
||||
negative_points: list[list[float]] | None = None,
|
||||
) -> tuple[list[list[list[float]]], list[float]]:
|
||||
if not options.get("auto_filter_background"):
|
||||
return polygons, scores
|
||||
|
||||
min_score = float(options.get("min_score", 0.0) or 0.0)
|
||||
next_polygons: list[list[list[float]]] = []
|
||||
next_scores: list[float] = []
|
||||
for index, polygon in enumerate(polygons):
|
||||
score = scores[index] if index < len(scores) else 0.0
|
||||
if score < min_score:
|
||||
continue
|
||||
if negative_points and any(_point_in_polygon(point, polygon) for point in negative_points):
|
||||
continue
|
||||
next_polygons.append(polygon)
|
||||
next_scores.append(score)
|
||||
return next_polygons, next_scores
|
||||
|
||||
|
||||
@router.post(
|
||||
"/predict",
|
||||
response_model=PredictResponse,
|
||||
@@ -58,9 +192,11 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
|
||||
image = _load_frame_image(frame)
|
||||
prompt_type = payload.prompt_type.lower()
|
||||
options = payload.options or {}
|
||||
|
||||
polygons: List[List[List[float]]] = []
|
||||
scores: List[float] = []
|
||||
negative_points: list[list[float]] = []
|
||||
|
||||
try:
|
||||
if prompt_type == "point":
|
||||
@@ -76,13 +212,39 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
raise HTTPException(status_code=400, detail="Invalid point prompt data")
|
||||
if not isinstance(labels, list) or len(labels) != len(points):
|
||||
labels = [1] * len(points)
|
||||
polygons, scores = sam_registry.predict_points(payload.model, image, points, labels)
|
||||
negative_points = [
|
||||
point for point, label in zip(points, labels) if label == 0
|
||||
]
|
||||
inference_image = image
|
||||
inference_points = points
|
||||
crop_bounds = None
|
||||
if options.get("crop_to_prompt"):
|
||||
margin = float(options.get("crop_margin", 0.25) or 0.25)
|
||||
crop_bounds = _crop_bounds_from_points(points, margin)
|
||||
inference_image = _crop_image(image, crop_bounds)
|
||||
inference_points = [_to_crop_point(point, crop_bounds) for point in points]
|
||||
polygons, scores = sam_registry.predict_points(payload.model, inference_image, inference_points, labels)
|
||||
if crop_bounds:
|
||||
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
|
||||
|
||||
elif prompt_type == "box":
|
||||
box = payload.prompt_data
|
||||
if not isinstance(box, list) or len(box) != 4:
|
||||
raise HTTPException(status_code=400, detail="Invalid box prompt data")
|
||||
polygons, scores = sam_registry.predict_box(payload.model, image, box)
|
||||
inference_image = image
|
||||
inference_box = box
|
||||
crop_bounds = None
|
||||
if options.get("crop_to_prompt"):
|
||||
margin = float(options.get("crop_margin", 0.05) or 0.05)
|
||||
crop_bounds = _crop_bounds_from_points([[box[0], box[1]], [box[2], box[3]]], margin)
|
||||
inference_image = _crop_image(image, crop_bounds)
|
||||
inference_box = [
|
||||
*_to_crop_point([box[0], box[1]], crop_bounds),
|
||||
*_to_crop_point([box[2], box[3]], crop_bounds),
|
||||
]
|
||||
polygons, scores = sam_registry.predict_box(payload.model, inference_image, inference_box)
|
||||
if crop_bounds:
|
||||
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
|
||||
|
||||
elif prompt_type == "semantic":
|
||||
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
|
||||
@@ -95,8 +257,9 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
except NotImplementedError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
polygons, scores = _filter_predictions(polygons, scores, options, negative_points)
|
||||
return {"polygons": polygons, "scores": scores}
|
||||
|
||||
|
||||
@@ -161,6 +324,100 @@ def save_annotation(
|
||||
return annotation
|
||||
|
||||
|
||||
@router.post(
|
||||
"/import-gt-mask",
|
||||
response_model=List[AnnotationOut],
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Import a GT mask and reduce components to editable point regions",
|
||||
)
|
||||
async def import_gt_mask(
|
||||
project_id: int = Form(...),
|
||||
frame_id: int = Form(...),
|
||||
template_id: int | None = Form(None),
|
||||
label: str = Form("GT Mask"),
|
||||
color: str = Form("#22c55e"),
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[Annotation]:
|
||||
"""Convert a binary/label mask image into persisted polygon annotations.
|
||||
|
||||
Each connected component becomes one annotation. The `points` field stores a
|
||||
positive seed point at the component's distance-transform center, which gives
|
||||
the frontend an editable point-region representation instead of a static
|
||||
bitmap layer.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
frame = db.query(Frame).filter(Frame.id == frame_id, Frame.project_id == project_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
if template_id is not None:
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
data = await file.read()
|
||||
image = cv2.imdecode(np.frombuffer(data, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
|
||||
if image is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid mask image")
|
||||
|
||||
width = int(frame.width or image.shape[1])
|
||||
height = int(frame.height or image.shape[0])
|
||||
label_values = [int(value) for value in np.unique(image) if int(value) > 0]
|
||||
if not label_values:
|
||||
raise HTTPException(status_code=400, detail="No foreground mask regions found")
|
||||
has_multiple_labels = len(label_values) > 1
|
||||
|
||||
annotations: list[Annotation] = []
|
||||
for label_value in label_values:
|
||||
binary = np.where(image == label_value, 255, 0).astype(np.uint8)
|
||||
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
annotation_label = f"{label} {label_value}" if has_multiple_labels else label
|
||||
|
||||
for contour in contours:
|
||||
if cv2.contourArea(contour) < 1:
|
||||
continue
|
||||
|
||||
polygon = _normalized_contour(contour, image.shape[1], image.shape[0])
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
|
||||
component = np.zeros_like(binary, dtype=np.uint8)
|
||||
cv2.drawContours(component, [contour], -1, 1, thickness=-1)
|
||||
seed_point = _component_seed_point(component, image.shape[1], image.shape[0])
|
||||
bbox = _contour_bbox(contour, image.shape[1], image.shape[0])
|
||||
|
||||
annotation = Annotation(
|
||||
project_id=project_id,
|
||||
frame_id=frame_id,
|
||||
template_id=template_id,
|
||||
mask_data={
|
||||
"polygons": [polygon],
|
||||
"label": annotation_label,
|
||||
"color": color,
|
||||
"source": "gt_mask",
|
||||
"gt_label_value": label_value,
|
||||
"image_size": {"width": width, "height": height},
|
||||
},
|
||||
points=[seed_point],
|
||||
bbox=bbox,
|
||||
)
|
||||
db.add(annotation)
|
||||
annotations.append(annotation)
|
||||
|
||||
if not annotations:
|
||||
raise HTTPException(status_code=400, detail="No foreground mask regions found")
|
||||
|
||||
db.commit()
|
||||
for annotation in annotations:
|
||||
db.refresh(annotation)
|
||||
logger.info("Imported %s GT mask annotations for project_id=%s frame_id=%s", len(annotations), project_id, frame_id)
|
||||
return annotations
|
||||
|
||||
|
||||
@router.get(
|
||||
"/annotations",
|
||||
response_model=List[AnnotationOut],
|
||||
|
||||
@@ -14,6 +14,7 @@ from models import Annotation, Frame, ProcessingTask, Project, Template
|
||||
router = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
|
||||
|
||||
ACTIVE_TASK_STATUSES = {"queued", "running"}
|
||||
MONITORED_TASK_STATUSES = {"queued", "running", "failed", "cancelled"}
|
||||
|
||||
|
||||
def _system_load_percent() -> int:
|
||||
@@ -42,7 +43,9 @@ def _task_payload(task: ProcessingTask) -> dict[str, Any]:
|
||||
"name": task.project.name if task.project else f"任务 {task.id}",
|
||||
"progress": task.progress,
|
||||
"status": task.message or task.status,
|
||||
"raw_status": task.status,
|
||||
"frame_count": (task.result or {}).get("frames_extracted", 0),
|
||||
"error": task.error,
|
||||
"updated_at": _iso_or_none(task.updated_at),
|
||||
}
|
||||
|
||||
@@ -68,7 +71,7 @@ def get_dashboard_overview(db: Session = Depends(get_db)) -> dict[str, Any]:
|
||||
.limit(50)
|
||||
.all()
|
||||
)
|
||||
tasks = [_task_payload(task) for task in recent_tasks if task.status in ACTIVE_TASK_STATUSES]
|
||||
tasks = [_task_payload(task) for task in recent_tasks if task.status in MONITORED_TASK_STATUSES]
|
||||
|
||||
activities: list[dict[str, Any]] = []
|
||||
for task in recent_tasks[:10]:
|
||||
|
||||
@@ -37,6 +37,54 @@ def _mask_from_polygon(
|
||||
return mask
|
||||
|
||||
|
||||
def _annotation_z_index(annotation: Annotation) -> int:
|
||||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||
if isinstance(class_meta, dict) and class_meta.get("zIndex") is not None:
|
||||
try:
|
||||
return int(class_meta["zIndex"])
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
if annotation.template and annotation.template.z_index is not None:
|
||||
return int(annotation.template.z_index)
|
||||
return 0
|
||||
|
||||
|
||||
def _annotation_class_key(annotation: Annotation) -> str:
|
||||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||||
if isinstance(class_meta, dict):
|
||||
if class_meta.get("id"):
|
||||
return f"class:{class_meta['id']}"
|
||||
if class_meta.get("name"):
|
||||
return f"name:{class_meta['name']}"
|
||||
if annotation.template_id:
|
||||
return f"template:{annotation.template_id}"
|
||||
return f"annotation:{annotation.id}"
|
||||
|
||||
|
||||
def _annotation_label(annotation: Annotation) -> str:
|
||||
mask_data = annotation.mask_data or {}
|
||||
class_meta = mask_data.get("class") or {}
|
||||
if isinstance(class_meta, dict) and class_meta.get("name"):
|
||||
return str(class_meta["name"])
|
||||
if mask_data.get("label"):
|
||||
return str(mask_data["label"])
|
||||
if annotation.template and annotation.template.name:
|
||||
return str(annotation.template.name)
|
||||
return f"Annotation {annotation.id}"
|
||||
|
||||
|
||||
def _annotation_color(annotation: Annotation) -> str:
|
||||
mask_data = annotation.mask_data or {}
|
||||
class_meta = mask_data.get("class") or {}
|
||||
if isinstance(class_meta, dict) and class_meta.get("color"):
|
||||
return str(class_meta["color"])
|
||||
if mask_data.get("color"):
|
||||
return str(mask_data["color"])
|
||||
if annotation.template and annotation.template.color:
|
||||
return str(annotation.template.color)
|
||||
return "#ffffff"
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/coco",
|
||||
summary="Export annotations in COCO format",
|
||||
@@ -150,19 +198,46 @@ def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResp
|
||||
summary="Export PNG masks as a ZIP archive",
|
||||
)
|
||||
def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
|
||||
"""Export all annotation masks as individual PNG files inside a ZIP archive."""
|
||||
"""Export individual masks plus z-index fused semantic masks inside a ZIP."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
import cv2
|
||||
|
||||
annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == project_id)
|
||||
.all()
|
||||
)
|
||||
frames = (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id)
|
||||
.order_by(Frame.frame_index)
|
||||
.all()
|
||||
)
|
||||
|
||||
class_values: dict[str, int] = {}
|
||||
semantic_classes: list[dict[str, Any]] = []
|
||||
|
||||
def class_value(annotation: Annotation) -> int:
|
||||
key = _annotation_class_key(annotation)
|
||||
if key not in class_values:
|
||||
value = len(class_values) + 1
|
||||
class_values[key] = value
|
||||
semantic_classes.append({
|
||||
"value": value,
|
||||
"key": key,
|
||||
"label": _annotation_label(annotation),
|
||||
"color": _annotation_color(annotation),
|
||||
"zIndex": _annotation_z_index(annotation),
|
||||
"template_id": annotation.template_id,
|
||||
})
|
||||
return class_values[key]
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
frame_masks: dict[int, list[tuple[Annotation, np.ndarray]]] = {}
|
||||
for ann in annotations:
|
||||
if not ann.mask_data:
|
||||
continue
|
||||
@@ -178,11 +253,28 @@ def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingRes
|
||||
mask = _mask_from_polygon(poly, width, height)
|
||||
combined = np.maximum(combined, mask)
|
||||
|
||||
# Encode PNG
|
||||
import cv2
|
||||
_, encoded = cv2.imencode(".png", combined)
|
||||
fname = f"mask_{ann.id:06d}.png"
|
||||
zf.writestr(fname, encoded.tobytes())
|
||||
if ann.frame_id is not None:
|
||||
frame_masks.setdefault(ann.frame_id, []).append((ann, combined))
|
||||
|
||||
for frame in frames:
|
||||
entries = frame_masks.get(frame.id, [])
|
||||
if not entries:
|
||||
continue
|
||||
width = frame.width or 1920
|
||||
height = frame.height or 1080
|
||||
semantic = np.zeros((height, width), dtype=np.uint8)
|
||||
for ann, mask in sorted(entries, key=lambda item: _annotation_z_index(item[0])):
|
||||
semantic[mask > 0] = class_value(ann)
|
||||
_, encoded = cv2.imencode(".png", semantic)
|
||||
zf.writestr(f"semantic_frame_{frame.frame_index:06d}.png", encoded.tobytes())
|
||||
|
||||
zf.writestr(
|
||||
"semantic_classes.json",
|
||||
json.dumps({"classes": semantic_classes}, ensure_ascii=False, indent=2).encode("utf-8"),
|
||||
)
|
||||
|
||||
zip_buffer.seek(0)
|
||||
filename = f"project_{project_id}_masks.zip"
|
||||
|
||||
@@ -1,15 +1,45 @@
|
||||
"""Processing task query endpoints."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from celery_app import celery_app
|
||||
from database import get_db
|
||||
from models import ProcessingTask
|
||||
from models import ProcessingTask, Project
|
||||
from progress_events import publish_task_progress_event
|
||||
from schemas import ProcessingTaskOut
|
||||
from statuses import (
|
||||
PROJECT_STATUS_PARSING,
|
||||
PROJECT_STATUS_PENDING,
|
||||
PROJECT_STATUS_READY,
|
||||
TASK_ACTIVE_STATUSES,
|
||||
TASK_STATUS_CANCELLED,
|
||||
TASK_STATUS_FAILED,
|
||||
TASK_STATUS_QUEUED,
|
||||
)
|
||||
from worker_tasks import parse_project_media
|
||||
|
||||
router = APIRouter(prefix="/api/tasks", tags=["Tasks"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _get_task_or_404(task_id: int, db: Session) -> ProcessingTask:
|
||||
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return task
|
||||
|
||||
|
||||
def _project_status_after_stop(project: Project) -> str:
|
||||
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
|
||||
|
||||
|
||||
@router.get("", response_model=List[ProcessingTaskOut], summary="List processing tasks")
|
||||
@@ -31,7 +61,78 @@ def list_tasks(
|
||||
@router.get("/{task_id}", response_model=ProcessingTaskOut, summary="Get processing task")
|
||||
def get_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
"""Return one background task by id."""
|
||||
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return _get_task_or_404(task_id, db)
|
||||
|
||||
|
||||
@router.post("/{task_id}/cancel", response_model=ProcessingTaskOut, summary="Cancel processing task")
|
||||
def cancel_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
"""Cancel a queued/running background task and revoke the Celery job when possible."""
|
||||
task = _get_task_or_404(task_id, db)
|
||||
if task.status not in TASK_ACTIVE_STATUSES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Task is not cancellable in status: {task.status}",
|
||||
)
|
||||
|
||||
if task.celery_task_id:
|
||||
try:
|
||||
celery_app.control.revoke(task.celery_task_id, terminate=True, signal="SIGTERM")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Failed to revoke celery task %s: %s", task.celery_task_id, exc)
|
||||
|
||||
task.status = TASK_STATUS_CANCELLED
|
||||
task.progress = 100
|
||||
task.message = "任务已取消"
|
||||
task.error = "Cancelled by user"
|
||||
task.finished_at = _now()
|
||||
if task.project:
|
||||
task.project.status = _project_status_after_stop(task.project)
|
||||
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
return task
|
||||
|
||||
|
||||
@router.post("/{task_id}/retry", response_model=ProcessingTaskOut, status_code=status.HTTP_202_ACCEPTED, summary="Retry processing task")
|
||||
def retry_task(task_id: int, db: Session = Depends(get_db)) -> ProcessingTask:
|
||||
"""Create a fresh queued task from a failed or cancelled task."""
|
||||
previous = _get_task_or_404(task_id, db)
|
||||
if previous.status not in {TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Task is not retryable in status: {previous.status}",
|
||||
)
|
||||
if previous.project_id is None:
|
||||
raise HTTPException(status_code=400, detail="Task has no project_id")
|
||||
|
||||
project = db.query(Project).filter(Project.id == previous.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
if not project.video_path:
|
||||
raise HTTPException(status_code=400, detail="Project has no media uploaded")
|
||||
|
||||
payload = dict(previous.payload or {})
|
||||
payload.setdefault("source_type", project.source_type or "video")
|
||||
payload["retry_of"] = previous.id
|
||||
|
||||
task = ProcessingTask(
|
||||
task_type=previous.task_type,
|
||||
status=TASK_STATUS_QUEUED,
|
||||
progress=0,
|
||||
message=f"重试任务已入队(源任务 #{previous.id})",
|
||||
project_id=project.id,
|
||||
payload=payload,
|
||||
)
|
||||
project.status = PROJECT_STATUS_PARSING
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
|
||||
async_result = parse_project_media.delay(task.id)
|
||||
task.celery_task_id = async_result.id
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
return task
|
||||
|
||||
@@ -180,6 +180,7 @@ class PredictRequest(BaseModel):
|
||||
prompt_type: str # point / box / semantic
|
||||
prompt_data: Any
|
||||
model: Optional[str] = None
|
||||
options: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class PredictResponse(BaseModel):
|
||||
@@ -201,6 +202,8 @@ class AiModelStatus(BaseModel):
|
||||
python_ok: bool = True
|
||||
torch_ok: bool = True
|
||||
cuda_required: bool = False
|
||||
external_available: bool = False
|
||||
external_python: Optional[str] = None
|
||||
|
||||
|
||||
class GpuStatus(BaseModel):
|
||||
|
||||
@@ -20,9 +20,11 @@ from services.frame_parser import (
|
||||
upload_frames_to_minio,
|
||||
)
|
||||
from statuses import (
|
||||
PROJECT_STATUS_PENDING,
|
||||
PROJECT_STATUS_ERROR,
|
||||
PROJECT_STATUS_PARSING,
|
||||
PROJECT_STATUS_READY,
|
||||
TASK_STATUS_CANCELLED,
|
||||
TASK_STATUS_FAILED,
|
||||
TASK_STATUS_RUNNING,
|
||||
TASK_STATUS_SUCCESS,
|
||||
@@ -31,6 +33,10 @@ from statuses import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskCancelled(RuntimeError):
|
||||
"""Raised internally when a persisted task has been cancelled."""
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
@@ -66,12 +72,29 @@ def _set_task_state(
|
||||
publish_task_progress_event(task)
|
||||
|
||||
|
||||
def _project_status_after_stop(project: Project) -> str:
|
||||
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
|
||||
|
||||
|
||||
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
|
||||
db.refresh(task)
|
||||
if task.status == TASK_STATUS_CANCELLED:
|
||||
raise TaskCancelled("Task was cancelled")
|
||||
|
||||
|
||||
def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
"""Parse one project's media and update task progress in the database."""
|
||||
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
|
||||
if not task:
|
||||
raise ValueError(f"Task not found: {task_id}")
|
||||
|
||||
if task.status == TASK_STATUS_CANCELLED:
|
||||
return {
|
||||
"task_id": task.id,
|
||||
"status": TASK_STATUS_CANCELLED,
|
||||
"message": task.message or "任务已取消",
|
||||
}
|
||||
|
||||
if task.project_id is None:
|
||||
_set_task_state(
|
||||
db,
|
||||
@@ -111,6 +134,7 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
db.commit()
|
||||
raise ValueError("Project has no media uploaded")
|
||||
|
||||
_ensure_not_cancelled(db, task)
|
||||
project.status = PROJECT_STATUS_PARSING
|
||||
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="后台解析已启动", started=True)
|
||||
|
||||
@@ -121,6 +145,7 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
_ensure_not_cancelled(db, task)
|
||||
_set_task_state(db, task, progress=15, message="正在下载媒体文件")
|
||||
if effective_source == "dicom":
|
||||
dcm_dir = os.path.join(tmp_dir, "dcm")
|
||||
@@ -129,20 +154,24 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
client = get_minio_client()
|
||||
objects = list(client.list_objects(BUCKET_NAME, prefix=project.video_path, recursive=True))
|
||||
for obj in objects:
|
||||
_ensure_not_cancelled(db, task)
|
||||
if obj.object_name.lower().endswith(".dcm"):
|
||||
data = download_file(obj.object_name)
|
||||
local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name))
|
||||
with open(local_dcm, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
_ensure_not_cancelled(db, task)
|
||||
_set_task_state(db, task, progress=35, message="正在解析 DICOM 序列")
|
||||
frame_files = parse_dicom(dcm_dir, output_dir)
|
||||
else:
|
||||
_ensure_not_cancelled(db, task)
|
||||
media_bytes = download_file(project.video_path)
|
||||
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
|
||||
with open(local_path, "wb") as f:
|
||||
f.write(media_bytes)
|
||||
|
||||
_ensure_not_cancelled(db, task)
|
||||
_set_task_state(db, task, progress=35, message="正在使用 FFmpeg/OpenCV 拆帧")
|
||||
frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps))
|
||||
project.original_fps = original_fps
|
||||
@@ -158,12 +187,15 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Thumbnail extraction failed: %s", exc)
|
||||
|
||||
_ensure_not_cancelled(db, task)
|
||||
_set_task_state(db, task, progress=70, message="正在上传帧到对象存储")
|
||||
object_names = upload_frames_to_minio(frame_files, project.id)
|
||||
|
||||
_ensure_not_cancelled(db, task)
|
||||
_set_task_state(db, task, progress=85, message="正在写入帧索引")
|
||||
frames_out = []
|
||||
for idx, obj_name in enumerate(object_names):
|
||||
_ensure_not_cancelled(db, task)
|
||||
local_frame = frame_files[idx]
|
||||
try:
|
||||
import cv2
|
||||
@@ -203,6 +235,23 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
)
|
||||
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project.id)
|
||||
return result
|
||||
except TaskCancelled:
|
||||
project.status = _project_status_after_stop(project)
|
||||
task.status = TASK_STATUS_CANCELLED
|
||||
task.progress = 100
|
||||
task.message = task.message or "任务已取消"
|
||||
task.error = task.error or "Cancelled by user"
|
||||
task.finished_at = task.finished_at or _now()
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
publish_task_progress_event(task)
|
||||
logger.info("Parse task cancelled: task_id=%s project_id=%s", task.id, project.id)
|
||||
return {
|
||||
"task_id": task.id,
|
||||
"project_id": project.id,
|
||||
"status": TASK_STATUS_CANCELLED,
|
||||
"message": task.message,
|
||||
}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
project.status = PROJECT_STATUS_ERROR
|
||||
_set_task_state(
|
||||
|
||||
@@ -9,8 +9,14 @@ the package.
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
@@ -41,6 +47,8 @@ class SAM3Engine:
|
||||
self._processor: Any | None = None
|
||||
self._model_loaded = False
|
||||
self._last_error: str | None = None
|
||||
self._external_status_cache: dict[str, Any] | None = None
|
||||
self._external_status_checked_at = 0.0
|
||||
|
||||
def _python_ok(self) -> bool:
|
||||
return sys.version_info >= (3, 12)
|
||||
@@ -51,6 +59,81 @@ class SAM3Engine:
|
||||
def _can_load(self) -> bool:
|
||||
return bool(SAM3_PACKAGE_AVAILABLE and TORCH_AVAILABLE and self._python_ok() and self._gpu_ok())
|
||||
|
||||
def _worker_path(self) -> Path:
|
||||
return Path(__file__).with_name("sam3_external_worker.py")
|
||||
|
||||
def _external_python_exists(self) -> bool:
|
||||
return bool(settings.sam3_external_enabled and os.path.isfile(settings.sam3_external_python))
|
||||
|
||||
def _external_status(self, force: bool = False) -> dict[str, Any]:
|
||||
now = time.monotonic()
|
||||
if (
|
||||
not force
|
||||
and self._external_status_cache is not None
|
||||
and now - self._external_status_checked_at < settings.sam3_status_cache_seconds
|
||||
):
|
||||
return self._external_status_cache
|
||||
|
||||
if not settings.sam3_external_enabled:
|
||||
status = {
|
||||
"available": False,
|
||||
"package_available": False,
|
||||
"python_ok": False,
|
||||
"torch_ok": False,
|
||||
"cuda_available": False,
|
||||
"device": "unavailable",
|
||||
"message": "SAM 3 external runtime is disabled.",
|
||||
}
|
||||
elif not self._external_python_exists():
|
||||
status = {
|
||||
"available": False,
|
||||
"package_available": False,
|
||||
"python_ok": False,
|
||||
"torch_ok": False,
|
||||
"cuda_available": False,
|
||||
"device": "unavailable",
|
||||
"message": f"SAM 3 external Python not found: {settings.sam3_external_python}",
|
||||
}
|
||||
else:
|
||||
try:
|
||||
env = os.environ.copy()
|
||||
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
|
||||
completed = subprocess.run(
|
||||
[settings.sam3_external_python, str(self._worker_path()), "--status"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=min(settings.sam3_timeout_seconds, 30),
|
||||
check=False,
|
||||
env=env,
|
||||
)
|
||||
if completed.returncode != 0:
|
||||
detail = completed.stderr.strip() or completed.stdout.strip()
|
||||
status = {
|
||||
"available": False,
|
||||
"package_available": False,
|
||||
"python_ok": False,
|
||||
"torch_ok": False,
|
||||
"cuda_available": False,
|
||||
"device": "unavailable",
|
||||
"message": f"SAM 3 external status failed: {detail}",
|
||||
}
|
||||
else:
|
||||
status = json.loads(completed.stdout)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
status = {
|
||||
"available": False,
|
||||
"package_available": False,
|
||||
"python_ok": False,
|
||||
"torch_ok": False,
|
||||
"cuda_available": False,
|
||||
"device": "unavailable",
|
||||
"message": f"SAM 3 external status failed: {exc}",
|
||||
}
|
||||
|
||||
self._external_status_cache = status
|
||||
self._external_status_checked_at = now
|
||||
return status
|
||||
|
||||
def _load_model(self) -> None:
|
||||
if self._model_loaded:
|
||||
return
|
||||
@@ -92,26 +175,86 @@ class SAM3Engine:
|
||||
return "SAM 3 dependencies are present; model will load on first inference."
|
||||
|
||||
def status(self) -> dict:
|
||||
available = self._can_load()
|
||||
external_status = self._external_status()
|
||||
available = bool(self._can_load() or external_status.get("available"))
|
||||
external_ready = bool(external_status.get("available"))
|
||||
message = self._last_error or self._status_message()
|
||||
if self._processor is not None:
|
||||
message = "SAM 3 model loaded and ready."
|
||||
elif external_ready:
|
||||
message = "SAM 3 external runtime is ready; model will load in the helper process on inference."
|
||||
elif external_status.get("message") and not self._can_load():
|
||||
message = str(external_status["message"])
|
||||
return {
|
||||
"id": "sam3",
|
||||
"label": "SAM 3",
|
||||
"available": available,
|
||||
"loaded": self._processor is not None,
|
||||
"device": "cuda" if self._gpu_ok() else "unavailable",
|
||||
"device": "cuda" if self._gpu_ok() else str(external_status.get("device", "unavailable")),
|
||||
"supports": ["semantic"],
|
||||
"message": "SAM 3 model loaded and ready." if self._processor is not None else (self._last_error or self._status_message()),
|
||||
"package_available": SAM3_PACKAGE_AVAILABLE,
|
||||
"checkpoint_exists": SAM3_PACKAGE_AVAILABLE,
|
||||
"message": message,
|
||||
"package_available": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("package_available")),
|
||||
"checkpoint_exists": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("checkpoint_access")),
|
||||
"checkpoint_path": f"official/HuggingFace ({settings.sam3_model_version})",
|
||||
"python_ok": self._python_ok(),
|
||||
"torch_ok": TORCH_AVAILABLE,
|
||||
"python_ok": bool(self._python_ok() or external_status.get("python_ok")),
|
||||
"torch_ok": bool(TORCH_AVAILABLE or external_status.get("torch_ok")),
|
||||
"cuda_required": True,
|
||||
"external_available": external_ready,
|
||||
"external_python": settings.sam3_external_python if settings.sam3_external_enabled else None,
|
||||
}
|
||||
|
||||
def _predict_semantic_external(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
|
||||
status = self._external_status(force=True)
|
||||
if not status.get("available"):
|
||||
raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.")
|
||||
|
||||
with tempfile.TemporaryDirectory(prefix="sam3_") as tmpdir:
|
||||
tmp_path = Path(tmpdir)
|
||||
image_path = tmp_path / "image.png"
|
||||
request_path = tmp_path / "request.json"
|
||||
Image.fromarray(image).save(image_path)
|
||||
request_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"image_path": str(image_path),
|
||||
"text": text.strip(),
|
||||
"model_version": settings.sam3_model_version,
|
||||
"confidence_threshold": settings.sam3_confidence_threshold,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
env = os.environ.copy()
|
||||
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
|
||||
completed = subprocess.run(
|
||||
[settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=settings.sam3_timeout_seconds,
|
||||
check=False,
|
||||
env=env,
|
||||
)
|
||||
|
||||
if completed.returncode != 0:
|
||||
detail = completed.stderr.strip() or completed.stdout.strip()
|
||||
try:
|
||||
parsed = json.loads(detail)
|
||||
detail = parsed.get("error", detail)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
raise RuntimeError(f"SAM 3 external inference failed: {detail}")
|
||||
|
||||
payload = json.loads(completed.stdout)
|
||||
if payload.get("error"):
|
||||
raise RuntimeError(str(payload["error"]))
|
||||
return payload.get("polygons", []), payload.get("scores", [])
|
||||
|
||||
def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
|
||||
if not text.strip():
|
||||
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
|
||||
if not self._can_load() and self._external_status().get("available"):
|
||||
return self._predict_semantic_external(image, text)
|
||||
if not self._ensure_ready():
|
||||
raise RuntimeError(self.status()["message"])
|
||||
|
||||
|
||||
190
backend/services/sam3_external_worker.py
Normal file
190
backend/services/sam3_external_worker.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Standalone SAM 3 helper for the dedicated Python 3.12 runtime.
|
||||
|
||||
The main FastAPI backend can keep running in the existing Python 3.11/SAM 2
|
||||
environment while this helper is executed with a separate conda env that meets
|
||||
SAM 3's stricter runtime requirements.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def _torch_status() -> tuple[bool, str | None, str | None, str | None]:
|
||||
try:
|
||||
import torch
|
||||
|
||||
cuda_available = bool(torch.cuda.is_available())
|
||||
return (
|
||||
cuda_available,
|
||||
getattr(torch, "__version__", None),
|
||||
getattr(torch.version, "cuda", None),
|
||||
torch.cuda.get_device_name(0) if cuda_available else None,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
return False, None, None, None
|
||||
|
||||
|
||||
def _compact_error(exc: Exception) -> str:
|
||||
lines = [line.strip() for line in str(exc).splitlines() if line.strip()]
|
||||
for line in lines:
|
||||
if "Access to model" in line or "Cannot access gated repo" in line:
|
||||
return line
|
||||
return lines[0] if lines else exc.__class__.__name__
|
||||
|
||||
|
||||
def _checkpoint_access(model_version: str) -> tuple[bool, str | None]:
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
repo_id = "facebook/sam3.1" if model_version == "sam3.1" else "facebook/sam3"
|
||||
hf_hub_download(repo_id=repo_id, filename="config.json")
|
||||
return True, None
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return False, _compact_error(exc)
|
||||
|
||||
|
||||
def runtime_status() -> dict[str, Any]:
|
||||
model_version = os.environ.get("SAM3_MODEL_VERSION", "sam3")
|
||||
package_error = None
|
||||
package_available = importlib.util.find_spec("sam3") is not None
|
||||
if package_available:
|
||||
try:
|
||||
import sam3 # noqa: F401
|
||||
except Exception as exc: # noqa: BLE001
|
||||
package_available = False
|
||||
package_error = str(exc)
|
||||
cuda_available, torch_version, cuda_version, device_name = _torch_status()
|
||||
python_ok = sys.version_info >= (3, 12)
|
||||
checkpoint_access = False
|
||||
checkpoint_error = None
|
||||
if package_available:
|
||||
checkpoint_access, checkpoint_error = _checkpoint_access(model_version)
|
||||
available = bool(package_available and python_ok and cuda_available and checkpoint_access)
|
||||
missing = []
|
||||
if not python_ok:
|
||||
missing.append("Python 3.12+ runtime")
|
||||
if not package_available:
|
||||
missing.append(f"sam3 package ({package_error})" if package_error else "sam3 package")
|
||||
if torch_version is None:
|
||||
missing.append("PyTorch")
|
||||
if not cuda_available:
|
||||
missing.append("CUDA GPU")
|
||||
if package_available and not checkpoint_access:
|
||||
missing.append(f"Hugging Face checkpoint access ({checkpoint_error})")
|
||||
return {
|
||||
"available": available,
|
||||
"package_available": package_available,
|
||||
"checkpoint_access": checkpoint_access,
|
||||
"python_ok": python_ok,
|
||||
"torch_ok": torch_version is not None,
|
||||
"torch_version": torch_version,
|
||||
"cuda_version": cuda_version,
|
||||
"cuda_available": cuda_available,
|
||||
"device": "cuda" if cuda_available else "unavailable",
|
||||
"device_name": device_name,
|
||||
"message": (
|
||||
"SAM 3 external runtime is ready."
|
||||
if available
|
||||
else f"SAM 3 external runtime unavailable: missing {', '.join(missing)}."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
|
||||
import cv2
|
||||
|
||||
if mask.dtype != np.uint8:
|
||||
mask = (mask > 0).astype(np.uint8)
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
height, width = mask.shape[:2]
|
||||
largest = []
|
||||
for contour in contours:
|
||||
if len(contour) > len(largest):
|
||||
largest = contour
|
||||
if len(largest) < 3:
|
||||
return []
|
||||
return [[float(point[0][0]) / width, float(point[0][1]) / height] for point in largest]
|
||||
|
||||
|
||||
def _to_numpy(value: Any) -> np.ndarray:
|
||||
if hasattr(value, "detach"):
|
||||
value = value.detach().cpu().numpy()
|
||||
elif hasattr(value, "cpu"):
|
||||
value = value.cpu().numpy()
|
||||
return np.asarray(value)
|
||||
|
||||
|
||||
def predict(request_path: Path) -> dict[str, Any]:
|
||||
import torch
|
||||
from sam3.model.sam3_image_processor import Sam3Processor
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
|
||||
payload = json.loads(request_path.read_text(encoding="utf-8"))
|
||||
image_path = Path(payload["image_path"])
|
||||
text = str(payload["text"]).strip()
|
||||
threshold = float(payload.get("confidence_threshold", 0.5))
|
||||
if not text:
|
||||
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
model = build_sam3_image_model()
|
||||
processor = Sam3Processor(model, confidence_threshold=threshold)
|
||||
state = processor.set_image(image)
|
||||
output = processor.set_text_prompt(state=state, prompt=text)
|
||||
|
||||
masks = _to_numpy(output.get("masks", []))
|
||||
scores = _to_numpy(output.get("scores", []))
|
||||
if masks.ndim == 4:
|
||||
masks = masks[:, 0]
|
||||
elif masks.ndim == 3 and masks.shape[0] == 1:
|
||||
masks = masks[None, 0]
|
||||
|
||||
polygons = []
|
||||
for mask in masks:
|
||||
polygon = _mask_to_polygon(mask)
|
||||
if polygon:
|
||||
polygons.append(polygon)
|
||||
|
||||
return {
|
||||
"polygons": polygons,
|
||||
"scores": scores.astype(float).tolist() if scores.size else [],
|
||||
}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="SAM 3 external runtime helper")
|
||||
parser.add_argument("--status", action="store_true")
|
||||
parser.add_argument("--request", type=Path)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
if args.status:
|
||||
print(json.dumps(runtime_status(), ensure_ascii=False))
|
||||
return 0
|
||||
if args.request:
|
||||
print(json.dumps(predict(args.request), ensure_ascii=False))
|
||||
return 0
|
||||
parser.error("Use --status or --request")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(json.dumps({"error": str(exc)}, ensure_ascii=False), file=sys.stderr)
|
||||
return 1
|
||||
|
||||
return 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
24
backend/setup_sam3_env.sh
Executable file
24
backend/setup_sam3_env.sh
Executable file
@@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env bash
|
||||
# Create the dedicated SAM 3 runtime used by backend/services/sam3_external_worker.py.
|
||||
# Keep Hugging Face tokens outside this repository, for example:
|
||||
# export HF_TOKEN=...
|
||||
# huggingface-cli login --token "$HF_TOKEN"
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
ENV_NAME="${SAM3_CONDA_ENV:-sam3}"
|
||||
|
||||
source /home/wkmgc/miniconda3/etc/profile.d/conda.sh
|
||||
|
||||
if ! conda env list | awk '{print $1}' | grep -qx "$ENV_NAME"; then
|
||||
conda create -y -n "$ENV_NAME" python=3.12
|
||||
fi
|
||||
|
||||
conda activate "$ENV_NAME"
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install "setuptools<81"
|
||||
python -m pip install torch==2.10.0 torchvision --index-url https://download.pytorch.org/whl/cu128
|
||||
python -m pip install opencv-python pillow numpy huggingface_hub einops pycocotools psutil
|
||||
python -m pip install git+https://github.com/facebookresearch/sam3.git
|
||||
|
||||
python /home/wkmgc/Desktop/Seg_Server/backend/services/sam3_external_worker.py --status
|
||||
@@ -9,3 +9,7 @@ TASK_STATUS_QUEUED = "queued"
|
||||
TASK_STATUS_RUNNING = "running"
|
||||
TASK_STATUS_SUCCESS = "success"
|
||||
TASK_STATUS_FAILED = "failed"
|
||||
TASK_STATUS_CANCELLED = "cancelled"
|
||||
|
||||
TASK_ACTIVE_STATUSES = {TASK_STATUS_QUEUED, TASK_STATUS_RUNNING}
|
||||
TASK_TERMINAL_STATUSES = {TASK_STATUS_SUCCESS, TASK_STATUS_FAILED, TASK_STATUS_CANCELLED}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def _create_project_and_frame(client):
|
||||
@@ -46,6 +47,46 @@ def test_predict_accepts_point_object_with_labels(client, monkeypatch):
|
||||
assert calls["args"] == ([[0.5, 0.5], [0.1, 0.1]], [1, 0])
|
||||
|
||||
|
||||
def test_predict_applies_crop_and_background_filter_options(client, monkeypatch):
|
||||
_, frame, _ = _create_project_and_frame(client)
|
||||
calls = {}
|
||||
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((100, 200, 3), dtype=np.uint8))
|
||||
|
||||
def fake_predict_points(model, image, points, labels):
|
||||
calls["shape"] = image.shape
|
||||
calls["points"] = points
|
||||
calls["labels"] = labels
|
||||
return (
|
||||
[
|
||||
[[0.0, 0.0], [0.2, 0.0], [0.2, 0.2]],
|
||||
[[0.45, 0.45], [0.55, 0.45], [0.55, 0.55]],
|
||||
],
|
||||
[0.9, 0.01],
|
||||
)
|
||||
|
||||
monkeypatch.setattr("routers.ai.sam_registry.predict_points", fake_predict_points)
|
||||
|
||||
response = client.post("/api/ai/predict", json={
|
||||
"image_id": frame["id"],
|
||||
"prompt_type": "point",
|
||||
"prompt_data": {"points": [[0.5, 0.5], [0.52, 0.52]], "labels": [1, 0]},
|
||||
"options": {
|
||||
"crop_to_prompt": True,
|
||||
"crop_margin": 0.1,
|
||||
"auto_filter_background": True,
|
||||
"min_score": 0.05,
|
||||
},
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert calls["shape"][0] < 100
|
||||
assert calls["shape"][1] < 200
|
||||
assert calls["labels"] == [1, 0]
|
||||
assert response.json()["scores"] == [0.9]
|
||||
polygon = response.json()["polygons"][0]
|
||||
assert all(0.0 <= coord <= 1.0 for point in polygon for coord in point)
|
||||
|
||||
|
||||
def test_predict_box_and_semantic_fallback(client, monkeypatch):
|
||||
_, frame, _ = _create_project_and_frame(client)
|
||||
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
||||
@@ -246,3 +287,62 @@ def test_update_and_delete_annotation_validation(client):
|
||||
f"/api/ai/annotations/{saved['id']}",
|
||||
json={"template_id": 999},
|
||||
).status_code == 404
|
||||
|
||||
|
||||
def test_import_gt_mask_creates_annotations_with_seed_points(client):
|
||||
project, frame, template = _create_project_and_frame(client)
|
||||
mask = np.zeros((360, 640), dtype=np.uint8)
|
||||
cv2.rectangle(mask, (100, 80), (260, 220), 255, thickness=-1)
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
|
||||
response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"template_id": str(template["id"]),
|
||||
"label": "Imported GT",
|
||||
"color": "#22c55e",
|
||||
},
|
||||
files={"file": ("mask.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
body = response.json()
|
||||
assert len(body) == 1
|
||||
assert body[0]["project_id"] == project["id"]
|
||||
assert body[0]["frame_id"] == frame["id"]
|
||||
assert body[0]["template_id"] == template["id"]
|
||||
assert body[0]["mask_data"]["label"] == "Imported GT"
|
||||
assert body[0]["mask_data"]["source"] == "gt_mask"
|
||||
assert body[0]["mask_data"]["gt_label_value"] == 255
|
||||
assert len(body[0]["mask_data"]["polygons"][0]) >= 3
|
||||
assert len(body[0]["points"]) == 1
|
||||
assert 0.0 <= body[0]["points"][0][0] <= 1.0
|
||||
assert 0.0 <= body[0]["points"][0][1] <= 1.0
|
||||
|
||||
|
||||
def test_import_gt_mask_splits_label_values(client):
|
||||
project, frame, _ = _create_project_and_frame(client)
|
||||
mask = np.zeros((360, 640), dtype=np.uint8)
|
||||
cv2.rectangle(mask, (20, 20), (120, 120), 1, thickness=-1)
|
||||
cv2.rectangle(mask, (220, 80), (320, 180), 2, thickness=-1)
|
||||
ok, encoded = cv2.imencode(".png", mask)
|
||||
assert ok
|
||||
|
||||
response = client.post(
|
||||
"/api/ai/import-gt-mask",
|
||||
data={
|
||||
"project_id": str(project["id"]),
|
||||
"frame_id": str(frame["id"]),
|
||||
"label": "GT Class",
|
||||
},
|
||||
files={"file": ("labels.png", encoded.tobytes(), "image/png")},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
body = sorted(response.json(), key=lambda item: item["mask_data"]["gt_label_value"])
|
||||
assert [item["mask_data"]["gt_label_value"] for item in body] == [1, 2]
|
||||
assert [item["mask_data"]["label"] for item in body] == ["GT Class 1", "GT Class 2"]
|
||||
assert all(len(item["points"]) == 1 for item in body)
|
||||
|
||||
@@ -59,7 +59,9 @@ def test_dashboard_overview_uses_persisted_records(client, db_session):
|
||||
"name": "Pending Project",
|
||||
"progress": 35,
|
||||
"status": "正在使用 FFmpeg/OpenCV 拆帧",
|
||||
"raw_status": "running",
|
||||
"frame_count": 0,
|
||||
"error": None,
|
||||
"updated_at": body["tasks"][0]["updated_at"],
|
||||
},
|
||||
]
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import zipfile
|
||||
import json
|
||||
from io import BytesIO
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _seed_export_data(client):
|
||||
project = client.post("/api/projects", json={"name": "Export Project"}).json()
|
||||
@@ -58,7 +62,55 @@ def test_export_masks_zip(client):
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("application/zip")
|
||||
with zipfile.ZipFile(BytesIO(response.content)) as archive:
|
||||
assert archive.namelist() == [f"mask_{annotation['id']:06d}.png"]
|
||||
assert archive.namelist() == [
|
||||
f"mask_{annotation['id']:06d}.png",
|
||||
"semantic_frame_000000.png",
|
||||
"semantic_classes.json",
|
||||
]
|
||||
|
||||
|
||||
def test_export_masks_uses_z_index_for_semantic_fusion(client):
|
||||
project = client.post("/api/projects", json={"name": "Fusion Project"}).json()
|
||||
frame = client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": 0,
|
||||
"image_url": "frames/0.jpg",
|
||||
"width": 20,
|
||||
"height": 20,
|
||||
}).json()
|
||||
low = client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.1, 0.1], [0.8, 0.1], [0.8, 0.8], [0.1, 0.8]]],
|
||||
"label": "Low",
|
||||
"color": "#00ff00",
|
||||
"class": {"id": "low", "name": "Low", "color": "#00ff00", "zIndex": 10},
|
||||
},
|
||||
}).json()
|
||||
high = client.post("/api/ai/annotate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frame["id"],
|
||||
"mask_data": {
|
||||
"polygons": [[[0.4, 0.4], [0.9, 0.4], [0.9, 0.9], [0.4, 0.9]]],
|
||||
"label": "High",
|
||||
"color": "#ff0000",
|
||||
"class": {"id": "high", "name": "High", "color": "#ff0000", "zIndex": 20},
|
||||
},
|
||||
}).json()
|
||||
|
||||
response = client.get(f"/api/export/{project['id']}/masks")
|
||||
|
||||
assert response.status_code == 200
|
||||
with zipfile.ZipFile(BytesIO(response.content)) as archive:
|
||||
assert f"mask_{low['id']:06d}.png" in archive.namelist()
|
||||
assert f"mask_{high['id']:06d}.png" in archive.namelist()
|
||||
legend = json.loads(archive.read("semantic_classes.json"))
|
||||
high_value = next(item["value"] for item in legend["classes"] if item["key"] == "class:high")
|
||||
semantic_bytes = np.frombuffer(archive.read("semantic_frame_000000.png"), dtype=np.uint8)
|
||||
semantic = cv2.imdecode(semantic_bytes, cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
assert semantic[10, 10] == high_value
|
||||
|
||||
|
||||
def test_export_missing_project_returns_404(client):
|
||||
|
||||
@@ -140,3 +140,25 @@ def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp
|
||||
assert project_detail["status"] == "ready"
|
||||
frames = client.get(f"/api/projects/{project['id']}/frames").json()
|
||||
assert "frame_000001.jpg" in frames[0]["image_url"]
|
||||
|
||||
|
||||
def test_parse_task_runner_skips_already_cancelled_task(db_session):
|
||||
from models import ProcessingTask
|
||||
from services.media_task_runner import run_parse_media_task
|
||||
|
||||
task = ProcessingTask(
|
||||
task_type="parse_video",
|
||||
status="cancelled",
|
||||
progress=100,
|
||||
message="任务已取消",
|
||||
project_id=1,
|
||||
payload={"source_type": "video"},
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
|
||||
result = run_parse_media_task(db_session, task.id)
|
||||
|
||||
assert result["status"] == "cancelled"
|
||||
assert result["message"] == "任务已取消"
|
||||
|
||||
@@ -26,6 +26,25 @@ def test_task_progress_payload_uses_dashboard_task_id_and_project_name():
|
||||
assert payload["status"] == "解析完成"
|
||||
|
||||
|
||||
def test_task_progress_payload_marks_cancelled_tasks():
|
||||
task = SimpleNamespace(
|
||||
id=13,
|
||||
project_id=7,
|
||||
project=SimpleNamespace(name="demo.mp4"),
|
||||
status="cancelled",
|
||||
progress=100,
|
||||
message="任务已取消",
|
||||
error="Cancelled by user",
|
||||
updated_at=None,
|
||||
)
|
||||
|
||||
payload = task_progress_payload(task)
|
||||
|
||||
assert payload["type"] == "cancelled"
|
||||
assert payload["status"] == "任务已取消"
|
||||
assert payload["error"] == "Cancelled by user"
|
||||
|
||||
|
||||
def test_publish_progress_event_writes_json_to_redis(monkeypatch):
|
||||
calls = []
|
||||
|
||||
|
||||
112
backend/tests/test_sam3_engine.py
Normal file
112
backend/tests/test_sam3_engine.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from services.sam3_engine import SAM3Engine
|
||||
|
||||
|
||||
class _Completed:
|
||||
def __init__(self, returncode=0, stdout="", stderr=""):
|
||||
self.returncode = returncode
|
||||
self.stdout = stdout
|
||||
self.stderr = stderr
|
||||
|
||||
|
||||
def _external_settings(monkeypatch, python_path: Path):
|
||||
python_path.write_text("#!/usr/bin/env python\n", encoding="utf-8")
|
||||
python_path.chmod(0o755)
|
||||
monkeypatch.setattr("services.sam3_engine.SAM3_PACKAGE_AVAILABLE", False)
|
||||
monkeypatch.setattr("services.sam3_engine.TORCH_AVAILABLE", False)
|
||||
monkeypatch.setattr("services.sam3_engine.settings.sam3_external_enabled", True)
|
||||
monkeypatch.setattr("services.sam3_engine.settings.sam3_external_python", str(python_path))
|
||||
monkeypatch.setattr("services.sam3_engine.settings.sam3_timeout_seconds", 10)
|
||||
monkeypatch.setattr("services.sam3_engine.settings.sam3_status_cache_seconds", 30)
|
||||
monkeypatch.setattr("services.sam3_engine.settings.sam3_confidence_threshold", 0.4)
|
||||
|
||||
|
||||
def test_sam3_status_reports_external_runtime_ready(tmp_path, monkeypatch):
|
||||
_external_settings(monkeypatch, tmp_path / "python")
|
||||
|
||||
def fake_run(args, **_kwargs):
|
||||
assert "--status" in args
|
||||
return _Completed(stdout=json.dumps({
|
||||
"available": True,
|
||||
"package_available": True,
|
||||
"python_ok": True,
|
||||
"torch_ok": True,
|
||||
"cuda_available": True,
|
||||
"device": "cuda",
|
||||
"message": "ready",
|
||||
}))
|
||||
|
||||
monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run)
|
||||
|
||||
status = SAM3Engine().status()
|
||||
|
||||
assert status["available"] is True
|
||||
assert status["external_available"] is True
|
||||
assert status["package_available"] is True
|
||||
assert status["python_ok"] is True
|
||||
assert status["message"] == "SAM 3 external runtime is ready; model will load in the helper process on inference."
|
||||
|
||||
|
||||
def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch):
|
||||
_external_settings(monkeypatch, tmp_path / "python")
|
||||
calls = []
|
||||
|
||||
def fake_run(args, **_kwargs):
|
||||
calls.append(args)
|
||||
if "--status" in args:
|
||||
return _Completed(stdout=json.dumps({
|
||||
"available": True,
|
||||
"package_available": True,
|
||||
"python_ok": True,
|
||||
"torch_ok": True,
|
||||
"cuda_available": True,
|
||||
"device": "cuda",
|
||||
"message": "ready",
|
||||
}))
|
||||
request_path = Path(args[-1])
|
||||
request = json.loads(request_path.read_text(encoding="utf-8"))
|
||||
assert request["text"] == "vessel"
|
||||
assert request["confidence_threshold"] == 0.4
|
||||
assert Path(request["image_path"]).exists()
|
||||
return _Completed(stdout=json.dumps({
|
||||
"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
|
||||
"scores": [0.91],
|
||||
}))
|
||||
|
||||
monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run)
|
||||
|
||||
polygons, scores = SAM3Engine().predict_semantic(np.zeros((8, 8, 3), dtype=np.uint8), " vessel ")
|
||||
|
||||
assert polygons == [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]]
|
||||
assert scores == [0.91]
|
||||
assert any("--request" in args for args in calls)
|
||||
|
||||
|
||||
def test_sam3_predict_semantic_reports_external_errors(tmp_path, monkeypatch):
|
||||
_external_settings(monkeypatch, tmp_path / "python")
|
||||
|
||||
def fake_run(args, **_kwargs):
|
||||
if "--status" in args:
|
||||
return _Completed(stdout=json.dumps({
|
||||
"available": True,
|
||||
"package_available": True,
|
||||
"python_ok": True,
|
||||
"torch_ok": True,
|
||||
"cuda_available": True,
|
||||
"device": "cuda",
|
||||
"message": "ready",
|
||||
}))
|
||||
return _Completed(returncode=1, stderr=json.dumps({"error": "HF access denied"}))
|
||||
|
||||
monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run)
|
||||
|
||||
try:
|
||||
SAM3Engine().predict_semantic(np.zeros((8, 8, 3), dtype=np.uint8), "vessel")
|
||||
except RuntimeError as exc:
|
||||
assert "HF access denied" in str(exc)
|
||||
else:
|
||||
raise AssertionError("Expected SAM 3 external inference failure.")
|
||||
104
backend/tests/test_tasks.py
Normal file
104
backend/tests/test_tasks.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from models import ProcessingTask
|
||||
|
||||
|
||||
def test_cancel_task_revokes_celery_and_updates_project(client, db_session, monkeypatch):
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Cancelable",
|
||||
"video_path": "uploads/1/clip.mp4",
|
||||
"status": "parsing",
|
||||
}).json()
|
||||
task = ProcessingTask(
|
||||
task_type="parse_video",
|
||||
status="running",
|
||||
progress=35,
|
||||
message="正在使用 FFmpeg/OpenCV 拆帧",
|
||||
project_id=project["id"],
|
||||
celery_task_id="celery-1",
|
||||
payload={"source_type": "video"},
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
|
||||
revoked = []
|
||||
published = []
|
||||
monkeypatch.setattr(
|
||||
"routers.tasks.celery_app.control.revoke",
|
||||
lambda celery_id, terminate, signal: revoked.append((celery_id, terminate, signal)),
|
||||
)
|
||||
monkeypatch.setattr("routers.tasks.publish_task_progress_event", lambda event_task: published.append(event_task.status))
|
||||
|
||||
response = client.post(f"/api/tasks/{task.id}/cancel")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["status"] == "cancelled"
|
||||
assert body["progress"] == 100
|
||||
assert body["message"] == "任务已取消"
|
||||
assert body["error"] == "Cancelled by user"
|
||||
assert revoked == [("celery-1", True, "SIGTERM")]
|
||||
assert published == ["cancelled"]
|
||||
assert client.get(f"/api/projects/{project['id']}").json()["status"] == "pending"
|
||||
|
||||
|
||||
def test_retry_task_creates_fresh_parse_task(client, db_session, monkeypatch):
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Retryable",
|
||||
"video_path": "uploads/2/clip.mp4",
|
||||
"source_type": "video",
|
||||
"status": "error",
|
||||
}).json()
|
||||
task = ProcessingTask(
|
||||
task_type="parse_video",
|
||||
status="failed",
|
||||
progress=100,
|
||||
message="解析失败",
|
||||
error="ffmpeg failed",
|
||||
project_id=project["id"],
|
||||
payload={"source_type": "video"},
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
|
||||
class FakeAsyncResult:
|
||||
id = "celery-retry"
|
||||
|
||||
queued = []
|
||||
published = []
|
||||
monkeypatch.setattr("routers.tasks.parse_project_media.delay", lambda task_id: queued.append(task_id) or FakeAsyncResult())
|
||||
monkeypatch.setattr("routers.tasks.publish_task_progress_event", lambda event_task: published.append((event_task.id, event_task.status)))
|
||||
|
||||
response = client.post(f"/api/tasks/{task.id}/retry")
|
||||
|
||||
assert response.status_code == 202
|
||||
body = response.json()
|
||||
assert body["id"] != task.id
|
||||
assert body["status"] == "queued"
|
||||
assert body["progress"] == 0
|
||||
assert body["celery_task_id"] == "celery-retry"
|
||||
assert body["payload"]["retry_of"] == task.id
|
||||
assert queued == [body["id"]]
|
||||
assert published[0] == (body["id"], "queued")
|
||||
assert published[-1] == (body["id"], "queued")
|
||||
assert client.get(f"/api/projects/{project['id']}").json()["status"] == "parsing"
|
||||
|
||||
|
||||
def test_task_actions_reject_invalid_states(client, db_session):
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Done",
|
||||
"video_path": "uploads/3/clip.mp4",
|
||||
}).json()
|
||||
task = ProcessingTask(
|
||||
task_type="parse_video",
|
||||
status="success",
|
||||
progress=100,
|
||||
project_id=project["id"],
|
||||
payload={"source_type": "video"},
|
||||
)
|
||||
db_session.add(task)
|
||||
db_session.commit()
|
||||
db_session.refresh(task)
|
||||
|
||||
assert client.post(f"/api/tasks/{task.id}/cancel").status_code == 409
|
||||
assert client.post(f"/api/tasks/{task.id}/retry").status_code == 409
|
||||
Reference in New Issue
Block a user