20260430_001744-feat: PyTorch CUDA + SAM2 GPU inference, video thumbnail, real FPS + configurable parse FPS, DICOM batch import

This commit is contained in:
2026-04-30 00:30:58 +08:00
parent 35d6e1503c
commit 6d008ec4a2
15 changed files with 555 additions and 101 deletions

View File

@@ -20,7 +20,7 @@ class Settings(BaseSettings):
# SAM2
sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
sam_model_config: str = "sam2_hiera_t.yaml"
sam_model_config: str = "configs/sam2/sam2_hiera_t.yaml"
# App
app_env: str = "development"

View File

@@ -30,7 +30,7 @@ def _seed_default_project_sync() -> None:
"""Synchronously seed the default video project on first startup."""
import cv2
from models import Project, Frame
from services.frame_parser import parse_video, upload_frames_to_minio
from services.frame_parser import parse_video, upload_frames_to_minio, extract_thumbnail
db = SessionLocal()
try:
@@ -46,6 +46,8 @@ def _seed_default_project_sync() -> None:
name="Data_MyVideo_1",
description="默认演示视频",
status="pending",
source_type="video",
parse_fps=30.0,
)
db.add(project)
db.commit()
@@ -67,7 +69,20 @@ def _seed_default_project_sync() -> None:
f.write(data)
output_dir = os.path.join(tmp_dir, "frames")
os.makedirs(output_dir, exist_ok=True)
frame_files = parse_video(local_path, output_dir, fps=30, max_frames=100)
frame_files, original_fps = parse_video(local_path, output_dir, fps=30, max_frames=100)
project.original_fps = original_fps
# Extract thumbnail
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
try:
extract_thumbnail(local_path, thumbnail_path)
with open(thumbnail_path, "rb") as f:
thumb_data = f.read()
thumb_object = f"projects/{project.id}/thumbnail.jpg"
upload_file(thumb_object, thumb_data, content_type="image/jpeg", length=len(thumb_data))
project.thumbnail_url = thumb_object
except Exception as exc: # noqa: BLE001
logger.warning("Thumbnail extraction failed: %s", exc)
object_names = upload_frames_to_minio(frame_files, project.id)

View File

@@ -25,7 +25,11 @@ class Project(Base):
name = Column(String(255), nullable=False)
description = Column(Text, nullable=True)
video_path = Column(String(512), nullable=True)
thumbnail_url = Column(String(512), nullable=True)
status = Column(String(50), default="Ready", nullable=False)
source_type = Column(String(20), default="video", nullable=False) # video | dicom
original_fps = Column(Float, nullable=True)
parse_fps = Column(Float, default=30.0, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()

View File

@@ -6,16 +6,19 @@ import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Optional
from typing import List, Optional
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
from sqlalchemy.orm import Session
from database import get_db
from minio_client import upload_file, get_presigned_url
from minio_client import upload_file, get_presigned_url, download_file
from models import Project, Frame
from schemas import FrameOut
from services.frame_parser import parse_video, parse_dicom, upload_frames_to_minio
from services.frame_parser import (
parse_video, parse_dicom, upload_frames_to_minio,
extract_thumbnail, get_video_fps,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/media", tags=["Media"])
@@ -78,6 +81,7 @@ async def upload_media(
description="Auto-created from upload",
status="pending",
video_path=object_name,
source_type="video",
)
db.add(project)
db.commit()
@@ -90,7 +94,6 @@ async def upload_media(
db.commit()
logger.info("Auto-created project id=%s for upload %s", project_id, file.filename)
# TODO: enqueue async parsing job (Celery / background task)
logger.info("Upload complete: %s (size=%d bytes). Async parsing queued.", object_name, len(data))
return {
@@ -102,6 +105,66 @@ async def upload_media(
}
@router.post(
"/upload/dicom",
status_code=status.HTTP_201_CREATED,
summary="Upload multiple DICOM files",
)
async def upload_dicom_batch(
files: List[UploadFile] = File(...),
project_id: Optional[int] = Form(None),
db: Session = Depends(get_db),
) -> dict:
"""Upload multiple .dcm files for a DICOM series.
If project_id is provided, files are added to the existing project.
Otherwise a new DICOM project is created.
"""
if not files:
raise HTTPException(status_code=400, detail="No files uploaded")
uploaded = []
if project_id:
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
else:
# Create new DICOM project
first_name = files[0].filename or "DICOM_Series"
project = Project(
name=first_name,
description=f"DICOM series with {len(files)} files",
status="pending",
source_type="dicom",
)
db.add(project)
db.commit()
db.refresh(project)
project_id = project.id
logger.info("Auto-created DICOM project id=%s", project_id)
for file in files:
if not file.filename or not file.filename.lower().endswith(".dcm"):
continue
data = await file.read()
object_name = f"uploads/{project_id}/dicom/{file.filename}"
try:
upload_file(object_name, data, content_type="application/dicom", length=len(data))
uploaded.append(object_name)
except Exception as exc: # noqa: BLE001
logger.error("Failed to upload DICOM %s: %s", file.filename, exc)
project.video_path = f"uploads/{project_id}/dicom"
db.commit()
return {
"project_id": project_id,
"uploaded_count": len(uploaded),
"message": f"Uploaded {len(uploaded)} DICOM files. Parsing job queued.",
}
@router.post(
"/parse",
status_code=status.HTTP_202_ACCEPTED,
@@ -109,12 +172,12 @@ async def upload_media(
)
def parse_media(
project_id: int,
source_type: str = "video", # video | dicom
source_type: Optional[str] = None,
db: Session = Depends(get_db),
) -> dict:
"""Trigger frame extraction for a project's uploaded media.
* video: uses FFmpeg or OpenCV fallback.
* video: uses FFmpeg or OpenCV fallback, extracts thumbnail.
* dicom: uses pydicom to read DCM frames.
Extracted frames are uploaded to MinIO and registered in the database.
@@ -126,37 +189,53 @@ def parse_media(
if not project.video_path:
raise HTTPException(status_code=400, detail="Project has no media uploaded")
# Download from MinIO to a temp directory
from minio_client import download_file
try:
media_bytes = download_file(project.video_path)
except Exception as exc: # noqa: BLE001
logger.error("Failed to download media for parsing: %s", exc)
raise HTTPException(status_code=500, detail="Failed to retrieve media from storage") from exc
effective_source = source_type or project.source_type or "video"
parse_fps = project.parse_fps or 30.0
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project_id}_")
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
with open(local_path, "wb") as f:
f.write(media_bytes)
output_dir = os.path.join(tmp_dir, "frames")
os.makedirs(output_dir, exist_ok=True)
try:
if source_type == "dicom":
# For DICOM, treat local_path as a directory if it contains multiple .dcm
# If a single .dcm file was uploaded, put it in its own sub-dir
if effective_source == "dicom":
# Download all dicom files from MinIO
dcm_dir = os.path.join(tmp_dir, "dcm")
os.makedirs(dcm_dir, exist_ok=True)
if local_path.lower().endswith(".dcm"):
shutil.move(local_path, os.path.join(dcm_dir, os.path.basename(local_path)))
else:
shutil.unpack_archive(local_path, dcm_dir) if shutil.which("unzip") else shutil.move(local_path, dcm_dir)
from minio_client import get_minio_client, BUCKET_NAME
client = get_minio_client()
prefix = project.video_path
objects = list(client.list_objects(BUCKET_NAME, prefix=prefix, recursive=True))
for obj in objects:
if obj.object_name.lower().endswith(".dcm"):
data = download_file(obj.object_name)
local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name))
with open(local_dcm, "wb") as f:
f.write(data)
frame_files = parse_dicom(dcm_dir, output_dir)
else:
frame_files = parse_video(local_path, output_dir, fps=30)
# Video: download and parse
media_bytes = download_file(project.video_path)
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
with open(local_path, "wb") as f:
f.write(media_bytes)
frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps))
project.original_fps = original_fps
# Extract thumbnail from first frame
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
try:
extract_thumbnail(local_path, thumbnail_path)
with open(thumbnail_path, "rb") as f:
thumb_data = f.read()
thumb_object = f"projects/{project_id}/thumbnail.jpg"
upload_file(thumb_object, thumb_data, content_type="image/jpeg", length=len(thumb_data))
project.thumbnail_url = thumb_object
logger.info("Uploaded thumbnail for project_id=%s", project_id)
except Exception as exc: # noqa: BLE001
logger.warning("Thumbnail extraction failed: %s", exc)
except Exception as exc: # noqa: BLE001
logger.error("Frame extraction failed: %s", exc)
shutil.rmtree(tmp_dir, ignore_errors=True)
@@ -173,7 +252,6 @@ def parse_media(
# Register frames in DB
frames_out = []
for idx, obj_name in enumerate(object_names):
# Get image dimensions
local_frame = frame_files[idx]
try:
import cv2

View File

@@ -44,6 +44,8 @@ def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)
projects = db.query(Project).offset(skip).limit(limit).all()
for p in projects:
p.frame_count = len(p.frames)
if p.thumbnail_url:
p.thumbnail_url = get_presigned_url(p.thumbnail_url, expires=3600)
return projects
@@ -58,6 +60,8 @@ def get_project(project_id: int, db: Session = Depends(get_db)) -> Project:
if not project:
raise HTTPException(status_code=404, detail="Project not found")
project.frame_count = len(project.frames)
if project.thumbnail_url:
project.thumbnail_url = get_presigned_url(project.thumbnail_url, expires=3600)
return project

View File

@@ -12,7 +12,11 @@ class ProjectBase(BaseModel):
name: str
description: Optional[str] = None
video_path: Optional[str] = None
thumbnail_url: Optional[str] = None
status: Optional[str] = "pending"
source_type: Optional[str] = "video"
original_fps: Optional[float] = None
parse_fps: Optional[float] = 30.0
class ProjectCreate(ProjectBase):
@@ -23,7 +27,11 @@ class ProjectUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
video_path: Optional[str] = None
thumbnail_url: Optional[str] = None
status: Optional[str] = None
source_type: Optional[str] = None
original_fps: Optional[float] = None
parse_fps: Optional[float] = None
class ProjectOut(ProjectBase):
@@ -103,7 +111,7 @@ class AnnotationCreate(AnnotationBase):
class AnnotationUpdate(BaseModel):
mask_data: Optional[dict[str, Any]] = None
points: Optional[list[list[float]]] = None
points: Optional[list[float]] = None
bbox: Optional[list[float]] = None
template_id: Optional[int] = None

View File

@@ -5,7 +5,7 @@ import os
import shutil
import subprocess
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Tuple
import cv2
import numpy as np
@@ -16,12 +16,43 @@ from minio_client import upload_file, BUCKET_NAME
logger = logging.getLogger(__name__)
def get_video_fps(video_path: str) -> float:
"""Read the original frame rate of a video file."""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return 30.0
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()
return fps if fps > 0 else 30.0
def extract_thumbnail(video_path: str, output_path: str, width: int = 640) -> str:
"""Extract the first frame of a video as a thumbnail JPEG."""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"Cannot open video for thumbnail: {video_path}")
ret, frame = cap.read()
cap.release()
if not ret or frame is None:
raise RuntimeError(f"Cannot read first frame from: {video_path}")
h, w = frame.shape[:2]
if w > width:
scale = width / w
new_w = int(w * scale)
new_h = int(h * scale)
frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
cv2.imwrite(output_path, frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
return output_path
def parse_video(
video_path: str,
output_dir: str,
fps: int = 30,
max_frames: Optional[int] = None,
) -> List[str]:
) -> Tuple[List[str], float]:
"""Extract frames from a video file using FFmpeg or OpenCV fallback.
Args:
@@ -31,10 +62,11 @@ def parse_video(
max_frames: Optional maximum number of frames to extract.
Returns:
List of paths to extracted frame images.
Tuple of (frame_paths, original_fps).
"""
os.makedirs(output_dir, exist_ok=True)
frame_paths: List[str] = []
original_fps = get_video_fps(video_path)
# Try FFmpeg first
if shutil.which("ffmpeg"):
@@ -57,7 +89,7 @@ def parse_video(
if max_frames:
frame_paths = frame_paths[:max_frames]
logger.info("Extracted %d frames via FFmpeg", len(frame_paths))
return frame_paths
return frame_paths, original_fps
else:
logger.warning("FFmpeg failed: %s", result.stderr)
except Exception as exc: # noqa: BLE001
@@ -89,7 +121,7 @@ def parse_video(
cap.release()
logger.info("Extracted %d frames via OpenCV", len(frame_paths))
return frame_paths
return frame_paths, original_fps
def parse_dicom(
@@ -134,12 +166,12 @@ def parse_dicom(
# Handle multi-frame DICOM
if pixel_array.ndim == 3:
for f in range(pixel_array.shape[0]):
out_path = os.path.join(output_dir, f"frame_{idx:06d}_{f:03d}.png")
cv2.imwrite(out_path, pixel_array[f])
out_path = os.path.join(output_dir, f"frame_{idx:06d}_{f:03d}.jpg")
cv2.imwrite(out_path, pixel_array[f], [cv2.IMWRITE_JPEG_QUALITY, 85])
frame_paths.append(out_path)
else:
out_path = os.path.join(output_dir, f"frame_{idx:06d}.png")
cv2.imwrite(out_path, pixel_array)
out_path = os.path.join(output_dir, f"frame_{idx:06d}.jpg")
cv2.imwrite(out_path, pixel_array, [cv2.IMWRITE_JPEG_QUALITY, 85])
frame_paths.append(out_path)
except Exception as exc: # noqa: BLE001
logger.error("Failed to read DICOM %s: %s", path, exc)