Files
Pre_Seg_Server/backend/main.py

220 lines
7.0 KiB
Python

"""FastAPI application entrypoint."""
import asyncio
import logging
import os
import shutil
import tempfile
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from config import settings
from database import Base, engine, SessionLocal
from minio_client import ensure_bucket_exists, upload_file
from redis_client import ping as redis_ping
from routers import projects, templates, media, ai, export, auth
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
logger = logging.getLogger(__name__)
DEFAULT_VIDEO_PATH = "/home/wkmgc/Desktop/Seg_Server/Data_MyVideo_1.mp4"
def _seed_default_project_sync() -> None:
"""Synchronously seed the default video project on first startup."""
import cv2
from models import Project, Frame
from services.frame_parser import parse_video, upload_frames_to_minio
db = SessionLocal()
try:
existing = db.query(Project).filter(Project.name == "Data_MyVideo_1").first()
if existing is not None:
return
if not os.path.exists(DEFAULT_VIDEO_PATH):
logger.warning("Default video not found at %s", DEFAULT_VIDEO_PATH)
return
project = Project(
name="Data_MyVideo_1",
description="默认演示视频",
status="pending",
)
db.add(project)
db.commit()
db.refresh(project)
with open(DEFAULT_VIDEO_PATH, "rb") as f:
data = f.read()
object_name = f"uploads/{project.id}/Data_MyVideo_1.mp4"
upload_file(object_name, data, content_type="video/mp4", length=len(data))
project.video_path = object_name
db.commit()
# Parse frames
tmp_dir = tempfile.mkdtemp(prefix=f"seg_seed_{project.id}_")
try:
local_path = os.path.join(tmp_dir, "video.mp4")
with open(local_path, "wb") as f:
f.write(data)
output_dir = os.path.join(tmp_dir, "frames")
os.makedirs(output_dir, exist_ok=True)
frame_files = parse_video(local_path, output_dir, fps=30, max_frames=100)
object_names = upload_frames_to_minio(frame_files, project.id)
for idx, obj_name in enumerate(object_names):
img = cv2.imread(frame_files[idx])
h, w = img.shape[:2] if img is not None else (None, None)
frame = Frame(
project_id=project.id,
frame_index=idx,
image_url=obj_name,
width=w,
height=h,
)
db.add(frame)
project.status = "ready"
db.commit()
logger.info("Seeded default project id=%s with %d frames", project.id, len(object_names))
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
except Exception as exc:
logger.error("Failed to seed default project: %s", exc)
finally:
db.close()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan: startup and shutdown hooks."""
# Startup
logger.info("Starting up SegServer backend...")
# Initialize database tables
try:
Base.metadata.create_all(bind=engine)
logger.info("Database tables initialized.")
except Exception as exc: # noqa: BLE001
logger.error("Database initialization failed: %s", exc)
# Check MinIO bucket
try:
ensure_bucket_exists()
except Exception as exc: # noqa: BLE001
logger.error("MinIO bucket check failed: %s", exc)
# Check Redis
if redis_ping():
logger.info("Redis connection OK.")
else:
logger.warning("Redis connection failed.")
# Seed default project in background thread so it doesn't block startup
try:
asyncio.create_task(asyncio.to_thread(_seed_default_project_sync))
except Exception as exc: # noqa: BLE001
logger.error("Failed to start default project seeding: %s", exc)
yield
# Shutdown
logger.info("Shutting down SegServer backend...")
engine.dispose()
app = FastAPI(
title="SegServer API",
description="Semantic Segmentation System Backend",
version="1.0.0",
lifespan=lifespan,
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Routers
app.include_router(auth.router)
app.include_router(projects.router)
app.include_router(templates.router)
app.include_router(media.router)
app.include_router(ai.router)
app.include_router(export.router)
@app.get("/health", tags=["Health"])
def health_check() -> dict:
"""Health check endpoint."""
return {"status": "ok", "service": "SegServer"}
# ---------------------------------------------------------------------------
# WebSocket: 实时进度推送
# ---------------------------------------------------------------------------
class ConnectionManager:
"""Manage WebSocket connections for progress broadcasting."""
def __init__(self):
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
logger.info("WebSocket client connected. Total: %d", len(self.active_connections))
def disconnect(self, websocket: WebSocket):
if websocket in self.active_connections:
self.active_connections.remove(websocket)
logger.info("WebSocket client disconnected. Total: %d", len(self.active_connections))
async def broadcast(self, message: dict):
"""Broadcast a message to all connected clients."""
for connection in self.active_connections.copy():
try:
await connection.send_json(message)
except Exception as exc:
logger.warning("WebSocket send failed: %s", exc)
self.disconnect(connection)
manager = ConnectionManager()
@app.websocket("/ws/progress")
async def websocket_progress(websocket: WebSocket):
"""WebSocket endpoint for real-time parsing/AI progress updates."""
await manager.connect(websocket)
try:
while True:
# Receive client messages (heartbeat / subscription requests)
data = await websocket.receive_text()
logger.debug("WebSocket received: %s", data)
# Echo heartbeat to keep connection alive
await websocket.send_json({
"type": "status",
"status": "connected",
"message": "Progress stream active",
"timestamp": str(logging.time.time() if hasattr(logging, 'time') else __import__('time').time()),
})
except WebSocketDisconnect:
manager.disconnect(websocket)
except Exception as exc:
logger.error("WebSocket error: %s", exc)
manager.disconnect(websocket)