"""FastAPI application entrypoint.""" import logging from contextlib import asynccontextmanager from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from config import settings from database import Base, engine from minio_client import ensure_bucket_exists from redis_client import ping as redis_ping from routers import projects, templates, media, ai, export, auth logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", ) logger = logging.getLogger(__name__) @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan: startup and shutdown hooks.""" # Startup logger.info("Starting up SegServer backend...") # Initialize database tables try: Base.metadata.create_all(bind=engine) logger.info("Database tables initialized.") except Exception as exc: # noqa: BLE001 logger.error("Database initialization failed: %s", exc) # Check MinIO bucket try: ensure_bucket_exists() except Exception as exc: # noqa: BLE001 logger.error("MinIO bucket check failed: %s", exc) # Check Redis if redis_ping(): logger.info("Redis connection OK.") else: logger.warning("Redis connection failed.") yield # Shutdown logger.info("Shutting down SegServer backend...") engine.dispose() app = FastAPI( title="SegServer API", description="Semantic Segmentation System Backend", version="1.0.0", lifespan=lifespan, ) # CORS app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Routers app.include_router(auth.router) app.include_router(projects.router) app.include_router(templates.router) app.include_router(media.router) app.include_router(ai.router) app.include_router(export.router) @app.get("/health", tags=["Health"]) def health_check() -> dict: """Health check endpoint.""" return {"status": "ok", "service": "SegServer"} # --------------------------------------------------------------------------- # WebSocket: 实时进度推送 # --------------------------------------------------------------------------- class ConnectionManager: """Manage WebSocket connections for progress broadcasting.""" def __init__(self): self.active_connections: list[WebSocket] = [] async def connect(self, websocket: WebSocket): await websocket.accept() self.active_connections.append(websocket) logger.info("WebSocket client connected. Total: %d", len(self.active_connections)) def disconnect(self, websocket: WebSocket): if websocket in self.active_connections: self.active_connections.remove(websocket) logger.info("WebSocket client disconnected. Total: %d", len(self.active_connections)) async def broadcast(self, message: dict): """Broadcast a message to all connected clients.""" for connection in self.active_connections.copy(): try: await connection.send_json(message) except Exception as exc: logger.warning("WebSocket send failed: %s", exc) self.disconnect(connection) manager = ConnectionManager() @app.websocket("/ws/progress") async def websocket_progress(websocket: WebSocket): """WebSocket endpoint for real-time parsing/AI progress updates.""" await manager.connect(websocket) try: while True: # Receive client messages (heartbeat / subscription requests) data = await websocket.receive_text() logger.debug("WebSocket received: %s", data) # Echo heartbeat to keep connection alive await websocket.send_json({ "type": "status", "status": "connected", "message": "Progress stream active", "timestamp": str(logging.time.time() if hasattr(logging, 'time') else __import__('time').time()), }) except WebSocketDisconnect: manager.disconnect(websocket) except Exception as exc: logger.error("WebSocket error: %s", exc) manager.disconnect(websocket)