Files
ISISeg/backend/main.py

404 lines
14 KiB
Python

from __future__ import annotations
import subprocess
import shutil
import uuid
from pathlib import Path
from typing import Any
import cv2
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, Response
from fastapi.staticfiles import StaticFiles
from backend.segmentation import METHOD_DESCRIPTIONS, compare_frame, segment_frame
ROOT = Path(__file__).resolve().parents[1]
FRONTEND_DIR = ROOT / "frontend"
STORAGE_DIR = ROOT / "storage"
UPLOAD_DIR = STORAGE_DIR / "uploads"
JOB_DIR = STORAGE_DIR / "jobs"
SAMPLE_DIR = STORAGE_DIR / "samples"
IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}
VIDEO_SUFFIXES = {".mp4", ".avi", ".mov", ".mkv", ".webm"}
app = FastAPI(title="ISISeg Guidewire Segmentation", version="0.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
def ensure_dirs() -> None:
for directory in (UPLOAD_DIR, JOB_DIR, SAMPLE_DIR):
directory.mkdir(parents=True, exist_ok=True)
ensure_dirs()
app.mount("/storage", StaticFiles(directory=STORAGE_DIR), name="storage")
@app.get("/api/health")
def health() -> dict[str, str]:
return {"status": "ok", "service": "ISISeg", "version": app.version}
@app.get("/api/methods")
def methods() -> dict[str, Any]:
return {"methods": METHOD_DESCRIPTIONS}
@app.get("/api/samples")
def samples() -> dict[str, Any]:
ensure_dirs()
items = []
for path in sorted(SAMPLE_DIR.glob("*")):
suffix = path.suffix.lower()
if suffix in IMAGE_SUFFIXES | VIDEO_SUFFIXES:
version = str(path.stat().st_mtime_ns)
items.append(
{
"name": path.name,
"url": f"{_public(path)}?v={version}",
"kind": "image" if suffix in IMAGE_SUFFIXES else "video",
"size": path.stat().st_size,
"version": version,
}
)
return {"samples": items}
@app.get("/favicon.ico", include_in_schema=False)
def favicon() -> Response:
svg = """<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 64 64">
<rect width="64" height="64" rx="12" fill="#10211d"/>
<path d="M18 40V18h6v22h-6Zm14 0c-4.9 0-8.2-2.4-8.5-6.4h5.8c.3 1.4 1.4 2.2 3.1 2.2 1.6 0 2.6-.6 2.6-1.8 0-1.3-1.3-1.6-4.3-2.3-3.4-.8-6.6-2-6.6-6.5 0-4.3 3.3-7.2 8.1-7.2 4.7 0 7.8 2.4 8.2 6.4h-5.7c-.2-1.2-1.1-1.9-2.6-1.9-1.4 0-2.3.6-2.3 1.7 0 1.1 1.2 1.5 3.8 2.1 3.8.9 7.2 2.1 7.2 6.9 0 4.2-3.4 6.8-8.8 6.8Z" fill="#38d8b8"/>
</svg>"""
return Response(content=svg, media_type="image/svg+xml")
def _public(path: Path) -> str:
return "/" + path.relative_to(ROOT).as_posix()
def _save_upload(file: UploadFile, job_path: Path) -> Path:
suffix = Path(file.filename or "").suffix.lower()
if suffix not in IMAGE_SUFFIXES | VIDEO_SUFFIXES:
raise HTTPException(status_code=400, detail="仅支持图片或视频文件")
destination = job_path / f"upload{suffix}"
with destination.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
return destination
def _save_frame_outputs(
job_path: Path,
frame_index: int,
method: str,
frame,
output,
source_time: float | None = None,
result_time: float | None = None,
result_index: int | None = None,
) -> dict[str, Any]:
method_path = job_path / method
method_path.mkdir(exist_ok=True)
original_path = method_path / f"frame_{frame_index:04d}_original.png"
mask_path = method_path / f"frame_{frame_index:04d}_mask.png"
overlay_path = method_path / f"frame_{frame_index:04d}_overlay.png"
cv2.imwrite(str(original_path), frame)
cv2.imwrite(str(mask_path), output.mask)
cv2.imwrite(str(overlay_path), output.overlay)
payload = {
"frame_index": frame_index,
"method": method,
"original_url": _public(original_path),
"mask_url": _public(mask_path),
"overlay_url": _public(overlay_path),
"metrics": output.metrics,
}
if source_time is not None:
payload["source_time"] = round(float(source_time), 4)
if result_time is not None:
payload["result_time"] = round(float(result_time), 4)
if result_index is not None:
payload["result_index"] = int(result_index)
return payload
def _browser_video(raw_path: Path, final_path: Path) -> Path:
ffmpeg = shutil.which("ffmpeg")
if ffmpeg:
subprocess.run(
[
ffmpeg,
"-y",
"-i",
str(raw_path),
"-c:v",
"libx264",
"-pix_fmt",
"yuv420p",
"-movflags",
"+faststart",
"-preset",
"veryfast",
"-crf",
"20",
str(final_path),
],
check=True,
capture_output=True,
text=True,
)
raw_path.unlink(missing_ok=True)
else:
raw_path.replace(final_path)
return final_path
def _process_image(job_path: Path, source: Path, method: str, sensitivity: float) -> dict[str, Any]:
frame = cv2.imread(str(source), cv2.IMREAD_COLOR)
if frame is None:
raise HTTPException(status_code=400, detail="无法读取图片")
if method == "compare":
results = [
_save_frame_outputs(job_path, 0, item.method, frame, item, 0.0, 0.0, 0)
for item in compare_frame(frame, None, sensitivity)
]
else:
output = segment_frame(frame, method, None, sensitivity)
results = [_save_frame_outputs(job_path, 0, method, frame, output, 0.0, 0.0, 0)]
return {"kind": "image", "frames": results, "video_url": None, "source_fps": 1.0, "result_fps": 1.0, "duration": 0.0}
def _selected_frame(index: int, stride: int, selected_count: int, max_frames: int) -> bool:
return index % max(1, stride) == 0 and selected_count < max_frames
def _selected_frame_indices(frame_count: int, stride: int, max_frames: int) -> set[int]:
if frame_count <= 0:
return set()
candidates = list(range(0, frame_count, max(1, stride)))
if len(candidates) <= max_frames:
return set(candidates)
if max_frames <= 1:
return {candidates[0]}
last = len(candidates) - 1
return {
candidates[round(position * last / (max_frames - 1))]
for position in range(max_frames)
}
def _process_video(
job_path: Path,
source: Path,
method: str,
sensitivity: float,
frame_stride: int,
max_frames: int,
) -> dict[str, Any]:
capture = cv2.VideoCapture(str(source))
if not capture.isOpened():
raise HTTPException(status_code=400, detail="无法读取视频")
frames: list[dict[str, Any]] = []
previous = None
frame_index = 0
selected_count = 0
written_count = 0
writer = None
raw_video_path = job_path / f"{method}_overlay.raw.mp4"
video_path = job_path / f"{method}_overlay.mp4"
source_fps = float(capture.get(cv2.CAP_PROP_FPS) or 0.0)
if source_fps <= 0:
source_fps = 12.0
result_fps = source_fps
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
selected_indices = _selected_frame_indices(frame_count, frame_stride, max_frames)
try:
while True:
ok, frame = capture.read()
if not ok:
break
should_process = (
frame_index in selected_indices
if selected_indices
else _selected_frame(frame_index, frame_stride, selected_count, max_frames)
)
video_frame = frame
if should_process and method == "compare":
outputs = compare_frame(frame, previous, sensitivity)
for output in outputs:
frames.append(
_save_frame_outputs(
job_path,
frame_index,
output.method,
frame,
output,
frame_index / source_fps,
frame_index / source_fps,
frame_index,
)
)
video_output = next(item for item in outputs if item.method == "fusion")
video_frame = video_output.overlay
selected_count += 1
elif should_process:
video_output = segment_frame(frame, method, previous, sensitivity)
frames.append(
_save_frame_outputs(
job_path,
frame_index,
method,
frame,
video_output,
frame_index / source_fps,
frame_index / source_fps,
frame_index,
)
)
video_frame = video_output.overlay
selected_count += 1
if writer is None:
height, width = frame.shape[:2]
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(str(raw_video_path), fourcc, result_fps, (width, height))
writer.write(video_frame)
written_count += 1
previous = frame
frame_index += 1
finally:
capture.release()
if writer is not None:
writer.release()
if not frames:
raise HTTPException(status_code=400, detail="视频没有可处理帧")
if raw_video_path.exists():
_browser_video(raw_video_path, video_path)
duration = round(written_count / source_fps, 4) if written_count else 0.0
return {
"kind": "video",
"frames": frames,
"video_url": _public(video_path) if video_path.exists() else None,
"source_fps": round(source_fps, 4),
"result_fps": round(result_fps, 4),
"duration": duration,
"result_duration": duration,
}
def _read_video_frame(source: Path, frame_index: int) -> tuple[Any, Any | None, float]:
capture = cv2.VideoCapture(str(source))
if not capture.isOpened():
raise HTTPException(status_code=400, detail="无法读取视频")
source_fps = float(capture.get(cv2.CAP_PROP_FPS) or 0.0)
if source_fps <= 0:
source_fps = 12.0
frame_index = max(0, int(frame_index))
previous = None
if frame_index > 0:
capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index - 1)
ok_prev, previous = capture.read()
if not ok_prev:
previous = None
capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
ok, frame = capture.read()
capture.release()
if not ok:
raise HTTPException(status_code=400, detail="无法读取指定帧")
return frame, previous, source_fps
@app.post("/api/segment")
def segment(
file: UploadFile = File(...),
method: str = Form("fusion"),
sensitivity: float = Form(0.56),
frame_stride: int = Form(5),
max_frames: int = Form(12),
) -> dict[str, Any]:
ensure_dirs()
if method not in METHOD_DESCRIPTIONS:
raise HTTPException(status_code=400, detail="未知分割方法")
sensitivity = max(0.05, min(float(sensitivity), 0.95))
frame_stride = max(1, min(int(frame_stride), 90))
max_frames = max(1, min(int(max_frames), 80))
job_id = uuid.uuid4().hex[:12]
job_path = JOB_DIR / job_id
job_path.mkdir(parents=True, exist_ok=True)
source = _save_upload(file, job_path)
suffix = source.suffix.lower()
if suffix in IMAGE_SUFFIXES:
result = _process_image(job_path, source, method, sensitivity)
elif suffix in VIDEO_SUFFIXES:
result = _process_video(job_path, source, method, sensitivity, frame_stride, max_frames)
else:
raise HTTPException(status_code=400, detail="不支持的文件类型")
return {
"job_id": job_id,
"method": method,
"sensitivity": sensitivity,
"frame_stride": frame_stride,
"max_frames": max_frames,
**result,
}
@app.post("/api/compare-frame")
def compare_single_frame(
file: UploadFile = File(...),
frame_index: int = Form(0),
sensitivity: float = Form(0.56),
) -> dict[str, Any]:
ensure_dirs()
sensitivity = max(0.05, min(float(sensitivity), 0.95))
job_id = uuid.uuid4().hex[:12]
job_path = JOB_DIR / job_id
job_path.mkdir(parents=True, exist_ok=True)
source = _save_upload(file, job_path)
suffix = source.suffix.lower()
if suffix in IMAGE_SUFFIXES:
frame = cv2.imread(str(source), cv2.IMREAD_COLOR)
previous = None
source_time = 0.0
elif suffix in VIDEO_SUFFIXES:
frame, previous, source_fps = _read_video_frame(source, frame_index)
source_time = int(frame_index) / source_fps
else:
raise HTTPException(status_code=400, detail="不支持的文件类型")
if frame is None:
raise HTTPException(status_code=400, detail="无法读取文件")
outputs = compare_frame(frame, previous, sensitivity)
frames = [
_save_frame_outputs(job_path, int(frame_index), output.method, frame, output, source_time, 0.0, 0)
for output in outputs
]
return {
"job_id": job_id,
"kind": "compare",
"frame_index": int(frame_index),
"source_time": round(float(source_time), 4),
"frames": frames,
}
@app.get("/")
def index() -> FileResponse:
return FileResponse(FRONTEND_DIR / "index.html")
app.mount("/", StaticFiles(directory=FRONTEND_DIR), name="frontend")