feat: 完善视频传播、标注编辑和拆帧闭环
- 接入 SAM2 视频传播能力:新增 /api/ai/propagate,支持用当前帧 mask/polygon/bbox 作为 seed,通过 SAM2 video predictor 向前、向后或双向传播,并可保存为真实 annotation。 - 接入 SAM3 video tracker:通过独立 Python 3.12 external worker 调用 SAM3 video predictor/tracker,使用本地 checkpoint 与 bbox seed 执行视频级跟踪,并在模型状态中标记 video_track 能力。 - 完善 SAM 模型分发:sam_registry 按 model_id 明确区分 sam2 propagation 与 sam3 video_track,避免两个模型链路混用。 - 打通前端“传播片段”:VideoWorkspace 使用当前选中 mask 和当前 AI 模型调用后端传播接口,传播结果回写并刷新工作区已保存标注。 - 增强 SAM3 本地 checkpoint 配置:新增 sam3_checkpoint_path 配置和 .env.example 示例,状态检查改为基于本地 checkpoint/独立环境/模型包可用性。 - 完善视频拆帧参数:/api/media/parse 支持 parse_fps、max_frames、target_width,后端任务保存帧时间戳、源帧号和 frame_sequence 元数据。 - 增加运行时 schema 兼容处理:启动时为旧 frames 表补充 timestamp_ms 和 source_frame_number 列,避免旧库升级后缺字段。 - 强化 Canvas 标注编辑:补齐多边形闭合、点工具、顶点拖拽、边中点插入、Delete/Backspace 删除、区域合并和重叠去除等交互。 - 增强语义分类联动:选中 mask 后可通过右侧语义分类树更新标签、颜色和 class metadata,并同步到保存/导出链路。 - 增加关键帧时间轴体验:FrameTimeline 显示具体时间信息,并支持键盘左右方向键切换关键帧。 - 完善 AI 交互分割参数:前端保留正向点、反向点、框选和 interactive prompt 的调用状态,支持 SAM2 细化候选区域与 SAM3 bbox 入口。 - 扩展后端/前端 API 类型:新增 propagateMasks、传播请求/响应 schema,并补齐 annotation、导出、模型状态和任务接口的测试覆盖。 - 更新项目文档:同步 README、AGENTS、接口契约、需求冻结、设计冻结、前端元素审计、实施计划和测试计划,标明真实功能边界与剩余风险。 - 增加测试覆盖:补充 SAM2/SAM3 传播、SAM3 状态、媒体拆帧参数、Canvas 编辑、语义标签切换、时间轴、工作区传播和 API 合约测试。 - 加强仓库安全边界:将 sam3权重/ 加入 .gitignore,避免本地模型权重被误提交。 验证:npm run test:run;pytest backend/tests;npm run lint;npm run build;python -m py_compile;git diff --check。
This commit is contained in:
@@ -23,6 +23,7 @@ class Settings(BaseSettings):
|
||||
sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
|
||||
sam_model_config: str = "configs/sam2/sam2_hiera_t.yaml"
|
||||
sam3_model_version: str = "sam3"
|
||||
sam3_checkpoint_path: str = "/home/wkmgc/Desktop/Seg_Server/sam3权重/sam3.pt"
|
||||
sam3_external_enabled: bool = True
|
||||
sam3_external_python: str = "/home/wkmgc/miniconda3/envs/sam3/bin/python"
|
||||
sam3_timeout_seconds: int = 300
|
||||
|
||||
@@ -11,6 +11,7 @@ from datetime import datetime, timezone
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sqlalchemy import inspect, text
|
||||
|
||||
from config import settings
|
||||
from database import Base, engine, SessionLocal
|
||||
@@ -30,6 +31,20 @@ logger = logging.getLogger(__name__)
|
||||
DEFAULT_VIDEO_PATH = "/home/wkmgc/Desktop/Seg_Server/Data_MyVideo_1.mp4"
|
||||
|
||||
|
||||
def _ensure_runtime_schema_columns() -> None:
|
||||
"""Add nullable columns introduced after initial create_all deployments."""
|
||||
try:
|
||||
inspector = inspect(engine)
|
||||
frame_columns = {column["name"] for column in inspector.get_columns("frames")}
|
||||
with engine.begin() as connection:
|
||||
if "timestamp_ms" not in frame_columns:
|
||||
connection.execute(text("ALTER TABLE frames ADD COLUMN timestamp_ms FLOAT"))
|
||||
if "source_frame_number" not in frame_columns:
|
||||
connection.execute(text("ALTER TABLE frames ADD COLUMN source_frame_number INTEGER"))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Runtime schema column check failed: %s", exc)
|
||||
|
||||
|
||||
def _seed_default_project_sync() -> None:
|
||||
"""Synchronously seed the default video project on first startup."""
|
||||
import cv2
|
||||
@@ -93,12 +108,16 @@ def _seed_default_project_sync() -> None:
|
||||
for idx, obj_name in enumerate(object_names):
|
||||
img = cv2.imread(frame_files[idx])
|
||||
h, w = img.shape[:2] if img is not None else (None, None)
|
||||
timestamp_ms = idx * 1000.0 / 30.0
|
||||
source_frame_number = int(round(idx * original_fps / 30.0)) if original_fps else None
|
||||
frame = Frame(
|
||||
project_id=project.id,
|
||||
frame_index=idx,
|
||||
image_url=obj_name,
|
||||
width=w,
|
||||
height=h,
|
||||
timestamp_ms=timestamp_ms,
|
||||
source_frame_number=source_frame_number,
|
||||
)
|
||||
db.add(frame)
|
||||
|
||||
@@ -176,6 +195,7 @@ async def lifespan(app: FastAPI):
|
||||
# Initialize database tables
|
||||
try:
|
||||
Base.metadata.create_all(bind=engine)
|
||||
_ensure_runtime_schema_columns()
|
||||
logger.info("Database tables initialized.")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Database initialization failed: %s", exc)
|
||||
|
||||
@@ -56,6 +56,8 @@ class Frame(Base):
|
||||
image_url = Column(String(512), nullable=False)
|
||||
width = Column(Integer, nullable=True)
|
||||
height = Column(Integer, nullable=True)
|
||||
timestamp_ms = Column(Float, nullable=True)
|
||||
source_frame_number = Column(Integer, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
project = relationship("Project", back_populates="frames")
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""AI inference endpoints using selectable SAM runtimes."""
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
|
||||
import cv2
|
||||
@@ -15,6 +17,8 @@ from schemas import (
|
||||
AiRuntimeStatus,
|
||||
PredictRequest,
|
||||
PredictResponse,
|
||||
PropagateRequest,
|
||||
PropagateResponse,
|
||||
AnnotationOut,
|
||||
AnnotationCreate,
|
||||
AnnotationUpdate,
|
||||
@@ -66,6 +70,48 @@ def _contour_bbox(contour: np.ndarray, width: int, height: int) -> list[float]:
|
||||
]
|
||||
|
||||
|
||||
def _polygon_bbox(polygon: list[list[float]]) -> list[float]:
|
||||
xs = [_clamp01(point[0]) for point in polygon]
|
||||
ys = [_clamp01(point[1]) for point in polygon]
|
||||
left, right = min(xs), max(xs)
|
||||
top, bottom = min(ys), max(ys)
|
||||
return [left, top, max(right - left, 0.0), max(bottom - top, 0.0)]
|
||||
|
||||
|
||||
def _frame_window(
|
||||
frames: list[Frame],
|
||||
source_position: int,
|
||||
direction: str,
|
||||
max_frames: int,
|
||||
) -> tuple[list[Frame], int]:
|
||||
count = max(1, min(max_frames, len(frames)))
|
||||
if direction == "backward":
|
||||
start = max(0, source_position - count + 1)
|
||||
return frames[start:source_position + 1], source_position - start
|
||||
if direction == "both":
|
||||
before = (count - 1) // 2
|
||||
after = count - 1 - before
|
||||
start = max(0, source_position - before)
|
||||
end = min(len(frames), source_position + after + 1)
|
||||
while end - start < count and start > 0:
|
||||
start -= 1
|
||||
while end - start < count and end < len(frames):
|
||||
end += 1
|
||||
return frames[start:end], source_position - start
|
||||
end = min(len(frames), source_position + count)
|
||||
return frames[source_position:end], 0
|
||||
|
||||
|
||||
def _write_frame_sequence(frames: list[Frame], directory: Path) -> list[str]:
|
||||
paths = []
|
||||
for index, frame in enumerate(frames):
|
||||
data = download_file(frame.image_url)
|
||||
path = directory / f"frame_{index:06d}.jpg"
|
||||
path.write_bytes(data)
|
||||
paths.append(str(path))
|
||||
return paths
|
||||
|
||||
|
||||
def _component_seed_point(component_mask: np.ndarray, width: int, height: int) -> list[float]:
|
||||
"""Reduce a binary component to one positive prompt point using distance transform."""
|
||||
dist = cv2.distanceTransform(component_mask.astype(np.uint8), cv2.DIST_L2, 5)
|
||||
@@ -184,6 +230,7 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
- **point**: `prompt_data` is either a list of `[[x, y], ...]` normalized
|
||||
coordinates or `{ "points": [[x, y], ...], "labels": [1, 0, ...] }`.
|
||||
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
|
||||
- **interactive**: `prompt_data` is `{ "box": [...], "points": [[x, y]], "labels": [1, 0] }`.
|
||||
- **semantic**: SAM 3 text prompt when model=`sam3`; SAM 2 falls back to auto.
|
||||
"""
|
||||
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
|
||||
@@ -246,6 +293,51 @@ def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
if crop_bounds:
|
||||
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
|
||||
|
||||
elif prompt_type == "interactive":
|
||||
prompt = payload.prompt_data
|
||||
if not isinstance(prompt, dict):
|
||||
raise HTTPException(status_code=400, detail="Invalid interactive prompt data")
|
||||
box = prompt.get("box")
|
||||
points = prompt.get("points") or []
|
||||
labels = prompt.get("labels")
|
||||
if box is not None and (not isinstance(box, list) or len(box) != 4):
|
||||
raise HTTPException(status_code=400, detail="Invalid interactive box prompt data")
|
||||
if not isinstance(points, list):
|
||||
raise HTTPException(status_code=400, detail="Invalid interactive point prompt data")
|
||||
if not box and len(points) == 0:
|
||||
raise HTTPException(status_code=400, detail="Interactive prompt requires a box or points")
|
||||
if not isinstance(labels, list) or len(labels) != len(points):
|
||||
labels = [1] * len(points)
|
||||
negative_points = [
|
||||
point for point, label in zip(points, labels) if label == 0
|
||||
]
|
||||
inference_image = image
|
||||
inference_box = box
|
||||
inference_points = points
|
||||
crop_bounds = None
|
||||
if options.get("crop_to_prompt"):
|
||||
margin = float(options.get("crop_margin", 0.05) or 0.05)
|
||||
crop_points = list(points)
|
||||
if box:
|
||||
crop_points.extend([[box[0], box[1]], [box[2], box[3]]])
|
||||
crop_bounds = _crop_bounds_from_points(crop_points, margin)
|
||||
inference_image = _crop_image(image, crop_bounds)
|
||||
inference_points = [_to_crop_point(point, crop_bounds) for point in points]
|
||||
if box:
|
||||
inference_box = [
|
||||
*_to_crop_point([box[0], box[1]], crop_bounds),
|
||||
*_to_crop_point([box[2], box[3]], crop_bounds),
|
||||
]
|
||||
polygons, scores = sam_registry.predict_interactive(
|
||||
payload.model,
|
||||
inference_image,
|
||||
inference_box,
|
||||
inference_points,
|
||||
labels,
|
||||
)
|
||||
if crop_bounds:
|
||||
polygons = [_from_crop_polygon(polygon, crop_bounds) for polygon in polygons]
|
||||
|
||||
elif prompt_type == "semantic":
|
||||
text = payload.prompt_data if isinstance(payload.prompt_data, str) else ""
|
||||
polygons, scores = sam_registry.predict_semantic(payload.model, image, text)
|
||||
@@ -276,6 +368,124 @@ def model_status(selected_model: str | None = None) -> dict:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/propagate",
|
||||
response_model=PropagateResponse,
|
||||
summary="Propagate one current-frame region across a video frame segment",
|
||||
)
|
||||
def propagate(payload: PropagateRequest, db: Session = Depends(get_db)) -> dict:
|
||||
"""Track one selected region from the current frame across nearby frames.
|
||||
|
||||
SAM 2 uses the official video predictor with the selected mask as the seed.
|
||||
SAM 3 uses the external Python 3.12 video tracker with the seed bbox.
|
||||
"""
|
||||
direction = payload.direction.lower()
|
||||
if direction not in {"forward", "backward", "both"}:
|
||||
raise HTTPException(status_code=400, detail="direction must be forward, backward, or both")
|
||||
max_frames = max(1, min(int(payload.max_frames or 30), 500))
|
||||
|
||||
project = db.query(Project).filter(Project.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
source_frame = db.query(Frame).filter(
|
||||
Frame.id == payload.frame_id,
|
||||
Frame.project_id == payload.project_id,
|
||||
).first()
|
||||
if not source_frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
seed = payload.seed.model_dump(exclude_none=True)
|
||||
polygons = seed.get("polygons") or []
|
||||
bbox = seed.get("bbox")
|
||||
points = seed.get("points") or []
|
||||
if not polygons and not bbox and not points:
|
||||
raise HTTPException(status_code=400, detail="Propagation requires seed polygons, bbox, or points")
|
||||
|
||||
frames = db.query(Frame).filter(Frame.project_id == payload.project_id).order_by(Frame.frame_index).all()
|
||||
source_position = next((index for index, frame in enumerate(frames) if frame.id == source_frame.id), None)
|
||||
if source_position is None:
|
||||
raise HTTPException(status_code=404, detail="Source frame is not in project frame sequence")
|
||||
|
||||
selected_frames, source_relative_index = _frame_window(frames, source_position, direction, max_frames)
|
||||
if len(selected_frames) == 0:
|
||||
raise HTTPException(status_code=400, detail="No frames available for propagation")
|
||||
|
||||
try:
|
||||
with tempfile.TemporaryDirectory(prefix=f"seg_propagate_{payload.project_id}_") as tmpdir:
|
||||
frame_paths = _write_frame_sequence(selected_frames, Path(tmpdir))
|
||||
propagated = sam_registry.propagate_video(
|
||||
payload.model,
|
||||
frame_paths,
|
||||
source_relative_index,
|
||||
seed,
|
||||
direction,
|
||||
len(selected_frames),
|
||||
)
|
||||
except ModelUnavailableError as exc:
|
||||
raise HTTPException(status_code=503, detail=str(exc)) from exc
|
||||
except NotImplementedError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Video propagation failed: %s", exc)
|
||||
raise HTTPException(status_code=500, detail=f"Video propagation failed: {exc}") from exc
|
||||
|
||||
created: list[Annotation] = []
|
||||
if payload.save_annotations:
|
||||
class_metadata = seed.get("class_metadata")
|
||||
template_id = seed.get("template_id")
|
||||
label = seed.get("label") or "Propagated Mask"
|
||||
color = seed.get("color") or "#06b6d4"
|
||||
model_id = sam_registry.normalize_model_id(payload.model)
|
||||
|
||||
for frame_result in propagated:
|
||||
relative_index = int(frame_result.get("frame_index", -1))
|
||||
if relative_index < 0 or relative_index >= len(selected_frames):
|
||||
continue
|
||||
frame = selected_frames[relative_index]
|
||||
if not payload.include_source and frame.id == source_frame.id:
|
||||
continue
|
||||
result_polygons = frame_result.get("polygons") or []
|
||||
scores = frame_result.get("scores") or []
|
||||
for polygon_index, polygon in enumerate(result_polygons):
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
annotation = Annotation(
|
||||
project_id=payload.project_id,
|
||||
frame_id=frame.id,
|
||||
template_id=template_id,
|
||||
mask_data={
|
||||
"polygons": [polygon],
|
||||
"label": label,
|
||||
"color": color,
|
||||
"source": f"{model_id}_propagation",
|
||||
"propagated_from_frame_id": source_frame.id,
|
||||
"propagated_from_frame_index": source_frame.frame_index,
|
||||
"score": scores[polygon_index] if polygon_index < len(scores) else None,
|
||||
**({"class": class_metadata} if class_metadata else {}),
|
||||
},
|
||||
points=None,
|
||||
bbox=_polygon_bbox(polygon),
|
||||
)
|
||||
db.add(annotation)
|
||||
created.append(annotation)
|
||||
|
||||
db.commit()
|
||||
for annotation in created:
|
||||
db.refresh(annotation)
|
||||
|
||||
return {
|
||||
"model": sam_registry.normalize_model_id(payload.model),
|
||||
"direction": direction,
|
||||
"source_frame_id": source_frame.id,
|
||||
"processed_frame_count": len(selected_frames),
|
||||
"created_annotation_count": len(created),
|
||||
"annotations": created,
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/auto",
|
||||
response_model=PredictResponse,
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
@@ -169,6 +169,9 @@ async def upload_dicom_batch(
|
||||
def parse_media(
|
||||
project_id: int,
|
||||
source_type: Optional[str] = None,
|
||||
parse_fps: Optional[float] = Query(None, gt=0, le=120),
|
||||
max_frames: Optional[int] = Query(None, gt=0),
|
||||
target_width: int = Query(640, ge=64, le=4096),
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProcessingTask:
|
||||
"""Create a background task for media frame extraction.
|
||||
@@ -184,14 +187,21 @@ def parse_media(
|
||||
raise HTTPException(status_code=400, detail="Project has no media uploaded")
|
||||
|
||||
effective_source = source_type or project.source_type or "video"
|
||||
effective_parse_fps = parse_fps or project.parse_fps or 30.0
|
||||
task = ProcessingTask(
|
||||
task_type=f"parse_{effective_source}",
|
||||
status=TASK_STATUS_QUEUED,
|
||||
progress=0,
|
||||
message="解析任务已入队",
|
||||
project_id=project_id,
|
||||
payload={"source_type": effective_source},
|
||||
payload={
|
||||
"source_type": effective_source,
|
||||
"parse_fps": effective_parse_fps,
|
||||
"max_frames": max_frames,
|
||||
"target_width": target_width,
|
||||
},
|
||||
)
|
||||
project.parse_fps = effective_parse_fps
|
||||
project.status = PROJECT_STATUS_PARSING
|
||||
db.add(task)
|
||||
db.commit()
|
||||
|
||||
@@ -51,6 +51,8 @@ class FrameBase(BaseModel):
|
||||
image_url: str
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
timestamp_ms: Optional[float] = None
|
||||
source_frame_number: Optional[int] = None
|
||||
|
||||
|
||||
class FrameCreate(FrameBase):
|
||||
@@ -188,6 +190,37 @@ class PredictResponse(BaseModel):
|
||||
scores: Optional[list[float]] = None
|
||||
|
||||
|
||||
class PropagationSeed(BaseModel):
|
||||
polygons: Optional[list[list[list[float]]]] = None
|
||||
bbox: Optional[list[float]] = None
|
||||
points: Optional[list[list[float]]] = None
|
||||
labels: Optional[list[int]] = None
|
||||
label: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
class_metadata: Optional[dict[str, Any]] = None
|
||||
template_id: Optional[int] = None
|
||||
|
||||
|
||||
class PropagateRequest(BaseModel):
|
||||
project_id: int
|
||||
frame_id: int
|
||||
model: Optional[str] = "sam2"
|
||||
seed: PropagationSeed
|
||||
direction: str = "forward"
|
||||
max_frames: int = 30
|
||||
include_source: bool = False
|
||||
save_annotations: bool = True
|
||||
|
||||
|
||||
class PropagateResponse(BaseModel):
|
||||
model: str
|
||||
direction: str
|
||||
source_frame_id: int
|
||||
processed_frame_count: int
|
||||
created_annotation_count: int
|
||||
annotations: list[AnnotationOut]
|
||||
|
||||
|
||||
class AiModelStatus(BaseModel):
|
||||
id: str
|
||||
label: str
|
||||
|
||||
@@ -52,6 +52,7 @@ def parse_video(
|
||||
output_dir: str,
|
||||
fps: int = 30,
|
||||
max_frames: Optional[int] = None,
|
||||
target_width: int = 640,
|
||||
) -> Tuple[List[str], float]:
|
||||
"""Extract frames from a video file using FFmpeg or OpenCV fallback.
|
||||
|
||||
@@ -60,6 +61,7 @@ def parse_video(
|
||||
output_dir: Directory to save extracted frames.
|
||||
fps: Target frame extraction rate.
|
||||
max_frames: Optional maximum number of frames to extract.
|
||||
target_width: Output frame width for model-friendly frame sequences.
|
||||
|
||||
Returns:
|
||||
Tuple of (frame_paths, original_fps).
|
||||
@@ -67,6 +69,8 @@ def parse_video(
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
frame_paths: List[str] = []
|
||||
original_fps = get_video_fps(video_path)
|
||||
safe_fps = max(int(fps), 1)
|
||||
safe_width = max(int(target_width), 1)
|
||||
|
||||
# Try FFmpeg first
|
||||
if shutil.which("ffmpeg"):
|
||||
@@ -75,7 +79,8 @@ def parse_video(
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-i", video_path,
|
||||
"-vf", f"fps={fps},scale=640:-1",
|
||||
"-vf", f"fps={safe_fps},scale={safe_width}:-1",
|
||||
"-start_number", "0",
|
||||
"-q:v", "5",
|
||||
"-y",
|
||||
pattern,
|
||||
@@ -102,7 +107,7 @@ def parse_video(
|
||||
raise RuntimeError(f"Cannot open video: {video_path}")
|
||||
|
||||
video_fps = cap.get(cv2.CAP_PROP_FPS) or 30
|
||||
interval = max(1, int(round(video_fps / fps)))
|
||||
interval = max(1, int(round(video_fps / safe_fps)))
|
||||
count = 0
|
||||
saved = 0
|
||||
|
||||
@@ -112,6 +117,10 @@ def parse_video(
|
||||
break
|
||||
if count % interval == 0:
|
||||
path = os.path.join(output_dir, f"frame_{saved:06d}.jpg")
|
||||
h, w = frame.shape[:2]
|
||||
if safe_width > 0 and w != safe_width:
|
||||
scale = safe_width / max(w, 1)
|
||||
frame = cv2.resize(frame, (safe_width, max(1, int(round(h * scale)))), interpolation=cv2.INTER_AREA)
|
||||
cv2.imwrite(path, frame, [cv2.IMWRITE_JPEG_QUALITY, 80])
|
||||
frame_paths.append(path)
|
||||
saved += 1
|
||||
|
||||
@@ -76,6 +76,38 @@ def _project_status_after_stop(project: Project) -> str:
|
||||
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
|
||||
|
||||
|
||||
def _positive_int(value: Any, default: int | None = None) -> int | None:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return parsed if parsed > 0 else default
|
||||
|
||||
|
||||
def _positive_float(value: Any, default: float) -> float:
|
||||
try:
|
||||
parsed = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return parsed if parsed > 0 else default
|
||||
|
||||
|
||||
def _frame_sequence_metadata(
|
||||
index: int,
|
||||
parse_fps: float,
|
||||
original_fps: float | None,
|
||||
) -> dict[str, float | int | None]:
|
||||
safe_parse_fps = max(float(parse_fps or 1.0), 1e-6)
|
||||
timestamp_ms = index * 1000.0 / safe_parse_fps
|
||||
source_frame_number = None
|
||||
if original_fps and original_fps > 0:
|
||||
source_frame_number = int(round(index * original_fps / safe_parse_fps))
|
||||
return {
|
||||
"timestamp_ms": timestamp_ms,
|
||||
"source_frame_number": source_frame_number,
|
||||
}
|
||||
|
||||
|
||||
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
|
||||
db.refresh(task)
|
||||
if task.status == TASK_STATUS_CANCELLED:
|
||||
@@ -138,8 +170,12 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
project.status = PROJECT_STATUS_PARSING
|
||||
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="后台解析已启动", started=True)
|
||||
|
||||
effective_source = (task.payload or {}).get("source_type") or project.source_type or "video"
|
||||
parse_fps = project.parse_fps or 30.0
|
||||
payload = task.payload or {}
|
||||
effective_source = payload.get("source_type") or project.source_type or "video"
|
||||
parse_fps = _positive_float(payload.get("parse_fps"), project.parse_fps or 30.0)
|
||||
max_frames = _positive_int(payload.get("max_frames"))
|
||||
target_width = _positive_int(payload.get("target_width"), 640) or 640
|
||||
project.parse_fps = parse_fps
|
||||
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project.id}_")
|
||||
output_dir = os.path.join(tmp_dir, "frames")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
@@ -163,7 +199,7 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
|
||||
_ensure_not_cancelled(db, task)
|
||||
_set_task_state(db, task, progress=35, message="正在解析 DICOM 序列")
|
||||
frame_files = parse_dicom(dcm_dir, output_dir)
|
||||
frame_files = parse_dicom(dcm_dir, output_dir, max_frames=max_frames)
|
||||
else:
|
||||
_ensure_not_cancelled(db, task)
|
||||
media_bytes = download_file(project.video_path)
|
||||
@@ -173,7 +209,13 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
|
||||
_ensure_not_cancelled(db, task)
|
||||
_set_task_state(db, task, progress=35, message="正在使用 FFmpeg/OpenCV 拆帧")
|
||||
frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps))
|
||||
frame_files, original_fps = parse_video(
|
||||
local_path,
|
||||
output_dir,
|
||||
fps=int(parse_fps),
|
||||
max_frames=max_frames,
|
||||
target_width=target_width,
|
||||
)
|
||||
project.original_fps = original_fps
|
||||
|
||||
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
|
||||
@@ -205,12 +247,15 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
except Exception: # noqa: BLE001
|
||||
h, w = None, None
|
||||
|
||||
sequence_meta = _frame_sequence_metadata(idx, parse_fps, project.original_fps)
|
||||
frame = Frame(
|
||||
project_id=project.id,
|
||||
frame_index=idx,
|
||||
image_url=obj_name,
|
||||
width=w,
|
||||
height=h,
|
||||
timestamp_ms=sequence_meta["timestamp_ms"],
|
||||
source_frame_number=sequence_meta["source_frame_number"],
|
||||
)
|
||||
db.add(frame)
|
||||
frames_out.append(frame)
|
||||
@@ -223,6 +268,17 @@ def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
|
||||
"frames_extracted": len(frames_out),
|
||||
"status": PROJECT_STATUS_READY,
|
||||
"message": "Frame extraction completed successfully.",
|
||||
"frame_sequence": {
|
||||
"original_fps": project.original_fps,
|
||||
"parse_fps": parse_fps,
|
||||
"frame_count": len(frames_out),
|
||||
"duration_ms": (len(frames_out) - 1) * 1000.0 / parse_fps if frames_out else 0,
|
||||
"target_width": target_width,
|
||||
"frame_width": frames_out[0].width if frames_out else None,
|
||||
"frame_height": frames_out[0].height if frames_out else None,
|
||||
"max_frames": max_frames,
|
||||
"object_prefix": f"projects/{project.id}/frames",
|
||||
},
|
||||
}
|
||||
_set_task_state(
|
||||
db,
|
||||
|
||||
@@ -24,6 +24,7 @@ except Exception as exc: # noqa: BLE001
|
||||
|
||||
try:
|
||||
from sam2.build_sam import build_sam2
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
|
||||
SAM2_AVAILABLE = True
|
||||
@@ -38,9 +39,12 @@ class SAM2Engine:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._predictor: Optional[SAM2ImagePredictor] = None
|
||||
self._video_predictor = None
|
||||
self._model_loaded = False
|
||||
self._video_model_loaded = False
|
||||
self._loaded_device: str | None = None
|
||||
self._last_error: str | None = None
|
||||
self._video_last_error: str | None = None
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
@@ -85,6 +89,40 @@ class SAM2Engine:
|
||||
logger.error("Failed to load SAM 2 model: %s", exc)
|
||||
self._model_loaded = True # Prevent repeated load attempts
|
||||
|
||||
def _load_video_model(self) -> None:
|
||||
"""Load the SAM 2 video predictor on first propagation use."""
|
||||
if self._video_model_loaded:
|
||||
return
|
||||
|
||||
if not TORCH_AVAILABLE:
|
||||
self._video_last_error = "PyTorch is not installed."
|
||||
self._video_model_loaded = True
|
||||
return
|
||||
if not SAM2_AVAILABLE:
|
||||
self._video_last_error = "sam2 package is not installed."
|
||||
self._video_model_loaded = True
|
||||
return
|
||||
if not os.path.isfile(settings.sam_model_path):
|
||||
self._video_last_error = f"SAM2 checkpoint not found: {settings.sam_model_path}"
|
||||
self._video_model_loaded = True
|
||||
return
|
||||
|
||||
try:
|
||||
device = self._best_device()
|
||||
self._video_predictor = build_sam2_video_predictor(
|
||||
settings.sam_model_config,
|
||||
settings.sam_model_path,
|
||||
device=device,
|
||||
)
|
||||
self._video_model_loaded = True
|
||||
self._loaded_device = device
|
||||
self._video_last_error = None
|
||||
logger.info("SAM 2 video predictor loaded from %s on %s", settings.sam_model_path, device)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
self._video_last_error = str(exc)
|
||||
self._video_model_loaded = True
|
||||
logger.error("Failed to load SAM 2 video predictor: %s", exc)
|
||||
|
||||
def _best_device(self) -> str:
|
||||
if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available():
|
||||
return "cuda"
|
||||
@@ -95,6 +133,11 @@ class SAM2Engine:
|
||||
self._load_model()
|
||||
return SAM2_AVAILABLE and self._predictor is not None
|
||||
|
||||
def _ensure_video_ready(self) -> bool:
|
||||
"""Ensure the video predictor is loaded; return whether it is usable."""
|
||||
self._load_video_model()
|
||||
return SAM2_AVAILABLE and self._video_predictor is not None
|
||||
|
||||
def status(self) -> dict:
|
||||
"""Return lightweight, real runtime status without forcing model load."""
|
||||
checkpoint_exists = os.path.isfile(settings.sam_model_path)
|
||||
@@ -121,7 +164,7 @@ class SAM2Engine:
|
||||
"available": available,
|
||||
"loaded": self._predictor is not None,
|
||||
"device": device,
|
||||
"supports": ["point", "box", "auto"],
|
||||
"supports": ["point", "box", "interactive", "auto", "propagate"],
|
||||
"message": message,
|
||||
"package_available": SAM2_AVAILABLE,
|
||||
"checkpoint_exists": checkpoint_exists,
|
||||
@@ -221,6 +264,52 @@ class SAM2Engine:
|
||||
logger.error("SAM2 box prediction failed: %s", exc)
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
def predict_interactive(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
box: list[float] | None,
|
||||
points: list[list[float]],
|
||||
labels: list[int],
|
||||
) -> tuple[list[list[list[float]]], list[float]]:
|
||||
"""Run combined box and point prompt segmentation for refinement."""
|
||||
if not self._ensure_ready():
|
||||
logger.warning("SAM2 not ready; returning dummy masks.")
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
try:
|
||||
h, w = image.shape[:2]
|
||||
bbox = None
|
||||
if box:
|
||||
bbox = np.array(
|
||||
[box[0] * w, box[1] * h, box[2] * w, box[3] * h],
|
||||
dtype=np.float32,
|
||||
)
|
||||
pts = None
|
||||
lbls = None
|
||||
if points:
|
||||
pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32)
|
||||
lbls = np.array(labels, dtype=np.int32)
|
||||
|
||||
with torch.inference_mode(): # type: ignore[name-defined]
|
||||
self._predictor.set_image(image)
|
||||
masks, scores, _ = self._predictor.predict(
|
||||
point_coords=pts,
|
||||
point_labels=lbls,
|
||||
box=bbox,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
polygons = []
|
||||
for m in masks:
|
||||
poly = self._mask_to_polygon(m)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
return polygons, scores.tolist()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("SAM2 interactive prediction failed: %s", exc)
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
def predict_auto(self, image: np.ndarray) -> tuple[list[list[list[float]]], list[float]]:
|
||||
"""Run automatic mask generation (grid of points).
|
||||
|
||||
@@ -260,6 +349,89 @@ class SAM2Engine:
|
||||
logger.error("SAM2 auto prediction failed: %s", exc)
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
def propagate_video(
|
||||
self,
|
||||
frame_paths: list[str],
|
||||
source_frame_index: int,
|
||||
seed: dict,
|
||||
direction: str = "forward",
|
||||
max_frames: int | None = None,
|
||||
) -> list[dict]:
|
||||
"""Propagate one seed mask across a prepared frame directory with SAM 2 video."""
|
||||
if not self._ensure_video_ready():
|
||||
raise RuntimeError(self._video_last_error or self.status()["message"])
|
||||
if not frame_paths:
|
||||
return []
|
||||
if source_frame_index < 0 or source_frame_index >= len(frame_paths):
|
||||
raise ValueError("source_frame_index is outside the frame sequence.")
|
||||
|
||||
import cv2
|
||||
|
||||
source_image = cv2.imread(frame_paths[source_frame_index])
|
||||
if source_image is None:
|
||||
raise RuntimeError("Failed to decode source frame for SAM 2 propagation.")
|
||||
height, width = source_image.shape[:2]
|
||||
seed_mask = self._polygons_to_mask(seed.get("polygons") or [], width, height)
|
||||
if not seed_mask.any():
|
||||
bbox = seed.get("bbox")
|
||||
if isinstance(bbox, list) and len(bbox) == 4:
|
||||
seed_mask = self._bbox_to_mask(bbox, width, height)
|
||||
if not seed_mask.any():
|
||||
raise ValueError("SAM 2 propagation requires a non-empty seed polygon or bbox.")
|
||||
|
||||
inference_state = self._video_predictor.init_state(
|
||||
video_path=os.path.dirname(frame_paths[0]),
|
||||
offload_video_to_cpu=True,
|
||||
offload_state_to_cpu=True,
|
||||
)
|
||||
self._video_predictor.add_new_mask(
|
||||
inference_state,
|
||||
frame_idx=source_frame_index,
|
||||
obj_id=1,
|
||||
mask=seed_mask,
|
||||
)
|
||||
|
||||
results: dict[int, dict] = {}
|
||||
|
||||
def collect(reverse: bool) -> None:
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video(
|
||||
inference_state,
|
||||
start_frame_idx=source_frame_index,
|
||||
max_frame_num_to_track=max_frames,
|
||||
reverse=reverse,
|
||||
):
|
||||
masks = out_mask_logits
|
||||
if hasattr(masks, "detach"):
|
||||
masks = masks.detach().cpu().numpy()
|
||||
masks = np.asarray(masks)
|
||||
if masks.ndim == 4:
|
||||
masks = masks[:, 0]
|
||||
polygons = []
|
||||
scores = []
|
||||
for mask in masks:
|
||||
polygon = self._mask_to_polygon(mask > 0)
|
||||
if polygon:
|
||||
polygons.append(polygon)
|
||||
scores.append(1.0)
|
||||
results[int(out_frame_idx)] = {
|
||||
"frame_index": int(out_frame_idx),
|
||||
"polygons": polygons,
|
||||
"scores": scores,
|
||||
"object_ids": [int(obj_id) for obj_id in list(out_obj_ids)],
|
||||
}
|
||||
|
||||
normalized_direction = direction.lower()
|
||||
if normalized_direction in {"forward", "both"}:
|
||||
collect(reverse=False)
|
||||
if normalized_direction in {"backward", "both"}:
|
||||
collect(reverse=True)
|
||||
|
||||
try:
|
||||
self._video_predictor.reset_state(inference_state)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
return [results[index] for index in sorted(results)]
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -----------------------------------------------------------------------
|
||||
@@ -292,6 +464,38 @@ class SAM2Engine:
|
||||
]
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _polygons_to_mask(polygons: list[list[list[float]]], width: int, height: int) -> np.ndarray:
|
||||
import cv2
|
||||
|
||||
mask = np.zeros((height, width), dtype=np.uint8)
|
||||
for polygon in polygons:
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
pts = np.array(
|
||||
[
|
||||
[
|
||||
int(round(min(max(float(x), 0.0), 1.0) * max(width - 1, 1))),
|
||||
int(round(min(max(float(y), 0.0), 1.0) * max(height - 1, 1))),
|
||||
]
|
||||
for x, y in polygon
|
||||
],
|
||||
dtype=np.int32,
|
||||
)
|
||||
cv2.fillPoly(mask, [pts], 1)
|
||||
return mask.astype(bool)
|
||||
|
||||
@staticmethod
|
||||
def _bbox_to_mask(bbox: list[float], width: int, height: int) -> np.ndarray:
|
||||
x, y, w, h = [min(max(float(value), 0.0), 1.0) for value in bbox]
|
||||
left = int(round(x * max(width - 1, 1)))
|
||||
top = int(round(y * max(height - 1, 1)))
|
||||
right = int(round(min(x + w, 1.0) * max(width - 1, 1)))
|
||||
bottom = int(round(min(y + h, 1.0) * max(height - 1, 1)))
|
||||
mask = np.zeros((height, width), dtype=bool)
|
||||
mask[top:max(bottom + 1, top + 1), left:max(right + 1, left + 1)] = True
|
||||
return mask
|
||||
|
||||
|
||||
# Singleton instance
|
||||
sam_engine = SAM2Engine()
|
||||
|
||||
@@ -56,8 +56,22 @@ class SAM3Engine:
|
||||
def _gpu_ok(self) -> bool:
|
||||
return bool(TORCH_AVAILABLE and torch is not None and torch.cuda.is_available())
|
||||
|
||||
def _checkpoint_path(self) -> str | None:
|
||||
path = settings.sam3_checkpoint_path.strip()
|
||||
return path if path else None
|
||||
|
||||
def _checkpoint_exists(self) -> bool:
|
||||
path = self._checkpoint_path()
|
||||
return bool(path and os.path.isfile(path))
|
||||
|
||||
def _can_load(self) -> bool:
|
||||
return bool(SAM3_PACKAGE_AVAILABLE and TORCH_AVAILABLE and self._python_ok() and self._gpu_ok())
|
||||
return bool(
|
||||
SAM3_PACKAGE_AVAILABLE
|
||||
and TORCH_AVAILABLE
|
||||
and self._python_ok()
|
||||
and self._gpu_ok()
|
||||
and self._checkpoint_exists()
|
||||
)
|
||||
|
||||
def _worker_path(self) -> Path:
|
||||
return Path(__file__).with_name("sam3_external_worker.py")
|
||||
@@ -98,6 +112,8 @@ class SAM3Engine:
|
||||
try:
|
||||
env = os.environ.copy()
|
||||
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
|
||||
if self._checkpoint_path():
|
||||
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
|
||||
completed = subprocess.run(
|
||||
[settings.sam3_external_python, str(self._worker_path()), "--status"],
|
||||
capture_output=True,
|
||||
@@ -146,7 +162,10 @@ class SAM3Engine:
|
||||
from sam3.model.sam3_image_processor import Sam3Processor
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
|
||||
self._model = build_sam3_image_model()
|
||||
self._model = build_sam3_image_model(
|
||||
checkpoint_path=self._checkpoint_path(),
|
||||
load_from_HF=False,
|
||||
)
|
||||
self._processor = Sam3Processor(self._model)
|
||||
self._model_loaded = True
|
||||
self._last_error = None
|
||||
@@ -170,6 +189,8 @@ class SAM3Engine:
|
||||
missing.append("PyTorch")
|
||||
if not self._gpu_ok():
|
||||
missing.append("CUDA GPU")
|
||||
if not self._checkpoint_exists():
|
||||
missing.append(f"local checkpoint ({settings.sam3_checkpoint_path})")
|
||||
if missing:
|
||||
return f"SAM 3 unavailable: missing {', '.join(missing)}."
|
||||
return "SAM 3 dependencies are present; model will load on first inference."
|
||||
@@ -182,7 +203,7 @@ class SAM3Engine:
|
||||
if self._processor is not None:
|
||||
message = "SAM 3 model loaded and ready."
|
||||
elif external_ready:
|
||||
message = "SAM 3 external runtime is ready; model will load in the helper process on inference."
|
||||
message = "SAM 3 external runtime is ready; local checkpoint will load in the helper process on inference."
|
||||
elif external_status.get("message") and not self._can_load():
|
||||
message = str(external_status["message"])
|
||||
return {
|
||||
@@ -191,11 +212,11 @@ class SAM3Engine:
|
||||
"available": available,
|
||||
"loaded": self._processor is not None,
|
||||
"device": "cuda" if self._gpu_ok() else str(external_status.get("device", "unavailable")),
|
||||
"supports": ["semantic"],
|
||||
"supports": ["semantic", "box", "video_track"],
|
||||
"message": message,
|
||||
"package_available": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("package_available")),
|
||||
"checkpoint_exists": bool(SAM3_PACKAGE_AVAILABLE or external_status.get("checkpoint_access")),
|
||||
"checkpoint_path": f"official/HuggingFace ({settings.sam3_model_version})",
|
||||
"checkpoint_exists": bool(self._checkpoint_exists() or external_status.get("checkpoint_access")),
|
||||
"checkpoint_path": self._checkpoint_path() or f"official/HuggingFace ({settings.sam3_model_version})",
|
||||
"python_ok": bool(self._python_ok() or external_status.get("python_ok")),
|
||||
"torch_ok": bool(TORCH_AVAILABLE or external_status.get("torch_ok")),
|
||||
"cuda_required": True,
|
||||
@@ -203,7 +224,43 @@ class SAM3Engine:
|
||||
"external_python": settings.sam3_external_python if settings.sam3_external_enabled else None,
|
||||
}
|
||||
|
||||
def _predict_semantic_external(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
|
||||
def _xyxy_to_cxcywh(self, box: list[float]) -> list[float]:
|
||||
if len(box) != 4:
|
||||
raise ValueError("SAM 3 box prompt requires [x1, y1, x2, y2].")
|
||||
x1, y1, x2, y2 = [min(max(float(value), 0.0), 1.0) for value in box]
|
||||
left, right = sorted([x1, x2])
|
||||
top, bottom = sorted([y1, y2])
|
||||
width = max(right - left, 1e-6)
|
||||
height = max(bottom - top, 1e-6)
|
||||
return [left + width / 2, top + height / 2, width, height]
|
||||
|
||||
def _prediction_to_polygons(self, output: Any) -> tuple[list[list[list[float]]], list[float]]:
|
||||
masks = output.get("masks", [])
|
||||
scores = output.get("scores", [])
|
||||
polygons = []
|
||||
for mask in masks:
|
||||
if hasattr(mask, "detach"):
|
||||
mask = mask.detach().cpu().numpy()
|
||||
if mask.ndim == 3:
|
||||
mask = mask[0]
|
||||
poly = SAM2Engine._mask_to_polygon(mask)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
if hasattr(scores, "detach"):
|
||||
scores = scores.detach().cpu().tolist()
|
||||
elif hasattr(scores, "tolist"):
|
||||
scores = scores.tolist()
|
||||
return polygons, list(scores)
|
||||
|
||||
def _predict_external(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
prompt_type: str,
|
||||
*,
|
||||
text: str = "",
|
||||
box: list[float] | None = None,
|
||||
) -> tuple[list[list[list[float]]], list[float]]:
|
||||
status = self._external_status(force=True)
|
||||
if not status.get("available"):
|
||||
raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.")
|
||||
@@ -217,8 +274,11 @@ class SAM3Engine:
|
||||
json.dumps(
|
||||
{
|
||||
"image_path": str(image_path),
|
||||
"prompt_type": prompt_type,
|
||||
"text": text.strip(),
|
||||
"box": box,
|
||||
"model_version": settings.sam3_model_version,
|
||||
"checkpoint_path": self._checkpoint_path(),
|
||||
"confidence_threshold": settings.sam3_confidence_threshold,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
@@ -227,6 +287,8 @@ class SAM3Engine:
|
||||
)
|
||||
env = os.environ.copy()
|
||||
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
|
||||
if self._checkpoint_path():
|
||||
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
|
||||
completed = subprocess.run(
|
||||
[settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)],
|
||||
capture_output=True,
|
||||
@@ -250,6 +312,72 @@ class SAM3Engine:
|
||||
raise RuntimeError(str(payload["error"]))
|
||||
return payload.get("polygons", []), payload.get("scores", [])
|
||||
|
||||
def _predict_semantic_external(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
|
||||
return self._predict_external(image, "semantic", text=text)
|
||||
|
||||
def _predict_box_external(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]:
|
||||
return self._predict_external(image, "box", box=box)
|
||||
|
||||
def _propagate_video_external(
|
||||
self,
|
||||
frame_paths: list[str],
|
||||
source_frame_index: int,
|
||||
seed: dict[str, Any],
|
||||
direction: str,
|
||||
max_frames: int | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
status = self._external_status(force=True)
|
||||
if not status.get("available"):
|
||||
raise RuntimeError(status.get("message") or "SAM 3 external runtime is unavailable.")
|
||||
if not frame_paths:
|
||||
return []
|
||||
|
||||
with tempfile.TemporaryDirectory(prefix="sam3_video_") as tmpdir:
|
||||
request_path = Path(tmpdir) / "request.json"
|
||||
request_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"prompt_type": "video_track",
|
||||
"frame_dir": str(Path(frame_paths[0]).parent),
|
||||
"source_frame_index": source_frame_index,
|
||||
"seed": seed,
|
||||
"direction": direction,
|
||||
"max_frames": max_frames,
|
||||
"model_version": settings.sam3_model_version,
|
||||
"checkpoint_path": self._checkpoint_path(),
|
||||
"confidence_threshold": settings.sam3_confidence_threshold,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
env = os.environ.copy()
|
||||
env["SAM3_MODEL_VERSION"] = settings.sam3_model_version
|
||||
if self._checkpoint_path():
|
||||
env["SAM3_CHECKPOINT_PATH"] = self._checkpoint_path() or ""
|
||||
completed = subprocess.run(
|
||||
[settings.sam3_external_python, str(self._worker_path()), "--request", str(request_path)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=settings.sam3_timeout_seconds,
|
||||
check=False,
|
||||
env=env,
|
||||
)
|
||||
|
||||
if completed.returncode != 0:
|
||||
detail = completed.stderr.strip() or completed.stdout.strip()
|
||||
try:
|
||||
parsed = json.loads(detail)
|
||||
detail = parsed.get("error", detail)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
raise RuntimeError(f"SAM 3 external video tracking failed: {detail}")
|
||||
|
||||
payload = json.loads(completed.stdout)
|
||||
if payload.get("error"):
|
||||
raise RuntimeError(str(payload["error"]))
|
||||
return payload.get("frames", [])
|
||||
|
||||
def predict_semantic(self, image: np.ndarray, text: str) -> tuple[list[list[list[float]]], list[float]]:
|
||||
if not text.strip():
|
||||
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
|
||||
@@ -263,29 +391,37 @@ class SAM3Engine:
|
||||
state = self._processor.set_image(pil_image)
|
||||
output = self._processor.set_text_prompt(state=state, prompt=text.strip())
|
||||
|
||||
masks = output.get("masks", [])
|
||||
scores = output.get("scores", [])
|
||||
polygons = []
|
||||
for mask in masks:
|
||||
if hasattr(mask, "detach"):
|
||||
mask = mask.detach().cpu().numpy()
|
||||
if mask.ndim == 3:
|
||||
mask = mask[0]
|
||||
poly = SAM2Engine._mask_to_polygon(mask)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
if hasattr(scores, "detach"):
|
||||
scores = scores.detach().cpu().tolist()
|
||||
elif hasattr(scores, "tolist"):
|
||||
scores = scores.tolist()
|
||||
return polygons, list(scores)
|
||||
return self._prediction_to_polygons(output)
|
||||
|
||||
def predict_points(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
|
||||
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for point prompts.")
|
||||
|
||||
def predict_box(self, *_args: Any, **_kwargs: Any) -> tuple[list[list[list[float]]], list[float]]:
|
||||
raise NotImplementedError("This backend currently exposes SAM 3 semantic text inference; use SAM 2 for box prompts.")
|
||||
def predict_box(self, image: np.ndarray, box: list[float]) -> tuple[list[list[list[float]]], list[float]]:
|
||||
if not self._can_load() and self._external_status().get("available"):
|
||||
return self._predict_box_external(image, box)
|
||||
if not self._ensure_ready():
|
||||
raise RuntimeError(self.status()["message"])
|
||||
|
||||
pil_image = Image.fromarray(image)
|
||||
with torch.inference_mode(): # type: ignore[union-attr]
|
||||
state = self._processor.set_image(pil_image)
|
||||
output = self._processor.add_geometric_prompt(
|
||||
state=state,
|
||||
box=self._xyxy_to_cxcywh(box),
|
||||
label=True,
|
||||
)
|
||||
|
||||
return self._prediction_to_polygons(output)
|
||||
|
||||
def propagate_video(
|
||||
self,
|
||||
frame_paths: list[str],
|
||||
source_frame_index: int,
|
||||
seed: dict[str, Any],
|
||||
direction: str = "forward",
|
||||
max_frames: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
return self._propagate_video_external(frame_paths, source_frame_index, seed, direction, max_frames)
|
||||
|
||||
|
||||
sam3_engine = SAM3Engine()
|
||||
|
||||
@@ -43,6 +43,13 @@ def _compact_error(exc: Exception) -> str:
|
||||
|
||||
|
||||
def _checkpoint_access(model_version: str) -> tuple[bool, str | None]:
|
||||
checkpoint_path = os.environ.get("SAM3_CHECKPOINT_PATH", "").strip()
|
||||
if checkpoint_path:
|
||||
path = Path(checkpoint_path)
|
||||
if path.is_file():
|
||||
return True, None
|
||||
return False, f"local checkpoint not found: {checkpoint_path}"
|
||||
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
@@ -55,6 +62,7 @@ def _checkpoint_access(model_version: str) -> tuple[bool, str | None]:
|
||||
|
||||
def runtime_status() -> dict[str, Any]:
|
||||
model_version = os.environ.get("SAM3_MODEL_VERSION", "sam3")
|
||||
checkpoint_path = os.environ.get("SAM3_CHECKPOINT_PATH", "").strip() or None
|
||||
package_error = None
|
||||
package_available = importlib.util.find_spec("sam3") is not None
|
||||
if package_available:
|
||||
@@ -85,6 +93,7 @@ def runtime_status() -> dict[str, Any]:
|
||||
"available": available,
|
||||
"package_available": package_available,
|
||||
"checkpoint_access": checkpoint_access,
|
||||
"checkpoint_path": checkpoint_path or f"official/HuggingFace ({model_version})",
|
||||
"python_ok": python_ok,
|
||||
"torch_ok": torch_version is not None,
|
||||
"torch_version": torch_version,
|
||||
@@ -118,34 +127,67 @@ def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
|
||||
|
||||
def _to_numpy(value: Any) -> np.ndarray:
|
||||
if hasattr(value, "detach"):
|
||||
value = value.detach().cpu().numpy()
|
||||
elif hasattr(value, "cpu"):
|
||||
value = value.detach()
|
||||
if hasattr(value, "is_floating_point") and value.is_floating_point():
|
||||
value = value.float()
|
||||
value = value.cpu().numpy()
|
||||
elif hasattr(value, "cpu"):
|
||||
value = value.cpu()
|
||||
if hasattr(value, "is_floating_point") and value.is_floating_point():
|
||||
value = value.float()
|
||||
value = value.numpy()
|
||||
return np.asarray(value)
|
||||
|
||||
|
||||
def predict(request_path: Path) -> dict[str, Any]:
|
||||
import torch
|
||||
from sam3.model.sam3_image_processor import Sam3Processor
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
def _xyxy_to_cxcywh(box: list[float]) -> list[float]:
|
||||
if len(box) != 4:
|
||||
raise ValueError("SAM 3 box prompt requires [x1, y1, x2, y2].")
|
||||
x1, y1, x2, y2 = [min(max(float(value), 0.0), 1.0) for value in box]
|
||||
left, right = sorted([x1, x2])
|
||||
top, bottom = sorted([y1, y2])
|
||||
width = max(right - left, 1e-6)
|
||||
height = max(bottom - top, 1e-6)
|
||||
return [left + width / 2, top + height / 2, width, height]
|
||||
|
||||
payload = json.loads(request_path.read_text(encoding="utf-8"))
|
||||
image_path = Path(payload["image_path"])
|
||||
text = str(payload["text"]).strip()
|
||||
threshold = float(payload.get("confidence_threshold", 0.5))
|
||||
if not text:
|
||||
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
def _bbox_from_seed(seed: dict[str, Any]) -> list[float]:
|
||||
bbox = seed.get("bbox")
|
||||
if isinstance(bbox, list) and len(bbox) == 4:
|
||||
return [min(max(float(value), 0.0), 1.0) for value in bbox]
|
||||
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
model = build_sam3_image_model()
|
||||
processor = Sam3Processor(model, confidence_threshold=threshold)
|
||||
state = processor.set_image(image)
|
||||
output = processor.set_text_prompt(state=state, prompt=text)
|
||||
polygons = seed.get("polygons") or []
|
||||
points = [point for polygon in polygons for point in polygon if len(point) >= 2]
|
||||
if not points:
|
||||
raise ValueError("SAM 3 video tracking requires seed bbox or polygons.")
|
||||
xs = [min(max(float(point[0]), 0.0), 1.0) for point in points]
|
||||
ys = [min(max(float(point[1]), 0.0), 1.0) for point in points]
|
||||
left, right = min(xs), max(xs)
|
||||
top, bottom = min(ys), max(ys)
|
||||
return [left, top, max(right - left, 1e-6), max(bottom - top, 1e-6)]
|
||||
|
||||
|
||||
def _video_outputs_to_response(outputs: dict[str, Any]) -> dict[str, Any]:
|
||||
masks = _to_numpy(outputs.get("out_binary_masks", []))
|
||||
scores = _to_numpy(outputs.get("out_probs", []))
|
||||
obj_ids = _to_numpy(outputs.get("out_obj_ids", []))
|
||||
if masks.ndim == 4:
|
||||
masks = masks[:, 0]
|
||||
elif masks.ndim == 2:
|
||||
masks = masks[None, ...]
|
||||
|
||||
polygons = []
|
||||
out_scores = []
|
||||
out_ids = []
|
||||
for index, mask in enumerate(masks):
|
||||
polygon = _mask_to_polygon(mask)
|
||||
if polygon:
|
||||
polygons.append(polygon)
|
||||
out_scores.append(float(scores[index]) if scores.size > index else 1.0)
|
||||
out_ids.append(int(obj_ids[index]) if obj_ids.size > index else index + 1)
|
||||
return {"polygons": polygons, "scores": out_scores, "object_ids": out_ids}
|
||||
|
||||
|
||||
def _prediction_to_response(output: dict[str, Any]) -> dict[str, Any]:
|
||||
masks = _to_numpy(output.get("masks", []))
|
||||
scores = _to_numpy(output.get("scores", []))
|
||||
if masks.ndim == 4:
|
||||
@@ -165,6 +207,115 @@ def predict(request_path: Path) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def predict_video(request_path: Path) -> dict[str, Any]:
|
||||
import torch
|
||||
from sam3.model_builder import build_sam3_video_predictor
|
||||
|
||||
payload = json.loads(request_path.read_text(encoding="utf-8"))
|
||||
frame_dir = Path(payload["frame_dir"])
|
||||
source_frame_index = int(payload.get("source_frame_index", 0))
|
||||
seed = payload.get("seed") or {}
|
||||
direction = str(payload.get("direction") or "forward").lower()
|
||||
max_frames = payload.get("max_frames")
|
||||
max_frames = int(max_frames) if max_frames else None
|
||||
checkpoint_path = str(payload.get("checkpoint_path") or os.environ.get("SAM3_CHECKPOINT_PATH", "")).strip()
|
||||
threshold = float(payload.get("confidence_threshold", 0.5))
|
||||
if direction not in {"forward", "backward", "both"}:
|
||||
raise ValueError(f"Unsupported propagation direction: {direction}")
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
predictor = build_sam3_video_predictor(
|
||||
checkpoint_path=checkpoint_path or None,
|
||||
async_loading_frames=False,
|
||||
)
|
||||
session_id = None
|
||||
try:
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
session = predictor.handle_request(
|
||||
{
|
||||
"type": "start_session",
|
||||
"resource_path": str(frame_dir),
|
||||
"offload_video_to_cpu": True,
|
||||
"offload_state_to_cpu": True,
|
||||
}
|
||||
)
|
||||
session_id = session["session_id"]
|
||||
predictor.handle_request(
|
||||
{
|
||||
"type": "add_prompt",
|
||||
"session_id": session_id,
|
||||
"frame_index": source_frame_index,
|
||||
"bounding_boxes": [_bbox_from_seed(seed)],
|
||||
"bounding_box_labels": [1],
|
||||
"output_prob_thresh": threshold,
|
||||
"rel_coordinates": True,
|
||||
}
|
||||
)
|
||||
frames = []
|
||||
for item in predictor.handle_stream_request(
|
||||
{
|
||||
"type": "propagate_in_video",
|
||||
"session_id": session_id,
|
||||
"propagation_direction": direction,
|
||||
"start_frame_index": source_frame_index,
|
||||
"max_frame_num_to_track": max_frames,
|
||||
"output_prob_thresh": threshold,
|
||||
}
|
||||
):
|
||||
frame_response = _video_outputs_to_response(item.get("outputs") or {})
|
||||
frame_response["frame_index"] = int(item["frame_index"])
|
||||
frames.append(frame_response)
|
||||
finally:
|
||||
if session_id:
|
||||
predictor.handle_request({"type": "close_session", "session_id": session_id})
|
||||
|
||||
return {"frames": frames}
|
||||
|
||||
|
||||
def predict(request_path: Path) -> dict[str, Any]:
|
||||
import torch
|
||||
from sam3.model.sam3_image_processor import Sam3Processor
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
|
||||
payload = json.loads(request_path.read_text(encoding="utf-8"))
|
||||
if str(payload.get("prompt_type") or "").strip().lower() == "video_track":
|
||||
return predict_video(request_path)
|
||||
|
||||
image_path = Path(payload["image_path"])
|
||||
prompt_type = str(payload.get("prompt_type") or "semantic").strip().lower()
|
||||
text = str(payload.get("text") or "").strip()
|
||||
threshold = float(payload.get("confidence_threshold", 0.5))
|
||||
checkpoint_path = str(payload.get("checkpoint_path") or os.environ.get("SAM3_CHECKPOINT_PATH", "")).strip()
|
||||
if prompt_type == "semantic" and not text:
|
||||
raise ValueError("SAM 3 semantic prompt requires non-empty text.")
|
||||
if prompt_type not in {"semantic", "box"}:
|
||||
raise ValueError(f"Unsupported SAM 3 prompt type: {prompt_type}")
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
model = build_sam3_image_model(
|
||||
checkpoint_path=checkpoint_path or None,
|
||||
load_from_HF=not bool(checkpoint_path),
|
||||
)
|
||||
processor = Sam3Processor(model, confidence_threshold=threshold)
|
||||
state = processor.set_image(image)
|
||||
if prompt_type == "box":
|
||||
output = processor.add_geometric_prompt(
|
||||
state=state,
|
||||
box=_xyxy_to_cxcywh(payload.get("box") or []),
|
||||
label=True,
|
||||
)
|
||||
else:
|
||||
output = processor.set_text_prompt(state=state, prompt=text)
|
||||
|
||||
return _prediction_to_response(output)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="SAM 3 external runtime helper")
|
||||
parser.add_argument("--status", action="store_true")
|
||||
|
||||
@@ -67,6 +67,19 @@ class SAMRegistry:
|
||||
def predict_box(self, model_id: str | None, image: Any, box: list[float]):
|
||||
return self._ensure_available(model_id).predict_box(image, box)
|
||||
|
||||
def predict_interactive(
|
||||
self,
|
||||
model_id: str | None,
|
||||
image: Any,
|
||||
box: list[float] | None,
|
||||
points: list[list[float]],
|
||||
labels: list[int],
|
||||
):
|
||||
model = self.normalize_model_id(model_id)
|
||||
if model != "sam2":
|
||||
raise NotImplementedError("Interactive box + point refinement is currently supported by SAM 2.")
|
||||
return self._ensure_available(model).predict_interactive(image, box, points, labels)
|
||||
|
||||
def predict_auto(self, model_id: str | None, image: Any):
|
||||
return self._ensure_available(model_id).predict_auto(image)
|
||||
|
||||
@@ -76,5 +89,22 @@ class SAMRegistry:
|
||||
return self._ensure_available(model).predict_semantic(image, text)
|
||||
return self._ensure_available(model).predict_auto(image)
|
||||
|
||||
def propagate_video(
|
||||
self,
|
||||
model_id: str | None,
|
||||
frame_paths: list[str],
|
||||
source_frame_index: int,
|
||||
seed: dict[str, Any],
|
||||
direction: str,
|
||||
max_frames: int | None,
|
||||
):
|
||||
return self._ensure_available(model_id).propagate_video(
|
||||
frame_paths,
|
||||
source_frame_index,
|
||||
seed,
|
||||
direction=direction,
|
||||
max_frames=max_frames,
|
||||
)
|
||||
|
||||
|
||||
sam_registry = SAMRegistry()
|
||||
|
||||
@@ -116,6 +116,44 @@ def test_predict_box_and_semantic_fallback(client, monkeypatch):
|
||||
assert semantic_response.json()["scores"] == [0.5]
|
||||
|
||||
|
||||
def test_predict_interactive_combines_box_and_points(client, monkeypatch):
|
||||
_, frame, _ = _create_project_and_frame(client)
|
||||
calls = {}
|
||||
monkeypatch.setattr("routers.ai._load_frame_image", lambda frame: np.zeros((10, 10, 3), dtype=np.uint8))
|
||||
|
||||
def fake_predict_interactive(model, image, box, points, labels):
|
||||
calls["model"] = model
|
||||
calls["box"] = box
|
||||
calls["points"] = points
|
||||
calls["labels"] = labels
|
||||
return (
|
||||
[[[0.2, 0.2], [0.8, 0.2], [0.8, 0.8]]],
|
||||
[0.88],
|
||||
)
|
||||
|
||||
monkeypatch.setattr("routers.ai.sam_registry.predict_interactive", fake_predict_interactive)
|
||||
|
||||
response = client.post("/api/ai/predict", json={
|
||||
"image_id": frame["id"],
|
||||
"prompt_type": "interactive",
|
||||
"prompt_data": {
|
||||
"box": [0.1, 0.1, 0.9, 0.9],
|
||||
"points": [[0.5, 0.5], [0.2, 0.2]],
|
||||
"labels": [1, 0],
|
||||
},
|
||||
"model": "sam2",
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["scores"] == [0.88]
|
||||
assert calls == {
|
||||
"model": "sam2",
|
||||
"box": [0.1, 0.1, 0.9, 0.9],
|
||||
"points": [[0.5, 0.5], [0.2, 0.2]],
|
||||
"labels": [1, 0],
|
||||
}
|
||||
|
||||
|
||||
def test_model_status_reports_runtime(client, monkeypatch):
|
||||
monkeypatch.setattr("routers.ai.sam_registry.runtime_status", lambda selected_model=None: {
|
||||
"selected_model": selected_model or "sam2",
|
||||
@@ -170,6 +208,80 @@ def test_model_status_reports_runtime(client, monkeypatch):
|
||||
assert body["models"][1]["available"] is False
|
||||
|
||||
|
||||
def test_propagate_saves_tracked_annotations(client, monkeypatch):
|
||||
project = client.post("/api/projects", json={"name": "Video Project"}).json()
|
||||
frames = [
|
||||
client.post(f"/api/projects/{project['id']}/frames", json={
|
||||
"project_id": project["id"],
|
||||
"frame_index": idx,
|
||||
"image_url": f"frames/{idx}.jpg",
|
||||
"width": 640,
|
||||
"height": 360,
|
||||
}).json()
|
||||
for idx in range(3)
|
||||
]
|
||||
calls = {}
|
||||
monkeypatch.setattr("routers.ai.download_file", lambda object_name: b"jpeg")
|
||||
|
||||
def fake_propagate_video(model, frame_paths, source_frame_index, seed, direction, max_frames):
|
||||
calls["model"] = model
|
||||
calls["source_frame_index"] = source_frame_index
|
||||
calls["seed"] = seed
|
||||
calls["direction"] = direction
|
||||
calls["max_frames"] = max_frames
|
||||
calls["frame_count"] = len(frame_paths)
|
||||
return [
|
||||
{
|
||||
"frame_index": 0,
|
||||
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
|
||||
"scores": [0.9],
|
||||
"object_ids": [1],
|
||||
},
|
||||
{
|
||||
"frame_index": 1,
|
||||
"polygons": [[[0.15, 0.15], [0.25, 0.15], [0.25, 0.25]]],
|
||||
"scores": [0.8],
|
||||
"object_ids": [1],
|
||||
},
|
||||
]
|
||||
|
||||
monkeypatch.setattr("routers.ai.sam_registry.propagate_video", fake_propagate_video)
|
||||
|
||||
response = client.post("/api/ai/propagate", json={
|
||||
"project_id": project["id"],
|
||||
"frame_id": frames[0]["id"],
|
||||
"model": "sam2",
|
||||
"direction": "forward",
|
||||
"max_frames": 2,
|
||||
"include_source": False,
|
||||
"seed": {
|
||||
"polygons": [[[0.1, 0.1], [0.2, 0.1], [0.2, 0.2]]],
|
||||
"bbox": [0.1, 0.1, 0.1, 0.1],
|
||||
"label": "胆囊",
|
||||
"color": "#ff0000",
|
||||
"class_metadata": {"id": "c1", "name": "胆囊", "color": "#ff0000", "zIndex": 20},
|
||||
"template_id": None,
|
||||
},
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["created_annotation_count"] == 1
|
||||
assert body["processed_frame_count"] == 2
|
||||
assert calls["model"] == "sam2"
|
||||
assert calls["source_frame_index"] == 0
|
||||
assert calls["direction"] == "forward"
|
||||
assert calls["frame_count"] == 2
|
||||
saved = body["annotations"][0]
|
||||
assert saved["frame_id"] == frames[1]["id"]
|
||||
assert saved["mask_data"]["source"] == "sam2_propagation"
|
||||
assert saved["mask_data"]["class"]["name"] == "胆囊"
|
||||
assert saved["mask_data"]["score"] == 0.8
|
||||
|
||||
listing = client.get(f"/api/ai/annotations?project_id={project['id']}")
|
||||
assert len(listing.json()) == 1
|
||||
|
||||
|
||||
def test_predict_validation_errors(client, monkeypatch):
|
||||
project, _, _ = _create_project_and_frame(client)
|
||||
|
||||
|
||||
@@ -84,6 +84,12 @@ def test_parse_media_queues_background_task(client, monkeypatch):
|
||||
assert data["progress"] == 0
|
||||
assert data["project_id"] == project["id"]
|
||||
assert data["celery_task_id"] == "celery-1"
|
||||
assert data["payload"] == {
|
||||
"source_type": "video",
|
||||
"parse_fps": 5.0,
|
||||
"max_frames": None,
|
||||
"target_width": 640,
|
||||
}
|
||||
assert queued == [data["id"]]
|
||||
assert published == [data["id"]]
|
||||
|
||||
@@ -94,6 +100,35 @@ def test_parse_media_queues_background_task(client, monkeypatch):
|
||||
assert project_detail["status"] == "parsing"
|
||||
|
||||
|
||||
def test_parse_media_accepts_frame_sequence_options(client, monkeypatch):
|
||||
project = client.post("/api/projects", json={
|
||||
"name": "Parse Options",
|
||||
"video_path": "uploads/1/clip.mp4",
|
||||
"source_type": "video",
|
||||
"parse_fps": 30,
|
||||
}).json()
|
||||
|
||||
class FakeAsyncResult:
|
||||
id = "celery-options"
|
||||
|
||||
monkeypatch.setattr("routers.media.parse_project_media.delay", lambda task_id: FakeAsyncResult())
|
||||
monkeypatch.setattr("routers.media.publish_task_progress_event", lambda task: None)
|
||||
|
||||
response = client.post(
|
||||
f"/api/media/parse?project_id={project['id']}&parse_fps=15&max_frames=120&target_width=960"
|
||||
)
|
||||
|
||||
assert response.status_code == 202
|
||||
data = response.json()
|
||||
assert data["payload"] == {
|
||||
"source_type": "video",
|
||||
"parse_fps": 15.0,
|
||||
"max_frames": 120,
|
||||
"target_width": 960,
|
||||
}
|
||||
assert client.get(f"/api/projects/{project['id']}").json()["parse_fps"] == 15.0
|
||||
|
||||
|
||||
def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp_path):
|
||||
from models import ProcessingTask
|
||||
from services.media_task_runner import run_parse_media_task
|
||||
@@ -118,10 +153,14 @@ def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp
|
||||
frame_file.write_bytes(b"fake image")
|
||||
|
||||
monkeypatch.setattr("services.media_task_runner.download_file", lambda object_name: b"video")
|
||||
monkeypatch.setattr("services.media_task_runner.parse_video", lambda local_path, output_dir, fps: ([str(frame_file)], 25.0))
|
||||
monkeypatch.setattr(
|
||||
"services.media_task_runner.parse_video",
|
||||
lambda local_path, output_dir, fps, max_frames=None, target_width=640: ([str(frame_file)], 25.0),
|
||||
)
|
||||
monkeypatch.setattr("services.media_task_runner.extract_thumbnail", lambda local_path, thumbnail_path: open(thumbnail_path, "wb").write(b"thumb"))
|
||||
monkeypatch.setattr("services.media_task_runner.upload_file", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr("services.media_task_runner.upload_frames_to_minio", lambda frame_files, project_id: [f"projects/{project_id}/frames/frame_000001.jpg"])
|
||||
monkeypatch.setattr("routers.projects.get_presigned_url", lambda object_name, expires=3600: f"http://storage/{object_name}")
|
||||
published = []
|
||||
monkeypatch.setattr(
|
||||
"services.media_task_runner.publish_task_progress_event",
|
||||
@@ -131,6 +170,17 @@ def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp
|
||||
result = run_parse_media_task(db_session, task.id)
|
||||
|
||||
assert result["frames_extracted"] == 1
|
||||
assert result["frame_sequence"] == {
|
||||
"original_fps": 25.0,
|
||||
"parse_fps": 5.0,
|
||||
"frame_count": 1,
|
||||
"duration_ms": 0.0,
|
||||
"target_width": 640,
|
||||
"frame_width": None,
|
||||
"frame_height": None,
|
||||
"max_frames": None,
|
||||
"object_prefix": f"projects/{project['id']}/frames",
|
||||
}
|
||||
db_session.refresh(task)
|
||||
assert task.status == "success"
|
||||
assert task.progress == 100
|
||||
@@ -140,6 +190,8 @@ def test_parse_task_runner_registers_frames(client, db_session, monkeypatch, tmp
|
||||
assert project_detail["status"] == "ready"
|
||||
frames = client.get(f"/api/projects/{project['id']}/frames").json()
|
||||
assert "frame_000001.jpg" in frames[0]["image_url"]
|
||||
assert frames[0]["timestamp_ms"] == 0.0
|
||||
assert frames[0]["source_frame_number"] == 0
|
||||
|
||||
|
||||
def test_parse_task_runner_skips_already_cancelled_task(db_session):
|
||||
|
||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
from services.sam3_engine import SAM3Engine
|
||||
from services.sam3_external_worker import _to_numpy
|
||||
|
||||
|
||||
class _Completed:
|
||||
@@ -14,6 +15,8 @@ class _Completed:
|
||||
|
||||
|
||||
def _external_settings(monkeypatch, python_path: Path):
|
||||
checkpoint_path = python_path.with_name("sam3.pt")
|
||||
checkpoint_path.write_bytes(b"checkpoint")
|
||||
python_path.write_text("#!/usr/bin/env python\n", encoding="utf-8")
|
||||
python_path.chmod(0o755)
|
||||
monkeypatch.setattr("services.sam3_engine.SAM3_PACKAGE_AVAILABLE", False)
|
||||
@@ -23,6 +26,7 @@ def _external_settings(monkeypatch, python_path: Path):
|
||||
monkeypatch.setattr("services.sam3_engine.settings.sam3_timeout_seconds", 10)
|
||||
monkeypatch.setattr("services.sam3_engine.settings.sam3_status_cache_seconds", 30)
|
||||
monkeypatch.setattr("services.sam3_engine.settings.sam3_confidence_threshold", 0.4)
|
||||
monkeypatch.setattr("services.sam3_engine.settings.sam3_checkpoint_path", str(checkpoint_path))
|
||||
|
||||
|
||||
def test_sam3_status_reports_external_runtime_ready(tmp_path, monkeypatch):
|
||||
@@ -30,9 +34,12 @@ def test_sam3_status_reports_external_runtime_ready(tmp_path, monkeypatch):
|
||||
|
||||
def fake_run(args, **_kwargs):
|
||||
assert "--status" in args
|
||||
assert _kwargs["env"]["SAM3_CHECKPOINT_PATH"].endswith("sam3.pt")
|
||||
return _Completed(stdout=json.dumps({
|
||||
"available": True,
|
||||
"package_available": True,
|
||||
"checkpoint_access": True,
|
||||
"checkpoint_path": _kwargs["env"]["SAM3_CHECKPOINT_PATH"],
|
||||
"python_ok": True,
|
||||
"torch_ok": True,
|
||||
"cuda_available": True,
|
||||
@@ -48,7 +55,10 @@ def test_sam3_status_reports_external_runtime_ready(tmp_path, monkeypatch):
|
||||
assert status["external_available"] is True
|
||||
assert status["package_available"] is True
|
||||
assert status["python_ok"] is True
|
||||
assert status["message"] == "SAM 3 external runtime is ready; model will load in the helper process on inference."
|
||||
assert status["checkpoint_exists"] is True
|
||||
assert status["checkpoint_path"].endswith("sam3.pt")
|
||||
assert status["supports"] == ["semantic", "box", "video_track"]
|
||||
assert status["message"] == "SAM 3 external runtime is ready; local checkpoint will load in the helper process on inference."
|
||||
|
||||
|
||||
def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch):
|
||||
@@ -61,6 +71,7 @@ def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch):
|
||||
return _Completed(stdout=json.dumps({
|
||||
"available": True,
|
||||
"package_available": True,
|
||||
"checkpoint_access": True,
|
||||
"python_ok": True,
|
||||
"torch_ok": True,
|
||||
"cuda_available": True,
|
||||
@@ -71,6 +82,7 @@ def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch):
|
||||
request = json.loads(request_path.read_text(encoding="utf-8"))
|
||||
assert request["text"] == "vessel"
|
||||
assert request["confidence_threshold"] == 0.4
|
||||
assert request["checkpoint_path"].endswith("sam3.pt")
|
||||
assert Path(request["image_path"]).exists()
|
||||
return _Completed(stdout=json.dumps({
|
||||
"polygons": [[[0.1, 0.1], [0.9, 0.1], [0.9, 0.9]]],
|
||||
@@ -86,6 +98,97 @@ def test_sam3_predict_semantic_uses_external_worker(tmp_path, monkeypatch):
|
||||
assert any("--request" in args for args in calls)
|
||||
|
||||
|
||||
def test_sam3_predict_box_uses_external_worker(tmp_path, monkeypatch):
|
||||
_external_settings(monkeypatch, tmp_path / "python")
|
||||
|
||||
def fake_run(args, **_kwargs):
|
||||
if "--status" in args:
|
||||
return _Completed(stdout=json.dumps({
|
||||
"available": True,
|
||||
"package_available": True,
|
||||
"checkpoint_access": True,
|
||||
"python_ok": True,
|
||||
"torch_ok": True,
|
||||
"cuda_available": True,
|
||||
"device": "cuda",
|
||||
"message": "ready",
|
||||
}))
|
||||
request_path = Path(args[-1])
|
||||
request = json.loads(request_path.read_text(encoding="utf-8"))
|
||||
assert request["prompt_type"] == "box"
|
||||
assert request["box"] == [0.1, 0.2, 0.7, 0.8]
|
||||
assert request["text"] == ""
|
||||
return _Completed(stdout=json.dumps({
|
||||
"polygons": [[[0.1, 0.2], [0.7, 0.2], [0.7, 0.8]]],
|
||||
"scores": [0.88],
|
||||
}))
|
||||
|
||||
monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run)
|
||||
|
||||
polygons, scores = SAM3Engine().predict_box(
|
||||
np.zeros((8, 8, 3), dtype=np.uint8),
|
||||
[0.1, 0.2, 0.7, 0.8],
|
||||
)
|
||||
|
||||
assert polygons == [[[0.1, 0.2], [0.7, 0.2], [0.7, 0.8]]]
|
||||
assert scores == [0.88]
|
||||
|
||||
|
||||
def test_sam3_propagate_video_uses_external_worker(tmp_path, monkeypatch):
|
||||
_external_settings(monkeypatch, tmp_path / "python")
|
||||
frame_dir = tmp_path / "frames"
|
||||
frame_dir.mkdir()
|
||||
frame_paths = []
|
||||
for index in range(2):
|
||||
frame_path = frame_dir / f"frame_{index:06d}.jpg"
|
||||
frame_path.write_bytes(b"jpeg")
|
||||
frame_paths.append(str(frame_path))
|
||||
|
||||
def fake_run(args, **_kwargs):
|
||||
if "--status" in args:
|
||||
return _Completed(stdout=json.dumps({
|
||||
"available": True,
|
||||
"package_available": True,
|
||||
"checkpoint_access": True,
|
||||
"python_ok": True,
|
||||
"torch_ok": True,
|
||||
"cuda_available": True,
|
||||
"device": "cuda",
|
||||
"message": "ready",
|
||||
}))
|
||||
request_path = Path(args[-1])
|
||||
request = json.loads(request_path.read_text(encoding="utf-8"))
|
||||
assert request["prompt_type"] == "video_track"
|
||||
assert request["frame_dir"] == str(frame_dir)
|
||||
assert request["source_frame_index"] == 0
|
||||
assert request["direction"] == "forward"
|
||||
assert request["max_frames"] == 2
|
||||
assert request["seed"]["bbox"] == [0.1, 0.1, 0.2, 0.2]
|
||||
return _Completed(stdout=json.dumps({
|
||||
"frames": [
|
||||
{
|
||||
"frame_index": 1,
|
||||
"polygons": [[[0.2, 0.2], [0.4, 0.2], [0.4, 0.4]]],
|
||||
"scores": [0.7],
|
||||
"object_ids": [1],
|
||||
}
|
||||
]
|
||||
}))
|
||||
|
||||
monkeypatch.setattr("services.sam3_engine.subprocess.run", fake_run)
|
||||
|
||||
frames = SAM3Engine().propagate_video(
|
||||
frame_paths,
|
||||
0,
|
||||
{"bbox": [0.1, 0.1, 0.2, 0.2]},
|
||||
direction="forward",
|
||||
max_frames=2,
|
||||
)
|
||||
|
||||
assert frames[0]["frame_index"] == 1
|
||||
assert frames[0]["scores"] == [0.7]
|
||||
|
||||
|
||||
def test_sam3_predict_semantic_reports_external_errors(tmp_path, monkeypatch):
|
||||
_external_settings(monkeypatch, tmp_path / "python")
|
||||
|
||||
@@ -94,6 +197,7 @@ def test_sam3_predict_semantic_reports_external_errors(tmp_path, monkeypatch):
|
||||
return _Completed(stdout=json.dumps({
|
||||
"available": True,
|
||||
"package_available": True,
|
||||
"checkpoint_access": True,
|
||||
"python_ok": True,
|
||||
"torch_ok": True,
|
||||
"cuda_available": True,
|
||||
@@ -110,3 +214,32 @@ def test_sam3_predict_semantic_reports_external_errors(tmp_path, monkeypatch):
|
||||
assert "HF access denied" in str(exc)
|
||||
else:
|
||||
raise AssertionError("Expected SAM 3 external inference failure.")
|
||||
|
||||
|
||||
def test_sam3_worker_casts_floating_tensors_before_numpy():
|
||||
class FakeTensor:
|
||||
def __init__(self):
|
||||
self.float_called = False
|
||||
|
||||
def detach(self):
|
||||
return self
|
||||
|
||||
def is_floating_point(self):
|
||||
return True
|
||||
|
||||
def float(self):
|
||||
self.float_called = True
|
||||
return self
|
||||
|
||||
def cpu(self):
|
||||
return self
|
||||
|
||||
def numpy(self):
|
||||
return np.array([1.0], dtype=np.float32)
|
||||
|
||||
tensor = FakeTensor()
|
||||
|
||||
result = _to_numpy(tensor)
|
||||
|
||||
assert tensor.float_called is True
|
||||
assert result.tolist() == [1.0]
|
||||
|
||||
Reference in New Issue
Block a user