from __future__ import annotations import subprocess import shutil import uuid from pathlib import Path from typing import Any import cv2 from fastapi import FastAPI, File, Form, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, Response from fastapi.staticfiles import StaticFiles from backend.segmentation import METHOD_DESCRIPTIONS, compare_frame, segment_frame ROOT = Path(__file__).resolve().parents[1] FRONTEND_DIR = ROOT / "frontend" STORAGE_DIR = ROOT / "storage" UPLOAD_DIR = STORAGE_DIR / "uploads" JOB_DIR = STORAGE_DIR / "jobs" SAMPLE_DIR = STORAGE_DIR / "samples" IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"} VIDEO_SUFFIXES = {".mp4", ".avi", ".mov", ".mkv", ".webm"} app = FastAPI(title="ISISeg Guidewire Segmentation", version="0.1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) def ensure_dirs() -> None: for directory in (UPLOAD_DIR, JOB_DIR, SAMPLE_DIR): directory.mkdir(parents=True, exist_ok=True) ensure_dirs() app.mount("/storage", StaticFiles(directory=STORAGE_DIR), name="storage") @app.get("/api/health") def health() -> dict[str, str]: return {"status": "ok", "service": "ISISeg", "version": app.version} @app.get("/api/methods") def methods() -> dict[str, Any]: return {"methods": METHOD_DESCRIPTIONS} @app.get("/api/samples") def samples() -> dict[str, Any]: ensure_dirs() items = [] for path in sorted(SAMPLE_DIR.glob("*")): suffix = path.suffix.lower() if suffix in IMAGE_SUFFIXES | VIDEO_SUFFIXES: version = str(path.stat().st_mtime_ns) items.append( { "name": path.name, "url": f"{_public(path)}?v={version}", "kind": "image" if suffix in IMAGE_SUFFIXES else "video", "size": path.stat().st_size, "version": version, } ) return {"samples": items} @app.get("/favicon.ico", include_in_schema=False) def favicon() -> Response: svg = """ """ return Response(content=svg, media_type="image/svg+xml") def _public(path: Path) -> str: return "/" + path.relative_to(ROOT).as_posix() def _save_upload(file: UploadFile, job_path: Path) -> Path: suffix = Path(file.filename or "").suffix.lower() if suffix not in IMAGE_SUFFIXES | VIDEO_SUFFIXES: raise HTTPException(status_code=400, detail="仅支持图片或视频文件") destination = job_path / f"upload{suffix}" with destination.open("wb") as buffer: shutil.copyfileobj(file.file, buffer) return destination def _save_frame_outputs( job_path: Path, frame_index: int, method: str, frame, output, source_time: float | None = None, result_time: float | None = None, result_index: int | None = None, ) -> dict[str, Any]: method_path = job_path / method method_path.mkdir(exist_ok=True) original_path = method_path / f"frame_{frame_index:04d}_original.png" mask_path = method_path / f"frame_{frame_index:04d}_mask.png" overlay_path = method_path / f"frame_{frame_index:04d}_overlay.png" cv2.imwrite(str(original_path), frame) cv2.imwrite(str(mask_path), output.mask) cv2.imwrite(str(overlay_path), output.overlay) payload = { "frame_index": frame_index, "method": method, "original_url": _public(original_path), "mask_url": _public(mask_path), "overlay_url": _public(overlay_path), "metrics": output.metrics, } if source_time is not None: payload["source_time"] = round(float(source_time), 4) if result_time is not None: payload["result_time"] = round(float(result_time), 4) if result_index is not None: payload["result_index"] = int(result_index) return payload def _browser_video(raw_path: Path, final_path: Path) -> Path: ffmpeg = shutil.which("ffmpeg") if ffmpeg: subprocess.run( [ ffmpeg, "-y", "-i", str(raw_path), "-c:v", "libx264", "-pix_fmt", "yuv420p", "-movflags", "+faststart", "-preset", "veryfast", "-crf", "20", str(final_path), ], check=True, capture_output=True, text=True, ) raw_path.unlink(missing_ok=True) else: raw_path.replace(final_path) return final_path def _process_image(job_path: Path, source: Path, method: str, sensitivity: float) -> dict[str, Any]: frame = cv2.imread(str(source), cv2.IMREAD_COLOR) if frame is None: raise HTTPException(status_code=400, detail="无法读取图片") if method == "compare": results = [ _save_frame_outputs(job_path, 0, item.method, frame, item, 0.0, 0.0, 0) for item in compare_frame(frame, None, sensitivity) ] else: output = segment_frame(frame, method, None, sensitivity) results = [_save_frame_outputs(job_path, 0, method, frame, output, 0.0, 0.0, 0)] return {"kind": "image", "frames": results, "video_url": None, "source_fps": 1.0, "result_fps": 1.0, "duration": 0.0} def _selected_frame(index: int, stride: int, selected_count: int, max_frames: int) -> bool: return index % max(1, stride) == 0 and selected_count < max_frames def _selected_frame_indices(frame_count: int, stride: int, max_frames: int) -> set[int]: if frame_count <= 0: return set() candidates = list(range(0, frame_count, max(1, stride))) if len(candidates) <= max_frames: return set(candidates) if max_frames <= 1: return {candidates[0]} last = len(candidates) - 1 return { candidates[round(position * last / (max_frames - 1))] for position in range(max_frames) } def _process_video( job_path: Path, source: Path, method: str, sensitivity: float, frame_stride: int, max_frames: int, ) -> dict[str, Any]: capture = cv2.VideoCapture(str(source)) if not capture.isOpened(): raise HTTPException(status_code=400, detail="无法读取视频") frames: list[dict[str, Any]] = [] previous = None frame_index = 0 selected_count = 0 written_count = 0 writer = None raw_video_path = job_path / f"{method}_overlay.raw.mp4" video_path = job_path / f"{method}_overlay.mp4" source_fps = float(capture.get(cv2.CAP_PROP_FPS) or 0.0) if source_fps <= 0: source_fps = 12.0 result_fps = source_fps frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT) or 0) selected_indices = _selected_frame_indices(frame_count, frame_stride, max_frames) try: while True: ok, frame = capture.read() if not ok: break should_process = ( frame_index in selected_indices if selected_indices else _selected_frame(frame_index, frame_stride, selected_count, max_frames) ) video_frame = frame if should_process and method == "compare": outputs = compare_frame(frame, previous, sensitivity) for output in outputs: frames.append( _save_frame_outputs( job_path, frame_index, output.method, frame, output, frame_index / source_fps, frame_index / source_fps, frame_index, ) ) video_output = next(item for item in outputs if item.method == "fusion") video_frame = video_output.overlay selected_count += 1 elif should_process: video_output = segment_frame(frame, method, previous, sensitivity) frames.append( _save_frame_outputs( job_path, frame_index, method, frame, video_output, frame_index / source_fps, frame_index / source_fps, frame_index, ) ) video_frame = video_output.overlay selected_count += 1 if writer is None: height, width = frame.shape[:2] fourcc = cv2.VideoWriter_fourcc(*"mp4v") writer = cv2.VideoWriter(str(raw_video_path), fourcc, result_fps, (width, height)) writer.write(video_frame) written_count += 1 previous = frame frame_index += 1 finally: capture.release() if writer is not None: writer.release() if not frames: raise HTTPException(status_code=400, detail="视频没有可处理帧") if raw_video_path.exists(): _browser_video(raw_video_path, video_path) duration = round(written_count / source_fps, 4) if written_count else 0.0 return { "kind": "video", "frames": frames, "video_url": _public(video_path) if video_path.exists() else None, "source_fps": round(source_fps, 4), "result_fps": round(result_fps, 4), "duration": duration, "result_duration": duration, } def _read_video_frame(source: Path, frame_index: int) -> tuple[Any, Any | None, float]: capture = cv2.VideoCapture(str(source)) if not capture.isOpened(): raise HTTPException(status_code=400, detail="无法读取视频") source_fps = float(capture.get(cv2.CAP_PROP_FPS) or 0.0) if source_fps <= 0: source_fps = 12.0 frame_index = max(0, int(frame_index)) previous = None if frame_index > 0: capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index - 1) ok_prev, previous = capture.read() if not ok_prev: previous = None capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index) ok, frame = capture.read() capture.release() if not ok: raise HTTPException(status_code=400, detail="无法读取指定帧") return frame, previous, source_fps @app.post("/api/segment") def segment( file: UploadFile = File(...), method: str = Form("fusion"), sensitivity: float = Form(0.56), frame_stride: int = Form(5), max_frames: int = Form(12), ) -> dict[str, Any]: ensure_dirs() if method not in METHOD_DESCRIPTIONS: raise HTTPException(status_code=400, detail="未知分割方法") sensitivity = max(0.05, min(float(sensitivity), 0.95)) frame_stride = max(1, min(int(frame_stride), 90)) max_frames = max(1, min(int(max_frames), 80)) job_id = uuid.uuid4().hex[:12] job_path = JOB_DIR / job_id job_path.mkdir(parents=True, exist_ok=True) source = _save_upload(file, job_path) suffix = source.suffix.lower() if suffix in IMAGE_SUFFIXES: result = _process_image(job_path, source, method, sensitivity) elif suffix in VIDEO_SUFFIXES: result = _process_video(job_path, source, method, sensitivity, frame_stride, max_frames) else: raise HTTPException(status_code=400, detail="不支持的文件类型") return { "job_id": job_id, "method": method, "sensitivity": sensitivity, "frame_stride": frame_stride, "max_frames": max_frames, **result, } @app.post("/api/compare-frame") def compare_single_frame( file: UploadFile = File(...), frame_index: int = Form(0), sensitivity: float = Form(0.56), ) -> dict[str, Any]: ensure_dirs() sensitivity = max(0.05, min(float(sensitivity), 0.95)) job_id = uuid.uuid4().hex[:12] job_path = JOB_DIR / job_id job_path.mkdir(parents=True, exist_ok=True) source = _save_upload(file, job_path) suffix = source.suffix.lower() if suffix in IMAGE_SUFFIXES: frame = cv2.imread(str(source), cv2.IMREAD_COLOR) previous = None source_time = 0.0 elif suffix in VIDEO_SUFFIXES: frame, previous, source_fps = _read_video_frame(source, frame_index) source_time = int(frame_index) / source_fps else: raise HTTPException(status_code=400, detail="不支持的文件类型") if frame is None: raise HTTPException(status_code=400, detail="无法读取文件") outputs = compare_frame(frame, previous, sensitivity) frames = [ _save_frame_outputs(job_path, int(frame_index), output.method, frame, output, source_time, 0.0, 0) for output in outputs ] return { "job_id": job_id, "kind": "compare", "frame_index": int(frame_index), "source_time": round(float(source_time), 4), "frames": frames, } @app.get("/") def index() -> FileResponse: return FileResponse(FRONTEND_DIR / "index.html") app.mount("/", StaticFiles(directory=FRONTEND_DIR), name="frontend")