2026-04-29-21-51-19 - 全栈系统改造:FastAPI后端+SAM2+PostgreSQL+Redis+MinIO+前端Zustand重构
This commit is contained in:
0
backend/routers/__init__.py
Normal file
0
backend/routers/__init__.py
Normal file
123
backend/routers/ai.py
Normal file
123
backend/routers/ai.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""AI inference endpoints using SAM 2."""
|
||||
|
||||
import logging
|
||||
from typing import Any, List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import download_file
|
||||
from models import Frame, Annotation
|
||||
from schemas import PredictRequest, PredictResponse, AnnotationOut, AnnotationCreate
|
||||
from services.sam2_engine import sam_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/ai", tags=["AI"])
|
||||
|
||||
|
||||
def _load_frame_image(frame: Frame) -> np.ndarray:
|
||||
"""Download a frame from MinIO and decode it to an RGB numpy array."""
|
||||
try:
|
||||
data = download_file(frame.image_url)
|
||||
arr = np.frombuffer(data, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise ValueError("OpenCV could not decode image")
|
||||
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to load frame image: %s", exc)
|
||||
raise HTTPException(status_code=500, detail="Failed to load frame image") from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/predict",
|
||||
response_model=PredictResponse,
|
||||
summary="Run SAM 2 inference with a prompt",
|
||||
)
|
||||
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
"""Execute SAM 2 segmentation given an image and a prompt.
|
||||
|
||||
- **point**: `prompt_data` is a list of `[[x, y], ...]` normalized coordinates.
|
||||
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
|
||||
- **semantic**: Not yet implemented; falls back to auto segmentation.
|
||||
"""
|
||||
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
image = _load_frame_image(frame)
|
||||
prompt_type = payload.prompt_type.lower()
|
||||
|
||||
polygons: List[List[List[float]]] = []
|
||||
scores: List[float] = []
|
||||
|
||||
if prompt_type == "point":
|
||||
points = payload.prompt_data
|
||||
if not isinstance(points, list) or len(points) == 0:
|
||||
raise HTTPException(status_code=400, detail="Invalid point prompt data")
|
||||
labels = [1] * len(points)
|
||||
polygons, scores = sam_engine.predict_points(image, points, labels)
|
||||
|
||||
elif prompt_type == "box":
|
||||
box = payload.prompt_data
|
||||
if not isinstance(box, list) or len(box) != 4:
|
||||
raise HTTPException(status_code=400, detail="Invalid box prompt data")
|
||||
polygons, scores = sam_engine.predict_box(image, box)
|
||||
|
||||
elif prompt_type == "semantic":
|
||||
# Placeholder: use auto segmentation for now
|
||||
logger.info("Semantic prompt not implemented; using auto segmentation")
|
||||
polygons, scores = sam_engine.predict_auto(image)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
|
||||
|
||||
return {"polygons": polygons, "scores": scores}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/auto",
|
||||
response_model=PredictResponse,
|
||||
summary="Run automatic segmentation",
|
||||
)
|
||||
def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
|
||||
"""Run automatic mask generation on a frame using a grid of point prompts."""
|
||||
frame = db.query(Frame).filter(Frame.id == image_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
image = _load_frame_image(frame)
|
||||
polygons, scores = sam_engine.predict_auto(image)
|
||||
|
||||
return {"polygons": polygons, "scores": scores}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/annotate",
|
||||
response_model=AnnotationOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Save an AI-generated annotation",
|
||||
)
|
||||
def save_annotation(
|
||||
payload: AnnotationCreate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Annotation:
|
||||
"""Persist an annotation (mask, points, bbox) into the database."""
|
||||
project = db.query(Frame).filter(Frame.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
if payload.frame_id:
|
||||
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
annotation = Annotation(**payload.model_dump())
|
||||
db.add(annotation)
|
||||
db.commit()
|
||||
db.refresh(annotation)
|
||||
logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id)
|
||||
return annotation
|
||||
24
backend/routers/auth.py
Normal file
24
backend/routers/auth.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Authentication endpoints."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["Auth"])
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
token: str
|
||||
username: str
|
||||
|
||||
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
def login(payload: LoginRequest) -> dict:
|
||||
"""Simple login for development."""
|
||||
if payload.username == "admin" and payload.password == "123456":
|
||||
return {"token": "fake-jwt-token-for-admin", "username": payload.username}
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
194
backend/routers/export.py
Normal file
194
backend/routers/export.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Annotation export endpoints (COCO, PNG masks)."""
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Project, Annotation, Frame, Template
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/export", tags=["Export"])
|
||||
|
||||
|
||||
def _mask_from_polygon(
|
||||
polygon: List[List[float]],
|
||||
width: int,
|
||||
height: int,
|
||||
) -> np.ndarray:
|
||||
"""Render a normalized polygon to a binary mask."""
|
||||
import cv2
|
||||
|
||||
pts = np.array(
|
||||
[[int(p[0] * width), int(p[1] * height)] for p in polygon],
|
||||
dtype=np.int32,
|
||||
)
|
||||
mask = np.zeros((height, width), dtype=np.uint8)
|
||||
cv2.fillPoly(mask, [pts], 255)
|
||||
return mask
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/coco",
|
||||
summary="Export annotations in COCO format",
|
||||
)
|
||||
def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
|
||||
"""Export all annotations for a project as a COCO-format JSON file."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == project_id)
|
||||
.all()
|
||||
)
|
||||
frames = (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id)
|
||||
.order_by(Frame.frame_index)
|
||||
.all()
|
||||
)
|
||||
templates = db.query(Template).all()
|
||||
|
||||
# Build COCO structure
|
||||
images = []
|
||||
for idx, frame in enumerate(frames):
|
||||
images.append({
|
||||
"id": frame.id,
|
||||
"file_name": frame.image_url,
|
||||
"width": frame.width or 1920,
|
||||
"height": frame.height or 1080,
|
||||
"frame_index": idx,
|
||||
})
|
||||
|
||||
categories = []
|
||||
template_id_to_cat_id: Dict[int, int] = {}
|
||||
for cat_idx, tmpl in enumerate(templates, start=1):
|
||||
categories.append({
|
||||
"id": cat_idx,
|
||||
"name": tmpl.name,
|
||||
"color": tmpl.color,
|
||||
})
|
||||
template_id_to_cat_id[tmpl.id] = cat_idx
|
||||
|
||||
coco_annotations = []
|
||||
ann_id = 1
|
||||
for ann in annotations:
|
||||
if not ann.mask_data:
|
||||
continue
|
||||
polygons = ann.mask_data.get("polygons", [])
|
||||
if not polygons:
|
||||
continue
|
||||
|
||||
# Use first polygon for bbox / area approximation
|
||||
first_poly = polygons[0]
|
||||
xs = [p[0] for p in first_poly]
|
||||
ys = [p[1] for p in first_poly]
|
||||
width = ann.frame.width if ann.frame else 1920
|
||||
height = ann.frame.height if ann.frame else 1080
|
||||
bbox = [
|
||||
min(xs) * width,
|
||||
min(ys) * height,
|
||||
(max(xs) - min(xs)) * width,
|
||||
(max(ys) - min(ys)) * height,
|
||||
]
|
||||
area = bbox[2] * bbox[3]
|
||||
|
||||
segmentation = []
|
||||
for poly in polygons:
|
||||
flat = []
|
||||
for p in poly:
|
||||
flat.append(p[0] * width)
|
||||
flat.append(p[1] * height)
|
||||
segmentation.append(flat)
|
||||
|
||||
coco_annotations.append({
|
||||
"id": ann_id,
|
||||
"image_id": ann.frame_id,
|
||||
"category_id": template_id_to_cat_id.get(ann.template_id, 0),
|
||||
"segmentation": segmentation,
|
||||
"area": area,
|
||||
"bbox": bbox,
|
||||
"iscrowd": 0,
|
||||
})
|
||||
ann_id += 1
|
||||
|
||||
coco = {
|
||||
"info": {
|
||||
"description": f"Annotations for {project.name}",
|
||||
"version": "1.0",
|
||||
"year": datetime.now().year,
|
||||
"date_created": datetime.now().isoformat(),
|
||||
},
|
||||
"images": images,
|
||||
"annotations": coco_annotations,
|
||||
"categories": categories,
|
||||
}
|
||||
|
||||
data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8")
|
||||
filename = f"project_{project_id}_coco.json"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(data),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/masks",
|
||||
summary="Export PNG masks as a ZIP archive",
|
||||
)
|
||||
def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
|
||||
"""Export all annotation masks as individual PNG files inside a ZIP archive."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == project_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for ann in annotations:
|
||||
if not ann.mask_data:
|
||||
continue
|
||||
polygons = ann.mask_data.get("polygons", [])
|
||||
if not polygons:
|
||||
continue
|
||||
|
||||
width = ann.frame.width if ann.frame else 1920
|
||||
height = ann.frame.height if ann.frame else 1080
|
||||
combined = np.zeros((height, width), dtype=np.uint8)
|
||||
|
||||
for poly in polygons:
|
||||
mask = _mask_from_polygon(poly, width, height)
|
||||
combined = np.maximum(combined, mask)
|
||||
|
||||
# Encode PNG
|
||||
import cv2
|
||||
_, encoded = cv2.imencode(".png", combined)
|
||||
fname = f"mask_{ann.id:06d}.png"
|
||||
zf.writestr(fname, encoded.tobytes())
|
||||
|
||||
zip_buffer.seek(0)
|
||||
filename = f"project_{project_id}_masks.zip"
|
||||
|
||||
return StreamingResponse(
|
||||
zip_buffer,
|
||||
media_type="application/zip",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
192
backend/routers/media.py
Normal file
192
backend/routers/media.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Media upload and parsing endpoints."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import upload_file, get_presigned_url
|
||||
from models import Project, Frame
|
||||
from schemas import FrameOut
|
||||
from services.frame_parser import parse_video, parse_dicom, upload_frames_to_minio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/media", tags=["Media"])
|
||||
|
||||
ALLOWED_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".png", ".jpg", ".jpeg", ".dcm"}
|
||||
|
||||
|
||||
def _get_ext(filename: str) -> str:
|
||||
return Path(filename).suffix.lower()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/upload",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Upload a media file",
|
||||
)
|
||||
async def upload_media(
|
||||
file: UploadFile = File(...),
|
||||
project_id: Optional[int] = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Accept a video, image, or DICOM file and store it in MinIO.
|
||||
|
||||
If project_id is provided, the video_path of the project is updated.
|
||||
Returns the presigned URL of the uploaded object.
|
||||
"""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="Missing filename")
|
||||
|
||||
ext = _get_ext(file.filename)
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported file type: {ext}",
|
||||
)
|
||||
|
||||
data = await file.read()
|
||||
object_name = f"uploads/{project_id or 'general'}/{file.filename}"
|
||||
|
||||
try:
|
||||
upload_file(object_name, data, content_type=file.content_type or "application/octet-stream", length=len(data))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Upload failed: %s", exc)
|
||||
raise HTTPException(status_code=500, detail="Upload to storage failed") from exc
|
||||
|
||||
file_url = get_presigned_url(object_name, expires=3600)
|
||||
|
||||
if project_id:
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if project:
|
||||
project.video_path = object_name
|
||||
db.commit()
|
||||
logger.info("Linked upload to project_id=%s", project_id)
|
||||
else:
|
||||
logger.warning("Project id=%s not found for upload linkage", project_id)
|
||||
|
||||
# TODO: enqueue async parsing job (Celery / background task)
|
||||
logger.info("Upload complete: %s (size=%d bytes). Async parsing queued.", object_name, len(data))
|
||||
|
||||
return {
|
||||
"object_name": object_name,
|
||||
"file_url": file_url,
|
||||
"size": len(data),
|
||||
"message": "Upload successful. Parsing job queued.",
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/parse",
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
summary="Trigger frame extraction",
|
||||
)
|
||||
def parse_media(
|
||||
project_id: int,
|
||||
source_type: str = "video", # video | dicom
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Trigger frame extraction for a project's uploaded media.
|
||||
|
||||
* video: uses FFmpeg or OpenCV fallback.
|
||||
* dicom: uses pydicom to read DCM frames.
|
||||
|
||||
Extracted frames are uploaded to MinIO and registered in the database.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
if not project.video_path:
|
||||
raise HTTPException(status_code=400, detail="Project has no media uploaded")
|
||||
|
||||
# Download from MinIO to a temp directory
|
||||
from minio_client import download_file
|
||||
|
||||
try:
|
||||
media_bytes = download_file(project.video_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to download media for parsing: %s", exc)
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve media from storage") from exc
|
||||
|
||||
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project_id}_")
|
||||
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
|
||||
|
||||
with open(local_path, "wb") as f:
|
||||
f.write(media_bytes)
|
||||
|
||||
output_dir = os.path.join(tmp_dir, "frames")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
if source_type == "dicom":
|
||||
# For DICOM, treat local_path as a directory if it contains multiple .dcm
|
||||
# If a single .dcm file was uploaded, put it in its own sub-dir
|
||||
dcm_dir = os.path.join(tmp_dir, "dcm")
|
||||
os.makedirs(dcm_dir, exist_ok=True)
|
||||
if local_path.lower().endswith(".dcm"):
|
||||
shutil.move(local_path, os.path.join(dcm_dir, os.path.basename(local_path)))
|
||||
else:
|
||||
shutil.unpack_archive(local_path, dcm_dir) if shutil.which("unzip") else shutil.move(local_path, dcm_dir)
|
||||
frame_files = parse_dicom(dcm_dir, output_dir)
|
||||
else:
|
||||
frame_files = parse_video(local_path, output_dir, fps=30)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Frame extraction failed: %s", exc)
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
raise HTTPException(status_code=500, detail="Frame extraction failed") from exc
|
||||
|
||||
# Upload frames to MinIO
|
||||
try:
|
||||
object_names = upload_frames_to_minio(frame_files, project_id)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Frame upload failed: %s", exc)
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
raise HTTPException(status_code=500, detail="Frame upload to storage failed") from exc
|
||||
|
||||
# Register frames in DB
|
||||
frames_out = []
|
||||
for idx, obj_name in enumerate(object_names):
|
||||
# Get image dimensions
|
||||
local_frame = frame_files[idx]
|
||||
try:
|
||||
import cv2
|
||||
img = cv2.imread(local_frame)
|
||||
h, w = img.shape[:2] if img is not None else (None, None)
|
||||
except Exception: # noqa: BLE001
|
||||
h, w = None, None
|
||||
|
||||
frame = Frame(
|
||||
project_id=project_id,
|
||||
frame_index=idx,
|
||||
image_url=obj_name,
|
||||
width=w,
|
||||
height=h,
|
||||
)
|
||||
db.add(frame)
|
||||
frames_out.append(frame)
|
||||
|
||||
db.commit()
|
||||
for f in frames_out:
|
||||
db.refresh(f)
|
||||
|
||||
# Cleanup temp files
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
|
||||
project.status = "ready"
|
||||
db.commit()
|
||||
|
||||
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project_id)
|
||||
return {
|
||||
"project_id": project_id,
|
||||
"frames_extracted": len(frames_out),
|
||||
"status": "ready",
|
||||
"message": "Frame extraction completed successfully.",
|
||||
}
|
||||
165
backend/routers/projects.py
Normal file
165
backend/routers/projects.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Project and Frame CRUD endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Project, Frame
|
||||
from schemas import ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/projects", tags=["Projects"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Projects
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post(
|
||||
"",
|
||||
response_model=ProjectOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a new project",
|
||||
)
|
||||
def create_project(payload: ProjectCreate, db: Session = Depends(get_db)) -> Project:
|
||||
"""Create a new segmentation project."""
|
||||
project = Project(**payload.model_dump())
|
||||
db.add(project)
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
logger.info("Created project id=%s name=%s", project.id, project.name)
|
||||
return project
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=List[ProjectOut],
|
||||
summary="List all projects",
|
||||
)
|
||||
def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)) -> List[Project]:
|
||||
"""Retrieve a paginated list of projects."""
|
||||
return db.query(Project).offset(skip).limit(limit).all()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}",
|
||||
response_model=ProjectOut,
|
||||
summary="Get a single project",
|
||||
)
|
||||
def get_project(project_id: int, db: Session = Depends(get_db)) -> Project:
|
||||
"""Retrieve a project by its ID."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
return project
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{project_id}",
|
||||
response_model=ProjectOut,
|
||||
summary="Update a project",
|
||||
)
|
||||
def update_project(
|
||||
project_id: int,
|
||||
payload: ProjectUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Project:
|
||||
"""Update project fields partially."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
for key, value in payload.model_dump(exclude_unset=True).items():
|
||||
setattr(project, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
logger.info("Updated project id=%s", project_id)
|
||||
return project
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{project_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a project",
|
||||
)
|
||||
def delete_project(project_id: int, db: Session = Depends(get_db)) -> None:
|
||||
"""Delete a project and all related frames and annotations."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
db.delete(project)
|
||||
db.commit()
|
||||
logger.info("Deleted project id=%s", project_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frames
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post(
|
||||
"/{project_id}/frames",
|
||||
response_model=FrameOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Add a frame to a project",
|
||||
)
|
||||
def create_frame(
|
||||
project_id: int,
|
||||
payload: FrameCreate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Frame:
|
||||
"""Register a new frame under a project."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
frame = Frame(project_id=project_id, **payload.model_dump(exclude={"project_id"}))
|
||||
db.add(frame)
|
||||
db.commit()
|
||||
db.refresh(frame)
|
||||
return frame
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/frames",
|
||||
response_model=List[FrameOut],
|
||||
summary="List frames for a project",
|
||||
)
|
||||
def list_frames(
|
||||
project_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 1000,
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[Frame]:
|
||||
"""Retrieve all frames belonging to a project."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
return (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id)
|
||||
.order_by(Frame.frame_index)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/frames/{frame_id}",
|
||||
response_model=FrameOut,
|
||||
summary="Get a single frame",
|
||||
)
|
||||
def get_frame(project_id: int, frame_id: int, db: Session = Depends(get_db)) -> Frame:
|
||||
"""Retrieve a specific frame by ID."""
|
||||
frame = (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id, Frame.id == frame_id)
|
||||
.first()
|
||||
)
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
return frame
|
||||
97
backend/routers/templates.py
Normal file
97
backend/routers/templates.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Template (Ontology) CRUD endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Template
|
||||
from schemas import TemplateCreate, TemplateOut, TemplateUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/templates", tags=["Templates"])
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=TemplateOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a new template",
|
||||
)
|
||||
def create_template(payload: TemplateCreate, db: Session = Depends(get_db)) -> Template:
|
||||
"""Create a new ontology template / segmentation class."""
|
||||
template = Template(**payload.model_dump())
|
||||
db.add(template)
|
||||
db.commit()
|
||||
db.refresh(template)
|
||||
logger.info("Created template id=%s name=%s", template.id, template.name)
|
||||
return template
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=List[TemplateOut],
|
||||
summary="List all templates",
|
||||
)
|
||||
def list_templates(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[Template]:
|
||||
"""Retrieve all ontology templates."""
|
||||
return db.query(Template).offset(skip).limit(limit).all()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{template_id}",
|
||||
response_model=TemplateOut,
|
||||
summary="Get a single template",
|
||||
)
|
||||
def get_template(template_id: int, db: Session = Depends(get_db)) -> Template:
|
||||
"""Retrieve a template by its ID."""
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
return template
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{template_id}",
|
||||
response_model=TemplateOut,
|
||||
summary="Update a template",
|
||||
)
|
||||
def update_template(
|
||||
template_id: int,
|
||||
payload: TemplateUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Template:
|
||||
"""Update template fields partially."""
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
for key, value in payload.model_dump(exclude_unset=True).items():
|
||||
setattr(template, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(template)
|
||||
logger.info("Updated template id=%s", template_id)
|
||||
return template
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{template_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a template",
|
||||
)
|
||||
def delete_template(template_id: int, db: Session = Depends(get_db)) -> None:
|
||||
"""Delete a template. Associated annotations will have template_id set to NULL."""
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
db.delete(template)
|
||||
db.commit()
|
||||
logger.info("Deleted template id=%s", template_id)
|
||||
Reference in New Issue
Block a user