Files
ISISeg/backend/main.py

238 lines
7.9 KiB
Python

from __future__ import annotations
import shutil
import uuid
from pathlib import Path
from typing import Any
import cv2
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, Response
from fastapi.staticfiles import StaticFiles
from backend.segmentation import METHOD_DESCRIPTIONS, compare_frame, segment_frame
ROOT = Path(__file__).resolve().parents[1]
FRONTEND_DIR = ROOT / "frontend"
STORAGE_DIR = ROOT / "storage"
UPLOAD_DIR = STORAGE_DIR / "uploads"
JOB_DIR = STORAGE_DIR / "jobs"
SAMPLE_DIR = STORAGE_DIR / "samples"
IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}
VIDEO_SUFFIXES = {".mp4", ".avi", ".mov", ".mkv", ".webm"}
app = FastAPI(title="ISISeg Guidewire Segmentation", version="0.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
def ensure_dirs() -> None:
for directory in (UPLOAD_DIR, JOB_DIR, SAMPLE_DIR):
directory.mkdir(parents=True, exist_ok=True)
ensure_dirs()
app.mount("/storage", StaticFiles(directory=STORAGE_DIR), name="storage")
@app.get("/api/health")
def health() -> dict[str, str]:
return {"status": "ok", "service": "ISISeg", "version": app.version}
@app.get("/api/methods")
def methods() -> dict[str, Any]:
return {"methods": METHOD_DESCRIPTIONS}
@app.get("/api/samples")
def samples() -> dict[str, Any]:
ensure_dirs()
items = []
for path in sorted(SAMPLE_DIR.glob("*")):
suffix = path.suffix.lower()
if suffix in IMAGE_SUFFIXES | VIDEO_SUFFIXES:
items.append(
{
"name": path.name,
"url": _public(path),
"kind": "image" if suffix in IMAGE_SUFFIXES else "video",
"size": path.stat().st_size,
}
)
return {"samples": items}
@app.get("/favicon.ico", include_in_schema=False)
def favicon() -> Response:
svg = """<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 64 64">
<rect width="64" height="64" rx="12" fill="#10211d"/>
<path d="M18 40V18h6v22h-6Zm14 0c-4.9 0-8.2-2.4-8.5-6.4h5.8c.3 1.4 1.4 2.2 3.1 2.2 1.6 0 2.6-.6 2.6-1.8 0-1.3-1.3-1.6-4.3-2.3-3.4-.8-6.6-2-6.6-6.5 0-4.3 3.3-7.2 8.1-7.2 4.7 0 7.8 2.4 8.2 6.4h-5.7c-.2-1.2-1.1-1.9-2.6-1.9-1.4 0-2.3.6-2.3 1.7 0 1.1 1.2 1.5 3.8 2.1 3.8.9 7.2 2.1 7.2 6.9 0 4.2-3.4 6.8-8.8 6.8Z" fill="#38d8b8"/>
</svg>"""
return Response(content=svg, media_type="image/svg+xml")
def _public(path: Path) -> str:
return "/" + path.relative_to(ROOT).as_posix()
def _save_upload(file: UploadFile, job_path: Path) -> Path:
suffix = Path(file.filename or "").suffix.lower()
if suffix not in IMAGE_SUFFIXES | VIDEO_SUFFIXES:
raise HTTPException(status_code=400, detail="仅支持图片或视频文件")
destination = job_path / f"upload{suffix}"
with destination.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
return destination
def _save_frame_outputs(
job_path: Path,
frame_index: int,
method: str,
frame,
output,
) -> dict[str, Any]:
method_path = job_path / method
method_path.mkdir(exist_ok=True)
original_path = method_path / f"frame_{frame_index:04d}_original.png"
mask_path = method_path / f"frame_{frame_index:04d}_mask.png"
overlay_path = method_path / f"frame_{frame_index:04d}_overlay.png"
cv2.imwrite(str(original_path), frame)
cv2.imwrite(str(mask_path), output.mask)
cv2.imwrite(str(overlay_path), output.overlay)
return {
"frame_index": frame_index,
"method": method,
"original_url": _public(original_path),
"mask_url": _public(mask_path),
"overlay_url": _public(overlay_path),
"metrics": output.metrics,
}
def _process_image(job_path: Path, source: Path, method: str, sensitivity: float) -> dict[str, Any]:
frame = cv2.imread(str(source), cv2.IMREAD_COLOR)
if frame is None:
raise HTTPException(status_code=400, detail="无法读取图片")
if method == "compare":
results = [
_save_frame_outputs(job_path, 0, item.method, frame, item)
for item in compare_frame(frame, None, sensitivity)
]
else:
output = segment_frame(frame, method, None, sensitivity)
results = [_save_frame_outputs(job_path, 0, method, frame, output)]
return {"kind": "image", "frames": results, "video_url": None}
def _selected_frame(index: int, stride: int, selected_count: int, max_frames: int) -> bool:
return index % max(1, stride) == 0 and selected_count < max_frames
def _process_video(
job_path: Path,
source: Path,
method: str,
sensitivity: float,
frame_stride: int,
max_frames: int,
) -> dict[str, Any]:
capture = cv2.VideoCapture(str(source))
if not capture.isOpened():
raise HTTPException(status_code=400, detail="无法读取视频")
frames: list[dict[str, Any]] = []
previous = None
frame_index = 0
selected_count = 0
writer = None
video_path = job_path / f"{method}_overlay.mp4"
try:
while True:
ok, frame = capture.read()
if not ok:
break
if not _selected_frame(frame_index, frame_stride, selected_count, max_frames):
previous = frame
frame_index += 1
continue
if method == "compare":
outputs = compare_frame(frame, previous, sensitivity)
for output in outputs:
frames.append(_save_frame_outputs(job_path, frame_index, output.method, frame, output))
video_output = next(item for item in outputs if item.method == "fusion")
else:
video_output = segment_frame(frame, method, previous, sensitivity)
frames.append(_save_frame_outputs(job_path, frame_index, method, frame, video_output))
if writer is None:
height, width = video_output.overlay.shape[:2]
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(str(video_path), fourcc, 8.0, (width, height))
writer.write(video_output.overlay)
selected_count += 1
previous = frame
frame_index += 1
finally:
capture.release()
if writer is not None:
writer.release()
if not frames:
raise HTTPException(status_code=400, detail="视频没有可处理帧")
return {"kind": "video", "frames": frames, "video_url": _public(video_path) if video_path.exists() else None}
@app.post("/api/segment")
def segment(
file: UploadFile = File(...),
method: str = Form("fusion"),
sensitivity: float = Form(0.56),
frame_stride: int = Form(5),
max_frames: int = Form(12),
) -> dict[str, Any]:
ensure_dirs()
if method not in METHOD_DESCRIPTIONS:
raise HTTPException(status_code=400, detail="未知分割方法")
sensitivity = max(0.05, min(float(sensitivity), 0.95))
frame_stride = max(1, min(int(frame_stride), 90))
max_frames = max(1, min(int(max_frames), 80))
job_id = uuid.uuid4().hex[:12]
job_path = JOB_DIR / job_id
job_path.mkdir(parents=True, exist_ok=True)
source = _save_upload(file, job_path)
suffix = source.suffix.lower()
if suffix in IMAGE_SUFFIXES:
result = _process_image(job_path, source, method, sensitivity)
elif suffix in VIDEO_SUFFIXES:
result = _process_video(job_path, source, method, sensitivity, frame_stride, max_frames)
else:
raise HTTPException(status_code=400, detail="不支持的文件类型")
return {
"job_id": job_id,
"method": method,
"sensitivity": sensitivity,
"frame_stride": frame_stride,
"max_frames": max_frames,
**result,
}
@app.get("/")
def index() -> FileResponse:
return FileResponse(FRONTEND_DIR / "index.html")
app.mount("/", StaticFiles(directory=FRONTEND_DIR), name="frontend")