Files
Pre_Seg_Server/backend/main.py
admin cadacef04d 修复演示恢复默认模板覆盖逻辑
- 新增后端默认模板服务,集中维护腹腔镜胆囊切除术和头颈部CT分割的权威分类树、颜色、maskid 和层级定义。

- 演示恢复出厂设置时强制恢复系统默认模板,缺失模板会重建,已修改或删减的默认语义分类树会覆盖回默认状态。

- 清理 main.py 中重复的默认模板定义,让启动 seed 复用同一套服务逻辑,避免后续默认模板定义漂移。

- 扩展管理员恢复出厂设置测试,覆盖头颈部CT模板被改坏和腹腔镜模板缺失后的恢复结果。

- 更新 AGENTS、README 和需求/API/测试/前端审计文档,明确恢复出厂设置会权威恢复系统默认模板。
2026-05-03 17:54:19 +08:00

311 lines
11 KiB
Python

"""FastAPI application entrypoint."""
import asyncio
import json
import logging
import os
from contextlib import asynccontextmanager, suppress
from datetime import datetime, timezone
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import inspect, text
from config import settings
from database import Base, engine, SessionLocal
from minio_client import ensure_bucket_exists
from progress_events import PROGRESS_CHANNEL
from redis_client import get_redis_client, ping as redis_ping
from routers import projects, templates, media, ai, export, auth, dashboard, tasks, admin
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
logger = logging.getLogger(__name__)
def _ensure_runtime_schema_columns() -> None:
"""Add nullable columns introduced after initial create_all deployments."""
try:
inspector = inspect(engine)
frame_columns = {column["name"] for column in inspector.get_columns("frames")}
project_columns = {column["name"] for column in inspector.get_columns("projects")}
template_columns = {column["name"] for column in inspector.get_columns("templates")}
with engine.begin() as connection:
if "timestamp_ms" not in frame_columns:
connection.execute(text("ALTER TABLE frames ADD COLUMN timestamp_ms FLOAT"))
if "source_frame_number" not in frame_columns:
connection.execute(text("ALTER TABLE frames ADD COLUMN source_frame_number INTEGER"))
if "owner_user_id" not in project_columns:
connection.execute(text("ALTER TABLE projects ADD COLUMN owner_user_id INTEGER"))
if "owner_user_id" not in template_columns:
connection.execute(text("ALTER TABLE templates ADD COLUMN owner_user_id INTEGER"))
except Exception as exc: # noqa: BLE001
logger.warning("Runtime schema column check failed: %s", exc)
def _seed_default_admin_and_ownership_sync() -> None:
"""Ensure the default admin exists and owns legacy unassigned projects."""
from models import Project
from routers.auth import ensure_default_admin
db = SessionLocal()
try:
admin = ensure_default_admin(db)
db.query(Project).filter(Project.owner_user_id.is_(None)).update(
{"owner_user_id": admin.id},
synchronize_session=False,
)
db.commit()
logger.info("Default admin ready; legacy projects assigned to user id=%s", admin.id)
except Exception as exc: # noqa: BLE001
logger.error("Failed to seed default admin or ownership: %s", exc)
finally:
db.close()
def _seed_default_project_sync() -> None:
"""Synchronously seed the bundled demo video and DICOM projects on first startup."""
from models import Project
from routers.auth import ensure_default_admin
from services.demo_media import (
DEMO_DICOM_PROJECT_NAME,
DEMO_VIDEO_PROJECT_NAME,
create_parsed_dicom_demo_project,
create_unparsed_video_demo_project,
demo_dicom_files,
)
db = SessionLocal()
try:
admin = ensure_default_admin(db)
existing_video = db.query(Project).filter(Project.name == DEMO_VIDEO_PROJECT_NAME).first()
if existing_video is not None and existing_video.owner_user_id is None:
existing_video.owner_user_id = admin.id
db.commit()
elif existing_video is None and os.path.exists(settings.demo_video_path):
video_project = create_unparsed_video_demo_project(
db,
owner=admin,
video_path=settings.demo_video_path,
project_name=DEMO_VIDEO_PROJECT_NAME,
)
logger.info("Seeded default video project id=%s", video_project.id)
existing_dicom = db.query(Project).filter(Project.name == DEMO_DICOM_PROJECT_NAME).first()
if existing_dicom is not None:
if existing_dicom.owner_user_id is None:
existing_dicom.owner_user_id = admin.id
db.commit()
return
if not demo_dicom_files(settings.demo_dicom_dir):
logger.warning("Default DICOM series not found at %s", settings.demo_dicom_dir)
return
project = create_parsed_dicom_demo_project(
db,
owner=admin,
dicom_dir=settings.demo_dicom_dir,
project_name=DEMO_DICOM_PROJECT_NAME,
)
logger.info("Seeded default DICOM project id=%s with %d frames", project.id, len(project.frames))
except Exception as exc:
logger.error("Failed to seed default project: %s", exc)
finally:
db.close()
def _seed_default_templates_sync() -> None:
"""Seed default ontology templates on first startup."""
db = SessionLocal()
try:
ensure_default_templates(db)
except Exception as exc:
logger.error("Failed to seed default templates: %s", exc)
finally:
db.close()
def ensure_default_templates(db) -> None:
"""Ensure all bundled system templates exist."""
from services.default_templates import ensure_default_templates as _ensure_default_templates
_ensure_default_templates(db)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan: startup and shutdown hooks."""
progress_listener: asyncio.Task | None = None
# Startup
logger.info("Starting up SegServer backend...")
# Initialize database tables
try:
Base.metadata.create_all(bind=engine)
_ensure_runtime_schema_columns()
_seed_default_admin_and_ownership_sync()
logger.info("Database tables initialized.")
except Exception as exc: # noqa: BLE001
logger.error("Database initialization failed: %s", exc)
# Check MinIO bucket
try:
ensure_bucket_exists()
except Exception as exc: # noqa: BLE001
logger.error("MinIO bucket check failed: %s", exc)
# Check Redis
if redis_ping():
logger.info("Redis connection OK.")
else:
logger.warning("Redis connection failed.")
try:
progress_listener = asyncio.create_task(_progress_pubsub_loop())
except Exception as exc: # noqa: BLE001
logger.error("Failed to start Redis progress subscription: %s", exc)
# Seed default templates
try:
asyncio.create_task(asyncio.to_thread(_seed_default_templates_sync))
except Exception as exc: # noqa: BLE001
logger.error("Failed to start default template seeding: %s", exc)
# Seed default project in background thread so it doesn't block startup
try:
asyncio.create_task(asyncio.to_thread(_seed_default_project_sync))
except Exception as exc: # noqa: BLE001
logger.error("Failed to start default project seeding: %s", exc)
yield
# Shutdown
logger.info("Shutting down SegServer backend...")
if progress_listener is not None:
progress_listener.cancel()
with suppress(asyncio.CancelledError):
await progress_listener
engine.dispose()
app = FastAPI(
title="SegServer API",
description="Semantic Segmentation System Backend",
version="1.0.0",
lifespan=lifespan,
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Routers
app.include_router(auth.router)
app.include_router(projects.router)
app.include_router(templates.router)
app.include_router(media.router)
app.include_router(ai.router)
app.include_router(export.router)
app.include_router(dashboard.router)
app.include_router(tasks.router)
app.include_router(admin.router)
@app.get("/health", tags=["Health"])
def health_check() -> dict:
"""Health check endpoint."""
return {"status": "ok", "service": "SegServer"}
# ---------------------------------------------------------------------------
# WebSocket: 实时进度推送
# ---------------------------------------------------------------------------
class ConnectionManager:
"""Manage WebSocket connections for progress broadcasting."""
def __init__(self):
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
logger.info("WebSocket client connected. Total: %d", len(self.active_connections))
def disconnect(self, websocket: WebSocket):
if websocket in self.active_connections:
self.active_connections.remove(websocket)
logger.info("WebSocket client disconnected. Total: %d", len(self.active_connections))
async def broadcast(self, message: dict):
"""Broadcast a message to all connected clients."""
for connection in self.active_connections.copy():
try:
await connection.send_json(message)
except Exception as exc:
logger.warning("WebSocket send failed: %s", exc)
self.disconnect(connection)
manager = ConnectionManager()
async def _progress_pubsub_loop() -> None:
"""Forward Redis task-progress events to connected WebSocket clients."""
while True:
pubsub = None
try:
pubsub = get_redis_client().pubsub()
await asyncio.to_thread(pubsub.subscribe, PROGRESS_CHANNEL)
logger.info("Subscribed to Redis progress channel: %s", PROGRESS_CHANNEL)
while True:
message = await asyncio.to_thread(pubsub.get_message, True, 1.0)
if message is None:
await asyncio.sleep(0)
continue
raw_data = message.get("data")
payload = json.loads(raw_data) if isinstance(raw_data, str) else raw_data
if isinstance(payload, dict):
await manager.broadcast(payload)
except asyncio.CancelledError:
raise
except Exception as exc: # noqa: BLE001
logger.error("Redis progress subscription failed: %s", exc)
await asyncio.sleep(5)
finally:
if pubsub is not None:
with suppress(Exception):
await asyncio.to_thread(pubsub.close)
@app.websocket("/ws/progress")
async def websocket_progress(websocket: WebSocket):
"""WebSocket endpoint for real-time parsing/AI progress updates."""
await manager.connect(websocket)
try:
while True:
# Receive client messages (heartbeat / subscription requests)
data = await websocket.receive_text()
logger.debug("WebSocket received: %s", data)
# Echo heartbeat to keep connection alive
await websocket.send_json({
"type": "status",
"status": "connected",
"message": "Progress stream active",
"timestamp": datetime.now(timezone.utc).isoformat(),
})
except WebSocketDisconnect:
manager.disconnect(websocket)
except Exception as exc:
logger.error("WebSocket error: %s", exc)
manager.disconnect(websocket)