2026-04-29-21-51-19 - 全栈系统改造:FastAPI后端+SAM2+PostgreSQL+Redis+MinIO+前端Zustand重构

This commit is contained in:
2026-04-29 22:17:25 +08:00
parent c8f8686097
commit fd4b5e5b3d
39 changed files with 3816 additions and 211 deletions

View File

123
backend/routers/ai.py Normal file
View File

@@ -0,0 +1,123 @@
"""AI inference endpoints using SAM 2."""
import logging
from typing import Any, List
import cv2
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from database import get_db
from minio_client import download_file
from models import Frame, Annotation
from schemas import PredictRequest, PredictResponse, AnnotationOut, AnnotationCreate
from services.sam2_engine import sam_engine
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/ai", tags=["AI"])
def _load_frame_image(frame: Frame) -> np.ndarray:
"""Download a frame from MinIO and decode it to an RGB numpy array."""
try:
data = download_file(frame.image_url)
arr = np.frombuffer(data, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise ValueError("OpenCV could not decode image")
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
except Exception as exc: # noqa: BLE001
logger.error("Failed to load frame image: %s", exc)
raise HTTPException(status_code=500, detail="Failed to load frame image") from exc
@router.post(
"/predict",
response_model=PredictResponse,
summary="Run SAM 2 inference with a prompt",
)
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
"""Execute SAM 2 segmentation given an image and a prompt.
- **point**: `prompt_data` is a list of `[[x, y], ...]` normalized coordinates.
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
- **semantic**: Not yet implemented; falls back to auto segmentation.
"""
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
image = _load_frame_image(frame)
prompt_type = payload.prompt_type.lower()
polygons: List[List[List[float]]] = []
scores: List[float] = []
if prompt_type == "point":
points = payload.prompt_data
if not isinstance(points, list) or len(points) == 0:
raise HTTPException(status_code=400, detail="Invalid point prompt data")
labels = [1] * len(points)
polygons, scores = sam_engine.predict_points(image, points, labels)
elif prompt_type == "box":
box = payload.prompt_data
if not isinstance(box, list) or len(box) != 4:
raise HTTPException(status_code=400, detail="Invalid box prompt data")
polygons, scores = sam_engine.predict_box(image, box)
elif prompt_type == "semantic":
# Placeholder: use auto segmentation for now
logger.info("Semantic prompt not implemented; using auto segmentation")
polygons, scores = sam_engine.predict_auto(image)
else:
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
return {"polygons": polygons, "scores": scores}
@router.post(
"/auto",
response_model=PredictResponse,
summary="Run automatic segmentation",
)
def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
"""Run automatic mask generation on a frame using a grid of point prompts."""
frame = db.query(Frame).filter(Frame.id == image_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
image = _load_frame_image(frame)
polygons, scores = sam_engine.predict_auto(image)
return {"polygons": polygons, "scores": scores}
@router.post(
"/annotate",
response_model=AnnotationOut,
status_code=status.HTTP_201_CREATED,
summary="Save an AI-generated annotation",
)
def save_annotation(
payload: AnnotationCreate,
db: Session = Depends(get_db),
) -> Annotation:
"""Persist an annotation (mask, points, bbox) into the database."""
project = db.query(Frame).filter(Frame.id == payload.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if payload.frame_id:
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
annotation = Annotation(**payload.model_dump())
db.add(annotation)
db.commit()
db.refresh(annotation)
logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id)
return annotation

24
backend/routers/auth.py Normal file
View File

@@ -0,0 +1,24 @@
"""Authentication endpoints."""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
router = APIRouter(prefix="/api/auth", tags=["Auth"])
class LoginRequest(BaseModel):
username: str
password: str
class LoginResponse(BaseModel):
token: str
username: str
@router.post("/login", response_model=LoginResponse)
def login(payload: LoginRequest) -> dict:
"""Simple login for development."""
if payload.username == "admin" and payload.password == "123456":
return {"token": "fake-jwt-token-for-admin", "username": payload.username}
raise HTTPException(status_code=401, detail="Invalid credentials")

194
backend/routers/export.py Normal file
View File

@@ -0,0 +1,194 @@
"""Annotation export endpoints (COCO, PNG masks)."""
import io
import json
import logging
import os
import zipfile
from datetime import datetime
from typing import Any, Dict, List
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from database import get_db
from models import Project, Annotation, Frame, Template
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/export", tags=["Export"])
def _mask_from_polygon(
polygon: List[List[float]],
width: int,
height: int,
) -> np.ndarray:
"""Render a normalized polygon to a binary mask."""
import cv2
pts = np.array(
[[int(p[0] * width), int(p[1] * height)] for p in polygon],
dtype=np.int32,
)
mask = np.zeros((height, width), dtype=np.uint8)
cv2.fillPoly(mask, [pts], 255)
return mask
@router.get(
"/{project_id}/coco",
summary="Export annotations in COCO format",
)
def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
"""Export all annotations for a project as a COCO-format JSON file."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
annotations = (
db.query(Annotation)
.filter(Annotation.project_id == project_id)
.all()
)
frames = (
db.query(Frame)
.filter(Frame.project_id == project_id)
.order_by(Frame.frame_index)
.all()
)
templates = db.query(Template).all()
# Build COCO structure
images = []
for idx, frame in enumerate(frames):
images.append({
"id": frame.id,
"file_name": frame.image_url,
"width": frame.width or 1920,
"height": frame.height or 1080,
"frame_index": idx,
})
categories = []
template_id_to_cat_id: Dict[int, int] = {}
for cat_idx, tmpl in enumerate(templates, start=1):
categories.append({
"id": cat_idx,
"name": tmpl.name,
"color": tmpl.color,
})
template_id_to_cat_id[tmpl.id] = cat_idx
coco_annotations = []
ann_id = 1
for ann in annotations:
if not ann.mask_data:
continue
polygons = ann.mask_data.get("polygons", [])
if not polygons:
continue
# Use first polygon for bbox / area approximation
first_poly = polygons[0]
xs = [p[0] for p in first_poly]
ys = [p[1] for p in first_poly]
width = ann.frame.width if ann.frame else 1920
height = ann.frame.height if ann.frame else 1080
bbox = [
min(xs) * width,
min(ys) * height,
(max(xs) - min(xs)) * width,
(max(ys) - min(ys)) * height,
]
area = bbox[2] * bbox[3]
segmentation = []
for poly in polygons:
flat = []
for p in poly:
flat.append(p[0] * width)
flat.append(p[1] * height)
segmentation.append(flat)
coco_annotations.append({
"id": ann_id,
"image_id": ann.frame_id,
"category_id": template_id_to_cat_id.get(ann.template_id, 0),
"segmentation": segmentation,
"area": area,
"bbox": bbox,
"iscrowd": 0,
})
ann_id += 1
coco = {
"info": {
"description": f"Annotations for {project.name}",
"version": "1.0",
"year": datetime.now().year,
"date_created": datetime.now().isoformat(),
},
"images": images,
"annotations": coco_annotations,
"categories": categories,
}
data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8")
filename = f"project_{project_id}_coco.json"
return StreamingResponse(
io.BytesIO(data),
media_type="application/json",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
@router.get(
"/{project_id}/masks",
summary="Export PNG masks as a ZIP archive",
)
def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
"""Export all annotation masks as individual PNG files inside a ZIP archive."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
annotations = (
db.query(Annotation)
.filter(Annotation.project_id == project_id)
.all()
)
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
for ann in annotations:
if not ann.mask_data:
continue
polygons = ann.mask_data.get("polygons", [])
if not polygons:
continue
width = ann.frame.width if ann.frame else 1920
height = ann.frame.height if ann.frame else 1080
combined = np.zeros((height, width), dtype=np.uint8)
for poly in polygons:
mask = _mask_from_polygon(poly, width, height)
combined = np.maximum(combined, mask)
# Encode PNG
import cv2
_, encoded = cv2.imencode(".png", combined)
fname = f"mask_{ann.id:06d}.png"
zf.writestr(fname, encoded.tobytes())
zip_buffer.seek(0)
filename = f"project_{project_id}_masks.zip"
return StreamingResponse(
zip_buffer,
media_type="application/zip",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)

192
backend/routers/media.py Normal file
View File

@@ -0,0 +1,192 @@
"""Media upload and parsing endpoints."""
import logging
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
from sqlalchemy.orm import Session
from database import get_db
from minio_client import upload_file, get_presigned_url
from models import Project, Frame
from schemas import FrameOut
from services.frame_parser import parse_video, parse_dicom, upload_frames_to_minio
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/media", tags=["Media"])
ALLOWED_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".png", ".jpg", ".jpeg", ".dcm"}
def _get_ext(filename: str) -> str:
return Path(filename).suffix.lower()
@router.post(
"/upload",
status_code=status.HTTP_201_CREATED,
summary="Upload a media file",
)
async def upload_media(
file: UploadFile = File(...),
project_id: Optional[int] = Form(None),
db: Session = Depends(get_db),
) -> dict:
"""Accept a video, image, or DICOM file and store it in MinIO.
If project_id is provided, the video_path of the project is updated.
Returns the presigned URL of the uploaded object.
"""
if not file.filename:
raise HTTPException(status_code=400, detail="Missing filename")
ext = _get_ext(file.filename)
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {ext}",
)
data = await file.read()
object_name = f"uploads/{project_id or 'general'}/{file.filename}"
try:
upload_file(object_name, data, content_type=file.content_type or "application/octet-stream", length=len(data))
except Exception as exc: # noqa: BLE001
logger.error("Upload failed: %s", exc)
raise HTTPException(status_code=500, detail="Upload to storage failed") from exc
file_url = get_presigned_url(object_name, expires=3600)
if project_id:
project = db.query(Project).filter(Project.id == project_id).first()
if project:
project.video_path = object_name
db.commit()
logger.info("Linked upload to project_id=%s", project_id)
else:
logger.warning("Project id=%s not found for upload linkage", project_id)
# TODO: enqueue async parsing job (Celery / background task)
logger.info("Upload complete: %s (size=%d bytes). Async parsing queued.", object_name, len(data))
return {
"object_name": object_name,
"file_url": file_url,
"size": len(data),
"message": "Upload successful. Parsing job queued.",
}
@router.post(
"/parse",
status_code=status.HTTP_202_ACCEPTED,
summary="Trigger frame extraction",
)
def parse_media(
project_id: int,
source_type: str = "video", # video | dicom
db: Session = Depends(get_db),
) -> dict:
"""Trigger frame extraction for a project's uploaded media.
* video: uses FFmpeg or OpenCV fallback.
* dicom: uses pydicom to read DCM frames.
Extracted frames are uploaded to MinIO and registered in the database.
"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if not project.video_path:
raise HTTPException(status_code=400, detail="Project has no media uploaded")
# Download from MinIO to a temp directory
from minio_client import download_file
try:
media_bytes = download_file(project.video_path)
except Exception as exc: # noqa: BLE001
logger.error("Failed to download media for parsing: %s", exc)
raise HTTPException(status_code=500, detail="Failed to retrieve media from storage") from exc
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project_id}_")
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
with open(local_path, "wb") as f:
f.write(media_bytes)
output_dir = os.path.join(tmp_dir, "frames")
os.makedirs(output_dir, exist_ok=True)
try:
if source_type == "dicom":
# For DICOM, treat local_path as a directory if it contains multiple .dcm
# If a single .dcm file was uploaded, put it in its own sub-dir
dcm_dir = os.path.join(tmp_dir, "dcm")
os.makedirs(dcm_dir, exist_ok=True)
if local_path.lower().endswith(".dcm"):
shutil.move(local_path, os.path.join(dcm_dir, os.path.basename(local_path)))
else:
shutil.unpack_archive(local_path, dcm_dir) if shutil.which("unzip") else shutil.move(local_path, dcm_dir)
frame_files = parse_dicom(dcm_dir, output_dir)
else:
frame_files = parse_video(local_path, output_dir, fps=30)
except Exception as exc: # noqa: BLE001
logger.error("Frame extraction failed: %s", exc)
shutil.rmtree(tmp_dir, ignore_errors=True)
raise HTTPException(status_code=500, detail="Frame extraction failed") from exc
# Upload frames to MinIO
try:
object_names = upload_frames_to_minio(frame_files, project_id)
except Exception as exc: # noqa: BLE001
logger.error("Frame upload failed: %s", exc)
shutil.rmtree(tmp_dir, ignore_errors=True)
raise HTTPException(status_code=500, detail="Frame upload to storage failed") from exc
# Register frames in DB
frames_out = []
for idx, obj_name in enumerate(object_names):
# Get image dimensions
local_frame = frame_files[idx]
try:
import cv2
img = cv2.imread(local_frame)
h, w = img.shape[:2] if img is not None else (None, None)
except Exception: # noqa: BLE001
h, w = None, None
frame = Frame(
project_id=project_id,
frame_index=idx,
image_url=obj_name,
width=w,
height=h,
)
db.add(frame)
frames_out.append(frame)
db.commit()
for f in frames_out:
db.refresh(f)
# Cleanup temp files
shutil.rmtree(tmp_dir, ignore_errors=True)
project.status = "ready"
db.commit()
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project_id)
return {
"project_id": project_id,
"frames_extracted": len(frames_out),
"status": "ready",
"message": "Frame extraction completed successfully.",
}

165
backend/routers/projects.py Normal file
View File

@@ -0,0 +1,165 @@
"""Project and Frame CRUD endpoints."""
import logging
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from database import get_db
from models import Project, Frame
from schemas import ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/projects", tags=["Projects"])
# ---------------------------------------------------------------------------
# Projects
# ---------------------------------------------------------------------------
@router.post(
"",
response_model=ProjectOut,
status_code=status.HTTP_201_CREATED,
summary="Create a new project",
)
def create_project(payload: ProjectCreate, db: Session = Depends(get_db)) -> Project:
"""Create a new segmentation project."""
project = Project(**payload.model_dump())
db.add(project)
db.commit()
db.refresh(project)
logger.info("Created project id=%s name=%s", project.id, project.name)
return project
@router.get(
"",
response_model=List[ProjectOut],
summary="List all projects",
)
def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)) -> List[Project]:
"""Retrieve a paginated list of projects."""
return db.query(Project).offset(skip).limit(limit).all()
@router.get(
"/{project_id}",
response_model=ProjectOut,
summary="Get a single project",
)
def get_project(project_id: int, db: Session = Depends(get_db)) -> Project:
"""Retrieve a project by its ID."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return project
@router.patch(
"/{project_id}",
response_model=ProjectOut,
summary="Update a project",
)
def update_project(
project_id: int,
payload: ProjectUpdate,
db: Session = Depends(get_db),
) -> Project:
"""Update project fields partially."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
for key, value in payload.model_dump(exclude_unset=True).items():
setattr(project, key, value)
db.commit()
db.refresh(project)
logger.info("Updated project id=%s", project_id)
return project
@router.delete(
"/{project_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a project",
)
def delete_project(project_id: int, db: Session = Depends(get_db)) -> None:
"""Delete a project and all related frames and annotations."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
db.delete(project)
db.commit()
logger.info("Deleted project id=%s", project_id)
# ---------------------------------------------------------------------------
# Frames
# ---------------------------------------------------------------------------
@router.post(
"/{project_id}/frames",
response_model=FrameOut,
status_code=status.HTTP_201_CREATED,
summary="Add a frame to a project",
)
def create_frame(
project_id: int,
payload: FrameCreate,
db: Session = Depends(get_db),
) -> Frame:
"""Register a new frame under a project."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
frame = Frame(project_id=project_id, **payload.model_dump(exclude={"project_id"}))
db.add(frame)
db.commit()
db.refresh(frame)
return frame
@router.get(
"/{project_id}/frames",
response_model=List[FrameOut],
summary="List frames for a project",
)
def list_frames(
project_id: int,
skip: int = 0,
limit: int = 1000,
db: Session = Depends(get_db),
) -> List[Frame]:
"""Retrieve all frames belonging to a project."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return (
db.query(Frame)
.filter(Frame.project_id == project_id)
.order_by(Frame.frame_index)
.offset(skip)
.limit(limit)
.all()
)
@router.get(
"/{project_id}/frames/{frame_id}",
response_model=FrameOut,
summary="Get a single frame",
)
def get_frame(project_id: int, frame_id: int, db: Session = Depends(get_db)) -> Frame:
"""Retrieve a specific frame by ID."""
frame = (
db.query(Frame)
.filter(Frame.project_id == project_id, Frame.id == frame_id)
.first()
)
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
return frame

View File

@@ -0,0 +1,97 @@
"""Template (Ontology) CRUD endpoints."""
import logging
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from database import get_db
from models import Template
from schemas import TemplateCreate, TemplateOut, TemplateUpdate
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/templates", tags=["Templates"])
@router.post(
"",
response_model=TemplateOut,
status_code=status.HTTP_201_CREATED,
summary="Create a new template",
)
def create_template(payload: TemplateCreate, db: Session = Depends(get_db)) -> Template:
"""Create a new ontology template / segmentation class."""
template = Template(**payload.model_dump())
db.add(template)
db.commit()
db.refresh(template)
logger.info("Created template id=%s name=%s", template.id, template.name)
return template
@router.get(
"",
response_model=List[TemplateOut],
summary="List all templates",
)
def list_templates(
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
) -> List[Template]:
"""Retrieve all ontology templates."""
return db.query(Template).offset(skip).limit(limit).all()
@router.get(
"/{template_id}",
response_model=TemplateOut,
summary="Get a single template",
)
def get_template(template_id: int, db: Session = Depends(get_db)) -> Template:
"""Retrieve a template by its ID."""
template = db.query(Template).filter(Template.id == template_id).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
return template
@router.patch(
"/{template_id}",
response_model=TemplateOut,
summary="Update a template",
)
def update_template(
template_id: int,
payload: TemplateUpdate,
db: Session = Depends(get_db),
) -> Template:
"""Update template fields partially."""
template = db.query(Template).filter(Template.id == template_id).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
for key, value in payload.model_dump(exclude_unset=True).items():
setattr(template, key, value)
db.commit()
db.refresh(template)
logger.info("Updated template id=%s", template_id)
return template
@router.delete(
"/{template_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a template",
)
def delete_template(template_id: int, db: Session = Depends(get_db)) -> None:
"""Delete a template. Associated annotations will have template_id set to NULL."""
template = db.query(Template).filter(Template.id == template_id).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
db.delete(template)
db.commit()
logger.info("Deleted template id=%s", template_id)