"""FastAPI application entrypoint.""" import asyncio import json import logging import os import shutil import tempfile from contextlib import asynccontextmanager, suppress from datetime import datetime, timezone from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from sqlalchemy import inspect, text from config import settings from database import Base, engine, SessionLocal from minio_client import ensure_bucket_exists, upload_file from progress_events import PROGRESS_CHANNEL from redis_client import get_redis_client, ping as redis_ping from statuses import PROJECT_STATUS_PENDING, PROJECT_STATUS_READY from routers import projects, templates, media, ai, export, auth, dashboard, tasks logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", ) logger = logging.getLogger(__name__) DEFAULT_VIDEO_PATH = "/home/wkmgc/Desktop/Seg_Server/Data_MyVideo_1.mp4" def _ensure_runtime_schema_columns() -> None: """Add nullable columns introduced after initial create_all deployments.""" try: inspector = inspect(engine) frame_columns = {column["name"] for column in inspector.get_columns("frames")} with engine.begin() as connection: if "timestamp_ms" not in frame_columns: connection.execute(text("ALTER TABLE frames ADD COLUMN timestamp_ms FLOAT")) if "source_frame_number" not in frame_columns: connection.execute(text("ALTER TABLE frames ADD COLUMN source_frame_number INTEGER")) except Exception as exc: # noqa: BLE001 logger.warning("Runtime schema column check failed: %s", exc) def _seed_default_project_sync() -> None: """Synchronously seed the default video project on first startup.""" import cv2 from models import Project, Frame from services.frame_parser import parse_video, upload_frames_to_minio, extract_thumbnail db = SessionLocal() try: existing = db.query(Project).filter(Project.name == "Data_MyVideo_1").first() if existing is not None: return if not os.path.exists(DEFAULT_VIDEO_PATH): logger.warning("Default video not found at %s", DEFAULT_VIDEO_PATH) return project = Project( name="Data_MyVideo_1", description="默认演示视频", status=PROJECT_STATUS_PENDING, source_type="video", parse_fps=30.0, ) db.add(project) db.commit() db.refresh(project) with open(DEFAULT_VIDEO_PATH, "rb") as f: data = f.read() object_name = f"uploads/{project.id}/Data_MyVideo_1.mp4" upload_file(object_name, data, content_type="video/mp4", length=len(data)) project.video_path = object_name db.commit() # Parse frames tmp_dir = tempfile.mkdtemp(prefix=f"seg_seed_{project.id}_") try: local_path = os.path.join(tmp_dir, "video.mp4") with open(local_path, "wb") as f: f.write(data) output_dir = os.path.join(tmp_dir, "frames") os.makedirs(output_dir, exist_ok=True) frame_files, original_fps = parse_video(local_path, output_dir, fps=30, max_frames=100) project.original_fps = original_fps # Extract thumbnail thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg") try: extract_thumbnail(local_path, thumbnail_path) with open(thumbnail_path, "rb") as f: thumb_data = f.read() thumb_object = f"projects/{project.id}/thumbnail.jpg" upload_file(thumb_object, thumb_data, content_type="image/jpeg", length=len(thumb_data)) project.thumbnail_url = thumb_object except Exception as exc: # noqa: BLE001 logger.warning("Thumbnail extraction failed: %s", exc) object_names = upload_frames_to_minio(frame_files, project.id) for idx, obj_name in enumerate(object_names): img = cv2.imread(frame_files[idx]) h, w = img.shape[:2] if img is not None else (None, None) timestamp_ms = idx * 1000.0 / 30.0 source_frame_number = int(round(idx * original_fps / 30.0)) if original_fps else None frame = Frame( project_id=project.id, frame_index=idx, image_url=obj_name, width=w, height=h, timestamp_ms=timestamp_ms, source_frame_number=source_frame_number, ) db.add(frame) project.status = PROJECT_STATUS_READY db.commit() logger.info("Seeded default project id=%s with %d frames", project.id, len(object_names)) finally: shutil.rmtree(tmp_dir, ignore_errors=True) except Exception as exc: logger.error("Failed to seed default project: %s", exc) finally: db.close() def _seed_default_templates_sync() -> None: """Seed default ontology templates on first startup.""" from models import Template db = SessionLocal() try: if db.query(Template).first() is not None: return # Laparoscopic cholecystectomy template (35 classes) colors = [ (134, 124, 118), (0, 157, 142), (245, 161, 0), (255, 172, 159), (146, 175, 236), (155, 62, 0), (255, 91, 0), (255, 234, 0), (85, 111, 181), (155, 132, 0), (181, 227, 14), (72, 0, 255), (255, 0, 255), (29, 32, 136), (240, 16, 116), (160, 15, 95), (0, 155, 33), (0, 160, 233), (52, 184, 178), (66, 115, 82), (90, 120, 41), (255, 0, 0), (117, 0, 0), (167, 24, 233), (42, 8, 66), (112, 113, 150), (0, 255, 0), (255, 255, 255), (0, 255, 255), (181, 85, 105), (113, 102, 140), (202, 202, 200), (197, 83, 181), (136, 162, 196), (138, 251, 213), ] names = [ '针', '线', '肿瘤', '血管阻断夹', '棉球', '双极电凝', '肝脏', '胆囊', '分离钳', '脂肪', '止血海绵', '肝总管', '吸引器', '剪刀', '超声刀', '止血纱布', '胆总管', '生物夹', '无损伤钳', '钳夹', '喷洒', '胆囊管', '动脉', '电凝', '静脉', '标本袋', '引流管', '纱布', '金属钛夹', '韧带', '肝蒂', '推结器', '乳胶管-血管阻断', '吻合器', '术中超声', ] classes = [] for idx, (rgb, name) in enumerate(zip(colors, names)): color_hex = f"#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}" classes.append({ "id": f"cls-lap-{idx}", "name": name, "color": color_hex, "zIndex": (len(names) - idx) * 10, "category": "腹腔镜胆囊切除术", }) template = Template( name="腹腔镜胆囊切除术", description="腹腔镜胆囊切除术(LC)手术器械与解剖结构语义分割模板,共35个分类", color="#06b6d4", z_index=0, mapping_rules={"classes": classes, "rules": []}, ) db.add(template) db.commit() logger.info("Seeded default template '腹腔镜胆囊切除术' with %d classes", len(classes)) except Exception as exc: logger.error("Failed to seed default templates: %s", exc) finally: db.close() @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan: startup and shutdown hooks.""" progress_listener: asyncio.Task | None = None # Startup logger.info("Starting up SegServer backend...") # Initialize database tables try: Base.metadata.create_all(bind=engine) _ensure_runtime_schema_columns() logger.info("Database tables initialized.") except Exception as exc: # noqa: BLE001 logger.error("Database initialization failed: %s", exc) # Check MinIO bucket try: ensure_bucket_exists() except Exception as exc: # noqa: BLE001 logger.error("MinIO bucket check failed: %s", exc) # Check Redis if redis_ping(): logger.info("Redis connection OK.") else: logger.warning("Redis connection failed.") try: progress_listener = asyncio.create_task(_progress_pubsub_loop()) except Exception as exc: # noqa: BLE001 logger.error("Failed to start Redis progress subscription: %s", exc) # Seed default templates try: asyncio.create_task(asyncio.to_thread(_seed_default_templates_sync)) except Exception as exc: # noqa: BLE001 logger.error("Failed to start default template seeding: %s", exc) # Seed default project in background thread so it doesn't block startup try: asyncio.create_task(asyncio.to_thread(_seed_default_project_sync)) except Exception as exc: # noqa: BLE001 logger.error("Failed to start default project seeding: %s", exc) yield # Shutdown logger.info("Shutting down SegServer backend...") if progress_listener is not None: progress_listener.cancel() with suppress(asyncio.CancelledError): await progress_listener engine.dispose() app = FastAPI( title="SegServer API", description="Semantic Segmentation System Backend", version="1.0.0", lifespan=lifespan, ) # CORS app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Routers app.include_router(auth.router) app.include_router(projects.router) app.include_router(templates.router) app.include_router(media.router) app.include_router(ai.router) app.include_router(export.router) app.include_router(dashboard.router) app.include_router(tasks.router) @app.get("/health", tags=["Health"]) def health_check() -> dict: """Health check endpoint.""" return {"status": "ok", "service": "SegServer"} # --------------------------------------------------------------------------- # WebSocket: 实时进度推送 # --------------------------------------------------------------------------- class ConnectionManager: """Manage WebSocket connections for progress broadcasting.""" def __init__(self): self.active_connections: list[WebSocket] = [] async def connect(self, websocket: WebSocket): await websocket.accept() self.active_connections.append(websocket) logger.info("WebSocket client connected. Total: %d", len(self.active_connections)) def disconnect(self, websocket: WebSocket): if websocket in self.active_connections: self.active_connections.remove(websocket) logger.info("WebSocket client disconnected. Total: %d", len(self.active_connections)) async def broadcast(self, message: dict): """Broadcast a message to all connected clients.""" for connection in self.active_connections.copy(): try: await connection.send_json(message) except Exception as exc: logger.warning("WebSocket send failed: %s", exc) self.disconnect(connection) manager = ConnectionManager() async def _progress_pubsub_loop() -> None: """Forward Redis task-progress events to connected WebSocket clients.""" while True: pubsub = None try: pubsub = get_redis_client().pubsub() await asyncio.to_thread(pubsub.subscribe, PROGRESS_CHANNEL) logger.info("Subscribed to Redis progress channel: %s", PROGRESS_CHANNEL) while True: message = await asyncio.to_thread(pubsub.get_message, True, 1.0) if message is None: await asyncio.sleep(0) continue raw_data = message.get("data") payload = json.loads(raw_data) if isinstance(raw_data, str) else raw_data if isinstance(payload, dict): await manager.broadcast(payload) except asyncio.CancelledError: raise except Exception as exc: # noqa: BLE001 logger.error("Redis progress subscription failed: %s", exc) await asyncio.sleep(5) finally: if pubsub is not None: with suppress(Exception): await asyncio.to_thread(pubsub.close) @app.websocket("/ws/progress") async def websocket_progress(websocket: WebSocket): """WebSocket endpoint for real-time parsing/AI progress updates.""" await manager.connect(websocket) try: while True: # Receive client messages (heartbeat / subscription requests) data = await websocket.receive_text() logger.debug("WebSocket received: %s", data) # Echo heartbeat to keep connection alive await websocket.send_json({ "type": "status", "status": "connected", "message": "Progress stream active", "timestamp": datetime.now(timezone.utc).isoformat(), }) except WebSocketDisconnect: manager.disconnect(websocket) except Exception as exc: logger.error("WebSocket error: %s", exc) manager.disconnect(websocket)