404 lines
14 KiB
Python
404 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import subprocess
|
|
import shutil
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import cv2
|
|
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import FileResponse, Response
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
from backend.segmentation import METHOD_DESCRIPTIONS, compare_frame, segment_frame
|
|
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
FRONTEND_DIR = ROOT / "frontend"
|
|
STORAGE_DIR = ROOT / "storage"
|
|
UPLOAD_DIR = STORAGE_DIR / "uploads"
|
|
JOB_DIR = STORAGE_DIR / "jobs"
|
|
SAMPLE_DIR = STORAGE_DIR / "samples"
|
|
|
|
IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}
|
|
VIDEO_SUFFIXES = {".mp4", ".avi", ".mov", ".mkv", ".webm"}
|
|
|
|
app = FastAPI(title="ISISeg Guidewire Segmentation", version="0.1.0")
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
def ensure_dirs() -> None:
|
|
for directory in (UPLOAD_DIR, JOB_DIR, SAMPLE_DIR):
|
|
directory.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
ensure_dirs()
|
|
app.mount("/storage", StaticFiles(directory=STORAGE_DIR), name="storage")
|
|
|
|
|
|
@app.get("/api/health")
|
|
def health() -> dict[str, str]:
|
|
return {"status": "ok", "service": "ISISeg", "version": app.version}
|
|
|
|
|
|
@app.get("/api/methods")
|
|
def methods() -> dict[str, Any]:
|
|
return {"methods": METHOD_DESCRIPTIONS}
|
|
|
|
|
|
@app.get("/api/samples")
|
|
def samples() -> dict[str, Any]:
|
|
ensure_dirs()
|
|
items = []
|
|
for path in sorted(SAMPLE_DIR.glob("*")):
|
|
suffix = path.suffix.lower()
|
|
if suffix in IMAGE_SUFFIXES | VIDEO_SUFFIXES:
|
|
version = str(path.stat().st_mtime_ns)
|
|
items.append(
|
|
{
|
|
"name": path.name,
|
|
"url": f"{_public(path)}?v={version}",
|
|
"kind": "image" if suffix in IMAGE_SUFFIXES else "video",
|
|
"size": path.stat().st_size,
|
|
"version": version,
|
|
}
|
|
)
|
|
return {"samples": items}
|
|
|
|
|
|
@app.get("/favicon.ico", include_in_schema=False)
|
|
def favicon() -> Response:
|
|
svg = """<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 64 64">
|
|
<rect width="64" height="64" rx="12" fill="#10211d"/>
|
|
<path d="M18 40V18h6v22h-6Zm14 0c-4.9 0-8.2-2.4-8.5-6.4h5.8c.3 1.4 1.4 2.2 3.1 2.2 1.6 0 2.6-.6 2.6-1.8 0-1.3-1.3-1.6-4.3-2.3-3.4-.8-6.6-2-6.6-6.5 0-4.3 3.3-7.2 8.1-7.2 4.7 0 7.8 2.4 8.2 6.4h-5.7c-.2-1.2-1.1-1.9-2.6-1.9-1.4 0-2.3.6-2.3 1.7 0 1.1 1.2 1.5 3.8 2.1 3.8.9 7.2 2.1 7.2 6.9 0 4.2-3.4 6.8-8.8 6.8Z" fill="#38d8b8"/>
|
|
</svg>"""
|
|
return Response(content=svg, media_type="image/svg+xml")
|
|
|
|
|
|
def _public(path: Path) -> str:
|
|
return "/" + path.relative_to(ROOT).as_posix()
|
|
|
|
|
|
def _save_upload(file: UploadFile, job_path: Path) -> Path:
|
|
suffix = Path(file.filename or "").suffix.lower()
|
|
if suffix not in IMAGE_SUFFIXES | VIDEO_SUFFIXES:
|
|
raise HTTPException(status_code=400, detail="仅支持图片或视频文件")
|
|
destination = job_path / f"upload{suffix}"
|
|
with destination.open("wb") as buffer:
|
|
shutil.copyfileobj(file.file, buffer)
|
|
return destination
|
|
|
|
|
|
def _save_frame_outputs(
|
|
job_path: Path,
|
|
frame_index: int,
|
|
method: str,
|
|
frame,
|
|
output,
|
|
source_time: float | None = None,
|
|
result_time: float | None = None,
|
|
result_index: int | None = None,
|
|
) -> dict[str, Any]:
|
|
method_path = job_path / method
|
|
method_path.mkdir(exist_ok=True)
|
|
original_path = method_path / f"frame_{frame_index:04d}_original.png"
|
|
mask_path = method_path / f"frame_{frame_index:04d}_mask.png"
|
|
overlay_path = method_path / f"frame_{frame_index:04d}_overlay.png"
|
|
cv2.imwrite(str(original_path), frame)
|
|
cv2.imwrite(str(mask_path), output.mask)
|
|
cv2.imwrite(str(overlay_path), output.overlay)
|
|
payload = {
|
|
"frame_index": frame_index,
|
|
"method": method,
|
|
"original_url": _public(original_path),
|
|
"mask_url": _public(mask_path),
|
|
"overlay_url": _public(overlay_path),
|
|
"metrics": output.metrics,
|
|
}
|
|
if source_time is not None:
|
|
payload["source_time"] = round(float(source_time), 4)
|
|
if result_time is not None:
|
|
payload["result_time"] = round(float(result_time), 4)
|
|
if result_index is not None:
|
|
payload["result_index"] = int(result_index)
|
|
return payload
|
|
|
|
|
|
def _browser_video(raw_path: Path, final_path: Path) -> Path:
|
|
ffmpeg = shutil.which("ffmpeg")
|
|
if ffmpeg:
|
|
subprocess.run(
|
|
[
|
|
ffmpeg,
|
|
"-y",
|
|
"-i",
|
|
str(raw_path),
|
|
"-c:v",
|
|
"libx264",
|
|
"-pix_fmt",
|
|
"yuv420p",
|
|
"-movflags",
|
|
"+faststart",
|
|
"-preset",
|
|
"veryfast",
|
|
"-crf",
|
|
"20",
|
|
str(final_path),
|
|
],
|
|
check=True,
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
raw_path.unlink(missing_ok=True)
|
|
else:
|
|
raw_path.replace(final_path)
|
|
return final_path
|
|
|
|
|
|
def _process_image(job_path: Path, source: Path, method: str, sensitivity: float) -> dict[str, Any]:
|
|
frame = cv2.imread(str(source), cv2.IMREAD_COLOR)
|
|
if frame is None:
|
|
raise HTTPException(status_code=400, detail="无法读取图片")
|
|
if method == "compare":
|
|
results = [
|
|
_save_frame_outputs(job_path, 0, item.method, frame, item, 0.0, 0.0, 0)
|
|
for item in compare_frame(frame, None, sensitivity)
|
|
]
|
|
else:
|
|
output = segment_frame(frame, method, None, sensitivity)
|
|
results = [_save_frame_outputs(job_path, 0, method, frame, output, 0.0, 0.0, 0)]
|
|
return {"kind": "image", "frames": results, "video_url": None, "source_fps": 1.0, "result_fps": 1.0, "duration": 0.0}
|
|
|
|
|
|
def _selected_frame(index: int, stride: int, selected_count: int, max_frames: int) -> bool:
|
|
return index % max(1, stride) == 0 and selected_count < max_frames
|
|
|
|
|
|
def _selected_frame_indices(frame_count: int, stride: int, max_frames: int) -> set[int]:
|
|
if frame_count <= 0:
|
|
return set()
|
|
candidates = list(range(0, frame_count, max(1, stride)))
|
|
if len(candidates) <= max_frames:
|
|
return set(candidates)
|
|
if max_frames <= 1:
|
|
return {candidates[0]}
|
|
last = len(candidates) - 1
|
|
return {
|
|
candidates[round(position * last / (max_frames - 1))]
|
|
for position in range(max_frames)
|
|
}
|
|
|
|
|
|
def _process_video(
|
|
job_path: Path,
|
|
source: Path,
|
|
method: str,
|
|
sensitivity: float,
|
|
frame_stride: int,
|
|
max_frames: int,
|
|
) -> dict[str, Any]:
|
|
capture = cv2.VideoCapture(str(source))
|
|
if not capture.isOpened():
|
|
raise HTTPException(status_code=400, detail="无法读取视频")
|
|
|
|
frames: list[dict[str, Any]] = []
|
|
previous = None
|
|
frame_index = 0
|
|
selected_count = 0
|
|
written_count = 0
|
|
writer = None
|
|
raw_video_path = job_path / f"{method}_overlay.raw.mp4"
|
|
video_path = job_path / f"{method}_overlay.mp4"
|
|
source_fps = float(capture.get(cv2.CAP_PROP_FPS) or 0.0)
|
|
if source_fps <= 0:
|
|
source_fps = 12.0
|
|
result_fps = source_fps
|
|
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
|
|
selected_indices = _selected_frame_indices(frame_count, frame_stride, max_frames)
|
|
|
|
try:
|
|
while True:
|
|
ok, frame = capture.read()
|
|
if not ok:
|
|
break
|
|
should_process = (
|
|
frame_index in selected_indices
|
|
if selected_indices
|
|
else _selected_frame(frame_index, frame_stride, selected_count, max_frames)
|
|
)
|
|
|
|
if method == "compare" and should_process:
|
|
outputs = compare_frame(frame, previous, sensitivity)
|
|
for output in outputs:
|
|
frames.append(
|
|
_save_frame_outputs(
|
|
job_path,
|
|
frame_index,
|
|
output.method,
|
|
frame,
|
|
output,
|
|
frame_index / source_fps,
|
|
frame_index / source_fps,
|
|
frame_index,
|
|
)
|
|
)
|
|
video_output = next(item for item in outputs if item.method == "fusion")
|
|
elif method == "compare":
|
|
video_output = segment_frame(frame, "fusion", previous, sensitivity)
|
|
else:
|
|
video_output = segment_frame(frame, method, previous, sensitivity)
|
|
if should_process:
|
|
frames.append(
|
|
_save_frame_outputs(
|
|
job_path,
|
|
frame_index,
|
|
method,
|
|
frame,
|
|
video_output,
|
|
frame_index / source_fps,
|
|
frame_index / source_fps,
|
|
frame_index,
|
|
)
|
|
)
|
|
if should_process:
|
|
selected_count += 1
|
|
|
|
if writer is None:
|
|
height, width = frame.shape[:2]
|
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
|
writer = cv2.VideoWriter(str(raw_video_path), fourcc, result_fps, (width, height))
|
|
writer.write(video_output.overlay)
|
|
written_count += 1
|
|
|
|
previous = frame
|
|
frame_index += 1
|
|
finally:
|
|
capture.release()
|
|
if writer is not None:
|
|
writer.release()
|
|
|
|
if not frames:
|
|
raise HTTPException(status_code=400, detail="视频没有可处理帧")
|
|
if raw_video_path.exists():
|
|
_browser_video(raw_video_path, video_path)
|
|
duration = round(written_count / source_fps, 4) if written_count else 0.0
|
|
return {
|
|
"kind": "video",
|
|
"frames": frames,
|
|
"video_url": _public(video_path) if video_path.exists() else None,
|
|
"source_fps": round(source_fps, 4),
|
|
"result_fps": round(result_fps, 4),
|
|
"duration": duration,
|
|
"result_duration": duration,
|
|
}
|
|
|
|
|
|
def _read_video_frame(source: Path, frame_index: int) -> tuple[Any, Any | None, float]:
|
|
capture = cv2.VideoCapture(str(source))
|
|
if not capture.isOpened():
|
|
raise HTTPException(status_code=400, detail="无法读取视频")
|
|
source_fps = float(capture.get(cv2.CAP_PROP_FPS) or 0.0)
|
|
if source_fps <= 0:
|
|
source_fps = 12.0
|
|
frame_index = max(0, int(frame_index))
|
|
previous = None
|
|
if frame_index > 0:
|
|
capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index - 1)
|
|
ok_prev, previous = capture.read()
|
|
if not ok_prev:
|
|
previous = None
|
|
capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
|
|
ok, frame = capture.read()
|
|
capture.release()
|
|
if not ok:
|
|
raise HTTPException(status_code=400, detail="无法读取指定帧")
|
|
return frame, previous, source_fps
|
|
|
|
|
|
@app.post("/api/segment")
|
|
def segment(
|
|
file: UploadFile = File(...),
|
|
method: str = Form("fusion"),
|
|
sensitivity: float = Form(0.56),
|
|
frame_stride: int = Form(5),
|
|
max_frames: int = Form(12),
|
|
) -> dict[str, Any]:
|
|
ensure_dirs()
|
|
if method not in METHOD_DESCRIPTIONS:
|
|
raise HTTPException(status_code=400, detail="未知分割方法")
|
|
sensitivity = max(0.05, min(float(sensitivity), 0.95))
|
|
frame_stride = max(1, min(int(frame_stride), 90))
|
|
max_frames = max(1, min(int(max_frames), 80))
|
|
job_id = uuid.uuid4().hex[:12]
|
|
job_path = JOB_DIR / job_id
|
|
job_path.mkdir(parents=True, exist_ok=True)
|
|
source = _save_upload(file, job_path)
|
|
suffix = source.suffix.lower()
|
|
if suffix in IMAGE_SUFFIXES:
|
|
result = _process_image(job_path, source, method, sensitivity)
|
|
elif suffix in VIDEO_SUFFIXES:
|
|
result = _process_video(job_path, source, method, sensitivity, frame_stride, max_frames)
|
|
else:
|
|
raise HTTPException(status_code=400, detail="不支持的文件类型")
|
|
return {
|
|
"job_id": job_id,
|
|
"method": method,
|
|
"sensitivity": sensitivity,
|
|
"frame_stride": frame_stride,
|
|
"max_frames": max_frames,
|
|
**result,
|
|
}
|
|
|
|
|
|
@app.post("/api/compare-frame")
|
|
def compare_single_frame(
|
|
file: UploadFile = File(...),
|
|
frame_index: int = Form(0),
|
|
sensitivity: float = Form(0.56),
|
|
) -> dict[str, Any]:
|
|
ensure_dirs()
|
|
sensitivity = max(0.05, min(float(sensitivity), 0.95))
|
|
job_id = uuid.uuid4().hex[:12]
|
|
job_path = JOB_DIR / job_id
|
|
job_path.mkdir(parents=True, exist_ok=True)
|
|
source = _save_upload(file, job_path)
|
|
suffix = source.suffix.lower()
|
|
if suffix in IMAGE_SUFFIXES:
|
|
frame = cv2.imread(str(source), cv2.IMREAD_COLOR)
|
|
previous = None
|
|
source_time = 0.0
|
|
elif suffix in VIDEO_SUFFIXES:
|
|
frame, previous, source_fps = _read_video_frame(source, frame_index)
|
|
source_time = int(frame_index) / source_fps
|
|
else:
|
|
raise HTTPException(status_code=400, detail="不支持的文件类型")
|
|
if frame is None:
|
|
raise HTTPException(status_code=400, detail="无法读取文件")
|
|
outputs = compare_frame(frame, previous, sensitivity)
|
|
frames = [
|
|
_save_frame_outputs(job_path, int(frame_index), output.method, frame, output, source_time, 0.0, 0)
|
|
for output in outputs
|
|
]
|
|
return {
|
|
"job_id": job_id,
|
|
"kind": "compare",
|
|
"frame_index": int(frame_index),
|
|
"source_time": round(float(source_time), 4),
|
|
"frames": frames,
|
|
}
|
|
|
|
|
|
@app.get("/")
|
|
def index() -> FileResponse:
|
|
return FileResponse(FRONTEND_DIR / "index.html")
|
|
|
|
|
|
app.mount("/", StaticFiles(directory=FRONTEND_DIR), name="frontend")
|