Files
Pre_Seg_Server/backend/services/media_task_runner.py
admin 5ab4602535 feat: 完善视频传播、标注编辑和拆帧闭环
- 接入 SAM2 视频传播能力:新增 /api/ai/propagate,支持用当前帧 mask/polygon/bbox 作为 seed,通过 SAM2 video predictor 向前、向后或双向传播,并可保存为真实 annotation。
- 接入 SAM3 video tracker:通过独立 Python 3.12 external worker 调用 SAM3 video predictor/tracker,使用本地 checkpoint 与 bbox seed 执行视频级跟踪,并在模型状态中标记 video_track 能力。
- 完善 SAM 模型分发:sam_registry 按 model_id 明确区分 sam2 propagation 与 sam3 video_track,避免两个模型链路混用。
- 打通前端“传播片段”:VideoWorkspace 使用当前选中 mask 和当前 AI 模型调用后端传播接口,传播结果回写并刷新工作区已保存标注。
- 增强 SAM3 本地 checkpoint 配置:新增 sam3_checkpoint_path 配置和 .env.example 示例,状态检查改为基于本地 checkpoint/独立环境/模型包可用性。
- 完善视频拆帧参数:/api/media/parse 支持 parse_fps、max_frames、target_width,后端任务保存帧时间戳、源帧号和 frame_sequence 元数据。
- 增加运行时 schema 兼容处理:启动时为旧 frames 表补充 timestamp_ms 和 source_frame_number 列,避免旧库升级后缺字段。
- 强化 Canvas 标注编辑:补齐多边形闭合、点工具、顶点拖拽、边中点插入、Delete/Backspace 删除、区域合并和重叠去除等交互。
- 增强语义分类联动:选中 mask 后可通过右侧语义分类树更新标签、颜色和 class metadata,并同步到保存/导出链路。
- 增加关键帧时间轴体验:FrameTimeline 显示具体时间信息,并支持键盘左右方向键切换关键帧。
- 完善 AI 交互分割参数:前端保留正向点、反向点、框选和 interactive prompt 的调用状态,支持 SAM2 细化候选区域与 SAM3 bbox 入口。
- 扩展后端/前端 API 类型:新增 propagateMasks、传播请求/响应 schema,并补齐 annotation、导出、模型状态和任务接口的测试覆盖。
- 更新项目文档:同步 README、AGENTS、接口契约、需求冻结、设计冻结、前端元素审计、实施计划和测试计划,标明真实功能边界与剩余风险。
- 增加测试覆盖:补充 SAM2/SAM3 传播、SAM3 状态、媒体拆帧参数、Canvas 编辑、语义标签切换、时间轴、工作区传播和 API 合约测试。
- 加强仓库安全边界:将 sam3权重/ 加入 .gitignore,避免本地模型权重被误提交。

验证:npm run test:run;pytest backend/tests;npm run lint;npm run build;python -m py_compile;git diff --check。
2026-05-01 20:27:33 +08:00

326 lines
11 KiB
Python

"""Background media parsing runner used by Celery workers."""
import logging
import os
import shutil
import tempfile
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from sqlalchemy.orm import Session
from minio_client import BUCKET_NAME, download_file, get_minio_client, upload_file
from models import Frame, ProcessingTask, Project
from progress_events import publish_task_progress_event
from services.frame_parser import (
extract_thumbnail,
parse_dicom,
parse_video,
upload_frames_to_minio,
)
from statuses import (
PROJECT_STATUS_PENDING,
PROJECT_STATUS_ERROR,
PROJECT_STATUS_PARSING,
PROJECT_STATUS_READY,
TASK_STATUS_CANCELLED,
TASK_STATUS_FAILED,
TASK_STATUS_RUNNING,
TASK_STATUS_SUCCESS,
)
logger = logging.getLogger(__name__)
class TaskCancelled(RuntimeError):
"""Raised internally when a persisted task has been cancelled."""
def _now() -> datetime:
return datetime.now(timezone.utc)
def _set_task_state(
db: Session,
task: ProcessingTask,
*,
status: str | None = None,
progress: int | None = None,
message: str | None = None,
result: dict[str, Any] | None = None,
error: str | None = None,
started: bool = False,
finished: bool = False,
) -> None:
if status is not None:
task.status = status
if progress is not None:
task.progress = max(0, min(100, progress))
if message is not None:
task.message = message
if result is not None:
task.result = result
if error is not None:
task.error = error
if started:
task.started_at = _now()
if finished:
task.finished_at = _now()
db.commit()
db.refresh(task)
publish_task_progress_event(task)
def _project_status_after_stop(project: Project) -> str:
return PROJECT_STATUS_READY if project.frames else PROJECT_STATUS_PENDING
def _positive_int(value: Any, default: int | None = None) -> int | None:
try:
parsed = int(value)
except (TypeError, ValueError):
return default
return parsed if parsed > 0 else default
def _positive_float(value: Any, default: float) -> float:
try:
parsed = float(value)
except (TypeError, ValueError):
return default
return parsed if parsed > 0 else default
def _frame_sequence_metadata(
index: int,
parse_fps: float,
original_fps: float | None,
) -> dict[str, float | int | None]:
safe_parse_fps = max(float(parse_fps or 1.0), 1e-6)
timestamp_ms = index * 1000.0 / safe_parse_fps
source_frame_number = None
if original_fps and original_fps > 0:
source_frame_number = int(round(index * original_fps / safe_parse_fps))
return {
"timestamp_ms": timestamp_ms,
"source_frame_number": source_frame_number,
}
def _ensure_not_cancelled(db: Session, task: ProcessingTask) -> None:
db.refresh(task)
if task.status == TASK_STATUS_CANCELLED:
raise TaskCancelled("Task was cancelled")
def run_parse_media_task(db: Session, task_id: int) -> dict[str, Any]:
"""Parse one project's media and update task progress in the database."""
task = db.query(ProcessingTask).filter(ProcessingTask.id == task_id).first()
if not task:
raise ValueError(f"Task not found: {task_id}")
if task.status == TASK_STATUS_CANCELLED:
return {
"task_id": task.id,
"status": TASK_STATUS_CANCELLED,
"message": task.message or "任务已取消",
}
if task.project_id is None:
_set_task_state(
db,
task,
status=TASK_STATUS_FAILED,
progress=100,
message="任务缺少 project_id",
error="Task has no project_id",
finished=True,
)
raise ValueError("Task has no project_id")
project = db.query(Project).filter(Project.id == task.project_id).first()
if not project:
_set_task_state(
db,
task,
status=TASK_STATUS_FAILED,
progress=100,
message="项目不存在",
error="Project not found",
finished=True,
)
raise ValueError(f"Project not found: {task.project_id}")
if not project.video_path:
_set_task_state(
db,
task,
status=TASK_STATUS_FAILED,
progress=100,
message="项目没有可解析媒体",
error="Project has no media uploaded",
finished=True,
)
project.status = PROJECT_STATUS_ERROR
db.commit()
raise ValueError("Project has no media uploaded")
_ensure_not_cancelled(db, task)
project.status = PROJECT_STATUS_PARSING
_set_task_state(db, task, status=TASK_STATUS_RUNNING, progress=5, message="后台解析已启动", started=True)
payload = task.payload or {}
effective_source = payload.get("source_type") or project.source_type or "video"
parse_fps = _positive_float(payload.get("parse_fps"), project.parse_fps or 30.0)
max_frames = _positive_int(payload.get("max_frames"))
target_width = _positive_int(payload.get("target_width"), 640) or 640
project.parse_fps = parse_fps
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project.id}_")
output_dir = os.path.join(tmp_dir, "frames")
os.makedirs(output_dir, exist_ok=True)
try:
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=15, message="正在下载媒体文件")
if effective_source == "dicom":
dcm_dir = os.path.join(tmp_dir, "dcm")
os.makedirs(dcm_dir, exist_ok=True)
client = get_minio_client()
objects = list(client.list_objects(BUCKET_NAME, prefix=project.video_path, recursive=True))
for obj in objects:
_ensure_not_cancelled(db, task)
if obj.object_name.lower().endswith(".dcm"):
data = download_file(obj.object_name)
local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name))
with open(local_dcm, "wb") as f:
f.write(data)
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=35, message="正在解析 DICOM 序列")
frame_files = parse_dicom(dcm_dir, output_dir, max_frames=max_frames)
else:
_ensure_not_cancelled(db, task)
media_bytes = download_file(project.video_path)
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
with open(local_path, "wb") as f:
f.write(media_bytes)
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=35, message="正在使用 FFmpeg/OpenCV 拆帧")
frame_files, original_fps = parse_video(
local_path,
output_dir,
fps=int(parse_fps),
max_frames=max_frames,
target_width=target_width,
)
project.original_fps = original_fps
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
try:
extract_thumbnail(local_path, thumbnail_path)
with open(thumbnail_path, "rb") as f:
thumb_data = f.read()
thumb_object = f"projects/{project.id}/thumbnail.jpg"
upload_file(thumb_object, thumb_data, content_type="image/jpeg", length=len(thumb_data))
project.thumbnail_url = thumb_object
except Exception as exc: # noqa: BLE001
logger.warning("Thumbnail extraction failed: %s", exc)
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=70, message="正在上传帧到对象存储")
object_names = upload_frames_to_minio(frame_files, project.id)
_ensure_not_cancelled(db, task)
_set_task_state(db, task, progress=85, message="正在写入帧索引")
frames_out = []
for idx, obj_name in enumerate(object_names):
_ensure_not_cancelled(db, task)
local_frame = frame_files[idx]
try:
import cv2
img = cv2.imread(local_frame)
h, w = img.shape[:2] if img is not None else (None, None)
except Exception: # noqa: BLE001
h, w = None, None
sequence_meta = _frame_sequence_metadata(idx, parse_fps, project.original_fps)
frame = Frame(
project_id=project.id,
frame_index=idx,
image_url=obj_name,
width=w,
height=h,
timestamp_ms=sequence_meta["timestamp_ms"],
source_frame_number=sequence_meta["source_frame_number"],
)
db.add(frame)
frames_out.append(frame)
project.status = PROJECT_STATUS_READY
db.commit()
result = {
"project_id": project.id,
"frames_extracted": len(frames_out),
"status": PROJECT_STATUS_READY,
"message": "Frame extraction completed successfully.",
"frame_sequence": {
"original_fps": project.original_fps,
"parse_fps": parse_fps,
"frame_count": len(frames_out),
"duration_ms": (len(frames_out) - 1) * 1000.0 / parse_fps if frames_out else 0,
"target_width": target_width,
"frame_width": frames_out[0].width if frames_out else None,
"frame_height": frames_out[0].height if frames_out else None,
"max_frames": max_frames,
"object_prefix": f"projects/{project.id}/frames",
},
}
_set_task_state(
db,
task,
status=TASK_STATUS_SUCCESS,
progress=100,
message="解析完成",
result=result,
finished=True,
)
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project.id)
return result
except TaskCancelled:
project.status = _project_status_after_stop(project)
task.status = TASK_STATUS_CANCELLED
task.progress = 100
task.message = task.message or "任务已取消"
task.error = task.error or "Cancelled by user"
task.finished_at = task.finished_at or _now()
db.commit()
db.refresh(task)
publish_task_progress_event(task)
logger.info("Parse task cancelled: task_id=%s project_id=%s", task.id, project.id)
return {
"task_id": task.id,
"project_id": project.id,
"status": TASK_STATUS_CANCELLED,
"message": task.message,
}
except Exception as exc: # noqa: BLE001
project.status = PROJECT_STATUS_ERROR
_set_task_state(
db,
task,
status=TASK_STATUS_FAILED,
progress=100,
message="解析失败",
error=str(exc),
finished=True,
)
logger.error("Frame extraction failed: %s", exc)
raise
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)