2026-04-29-21-51-19 - 全栈系统改造:FastAPI后端+SAM2+PostgreSQL+Redis+MinIO+前端Zustand重构

This commit is contained in:
2026-04-29 22:17:25 +08:00
parent c8f8686097
commit fd4b5e5b3d
39 changed files with 3816 additions and 211 deletions

View File

View File

@@ -0,0 +1,186 @@
"""Video/DICOM frame parsing and MinIO upload utilities."""
import logging
import os
import shutil
import subprocess
from pathlib import Path
from typing import List, Optional
import cv2
import numpy as np
from pydicom import dcmread
from minio_client import upload_file, BUCKET_NAME
logger = logging.getLogger(__name__)
def parse_video(
video_path: str,
output_dir: str,
fps: int = 30,
max_frames: Optional[int] = None,
) -> List[str]:
"""Extract frames from a video file using FFmpeg or OpenCV fallback.
Args:
video_path: Path to the input video file.
output_dir: Directory to save extracted frames.
fps: Target frame extraction rate.
max_frames: Optional maximum number of frames to extract.
Returns:
List of paths to extracted frame images.
"""
os.makedirs(output_dir, exist_ok=True)
frame_paths: List[str] = []
# Try FFmpeg first
if shutil.which("ffmpeg"):
try:
pattern = os.path.join(output_dir, "frame_%06d.png")
cmd = [
"ffmpeg",
"-i", video_path,
"-vf", f"fps={fps},scale='min(1920,iw)':-1",
"-pix_fmt", "rgb24",
"-y",
pattern,
]
logger.info("Running FFmpeg: %s", " ".join(cmd))
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
if result.returncode == 0:
frame_paths = sorted(
[os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith(".png")]
)
if max_frames:
frame_paths = frame_paths[:max_frames]
logger.info("Extracted %d frames via FFmpeg", len(frame_paths))
return frame_paths
else:
logger.warning("FFmpeg failed: %s", result.stderr)
except Exception as exc: # noqa: BLE001
logger.warning("FFmpeg exception: %s", exc)
# OpenCV fallback
logger.info("Falling back to OpenCV frame extraction")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"Cannot open video: {video_path}")
video_fps = cap.get(cv2.CAP_PROP_FPS) or 30
interval = max(1, int(round(video_fps / fps)))
count = 0
saved = 0
while True:
ret, frame = cap.read()
if not ret:
break
if count % interval == 0:
path = os.path.join(output_dir, f"frame_{saved:06d}.png")
cv2.imwrite(path, frame)
frame_paths.append(path)
saved += 1
if max_frames and saved >= max_frames:
break
count += 1
cap.release()
logger.info("Extracted %d frames via OpenCV", len(frame_paths))
return frame_paths
def parse_dicom(
dicom_dir: str,
output_dir: str,
max_frames: Optional[int] = None,
) -> List[str]:
"""Extract frames from DICOM files in a directory.
Args:
dicom_dir: Directory containing .dcm files.
output_dir: Directory to save extracted frames.
max_frames: Optional maximum number of frames to extract.
Returns:
List of paths to extracted frame images.
"""
os.makedirs(output_dir, exist_ok=True)
dcm_files = sorted(
[f for f in os.listdir(dicom_dir) if f.lower().endswith(".dcm")]
)
frame_paths: List[str] = []
for idx, fname in enumerate(dcm_files):
if max_frames and idx >= max_frames:
break
path = os.path.join(dicom_dir, fname)
try:
ds = dcmread(path)
pixel_array = ds.pixel_array
# Normalize to 8-bit
if pixel_array.dtype != np.uint8:
pixel_array = pixel_array.astype(np.float32)
pixel_array = (
(pixel_array - pixel_array.min())
/ (pixel_array.max() - pixel_array.min() + 1e-8)
* 255
)
pixel_array = pixel_array.astype(np.uint8)
# Handle multi-frame DICOM
if pixel_array.ndim == 3:
for f in range(pixel_array.shape[0]):
out_path = os.path.join(output_dir, f"frame_{idx:06d}_{f:03d}.png")
cv2.imwrite(out_path, pixel_array[f])
frame_paths.append(out_path)
else:
out_path = os.path.join(output_dir, f"frame_{idx:06d}.png")
cv2.imwrite(out_path, pixel_array)
frame_paths.append(out_path)
except Exception as exc: # noqa: BLE001
logger.error("Failed to read DICOM %s: %s", path, exc)
logger.info("Extracted %d frames from DICOM", len(frame_paths))
return frame_paths
def upload_frames_to_minio(
frames: List[str],
project_id: int,
object_prefix: Optional[str] = None,
) -> List[str]:
"""Upload a list of local frame images to MinIO.
Args:
frames: List of local file paths.
project_id: Project ID used for bucket path organization.
object_prefix: Optional prefix override.
Returns:
List of object names (paths) in MinIO.
"""
prefix = object_prefix or f"projects/{project_id}/frames"
object_names: List[str] = []
for frame_path in frames:
fname = os.path.basename(frame_path)
object_name = f"{prefix}/{fname}"
try:
with open(frame_path, "rb") as f:
data = f.read()
upload_file(
object_name,
data,
content_type="image/png",
length=len(data),
)
object_names.append(object_name)
except Exception as exc: # noqa: BLE001
logger.error("Failed to upload %s: %s", frame_path, exc)
logger.info("Uploaded %d/%d frames to MinIO", len(object_names), len(frames))
return object_names

View File

@@ -0,0 +1,234 @@
"""SAM 2 engine wrapper with lazy loading and fallback stubs."""
import logging
import os
from typing import Optional
import numpy as np
from config import settings
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Attempt to import SAM 2; fall back to stubs if unavailable.
# ---------------------------------------------------------------------------
try:
import torch
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
SAM2_AVAILABLE = True
logger.info("SAM2 library imported successfully.")
except Exception as exc: # noqa: BLE001
SAM2_AVAILABLE = False
logger.warning("SAM2 import failed (%s). Using stub engine.", exc)
class SAM2Engine:
"""Lazy-loaded SAM 2 inference engine."""
def __init__(self) -> None:
self._predictor: Optional[SAM2ImagePredictor] = None
self._model_loaded = False
# -----------------------------------------------------------------------
# Internal helpers
# -----------------------------------------------------------------------
def _load_model(self) -> None:
"""Load the SAM 2 model and predictor on first use."""
if self._model_loaded:
return
if not SAM2_AVAILABLE:
logger.warning("SAM2 not available; skipping model load.")
self._model_loaded = True
return
if not os.path.isfile(settings.sam_model_path):
logger.error("SAM checkpoint not found at %s", settings.sam_model_path)
self._model_loaded = True
return
try:
model = build_sam2(
settings.sam_model_config,
settings.sam_model_path,
device="cuda",
)
self._predictor = SAM2ImagePredictor(model)
self._model_loaded = True
logger.info("SAM 2 model loaded from %s", settings.sam_model_path)
except Exception as exc: # noqa: BLE001
logger.error("Failed to load SAM 2 model: %s", exc)
self._model_loaded = True # Prevent repeated load attempts
def _ensure_ready(self) -> bool:
"""Ensure the model is loaded; return whether it is usable."""
self._load_model()
return SAM2_AVAILABLE and self._predictor is not None
# -----------------------------------------------------------------------
# Public API
# -----------------------------------------------------------------------
def predict_points(
self,
image: np.ndarray,
points: list[list[float]],
labels: list[int],
) -> tuple[list[list[list[float]]], list[float]]:
"""Run point-prompt segmentation.
Args:
image: HWC numpy array (uint8).
points: List of [x, y] normalized coordinates (0-1).
labels: 1 for foreground, 0 for background.
Returns:
Tuple of (polygons, scores).
"""
if not self._ensure_ready():
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
h, w = image.shape[:2]
pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32)
lbls = np.array(labels, dtype=np.int32)
with torch.inference_mode(): # type: ignore[name-defined]
self._predictor.set_image(image)
masks, scores, _ = self._predictor.predict(
point_coords=pts,
point_labels=lbls,
multimask_output=True,
)
polygons = []
for m in masks:
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores.tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 point prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def predict_box(
self,
image: np.ndarray,
box: list[float],
) -> tuple[list[list[list[float]]], list[float]]:
"""Run box-prompt segmentation.
Args:
image: HWC numpy array (uint8).
box: [x1, y1, x2, y2] normalized coordinates.
Returns:
Tuple of (polygons, scores).
"""
if not self._ensure_ready():
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
h, w = image.shape[:2]
bbox = np.array(
[box[0] * w, box[1] * h, box[2] * w, box[3] * h],
dtype=np.float32,
)
with torch.inference_mode(): # type: ignore[name-defined]
self._predictor.set_image(image)
masks, scores, _ = self._predictor.predict(
box=bbox[None, :],
multimask_output=False,
)
polygons = []
for m in masks:
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores.tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 box prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def predict_auto(self, image: np.ndarray) -> tuple[list[list[list[float]]], list[float]]:
"""Run automatic mask generation (grid of points).
Args:
image: HWC numpy array (uint8).
Returns:
Tuple of (polygons, scores).
"""
if not self._ensure_ready():
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
with torch.inference_mode(): # type: ignore[name-defined]
self._predictor.set_image(image)
# Generate a uniform 16x16 grid of point prompts
h, w = image.shape[:2]
grid = np.mgrid[0:1:17j, 0:1:17j].reshape(2, -1).T
pts = grid * np.array([w, h])
lbls = np.ones(pts.shape[0], dtype=np.int32)
masks, scores, _ = self._predictor.predict(
point_coords=pts,
point_labels=lbls,
multimask_output=True,
)
polygons = []
for m in masks[:3]: # Limit to top 3 masks
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores[:3].tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 auto prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
# -----------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------
@staticmethod
def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
"""Convert a binary mask to a normalized polygon."""
import cv2
if mask.dtype != np.uint8:
mask = (mask > 0).astype(np.uint8)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
h, w = mask.shape[:2]
largest = []
for cnt in contours:
if len(cnt) > len(largest):
largest = cnt
if len(largest) < 3:
return []
return [[float(pt[0][0]) / w, float(pt[0][1]) / h] for pt in largest]
@staticmethod
def _dummy_polygons(w: int, h: int) -> list[list[list[float]]]:
"""Return a dummy rectangle polygon for fallback mode."""
return [
[
[0.25, 0.25],
[0.75, 0.25],
[0.75, 0.75],
[0.25, 0.75],
]
]
# Singleton instance
sam_engine = SAM2Engine()