- 接入 SAM2 视频传播能力:新增 /api/ai/propagate,支持用当前帧 mask/polygon/bbox 作为 seed,通过 SAM2 video predictor 向前、向后或双向传播,并可保存为真实 annotation。 - 接入 SAM3 video tracker:通过独立 Python 3.12 external worker 调用 SAM3 video predictor/tracker,使用本地 checkpoint 与 bbox seed 执行视频级跟踪,并在模型状态中标记 video_track 能力。 - 完善 SAM 模型分发:sam_registry 按 model_id 明确区分 sam2 propagation 与 sam3 video_track,避免两个模型链路混用。 - 打通前端“传播片段”:VideoWorkspace 使用当前选中 mask 和当前 AI 模型调用后端传播接口,传播结果回写并刷新工作区已保存标注。 - 增强 SAM3 本地 checkpoint 配置:新增 sam3_checkpoint_path 配置和 .env.example 示例,状态检查改为基于本地 checkpoint/独立环境/模型包可用性。 - 完善视频拆帧参数:/api/media/parse 支持 parse_fps、max_frames、target_width,后端任务保存帧时间戳、源帧号和 frame_sequence 元数据。 - 增加运行时 schema 兼容处理:启动时为旧 frames 表补充 timestamp_ms 和 source_frame_number 列,避免旧库升级后缺字段。 - 强化 Canvas 标注编辑:补齐多边形闭合、点工具、顶点拖拽、边中点插入、Delete/Backspace 删除、区域合并和重叠去除等交互。 - 增强语义分类联动:选中 mask 后可通过右侧语义分类树更新标签、颜色和 class metadata,并同步到保存/导出链路。 - 增加关键帧时间轴体验:FrameTimeline 显示具体时间信息,并支持键盘左右方向键切换关键帧。 - 完善 AI 交互分割参数:前端保留正向点、反向点、框选和 interactive prompt 的调用状态,支持 SAM2 细化候选区域与 SAM3 bbox 入口。 - 扩展后端/前端 API 类型:新增 propagateMasks、传播请求/响应 schema,并补齐 annotation、导出、模型状态和任务接口的测试覆盖。 - 更新项目文档:同步 README、AGENTS、接口契约、需求冻结、设计冻结、前端元素审计、实施计划和测试计划,标明真实功能边界与剩余风险。 - 增加测试覆盖:补充 SAM2/SAM3 传播、SAM3 状态、媒体拆帧参数、Canvas 编辑、语义标签切换、时间轴、工作区传播和 API 合约测试。 - 加强仓库安全边界:将 sam3权重/ 加入 .gitignore,避免本地模型权重被误提交。 验证:npm run test:run;pytest backend/tests;npm run lint;npm run build;python -m py_compile;git diff --check。
358 lines
13 KiB
Python
358 lines
13 KiB
Python
"""FastAPI application entrypoint."""
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import os
|
||
import shutil
|
||
import tempfile
|
||
from contextlib import asynccontextmanager, suppress
|
||
from datetime import datetime, timezone
|
||
|
||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from sqlalchemy import inspect, text
|
||
|
||
from config import settings
|
||
from database import Base, engine, SessionLocal
|
||
from minio_client import ensure_bucket_exists, upload_file
|
||
from progress_events import PROGRESS_CHANNEL
|
||
from redis_client import get_redis_client, ping as redis_ping
|
||
from statuses import PROJECT_STATUS_PENDING, PROJECT_STATUS_READY
|
||
|
||
from routers import projects, templates, media, ai, export, auth, dashboard, tasks
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
DEFAULT_VIDEO_PATH = "/home/wkmgc/Desktop/Seg_Server/Data_MyVideo_1.mp4"
|
||
|
||
|
||
def _ensure_runtime_schema_columns() -> None:
|
||
"""Add nullable columns introduced after initial create_all deployments."""
|
||
try:
|
||
inspector = inspect(engine)
|
||
frame_columns = {column["name"] for column in inspector.get_columns("frames")}
|
||
with engine.begin() as connection:
|
||
if "timestamp_ms" not in frame_columns:
|
||
connection.execute(text("ALTER TABLE frames ADD COLUMN timestamp_ms FLOAT"))
|
||
if "source_frame_number" not in frame_columns:
|
||
connection.execute(text("ALTER TABLE frames ADD COLUMN source_frame_number INTEGER"))
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.warning("Runtime schema column check failed: %s", exc)
|
||
|
||
|
||
def _seed_default_project_sync() -> None:
|
||
"""Synchronously seed the default video project on first startup."""
|
||
import cv2
|
||
from models import Project, Frame
|
||
from services.frame_parser import parse_video, upload_frames_to_minio, extract_thumbnail
|
||
|
||
db = SessionLocal()
|
||
try:
|
||
existing = db.query(Project).filter(Project.name == "Data_MyVideo_1").first()
|
||
if existing is not None:
|
||
return
|
||
|
||
if not os.path.exists(DEFAULT_VIDEO_PATH):
|
||
logger.warning("Default video not found at %s", DEFAULT_VIDEO_PATH)
|
||
return
|
||
|
||
project = Project(
|
||
name="Data_MyVideo_1",
|
||
description="默认演示视频",
|
||
status=PROJECT_STATUS_PENDING,
|
||
source_type="video",
|
||
parse_fps=30.0,
|
||
)
|
||
db.add(project)
|
||
db.commit()
|
||
db.refresh(project)
|
||
|
||
with open(DEFAULT_VIDEO_PATH, "rb") as f:
|
||
data = f.read()
|
||
object_name = f"uploads/{project.id}/Data_MyVideo_1.mp4"
|
||
upload_file(object_name, data, content_type="video/mp4", length=len(data))
|
||
|
||
project.video_path = object_name
|
||
db.commit()
|
||
|
||
# Parse frames
|
||
tmp_dir = tempfile.mkdtemp(prefix=f"seg_seed_{project.id}_")
|
||
try:
|
||
local_path = os.path.join(tmp_dir, "video.mp4")
|
||
with open(local_path, "wb") as f:
|
||
f.write(data)
|
||
output_dir = os.path.join(tmp_dir, "frames")
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
frame_files, original_fps = parse_video(local_path, output_dir, fps=30, max_frames=100)
|
||
project.original_fps = original_fps
|
||
|
||
# Extract thumbnail
|
||
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
|
||
try:
|
||
extract_thumbnail(local_path, thumbnail_path)
|
||
with open(thumbnail_path, "rb") as f:
|
||
thumb_data = f.read()
|
||
thumb_object = f"projects/{project.id}/thumbnail.jpg"
|
||
upload_file(thumb_object, thumb_data, content_type="image/jpeg", length=len(thumb_data))
|
||
project.thumbnail_url = thumb_object
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.warning("Thumbnail extraction failed: %s", exc)
|
||
|
||
object_names = upload_frames_to_minio(frame_files, project.id)
|
||
|
||
for idx, obj_name in enumerate(object_names):
|
||
img = cv2.imread(frame_files[idx])
|
||
h, w = img.shape[:2] if img is not None else (None, None)
|
||
timestamp_ms = idx * 1000.0 / 30.0
|
||
source_frame_number = int(round(idx * original_fps / 30.0)) if original_fps else None
|
||
frame = Frame(
|
||
project_id=project.id,
|
||
frame_index=idx,
|
||
image_url=obj_name,
|
||
width=w,
|
||
height=h,
|
||
timestamp_ms=timestamp_ms,
|
||
source_frame_number=source_frame_number,
|
||
)
|
||
db.add(frame)
|
||
|
||
project.status = PROJECT_STATUS_READY
|
||
db.commit()
|
||
logger.info("Seeded default project id=%s with %d frames", project.id, len(object_names))
|
||
finally:
|
||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||
except Exception as exc:
|
||
logger.error("Failed to seed default project: %s", exc)
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
def _seed_default_templates_sync() -> None:
|
||
"""Seed default ontology templates on first startup."""
|
||
from models import Template
|
||
|
||
db = SessionLocal()
|
||
try:
|
||
if db.query(Template).first() is not None:
|
||
return
|
||
|
||
# Laparoscopic cholecystectomy template (35 classes)
|
||
colors = [
|
||
(134, 124, 118), (0, 157, 142), (245, 161, 0), (255, 172, 159), (146, 175, 236), (155, 62, 0),
|
||
(255, 91, 0), (255, 234, 0), (85, 111, 181), (155, 132, 0), (181, 227, 14), (72, 0, 255),
|
||
(255, 0, 255), (29, 32, 136), (240, 16, 116), (160, 15, 95), (0, 155, 33), (0, 160, 233),
|
||
(52, 184, 178), (66, 115, 82), (90, 120, 41), (255, 0, 0), (117, 0, 0), (167, 24, 233),
|
||
(42, 8, 66), (112, 113, 150), (0, 255, 0), (255, 255, 255), (0, 255, 255), (181, 85, 105),
|
||
(113, 102, 140), (202, 202, 200), (197, 83, 181), (136, 162, 196), (138, 251, 213),
|
||
]
|
||
names = [
|
||
'针', '线', '肿瘤', '血管阻断夹', '棉球', '双极电凝',
|
||
'肝脏', '胆囊', '分离钳', '脂肪', '止血海绵', '肝总管',
|
||
'吸引器', '剪刀', '超声刀', '止血纱布', '胆总管', '生物夹',
|
||
'无损伤钳', '钳夹', '喷洒', '胆囊管', '动脉', '电凝',
|
||
'静脉', '标本袋', '引流管', '纱布', '金属钛夹', '韧带',
|
||
'肝蒂', '推结器', '乳胶管-血管阻断', '吻合器', '术中超声',
|
||
]
|
||
classes = []
|
||
for idx, (rgb, name) in enumerate(zip(colors, names)):
|
||
color_hex = f"#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}"
|
||
classes.append({
|
||
"id": f"cls-lap-{idx}",
|
||
"name": name,
|
||
"color": color_hex,
|
||
"zIndex": (len(names) - idx) * 10,
|
||
"category": "腹腔镜胆囊切除术",
|
||
})
|
||
|
||
template = Template(
|
||
name="腹腔镜胆囊切除术",
|
||
description="腹腔镜胆囊切除术(LC)手术器械与解剖结构语义分割模板,共35个分类",
|
||
color="#06b6d4",
|
||
z_index=0,
|
||
mapping_rules={"classes": classes, "rules": []},
|
||
)
|
||
db.add(template)
|
||
db.commit()
|
||
logger.info("Seeded default template '腹腔镜胆囊切除术' with %d classes", len(classes))
|
||
except Exception as exc:
|
||
logger.error("Failed to seed default templates: %s", exc)
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""Application lifespan: startup and shutdown hooks."""
|
||
progress_listener: asyncio.Task | None = None
|
||
# Startup
|
||
logger.info("Starting up SegServer backend...")
|
||
|
||
# Initialize database tables
|
||
try:
|
||
Base.metadata.create_all(bind=engine)
|
||
_ensure_runtime_schema_columns()
|
||
logger.info("Database tables initialized.")
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.error("Database initialization failed: %s", exc)
|
||
|
||
# Check MinIO bucket
|
||
try:
|
||
ensure_bucket_exists()
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.error("MinIO bucket check failed: %s", exc)
|
||
|
||
# Check Redis
|
||
if redis_ping():
|
||
logger.info("Redis connection OK.")
|
||
else:
|
||
logger.warning("Redis connection failed.")
|
||
|
||
try:
|
||
progress_listener = asyncio.create_task(_progress_pubsub_loop())
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.error("Failed to start Redis progress subscription: %s", exc)
|
||
|
||
# Seed default templates
|
||
try:
|
||
asyncio.create_task(asyncio.to_thread(_seed_default_templates_sync))
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.error("Failed to start default template seeding: %s", exc)
|
||
|
||
# Seed default project in background thread so it doesn't block startup
|
||
try:
|
||
asyncio.create_task(asyncio.to_thread(_seed_default_project_sync))
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.error("Failed to start default project seeding: %s", exc)
|
||
|
||
yield
|
||
|
||
# Shutdown
|
||
logger.info("Shutting down SegServer backend...")
|
||
if progress_listener is not None:
|
||
progress_listener.cancel()
|
||
with suppress(asyncio.CancelledError):
|
||
await progress_listener
|
||
engine.dispose()
|
||
|
||
|
||
app = FastAPI(
|
||
title="SegServer API",
|
||
description="Semantic Segmentation System Backend",
|
||
version="1.0.0",
|
||
lifespan=lifespan,
|
||
)
|
||
|
||
# CORS
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=settings.cors_origins,
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# Routers
|
||
app.include_router(auth.router)
|
||
app.include_router(projects.router)
|
||
app.include_router(templates.router)
|
||
app.include_router(media.router)
|
||
app.include_router(ai.router)
|
||
app.include_router(export.router)
|
||
app.include_router(dashboard.router)
|
||
app.include_router(tasks.router)
|
||
|
||
|
||
@app.get("/health", tags=["Health"])
|
||
def health_check() -> dict:
|
||
"""Health check endpoint."""
|
||
return {"status": "ok", "service": "SegServer"}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# WebSocket: 实时进度推送
|
||
# ---------------------------------------------------------------------------
|
||
class ConnectionManager:
|
||
"""Manage WebSocket connections for progress broadcasting."""
|
||
|
||
def __init__(self):
|
||
self.active_connections: list[WebSocket] = []
|
||
|
||
async def connect(self, websocket: WebSocket):
|
||
await websocket.accept()
|
||
self.active_connections.append(websocket)
|
||
logger.info("WebSocket client connected. Total: %d", len(self.active_connections))
|
||
|
||
def disconnect(self, websocket: WebSocket):
|
||
if websocket in self.active_connections:
|
||
self.active_connections.remove(websocket)
|
||
logger.info("WebSocket client disconnected. Total: %d", len(self.active_connections))
|
||
|
||
async def broadcast(self, message: dict):
|
||
"""Broadcast a message to all connected clients."""
|
||
for connection in self.active_connections.copy():
|
||
try:
|
||
await connection.send_json(message)
|
||
except Exception as exc:
|
||
logger.warning("WebSocket send failed: %s", exc)
|
||
self.disconnect(connection)
|
||
|
||
|
||
manager = ConnectionManager()
|
||
|
||
|
||
async def _progress_pubsub_loop() -> None:
|
||
"""Forward Redis task-progress events to connected WebSocket clients."""
|
||
while True:
|
||
pubsub = None
|
||
try:
|
||
pubsub = get_redis_client().pubsub()
|
||
await asyncio.to_thread(pubsub.subscribe, PROGRESS_CHANNEL)
|
||
logger.info("Subscribed to Redis progress channel: %s", PROGRESS_CHANNEL)
|
||
while True:
|
||
message = await asyncio.to_thread(pubsub.get_message, True, 1.0)
|
||
if message is None:
|
||
await asyncio.sleep(0)
|
||
continue
|
||
raw_data = message.get("data")
|
||
payload = json.loads(raw_data) if isinstance(raw_data, str) else raw_data
|
||
if isinstance(payload, dict):
|
||
await manager.broadcast(payload)
|
||
except asyncio.CancelledError:
|
||
raise
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.error("Redis progress subscription failed: %s", exc)
|
||
await asyncio.sleep(5)
|
||
finally:
|
||
if pubsub is not None:
|
||
with suppress(Exception):
|
||
await asyncio.to_thread(pubsub.close)
|
||
|
||
|
||
@app.websocket("/ws/progress")
|
||
async def websocket_progress(websocket: WebSocket):
|
||
"""WebSocket endpoint for real-time parsing/AI progress updates."""
|
||
await manager.connect(websocket)
|
||
try:
|
||
while True:
|
||
# Receive client messages (heartbeat / subscription requests)
|
||
data = await websocket.receive_text()
|
||
logger.debug("WebSocket received: %s", data)
|
||
|
||
# Echo heartbeat to keep connection alive
|
||
await websocket.send_json({
|
||
"type": "status",
|
||
"status": "connected",
|
||
"message": "Progress stream active",
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
})
|
||
except WebSocketDisconnect:
|
||
manager.disconnect(websocket)
|
||
except Exception as exc:
|
||
logger.error("WebSocket error: %s", exc)
|
||
manager.disconnect(websocket)
|