"""Standalone SAM 3 helper for the dedicated Python 3.12 runtime. The main FastAPI backend can keep running in the existing Python 3.11/SAM 2 environment while this helper is executed with a separate conda env that meets SAM 3's stricter runtime requirements. """ from __future__ import annotations import argparse import importlib.util import json import os import sys from pathlib import Path from typing import Any import numpy as np from PIL import Image def _torch_status() -> tuple[bool, str | None, str | None, str | None]: try: import torch cuda_available = bool(torch.cuda.is_available()) return ( cuda_available, getattr(torch, "__version__", None), getattr(torch.version, "cuda", None), torch.cuda.get_device_name(0) if cuda_available else None, ) except Exception: # noqa: BLE001 return False, None, None, None def _compact_error(exc: Exception) -> str: lines = [line.strip() for line in str(exc).splitlines() if line.strip()] for line in lines: if "Access to model" in line or "Cannot access gated repo" in line: return line return lines[0] if lines else exc.__class__.__name__ def _checkpoint_access(model_version: str) -> tuple[bool, str | None]: checkpoint_path = os.environ.get("SAM3_CHECKPOINT_PATH", "").strip() if checkpoint_path: path = Path(checkpoint_path) if path.is_file(): return True, None return False, f"local checkpoint not found: {checkpoint_path}" try: from huggingface_hub import hf_hub_download repo_id = "facebook/sam3.1" if model_version == "sam3.1" else "facebook/sam3" hf_hub_download(repo_id=repo_id, filename="config.json") return True, None except Exception as exc: # noqa: BLE001 return False, _compact_error(exc) def runtime_status() -> dict[str, Any]: model_version = os.environ.get("SAM3_MODEL_VERSION", "sam3") checkpoint_path = os.environ.get("SAM3_CHECKPOINT_PATH", "").strip() or None package_error = None package_available = importlib.util.find_spec("sam3") is not None if package_available: try: import sam3 # noqa: F401 except Exception as exc: # noqa: BLE001 package_available = False package_error = str(exc) cuda_available, torch_version, cuda_version, device_name = _torch_status() python_ok = sys.version_info >= (3, 12) checkpoint_access = False checkpoint_error = None if package_available: checkpoint_access, checkpoint_error = _checkpoint_access(model_version) available = bool(package_available and python_ok and cuda_available and checkpoint_access) missing = [] if not python_ok: missing.append("Python 3.12+ runtime") if not package_available: missing.append(f"sam3 package ({package_error})" if package_error else "sam3 package") if torch_version is None: missing.append("PyTorch") if not cuda_available: missing.append("CUDA GPU") if package_available and not checkpoint_access: missing.append(f"Hugging Face checkpoint access ({checkpoint_error})") return { "available": available, "package_available": package_available, "checkpoint_access": checkpoint_access, "checkpoint_path": checkpoint_path or f"official/HuggingFace ({model_version})", "python_ok": python_ok, "torch_ok": torch_version is not None, "torch_version": torch_version, "cuda_version": cuda_version, "cuda_available": cuda_available, "device": "cuda" if cuda_available else "unavailable", "device_name": device_name, "message": ( "SAM 3 external runtime is ready." if available else f"SAM 3 external runtime unavailable: missing {', '.join(missing)}." ), } def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]: import cv2 if mask.dtype != np.uint8: mask = (mask > 0).astype(np.uint8) contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) height, width = mask.shape[:2] largest = [] for contour in contours: if len(contour) > len(largest): largest = contour if len(largest) < 3: return [] return [[float(point[0][0]) / width, float(point[0][1]) / height] for point in largest] def _to_numpy(value: Any) -> np.ndarray: if hasattr(value, "detach"): value = value.detach() if hasattr(value, "is_floating_point") and value.is_floating_point(): value = value.float() value = value.cpu().numpy() elif hasattr(value, "cpu"): value = value.cpu() if hasattr(value, "is_floating_point") and value.is_floating_point(): value = value.float() value = value.numpy() return np.asarray(value) def _xyxy_to_cxcywh(box: list[float]) -> list[float]: if len(box) != 4: raise ValueError("SAM 3 box prompt requires [x1, y1, x2, y2].") x1, y1, x2, y2 = [min(max(float(value), 0.0), 1.0) for value in box] left, right = sorted([x1, x2]) top, bottom = sorted([y1, y2]) width = max(right - left, 1e-6) height = max(bottom - top, 1e-6) return [left + width / 2, top + height / 2, width, height] def _bbox_from_seed(seed: dict[str, Any]) -> list[float]: bbox = seed.get("bbox") if isinstance(bbox, list) and len(bbox) == 4: return [min(max(float(value), 0.0), 1.0) for value in bbox] polygons = seed.get("polygons") or [] points = [point for polygon in polygons for point in polygon if len(point) >= 2] if not points: raise ValueError("SAM 3 video tracking requires seed bbox or polygons.") xs = [min(max(float(point[0]), 0.0), 1.0) for point in points] ys = [min(max(float(point[1]), 0.0), 1.0) for point in points] left, right = min(xs), max(xs) top, bottom = min(ys), max(ys) return [left, top, max(right - left, 1e-6), max(bottom - top, 1e-6)] def _video_outputs_to_response(outputs: dict[str, Any]) -> dict[str, Any]: masks = _to_numpy(outputs.get("out_binary_masks", [])) scores = _to_numpy(outputs.get("out_probs", [])) obj_ids = _to_numpy(outputs.get("out_obj_ids", [])) if masks.ndim == 4: masks = masks[:, 0] elif masks.ndim == 2: masks = masks[None, ...] polygons = [] out_scores = [] out_ids = [] for index, mask in enumerate(masks): polygon = _mask_to_polygon(mask) if polygon: polygons.append(polygon) out_scores.append(float(scores[index]) if scores.size > index else 1.0) out_ids.append(int(obj_ids[index]) if obj_ids.size > index else index + 1) return {"polygons": polygons, "scores": out_scores, "object_ids": out_ids} def _prediction_to_response(output: dict[str, Any]) -> dict[str, Any]: masks = _to_numpy(output.get("masks", [])) scores = _to_numpy(output.get("scores", [])) if masks.ndim == 4: masks = masks[:, 0] elif masks.ndim == 3 and masks.shape[0] == 1: masks = masks[None, 0] polygons = [] for mask in masks: polygon = _mask_to_polygon(mask) if polygon: polygons.append(polygon) return { "polygons": polygons, "scores": scores.astype(float).tolist() if scores.size else [], } def predict_video(request_path: Path) -> dict[str, Any]: import torch from sam3.model_builder import build_sam3_video_predictor payload = json.loads(request_path.read_text(encoding="utf-8")) frame_dir = Path(payload["frame_dir"]) source_frame_index = int(payload.get("source_frame_index", 0)) seed = payload.get("seed") or {} direction = str(payload.get("direction") or "forward").lower() max_frames = payload.get("max_frames") max_frames = int(max_frames) if max_frames else None checkpoint_path = str(payload.get("checkpoint_path") or os.environ.get("SAM3_CHECKPOINT_PATH", "")).strip() threshold = float(payload.get("confidence_threshold", 0.5)) if direction not in {"forward", "backward", "both"}: raise ValueError(f"Unsupported propagation direction: {direction}") torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True predictor = build_sam3_video_predictor( checkpoint_path=checkpoint_path or None, async_loading_frames=False, ) session_id = None try: with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): session = predictor.handle_request( { "type": "start_session", "resource_path": str(frame_dir), "offload_video_to_cpu": True, "offload_state_to_cpu": True, } ) session_id = session["session_id"] predictor.handle_request( { "type": "add_prompt", "session_id": session_id, "frame_index": source_frame_index, "bounding_boxes": [_bbox_from_seed(seed)], "bounding_box_labels": [1], "output_prob_thresh": threshold, "rel_coordinates": True, } ) frames = [] for item in predictor.handle_stream_request( { "type": "propagate_in_video", "session_id": session_id, "propagation_direction": direction, "start_frame_index": source_frame_index, "max_frame_num_to_track": max_frames, "output_prob_thresh": threshold, } ): frame_response = _video_outputs_to_response(item.get("outputs") or {}) frame_response["frame_index"] = int(item["frame_index"]) frames.append(frame_response) finally: if session_id: predictor.handle_request({"type": "close_session", "session_id": session_id}) return {"frames": frames} def predict(request_path: Path) -> dict[str, Any]: import torch from sam3.model.sam3_image_processor import Sam3Processor from sam3.model_builder import build_sam3_image_model payload = json.loads(request_path.read_text(encoding="utf-8")) if str(payload.get("prompt_type") or "").strip().lower() == "video_track": return predict_video(request_path) image_path = Path(payload["image_path"]) prompt_type = str(payload.get("prompt_type") or "semantic").strip().lower() text = str(payload.get("text") or "").strip() threshold = float(payload.get("confidence_threshold", 0.5)) checkpoint_path = str(payload.get("checkpoint_path") or os.environ.get("SAM3_CHECKPOINT_PATH", "")).strip() if prompt_type == "semantic" and not text: raise ValueError("SAM 3 semantic prompt requires non-empty text.") if prompt_type not in {"semantic", "box"}: raise ValueError(f"Unsupported SAM 3 prompt type: {prompt_type}") torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True image = Image.open(image_path).convert("RGB") with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): model = build_sam3_image_model( checkpoint_path=checkpoint_path or None, load_from_HF=not bool(checkpoint_path), ) processor = Sam3Processor(model, confidence_threshold=threshold) state = processor.set_image(image) if prompt_type == "box": output = processor.add_geometric_prompt( state=state, box=_xyxy_to_cxcywh(payload.get("box") or []), label=True, ) else: output = processor.set_text_prompt(state=state, prompt=text) return _prediction_to_response(output) def main() -> int: parser = argparse.ArgumentParser(description="SAM 3 external runtime helper") parser.add_argument("--status", action="store_true") parser.add_argument("--request", type=Path) args = parser.parse_args() try: if args.status: print(json.dumps(runtime_status(), ensure_ascii=False)) return 0 if args.request: print(json.dumps(predict(args.request), ensure_ascii=False)) return 0 parser.error("Use --status or --request") except Exception as exc: # noqa: BLE001 print(json.dumps({"error": str(exc)}, ensure_ascii=False), file=sys.stderr) return 1 return 2 if __name__ == "__main__": raise SystemExit(main())