220 lines
7.0 KiB
Python
220 lines
7.0 KiB
Python
"""FastAPI application entrypoint."""
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from config import settings
|
|
from database import Base, engine, SessionLocal
|
|
from minio_client import ensure_bucket_exists, upload_file
|
|
from redis_client import ping as redis_ping
|
|
|
|
from routers import projects, templates, media, ai, export, auth
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_VIDEO_PATH = "/home/wkmgc/Desktop/Seg_Server/Data_MyVideo_1.mp4"
|
|
|
|
|
|
def _seed_default_project_sync() -> None:
|
|
"""Synchronously seed the default video project on first startup."""
|
|
import cv2
|
|
from models import Project, Frame
|
|
from services.frame_parser import parse_video, upload_frames_to_minio
|
|
|
|
db = SessionLocal()
|
|
try:
|
|
existing = db.query(Project).filter(Project.name == "Data_MyVideo_1").first()
|
|
if existing is not None:
|
|
return
|
|
|
|
if not os.path.exists(DEFAULT_VIDEO_PATH):
|
|
logger.warning("Default video not found at %s", DEFAULT_VIDEO_PATH)
|
|
return
|
|
|
|
project = Project(
|
|
name="Data_MyVideo_1",
|
|
description="默认演示视频",
|
|
status="pending",
|
|
)
|
|
db.add(project)
|
|
db.commit()
|
|
db.refresh(project)
|
|
|
|
with open(DEFAULT_VIDEO_PATH, "rb") as f:
|
|
data = f.read()
|
|
object_name = f"uploads/{project.id}/Data_MyVideo_1.mp4"
|
|
upload_file(object_name, data, content_type="video/mp4", length=len(data))
|
|
|
|
project.video_path = object_name
|
|
db.commit()
|
|
|
|
# Parse frames
|
|
tmp_dir = tempfile.mkdtemp(prefix=f"seg_seed_{project.id}_")
|
|
try:
|
|
local_path = os.path.join(tmp_dir, "video.mp4")
|
|
with open(local_path, "wb") as f:
|
|
f.write(data)
|
|
output_dir = os.path.join(tmp_dir, "frames")
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
frame_files = parse_video(local_path, output_dir, fps=30, max_frames=100)
|
|
|
|
object_names = upload_frames_to_minio(frame_files, project.id)
|
|
|
|
for idx, obj_name in enumerate(object_names):
|
|
img = cv2.imread(frame_files[idx])
|
|
h, w = img.shape[:2] if img is not None else (None, None)
|
|
frame = Frame(
|
|
project_id=project.id,
|
|
frame_index=idx,
|
|
image_url=obj_name,
|
|
width=w,
|
|
height=h,
|
|
)
|
|
db.add(frame)
|
|
|
|
project.status = "ready"
|
|
db.commit()
|
|
logger.info("Seeded default project id=%s with %d frames", project.id, len(object_names))
|
|
finally:
|
|
shutil.rmtree(tmp_dir, ignore_errors=True)
|
|
except Exception as exc:
|
|
logger.error("Failed to seed default project: %s", exc)
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Application lifespan: startup and shutdown hooks."""
|
|
# Startup
|
|
logger.info("Starting up SegServer backend...")
|
|
|
|
# Initialize database tables
|
|
try:
|
|
Base.metadata.create_all(bind=engine)
|
|
logger.info("Database tables initialized.")
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.error("Database initialization failed: %s", exc)
|
|
|
|
# Check MinIO bucket
|
|
try:
|
|
ensure_bucket_exists()
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.error("MinIO bucket check failed: %s", exc)
|
|
|
|
# Check Redis
|
|
if redis_ping():
|
|
logger.info("Redis connection OK.")
|
|
else:
|
|
logger.warning("Redis connection failed.")
|
|
|
|
# Seed default project in background thread so it doesn't block startup
|
|
try:
|
|
asyncio.create_task(asyncio.to_thread(_seed_default_project_sync))
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.error("Failed to start default project seeding: %s", exc)
|
|
|
|
yield
|
|
|
|
# Shutdown
|
|
logger.info("Shutting down SegServer backend...")
|
|
engine.dispose()
|
|
|
|
|
|
app = FastAPI(
|
|
title="SegServer API",
|
|
description="Semantic Segmentation System Backend",
|
|
version="1.0.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# CORS
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=settings.cors_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Routers
|
|
app.include_router(auth.router)
|
|
app.include_router(projects.router)
|
|
app.include_router(templates.router)
|
|
app.include_router(media.router)
|
|
app.include_router(ai.router)
|
|
app.include_router(export.router)
|
|
|
|
|
|
@app.get("/health", tags=["Health"])
|
|
def health_check() -> dict:
|
|
"""Health check endpoint."""
|
|
return {"status": "ok", "service": "SegServer"}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# WebSocket: 实时进度推送
|
|
# ---------------------------------------------------------------------------
|
|
class ConnectionManager:
|
|
"""Manage WebSocket connections for progress broadcasting."""
|
|
|
|
def __init__(self):
|
|
self.active_connections: list[WebSocket] = []
|
|
|
|
async def connect(self, websocket: WebSocket):
|
|
await websocket.accept()
|
|
self.active_connections.append(websocket)
|
|
logger.info("WebSocket client connected. Total: %d", len(self.active_connections))
|
|
|
|
def disconnect(self, websocket: WebSocket):
|
|
if websocket in self.active_connections:
|
|
self.active_connections.remove(websocket)
|
|
logger.info("WebSocket client disconnected. Total: %d", len(self.active_connections))
|
|
|
|
async def broadcast(self, message: dict):
|
|
"""Broadcast a message to all connected clients."""
|
|
for connection in self.active_connections.copy():
|
|
try:
|
|
await connection.send_json(message)
|
|
except Exception as exc:
|
|
logger.warning("WebSocket send failed: %s", exc)
|
|
self.disconnect(connection)
|
|
|
|
|
|
manager = ConnectionManager()
|
|
|
|
|
|
@app.websocket("/ws/progress")
|
|
async def websocket_progress(websocket: WebSocket):
|
|
"""WebSocket endpoint for real-time parsing/AI progress updates."""
|
|
await manager.connect(websocket)
|
|
try:
|
|
while True:
|
|
# Receive client messages (heartbeat / subscription requests)
|
|
data = await websocket.receive_text()
|
|
logger.debug("WebSocket received: %s", data)
|
|
|
|
# Echo heartbeat to keep connection alive
|
|
await websocket.send_json({
|
|
"type": "status",
|
|
"status": "connected",
|
|
"message": "Progress stream active",
|
|
"timestamp": str(logging.time.time() if hasattr(logging, 'time') else __import__('time').time()),
|
|
})
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(websocket)
|
|
except Exception as exc:
|
|
logger.error("WebSocket error: %s", exc)
|
|
manager.disconnect(websocket)
|