from __future__ import annotations import shutil import uuid from pathlib import Path from typing import Any import cv2 from fastapi import FastAPI, File, Form, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from backend.segmentation import METHOD_DESCRIPTIONS, compare_frame, segment_frame ROOT = Path(__file__).resolve().parents[1] FRONTEND_DIR = ROOT / "frontend" STORAGE_DIR = ROOT / "storage" UPLOAD_DIR = STORAGE_DIR / "uploads" JOB_DIR = STORAGE_DIR / "jobs" SAMPLE_DIR = STORAGE_DIR / "samples" IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"} VIDEO_SUFFIXES = {".mp4", ".avi", ".mov", ".mkv", ".webm"} app = FastAPI(title="ISISeg Guidewire Segmentation", version="0.1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) def ensure_dirs() -> None: for directory in (UPLOAD_DIR, JOB_DIR, SAMPLE_DIR): directory.mkdir(parents=True, exist_ok=True) ensure_dirs() app.mount("/storage", StaticFiles(directory=STORAGE_DIR), name="storage") @app.get("/api/health") def health() -> dict[str, str]: return {"status": "ok", "service": "ISISeg", "version": app.version} @app.get("/api/methods") def methods() -> dict[str, Any]: return {"methods": METHOD_DESCRIPTIONS} def _public(path: Path) -> str: return "/" + path.relative_to(ROOT).as_posix() def _save_upload(file: UploadFile, job_path: Path) -> Path: suffix = Path(file.filename or "").suffix.lower() if suffix not in IMAGE_SUFFIXES | VIDEO_SUFFIXES: raise HTTPException(status_code=400, detail="仅支持图片或视频文件") destination = job_path / f"upload{suffix}" with destination.open("wb") as buffer: shutil.copyfileobj(file.file, buffer) return destination def _save_frame_outputs( job_path: Path, frame_index: int, method: str, frame, output, ) -> dict[str, Any]: method_path = job_path / method method_path.mkdir(exist_ok=True) original_path = method_path / f"frame_{frame_index:04d}_original.png" mask_path = method_path / f"frame_{frame_index:04d}_mask.png" overlay_path = method_path / f"frame_{frame_index:04d}_overlay.png" cv2.imwrite(str(original_path), frame) cv2.imwrite(str(mask_path), output.mask) cv2.imwrite(str(overlay_path), output.overlay) return { "frame_index": frame_index, "method": method, "original_url": _public(original_path), "mask_url": _public(mask_path), "overlay_url": _public(overlay_path), "metrics": output.metrics, } def _process_image(job_path: Path, source: Path, method: str, sensitivity: float) -> dict[str, Any]: frame = cv2.imread(str(source), cv2.IMREAD_COLOR) if frame is None: raise HTTPException(status_code=400, detail="无法读取图片") if method == "compare": results = [ _save_frame_outputs(job_path, 0, item.method, frame, item) for item in compare_frame(frame, None, sensitivity) ] else: output = segment_frame(frame, method, None, sensitivity) results = [_save_frame_outputs(job_path, 0, method, frame, output)] return {"kind": "image", "frames": results, "video_url": None} def _selected_frame(index: int, stride: int, selected_count: int, max_frames: int) -> bool: return index % max(1, stride) == 0 and selected_count < max_frames def _process_video( job_path: Path, source: Path, method: str, sensitivity: float, frame_stride: int, max_frames: int, ) -> dict[str, Any]: capture = cv2.VideoCapture(str(source)) if not capture.isOpened(): raise HTTPException(status_code=400, detail="无法读取视频") frames: list[dict[str, Any]] = [] previous = None frame_index = 0 selected_count = 0 writer = None video_path = job_path / f"{method}_overlay.mp4" try: while True: ok, frame = capture.read() if not ok: break if not _selected_frame(frame_index, frame_stride, selected_count, max_frames): previous = frame frame_index += 1 continue if method == "compare": outputs = compare_frame(frame, previous, sensitivity) for output in outputs: frames.append(_save_frame_outputs(job_path, frame_index, output.method, frame, output)) video_output = next(item for item in outputs if item.method == "fusion") else: video_output = segment_frame(frame, method, previous, sensitivity) frames.append(_save_frame_outputs(job_path, frame_index, method, frame, video_output)) if writer is None: height, width = video_output.overlay.shape[:2] fourcc = cv2.VideoWriter_fourcc(*"mp4v") writer = cv2.VideoWriter(str(video_path), fourcc, 8.0, (width, height)) writer.write(video_output.overlay) selected_count += 1 previous = frame frame_index += 1 finally: capture.release() if writer is not None: writer.release() if not frames: raise HTTPException(status_code=400, detail="视频没有可处理帧") return {"kind": "video", "frames": frames, "video_url": _public(video_path) if video_path.exists() else None} @app.post("/api/segment") def segment( file: UploadFile = File(...), method: str = Form("fusion"), sensitivity: float = Form(0.56), frame_stride: int = Form(5), max_frames: int = Form(12), ) -> dict[str, Any]: ensure_dirs() if method not in METHOD_DESCRIPTIONS: raise HTTPException(status_code=400, detail="未知分割方法") sensitivity = max(0.05, min(float(sensitivity), 0.95)) frame_stride = max(1, min(int(frame_stride), 90)) max_frames = max(1, min(int(max_frames), 80)) job_id = uuid.uuid4().hex[:12] job_path = JOB_DIR / job_id job_path.mkdir(parents=True, exist_ok=True) source = _save_upload(file, job_path) suffix = source.suffix.lower() if suffix in IMAGE_SUFFIXES: result = _process_image(job_path, source, method, sensitivity) elif suffix in VIDEO_SUFFIXES: result = _process_video(job_path, source, method, sensitivity, frame_stride, max_frames) else: raise HTTPException(status_code=400, detail="不支持的文件类型") return { "job_id": job_id, "method": method, "sensitivity": sensitivity, "frame_stride": frame_stride, "max_frames": max_frames, **result, } @app.get("/") def index() -> FileResponse: return FileResponse(FRONTEND_DIR / "index.html") app.mount("/", StaticFiles(directory=FRONTEND_DIR), name="frontend")