feat: 完善视频传播、标注编辑和拆帧闭环

- 接入 SAM2 视频传播能力:新增 /api/ai/propagate,支持用当前帧 mask/polygon/bbox 作为 seed,通过 SAM2 video predictor 向前、向后或双向传播,并可保存为真实 annotation。
- 接入 SAM3 video tracker:通过独立 Python 3.12 external worker 调用 SAM3 video predictor/tracker,使用本地 checkpoint 与 bbox seed 执行视频级跟踪,并在模型状态中标记 video_track 能力。
- 完善 SAM 模型分发:sam_registry 按 model_id 明确区分 sam2 propagation 与 sam3 video_track,避免两个模型链路混用。
- 打通前端“传播片段”:VideoWorkspace 使用当前选中 mask 和当前 AI 模型调用后端传播接口,传播结果回写并刷新工作区已保存标注。
- 增强 SAM3 本地 checkpoint 配置:新增 sam3_checkpoint_path 配置和 .env.example 示例,状态检查改为基于本地 checkpoint/独立环境/模型包可用性。
- 完善视频拆帧参数:/api/media/parse 支持 parse_fps、max_frames、target_width,后端任务保存帧时间戳、源帧号和 frame_sequence 元数据。
- 增加运行时 schema 兼容处理:启动时为旧 frames 表补充 timestamp_ms 和 source_frame_number 列,避免旧库升级后缺字段。
- 强化 Canvas 标注编辑:补齐多边形闭合、点工具、顶点拖拽、边中点插入、Delete/Backspace 删除、区域合并和重叠去除等交互。
- 增强语义分类联动:选中 mask 后可通过右侧语义分类树更新标签、颜色和 class metadata,并同步到保存/导出链路。
- 增加关键帧时间轴体验:FrameTimeline 显示具体时间信息,并支持键盘左右方向键切换关键帧。
- 完善 AI 交互分割参数:前端保留正向点、反向点、框选和 interactive prompt 的调用状态,支持 SAM2 细化候选区域与 SAM3 bbox 入口。
- 扩展后端/前端 API 类型:新增 propagateMasks、传播请求/响应 schema,并补齐 annotation、导出、模型状态和任务接口的测试覆盖。
- 更新项目文档:同步 README、AGENTS、接口契约、需求冻结、设计冻结、前端元素审计、实施计划和测试计划,标明真实功能边界与剩余风险。
- 增加测试覆盖:补充 SAM2/SAM3 传播、SAM3 状态、媒体拆帧参数、Canvas 编辑、语义标签切换、时间轴、工作区传播和 API 合约测试。
- 加强仓库安全边界:将 sam3权重/ 加入 .gitignore,避免本地模型权重被误提交。

验证:npm run test:run;pytest backend/tests;npm run lint;npm run build;python -m py_compile;git diff --check。
This commit is contained in:
2026-05-01 20:27:33 +08:00
parent 689a9ba283
commit 5ab4602535
43 changed files with 2722 additions and 216 deletions

View File

@@ -52,6 +52,7 @@ def parse_video(
output_dir: str,
fps: int = 30,
max_frames: Optional[int] = None,
target_width: int = 640,
) -> Tuple[List[str], float]:
"""Extract frames from a video file using FFmpeg or OpenCV fallback.
@@ -60,6 +61,7 @@ def parse_video(
output_dir: Directory to save extracted frames.
fps: Target frame extraction rate.
max_frames: Optional maximum number of frames to extract.
target_width: Output frame width for model-friendly frame sequences.
Returns:
Tuple of (frame_paths, original_fps).
@@ -67,6 +69,8 @@ def parse_video(
os.makedirs(output_dir, exist_ok=True)
frame_paths: List[str] = []
original_fps = get_video_fps(video_path)
safe_fps = max(int(fps), 1)
safe_width = max(int(target_width), 1)
# Try FFmpeg first
if shutil.which("ffmpeg"):
@@ -75,7 +79,8 @@ def parse_video(
cmd = [
"ffmpeg",
"-i", video_path,
"-vf", f"fps={fps},scale=640:-1",
"-vf", f"fps={safe_fps},scale={safe_width}:-1",
"-start_number", "0",
"-q:v", "5",
"-y",
pattern,
@@ -102,7 +107,7 @@ def parse_video(
raise RuntimeError(f"Cannot open video: {video_path}")
video_fps = cap.get(cv2.CAP_PROP_FPS) or 30
interval = max(1, int(round(video_fps / fps)))
interval = max(1, int(round(video_fps / safe_fps)))
count = 0
saved = 0
@@ -112,6 +117,10 @@ def parse_video(
break
if count % interval == 0:
path = os.path.join(output_dir, f"frame_{saved:06d}.jpg")
h, w = frame.shape[:2]
if safe_width > 0 and w != safe_width:
scale = safe_width / max(w, 1)
frame = cv2.resize(frame, (safe_width, max(1, int(round(h * scale)))), interpolation=cv2.INTER_AREA)
cv2.imwrite(path, frame, [cv2.IMWRITE_JPEG_QUALITY, 80])
frame_paths.append(path)
saved += 1

View File

@@ -76,6 +76,38 @@ def _project_status_after_stop(project: Project) -> str:
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
def _positive_int(value: Any, default: int | None = None) -> int | None:
try:
parsed = int(value)
except (TypeError, ValueError):
return default
return parsed if parsed > 0 else default
def _positive_float(value: Any, default: float) -> float:
try:
parsed = float(value)
except (TypeError, ValueError):
return default
return parsed if parsed > 0 else default
def _frame_sequence_metadata(
index: int,
parse_fps: float,
original_fps: float | None,
) -> dict[str, float | int | None]:
safe_parse_fps = max(float(parse_fps or 1.0), 1e-6)
timestamp_ms = index * 1000.0 / safe_parse_fps
source_frame_number = None
if original_fps and original_fps > 0:
source_frame_number = int(round(index * original_fps / safe_parse_fps))
return {
"timestamp_ms": timestamp_ms,
"source_frame_number": source_frame_number,
}
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
db.refresh(task)
if task.status == TASK_STATUS_CANCELLED:
@@ -138,8 +170,12 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
project.status = PROJECT_STATUS_PARSING
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="后台解析已启动", started=True)
effective_source = (task.payload or {}).get("source_type") or project.source_type or "video"
parse_fps = project.parse_fps or 30.0
payload = task.payload or {}
effective_source = payload.get("source_type") or project.source_type or "video"
parse_fps = _positive_float(payload.get("parse_fps"), project.parse_fps or 30.0)
max_frames = _positive_int(payload.get("max_frames"))
target_width = _positive_int(payload.get("target_width"), 640) or 640
project.parse_fps = parse_fps
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project.id}_")
output_dir = os.path.join(tmp_dir, "frames")
os.makedirs(output_dir, exist_ok=True)
@@ -163,7 +199,7 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=35, message="正在解析 DICOM 序列")
frame_files = parse_dicom(dcm_dir, output_dir)
frame_files = parse_dicom(dcm_dir, output_dir, max_frames=max_frames)
else:
_ensure_not_cancelled(db, task)
media_bytes = download_file(project.video_path)
@@ -173,7 +209,13 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=35, message="正在使用 FFmpeg/OpenCV 拆帧")
frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps))
frame_files, original_fps = parse_video(
local_path,
output_dir,
fps=int(parse_fps),
max_frames=max_frames,
target_width=target_width,
)
project.original_fps = original_fps
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
@@ -205,12 +247,15 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
except Exception: # noqa: BLE001
h, w = None, None
sequence_meta = _frame_sequence_metadata(idx, parse_fps, project.original_fps)
frame = Frame(
project_id=project.id,
frame_index=idx,
image_url=obj_name,
width=w,
height=h,
timestamp_ms=sequence_meta["timestamp_ms"],
source_frame_number=sequence_meta["source_frame_number"],
)
db.add(frame)
frames_out.append(frame)
@@ -223,6 +268,17 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
"frames_extracted": len(frames_out),
"status": PROJECT_STATUS_READY,
"message": "Frame extraction completed successfully.",
"frame_sequence": {
"original_fps": project.original_fps,
"parse_fps": parse_fps,
"frame_count": len(frames_out),
"duration_ms": (len(frames_out) - 1) * 1000.0 / parse_fps if frames_out else 0,
"target_width": target_width,
"frame_width": frames_out[0].width if frames_out else None,
"frame_height": frames_out[0].height if frames_out else None,
"max_frames": max_frames,
"object_prefix": f"projects/{project.id}/frames",
},
}
_set_task_state(
db,

View File

@@ -24,6 +24,7 @@ except Exception as exc: # noqa: BLE001
try:
from sam2.build_sam import build_sam2
from sam2.build_sam import build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
SAM2_AVAILABLE = True
@@ -38,9 +39,12 @@ class SAM2Engine:
def __init__(self) -> None:
self._predictor: Optional[SAM2ImagePredictor] = None
self._video_predictor = None
self._model_loaded = False
self._video_model_loaded = False
self._loaded_device: str | None = None
self._last_error: str | None = None
self._video_last_error: str | None = None
# -----------------------------------------------------------------------
# Internal helpers
@@ -85,6 +89,40 @@ class SAM2Engine:
logger.error("Failed to load SAM 2 model: %s", exc)
self._model_loaded = True # Prevent repeated load attempts
def _load_video_model(self) -> None:
"""Load the SAM 2 video predictor on first propagation use."""
if self._video_model_loaded:
return
if not TORCH_AVAILABLE:
self._video_last_error = "PyTorch is not installed."
self._video_model_loaded = True
return
if not SAM2_AVAILABLE:
self._video_last_error = "sam2 package is not installed."
self._video_model_loaded = True
return
if not os.path.isfile(settings.sam_model_path):
self._video_last_error = f"SAM2 checkpoint not found: {settings.sam_model_path}"
self._video_model_loaded = True
return
try:
device = self._best_device()
self._video_predictor = build_sam2_video_predictor(
settings.sam_model_config,
settings.sam_model_path,
device=device,
)
self._video_model_loaded = True
self._loaded_device = device
self._video_last_error = None
logger.info("SAM 2 video predictor loaded from %s on %s", settings.sam_model_path, device)
except Exception as exc: # noqa: BLE001
self._video_last_error = str(exc)
self._video_model_loaded = True
logger.error("Failed to load SAM 2 video predictor: %s", exc)
def _best_device(self) -> str:
if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available():
return "cuda"
@@ -95,6 +133,11 @@ class SAM2Engine:
self._load_model()
return SAM2_AVAILABLE and self._predictor is not None
def _ensure_video_ready(self) -> bool:
"""Ensure the video predictor is loaded; return whether it is usable."""
self._load_video_model()
return SAM2_AVAILABLE and self._video_predictor is not None
def status(self) -> dict:
"""Return lightweight, real runtime status without forcing model load."""
checkpoint_exists = os.path.isfile(settings.sam_model_path)
@@ -121,7 +164,7 @@ class SAM2Engine:
"available": available,
"loaded": self._predictor is not None,
"device": device,
"supports": ["point", "box", "auto"],
"supports": ["point", "box", "interactive", "auto", "propagate"],
"message": message,
"package_available": SAM2_AVAILABLE,
"checkpoint_exists": checkpoint_exists,
@@ -221,6 +264,52 @@ class SAM2Engine:
logger.error("SAM2 box prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def predict_interactive(
self,
image: np.ndarray,
box: list[float] | None,
points: list[list[float]],
labels: list[int],
) -> tuple[list[list[list[float]]], list[float]]:
"""Run combined box and point prompt segmentation for refinement."""
if not self._ensure_ready():
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
h, w = image.shape[:2]
bbox = None
if box:
bbox = np.array(
[box[0] * w, box[1] * h, box[2] * w, box[3] * h],
dtype=np.float32,
)
pts = None
lbls = None
if points:
pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32)
lbls = np.array(labels, dtype=np.int32)
with torch.inference_mode(): # type: ignore[name-defined]
self._predictor.set_image(image)
masks, scores, _ = self._predictor.predict(
point_coords=pts,
point_labels=lbls,
box=bbox,
multimask_output=False,
)
polygons = []
for m in masks:
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores.tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 interactive prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def predict_auto(self, image: np.ndarray) -> tuple[list[list[list[float]]], list[float]]:
"""Run automatic mask generation (grid of points).
@@ -260,6 +349,89 @@ class SAM2Engine:
logger.error("SAM2 auto prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def propagate_video(
self,
frame_paths: list[str],
source_frame_index: int,
seed: dict,
direction: str = "forward",
max_frames: int | None = None,
) -> list[dict]:
"""Propagate one seed mask across a prepared frame directory with SAM 2 video."""
if not self._ensure_video_ready():
raise RuntimeError(self._video_last_error or self.status()["message"])
if not frame_paths:
return []
if source_frame_index < 0 or source_frame_index >= len(frame_paths):
raise ValueError("source_frame_index is outside the frame sequence.")
import cv2
source_image = cv2.imread(frame_paths[source_frame_index])
if source_image is None:
raise RuntimeError("Failed to decode source frame for SAM 2 propagation.")
height, width = source_image.shape[:2]
seed_mask = self._polygons_to_mask(seed.get("polygons") or [], width, height)
if not seed_mask.any():
bbox = seed.get("bbox")
if isinstance(bbox, list) and len(bbox) == 4:
seed_mask = self._bbox_to_mask(bbox, width, height)
if not seed_mask.any():
raise ValueError("SAM 2 propagation requires a non-empty seed polygon or bbox.")
inference_state = self._video_predictor.init_state(
video_path=os.path.dirname(frame_paths[0]),
offload_video_to_cpu=True,
offload_state_to_cpu=True,
)
self._video_predictor.add_new_mask(
inference_state,
frame_idx=source_frame_index,
obj_id=1,
mask=seed_mask,
)
results: dict[int, dict] = {}
def collect(reverse: bool) -> None:
for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video(
inference_state,
start_frame_idx=source_frame_index,
max_frame_num_to_track=max_frames,
reverse=reverse,
):
masks = out_mask_logits
if hasattr(masks, "detach"):
masks = masks.detach().cpu().numpy()
masks = np.asarray(masks)
if masks.ndim == 4:
masks = masks[:, 0]
polygons = []
scores = []
for mask in masks:
polygon = self._mask_to_polygon(mask > 0)
if polygon:
polygons.append(polygon)
scores.append(1.0)
results[int(out_frame_idx)] = {
"frame_index": int(out_frame_idx),
"polygons": polygons,
"scores": scores,
"object_ids": [int(obj_id) for obj_id in list(out_obj_ids)],
}
normalized_direction = direction.lower()
if normalized_direction in {"forward", "both"}:
collect(reverse=False)
if normalized_direction in {"backward", "both"}:
collect(reverse=True)
try:
self._video_predictor.reset_state(inference_state)
except Exception: # noqa: BLE001
pass
return [results[index] for index in sorted(results)]
# -----------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------
@@ -292,6 +464,38 @@ class SAM2Engine:
]
]
@staticmethod
def _polygons_to_mask(polygons: list[list[list[float]]], width: int, height: int) -> np.ndarray:
import cv2
mask = np.zeros((height, width), dtype=np.uint8)
for polygon in polygons:
if len(polygon) < 3:
continue
pts = np.array(
[
[
int(round(min(max(float(x), 0.0), 1.0) * max(width - 1, 1))),
int(round(min(max(float(y), 0.0), 1.0) * max(height - 1, 1))),
]
for x, y in polygon
],
dtype=np.int32,
)
cv2.fillPoly(mask, [pts], 1)
return mask.astype(bool)
@staticmethod
def _bbox_to_mask(bbox: list[float], width: int, height: int) -> np.ndarray:
x, y, w, h = [min(max(float(value), 0.0), 1.0) for value in bbox]
left = int(round(x * max(width - 1, 1)))
top = int(round(y * max(height - 1, 1)))
right = int(round(min(x + w, 1.0) * max(width - 1, 1)))
bottom = int(round(min(y + h, 1.0) * max(height - 1, 1)))
mask = np.zeros((height, width), dtype=bool)
mask[top:max(bottom + 1, top + 1), left:max(right + 1, left + 1)] = True
return mask
# Singleton instance
sam_engine = SAM2Engine()

View File

@@ -56,8 +56,22 @@ class SAM3Engine:
def _gpu_ok(self) -> bool:
return bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
def _checkpoint_path(self) -> str | None:
path = settings.sam3_checkpoint_path.strip()
return path if path else None
def _checkpoint_exists(self) -> bool:
path = self._checkpoint_path()
return bool(path and os.path.isfile(path))
def _can_load(self) -> bool:
return bool(SAM3_PACKAGE_AVAILABLE and TORCH_AVAILABLE and self._python_ok() and self._gpu_ok())
return bool(
SAM3_PACKAGE_AVAILABLE
and TORCH_AVAILABLE
and self._python_ok()
and self._gpu_ok()
and self._checkpoint_exists()
)
def _worker_path(self) -> Path:
return Path(__file__).with_name("sam3_external_worker.py")
@@ -98,6 +112,8 @@ class SAM3Engine:
try:
env = os.environ.copy()
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
if self._checkpoint_path():
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
completed = subprocess.run(
[settings.sam3_external_python, str(self._worker_path()), "--status"],
capture_output=True,
@@ -146,7 +162,10 @@ class SAM3Engine:
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.model_builder import build_sam3_image_model
self._model = build_sam3_image_model()
self._model = build_sam3_image_model(
checkpoint_path=self._checkpoint_path(),
load_from_HF=False,
)
self._processor = Sam3Processor(self._model)
self._model_loaded = True
self._last_error = None
@@ -170,6 +189,8 @@ class SAM3Engine:
missing.append("PyTorch")
if not self._gpu_ok():
missing.append("CUDA GPU")
if not self._checkpoint_exists():
missing.append(f"local checkpoint ({settings.sam3_checkpoint_path})")
if missing:
return f"SAM 3 unavailable: missing {', '.join(missing)}."
return "SAM 3 dependencies are present; model will load on first inference."
@@ -182,7 +203,7 @@ class SAM3Engine:
if self._processor is not None:
message = "SAM 3 model loaded and ready."
elif external_ready:
message = "SAM 3 external runtime is ready; model will load in the helper process on inference."
message = "SAM 3 external runtime is ready; local checkpoint will load in the helper process on inference."
elif external_status.get("message") and not self._can_load():
message = str(external_status["message"])
return {
@@ -191,11 +212,11 @@ class SAM3Engine:
"available": available,
"loaded": self._processor is not None,
"device": "cuda" if self._gpu_ok() else str(external_status.get("device", "unavailable")),
"supports": ["semantic"],
"supports": ["semantic", "box", "video_track"],
"message": message,
"package_available": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("package_available")),
"checkpoint_exists": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("checkpoint_access")),
"checkpoint_path": f"official/HuggingFace ({settings.sam3_model_version})",
"checkpoint_exists": bool(self._checkpoint_exists() or external_status.get("checkpoint_access")),
"checkpoint_path": self._checkpoint_path() or f"official/HuggingFace ({settings.sam3_model_version})",
"python_ok": bool(self._python_ok() or external_status.get("python_ok")),
"torch_ok": bool(TORCH_AVAILABLE or external_status.get("torch_ok")),
"cuda_required": True,
@@ -203,7 +224,43 @@ class SAM3Engine:
"external_python": settings.sam3_external_python if settings.sam3_external_enabled else None,
}
def _predict_semantic_external(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
def _xyxy_to_cxcywh(self, box: list[float]) -> list[float]:
if len(box) != 4:
raise ValueError("SAM 3 box prompt requires [x1, y1, x2, y2].")
x1, y1, x2, y2 = [min(max(float(value), 0.0), 1.0) for value in box]
left, right = sorted([x1, x2])
top, bottom = sorted([y1, y2])
width = max(right - left, 1e-6)
height = max(bottom - top, 1e-6)
return [left + width / 2, top + height / 2, width, height]
def _prediction_to_polygons(self, output: Any) -> tuple[list[list[list[float]]], list[float]]:
masks = output.get("masks", [])
scores = output.get("scores", [])
polygons = []
for mask in masks:
if hasattr(mask, "detach"):
mask = mask.detach().cpu().numpy()
if mask.ndim == 3:
mask = mask[0]
poly = SAM2Engine._mask_to_polygon(mask)
if poly:
polygons.append(poly)
if hasattr(scores, "detach"):
scores = scores.detach().cpu().tolist()
elif hasattr(scores, "tolist"):
scores = scores.tolist()
return polygons, list(scores)
def _predict_external(
self,
image: np.ndarray,
prompt_type: str,
*,
text: str = "",
box: list[float] | None = None,
) -> tuple[list[list[list[float]]], list[float]]:
status = self._external_status(force=True)
if not status.get("available"):
raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.")
@@ -217,8 +274,11 @@ class SAM3Engine:
json.dumps(
{
"image_path": str(image_path),
"prompt_type": prompt_type,
"text": text.strip(),
"box": box,
"model_version": settings.sam3_model_version,
"checkpoint_path": self._checkpoint_path(),
"confidence_threshold": settings.sam3_confidence_threshold,
},
ensure_ascii=False,
@@ -227,6 +287,8 @@ class SAM3Engine:
)
env = os.environ.copy()
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
if self._checkpoint_path():
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
completed = subprocess.run(
[settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)],
capture_output=True,
@@ -250,6 +312,72 @@ class SAM3Engine:
raise RuntimeError(str(payload["error"]))
return payload.get("polygons", []), payload.get("scores", [])
def _predict_semantic_external(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
return self._predict_external(image, "semantic", text=text)
def _predict_box_external(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]:
return self._predict_external(image, "box", box=box)
def _propagate_video_external(
self,
frame_paths: list[str],
source_frame_index: int,
seed: dict[str, Any],
direction: str,
max_frames: int | None,
) -> list[dict[str, Any]]:
status = self._external_status(force=True)
if not status.get("available"):
raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.")
if not frame_paths:
return []
with tempfile.TemporaryDirectory(prefix="sam3_video_") as tmpdir:
request_path = Path(tmpdir) / "request.json"
request_path.write_text(
json.dumps(
{
"prompt_type": "video_track",
"frame_dir": str(Path(frame_paths[0]).parent),
"source_frame_index": source_frame_index,
"seed": seed,
"direction": direction,
"max_frames": max_frames,
"model_version": settings.sam3_model_version,
"checkpoint_path": self._checkpoint_path(),
"confidence_threshold": settings.sam3_confidence_threshold,
},
ensure_ascii=False,
),
encoding="utf-8",
)
env = os.environ.copy()
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
if self._checkpoint_path():
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
completed = subprocess.run(
[settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)],
capture_output=True,
text=True,
timeout=settings.sam3_timeout_seconds,
check=False,
env=env,
)
if completed.returncode != 0:
detail = completed.stderr.strip() or completed.stdout.strip()
try:
parsed = json.loads(detail)
detail = parsed.get("error", detail)
except Exception: # noqa: BLE001
pass
raise RuntimeError(f"SAM 3 external video tracking failed: {detail}")
payload = json.loads(completed.stdout)
if payload.get("error"):
raise RuntimeError(str(payload["error"]))
return payload.get("frames", [])
def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
if not text.strip():
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
@@ -263,29 +391,37 @@ class SAM3Engine:
state = self._processor.set_image(pil_image)
output = self._processor.set_text_prompt(state=state, prompt=text.strip())
masks = output.get("masks", [])
scores = output.get("scores", [])
polygons = []
for mask in masks:
if hasattr(mask, "detach"):
mask = mask.detach().cpu().numpy()
if mask.ndim == 3:
mask = mask[0]
poly = SAM2Engine._mask_to_polygon(mask)
if poly:
polygons.append(poly)
if hasattr(scores, "detach"):
scores = scores.detach().cpu().tolist()
elif hasattr(scores, "tolist"):
scores = scores.tolist()
return polygons, list(scores)
return self._prediction_to_polygons(output)
def predict_points(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for point prompts.")
def predict_box(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for box prompts.")
def predict_box(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]:
if not self._can_load() and self._external_status().get("available"):
return self._predict_box_external(image, box)
if not self._ensure_ready():
raise RuntimeError(self.status()["message"])
pil_image = Image.fromarray(image)
with torch.inference_mode(): # type: ignore[union-attr]
state = self._processor.set_image(pil_image)
output = self._processor.add_geometric_prompt(
state=state,
box=self._xyxy_to_cxcywh(box),
label=True,
)
return self._prediction_to_polygons(output)
def propagate_video(
self,
frame_paths: list[str],
source_frame_index: int,
seed: dict[str, Any],
direction: str = "forward",
max_frames: int | None = None,
) -> list[dict[str, Any]]:
return self._propagate_video_external(frame_paths, source_frame_index, seed, direction, max_frames)
sam3_engine = SAM3Engine()

View File

@@ -43,6 +43,13 @@ def _compact_error(exc: Exception) -> str:
def _checkpoint_access(model_version: str) -> tuple[bool, str | None]:
checkpoint_path = os.environ.get("SAM3_CHECKPOINT_PATH", "").strip()
if checkpoint_path:
path = Path(checkpoint_path)
if path.is_file():
return True, None
return False, f"local checkpoint not found: {checkpoint_path}"
try:
from huggingface_hub import hf_hub_download
@@ -55,6 +62,7 @@ def _checkpoint_access(model_version: str) -> tuple[bool, str | None]:
def runtime_status() -> dict[str, Any]:
model_version = os.environ.get("SAM3_MODEL_VERSION", "sam3")
checkpoint_path = os.environ.get("SAM3_CHECKPOINT_PATH", "").strip() or None
package_error = None
package_available = importlib.util.find_spec("sam3") is not None
if package_available:
@@ -85,6 +93,7 @@ def runtime_status() -> dict[str, Any]:
"available": available,
"package_available": package_available,
"checkpoint_access": checkpoint_access,
"checkpoint_path": checkpoint_path or f"official/HuggingFace ({model_version})",
"python_ok": python_ok,
"torch_ok": torch_version is not None,
"torch_version": torch_version,
@@ -118,34 +127,67 @@ def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
def _to_numpy(value: Any) -> np.ndarray:
if hasattr(value, "detach"):
value = value.detach().cpu().numpy()
elif hasattr(value, "cpu"):
value = value.detach()
if hasattr(value, "is_floating_point") and value.is_floating_point():
value = value.float()
value = value.cpu().numpy()
elif hasattr(value, "cpu"):
value = value.cpu()
if hasattr(value, "is_floating_point") and value.is_floating_point():
value = value.float()
value = value.numpy()
return np.asarray(value)
def predict(request_path: Path) -> dict[str, Any]:
import torch
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.model_builder import build_sam3_image_model
def _xyxy_to_cxcywh(box: list[float]) -> list[float]:
if len(box) != 4:
raise ValueError("SAM 3 box prompt requires [x1, y1, x2, y2].")
x1, y1, x2, y2 = [min(max(float(value), 0.0), 1.0) for value in box]
left, right = sorted([x1, x2])
top, bottom = sorted([y1, y2])
width = max(right - left, 1e-6)
height = max(bottom - top, 1e-6)
return [left + width / 2, top + height / 2, width, height]
payload = json.loads(request_path.read_text(encoding="utf-8"))
image_path = Path(payload["image_path"])
text = str(payload["text"]).strip()
threshold = float(payload.get("confidence_threshold", 0.5))
if not text:
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def _bbox_from_seed(seed: dict[str, Any]) -> list[float]:
bbox = seed.get("bbox")
if isinstance(bbox, list) and len(bbox) == 4:
return [min(max(float(value), 0.0), 1.0) for value in bbox]
image = Image.open(image_path).convert("RGB")
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
model = build_sam3_image_model()
processor = Sam3Processor(model, confidence_threshold=threshold)
state = processor.set_image(image)
output = processor.set_text_prompt(state=state, prompt=text)
polygons = seed.get("polygons") or []
points = [point for polygon in polygons for point in polygon if len(point) >= 2]
if not points:
raise ValueError("SAM 3 video tracking requires seed bbox or polygons.")
xs = [min(max(float(point[0]), 0.0), 1.0) for point in points]
ys = [min(max(float(point[1]), 0.0), 1.0) for point in points]
left, right = min(xs), max(xs)
top, bottom = min(ys), max(ys)
return [left, top, max(right - left, 1e-6), max(bottom - top, 1e-6)]
def _video_outputs_to_response(outputs: dict[str, Any]) -> dict[str, Any]:
masks = _to_numpy(outputs.get("out_binary_masks", []))
scores = _to_numpy(outputs.get("out_probs", []))
obj_ids = _to_numpy(outputs.get("out_obj_ids", []))
if masks.ndim == 4:
masks = masks[:, 0]
elif masks.ndim == 2:
masks = masks[None, ...]
polygons = []
out_scores = []
out_ids = []
for index, mask in enumerate(masks):
polygon = _mask_to_polygon(mask)
if polygon:
polygons.append(polygon)
out_scores.append(float(scores[index]) if scores.size > index else 1.0)
out_ids.append(int(obj_ids[index]) if obj_ids.size > index else index + 1)
return {"polygons": polygons, "scores": out_scores, "object_ids": out_ids}
def _prediction_to_response(output: dict[str, Any]) -> dict[str, Any]:
masks = _to_numpy(output.get("masks", []))
scores = _to_numpy(output.get("scores", []))
if masks.ndim == 4:
@@ -165,6 +207,115 @@ def predict(request_path: Path) -> dict[str, Any]:
}
def predict_video(request_path: Path) -> dict[str, Any]:
import torch
from sam3.model_builder import build_sam3_video_predictor
payload = json.loads(request_path.read_text(encoding="utf-8"))
frame_dir = Path(payload["frame_dir"])
source_frame_index = int(payload.get("source_frame_index", 0))
seed = payload.get("seed") or {}
direction = str(payload.get("direction") or "forward").lower()
max_frames = payload.get("max_frames")
max_frames = int(max_frames) if max_frames else None
checkpoint_path = str(payload.get("checkpoint_path") or os.environ.get("SAM3_CHECKPOINT_PATH", "")).strip()
threshold = float(payload.get("confidence_threshold", 0.5))
if direction not in {"forward", "backward", "both"}:
raise ValueError(f"Unsupported propagation direction: {direction}")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
predictor = build_sam3_video_predictor(
checkpoint_path=checkpoint_path or None,
async_loading_frames=False,
)
session_id = None
try:
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
session = predictor.handle_request(
{
"type": "start_session",
"resource_path": str(frame_dir),
"offload_video_to_cpu": True,
"offload_state_to_cpu": True,
}
)
session_id = session["session_id"]
predictor.handle_request(
{
"type": "add_prompt",
"session_id": session_id,
"frame_index": source_frame_index,
"bounding_boxes": [_bbox_from_seed(seed)],
"bounding_box_labels": [1],
"output_prob_thresh": threshold,
"rel_coordinates": True,
}
)
frames = []
for item in predictor.handle_stream_request(
{
"type": "propagate_in_video",
"session_id": session_id,
"propagation_direction": direction,
"start_frame_index": source_frame_index,
"max_frame_num_to_track": max_frames,
"output_prob_thresh": threshold,
}
):
frame_response = _video_outputs_to_response(item.get("outputs") or {})
frame_response["frame_index"] = int(item["frame_index"])
frames.append(frame_response)
finally:
if session_id:
predictor.handle_request({"type": "close_session", "session_id": session_id})
return {"frames": frames}
def predict(request_path: Path) -> dict[str, Any]:
import torch
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.model_builder import build_sam3_image_model
payload = json.loads(request_path.read_text(encoding="utf-8"))
if str(payload.get("prompt_type") or "").strip().lower() == "video_track":
return predict_video(request_path)
image_path = Path(payload["image_path"])
prompt_type = str(payload.get("prompt_type") or "semantic").strip().lower()
text = str(payload.get("text") or "").strip()
threshold = float(payload.get("confidence_threshold", 0.5))
checkpoint_path = str(payload.get("checkpoint_path") or os.environ.get("SAM3_CHECKPOINT_PATH", "")).strip()
if prompt_type == "semantic" and not text:
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
if prompt_type not in {"semantic", "box"}:
raise ValueError(f"Unsupported SAM 3 prompt type: {prompt_type}")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
image = Image.open(image_path).convert("RGB")
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
model = build_sam3_image_model(
checkpoint_path=checkpoint_path or None,
load_from_HF=not bool(checkpoint_path),
)
processor = Sam3Processor(model, confidence_threshold=threshold)
state = processor.set_image(image)
if prompt_type == "box":
output = processor.add_geometric_prompt(
state=state,
box=_xyxy_to_cxcywh(payload.get("box") or []),
label=True,
)
else:
output = processor.set_text_prompt(state=state, prompt=text)
return _prediction_to_response(output)
def main() -> int:
parser = argparse.ArgumentParser(description="SAM 3 external runtime helper")
parser.add_argument("--status", action="store_true")

View File

@@ -67,6 +67,19 @@ class SAMRegistry:
def predict_box(self, model_id: str | None, image: Any, box: list[float]):
return self._ensure_available(model_id).predict_box(image, box)
def predict_interactive(
self,
model_id: str | None,
image: Any,
box: list[float] | None,
points: list[list[float]],
labels: list[int],
):
model = self.normalize_model_id(model_id)
if model != "sam2":
raise NotImplementedError("Interactive box + point refinement is currently supported by SAM 2.")
return self._ensure_available(model).predict_interactive(image, box, points, labels)
def predict_auto(self, model_id: str | None, image: Any):
return self._ensure_available(model_id).predict_auto(image)
@@ -76,5 +89,22 @@ class SAMRegistry:
return self._ensure_available(model).predict_semantic(image, text)
return self._ensure_available(model).predict_auto(image)
def propagate_video(
self,
model_id: str | None,
frame_paths: list[str],
source_frame_index: int,
seed: dict[str, Any],
direction: str,
max_frames: int | None,
):
return self._ensure_available(model_id).propagate_video(
frame_paths,
source_frame_index,
seed,
direction=direction,
max_frames=max_frames,
)
sam_registry = SAMRegistry()