- 登录页和侧栏统一使用根目录 logo_square.png,并更新登录系统名称与副标题。 - 更新 Dashboard、项目库和工作区时间轴文案,移除底层时序视频图层说明。 - 演示视频项目显示名改为“演视LC视频序列”,启动时兼容迁移旧 Data_MyVideo_1 名称,恢复出厂设置使用新名。 - 调整侧栏用户管理入口为用户图标,底部当前用户入口为退出图标,并让退出提示不接收鼠标事件。 - 补充前端组件测试、后端演示重置测试和文档说明。
309 lines
11 KiB
Python
309 lines
11 KiB
Python
"""FastAPI application entrypoint."""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
from contextlib import asynccontextmanager, suppress
|
|
from datetime import datetime, timezone
|
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from sqlalchemy import inspect, text
|
|
|
|
from config import settings
|
|
from database import Base, engine, SessionLocal
|
|
from minio_client import ensure_bucket_exists
|
|
from progress_events import PROGRESS_CHANNEL
|
|
from redis_client import get_redis_client, ping as redis_ping
|
|
|
|
from routers import projects, templates, media, ai, export, auth, dashboard, tasks, admin
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _ensure_runtime_schema_columns() -> None:
|
|
"""Add nullable columns introduced after initial create_all deployments."""
|
|
try:
|
|
inspector = inspect(engine)
|
|
frame_columns = {column["name"] for column in inspector.get_columns("frames")}
|
|
project_columns = {column["name"] for column in inspector.get_columns("projects")}
|
|
template_columns = {column["name"] for column in inspector.get_columns("templates")}
|
|
with engine.begin() as connection:
|
|
if "timestamp_ms" not in frame_columns:
|
|
connection.execute(text("ALTER TABLE frames ADD COLUMN timestamp_ms FLOAT"))
|
|
if "source_frame_number" not in frame_columns:
|
|
connection.execute(text("ALTER TABLE frames ADD COLUMN source_frame_number INTEGER"))
|
|
if "owner_user_id" not in project_columns:
|
|
connection.execute(text("ALTER TABLE projects ADD COLUMN owner_user_id INTEGER"))
|
|
if "owner_user_id" not in template_columns:
|
|
connection.execute(text("ALTER TABLE templates ADD COLUMN owner_user_id INTEGER"))
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.warning("Runtime schema column check failed: %s", exc)
|
|
|
|
|
|
def _seed_default_admin_sync() -> None:
|
|
"""Ensure the single default admin exists without rewriting project ownership metadata."""
|
|
from routers.auth import ensure_default_admin
|
|
|
|
db = SessionLocal()
|
|
try:
|
|
admin = ensure_default_admin(db)
|
|
db.commit()
|
|
logger.info("Default admin ready id=%s", admin.id)
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.error("Failed to seed default admin: %s", exc)
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def _seed_default_project_sync() -> None:
|
|
"""Synchronously seed the bundled demo video and DICOM projects on first startup."""
|
|
from models import Project
|
|
from routers.auth import ensure_default_admin
|
|
from services.demo_media import (
|
|
DEMO_DICOM_PROJECT_NAME,
|
|
DEMO_VIDEO_PROJECT_NAME,
|
|
LEGACY_DEMO_VIDEO_PROJECT_NAMES,
|
|
create_parsed_dicom_demo_project,
|
|
create_unparsed_video_demo_project,
|
|
demo_dicom_files,
|
|
)
|
|
|
|
db = SessionLocal()
|
|
try:
|
|
admin = ensure_default_admin(db)
|
|
legacy_video = (
|
|
db.query(Project)
|
|
.filter(Project.name.in_(LEGACY_DEMO_VIDEO_PROJECT_NAMES))
|
|
.first()
|
|
)
|
|
if legacy_video is not None:
|
|
legacy_video.name = DEMO_VIDEO_PROJECT_NAME
|
|
db.commit()
|
|
existing_video = db.query(Project).filter(Project.name == DEMO_VIDEO_PROJECT_NAME).first()
|
|
if existing_video is None and os.path.exists(settings.demo_video_path):
|
|
video_project = create_unparsed_video_demo_project(
|
|
db,
|
|
owner=admin,
|
|
video_path=settings.demo_video_path,
|
|
project_name=DEMO_VIDEO_PROJECT_NAME,
|
|
)
|
|
logger.info("Seeded default video project id=%s", video_project.id)
|
|
|
|
existing_dicom = db.query(Project).filter(Project.name == DEMO_DICOM_PROJECT_NAME).first()
|
|
if existing_dicom is not None:
|
|
return
|
|
|
|
if not demo_dicom_files(settings.demo_dicom_dir):
|
|
logger.warning("Default DICOM series not found at %s", settings.demo_dicom_dir)
|
|
return
|
|
|
|
project = create_parsed_dicom_demo_project(
|
|
db,
|
|
owner=admin,
|
|
dicom_dir=settings.demo_dicom_dir,
|
|
project_name=DEMO_DICOM_PROJECT_NAME,
|
|
)
|
|
logger.info("Seeded default DICOM project id=%s with %d frames", project.id, len(project.frames))
|
|
except Exception as exc:
|
|
logger.error("Failed to seed default project: %s", exc)
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def _seed_default_templates_sync() -> None:
|
|
"""Seed default ontology templates on first startup."""
|
|
db = SessionLocal()
|
|
try:
|
|
ensure_default_templates(db)
|
|
except Exception as exc:
|
|
logger.error("Failed to seed default templates: %s", exc)
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def ensure_default_templates(db) -> None:
|
|
"""Ensure all bundled system templates exist."""
|
|
from services.default_templates import ensure_default_templates as _ensure_default_templates
|
|
|
|
_ensure_default_templates(db)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Application lifespan: startup and shutdown hooks."""
|
|
progress_listener: asyncio.Task | None = None
|
|
# Startup
|
|
logger.info("Starting up SegServer backend...")
|
|
|
|
# Initialize database tables
|
|
try:
|
|
Base.metadata.create_all(bind=engine)
|
|
_ensure_runtime_schema_columns()
|
|
_seed_default_admin_sync()
|
|
logger.info("Database tables initialized.")
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.error("Database initialization failed: %s", exc)
|
|
|
|
# Check MinIO bucket
|
|
try:
|
|
ensure_bucket_exists()
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.error("MinIO bucket check failed: %s", exc)
|
|
|
|
# Check Redis
|
|
if redis_ping():
|
|
logger.info("Redis connection OK.")
|
|
else:
|
|
logger.warning("Redis connection failed.")
|
|
|
|
try:
|
|
progress_listener = asyncio.create_task(_progress_pubsub_loop())
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.error("Failed to start Redis progress subscription: %s", exc)
|
|
|
|
# Seed default templates
|
|
try:
|
|
asyncio.create_task(asyncio.to_thread(_seed_default_templates_sync))
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.error("Failed to start default template seeding: %s", exc)
|
|
|
|
# Seed default project in background thread so it doesn't block startup
|
|
try:
|
|
asyncio.create_task(asyncio.to_thread(_seed_default_project_sync))
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.error("Failed to start default project seeding: %s", exc)
|
|
|
|
yield
|
|
|
|
# Shutdown
|
|
logger.info("Shutting down SegServer backend...")
|
|
if progress_listener is not None:
|
|
progress_listener.cancel()
|
|
with suppress(asyncio.CancelledError):
|
|
await progress_listener
|
|
engine.dispose()
|
|
|
|
|
|
app = FastAPI(
|
|
title="SegServer API",
|
|
description="Semantic Segmentation System Backend",
|
|
version="1.0.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# CORS
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=settings.cors_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Routers
|
|
app.include_router(auth.router)
|
|
app.include_router(projects.router)
|
|
app.include_router(templates.router)
|
|
app.include_router(media.router)
|
|
app.include_router(ai.router)
|
|
app.include_router(export.router)
|
|
app.include_router(dashboard.router)
|
|
app.include_router(tasks.router)
|
|
app.include_router(admin.router)
|
|
|
|
|
|
@app.get("/health", tags=["Health"])
|
|
def health_check() -> dict:
|
|
"""Health check endpoint."""
|
|
return {"status": "ok", "service": "SegServer"}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# WebSocket: 实时进度推送
|
|
# ---------------------------------------------------------------------------
|
|
class ConnectionManager:
|
|
"""Manage WebSocket connections for progress broadcasting."""
|
|
|
|
def __init__(self):
|
|
self.active_connections: list[WebSocket] = []
|
|
|
|
async def connect(self, websocket: WebSocket):
|
|
await websocket.accept()
|
|
self.active_connections.append(websocket)
|
|
logger.info("WebSocket client connected. Total: %d", len(self.active_connections))
|
|
|
|
def disconnect(self, websocket: WebSocket):
|
|
if websocket in self.active_connections:
|
|
self.active_connections.remove(websocket)
|
|
logger.info("WebSocket client disconnected. Total: %d", len(self.active_connections))
|
|
|
|
async def broadcast(self, message: dict):
|
|
"""Broadcast a message to all connected clients."""
|
|
for connection in self.active_connections.copy():
|
|
try:
|
|
await connection.send_json(message)
|
|
except Exception as exc:
|
|
logger.warning("WebSocket send failed: %s", exc)
|
|
self.disconnect(connection)
|
|
|
|
|
|
manager = ConnectionManager()
|
|
|
|
|
|
async def _progress_pubsub_loop() -> None:
|
|
"""Forward Redis task-progress events to connected WebSocket clients."""
|
|
while True:
|
|
pubsub = None
|
|
try:
|
|
pubsub = get_redis_client().pubsub()
|
|
await asyncio.to_thread(pubsub.subscribe, PROGRESS_CHANNEL)
|
|
logger.info("Subscribed to Redis progress channel: %s", PROGRESS_CHANNEL)
|
|
while True:
|
|
message = await asyncio.to_thread(pubsub.get_message, True, 1.0)
|
|
if message is None:
|
|
await asyncio.sleep(0)
|
|
continue
|
|
raw_data = message.get("data")
|
|
payload = json.loads(raw_data) if isinstance(raw_data, str) else raw_data
|
|
if isinstance(payload, dict):
|
|
await manager.broadcast(payload)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.error("Redis progress subscription failed: %s", exc)
|
|
await asyncio.sleep(5)
|
|
finally:
|
|
if pubsub is not None:
|
|
with suppress(Exception):
|
|
await asyncio.to_thread(pubsub.close)
|
|
|
|
|
|
@app.websocket("/ws/progress")
|
|
async def websocket_progress(websocket: WebSocket):
|
|
"""WebSocket endpoint for real-time parsing/AI progress updates."""
|
|
await manager.connect(websocket)
|
|
try:
|
|
while True:
|
|
# Receive client messages (heartbeat / subscription requests)
|
|
data = await websocket.receive_text()
|
|
logger.debug("WebSocket received: %s", data)
|
|
|
|
# Echo heartbeat to keep connection alive
|
|
await websocket.send_json({
|
|
"type": "status",
|
|
"status": "connected",
|
|
"message": "Progress stream active",
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
})
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(websocket)
|
|
except Exception as exc:
|
|
logger.error("WebSocket error: %s", exc)
|
|
manager.disconnect(websocket)
|