Files
Pre_Seg_Server/backend/main.py
admin 481ffa5b67 完善项目导入、模板与分割工作区交互
- 增强 DICOM/视频项目导入与演示数据:DICOM 按文件名自然顺序处理,导入后展示上传与解析任务进度,恢复演示出厂设置保留演示视频和演示 DICOM 项目,并补充 demo media seed 逻辑。

- 完善项目管理:项目支持重命名、删除、复制,删除使用站内确认弹窗,复制支持新项目重置和全内容复制,DICOM 项目不显示生成帧入口。

- 完善 GT Mask 与导出链路:只支持 8-bit maskid 图导入,非法/全背景图明确拒绝,尺寸自动适配,高精度 polygon 回显;统一导出默认当前帧,GT_label 使用 uint8 和真实 maskid,待分类 maskid 0 与背景一致。

- 完善分割工作区交互:新增画笔和橡皮擦并支持尺寸控制,移除创建点/线段入口,工具栏按类别分隔,AI 智能分割使用明确 AI 图标,取消黄色 seed point,清空/删除传播 mask 后同步清理空帧时间轴状态。

- 完善传播与时间轴:自动传播使用 SAM 2.1 权重任务,参考帧无遮罩时提示,传播历史按同一蓝色系递进变暗,删除/清空传播链时保留人工或独立 AI 标注来源。

- 完善模板库:新增头颈部 CT 分割默认模板,所有模板保留 maskid 0 待分类,支持鼠标复制模板、拖拽层级、JSON 批量导入预览、删除 label 和站内删除确认。

- 完善用户与高风险确认:用户改密码、删除用户、恢复演示出厂设置和清空人工/AI 标注帧均改为站内确认交互,避免浏览器原生 prompt/confirm。

- 补充前后端测试与文档:更新项目、模板、GT 导入、导出、传播、DICOM、用户管理等测试,并同步 README、AGENTS 和 doc 下实现/契约/测试计划文档。
2026-05-03 17:11:59 +08:00

430 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""FastAPI application entrypoint."""
import asyncio
import json
import logging
import os
from contextlib import asynccontextmanager, suppress
from datetime import datetime, timezone
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import inspect, text
from config import settings
from database import Base, engine, SessionLocal
from minio_client import ensure_bucket_exists
from progress_events import PROGRESS_CHANNEL
from redis_client import get_redis_client, ping as redis_ping
from routers import projects, templates, media, ai, export, auth, dashboard, tasks, admin
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
logger = logging.getLogger(__name__)
RESERVED_UNCLASSIFIED_CLASS = {
"id": "reserved-unclassified",
"name": "待分类",
"color": "#000000",
"zIndex": 0,
"maskId": 0,
"category": "系统保留",
}
def _with_reserved_unclassified_class(classes: list[dict]) -> list[dict]:
filtered = [
item for item in classes
if item.get("id") != RESERVED_UNCLASSIFIED_CLASS["id"]
and item.get("name") != RESERVED_UNCLASSIFIED_CLASS["name"]
and item.get("maskId") != 0
]
return [*filtered, dict(RESERVED_UNCLASSIFIED_CLASS)]
def _ensure_runtime_schema_columns() -> None:
"""Add nullable columns introduced after initial create_all deployments."""
try:
inspector = inspect(engine)
frame_columns = {column["name"] for column in inspector.get_columns("frames")}
project_columns = {column["name"] for column in inspector.get_columns("projects")}
template_columns = {column["name"] for column in inspector.get_columns("templates")}
with engine.begin() as connection:
if "timestamp_ms" not in frame_columns:
connection.execute(text("ALTER TABLE frames ADD COLUMN timestamp_ms FLOAT"))
if "source_frame_number" not in frame_columns:
connection.execute(text("ALTER TABLE frames ADD COLUMN source_frame_number INTEGER"))
if "owner_user_id" not in project_columns:
connection.execute(text("ALTER TABLE projects ADD COLUMN owner_user_id INTEGER"))
if "owner_user_id" not in template_columns:
connection.execute(text("ALTER TABLE templates ADD COLUMN owner_user_id INTEGER"))
except Exception as exc: # noqa: BLE001
logger.warning("Runtime schema column check failed: %s", exc)
def _seed_default_admin_and_ownership_sync() -> None:
"""Ensure the default admin exists and owns legacy unassigned projects."""
from models import Project
from routers.auth import ensure_default_admin
db = SessionLocal()
try:
admin = ensure_default_admin(db)
db.query(Project).filter(Project.owner_user_id.is_(None)).update(
{"owner_user_id": admin.id},
synchronize_session=False,
)
db.commit()
logger.info("Default admin ready; legacy projects assigned to user id=%s", admin.id)
except Exception as exc: # noqa: BLE001
logger.error("Failed to seed default admin or ownership: %s", exc)
finally:
db.close()
def _seed_default_project_sync() -> None:
"""Synchronously seed the bundled demo video and DICOM projects on first startup."""
from models import Project
from routers.auth import ensure_default_admin
from services.demo_media import (
DEMO_DICOM_PROJECT_NAME,
DEMO_VIDEO_PROJECT_NAME,
create_parsed_dicom_demo_project,
create_unparsed_video_demo_project,
demo_dicom_files,
)
db = SessionLocal()
try:
admin = ensure_default_admin(db)
existing_video = db.query(Project).filter(Project.name == DEMO_VIDEO_PROJECT_NAME).first()
if existing_video is not None and existing_video.owner_user_id is None:
existing_video.owner_user_id = admin.id
db.commit()
elif existing_video is None and os.path.exists(settings.demo_video_path):
video_project = create_unparsed_video_demo_project(
db,
owner=admin,
video_path=settings.demo_video_path,
project_name=DEMO_VIDEO_PROJECT_NAME,
)
logger.info("Seeded default video project id=%s", video_project.id)
existing_dicom = db.query(Project).filter(Project.name == DEMO_DICOM_PROJECT_NAME).first()
if existing_dicom is not None:
if existing_dicom.owner_user_id is None:
existing_dicom.owner_user_id = admin.id
db.commit()
return
if not demo_dicom_files(settings.demo_dicom_dir):
logger.warning("Default DICOM series not found at %s", settings.demo_dicom_dir)
return
project = create_parsed_dicom_demo_project(
db,
owner=admin,
dicom_dir=settings.demo_dicom_dir,
project_name=DEMO_DICOM_PROJECT_NAME,
)
logger.info("Seeded default DICOM project id=%s with %d frames", project.id, len(project.frames))
except Exception as exc:
logger.error("Failed to seed default project: %s", exc)
finally:
db.close()
def _seed_default_templates_sync() -> None:
"""Seed default ontology templates on first startup."""
from models import Template
db = SessionLocal()
try:
ensure_default_templates(db)
except Exception as exc:
logger.error("Failed to seed default templates: %s", exc)
finally:
db.close()
def _template_classes(
template_name: str,
names: list[str],
colors: list[tuple[int, int, int]],
*,
id_prefix: str,
) -> list[dict]:
classes = []
for idx, (rgb, name) in enumerate(zip(colors, names)):
color_hex = f"#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}"
classes.append({
"id": f"{id_prefix}-{idx}",
"name": name,
"color": color_hex,
"zIndex": (len(names) - idx) * 10,
"maskId": idx + 1,
"category": template_name,
})
return classes
def ensure_default_templates(db) -> None:
"""Ensure all bundled system templates exist."""
from models import Template
default_templates = [
{
"name": "腹腔镜胆囊切除术",
"description": "腹腔镜胆囊切除术LC手术器械与解剖结构语义分割模板共35个分类",
"color": "#06b6d4",
"z_index": 0,
"classes": _with_reserved_unclassified_class(_template_classes(
"腹腔镜胆囊切除术",
[
'', '线', '肿瘤', '血管阻断夹', '棉球', '双极电凝',
'肝脏', '胆囊', '分离钳', '脂肪', '止血海绵', '肝总管',
'吸引器', '剪刀', '超声刀', '止血纱布', '胆总管', '生物夹',
'无损伤钳', '钳夹', '喷洒', '胆囊管', '动脉', '电凝',
'静脉', '标本袋', '引流管', '纱布', '金属钛夹', '韧带',
'肝蒂', '推结器', '乳胶管-血管阻断', '吻合器', '术中超声',
],
[
(134, 124, 118), (0, 157, 142), (245, 161, 0), (255, 172, 159), (146, 175, 236), (155, 62, 0),
(255, 91, 0), (255, 234, 0), (85, 111, 181), (155, 132, 0), (181, 227, 14), (72, 0, 255),
(255, 0, 255), (29, 32, 136), (240, 16, 116), (160, 15, 95), (0, 155, 33), (0, 160, 233),
(52, 184, 178), (66, 115, 82), (90, 120, 41), (255, 0, 0), (117, 0, 0), (167, 24, 233),
(42, 8, 66), (112, 113, 150), (0, 255, 0), (255, 255, 255), (0, 255, 255), (181, 85, 105),
(113, 102, 140), (202, 202, 200), (197, 83, 181), (136, 162, 196), (138, 251, 213),
],
id_prefix="cls-lap",
)),
},
{
"name": "头颈部CT分割",
"description": "头颈部CT分割",
"color": "#ef4444",
"z_index": 10,
"classes": _with_reserved_unclassified_class(_template_classes(
"头颈部CT分割",
[
"肿瘤/结节 (Tumor/Nodule)",
"下颌骨 (Mandible)",
"甲状腺 (Thyroid)",
"气管 (Trachea)",
"颈椎 (Cervical Spine)",
"颈动脉 (Carotid Artery)",
"颈静脉 (Jugular Vein)",
"腮腺 (Parotid Gland)",
"下颌下腺 (Submandibular Gland)",
"舌骨 (Hyoid Bone)",
],
[
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 255),
(255, 128, 0),
(128, 0, 128),
(0, 128, 128),
(128, 128, 0),
],
id_prefix="cls-head-neck-ct",
)),
},
]
for definition in default_templates:
existing = db.query(Template).filter(
Template.name == definition["name"],
Template.owner_user_id.is_(None),
).first()
if existing is not None:
continue
template = Template(
name=definition["name"],
description=definition["description"],
color=definition["color"],
z_index=definition["z_index"],
mapping_rules={"classes": definition["classes"], "rules": []},
owner_user_id=None,
)
db.add(template)
logger.info("Seeded default template '%s' with %d classes", definition["name"], len(definition["classes"]))
db.commit()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan: startup and shutdown hooks."""
progress_listener: asyncio.Task | None = None
# Startup
logger.info("Starting up SegServer backend...")
# Initialize database tables
try:
Base.metadata.create_all(bind=engine)
_ensure_runtime_schema_columns()
_seed_default_admin_and_ownership_sync()
logger.info("Database tables initialized.")
except Exception as exc: # noqa: BLE001
logger.error("Database initialization failed: %s", exc)
# Check MinIO bucket
try:
ensure_bucket_exists()
except Exception as exc: # noqa: BLE001
logger.error("MinIO bucket check failed: %s", exc)
# Check Redis
if redis_ping():
logger.info("Redis connection OK.")
else:
logger.warning("Redis connection failed.")
try:
progress_listener = asyncio.create_task(_progress_pubsub_loop())
except Exception as exc: # noqa: BLE001
logger.error("Failed to start Redis progress subscription: %s", exc)
# Seed default templates
try:
asyncio.create_task(asyncio.to_thread(_seed_default_templates_sync))
except Exception as exc: # noqa: BLE001
logger.error("Failed to start default template seeding: %s", exc)
# Seed default project in background thread so it doesn't block startup
try:
asyncio.create_task(asyncio.to_thread(_seed_default_project_sync))
except Exception as exc: # noqa: BLE001
logger.error("Failed to start default project seeding: %s", exc)
yield
# Shutdown
logger.info("Shutting down SegServer backend...")
if progress_listener is not None:
progress_listener.cancel()
with suppress(asyncio.CancelledError):
await progress_listener
engine.dispose()
app = FastAPI(
title="SegServer API",
description="Semantic Segmentation System Backend",
version="1.0.0",
lifespan=lifespan,
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Routers
app.include_router(auth.router)
app.include_router(projects.router)
app.include_router(templates.router)
app.include_router(media.router)
app.include_router(ai.router)
app.include_router(export.router)
app.include_router(dashboard.router)
app.include_router(tasks.router)
app.include_router(admin.router)
@app.get("/health", tags=["Health"])
def health_check() -> dict:
"""Health check endpoint."""
return {"status": "ok", "service": "SegServer"}
# ---------------------------------------------------------------------------
# WebSocket: 实时进度推送
# ---------------------------------------------------------------------------
class ConnectionManager:
"""Manage WebSocket connections for progress broadcasting."""
def __init__(self):
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
logger.info("WebSocket client connected. Total: %d", len(self.active_connections))
def disconnect(self, websocket: WebSocket):
if websocket in self.active_connections:
self.active_connections.remove(websocket)
logger.info("WebSocket client disconnected. Total: %d", len(self.active_connections))
async def broadcast(self, message: dict):
"""Broadcast a message to all connected clients."""
for connection in self.active_connections.copy():
try:
await connection.send_json(message)
except Exception as exc:
logger.warning("WebSocket send failed: %s", exc)
self.disconnect(connection)
manager = ConnectionManager()
async def _progress_pubsub_loop() -> None:
"""Forward Redis task-progress events to connected WebSocket clients."""
while True:
pubsub = None
try:
pubsub = get_redis_client().pubsub()
await asyncio.to_thread(pubsub.subscribe, PROGRESS_CHANNEL)
logger.info("Subscribed to Redis progress channel: %s", PROGRESS_CHANNEL)
while True:
message = await asyncio.to_thread(pubsub.get_message, True, 1.0)
if message is None:
await asyncio.sleep(0)
continue
raw_data = message.get("data")
payload = json.loads(raw_data) if isinstance(raw_data, str) else raw_data
if isinstance(payload, dict):
await manager.broadcast(payload)
except asyncio.CancelledError:
raise
except Exception as exc: # noqa: BLE001
logger.error("Redis progress subscription failed: %s", exc)
await asyncio.sleep(5)
finally:
if pubsub is not None:
with suppress(Exception):
await asyncio.to_thread(pubsub.close)
@app.websocket("/ws/progress")
async def websocket_progress(websocket: WebSocket):
"""WebSocket endpoint for real-time parsing/AI progress updates."""
await manager.connect(websocket)
try:
while True:
# Receive client messages (heartbeat / subscription requests)
data = await websocket.receive_text()
logger.debug("WebSocket received: %s", data)
# Echo heartbeat to keep connection alive
await websocket.send_json({
"type": "status",
"status": "connected",
"message": "Progress stream active",
"timestamp": datetime.now(timezone.utc).isoformat(),
})
except WebSocketDisconnect:
manager.disconnect(websocket)
except Exception as exc:
logger.error("WebSocket error: %s", exc)
manager.disconnect(websocket)