20260430_001744-feat: PyTorch CUDA + SAM2 GPU inference, video thumbnail, real FPS + configurable parse FPS, DICOM batch import
This commit is contained in:
@@ -20,7 +20,7 @@ class Settings(BaseSettings):
|
||||
|
||||
# SAM2
|
||||
sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
|
||||
sam_model_config: str = "sam2_hiera_t.yaml"
|
||||
sam_model_config: str = "configs/sam2/sam2_hiera_t.yaml"
|
||||
|
||||
# App
|
||||
app_env: str = "development"
|
||||
|
||||
@@ -30,7 +30,7 @@ def _seed_default_project_sync() -> None:
|
||||
"""Synchronously seed the default video project on first startup."""
|
||||
import cv2
|
||||
from models import Project, Frame
|
||||
from services.frame_parser import parse_video, upload_frames_to_minio
|
||||
from services.frame_parser import parse_video, upload_frames_to_minio, extract_thumbnail
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
@@ -46,6 +46,8 @@ def _seed_default_project_sync() -> None:
|
||||
name="Data_MyVideo_1",
|
||||
description="默认演示视频",
|
||||
status="pending",
|
||||
source_type="video",
|
||||
parse_fps=30.0,
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
@@ -67,7 +69,20 @@ def _seed_default_project_sync() -> None:
|
||||
f.write(data)
|
||||
output_dir = os.path.join(tmp_dir, "frames")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
frame_files = parse_video(local_path, output_dir, fps=30, max_frames=100)
|
||||
frame_files, original_fps = parse_video(local_path, output_dir, fps=30, max_frames=100)
|
||||
project.original_fps = original_fps
|
||||
|
||||
# Extract thumbnail
|
||||
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
|
||||
try:
|
||||
extract_thumbnail(local_path, thumbnail_path)
|
||||
with open(thumbnail_path, "rb") as f:
|
||||
thumb_data = f.read()
|
||||
thumb_object = f"projects/{project.id}/thumbnail.jpg"
|
||||
upload_file(thumb_object, thumb_data, content_type="image/jpeg", length=len(thumb_data))
|
||||
project.thumbnail_url = thumb_object
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Thumbnail extraction failed: %s", exc)
|
||||
|
||||
object_names = upload_frames_to_minio(frame_files, project.id)
|
||||
|
||||
|
||||
@@ -25,7 +25,11 @@ class Project(Base):
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
video_path = Column(String(512), nullable=True)
|
||||
thumbnail_url = Column(String(512), nullable=True)
|
||||
status = Column(String(50), default="Ready", nullable=False)
|
||||
source_type = Column(String(20), default="video", nullable=False) # video | dicom
|
||||
original_fps = Column(Float, nullable=True)
|
||||
parse_fps = Column(Float, default=30.0, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
|
||||
@@ -6,16 +6,19 @@ import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import upload_file, get_presigned_url
|
||||
from minio_client import upload_file, get_presigned_url, download_file
|
||||
from models import Project, Frame
|
||||
from schemas import FrameOut
|
||||
from services.frame_parser import parse_video, parse_dicom, upload_frames_to_minio
|
||||
from services.frame_parser import (
|
||||
parse_video, parse_dicom, upload_frames_to_minio,
|
||||
extract_thumbnail, get_video_fps,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/media", tags=["Media"])
|
||||
@@ -78,6 +81,7 @@ async def upload_media(
|
||||
description="Auto-created from upload",
|
||||
status="pending",
|
||||
video_path=object_name,
|
||||
source_type="video",
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
@@ -90,7 +94,6 @@ async def upload_media(
|
||||
db.commit()
|
||||
logger.info("Auto-created project id=%s for upload %s", project_id, file.filename)
|
||||
|
||||
# TODO: enqueue async parsing job (Celery / background task)
|
||||
logger.info("Upload complete: %s (size=%d bytes). Async parsing queued.", object_name, len(data))
|
||||
|
||||
return {
|
||||
@@ -102,6 +105,66 @@ async def upload_media(
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/upload/dicom",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Upload multiple DICOM files",
|
||||
)
|
||||
async def upload_dicom_batch(
|
||||
files: List[UploadFile] = File(...),
|
||||
project_id: Optional[int] = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Upload multiple .dcm files for a DICOM series.
|
||||
|
||||
If project_id is provided, files are added to the existing project.
|
||||
Otherwise a new DICOM project is created.
|
||||
"""
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files uploaded")
|
||||
|
||||
uploaded = []
|
||||
|
||||
if project_id:
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
else:
|
||||
# Create new DICOM project
|
||||
first_name = files[0].filename or "DICOM_Series"
|
||||
project = Project(
|
||||
name=first_name,
|
||||
description=f"DICOM series with {len(files)} files",
|
||||
status="pending",
|
||||
source_type="dicom",
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
project_id = project.id
|
||||
logger.info("Auto-created DICOM project id=%s", project_id)
|
||||
|
||||
for file in files:
|
||||
if not file.filename or not file.filename.lower().endswith(".dcm"):
|
||||
continue
|
||||
data = await file.read()
|
||||
object_name = f"uploads/{project_id}/dicom/{file.filename}"
|
||||
try:
|
||||
upload_file(object_name, data, content_type="application/dicom", length=len(data))
|
||||
uploaded.append(object_name)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to upload DICOM %s: %s", file.filename, exc)
|
||||
|
||||
project.video_path = f"uploads/{project_id}/dicom"
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"project_id": project_id,
|
||||
"uploaded_count": len(uploaded),
|
||||
"message": f"Uploaded {len(uploaded)} DICOM files. Parsing job queued.",
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/parse",
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
@@ -109,12 +172,12 @@ async def upload_media(
|
||||
)
|
||||
def parse_media(
|
||||
project_id: int,
|
||||
source_type: str = "video", # video | dicom
|
||||
source_type: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Trigger frame extraction for a project's uploaded media.
|
||||
|
||||
* video: uses FFmpeg or OpenCV fallback.
|
||||
* video: uses FFmpeg or OpenCV fallback, extracts thumbnail.
|
||||
* dicom: uses pydicom to read DCM frames.
|
||||
|
||||
Extracted frames are uploaded to MinIO and registered in the database.
|
||||
@@ -126,37 +189,53 @@ def parse_media(
|
||||
if not project.video_path:
|
||||
raise HTTPException(status_code=400, detail="Project has no media uploaded")
|
||||
|
||||
# Download from MinIO to a temp directory
|
||||
from minio_client import download_file
|
||||
|
||||
try:
|
||||
media_bytes = download_file(project.video_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to download media for parsing: %s", exc)
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve media from storage") from exc
|
||||
effective_source = source_type or project.source_type or "video"
|
||||
parse_fps = project.parse_fps or 30.0
|
||||
|
||||
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project_id}_")
|
||||
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
|
||||
|
||||
with open(local_path, "wb") as f:
|
||||
f.write(media_bytes)
|
||||
|
||||
output_dir = os.path.join(tmp_dir, "frames")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
if source_type == "dicom":
|
||||
# For DICOM, treat local_path as a directory if it contains multiple .dcm
|
||||
# If a single .dcm file was uploaded, put it in its own sub-dir
|
||||
if effective_source == "dicom":
|
||||
# Download all dicom files from MinIO
|
||||
dcm_dir = os.path.join(tmp_dir, "dcm")
|
||||
os.makedirs(dcm_dir, exist_ok=True)
|
||||
if local_path.lower().endswith(".dcm"):
|
||||
shutil.move(local_path, os.path.join(dcm_dir, os.path.basename(local_path)))
|
||||
else:
|
||||
shutil.unpack_archive(local_path, dcm_dir) if shutil.which("unzip") else shutil.move(local_path, dcm_dir)
|
||||
|
||||
from minio_client import get_minio_client, BUCKET_NAME
|
||||
client = get_minio_client()
|
||||
prefix = project.video_path
|
||||
objects = list(client.list_objects(BUCKET_NAME, prefix=prefix, recursive=True))
|
||||
for obj in objects:
|
||||
if obj.object_name.lower().endswith(".dcm"):
|
||||
data = download_file(obj.object_name)
|
||||
local_dcm = os.path.join(dcm_dir, os.path.basename(obj.object_name))
|
||||
with open(local_dcm, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
frame_files = parse_dicom(dcm_dir, output_dir)
|
||||
else:
|
||||
frame_files = parse_video(local_path, output_dir, fps=30)
|
||||
# Video: download and parse
|
||||
media_bytes = download_file(project.video_path)
|
||||
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
|
||||
with open(local_path, "wb") as f:
|
||||
f.write(media_bytes)
|
||||
|
||||
frame_files, original_fps = parse_video(local_path, output_dir, fps=int(parse_fps))
|
||||
project.original_fps = original_fps
|
||||
|
||||
# Extract thumbnail from first frame
|
||||
thumbnail_path = os.path.join(tmp_dir, "thumbnail.jpg")
|
||||
try:
|
||||
extract_thumbnail(local_path, thumbnail_path)
|
||||
with open(thumbnail_path, "rb") as f:
|
||||
thumb_data = f.read()
|
||||
thumb_object = f"projects/{project_id}/thumbnail.jpg"
|
||||
upload_file(thumb_object, thumb_data, content_type="image/jpeg", length=len(thumb_data))
|
||||
project.thumbnail_url = thumb_object
|
||||
logger.info("Uploaded thumbnail for project_id=%s", project_id)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Thumbnail extraction failed: %s", exc)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Frame extraction failed: %s", exc)
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
@@ -173,7 +252,6 @@ def parse_media(
|
||||
# Register frames in DB
|
||||
frames_out = []
|
||||
for idx, obj_name in enumerate(object_names):
|
||||
# Get image dimensions
|
||||
local_frame = frame_files[idx]
|
||||
try:
|
||||
import cv2
|
||||
|
||||
@@ -44,6 +44,8 @@ def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)
|
||||
projects = db.query(Project).offset(skip).limit(limit).all()
|
||||
for p in projects:
|
||||
p.frame_count = len(p.frames)
|
||||
if p.thumbnail_url:
|
||||
p.thumbnail_url = get_presigned_url(p.thumbnail_url, expires=3600)
|
||||
return projects
|
||||
|
||||
|
||||
@@ -58,6 +60,8 @@ def get_project(project_id: int, db: Session = Depends(get_db)) -> Project:
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
project.frame_count = len(project.frames)
|
||||
if project.thumbnail_url:
|
||||
project.thumbnail_url = get_presigned_url(project.thumbnail_url, expires=3600)
|
||||
return project
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,11 @@ class ProjectBase(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
video_path: Optional[str] = None
|
||||
thumbnail_url: Optional[str] = None
|
||||
status: Optional[str] = "pending"
|
||||
source_type: Optional[str] = "video"
|
||||
original_fps: Optional[float] = None
|
||||
parse_fps: Optional[float] = 30.0
|
||||
|
||||
|
||||
class ProjectCreate(ProjectBase):
|
||||
@@ -23,7 +27,11 @@ class ProjectUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
video_path: Optional[str] = None
|
||||
thumbnail_url: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
source_type: Optional[str] = None
|
||||
original_fps: Optional[float] = None
|
||||
parse_fps: Optional[float] = None
|
||||
|
||||
|
||||
class ProjectOut(ProjectBase):
|
||||
@@ -103,7 +111,7 @@ class AnnotationCreate(AnnotationBase):
|
||||
|
||||
class AnnotationUpdate(BaseModel):
|
||||
mask_data: Optional[dict[str, Any]] = None
|
||||
points: Optional[list[list[float]]] = None
|
||||
points: Optional[list[float]] = None
|
||||
bbox: Optional[list[float]] = None
|
||||
template_id: Optional[int] = None
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -16,12 +16,43 @@ from minio_client import upload_file, BUCKET_NAME
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_video_fps(video_path: str) -> float:
|
||||
"""Read the original frame rate of a video file."""
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
return 30.0
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
cap.release()
|
||||
return fps if fps > 0 else 30.0
|
||||
|
||||
|
||||
def extract_thumbnail(video_path: str, output_path: str, width: int = 640) -> str:
|
||||
"""Extract the first frame of a video as a thumbnail JPEG."""
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
raise RuntimeError(f"Cannot open video for thumbnail: {video_path}")
|
||||
ret, frame = cap.read()
|
||||
cap.release()
|
||||
if not ret or frame is None:
|
||||
raise RuntimeError(f"Cannot read first frame from: {video_path}")
|
||||
|
||||
h, w = frame.shape[:2]
|
||||
if w > width:
|
||||
scale = width / w
|
||||
new_w = int(w * scale)
|
||||
new_h = int(h * scale)
|
||||
frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||||
|
||||
cv2.imwrite(output_path, frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
||||
return output_path
|
||||
|
||||
|
||||
def parse_video(
|
||||
video_path: str,
|
||||
output_dir: str,
|
||||
fps: int = 30,
|
||||
max_frames: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
) -> Tuple[List[str], float]:
|
||||
"""Extract frames from a video file using FFmpeg or OpenCV fallback.
|
||||
|
||||
Args:
|
||||
@@ -31,10 +62,11 @@ def parse_video(
|
||||
max_frames: Optional maximum number of frames to extract.
|
||||
|
||||
Returns:
|
||||
List of paths to extracted frame images.
|
||||
Tuple of (frame_paths, original_fps).
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
frame_paths: List[str] = []
|
||||
original_fps = get_video_fps(video_path)
|
||||
|
||||
# Try FFmpeg first
|
||||
if shutil.which("ffmpeg"):
|
||||
@@ -57,7 +89,7 @@ def parse_video(
|
||||
if max_frames:
|
||||
frame_paths = frame_paths[:max_frames]
|
||||
logger.info("Extracted %d frames via FFmpeg", len(frame_paths))
|
||||
return frame_paths
|
||||
return frame_paths, original_fps
|
||||
else:
|
||||
logger.warning("FFmpeg failed: %s", result.stderr)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
@@ -89,7 +121,7 @@ def parse_video(
|
||||
|
||||
cap.release()
|
||||
logger.info("Extracted %d frames via OpenCV", len(frame_paths))
|
||||
return frame_paths
|
||||
return frame_paths, original_fps
|
||||
|
||||
|
||||
def parse_dicom(
|
||||
@@ -134,12 +166,12 @@ def parse_dicom(
|
||||
# Handle multi-frame DICOM
|
||||
if pixel_array.ndim == 3:
|
||||
for f in range(pixel_array.shape[0]):
|
||||
out_path = os.path.join(output_dir, f"frame_{idx:06d}_{f:03d}.png")
|
||||
cv2.imwrite(out_path, pixel_array[f])
|
||||
out_path = os.path.join(output_dir, f"frame_{idx:06d}_{f:03d}.jpg")
|
||||
cv2.imwrite(out_path, pixel_array[f], [cv2.IMWRITE_JPEG_QUALITY, 85])
|
||||
frame_paths.append(out_path)
|
||||
else:
|
||||
out_path = os.path.join(output_dir, f"frame_{idx:06d}.png")
|
||||
cv2.imwrite(out_path, pixel_array)
|
||||
out_path = os.path.join(output_dir, f"frame_{idx:06d}.jpg")
|
||||
cv2.imwrite(out_path, pixel_array, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
||||
frame_paths.append(out_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to read DICOM %s: %s", path, exc)
|
||||
|
||||
Reference in New Issue
Block a user