2026-04-29-21-51-19 - 全栈系统改造:FastAPI后端+SAM2+PostgreSQL+Redis+MinIO+前端Zustand重构

This commit is contained in:
2026-04-29 22:17:25 +08:00
parent c8f8686097
commit fd4b5e5b3d
39 changed files with 3816 additions and 211 deletions

15
.gitignore vendored
View File

@@ -6,3 +6,18 @@ coverage/
*.log
.env*
!.env.example
# Data & Models
models/
uploads/
frames/
minio_data/
Data_*/
*.mp4
*.dcm
*.7z
# Binaries
Viewer/
Viewer.exe
DICOMDIR
__pycache__/
*.pyc

35
backend/config.py Normal file
View File

@@ -0,0 +1,35 @@
"""Application configuration using Pydantic Settings."""
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
# Database
db_url: str = "postgresql://seguser:segpass123@localhost:5432/segserver"
# Redis
redis_url: str = "redis://localhost:6379/0"
# MinIO
minio_endpoint: str = "localhost:9000"
minio_access_key: str = "minioadmin"
minio_secret_key: str = "minioadmin"
minio_secure: bool = False
# SAM2
sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
sam_model_config: str = "sam2_hiera_t.yaml"
# App
app_env: str = "development"
cors_origins: list[str] = ["http://localhost:3000"]
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
extra = "ignore"
settings = Settings()

29
backend/database.py Normal file
View File

@@ -0,0 +1,29 @@
"""Database configuration using synchronous SQLAlchemy."""
from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker, Session
from fastapi import Depends
from typing import Generator
from config import settings
engine = create_engine(
settings.db_url,
pool_pre_ping=True,
pool_size=10,
max_overflow=20,
echo=False,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def get_db() -> Generator[Session, None, None]:
"""FastAPI dependency that yields a database session."""
db = SessionLocal()
try:
yield db
finally:
db.close()

49
backend/download_sam2.py Normal file
View File

@@ -0,0 +1,49 @@
#!/usr/bin/env python3
"""
SAM 2 模型权重下载脚本
运行: python download_sam2.py
"""
import os
import urllib.request
import sys
MODEL_DIR = "/home/wkmgc/Desktop/Seg_Server/models"
os.makedirs(MODEL_DIR, exist_ok=True)
# SAM 2 模型权重 (Meta AI 官方)
MODELS = {
"sam2_hiera_tiny.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt",
"sam2_hiera_small.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
"sam2_hiera_base_plus.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
"sam2_hiera_large.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
}
def download_file(url: str, dest: str):
if os.path.exists(dest):
print(f"[跳过] {os.path.basename(dest)} 已存在")
return
print(f"[下载] {url}")
print(f"{dest}")
try:
urllib.request.urlretrieve(url, dest)
size_mb = os.path.getsize(dest) / 1024 / 1024
print(f" ✓ 完成 ({size_mb:.1f} MB)")
except Exception as e:
print(f" ✗ 失败: {e}", file=sys.stderr)
if os.path.exists(dest):
os.remove(dest)
def main():
print("=" * 50)
print("SAM 2 模型权重下载")
print("=" * 50)
for name, url in MODELS.items():
dest = os.path.join(MODEL_DIR, name)
download_file(url, dest)
print("=" * 50)
print("全部完成!")
print(f"模型目录: {MODEL_DIR}")
print("=" * 50)
if __name__ == "__main__":
main()

83
backend/main.py Normal file
View File

@@ -0,0 +1,83 @@
"""FastAPI application entrypoint."""
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from config import settings
from database import Base, engine
from minio_client import ensure_bucket_exists
from redis_client import ping as redis_ping
from routers import projects, templates, media, ai, export, auth
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan: startup and shutdown hooks."""
# Startup
logger.info("Starting up SegServer backend...")
# Initialize database tables
try:
Base.metadata.create_all(bind=engine)
logger.info("Database tables initialized.")
except Exception as exc: # noqa: BLE001
logger.error("Database initialization failed: %s", exc)
# Check MinIO bucket
try:
ensure_bucket_exists()
except Exception as exc: # noqa: BLE001
logger.error("MinIO bucket check failed: %s", exc)
# Check Redis
if redis_ping():
logger.info("Redis connection OK.")
else:
logger.warning("Redis connection failed.")
yield
# Shutdown
logger.info("Shutting down SegServer backend...")
engine.dispose()
app = FastAPI(
title="SegServer API",
description="Semantic Segmentation System Backend",
version="1.0.0",
lifespan=lifespan,
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Routers
app.include_router(auth.router)
app.include_router(projects.router)
app.include_router(templates.router)
app.include_router(media.router)
app.include_router(ai.router)
app.include_router(export.router)
@app.get("/health", tags=["Health"])
def health_check() -> dict:
"""Health check endpoint."""
return {"status": "ok", "service": "SegServer"}

127
backend/minio_client.py Normal file
View File

@@ -0,0 +1,127 @@
"""MinIO client wrapper for object storage operations."""
import io
import logging
from typing import Optional
from minio import Minio
from minio.error import S3Error
from config import settings
logger = logging.getLogger(__name__)
BUCKET_NAME = "seg-media"
_minio_client: Optional[Minio] = None
def get_minio_client() -> Minio:
"""Return a singleton MinIO client instance."""
global _minio_client
if _minio_client is None:
_minio_client = Minio(
settings.minio_endpoint,
access_key=settings.minio_access_key,
secret_key=settings.minio_secret_key,
secure=settings.minio_secure,
)
return _minio_client
def ensure_bucket_exists() -> None:
"""Create the bucket if it does not already exist."""
client = get_minio_client()
try:
if not client.bucket_exists(BUCKET_NAME):
client.make_bucket(BUCKET_NAME)
logger.info("Created MinIO bucket: %s", BUCKET_NAME)
else:
logger.info("MinIO bucket %s already exists", BUCKET_NAME)
except S3Error as exc:
logger.error("MinIO bucket check/creation failed: %s", exc)
raise
def upload_file(
object_name: str,
data: bytes,
content_type: str = "application/octet-stream",
length: int = -1,
) -> str:
"""Upload bytes to MinIO and return the object name.
Args:
object_name: Destination path inside the bucket.
data: Raw bytes or a file-like object.
content_type: MIME type of the object.
length: Object size; -1 for unknown (uses chunked upload).
Returns:
The object name (same as input).
"""
client = get_minio_client()
if isinstance(data, bytes):
data = io.BytesIO(data)
length = len(data.getvalue())
try:
client.put_object(
BUCKET_NAME,
object_name,
data,
length=length,
content_type=content_type,
)
logger.info("Uploaded to MinIO: %s", object_name)
return object_name
except S3Error as exc:
logger.error("MinIO upload failed: %s", exc)
raise
from datetime import timedelta
def get_presigned_url(
object_name: str,
expires: int = 3600,
method: str = "GET",
) -> str:
"""Generate a presigned URL for an object.
Args:
object_name: Path inside the bucket.
expires: Expiration time in seconds (default 1 hour).
method: HTTP method (GET or PUT).
Returns:
Presigned URL string.
"""
client = get_minio_client()
try:
url = client.get_presigned_url(method, BUCKET_NAME, object_name, expires=timedelta(seconds=expires))
return url
except S3Error as exc:
logger.error("MinIO presigned URL failed: %s", exc)
raise
def download_file(object_name: str) -> bytes:
"""Download an object from MinIO and return its bytes.
Args:
object_name: Path inside the bucket.
Returns:
Raw bytes of the object.
"""
client = get_minio_client()
try:
response = client.get_object(BUCKET_NAME, object_name)
data = response.read()
response.close()
response.release_conn()
return data
except S3Error as exc:
logger.error("MinIO download failed: %s", exc)
raise

118
backend/models.py Normal file
View File

@@ -0,0 +1,118 @@
"""SQLAlchemy ORM models."""
from sqlalchemy import (
Column,
Integer,
String,
Text,
DateTime,
ForeignKey,
JSON,
Float,
)
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from database import Base
class Project(Base):
"""Project model representing a segmentation project."""
__tablename__ = "projects"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(255), nullable=False)
description = Column(Text, nullable=True)
video_path = Column(String(512), nullable=True)
status = Column(String(50), default="pending", nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
frames = relationship("Frame", back_populates="project", cascade="all, delete-orphan")
annotations = relationship(
"Annotation", back_populates="project", cascade="all, delete-orphan"
)
class Frame(Base):
"""Frame model representing an extracted video frame."""
__tablename__ = "frames"
id = Column(Integer, primary_key=True, index=True)
project_id = Column(Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
frame_index = Column(Integer, nullable=False)
image_url = Column(String(512), nullable=False)
width = Column(Integer, nullable=True)
height = Column(Integer, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
project = relationship("Project", back_populates="frames")
annotations = relationship(
"Annotation", back_populates="frame", cascade="all, delete-orphan"
)
class Template(Base):
"""Template (Ontology) model for segmentation classes."""
__tablename__ = "templates"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(255), nullable=False)
color = Column(String(50), nullable=False)
z_index = Column(Integer, default=0, nullable=False)
mapping_rules = Column(JSON, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
annotations = relationship(
"Annotation", back_populates="template", cascade="all, delete-orphan"
)
class Annotation(Base):
"""Annotation model for segmentation masks and prompts."""
__tablename__ = "annotations"
id = Column(Integer, primary_key=True, index=True)
project_id = Column(
Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=False
)
frame_id = Column(
Integer, ForeignKey("frames.id", ondelete="CASCADE"), nullable=True
)
template_id = Column(
Integer, ForeignKey("templates.id", ondelete="SET NULL"), nullable=True
)
mask_data = Column(JSON, nullable=True)
points = Column(JSON, nullable=True)
bbox = Column(JSON, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
project = relationship("Project", back_populates="annotations")
frame = relationship("Frame", back_populates="annotations")
template = relationship("Template", back_populates="annotations")
masks = relationship("Mask", back_populates="annotation", cascade="all, delete-orphan")
class Mask(Base):
"""Mask model for exported/derived mask files."""
__tablename__ = "masks"
id = Column(Integer, primary_key=True, index=True)
annotation_id = Column(
Integer, ForeignKey("annotations.id", ondelete="CASCADE"), nullable=False
)
mask_url = Column(String(512), nullable=False)
format = Column(String(50), default="png", nullable=False) # png / rle / json
created_at = Column(DateTime(timezone=True), server_default=func.now())
annotation = relationship("Annotation", back_populates="masks")

61
backend/redis_client.py Normal file
View File

@@ -0,0 +1,61 @@
"""Redis client wrapper for caching and task queuing."""
import json
import logging
from typing import Optional, Any
import redis
from config import settings
logger = logging.getLogger(__name__)
_redis_client: Optional[redis.Redis] = None
def get_redis_client() -> redis.Redis:
"""Return a singleton Redis client instance."""
global _redis_client
if _redis_client is None:
_redis_client = redis.from_url(settings.redis_url, decode_responses=True)
return _redis_client
def ping() -> bool:
"""Check Redis connectivity."""
try:
return get_redis_client().ping()
except redis.ConnectionError as exc:
logger.error("Redis ping failed: %s", exc)
return False
def set_json(key: str, value: Any, expire: Optional[int] = None) -> None:
"""Store a JSON-serializable value in Redis."""
client = get_redis_client()
try:
client.set(key, json.dumps(value), ex=expire)
except redis.RedisError as exc:
logger.error("Redis set_json failed: %s", exc)
raise
def get_json(key: str) -> Optional[Any]:
"""Retrieve and deserialize a JSON value from Redis."""
client = get_redis_client()
try:
data = client.get(key)
return json.loads(data) if data is not None else None
except redis.RedisError as exc:
logger.error("Redis get_json failed: %s", exc)
raise
def delete_key(key: str) -> int:
"""Delete a key from Redis. Returns number of deleted keys."""
client = get_redis_client()
try:
return client.delete(key)
except redis.RedisError as exc:
logger.error("Redis delete_key failed: %s", exc)
raise

38
backend/requirements.txt Normal file
View File

@@ -0,0 +1,38 @@
# Web framework
fastapi
uvicorn[standard]
python-multipart
# Database
sqlalchemy
psycopg2-binary
alembic
# Cache / Task queue
redis
celery
# Object storage
minio
# Image / Video / DICOM processing
opencv-python
pillow
scikit-image
pydicom
numpy
# SAM 2 (may require manual installation depending on CUDA version)
sam2
# PyTorch (CUDA 12.4 wheel index; adjust for your CUDA version if needed)
torch
torchvision
torchaudio
# Configuration
pydantic-settings
# Utilities
python-jose[cryptography]
passlib[bcrypt]

View File

123
backend/routers/ai.py Normal file
View File

@@ -0,0 +1,123 @@
"""AI inference endpoints using SAM 2."""
import logging
from typing import Any, List
import cv2
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from database import get_db
from minio_client import download_file
from models import Frame, Annotation
from schemas import PredictRequest, PredictResponse, AnnotationOut, AnnotationCreate
from services.sam2_engine import sam_engine
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/ai", tags=["AI"])
def _load_frame_image(frame: Frame) -> np.ndarray:
"""Download a frame from MinIO and decode it to an RGB numpy array."""
try:
data = download_file(frame.image_url)
arr = np.frombuffer(data, dtype=np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None:
raise ValueError("OpenCV could not decode image")
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
except Exception as exc: # noqa: BLE001
logger.error("Failed to load frame image: %s", exc)
raise HTTPException(status_code=500, detail="Failed to load frame image") from exc
@router.post(
"/predict",
response_model=PredictResponse,
summary="Run SAM 2 inference with a prompt",
)
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
"""Execute SAM 2 segmentation given an image and a prompt.
- **point**: `prompt_data` is a list of `[[x, y], ...]` normalized coordinates.
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
- **semantic**: Not yet implemented; falls back to auto segmentation.
"""
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
image = _load_frame_image(frame)
prompt_type = payload.prompt_type.lower()
polygons: List[List[List[float]]] = []
scores: List[float] = []
if prompt_type == "point":
points = payload.prompt_data
if not isinstance(points, list) or len(points) == 0:
raise HTTPException(status_code=400, detail="Invalid point prompt data")
labels = [1] * len(points)
polygons, scores = sam_engine.predict_points(image, points, labels)
elif prompt_type == "box":
box = payload.prompt_data
if not isinstance(box, list) or len(box) != 4:
raise HTTPException(status_code=400, detail="Invalid box prompt data")
polygons, scores = sam_engine.predict_box(image, box)
elif prompt_type == "semantic":
# Placeholder: use auto segmentation for now
logger.info("Semantic prompt not implemented; using auto segmentation")
polygons, scores = sam_engine.predict_auto(image)
else:
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
return {"polygons": polygons, "scores": scores}
@router.post(
"/auto",
response_model=PredictResponse,
summary="Run automatic segmentation",
)
def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
"""Run automatic mask generation on a frame using a grid of point prompts."""
frame = db.query(Frame).filter(Frame.id == image_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
image = _load_frame_image(frame)
polygons, scores = sam_engine.predict_auto(image)
return {"polygons": polygons, "scores": scores}
@router.post(
"/annotate",
response_model=AnnotationOut,
status_code=status.HTTP_201_CREATED,
summary="Save an AI-generated annotation",
)
def save_annotation(
payload: AnnotationCreate,
db: Session = Depends(get_db),
) -> Annotation:
"""Persist an annotation (mask, points, bbox) into the database."""
project = db.query(Frame).filter(Frame.id == payload.project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if payload.frame_id:
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
annotation = Annotation(**payload.model_dump())
db.add(annotation)
db.commit()
db.refresh(annotation)
logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id)
return annotation

24
backend/routers/auth.py Normal file
View File

@@ -0,0 +1,24 @@
"""Authentication endpoints."""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
router = APIRouter(prefix="/api/auth", tags=["Auth"])
class LoginRequest(BaseModel):
username: str
password: str
class LoginResponse(BaseModel):
token: str
username: str
@router.post("/login", response_model=LoginResponse)
def login(payload: LoginRequest) -> dict:
"""Simple login for development."""
if payload.username == "admin" and payload.password == "123456":
return {"token": "fake-jwt-token-for-admin", "username": payload.username}
raise HTTPException(status_code=401, detail="Invalid credentials")

194
backend/routers/export.py Normal file
View File

@@ -0,0 +1,194 @@
"""Annotation export endpoints (COCO, PNG masks)."""
import io
import json
import logging
import os
import zipfile
from datetime import datetime
from typing import Any, Dict, List
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from database import get_db
from models import Project, Annotation, Frame, Template
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/export", tags=["Export"])
def _mask_from_polygon(
polygon: List[List[float]],
width: int,
height: int,
) -> np.ndarray:
"""Render a normalized polygon to a binary mask."""
import cv2
pts = np.array(
[[int(p[0] * width), int(p[1] * height)] for p in polygon],
dtype=np.int32,
)
mask = np.zeros((height, width), dtype=np.uint8)
cv2.fillPoly(mask, [pts], 255)
return mask
@router.get(
"/{project_id}/coco",
summary="Export annotations in COCO format",
)
def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
"""Export all annotations for a project as a COCO-format JSON file."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
annotations = (
db.query(Annotation)
.filter(Annotation.project_id == project_id)
.all()
)
frames = (
db.query(Frame)
.filter(Frame.project_id == project_id)
.order_by(Frame.frame_index)
.all()
)
templates = db.query(Template).all()
# Build COCO structure
images = []
for idx, frame in enumerate(frames):
images.append({
"id": frame.id,
"file_name": frame.image_url,
"width": frame.width or 1920,
"height": frame.height or 1080,
"frame_index": idx,
})
categories = []
template_id_to_cat_id: Dict[int, int] = {}
for cat_idx, tmpl in enumerate(templates, start=1):
categories.append({
"id": cat_idx,
"name": tmpl.name,
"color": tmpl.color,
})
template_id_to_cat_id[tmpl.id] = cat_idx
coco_annotations = []
ann_id = 1
for ann in annotations:
if not ann.mask_data:
continue
polygons = ann.mask_data.get("polygons", [])
if not polygons:
continue
# Use first polygon for bbox / area approximation
first_poly = polygons[0]
xs = [p[0] for p in first_poly]
ys = [p[1] for p in first_poly]
width = ann.frame.width if ann.frame else 1920
height = ann.frame.height if ann.frame else 1080
bbox = [
min(xs) * width,
min(ys) * height,
(max(xs) - min(xs)) * width,
(max(ys) - min(ys)) * height,
]
area = bbox[2] * bbox[3]
segmentation = []
for poly in polygons:
flat = []
for p in poly:
flat.append(p[0] * width)
flat.append(p[1] * height)
segmentation.append(flat)
coco_annotations.append({
"id": ann_id,
"image_id": ann.frame_id,
"category_id": template_id_to_cat_id.get(ann.template_id, 0),
"segmentation": segmentation,
"area": area,
"bbox": bbox,
"iscrowd": 0,
})
ann_id += 1
coco = {
"info": {
"description": f"Annotations for {project.name}",
"version": "1.0",
"year": datetime.now().year,
"date_created": datetime.now().isoformat(),
},
"images": images,
"annotations": coco_annotations,
"categories": categories,
}
data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8")
filename = f"project_{project_id}_coco.json"
return StreamingResponse(
io.BytesIO(data),
media_type="application/json",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
@router.get(
"/{project_id}/masks",
summary="Export PNG masks as a ZIP archive",
)
def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
"""Export all annotation masks as individual PNG files inside a ZIP archive."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
annotations = (
db.query(Annotation)
.filter(Annotation.project_id == project_id)
.all()
)
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
for ann in annotations:
if not ann.mask_data:
continue
polygons = ann.mask_data.get("polygons", [])
if not polygons:
continue
width = ann.frame.width if ann.frame else 1920
height = ann.frame.height if ann.frame else 1080
combined = np.zeros((height, width), dtype=np.uint8)
for poly in polygons:
mask = _mask_from_polygon(poly, width, height)
combined = np.maximum(combined, mask)
# Encode PNG
import cv2
_, encoded = cv2.imencode(".png", combined)
fname = f"mask_{ann.id:06d}.png"
zf.writestr(fname, encoded.tobytes())
zip_buffer.seek(0)
filename = f"project_{project_id}_masks.zip"
return StreamingResponse(
zip_buffer,
media_type="application/zip",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)

192
backend/routers/media.py Normal file
View File

@@ -0,0 +1,192 @@
"""Media upload and parsing endpoints."""
import logging
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
from sqlalchemy.orm import Session
from database import get_db
from minio_client import upload_file, get_presigned_url
from models import Project, Frame
from schemas import FrameOut
from services.frame_parser import parse_video, parse_dicom, upload_frames_to_minio
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/media", tags=["Media"])
ALLOWED_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".png", ".jpg", ".jpeg", ".dcm"}
def _get_ext(filename: str) -> str:
return Path(filename).suffix.lower()
@router.post(
"/upload",
status_code=status.HTTP_201_CREATED,
summary="Upload a media file",
)
async def upload_media(
file: UploadFile = File(...),
project_id: Optional[int] = Form(None),
db: Session = Depends(get_db),
) -> dict:
"""Accept a video, image, or DICOM file and store it in MinIO.
If project_id is provided, the video_path of the project is updated.
Returns the presigned URL of the uploaded object.
"""
if not file.filename:
raise HTTPException(status_code=400, detail="Missing filename")
ext = _get_ext(file.filename)
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {ext}",
)
data = await file.read()
object_name = f"uploads/{project_id or 'general'}/{file.filename}"
try:
upload_file(object_name, data, content_type=file.content_type or "application/octet-stream", length=len(data))
except Exception as exc: # noqa: BLE001
logger.error("Upload failed: %s", exc)
raise HTTPException(status_code=500, detail="Upload to storage failed") from exc
file_url = get_presigned_url(object_name, expires=3600)
if project_id:
project = db.query(Project).filter(Project.id == project_id).first()
if project:
project.video_path = object_name
db.commit()
logger.info("Linked upload to project_id=%s", project_id)
else:
logger.warning("Project id=%s not found for upload linkage", project_id)
# TODO: enqueue async parsing job (Celery / background task)
logger.info("Upload complete: %s (size=%d bytes). Async parsing queued.", object_name, len(data))
return {
"object_name": object_name,
"file_url": file_url,
"size": len(data),
"message": "Upload successful. Parsing job queued.",
}
@router.post(
"/parse",
status_code=status.HTTP_202_ACCEPTED,
summary="Trigger frame extraction",
)
def parse_media(
project_id: int,
source_type: str = "video", # video | dicom
db: Session = Depends(get_db),
) -> dict:
"""Trigger frame extraction for a project's uploaded media.
* video: uses FFmpeg or OpenCV fallback.
* dicom: uses pydicom to read DCM frames.
Extracted frames are uploaded to MinIO and registered in the database.
"""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if not project.video_path:
raise HTTPException(status_code=400, detail="Project has no media uploaded")
# Download from MinIO to a temp directory
from minio_client import download_file
try:
media_bytes = download_file(project.video_path)
except Exception as exc: # noqa: BLE001
logger.error("Failed to download media for parsing: %s", exc)
raise HTTPException(status_code=500, detail="Failed to retrieve media from storage") from exc
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project_id}_")
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
with open(local_path, "wb") as f:
f.write(media_bytes)
output_dir = os.path.join(tmp_dir, "frames")
os.makedirs(output_dir, exist_ok=True)
try:
if source_type == "dicom":
# For DICOM, treat local_path as a directory if it contains multiple .dcm
# If a single .dcm file was uploaded, put it in its own sub-dir
dcm_dir = os.path.join(tmp_dir, "dcm")
os.makedirs(dcm_dir, exist_ok=True)
if local_path.lower().endswith(".dcm"):
shutil.move(local_path, os.path.join(dcm_dir, os.path.basename(local_path)))
else:
shutil.unpack_archive(local_path, dcm_dir) if shutil.which("unzip") else shutil.move(local_path, dcm_dir)
frame_files = parse_dicom(dcm_dir, output_dir)
else:
frame_files = parse_video(local_path, output_dir, fps=30)
except Exception as exc: # noqa: BLE001
logger.error("Frame extraction failed: %s", exc)
shutil.rmtree(tmp_dir, ignore_errors=True)
raise HTTPException(status_code=500, detail="Frame extraction failed") from exc
# Upload frames to MinIO
try:
object_names = upload_frames_to_minio(frame_files, project_id)
except Exception as exc: # noqa: BLE001
logger.error("Frame upload failed: %s", exc)
shutil.rmtree(tmp_dir, ignore_errors=True)
raise HTTPException(status_code=500, detail="Frame upload to storage failed") from exc
# Register frames in DB
frames_out = []
for idx, obj_name in enumerate(object_names):
# Get image dimensions
local_frame = frame_files[idx]
try:
import cv2
img = cv2.imread(local_frame)
h, w = img.shape[:2] if img is not None else (None, None)
except Exception: # noqa: BLE001
h, w = None, None
frame = Frame(
project_id=project_id,
frame_index=idx,
image_url=obj_name,
width=w,
height=h,
)
db.add(frame)
frames_out.append(frame)
db.commit()
for f in frames_out:
db.refresh(f)
# Cleanup temp files
shutil.rmtree(tmp_dir, ignore_errors=True)
project.status = "ready"
db.commit()
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project_id)
return {
"project_id": project_id,
"frames_extracted": len(frames_out),
"status": "ready",
"message": "Frame extraction completed successfully.",
}

165
backend/routers/projects.py Normal file
View File

@@ -0,0 +1,165 @@
"""Project and Frame CRUD endpoints."""
import logging
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from database import get_db
from models import Project, Frame
from schemas import ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/projects", tags=["Projects"])
# ---------------------------------------------------------------------------
# Projects
# ---------------------------------------------------------------------------
@router.post(
"",
response_model=ProjectOut,
status_code=status.HTTP_201_CREATED,
summary="Create a new project",
)
def create_project(payload: ProjectCreate, db: Session = Depends(get_db)) -> Project:
"""Create a new segmentation project."""
project = Project(**payload.model_dump())
db.add(project)
db.commit()
db.refresh(project)
logger.info("Created project id=%s name=%s", project.id, project.name)
return project
@router.get(
"",
response_model=List[ProjectOut],
summary="List all projects",
)
def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)) -> List[Project]:
"""Retrieve a paginated list of projects."""
return db.query(Project).offset(skip).limit(limit).all()
@router.get(
"/{project_id}",
response_model=ProjectOut,
summary="Get a single project",
)
def get_project(project_id: int, db: Session = Depends(get_db)) -> Project:
"""Retrieve a project by its ID."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return project
@router.patch(
"/{project_id}",
response_model=ProjectOut,
summary="Update a project",
)
def update_project(
project_id: int,
payload: ProjectUpdate,
db: Session = Depends(get_db),
) -> Project:
"""Update project fields partially."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
for key, value in payload.model_dump(exclude_unset=True).items():
setattr(project, key, value)
db.commit()
db.refresh(project)
logger.info("Updated project id=%s", project_id)
return project
@router.delete(
"/{project_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a project",
)
def delete_project(project_id: int, db: Session = Depends(get_db)) -> None:
"""Delete a project and all related frames and annotations."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
db.delete(project)
db.commit()
logger.info("Deleted project id=%s", project_id)
# ---------------------------------------------------------------------------
# Frames
# ---------------------------------------------------------------------------
@router.post(
"/{project_id}/frames",
response_model=FrameOut,
status_code=status.HTTP_201_CREATED,
summary="Add a frame to a project",
)
def create_frame(
project_id: int,
payload: FrameCreate,
db: Session = Depends(get_db),
) -> Frame:
"""Register a new frame under a project."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
frame = Frame(project_id=project_id, **payload.model_dump(exclude={"project_id"}))
db.add(frame)
db.commit()
db.refresh(frame)
return frame
@router.get(
"/{project_id}/frames",
response_model=List[FrameOut],
summary="List frames for a project",
)
def list_frames(
project_id: int,
skip: int = 0,
limit: int = 1000,
db: Session = Depends(get_db),
) -> List[Frame]:
"""Retrieve all frames belonging to a project."""
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return (
db.query(Frame)
.filter(Frame.project_id == project_id)
.order_by(Frame.frame_index)
.offset(skip)
.limit(limit)
.all()
)
@router.get(
"/{project_id}/frames/{frame_id}",
response_model=FrameOut,
summary="Get a single frame",
)
def get_frame(project_id: int, frame_id: int, db: Session = Depends(get_db)) -> Frame:
"""Retrieve a specific frame by ID."""
frame = (
db.query(Frame)
.filter(Frame.project_id == project_id, Frame.id == frame_id)
.first()
)
if not frame:
raise HTTPException(status_code=404, detail="Frame not found")
return frame

View File

@@ -0,0 +1,97 @@
"""Template (Ontology) CRUD endpoints."""
import logging
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from database import get_db
from models import Template
from schemas import TemplateCreate, TemplateOut, TemplateUpdate
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/templates", tags=["Templates"])
@router.post(
"",
response_model=TemplateOut,
status_code=status.HTTP_201_CREATED,
summary="Create a new template",
)
def create_template(payload: TemplateCreate, db: Session = Depends(get_db)) -> Template:
"""Create a new ontology template / segmentation class."""
template = Template(**payload.model_dump())
db.add(template)
db.commit()
db.refresh(template)
logger.info("Created template id=%s name=%s", template.id, template.name)
return template
@router.get(
"",
response_model=List[TemplateOut],
summary="List all templates",
)
def list_templates(
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
) -> List[Template]:
"""Retrieve all ontology templates."""
return db.query(Template).offset(skip).limit(limit).all()
@router.get(
"/{template_id}",
response_model=TemplateOut,
summary="Get a single template",
)
def get_template(template_id: int, db: Session = Depends(get_db)) -> Template:
"""Retrieve a template by its ID."""
template = db.query(Template).filter(Template.id == template_id).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
return template
@router.patch(
"/{template_id}",
response_model=TemplateOut,
summary="Update a template",
)
def update_template(
template_id: int,
payload: TemplateUpdate,
db: Session = Depends(get_db),
) -> Template:
"""Update template fields partially."""
template = db.query(Template).filter(Template.id == template_id).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
for key, value in payload.model_dump(exclude_unset=True).items():
setattr(template, key, value)
db.commit()
db.refresh(template)
logger.info("Updated template id=%s", template_id)
return template
@router.delete(
"/{template_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete a template",
)
def delete_template(template_id: int, db: Session = Depends(get_db)) -> None:
"""Delete a template. Associated annotations will have template_id set to NULL."""
template = db.query(Template).filter(Template.id == template_id).first()
if not template:
raise HTTPException(status_code=404, detail="Template not found")
db.delete(template)
db.commit()
logger.info("Deleted template id=%s", template_id)

157
backend/schemas.py Normal file
View File

@@ -0,0 +1,157 @@
"""Pydantic schemas for request/response validation."""
from datetime import datetime
from typing import Optional, Any
from pydantic import BaseModel, ConfigDict
# ---------------------------------------------------------------------------
# Project schemas
# ---------------------------------------------------------------------------
class ProjectBase(BaseModel):
name: str
description: Optional[str] = None
video_path: Optional[str] = None
status: Optional[str] = "pending"
class ProjectCreate(ProjectBase):
pass
class ProjectUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
video_path: Optional[str] = None
status: Optional[str] = None
class ProjectOut(ProjectBase):
model_config = ConfigDict(from_attributes=True)
id: int
created_at: datetime
updated_at: datetime
# ---------------------------------------------------------------------------
# Frame schemas
# ---------------------------------------------------------------------------
class FrameBase(BaseModel):
frame_index: int
image_url: str
width: Optional[int] = None
height: Optional[int] = None
class FrameCreate(FrameBase):
project_id: int
class FrameOut(FrameBase):
model_config = ConfigDict(from_attributes=True)
id: int
project_id: int
created_at: datetime
# ---------------------------------------------------------------------------
# Template schemas
# ---------------------------------------------------------------------------
class TemplateBase(BaseModel):
name: str
color: str
z_index: int = 0
mapping_rules: Optional[dict[str, Any]] = None
class TemplateCreate(TemplateBase):
pass
class TemplateUpdate(BaseModel):
name: Optional[str] = None
color: Optional[str] = None
z_index: Optional[int] = None
mapping_rules: Optional[dict[str, Any]] = None
class TemplateOut(TemplateBase):
model_config = ConfigDict(from_attributes=True)
id: int
created_at: datetime
# ---------------------------------------------------------------------------
# Annotation schemas
# ---------------------------------------------------------------------------
class AnnotationBase(BaseModel):
project_id: int
frame_id: Optional[int] = None
template_id: Optional[int] = None
mask_data: Optional[dict[str, Any]] = None
points: Optional[list[list[float]]] = None
bbox: Optional[list[float]] = None
class AnnotationCreate(AnnotationBase):
pass
class AnnotationUpdate(BaseModel):
mask_data: Optional[dict[str, Any]] = None
points: Optional[list[list[float]]] = None
bbox: Optional[list[float]] = None
template_id: Optional[int] = None
class AnnotationOut(AnnotationBase):
model_config = ConfigDict(from_attributes=True)
id: int
created_at: datetime
updated_at: datetime
# ---------------------------------------------------------------------------
# Mask schemas
# ---------------------------------------------------------------------------
class MaskBase(BaseModel):
annotation_id: int
mask_url: str
format: str = "png"
class MaskCreate(MaskBase):
pass
class MaskOut(MaskBase):
model_config = ConfigDict(from_attributes=True)
id: int
created_at: datetime
# ---------------------------------------------------------------------------
# AI schemas
# ---------------------------------------------------------------------------
class PredictRequest(BaseModel):
image_id: int
prompt_type: str # point / box / semantic
prompt_data: Any
class PredictResponse(BaseModel):
polygons: list[list[list[float]]]
scores: Optional[list[float]] = None
# ---------------------------------------------------------------------------
# Export schemas
# ---------------------------------------------------------------------------
class ExportStatus(BaseModel):
url: str
format: str

View File

View File

@@ -0,0 +1,186 @@
"""Video/DICOM frame parsing and MinIO upload utilities."""
import logging
import os
import shutil
import subprocess
from pathlib import Path
from typing import List, Optional
import cv2
import numpy as np
from pydicom import dcmread
from minio_client import upload_file, BUCKET_NAME
logger = logging.getLogger(__name__)
def parse_video(
video_path: str,
output_dir: str,
fps: int = 30,
max_frames: Optional[int] = None,
) -> List[str]:
"""Extract frames from a video file using FFmpeg or OpenCV fallback.
Args:
video_path: Path to the input video file.
output_dir: Directory to save extracted frames.
fps: Target frame extraction rate.
max_frames: Optional maximum number of frames to extract.
Returns:
List of paths to extracted frame images.
"""
os.makedirs(output_dir, exist_ok=True)
frame_paths: List[str] = []
# Try FFmpeg first
if shutil.which("ffmpeg"):
try:
pattern = os.path.join(output_dir, "frame_%06d.png")
cmd = [
"ffmpeg",
"-i", video_path,
"-vf", f"fps={fps},scale='min(1920,iw)':-1",
"-pix_fmt", "rgb24",
"-y",
pattern,
]
logger.info("Running FFmpeg: %s", " ".join(cmd))
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
if result.returncode == 0:
frame_paths = sorted(
[os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith(".png")]
)
if max_frames:
frame_paths = frame_paths[:max_frames]
logger.info("Extracted %d frames via FFmpeg", len(frame_paths))
return frame_paths
else:
logger.warning("FFmpeg failed: %s", result.stderr)
except Exception as exc: # noqa: BLE001
logger.warning("FFmpeg exception: %s", exc)
# OpenCV fallback
logger.info("Falling back to OpenCV frame extraction")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"Cannot open video: {video_path}")
video_fps = cap.get(cv2.CAP_PROP_FPS) or 30
interval = max(1, int(round(video_fps / fps)))
count = 0
saved = 0
while True:
ret, frame = cap.read()
if not ret:
break
if count % interval == 0:
path = os.path.join(output_dir, f"frame_{saved:06d}.png")
cv2.imwrite(path, frame)
frame_paths.append(path)
saved += 1
if max_frames and saved >= max_frames:
break
count += 1
cap.release()
logger.info("Extracted %d frames via OpenCV", len(frame_paths))
return frame_paths
def parse_dicom(
dicom_dir: str,
output_dir: str,
max_frames: Optional[int] = None,
) -> List[str]:
"""Extract frames from DICOM files in a directory.
Args:
dicom_dir: Directory containing .dcm files.
output_dir: Directory to save extracted frames.
max_frames: Optional maximum number of frames to extract.
Returns:
List of paths to extracted frame images.
"""
os.makedirs(output_dir, exist_ok=True)
dcm_files = sorted(
[f for f in os.listdir(dicom_dir) if f.lower().endswith(".dcm")]
)
frame_paths: List[str] = []
for idx, fname in enumerate(dcm_files):
if max_frames and idx >= max_frames:
break
path = os.path.join(dicom_dir, fname)
try:
ds = dcmread(path)
pixel_array = ds.pixel_array
# Normalize to 8-bit
if pixel_array.dtype != np.uint8:
pixel_array = pixel_array.astype(np.float32)
pixel_array = (
(pixel_array - pixel_array.min())
/ (pixel_array.max() - pixel_array.min() + 1e-8)
* 255
)
pixel_array = pixel_array.astype(np.uint8)
# Handle multi-frame DICOM
if pixel_array.ndim == 3:
for f in range(pixel_array.shape[0]):
out_path = os.path.join(output_dir, f"frame_{idx:06d}_{f:03d}.png")
cv2.imwrite(out_path, pixel_array[f])
frame_paths.append(out_path)
else:
out_path = os.path.join(output_dir, f"frame_{idx:06d}.png")
cv2.imwrite(out_path, pixel_array)
frame_paths.append(out_path)
except Exception as exc: # noqa: BLE001
logger.error("Failed to read DICOM %s: %s", path, exc)
logger.info("Extracted %d frames from DICOM", len(frame_paths))
return frame_paths
def upload_frames_to_minio(
frames: List[str],
project_id: int,
object_prefix: Optional[str] = None,
) -> List[str]:
"""Upload a list of local frame images to MinIO.
Args:
frames: List of local file paths.
project_id: Project ID used for bucket path organization.
object_prefix: Optional prefix override.
Returns:
List of object names (paths) in MinIO.
"""
prefix = object_prefix or f"projects/{project_id}/frames"
object_names: List[str] = []
for frame_path in frames:
fname = os.path.basename(frame_path)
object_name = f"{prefix}/{fname}"
try:
with open(frame_path, "rb") as f:
data = f.read()
upload_file(
object_name,
data,
content_type="image/png",
length=len(data),
)
object_names.append(object_name)
except Exception as exc: # noqa: BLE001
logger.error("Failed to upload %s: %s", frame_path, exc)
logger.info("Uploaded %d/%d frames to MinIO", len(object_names), len(frames))
return object_names

View File

@@ -0,0 +1,234 @@
"""SAM 2 engine wrapper with lazy loading and fallback stubs."""
import logging
import os
from typing import Optional
import numpy as np
from config import settings
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Attempt to import SAM 2; fall back to stubs if unavailable.
# ---------------------------------------------------------------------------
try:
import torch
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
SAM2_AVAILABLE = True
logger.info("SAM2 library imported successfully.")
except Exception as exc: # noqa: BLE001
SAM2_AVAILABLE = False
logger.warning("SAM2 import failed (%s). Using stub engine.", exc)
class SAM2Engine:
"""Lazy-loaded SAM 2 inference engine."""
def __init__(self) -> None:
self._predictor: Optional[SAM2ImagePredictor] = None
self._model_loaded = False
# -----------------------------------------------------------------------
# Internal helpers
# -----------------------------------------------------------------------
def _load_model(self) -> None:
"""Load the SAM 2 model and predictor on first use."""
if self._model_loaded:
return
if not SAM2_AVAILABLE:
logger.warning("SAM2 not available; skipping model load.")
self._model_loaded = True
return
if not os.path.isfile(settings.sam_model_path):
logger.error("SAM checkpoint not found at %s", settings.sam_model_path)
self._model_loaded = True
return
try:
model = build_sam2(
settings.sam_model_config,
settings.sam_model_path,
device="cuda",
)
self._predictor = SAM2ImagePredictor(model)
self._model_loaded = True
logger.info("SAM 2 model loaded from %s", settings.sam_model_path)
except Exception as exc: # noqa: BLE001
logger.error("Failed to load SAM 2 model: %s", exc)
self._model_loaded = True # Prevent repeated load attempts
def _ensure_ready(self) -> bool:
"""Ensure the model is loaded; return whether it is usable."""
self._load_model()
return SAM2_AVAILABLE and self._predictor is not None
# -----------------------------------------------------------------------
# Public API
# -----------------------------------------------------------------------
def predict_points(
self,
image: np.ndarray,
points: list[list[float]],
labels: list[int],
) -> tuple[list[list[list[float]]], list[float]]:
"""Run point-prompt segmentation.
Args:
image: HWC numpy array (uint8).
points: List of [x, y] normalized coordinates (0-1).
labels: 1 for foreground, 0 for background.
Returns:
Tuple of (polygons, scores).
"""
if not self._ensure_ready():
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
h, w = image.shape[:2]
pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32)
lbls = np.array(labels, dtype=np.int32)
with torch.inference_mode(): # type: ignore[name-defined]
self._predictor.set_image(image)
masks, scores, _ = self._predictor.predict(
point_coords=pts,
point_labels=lbls,
multimask_output=True,
)
polygons = []
for m in masks:
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores.tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 point prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def predict_box(
self,
image: np.ndarray,
box: list[float],
) -> tuple[list[list[list[float]]], list[float]]:
"""Run box-prompt segmentation.
Args:
image: HWC numpy array (uint8).
box: [x1, y1, x2, y2] normalized coordinates.
Returns:
Tuple of (polygons, scores).
"""
if not self._ensure_ready():
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
h, w = image.shape[:2]
bbox = np.array(
[box[0] * w, box[1] * h, box[2] * w, box[3] * h],
dtype=np.float32,
)
with torch.inference_mode(): # type: ignore[name-defined]
self._predictor.set_image(image)
masks, scores, _ = self._predictor.predict(
box=bbox[None, :],
multimask_output=False,
)
polygons = []
for m in masks:
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores.tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 box prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
def predict_auto(self, image: np.ndarray) -> tuple[list[list[list[float]]], list[float]]:
"""Run automatic mask generation (grid of points).
Args:
image: HWC numpy array (uint8).
Returns:
Tuple of (polygons, scores).
"""
if not self._ensure_ready():
logger.warning("SAM2 not ready; returning dummy masks.")
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
try:
with torch.inference_mode(): # type: ignore[name-defined]
self._predictor.set_image(image)
# Generate a uniform 16x16 grid of point prompts
h, w = image.shape[:2]
grid = np.mgrid[0:1:17j, 0:1:17j].reshape(2, -1).T
pts = grid * np.array([w, h])
lbls = np.ones(pts.shape[0], dtype=np.int32)
masks, scores, _ = self._predictor.predict(
point_coords=pts,
point_labels=lbls,
multimask_output=True,
)
polygons = []
for m in masks[:3]: # Limit to top 3 masks
poly = self._mask_to_polygon(m)
if poly:
polygons.append(poly)
return polygons, scores[:3].tolist()
except Exception as exc: # noqa: BLE001
logger.error("SAM2 auto prediction failed: %s", exc)
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
# -----------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------
@staticmethod
def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
"""Convert a binary mask to a normalized polygon."""
import cv2
if mask.dtype != np.uint8:
mask = (mask > 0).astype(np.uint8)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
h, w = mask.shape[:2]
largest = []
for cnt in contours:
if len(cnt) > len(largest):
largest = cnt
if len(largest) < 3:
return []
return [[float(pt[0][0]) / w, float(pt[0][1]) / h] for pt in largest]
@staticmethod
def _dummy_polygons(w: int, h: int) -> list[list[list[float]]]:
"""Return a dummy rectangle polygon for fallback mode."""
return [
[
[0.25, 0.25],
[0.75, 0.25],
[0.75, 0.75],
[0.25, 0.75],
]
]
# Singleton instance
sam_engine = SAM2Engine()

146
package-lock.json generated
View File

@@ -11,6 +11,7 @@
"@google/genai": "^1.29.0",
"@tailwindcss/vite": "^4.1.14",
"@vitejs/plugin-react": "^5.0.4",
"axios": "^1.15.2",
"clsx": "^2.1.1",
"dotenv": "^17.2.3",
"express": "^4.21.2",
@@ -22,7 +23,8 @@
"react-konva": "^19.2.3",
"tailwind-merge": "^3.5.0",
"use-image": "^1.1.4",
"vite": "^6.2.0"
"vite": "^6.2.0",
"zustand": "^5.0.12"
},
"devDependencies": {
"@types/express": "^4.17.21",
@@ -1670,6 +1672,12 @@
"integrity": "sha512-PCVAQswWemu6UdxsDFFX/+gVeYqKAod3D3UVm91jHwynguOwAvYPhx8nNlM++NqRcK6CxxpUafjmhIdKiHibqg==",
"license": "MIT"
},
"node_modules/asynckit": {
"version": "0.4.0",
"resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz",
"integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==",
"license": "MIT"
},
"node_modules/autoprefixer": {
"version": "10.5.0",
"resolved": "https://registry.npmjs.org/autoprefixer/-/autoprefixer-10.5.0.tgz",
@@ -1707,6 +1715,17 @@
"postcss": "^8.1.0"
}
},
"node_modules/axios": {
"version": "1.15.2",
"resolved": "https://registry.npmjs.org/axios/-/axios-1.15.2.tgz",
"integrity": "sha512-wLrXxPtcrPTsNlJmKjkPnNPK2Ihe0hn0wGSaTEiHRPxwjvJwT3hKmXF4dpqxmPO9SoNb2FsYXj/xEo0gHN+D5A==",
"license": "MIT",
"dependencies": {
"follow-redirects": "^1.15.11",
"form-data": "^4.0.5",
"proxy-from-env": "^2.1.0"
}
},
"node_modules/base64-js": {
"version": "1.5.1",
"resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz",
@@ -1908,6 +1927,18 @@
"node": ">=6"
}
},
"node_modules/combined-stream": {
"version": "1.0.8",
"resolved": "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz",
"integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==",
"license": "MIT",
"dependencies": {
"delayed-stream": "~1.0.0"
},
"engines": {
"node": ">= 0.8"
}
},
"node_modules/content-disposition": {
"version": "0.5.4",
"resolved": "https://registry.npmjs.org/content-disposition/-/content-disposition-0.5.4.tgz",
@@ -1983,6 +2014,15 @@
}
}
},
"node_modules/delayed-stream": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz",
"integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==",
"license": "MIT",
"engines": {
"node": ">=0.4.0"
}
},
"node_modules/depd": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz",
@@ -2110,6 +2150,21 @@
"node": ">= 0.4"
}
},
"node_modules/es-set-tostringtag": {
"version": "2.1.0",
"resolved": "https://registry.npmjs.org/es-set-tostringtag/-/es-set-tostringtag-2.1.0.tgz",
"integrity": "sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==",
"license": "MIT",
"dependencies": {
"es-errors": "^1.3.0",
"get-intrinsic": "^1.2.6",
"has-tostringtag": "^1.0.2",
"hasown": "^2.0.2"
},
"engines": {
"node": ">= 0.4"
}
},
"node_modules/esbuild": {
"version": "0.27.7",
"resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.27.7.tgz",
@@ -2316,6 +2371,42 @@
"integrity": "sha512-Tpp60P6IUJDTuOq/5Z8cdskzJujfwqfOTkrwIwj7IRISpnkJnT6SyJ4PCPnGMoFjC9ddhal5KVIYtAt97ix05A==",
"license": "MIT"
},
"node_modules/follow-redirects": {
"version": "1.16.0",
"resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.16.0.tgz",
"integrity": "sha512-y5rN/uOsadFT/JfYwhxRS5R7Qce+g3zG97+JrtFZlC9klX/W5hD7iiLzScI4nZqUS7DNUdhPgw4xI8W2LuXlUw==",
"funding": [
{
"type": "individual",
"url": "https://github.com/sponsors/RubenVerborgh"
}
],
"license": "MIT",
"engines": {
"node": ">=4.0"
},
"peerDependenciesMeta": {
"debug": {
"optional": true
}
}
},
"node_modules/form-data": {
"version": "4.0.5",
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.5.tgz",
"integrity": "sha512-8RipRLol37bNs2bhoV67fiTEvdTrbMUYcFTiy3+wuuOnUog2QBHCZWXDRijWQfAkhBj2Uf5UnVaiWwA5vdd82w==",
"license": "MIT",
"dependencies": {
"asynckit": "^0.4.0",
"combined-stream": "^1.0.8",
"es-set-tostringtag": "^2.1.0",
"hasown": "^2.0.2",
"mime-types": "^2.1.12"
},
"engines": {
"node": ">= 6"
}
},
"node_modules/formdata-polyfill": {
"version": "4.0.10",
"resolved": "https://registry.npmjs.org/formdata-polyfill/-/formdata-polyfill-4.0.10.tgz",
@@ -2553,6 +2644,21 @@
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/has-tostringtag": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.2.tgz",
"integrity": "sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==",
"license": "MIT",
"dependencies": {
"has-symbols": "^1.0.3"
},
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/hasown": {
"version": "2.0.3",
"resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.3.tgz",
@@ -3346,6 +3452,15 @@
"node": ">= 0.10"
}
},
"node_modules/proxy-from-env": {
"version": "2.1.0",
"resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-2.1.0.tgz",
"integrity": "sha512-cJ+oHTW1VAEa8cJslgmUZrc+sjRKgAKl3Zyse6+PV38hZe/V6Z14TbCuXcan9F9ghlz4QrFr2c92TNF82UkYHA==",
"license": "MIT",
"engines": {
"node": ">=10"
}
},
"node_modules/qs": {
"version": "6.14.2",
"resolved": "https://registry.npmjs.org/qs/-/qs-6.14.2.tgz",
@@ -4461,6 +4576,35 @@
"resolved": "https://registry.npmjs.org/yallist/-/yallist-3.1.1.tgz",
"integrity": "sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g==",
"license": "ISC"
},
"node_modules/zustand": {
"version": "5.0.12",
"resolved": "https://registry.npmjs.org/zustand/-/zustand-5.0.12.tgz",
"integrity": "sha512-i77ae3aZq4dhMlRhJVCYgMLKuSiZAaUPAct2AksxQ+gOtimhGMdXljRT21P5BNpeT4kXlLIckvkPM029OljD7g==",
"license": "MIT",
"engines": {
"node": ">=12.20.0"
},
"peerDependencies": {
"@types/react": ">=18.0.0",
"immer": ">=9.0.6",
"react": ">=18.0.0",
"use-sync-external-store": ">=1.2.0"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
},
"immer": {
"optional": true
},
"react": {
"optional": true
},
"use-sync-external-store": {
"optional": true
}
}
}
}
}

View File

@@ -15,6 +15,7 @@
"@google/genai": "^1.29.0",
"@tailwindcss/vite": "^4.1.14",
"@vitejs/plugin-react": "^5.0.4",
"axios": "^1.15.2",
"clsx": "^2.1.1",
"dotenv": "^17.2.3",
"express": "^4.21.2",
@@ -26,7 +27,8 @@
"react-konva": "^19.2.3",
"tailwind-merge": "^3.5.0",
"use-image": "^1.1.4",
"vite": "^6.2.0"
"vite": "^6.2.0",
"zustand": "^5.0.12"
},
"devDependencies": {
"@types/express": "^4.17.21",

View File

@@ -1,4 +1,6 @@
import React, { useState } from 'react';
import React, { useEffect } from 'react';
import { useStore } from './store/useStore';
import { getProjects } from './lib/api';
import { Sidebar } from './components/Sidebar';
import { Dashboard } from './components/Dashboard';
import { ProjectLibrary } from './components/ProjectLibrary';
@@ -10,16 +12,30 @@ import { Login } from './components/Login';
export type ActiveModule = 'dashboard' | 'projects' | 'ai' | 'workspace' | 'templates';
export default function App() {
const [activeModule, setActiveModule] = useState<ActiveModule>('workspace');
const [isAuthenticated, setIsAuthenticated] = useState(false);
const isAuthenticated = useStore((state) => state.isAuthenticated);
const activeModule = useStore((state) => state.activeModule);
const setActiveModule = useStore((state) => state.setActiveModule);
const setProjects = useStore((state) => state.setProjects);
const setError = useStore((state) => state.setError);
useEffect(() => {
if (isAuthenticated) {
getProjects()
.then((data) => setProjects(data))
.catch((err) => {
console.error('Failed to fetch projects:', err);
setError('获取项目列表失败');
});
}
}, [isAuthenticated, setProjects, setError]);
if (!isAuthenticated) {
return <Login onLoginSuccess={() => setIsAuthenticated(true)} />;
return <Login />;
}
return (
<div className="flex h-screen w-full bg-[#0a0a0a] text-gray-200 overflow-hidden font-sans">
<Sidebar activeModule={activeModule} setActiveModule={setActiveModule} />
<Sidebar activeModule={activeModule as ActiveModule} setActiveModule={setActiveModule} />
<main className="flex-1 flex flex-col min-w-0 h-full relative">
{activeModule === 'dashboard' && <Dashboard />}
{activeModule === 'projects' && <ProjectLibrary onProjectSelect={() => setActiveModule('workspace')} />}

View File

@@ -1,20 +1,28 @@
import React, { useState } from 'react';
import { Target, PlusCircle, MinusCircle, SquareDashed, Sparkles, Settings2, Cpu, Image as ImageIcon, SendToBack, Tags, Undo, Redo } from 'lucide-react';
import React, { useState, useCallback } from 'react';
import { Target, PlusCircle, MinusCircle, SquareDashed, Sparkles, SendToBack, Image as ImageIcon, Undo, Redo, Loader2 } from 'lucide-react';
import { cn } from '../lib/utils';
import { Stage, Layer, Image as KonvaImage, Circle, Path, Group } from 'react-konva';
import useImage from 'use-image';
import { OntologyInspector } from './OntologyInspector';
import { useStore } from '../store/useStore';
import { predictMask } from '../lib/api';
interface AISegmentationProps {
onSendToWorkspace: () => void;
}
export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const [activeTool, setActiveTool] = useState('point_pos');
const storeActiveTool = useStore((state) => state.activeTool);
const setActiveTool = useStore((state) => state.setActiveTool);
const masks = useStore((state) => state.masks);
const addMask = useStore((state) => state.addMask);
const clearMasks = useStore((state) => state.clearMasks);
const [modelSize, setModelSize] = useState('vit_l');
const [semanticText, setSemanticText] = useState('');
const [autoDeleteBg, setAutoDeleteBg] = useState(true);
const [cropMode, setCropMode] = useState(false);
const [isInferencing, setIsInferencing] = useState(false);
// Canvas state
const [scale, setScale] = useState(1);
@@ -23,6 +31,8 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 });
const [image] = useImage('https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop');
const effectiveTool = storeActiveTool;
const handleWheel = (e: any) => {
e.evt.preventDefault();
const scaleBy = 1.1;
@@ -51,13 +61,43 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
}
};
const runInference = useCallback(async () => {
if (points.length === 0 && !semanticText.trim()) return;
setIsInferencing(true);
try {
const result = await predictMask({
imageUrl: 'https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop',
points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
text: semanticText.trim() || undefined,
modelSize,
});
result.masks.forEach((m) => {
addMask({
id: m.id,
frameId: 'frame-ai-1',
pathData: m.pathData,
label: m.label,
color: m.color,
segmentation: m.segmentation,
bbox: m.bbox,
area: m.area,
});
});
} catch (err) {
console.error('AI inference failed:', err);
} finally {
setIsInferencing(false);
}
}, [points, semanticText, modelSize, addMask]);
const handleStageClick = (e: any) => {
if (activeTool === 'move') return;
if (activeTool === 'point_pos' || activeTool === 'point_neg') {
if (effectiveTool === 'move') return;
if (effectiveTool === 'point_pos' || effectiveTool === 'point_neg') {
const stage = e.target.getStage();
const pos = stage.getRelativePointerPosition();
if (pos) {
setPoints([...points, { x: pos.x, y: pos.y, type: activeTool === 'point_pos' ? 'pos' : 'neg' }]);
setPoints([...points, { x: pos.x, y: pos.y, type: effectiveTool === 'point_pos' ? 'pos' : 'neg' }]);
}
}
};
@@ -68,7 +108,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<aside className="w-80 bg-[#0d0d0d] flex flex-col border-r border-white/5 shrink-0 z-10 overflow-hidden">
<div className="h-16 border-b border-white/5 flex items-center px-6 shrink-0 justify-between">
<div className="flex items-center font-medium text-[11px] uppercase tracking-widest text-cyan-400">
<Cpu size={16} className="mr-3 text-cyan-400" />
<Sparkles size={16} className="mr-3 text-cyan-400" />
AI智能分割引擎
</div>
</div>
@@ -96,7 +136,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<div className="grid grid-cols-2 gap-3">
<button
onClick={() => setActiveTool('point_pos')}
className={cn("flex flex-col items-center justify-center p-4 rounded-lg border transition-all", activeTool === 'point_pos' ? "bg-green-500/10 border-green-500/30 text-green-400 shadow-[0_0_15px_rgba(34,197,94,0.1)]" : "bg-[#111] border-white/5 text-gray-400 hover:bg-white/5 hover:text-gray-200")}
className={cn("flex flex-col items-center justify-center p-4 rounded-lg border transition-all", effectiveTool === 'point_pos' ? "bg-green-500/10 border-green-500/30 text-green-400 shadow-[0_0_15px_rgba(34,197,94,0.1)]" : "bg-[#111] border-white/5 text-gray-400 hover:bg-white/5 hover:text-gray-200")}
>
<PlusCircle size={22} className="mb-3" />
<span className="text-[10px] uppercase tracking-wider font-semibold"></span>
@@ -104,15 +144,15 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<button
onClick={() => setActiveTool('point_neg')}
className={cn("flex flex-col items-center justify-center p-4 rounded-lg border transition-all", activeTool === 'point_neg' ? "bg-red-500/10 border-red-500/30 text-red-500 shadow-[0_0_15px_rgba(239,68,68,0.1)]" : "bg-[#111] border-white/5 text-gray-400 hover:bg-white/5 hover:text-gray-200")}
className={cn("flex flex-col items-center justify-center p-4 rounded-lg border transition-all", effectiveTool === 'point_neg' ? "bg-red-500/10 border-red-500/30 text-red-500 shadow-[0_0_15px_rgba(239,68,68,0.1)]" : "bg-[#111] border-white/5 text-gray-400 hover:bg-white/5 hover:text-gray-200")}
>
<MinusCircle size={22} className="mb-3" />
<span className="text-[10px] uppercase tracking-wider font-semibold"></span>
</button>
<button
onClick={() => setActiveTool('box')}
className={cn("flex flex-col items-center justify-center p-4 rounded-lg border transition-all", activeTool === 'box' ? "bg-blue-500/10 border-blue-500/30 text-blue-400 shadow-[0_0_15px_rgba(59,130,246,0.1)]" : "bg-[#111] border-white/5 text-gray-400 hover:bg-white/5 hover:text-gray-200")}
onClick={() => setActiveTool('box_select')}
className={cn("flex flex-col items-center justify-center p-4 rounded-lg border transition-all", effectiveTool === 'box_select' ? "bg-blue-500/10 border-blue-500/30 text-blue-400 shadow-[0_0_15px_rgba(59,130,246,0.1)]" : "bg-[#111] border-white/5 text-gray-400 hover:bg-white/5 hover:text-gray-200")}
>
<SquareDashed size={22} className="mb-3" />
<span className="text-[10px] uppercase tracking-wider font-semibold"></span>
@@ -120,7 +160,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<button
onClick={() => setActiveTool('move')}
className={cn("flex flex-col items-center justify-center p-4 rounded-lg border transition-all", activeTool === 'move' ? "bg-white/10 border-white/20 text-white shadow-[0_0_15px_rgba(255,255,255,0.05)]" : "bg-[#111] border-white/5 text-gray-400 hover:bg-white/5 hover:text-gray-200")}
className={cn("flex flex-col items-center justify-center p-4 rounded-lg border transition-all", effectiveTool === 'move' ? "bg-white/10 border-white/20 text-white shadow-[0_0_15px_rgba(255,255,255,0.05)]" : "bg-[#111] border-white/5 text-gray-400 hover:bg-white/5 hover:text-gray-200")}
>
<Target size={22} className="mb-3" />
<span className="text-[10px] uppercase tracking-wider font-semibold"></span>
@@ -165,9 +205,17 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<div className="p-6 bg-[#0a0a0a] border-t border-white/5 shrink-0 flex flex-col gap-3">
<button
className="w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all shadow-lg font-medium tracking-wide text-xs uppercase bg-cyan-500 hover:bg-cyan-400 text-black shadow-cyan-500/20 hover:shadow-cyan-500/40"
onClick={runInference}
disabled={isInferencing}
className={cn(
"w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all shadow-lg font-medium tracking-wide text-xs uppercase",
isInferencing
? "bg-cyan-500/50 text-black/70 cursor-not-allowed"
: "bg-cyan-500 hover:bg-cyan-400 text-black shadow-cyan-500/20 hover:shadow-cyan-500/40"
)}
>
<Sparkles size={16} />
{isInferencing ? <Loader2 size={16} className="animate-spin" /> : <Sparkles size={16} />}
{isInferencing ? '推理中...' : '执行高精度语义分割'}
</button>
<button
onClick={onSendToWorkspace}
@@ -196,7 +244,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<button className="flex items-center gap-2 text-xs text-gray-400 hover:text-white transition-colors bg-white/5 hover:bg-white/10 px-3 py-1.5 rounded-md border border-white/5">
<ImageIcon size={14} />
</button>
<button className="text-xs text-gray-400 hover:text-white transition-colors px-3 py-1.5" onClick={() => setPoints([])}>
<button className="text-xs text-gray-400 hover:text-white transition-colors px-3 py-1.5" onClick={() => { setPoints([]); clearMasks(); }}>
</button>
</div>
@@ -205,7 +253,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<div className="flex-1 relative p-8">
<div className="w-full h-full relative border border-white/5 rounded shadow-2xl bg-[#1e1e1e] overflow-hidden cursor-crosshair">
<Stage
width={window.innerWidth - 320 - 64} // approx sizing, uses window to avoid ResizeObserver for simplicity here
width={window.innerWidth - 320 - 64}
height={window.innerHeight - 64 - 64}
onWheel={handleWheel}
onMouseMove={handleMouseMove}
@@ -214,7 +262,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
scaleY={scale}
x={position.x}
y={position.y}
draggable={activeTool === 'move'}
draggable={effectiveTool === 'move'}
>
<Layer>
{/* Background Image */}
@@ -227,13 +275,17 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
/>
)}
{/* Mock Instance Mask from SAM3 */}
<Group opacity={0.4}>
<Path
data="M 300 200 Q 400 150 450 250 T 400 350 Q 250 350 280 250 Z"
fill="#06b6d4" // cyan-500
/>
</Group>
{/* AI Returned Masks */}
{masks.map((mask) => (
<Group key={mask.id} opacity={0.45}>
<Path
data={mask.pathData}
fill={mask.color}
stroke={mask.color}
strokeWidth={1 / scale}
/>
</Group>
))}
{/* Points */}
{points.map((p, i) => (
@@ -257,6 +309,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
<div className="absolute bottom-4 left-4 flex gap-4 text-[10px] font-mono text-gray-500 pointer-events-none">
<span>: {cursorPos.x.toFixed(2)}, {cursorPos.y.toFixed(2)}</span>
<span>: {(scale * 100).toFixed(0)}%</span>
<span>: {masks.length}</span>
</div>
</div>
</div>

View File

@@ -1,6 +1,9 @@
import React, { useEffect, useRef, useState } from 'react';
import React, { useEffect, useRef, useState, useCallback } from 'react';
import { Stage, Layer, Image as KonvaImage, Circle, Rect, Path, Group } from 'react-konva';
import useImage from 'use-image';
import { useStore } from '../store/useStore';
import { predictMask } from '../lib/api';
import { cn } from '../lib/utils';
interface CanvasAreaProps {
activeTool: string;
@@ -13,6 +16,17 @@ export function CanvasArea({ activeTool }: CanvasAreaProps) {
const [position, setPosition] = useState({ x: 0, y: 0 });
const [points, setPoints] = useState<{ x: number, y: number, type: 'pos'|'neg' }[]>([]);
const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 });
const [boxStart, setBoxStart] = useState<{ x: number, y: number } | null>(null);
const [boxCurrent, setBoxCurrent] = useState<{ x: number, y: number } | null>(null);
const [isInferencing, setIsInferencing] = useState(false);
const masks = useStore((state) => state.masks);
const addMask = useStore((state) => state.addMask);
const clearMasks = useStore((state) => state.clearMasks);
const storeActiveTool = useStore((state) => state.activeTool);
const setActiveTool = useStore((state) => state.setActiveTool);
const effectiveTool = activeTool || storeActiveTool;
// We load a mock image representing a frame
const [image] = useImage('https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop');
@@ -56,35 +70,120 @@ export function CanvasArea({ activeTool }: CanvasAreaProps) {
if (!stage) return;
const pos = stage.getPointerPosition();
if (pos) {
// Convert to image coordinates
const imageX = (pos.x - position.x) / scale;
const imageY = (pos.y - position.y) / scale;
setCursorPos({ x: imageX, y: imageY });
}
};
const handleStageClick = (e: any) => {
if (activeTool === 'move') return;
if (activeTool === 'point_pos' || activeTool === 'point_neg') {
const stage = e.target.getStage();
const pos = stage.getRelativePointerPosition();
setPoints([...points, { x: pos.x, y: pos.y, type: activeTool === 'point_pos' ? 'pos' : 'neg' }]);
if (boxStart && effectiveTool === 'box_select') {
const relPos = stage.getRelativePointerPosition();
if (relPos) {
setBoxCurrent({ x: relPos.x, y: relPos.y });
}
}
};
const runInference = useCallback(async (promptPoints?: typeof points, promptBox?: { x1: number, y1: number, x2: number, y2: number }) => {
setIsInferencing(true);
try {
const result = await predictMask({
imageUrl: 'https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop',
points: promptPoints?.map((p) => ({ x: p.x, y: p.y, type: p.type })),
box: promptBox,
});
result.masks.forEach((m) => {
addMask({
id: m.id,
frameId: 'frame-1',
pathData: m.pathData,
label: m.label,
color: m.color,
segmentation: m.segmentation,
bbox: m.bbox,
area: m.area,
});
});
} catch (err) {
console.error('Inference failed:', err);
} finally {
setIsInferencing(false);
}
}, [addMask]);
const handleStageMouseDown = (e: any) => {
if (effectiveTool === 'box_select') {
const stage = e.target.getStage();
const pos = stage.getRelativePointerPosition();
if (pos) {
setBoxStart({ x: pos.x, y: pos.y });
setBoxCurrent({ x: pos.x, y: pos.y });
}
}
};
const handleStageMouseUp = (e: any) => {
if (effectiveTool === 'box_select' && boxStart && boxCurrent) {
const x1 = Math.min(boxStart.x, boxCurrent.x);
const y1 = Math.min(boxStart.y, boxCurrent.y);
const x2 = Math.max(boxStart.x, boxCurrent.x);
const y2 = Math.max(boxStart.y, boxCurrent.y);
if (Math.abs(x2 - x1) > 5 && Math.abs(y2 - y1) > 5) {
runInference(undefined, { x1, y1, x2, y2 });
}
setBoxStart(null);
setBoxCurrent(null);
}
};
const handleStageClick = (e: any) => {
if (effectiveTool === 'move') return;
if (effectiveTool === 'box_select') return; // handled by mouseup
if (effectiveTool === 'point_pos' || effectiveTool === 'point_neg') {
const stage = e.target.getStage();
const pos = stage.getRelativePointerPosition();
if (pos) {
const newPoints = [...points, { x: pos.x, y: pos.y, type: effectiveTool === 'point_pos' ? 'pos' : 'neg' as 'pos'|'neg' }];
setPoints(newPoints);
// Auto-trigger inference after point selection
runInference(newPoints);
}
}
};
const boxRect = React.useMemo(() => {
if (!boxStart || !boxCurrent) return null;
const x = Math.min(boxStart.x, boxCurrent.x);
const y = Math.min(boxStart.y, boxCurrent.y);
const width = Math.abs(boxCurrent.x - boxStart.x);
const height = Math.abs(boxCurrent.y - boxStart.y);
return { x, y, width, height };
}, [boxStart, boxCurrent]);
return (
<div ref={containerRef} className="w-full h-full relative cursor-crosshair overflow-hidden rounded-sm">
{isInferencing && (
<div className="absolute top-4 right-4 z-20 flex items-center gap-2 bg-[#111] border border-white/10 px-3 py-2 rounded-lg shadow-xl">
<div className="w-3 h-3 border-2 border-cyan-500 border-t-transparent rounded-full animate-spin" />
<span className="text-xs text-cyan-400 font-mono">AI ...</span>
</div>
)}
<Stage
width={stageSize.width}
height={stageSize.height}
onWheel={handleWheel}
onMouseMove={handleMouseMove}
onMouseDown={handleStageMouseDown}
onMouseUp={handleStageMouseUp}
scaleX={scale}
scaleY={scale}
x={position.x}
y={position.y}
draggable={activeTool === 'move'}
draggable={effectiveTool === 'move'}
onClick={handleStageClick}
>
<Layer>
@@ -98,24 +197,38 @@ export function CanvasArea({ activeTool }: CanvasAreaProps) {
/>
)}
{/* Mock Instance Mask overlapping */}
<Group opacity={0.4}>
<Path
data="M 300 200 Q 400 150 450 250 T 400 350 Q 250 350 280 250 Z"
fill="#06b6d4" // cyan-500
{/* AI Returned Masks */}
{masks.map((mask) => (
<Group key={mask.id} opacity={0.5}>
<Path
data={mask.pathData}
fill={mask.color}
stroke={mask.color}
strokeWidth={1 / scale}
/>
</Group>
))}
{/* Box selection preview */}
{boxRect && effectiveTool === 'box_select' && (
<Rect
x={boxRect.x}
y={boxRect.y}
width={boxRect.width}
height={boxRect.height}
fill="rgba(6, 182, 212, 0.1)"
stroke="#06b6d4"
strokeWidth={2 / scale}
dash={[4 / scale, 4 / scale]}
/>
<Path
data="M 600 400 Q 700 350 750 450 T 650 550 Q 550 550 580 450 Z"
fill="#a855f7" // purple-500
/>
</Group>
)}
{/* AI Prompts Point Regions */}
{points.map((p, i) => (
<Group key={i} x={p.x} y={p.y}>
<Circle
radius={6 / scale}
fill={p.type === 'pos' ? '#22c55e' : '#ef4444'} // green or red
fill={p.type === 'pos' ? '#22c55e' : '#ef4444'}
stroke="#ffffff"
strokeWidth={2 / scale}
shadowColor="black"
@@ -134,7 +247,17 @@ export function CanvasArea({ activeTool }: CanvasAreaProps) {
<span>: {cursorPos.x.toFixed(2)}, {cursorPos.y.toFixed(2)}</span>
<span>当前图层树: OBJECT_VEHICLE_01</span>
<span>: {(scale * 100).toFixed(0)}%</span>
<span>: {masks.length}</span>
</div>
{masks.length > 0 && (
<button
onClick={clearMasks}
className="absolute bottom-4 right-4 text-xs bg-red-500/10 hover:bg-red-500/20 text-red-400 border border-red-500/20 px-3 py-1.5 rounded transition-colors"
>
</button>
)}
</div>
);
}

View File

@@ -1,10 +1,99 @@
import React from 'react';
import { Activity, Clock, Folders, CheckCircle2 } from 'lucide-react';
import React, { useState, useEffect } from 'react';
import { Activity, Clock, Folders, CheckCircle2, Loader2 } from 'lucide-react';
import { progressWS, type ProgressMessage } from '../lib/websocket';
import { cn } from '../lib/utils';
interface QueueTask {
id: string;
name: string;
progress: number;
status: string;
}
export function Dashboard() {
const [tasks, setTasks] = useState<QueueTask[]>([
{ id: '1', name: 'City_Driving_Dataset_004.mp4', progress: 85, status: '正在截取帧 (30fps)' },
{ id: '2', name: 'Pedestrian_Night_Vision_02.mkv', progress: 32, status: '正在截取帧 (60fps)' },
{ id: '3', name: 'Drone_Mapping_Sector_7.avi', progress: 0, status: '队列排队等待中' },
]);
const [isConnected, setIsConnected] = useState(false);
const [activityLog, setActivityLog] = useState<Array<{ time: string; message: string; project?: string }>>([
{ time: '10 分钟前', message: '语义归档完成 54 帧', project: 'Highway_Data' },
{ time: '25 分钟前', message: '项目解析开始', project: 'City_Driving_Dataset_004' },
{ time: '1 小时前', message: '模板库更新: Cityscapes_v2', project: '系统' },
{ time: '2 小时前', message: 'AI 推理完成 12 个实例', project: 'Nav_Cam_Left' },
]);
useEffect(() => {
progressWS.connect();
const unsubscribe = progressWS.onProgress((data: ProgressMessage) => {
setIsConnected(progressWS.isConnected());
if (data.type === 'progress' && data.taskId && data.filename) {
setTasks((prev) => {
const exists = prev.find((t) => t.id === data.taskId);
if (exists) {
return prev.map((t) =>
t.id === data.taskId
? { ...t, progress: data.progress ?? t.progress, status: data.status ?? t.status }
: t
);
}
return [
...prev,
{
id: data.taskId!,
name: data.filename!,
progress: data.progress ?? 0,
status: data.status ?? '处理中',
},
];
});
}
if (data.type === 'complete' && data.taskId) {
setTasks((prev) =>
prev.map((t) =>
t.id === data.taskId ? { ...t, progress: 100, status: '已完成' } : t
)
);
setActivityLog((prev) => [
{ time: '刚刚', message: `解析完成: ${data.filename || data.taskId}`, project: '系统' },
...prev.slice(0, 9),
]);
}
if (data.type === 'error' && data.taskId) {
setTasks((prev) =>
prev.map((t) =>
t.id === data.taskId ? { ...t, status: `错误: ${data.message || '未知错误'}` } : t
)
);
}
if (data.type === 'status') {
setActivityLog((prev) => [
{ time: '刚刚', message: data.message || '状态更新', project: '系统' },
...prev.slice(0, 9),
]);
}
});
const checkConnection = setInterval(() => {
setIsConnected(progressWS.isConnected());
}, 5000);
return () => {
unsubscribe();
clearInterval(checkConnection);
progressWS.disconnect();
};
}, []);
const stats = [
{ label: '运行中项目', value: '14', icon: Folders, color: 'text-blue-400', bg: 'bg-blue-400/10' },
{ label: '排队处理任务', value: '3,291', icon: Clock, color: 'text-orange-400', bg: 'bg-orange-400/10' },
{ label: '排队处理任务', value: tasks.length.toString(), icon: Clock, color: 'text-orange-400', bg: 'bg-orange-400/10' },
{ label: '已归档批次', value: '128', icon: CheckCircle2, color: 'text-emerald-400', bg: 'bg-emerald-400/10' },
{ label: '系统负载', value: '78%', icon: Activity, color: 'text-cyan-400', bg: 'bg-cyan-400/10' },
];
@@ -12,7 +101,18 @@ export function Dashboard() {
return (
<div className="p-8 w-full h-full overflow-y-auto bg-[#0a0a0a]">
<header className="mb-8">
<h1 className="text-3xl font-medium tracking-tight text-white"></h1>
<div className="flex items-center gap-3">
<h1 className="text-3xl font-medium tracking-tight text-white"></h1>
<div className={cn(
"flex items-center gap-1.5 text-[10px] uppercase font-mono px-2 py-1 rounded border",
isConnected
? "bg-emerald-500/10 text-emerald-400 border-emerald-500/20"
: "bg-amber-500/10 text-amber-400 border-amber-500/20"
)}>
<div className={cn("w-1.5 h-1.5 rounded-full", isConnected ? "bg-emerald-500" : "bg-amber-500 animate-pulse")} />
{isConnected ? 'WebSocket 已连接' : 'WebSocket 断开'}
</div>
</div>
<p className="text-gray-400 text-sm mt-1"></p>
</header>
@@ -37,36 +137,43 @@ export function Dashboard() {
<div className="lg:col-span-2 bg-[#111] border border-white/5 rounded-xl p-6 min-h-[400px]">
<h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6"> (FFmpeg )</h2>
<div className="space-y-4">
{[
{ name: 'City_Driving_Dataset_004.mp4', progress: 85, status: '正在截取帧 (30fps)' },
{ name: 'Pedestrian_Night_Vision_02.mkv', progress: 32, status: '正在截取帧 (60fps)' },
{ name: 'Drone_Mapping_Sector_7.avi', progress: 0, status: '队列排队等待中' }
].map((task, i) => (
<div key={i} className="bg-[#0d0d0d] border border-white/5 p-4 rounded-lg">
{tasks.map((task) => (
<div key={task.id} className="bg-[#0d0d0d] border border-white/5 p-4 rounded-lg">
<div className="flex justify-between items-center mb-2">
<span className="font-mono text-sm text-gray-200">{task.name}</span>
<span className="text-xs text-cyan-400 font-mono">{task.progress}%</span>
</div>
<div className="w-full h-1.5 bg-white/5 rounded-full overflow-hidden mb-2">
<div className="h-full bg-gradient-to-r from-cyan-600 to-cyan-400 rounded-full" style={{ width: `${task.progress}%` }} />
<div className="h-full bg-gradient-to-r from-cyan-600 to-cyan-400 rounded-full transition-all duration-500" style={{ width: `${task.progress}%` }} />
</div>
<div className="text-xs text-gray-500 flex items-center gap-2">
{task.status === '已完成' ? (
<CheckCircle2 size={12} className="text-emerald-400" />
) : task.status.includes('错误') ? (
<span className="text-red-400"></span>
) : (
<Loader2 size={12} className="text-cyan-400 animate-spin" />
)}
{task.status}
</div>
<div className="text-xs text-gray-500">{task.status}</div>
</div>
))}
{tasks.length === 0 && (
<div className="text-sm text-gray-500 text-center py-12"></div>
)}
</div>
</div>
<div className="bg-[#111] border border-white/5 rounded-xl p-6 min-h-[400px]">
<h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6"></h2>
<div className="space-y-6 relative before:absolute before:inset-0 before:ml-[11px] before:-translate-x-px md:before:mx-auto md:before:translate-x-0 before:h-full before:w-0.5 before:bg-gradient-to-b before:from-transparent before:via-white/10 before:to-transparent">
{/* Activity log mockup */}
{[1, 2, 3, 4].map((i) => (
{activityLog.map((log, i) => (
<div key={i} className="relative flex items-center justify-between md:justify-normal md:odd:flex-row-reverse group is-active">
<div className="flex items-center justify-center w-6 h-6 rounded-full border border-white/10 bg-[#111] group-[.is-active]:bg-cyan-500 group-[.is-active]:border-cyan-400 text-slate-500 group-[.is-active]:text-black shadow shrink-0 md:order-1 md:group-odd:-translate-x-1/2 md:group-even:translate-x-1/2 z-10" />
<div className="w-[calc(100%-4rem)] md:w-[calc(50%-2.5rem)] bg-[#0d0d0d] p-3 rounded border border-white/5">
<div className="text-xs text-gray-400 mb-1">10 </div>
<div className="text-sm font-medium text-gray-200"> 54 </div>
<div className="text-xs text-gray-500">归属项目: Highway_Data</div>
<div className="text-xs text-gray-400 mb-1">{log.time}</div>
<div className="text-sm font-medium text-gray-200">{log.message}</div>
<div className="text-xs text-gray-500">: {log.project}</div>
</div>
</div>
))}

View File

@@ -1,12 +1,11 @@
import React, { useState } from 'react';
import { BrainCircuit } from 'lucide-react';
import { cn } from '../lib/utils';
import { useStore } from '../store/useStore';
import { login as loginApi } from '../lib/api';
interface LoginProps {
onLoginSuccess: (token: string) => void;
}
export function Login({ onLoginSuccess }: LoginProps) {
export function Login() {
const storeLogin = useStore((state) => state.login);
const [username, setUsername] = useState('admin');
const [password, setPassword] = useState('123456');
const [error, setError] = useState('');
@@ -18,21 +17,11 @@ export function Login({ onLoginSuccess }: LoginProps) {
setIsLoading(true);
try {
const response = await fetch('/api/login', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ username, password }),
});
if (response.ok) {
const data = await response.json();
onLoginSuccess(data.token);
} else {
const errData = await response.json();
setError(errData.error || '登录失败');
}
} catch (err) {
setError('网络异常,无法连接到后端验证');
const data = await loginApi(username, password);
storeLogin(data.token);
} catch (err: any) {
const msg = err?.response?.data?.detail || err?.response?.data?.error || '登录失败,请检查网络或凭证';
setError(msg);
} finally {
setIsLoading(false);
}
@@ -45,7 +34,7 @@ export function Login({ onLoginSuccess }: LoginProps) {
<div className="relative z-10 w-full max-w-md p-8 bg-[#111] border border-white/5 rounded-2xl shadow-2xl scale-in shadow-black/50">
<div className="flex flex-col items-center mb-8">
<div className="w-16 h-16 bg-white rounded-2xl flex items-center justify-center text-cyan-500 shadow-lg shadow-cyan-500/20 mb-4 overflow-hidden border border-white/10">
<img src="/Logo.png" alt="Logo" className="w-full h-full object-cover" />
<BrainCircuit size={32} />
</div>
<h1 className="text-2xl font-bold text-white tracking-wider mb-2"></h1>
<p className="text-sm text-gray-500">AI </p>

View File

@@ -1,19 +1,63 @@
import React, { useState, useEffect } from 'react';
import { UploadCloud, Film, Settings2, MoreHorizontal } from 'lucide-react';
import { UploadCloud, Film, Settings2, MoreHorizontal, Plus, Loader2 } from 'lucide-react';
import { cn } from '../lib/utils';
import { useStore } from '../store/useStore';
import { getProjects, createProject } from '../lib/api';
import type { Project } from '../store/useStore';
interface ProjectLibraryProps {
onProjectSelect: () => void;
}
export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
const [projects, setProjects] = useState<any[]>([]);
const projects = useStore((state) => state.projects);
const setProjects = useStore((state) => state.setProjects);
const setCurrentProject = useStore((state) => state.setCurrentProject);
const addProject = useStore((state) => state.addProject);
const [isLoading, setIsLoading] = useState(false);
const [isCreating, setIsCreating] = useState(false);
const [showModal, setShowModal] = useState(false);
const [newName, setNewName] = useState('');
const [newDesc, setNewDesc] = useState('');
useEffect(() => {
fetch('/api/projects')
.then(res => res.json())
.then(data => setProjects(data))
.catch(console.error);
}, []);
setIsLoading(true);
getProjects()
.then((data) => setProjects(data))
.catch(console.error)
.finally(() => setIsLoading(false));
}, [setProjects]);
const handleCreate = async () => {
if (!newName.trim()) return;
setIsCreating(true);
try {
const project = await createProject({ name: newName.trim(), description: newDesc.trim() || undefined });
addProject(project);
setShowModal(false);
setNewName('');
setNewDesc('');
} catch (err) {
console.error('Failed to create project:', err);
} finally {
setIsCreating(false);
}
};
const handleSelect = (project: Project) => {
setCurrentProject(project);
onProjectSelect();
};
const SkeletonCard = () => (
<div className="group flex flex-col bg-[#111] border border-white/5 rounded-xl overflow-hidden animate-pulse">
<div className="w-full aspect-[16/9] bg-[#1a1a1a]" />
<div className="p-4 flex flex-col gap-2">
<div className="h-4 bg-[#1a1a1a] rounded w-3/4" />
<div className="h-3 bg-[#1a1a1a] rounded w-1/2 mt-2" />
</div>
</div>
);
return (
<div className="p-8 w-full h-full overflow-y-auto bg-[#0a0a0a]">
@@ -22,47 +66,116 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
<h1 className="text-3xl font-medium tracking-tight text-white mb-2"></h1>
<p className="text-gray-400 text-sm"></p>
</div>
<button className="flex items-center gap-2 bg-cyan-600 hover:bg-cyan-500 text-white px-5 py-2.5 rounded-lg font-medium text-sm transition-colors border border-cyan-500 shadow-lg shadow-cyan-900/20">
<UploadCloud size={18} />
<span></span>
</button>
<div className="flex items-center gap-3">
<button
onClick={() => setShowModal(true)}
className="flex items-center gap-2 bg-white/5 hover:bg-white/10 border border-white/10 text-gray-200 px-5 py-2.5 rounded-lg font-medium text-sm transition-colors"
>
<Plus size={18} />
<span></span>
</button>
<button className="flex items-center gap-2 bg-cyan-600 hover:bg-cyan-500 text-white px-5 py-2.5 rounded-lg font-medium text-sm transition-colors border border-cyan-500 shadow-lg shadow-cyan-900/20">
<UploadCloud size={18} />
<span></span>
</button>
</div>
</div>
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4 gap-6">
{projects.map((proj) => (
<div
key={proj.id}
className="group flex flex-col bg-[#111] border border-white/5 rounded-xl overflow-hidden cursor-pointer hover:border-cyan-500/50 transition-all hover:shadow-[0_0_20px_rgba(6,182,212,0.15)]"
onClick={onProjectSelect}
>
<div className={`w-full aspect-[16/9] ${proj.thumbnail} relative flex items-center justify-center overflow-hidden`}>
{/* Stand-in for actual video frame thumbnail */}
<Film className="w-12 h-12 text-[#2a2a2a] group-hover:text-[#333] transition-colors" />
<div className="absolute top-2 right-2 flex gap-2">
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] font-mono px-2 py-1 rounded border border-white/10 uppercase tracking-widest">
{proj.fps}
</span>
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] px-2 py-1 rounded border border-white/10 flex items-center gap-1 uppercase tracking-widest">
{proj.status === 'Ready' ? (
<><div className="w-1.5 h-1.5 bg-emerald-500 rounded-full" /> </>
) : (
<><div className="w-1.5 h-1.5 bg-amber-500 rounded-full animate-pulse" /> </>
)}
</span>
</div>
{isLoading && projects.length === 0 ? (
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4 gap-6">
{Array.from({ length: 8 }).map((_, i) => (
<SkeletonCard key={i} />
))}
</div>
) : (
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4 gap-6">
{projects.map((proj) => (
<div
key={proj.id}
className="group flex flex-col bg-[#111] border border-white/5 rounded-xl overflow-hidden cursor-pointer hover:border-cyan-500/50 transition-all hover:shadow-[0_0_20px_rgba(6,182,212,0.15)]"
onClick={() => handleSelect(proj)}
>
<div className={cn("w-full aspect-[16/9] relative flex items-center justify-center overflow-hidden", proj.thumbnail || 'bg-[#0d0d0d]')}>
<Film className="w-12 h-12 text-[#2a2a2a] group-hover:text-[#333] transition-colors" />
<div className="absolute top-2 right-2 flex gap-2">
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] font-mono px-2 py-1 rounded border border-white/10 uppercase tracking-widest">
{proj.fps || '30FPS'}
</span>
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] px-2 py-1 rounded border border-white/10 flex items-center gap-1 uppercase tracking-widest">
{proj.status === 'Ready' ? (
<><div className="w-1.5 h-1.5 bg-emerald-500 rounded-full" /> </>
) : proj.status === 'Parsing' ? (
<><div className="w-1.5 h-1.5 bg-amber-500 rounded-full animate-pulse" /> </>
) : (
<><div className="w-1.5 h-1.5 bg-red-500 rounded-full" /> </>
)}
</span>
</div>
</div>
<div className="p-4 flex flex-col gap-1">
<div className="flex justify-between items-start">
<h3 className="text-sm font-medium text-gray-200 truncate pr-4" title={proj.name}>{proj.name}</h3>
<button className="text-gray-500 hover:text-gray-300"><MoreHorizontal size={16} /></button>
</div>
<div className="flex items-center gap-4 text-xs text-gray-500 font-mono mt-2">
<span className="flex items-center gap-1.5"><Settings2 size={12} /> {proj.frames ?? 0} </span>
</div>
</div>
</div>
<div className="p-4 flex flex-col gap-1">
<div className="flex justify-between items-start">
<h3 className="text-sm font-medium text-gray-200 truncate pr-4" title={proj.name}>{proj.name}</h3>
<button className="text-gray-500 hover:text-gray-300"><MoreHorizontal size={16} /></button>
))}
</div>
)}
{showModal && (
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm">
<div className="bg-[#111] border border-white/10 rounded-2xl p-6 w-full max-w-md shadow-2xl">
<h2 className="text-lg font-semibold text-white mb-4"></h2>
<div className="space-y-4">
<div>
<label className="block text-xs font-medium text-gray-400 uppercase tracking-widest mb-2"></label>
<input
type="text"
value={newName}
onChange={(e) => setNewName(e.target.value)}
className="w-full bg-[#1a1a1a] border border-white/10 rounded-lg px-4 py-3 text-sm focus:outline-none focus:border-cyan-500/50 focus:ring-1 focus:ring-cyan-500/50 transition-all"
placeholder="输入项目名称"
/>
</div>
<div className="flex items-center gap-4 text-xs text-gray-500 font-mono mt-2">
<span className="flex items-center gap-1.5"><Settings2 size={12} /> {proj.frames} </span>
<div>
<label className="block text-xs font-medium text-gray-400 uppercase tracking-widest mb-2"></label>
<input
type="text"
value={newDesc}
onChange={(e) => setNewDesc(e.target.value)}
className="w-full bg-[#1a1a1a] border border-white/10 rounded-lg px-4 py-3 text-sm focus:outline-none focus:border-cyan-500/50 focus:ring-1 focus:ring-cyan-500/50 transition-all"
placeholder="输入项目描述"
/>
</div>
</div>
<div className="flex justify-end gap-3 mt-6">
<button
onClick={() => { setShowModal(false); setNewName(''); setNewDesc(''); }}
className="px-4 py-2 rounded-lg text-sm text-gray-400 hover:text-white transition-colors"
>
</button>
<button
onClick={handleCreate}
disabled={isCreating || !newName.trim()}
className={cn(
"px-4 py-2 rounded-lg text-sm font-medium flex items-center gap-2 transition-all",
isCreating || !newName.trim()
? "bg-cyan-500/50 text-black/70 cursor-not-allowed"
: "bg-cyan-500 hover:bg-cyan-400 text-black"
)}
>
{isCreating && <Loader2 size={14} className="animate-spin" />}
</button>
</div>
</div>
))}
</div>
</div>
)}
</div>
);
}

View File

@@ -1,15 +1,112 @@
import React, { useState, useEffect } from 'react';
import { Settings, FileJson, ArrowRightLeft, Database } from 'lucide-react';
import { Settings, Database, Trash2, Edit3, Plus, Loader2, X } from 'lucide-react';
import { cn } from '../lib/utils';
import { useStore } from '../store/useStore';
import { getTemplates, createTemplate, updateTemplate, deleteTemplate } from '../lib/api';
import type { Template, TemplateClass } from '../store/useStore';
export function TemplateRegistry() {
const [templates, setTemplates] = useState<any[]>([]);
const templates = useStore((state) => state.templates);
const setTemplates = useStore((state) => state.setTemplates);
const addTemplate = useStore((state) => state.addTemplate);
const updateTemplateStore = useStore((state) => state.updateTemplate);
const removeTemplateStore = useStore((state) => state.removeTemplate);
const [isLoading, setIsLoading] = useState(false);
const [selectedTemplate, setSelectedTemplate] = useState<Template | null>(null);
const [showModal, setShowModal] = useState(false);
const [isSaving, setIsSaving] = useState(false);
const [editName, setEditName] = useState('');
const [editDesc, setEditDesc] = useState('');
const [editClasses, setEditClasses] = useState<TemplateClass[]>([]);
const [editingClassId, setEditingClassId] = useState<string | null>(null);
useEffect(() => {
fetch('/api/templates')
.then(res => res.json())
.then(data => setTemplates(data))
.catch(console.error);
}, []);
setIsLoading(true);
getTemplates()
.then((data) => setTemplates(data))
.catch(console.error)
.finally(() => setIsLoading(false));
}, [setTemplates]);
const openCreate = () => {
setSelectedTemplate(null);
setEditName('');
setEditDesc('');
setEditClasses([]);
setShowModal(true);
};
const openEdit = (template: Template) => {
setSelectedTemplate(template);
setEditName(template.name);
setEditDesc(template.description || '');
setEditClasses(template.classes ? [...template.classes] : []);
setShowModal(true);
};
const handleSave = async () => {
if (!editName.trim()) return;
setIsSaving(true);
try {
if (selectedTemplate) {
const updated = await updateTemplate(selectedTemplate.id, {
name: editName.trim(),
description: editDesc.trim() || undefined,
classes: editClasses,
});
updateTemplateStore(updated);
} else {
const created = await createTemplate({
name: editName.trim(),
description: editDesc.trim() || undefined,
classes: editClasses,
});
addTemplate(created);
}
setShowModal(false);
} catch (err) {
console.error('Failed to save template:', err);
} finally {
setIsSaving(false);
}
};
const handleDelete = async (id: string) => {
if (!confirm('确定要删除此模板吗?')) return;
try {
await deleteTemplate(id);
removeTemplateStore(id);
if (selectedTemplate?.id === id) {
setSelectedTemplate(null);
}
} catch (err) {
console.error('Failed to delete template:', err);
}
};
const addClass = () => {
const newClass: TemplateClass = {
id: `cls-${Date.now()}`,
name: '新类别',
color: '#06b6d4',
zIndex: editClasses.length > 0 ? Math.max(...editClasses.map((c) => c.zIndex)) + 10 : 10,
category: '未分类',
};
setEditClasses([...editClasses, newClass]);
setEditingClassId(newClass.id);
};
const updateClass = (id: string, updates: Partial<TemplateClass>) => {
setEditClasses(editClasses.map((c) => (c.id === id ? { ...c, ...updates } : c)));
};
const removeClass = (id: string) => {
setEditClasses(editClasses.filter((c) => c.id !== id));
};
const activeTemplate = selectedTemplate || templates[0] || null;
return (
<div className="p-8 w-full h-full overflow-y-auto bg-[#0a0a0a]">
@@ -22,90 +119,224 @@ export function TemplateRegistry() {
<div className="xl:col-span-1 border-r border-white/5 pr-6">
<div className="flex justify-between items-center mb-6">
<h2 className="text-sm font-bold text-gray-500 uppercase tracking-widest"></h2>
<button className="text-cyan-400 hover:text-cyan-300 text-sm transition-colors">+ </button>
<button
onClick={openCreate}
className="text-cyan-400 hover:text-cyan-300 text-sm transition-colors flex items-center gap-1"
>
<Plus size={14} />
</button>
</div>
<div className="space-y-3">
{templates.map(t => (
<div key={t.id} className="bg-[#111] border border-white/5 hover:border-cyan-500/50 p-4 rounded-xl cursor-pointer transition-all hover:shadow-lg hover:shadow-cyan-900/10">
<h3 className="font-medium text-gray-200 mb-1">{t.name}</h3>
<div className="flex items-center gap-4 text-xs text-gray-500 font-mono">
<span> {t.classes} </span>
<span> {t.rules} </span>
{isLoading && templates.length === 0 ? (
<div className="space-y-3">
{Array.from({ length: 4 }).map((_, i) => (
<div key={i} className="bg-[#111] border border-white/5 p-4 rounded-xl animate-pulse">
<div className="h-4 bg-[#1a1a1a] rounded w-2/3 mb-2" />
<div className="h-3 bg-[#1a1a1a] rounded w-1/2" />
</div>
))}
</div>
) : (
<div className="space-y-3">
{templates.map((t) => (
<div
key={t.id}
className={cn(
"bg-[#111] border p-4 rounded-xl cursor-pointer transition-all hover:shadow-lg hover:shadow-cyan-900/10 group",
activeTemplate?.id === t.id ? "border-cyan-500/50" : "border-white/5 hover:border-cyan-500/50"
)}
onClick={() => setSelectedTemplate(t)}
>
<div className="flex justify-between items-start">
<h3 className="font-medium text-gray-200 mb-1">{t.name}</h3>
<div className="flex items-center gap-1 opacity-0 group-hover:opacity-100 transition-opacity">
<button
onClick={(e) => { e.stopPropagation(); openEdit(t); }}
className="p-1 rounded text-gray-500 hover:text-cyan-400 transition-colors"
>
<Edit3 size={14} />
</button>
<button
onClick={(e) => { e.stopPropagation(); handleDelete(t.id); }}
className="p-1 rounded text-gray-500 hover:text-red-400 transition-colors"
>
<Trash2 size={14} />
</button>
</div>
</div>
<div className="flex items-center gap-4 text-xs text-gray-500 font-mono">
<span> {t.classes?.length ?? 0} </span>
<span> {t.rules?.length ?? 0} </span>
</div>
</div>
</div>
))}
</div>
))}
</div>
)}
</div>
<div className="xl:col-span-2 space-y-6">
<div className="bg-[#111] border border-white/5 rounded-xl p-6">
<div className="flex items-center justify-between mb-6 border-b border-white/5 pb-4">
<h2 className="text-lg font-medium text-gray-200 flex items-center gap-2"><Database size={18} /> Cityscapes_v2_Mapping</h2>
<button className="bg-white/5 hover:bg-white/10 border border-white/10 px-4 py-1.5 rounded text-sm text-gray-300 transition-colors"> (Schema)</button>
<h2 className="text-lg font-medium text-gray-200 flex items-center gap-2">
<Database size={18} />
{activeTemplate ? activeTemplate.name : '未选择模板'}
</h2>
{activeTemplate && (
<button
onClick={() => openEdit(activeTemplate)}
className="bg-white/5 hover:bg-white/10 border border-white/10 px-4 py-1.5 rounded text-sm text-gray-300 transition-colors flex items-center gap-2"
>
<Settings size={14} /> (Schema)
</button>
)}
</div>
<div className="space-y-6">
<div>
<h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest mb-4"> (Painter's Algorithm Weight)</h3>
<div className="space-y-2">
{[
{ l: 'pedestrian', z: 90, c: '#ec4899', t: ' (Dynamic Entity)' },
{ l: 'bicycle', z: 85, c: '#f59e0b', t: ' (Dynamic Entity)' },
{ l: 'vehicle_car', z: 80, c: '#6366f1', t: ' (Dynamic Entity)' },
{ l: 'traffic_sign', z: 60, c: '#eab308', t: ' (Static Entity)' },
{ l: 'road_surface', z: 10, c: '#71717a', t: ' (Background / Floor)' },
].map(cls => (
<div key={cls.l} className="flex grid grid-cols-4 gap-4 p-3 bg-[#0d0d0d] border border-white/5 rounded items-center">
<div className="col-span-1 flex items-center gap-2">
<div className="w-3 h-3 rounded" style={{ backgroundColor: cls.c }}></div>
<span className="font-medium text-sm text-gray-300">{cls.l}</span>
{activeTemplate ? (
<div className="space-y-6">
<div>
<h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest mb-4">
(Painter's Algorithm Weight)
</h3>
<div className="space-y-2">
{(activeTemplate.classes || []).sort((a, b) => b.zIndex - a.zIndex).map((cls) => (
<div key={cls.id} className="grid grid-cols-4 gap-4 p-3 bg-[#0d0d0d] border border-white/5 rounded items-center">
<div className="col-span-1 flex items-center gap-2">
<div className="w-3 h-3 rounded" style={{ backgroundColor: cls.color }}></div>
<span className="font-medium text-sm text-gray-300">{cls.name}</span>
</div>
<div className="col-span-1 font-mono text-xs text-gray-500">优先级 Z-Level: {cls.zIndex}</div>
<div className="col-span-2 flex justify-end">
<span className="bg-white/5 text-gray-400 text-xs px-2 py-1 rounded border border-white/10">{cls.category || ''}</span>
</div>
</div>
<div className="col-span-1 font-mono text-xs text-gray-500"> Z-Level: {cls.z}</div>
<div className="col-span-2 flex justify-end">
<span className="bg-white/5 text-gray-400 text-xs px-2 py-1 rounded border border-white/10">{cls.t}</span>
</div>
</div>
))}
))}
{(activeTemplate.classes || []).length === 0 && (
<div className="text-sm text-gray-500 text-center py-8">暂无分类定义</div>
)}
</div>
</div>
</div>
<div>
<h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest mb-4 flex items-center gap-2">
<ArrowRightLeft size={14} /> (GT Source)
</h3>
<div className="bg-[#0d0d0d] border border-white/5 rounded-lg overflow-hidden">
<table className="w-full text-left text-sm text-gray-400">
<thead className="bg-[#111] border-b border-white/5 text-xs uppercase text-gray-500 font-mono">
<tr>
<th className="px-4 py-3"> JSON (Legacy Key)</th>
<th className="px-4 py-3"></th>
<th className="px-4 py-3"></th>
</tr>
</thead>
<tbody className="divide-y divide-white/5">
<tr>
<td className="px-4 py-3 font-mono text-gray-300">"car_sedan"</td>
<td className="px-4 py-3 font-mono text-cyan-400"> (Logical OR)</td>
<td className="px-4 py-3 font-medium text-gray-300">vehicle_car</td>
</tr>
<tr>
<td className="px-4 py-3 font-mono text-gray-300">"car_suv"</td>
<td className="px-4 py-3 font-mono text-cyan-400"> (Logical OR)</td>
<td className="px-4 py-3 font-medium text-gray-300">vehicle_car</td>
</tr>
<tr>
<td className="px-4 py-3 font-mono text-gray-300">"sidewalk_curb"</td>
<td className="px-4 py-3 font-mono text-cyan-400">- (Skeletonization)</td>
<td className="px-4 py-3 font-medium text-gray-300">road_curb</td>
</tr>
</tbody>
</table>
</div>
</div>
</div>
) : (
<div className="text-sm text-gray-500 text-center py-12">请从左侧选择一个模板或创建新模板</div>
)}
</div>
</div>
</div>
{showModal && (
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm">
<div className="bg-[#111] border border-white/10 rounded-2xl p-6 w-full max-w-2xl max-h-[80vh] overflow-y-auto shadow-2xl">
<div className="flex justify-between items-center mb-4">
<h2 className="text-lg font-semibold text-white">{selectedTemplate ? '' : ''}</h2>
<button onClick={() => setShowModal(false)} className="text-gray-500 hover:text-white transition-colors">
<X size={18} />
</button>
</div>
<div className="space-y-4 mb-6">
<div>
<label className="block text-xs font-medium text-gray-400 uppercase tracking-widest mb-2">模板名称</label>
<input
type="text"
value={editName}
onChange={(e) => setEditName(e.target.value)}
className="w-full bg-[#1a1a1a] border border-white/10 rounded-lg px-4 py-3 text-sm focus:outline-none focus:border-cyan-500/50 focus:ring-1 focus:ring-cyan-500/50 transition-all"
/>
</div>
<div>
<label className="block text-xs font-medium text-gray-400 uppercase tracking-widest mb-2">描述</label>
<input
type="text"
value={editDesc}
onChange={(e) => setEditDesc(e.target.value)}
className="w-full bg-[#1a1a1a] border border-white/10 rounded-lg px-4 py-3 text-sm focus:outline-none focus:border-cyan-500/50 focus:ring-1 focus:ring-cyan-500/50 transition-all"
/>
</div>
</div>
<div className="mb-4">
<div className="flex justify-between items-center mb-3">
<h3 className="text-xs font-bold text-gray-500 uppercase tracking-widest">分类定义</h3>
<button
onClick={addClass}
className="text-cyan-400 hover:text-cyan-300 text-xs transition-colors flex items-center gap-1"
>
<Plus size={12} /> 添加分类
</button>
</div>
<div className="space-y-2">
{editClasses.map((cls) => (
<div key={cls.id} className="flex items-center gap-3 bg-[#0d0d0d] border border-white/5 rounded-lg p-3">
<input
type="color"
value={cls.color}
onChange={(e) => updateClass(cls.id, { color: e.target.value })}
className="w-8 h-8 rounded bg-transparent border-0 cursor-pointer"
/>
{editingClassId === cls.id ? (
<>
<input
type="text"
value={cls.name}
onChange={(e) => updateClass(cls.id, { name: e.target.value })}
onBlur={() => setEditingClassId(null)}
onKeyDown={(e) => e.key === 'Enter' && setEditingClassId(null)}
autoFocus
className="flex-1 bg-[#1a1a1a] border border-white/10 rounded px-2 py-1 text-sm text-white"
/>
<input
type="number"
value={cls.zIndex}
onChange={(e) => updateClass(cls.id, { zIndex: parseInt(e.target.value) || 0 })}
className="w-20 bg-[#1a1a1a] border border-white/10 rounded px-2 py-1 text-sm text-white font-mono"
/>
</>
) : (
<>
<span
className="flex-1 text-sm text-gray-300 cursor-pointer"
onClick={() => setEditingClassId(cls.id)}
>
{cls.name}
</span>
<span className="w-20 text-sm text-gray-500 font-mono text-right">z:{cls.zIndex}</span>
</>
)}
<button onClick={() => removeClass(cls.id)} className="text-gray-500 hover:text-red-400 transition-colors">
<Trash2 size={14} />
</button>
</div>
))}
{editClasses.length === 0 && (
<div className="text-sm text-gray-500 text-center py-4"></div>
)}
</div>
</div>
<div className="flex justify-end gap-3">
<button
onClick={() => setShowModal(false)}
className="px-4 py-2 rounded-lg text-sm text-gray-400 hover:text-white transition-colors"
>
</button>
<button
onClick={handleSave}
disabled={isSaving || !editName.trim()}
className={cn(
"px-4 py-2 rounded-lg text-sm font-medium flex items-center gap-2 transition-all",
isSaving || !editName.trim()
? "bg-cyan-500/50 text-black/70 cursor-not-allowed"
: "bg-cyan-500 hover:bg-cyan-400 text-black"
)}
>
{isSaving && <Loader2 size={14} className="animate-spin" />}
</button>
</div>
</div>
</div>
)}
</div>
);
}

View File

@@ -1,5 +1,5 @@
import React from 'react';
import { MousePointer2, Hexagon, Square, Circle, Minus, Combine, Scissors, Wand2, Undo, Redo, Crosshair } from 'lucide-react';
import { MousePointer2, Hexagon, Square, Circle, Minus, Combine, Scissors, Wand2, Undo, Redo, Crosshair, PlusCircle, MinusCircle, SquareDashed } from 'lucide-react';
import { cn } from '../lib/utils';
interface ToolsPaletteProps {
@@ -20,6 +20,12 @@ export function ToolsPalette({ activeTool, setActiveTool, onTriggerAI }: ToolsPa
{ id: 'area_remove', icon: Scissors, label: '重叠区域去除 (-)' },
];
const aiTools = [
{ id: 'point_pos', icon: PlusCircle, label: '正向选点 (SAM)', color: 'text-green-400', bg: 'bg-green-500/10', border: 'border-green-500/30' },
{ id: 'point_neg', icon: MinusCircle, label: '反向选点 (SAM)', color: 'text-red-400', bg: 'bg-red-500/10', border: 'border-red-500/30' },
{ id: 'box_select', icon: SquareDashed, label: '边界框选 (SAM)', color: 'text-blue-400', bg: 'bg-blue-500/10', border: 'border-blue-500/30' },
];
return (
<div className="w-12 bg-[#0d0d0d] border-r border-white/5 flex flex-col items-center py-4 shrink-0 z-10">
<div className="flex flex-col gap-4 w-full px-2">
@@ -47,6 +53,26 @@ export function ToolsPalette({ activeTool, setActiveTool, onTriggerAI }: ToolsPa
<div className="w-full h-px bg-white/10 my-1" />
{aiTools.map(tool => {
const Icon = tool.icon;
const isActive = activeTool === tool.id;
return (
<button
key={tool.id}
onClick={() => setActiveTool(tool.id)}
title={tool.label}
className={cn(
"w-10 h-10 rounded-lg flex items-center justify-center transition-all p-2 border",
isActive
? `${tool.bg} ${tool.color} ${tool.border} shadow-[0_0_10px_rgba(255,255,255,0.05)]`
: "text-gray-500 hover:bg-white/5 hover:text-white border-transparent"
)}
>
<Icon size={18} strokeWidth={isActive ? 2.5 : 2} />
</button>
)
})}
<button
onClick={() => {
setActiveTool('sam_trigger');

View File

@@ -1,11 +1,13 @@
import React, { useState } from 'react';
import React from 'react';
import { useStore } from '../store/useStore';
import { CanvasArea } from './CanvasArea';
import { ToolsPalette } from './ToolsPalette';
import { OntologyInspector } from './OntologyInspector';
import { FrameTimeline } from './FrameTimeline';
export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) {
const [activeTool, setActiveTool] = useState<string>('move');
const activeTool = useStore((state) => state.activeTool);
const setActiveTool = useStore((state) => state.setActiveTool);
return (
<div className="w-full h-full flex flex-col bg-[#0a0a0a]">

135
src/lib/api.ts Normal file
View File

@@ -0,0 +1,135 @@
import axios, { AxiosError } from 'axios';
import type { Project, Template } from '../store/useStore';
const apiClient = axios.create({
baseURL: 'http://localhost:8000',
headers: {
'Content-Type': 'application/json',
},
timeout: 30000,
});
// Request interceptor: attach token
apiClient.interceptors.request.use(
(config) => {
const token = localStorage.getItem('token');
if (token) {
config.headers.Authorization = `Bearer ${token}`;
}
return config;
},
(error) => Promise.reject(error)
);
// Response interceptor: handle errors
apiClient.interceptors.response.use(
(response) => response,
(error: AxiosError) => {
if (error.response?.status === 401) {
localStorage.removeItem('token');
window.location.reload();
}
return Promise.reject(error);
}
);
// Auth
export async function login(username: string, password: string): Promise<{ token: string }> {
const response = await apiClient.post('/api/auth/login', { username, password });
return response.data;
}
// Projects
export async function getProjects(): Promise<Project[]> {
const response = await apiClient.get('/api/projects');
return response.data;
}
export async function createProject(payload: {
name: string;
description?: string;
}): Promise<Project> {
const response = await apiClient.post('/api/projects', payload);
return response.data;
}
export async function updateProject(id: string, payload: Partial<Project>): Promise<Project> {
const response = await apiClient.put(`/api/projects/${id}`, payload);
return response.data;
}
export async function deleteProject(id: string): Promise<void> {
await apiClient.delete(`/api/projects/${id}`);
}
// Templates
export async function getTemplates(): Promise<Template[]> {
const response = await apiClient.get('/api/templates');
return response.data;
}
export async function createTemplate(payload: {
name: string;
description?: string;
classes?: { name: string; color: string; zIndex: number; category?: string }[];
}): Promise<Template> {
const response = await apiClient.post('/api/templates', payload);
return response.data;
}
export async function updateTemplate(id: string, payload: Partial<Template>): Promise<Template> {
const response = await apiClient.put(`/api/templates/${id}`, payload);
return response.data;
}
export async function deleteTemplate(id: string): Promise<void> {
await apiClient.delete(`/api/templates/${id}`);
}
// Media
export async function uploadMedia(file: File, projectId?: string): Promise<{ url: string; id: string }> {
const formData = new FormData();
formData.append('file', file);
if (projectId) {
formData.append('project_id', projectId);
}
const response = await apiClient.post('/api/media/upload', formData, {
headers: {
'Content-Type': 'multipart/form-data',
},
});
return response.data;
}
// AI Prediction
export async function predictMask(payload: {
imageUrl: string;
points?: { x: number; y: number; type: 'pos' | 'neg' }[];
box?: { x1: number; y1: number; x2: number; y2: number };
text?: string;
modelSize?: string;
}): Promise<{
masks: Array<{
id: string;
pathData: string;
label: string;
color: string;
segmentation: number[][];
bbox: [number, number, number, number];
area: number;
confidence: number;
}>;
}> {
const response = await apiClient.post('/api/ai/predict', payload);
return response.data;
}
// Export
export async function exportCoco(projectId: string): Promise<Blob> {
const response = await apiClient.get(`/api/export/coco/${projectId}`, {
responseType: 'blob',
});
return response.data;
}
export default apiClient;

104
src/lib/websocket.ts Normal file
View File

@@ -0,0 +1,104 @@
type ProgressCallback = (data: ProgressMessage) => void;
interface ProgressMessage {
type: 'progress' | 'status' | 'error' | 'complete';
taskId?: string;
filename?: string;
progress?: number;
status?: string;
message?: string;
timestamp?: string;
}
class ProgressWebSocket {
private ws: WebSocket | null = null;
private url: string;
private callbacks: Set<ProgressCallback> = new Set();
private reconnectTimer: ReturnType<typeof setTimeout> | null = null;
private reconnectInterval = 3000;
private maxReconnectInterval = 30000;
private shouldReconnect = false;
private currentInterval = 3000;
constructor(url = 'ws://localhost:8000/ws/progress') {
this.url = url;
}
connect() {
if (this.ws && (this.ws.readyState === WebSocket.OPEN || this.ws.readyState === WebSocket.CONNECTING)) {
return;
}
this.shouldReconnect = true;
try {
this.ws = new WebSocket(this.url);
this.ws.onopen = () => {
this.currentInterval = this.reconnectInterval;
console.log('[WebSocket] Connected to progress stream');
};
this.ws.onmessage = (event) => {
try {
const data: ProgressMessage = JSON.parse(event.data);
this.callbacks.forEach((cb) => cb(data));
} catch (err) {
console.error('[WebSocket] Failed to parse message:', err);
}
};
this.ws.onclose = () => {
console.log('[WebSocket] Connection closed');
if (this.shouldReconnect) {
this.scheduleReconnect();
}
};
this.ws.onerror = (err) => {
console.error('[WebSocket] Error:', err);
this.ws?.close();
};
} catch (err) {
console.error('[WebSocket] Failed to connect:', err);
this.scheduleReconnect();
}
}
disconnect() {
this.shouldReconnect = false;
if (this.reconnectTimer) {
clearTimeout(this.reconnectTimer);
this.reconnectTimer = null;
}
if (this.ws) {
this.ws.close();
this.ws = null;
}
}
onProgress(callback: ProgressCallback) {
this.callbacks.add(callback);
return () => {
this.callbacks.delete(callback);
};
}
private scheduleReconnect() {
if (this.reconnectTimer) {
clearTimeout(this.reconnectTimer);
}
this.reconnectTimer = setTimeout(() => {
console.log(`[WebSocket] Reconnecting in ${this.currentInterval}ms...`);
this.connect();
this.currentInterval = Math.min(this.currentInterval * 1.5, this.maxReconnectInterval);
}, this.currentInterval);
}
isConnected(): boolean {
return this.ws !== null && this.ws.readyState === WebSocket.OPEN;
}
}
export const progressWS = new ProgressWebSocket();
export type { ProgressMessage };

195
src/store/useStore.ts Normal file
View File

@@ -0,0 +1,195 @@
import { create } from 'zustand';
export interface Project {
id: string;
name: string;
description?: string;
status: 'Ready' | 'Parsing' | 'Error';
fps?: string;
frames?: number;
thumbnail?: string;
createdAt?: string;
updatedAt?: string;
}
export interface Frame {
id: string;
projectId: string;
index: number;
url: string;
width: number;
height: number;
timestamp?: string;
}
export interface Annotation {
id: string;
frameId: string;
type: 'polygon' | 'rectangle' | 'circle' | 'point' | 'mask';
points: number[];
label: string;
color: string;
zIndex?: number;
confidence?: number;
metadata?: Record<string, unknown>;
}
export interface Mask {
id: string;
frameId: string;
pathData: string;
label: string;
color: string;
opacity?: number;
segmentation?: number[][];
bbox?: [number, number, number, number];
area?: number;
}
export interface Template {
id: string;
name: string;
description?: string;
classes: TemplateClass[];
rules?: TemplateRule[];
createdAt?: string;
updatedAt?: string;
}
export interface TemplateClass {
id: string;
name: string;
color: string;
zIndex: number;
category?: string;
description?: string;
}
export interface TemplateRule {
id: string;
name: string;
sourceKey: string;
targetKey: string;
operation: string;
}
export interface AppState {
// Auth
isAuthenticated: boolean;
token: string | null;
login: (token: string) => void;
logout: () => void;
// Projects
projects: Project[];
currentProject: Project | null;
setProjects: (projects: Project[]) => void;
setCurrentProject: (project: Project | null) => void;
addProject: (project: Project) => void;
updateProject: (project: Project) => void;
// Workspace
activeModule: string;
activeTool: string;
frames: Frame[];
currentFrameIndex: number;
annotations: Annotation[];
masks: Mask[];
setActiveModule: (module: string) => void;
setActiveTool: (tool: string) => void;
setFrames: (frames: Frame[]) => void;
setCurrentFrame: (index: number) => void;
addAnnotation: (annotation: Annotation) => void;
addMask: (mask: Mask) => void;
clearMasks: () => void;
removeAnnotation: (id: string) => void;
// Templates
templates: Template[];
setTemplates: (templates: Template[]) => void;
addTemplate: (template: Template) => void;
updateTemplate: (template: Template) => void;
removeTemplate: (id: string) => void;
// UI
isLoading: boolean;
error: string | null;
setLoading: (loading: boolean) => void;
setError: (error: string | null) => void;
}
export const useStore = create<AppState>((set) => ({
// Auth
isAuthenticated: false,
token: null,
login: (token: string) => {
localStorage.setItem('token', token);
set({ isAuthenticated: true, token });
},
logout: () => {
localStorage.removeItem('token');
set({
isAuthenticated: false,
token: null,
currentProject: null,
projects: [],
templates: [],
frames: [],
annotations: [],
masks: [],
});
},
// Projects
projects: [],
currentProject: null,
setProjects: (projects: Project[]) => set({ projects }),
setCurrentProject: (currentProject: Project | null) => set({ currentProject }),
addProject: (project: Project) =>
set((state) => ({ projects: [project, ...state.projects] })),
updateProject: (project: Project) =>
set((state) => ({
projects: state.projects.map((p) => (p.id === project.id ? project : p)),
})),
// Workspace
activeModule: 'workspace',
activeTool: 'move',
frames: [],
currentFrameIndex: 0,
annotations: [],
masks: [],
setActiveModule: (activeModule: string) => set({ activeModule }),
setActiveTool: (activeTool: string) => set({ activeTool }),
setFrames: (frames: Frame[]) => set({ frames }),
setCurrentFrame: (currentFrameIndex: number) => set({ currentFrameIndex }),
addAnnotation: (annotation: Annotation) =>
set((state) => ({ annotations: [...state.annotations, annotation] })),
addMask: (mask: Mask) =>
set((state) => ({ masks: [...state.masks, mask] })),
clearMasks: () => set({ masks: [] }),
removeAnnotation: (id: string) =>
set((state) => ({
annotations: state.annotations.filter((a) => a.id !== id),
})),
// Templates
templates: [],
setTemplates: (templates: Template[]) => set({ templates }),
addTemplate: (template: Template) =>
set((state) => ({ templates: [...state.templates, template] })),
updateTemplate: (template: Template) =>
set((state) => ({
templates: state.templates.map((t) => (t.id === template.id ? template : t)),
})),
removeTemplate: (id: string) =>
set((state) => ({
templates: state.templates.filter((t) => t.id !== id),
})),
// UI
isLoading: false,
error: null,
setLoading: (isLoading: boolean) => set({ isLoading }),
setError: (error: string | null) => set({ error }),
}));

66
start_services.sh Executable file
View File

@@ -0,0 +1,66 @@
#!/bin/bash
# 语义分割系统 — 一键启动脚本
# 时间戳: 2026-04-29-21-51-19
set -e
PROJECT_DIR="/home/wkmgc/Desktop/Seg_Server"
CONDA_ENV="seg_server"
echo "========================================"
echo " 语义分割系统全栈启动"
echo "========================================"
# 1. 检查 PostgreSQL
echo "[1/5] 检查 PostgreSQL..."
if ! pg_isready -q; then
echo "Wkmgc" | sudo -S systemctl start postgresql
sleep 1
fi
pg_isready && echo " ✓ PostgreSQL 就绪"
# 2. 检查 Redis
echo "[2/5] 检查 Redis..."
if ! redis-cli ping > /dev/null 2>&1; then
echo "Wkmgc" | sudo -S systemctl start redis-server
sleep 1
fi
redis-cli ping && echo " ✓ Redis 就绪"
# 3. 检查 MinIO
echo "[3/5] 检查 MinIO..."
if ! curl -s http://localhost:9000/minio/health/live > /dev/null; then
nohup minio server /home/wkmgc/minio_data --console-address :9001 > /tmp/minio.log 2>&1 &
sleep 3
fi
curl -s http://localhost:9000/minio/health/live > /dev/null && echo " ✓ MinIO 就绪 (http://localhost:9001)"
# 4. 启动 FastAPI 后端
echo "[4/5] 启动 FastAPI 后端..."
source /home/wkmgc/miniconda3/etc/profile.d/conda.sh
conda activate "$CONDA_ENV"
cd "$PROJECT_DIR/backend"
nohup uvicorn main:app --host 0.0.0.0 --port 8000 --reload > /tmp/fastapi.log 2>&1 &
sleep 2
echo " ✓ FastAPI 已启动 (http://localhost:8000/docs)"
# 5. 启动前端
echo "[5/5] 启动前端..."
cd "$PROJECT_DIR"
nohup npm start > /tmp/frontend.log 2>&1 &
sleep 2
echo " ✓ 前端已启动 (http://localhost:3000)"
echo ""
echo "========================================"
echo " 所有服务已启动"
echo "========================================"
echo "前端: http://localhost:3000"
echo "后端 API: http://localhost:8000/docs"
echo "MinIO: http://localhost:9001"
echo ""
echo "日志文件:"
echo " FastAPI: /tmp/fastapi.log"
echo " 前端: /tmp/frontend.log"
echo " MinIO: /tmp/minio.log"
echo "========================================"

View File

@@ -0,0 +1,185 @@
# 实现方案 - 2026-04-29-21-51-19
## 对应需求
- 需求分析文档: `需求分析-2026-04-29-21-51-19.md`
## 方案概述
在 RTX 4090 + Ubuntu 22.04 环境下,将纯前端 React 应用改造为全栈语义分割系统。本次为**骨架搭建 + 核心能力落地**,优先保证各模块可独立运行并互联互通。
## 整体架构
```
┌─────────────────────────────────────────────────────────────────┐
│ 前端层 (React 19) │
│ Zustand (状态) + Axios (HTTP) + WebSocket (实时进度) │
│ Konva (Canvas) + TailwindCSS (样式) │
└────────────────────────────┬────────────────────────────────────┘
│ HTTP / WebSocket
┌────────────────────────────▼────────────────────────────────────┐
│ 后端层 (FastAPI + Python) │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ 项目/模板 API │ │ 媒体解析 API │ │ AI 推理 API │ │
│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │
│ └─────────────────┼─────────────────┘ │
│ │ SQLAlchemy │
│ ┌────────────────────────▼────────────────────────┐ │
│ │ PostgreSQL (关系数据) │ │
│ └─────────────────────────────────────────────────┘ │
│ │ │
│ ┌────────────────────────▼────────────────────────┐ │
│ │ MinIO (对象存储: 视频/帧/Mask) │ │
│ └─────────────────────────────────────────────────┘ │
│ │ │
│ ┌────────────────────────▼────────────────────────┐ │
│ │ Redis (缓存 + 任务队列) │ │
│ └─────────────────────────────────────────────────┘ │
│ │ │
│ ┌────────────────────────▼────────────────────────┐ │
│ │ SAM 3 (GPU 推理节点) │ │
│ └─────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
```
## 修改/新增文件清单
### A. 基础设施层
#### A1. 系统服务安装
- **操作**: `sudo apt install postgresql postgresql-contrib redis-server ffmpeg`
- **操作**: 下载并启动 MinIO 二进制
- **说明**: 配置 PostgreSQL 数据库/用户Redis 默认端口MinIO 端口 9000/9001
#### A2. Conda 环境
- **操作**: `conda create -n seg_server python=3.11`
- **操作**: 安装 PyTorch 2.5+ CUDA 13.x、FastAPI、SQLAlchemy、psycopg2-binary、redis-py、minio、uvicorn、python-multipart、celery、opencv-python、pillow、scikit-image、pydicom、numpy
### B. 后端层 (新增 backend/ 目录)
#### B1. `backend/main.py`
- **类型**: 新增
- **内容**: FastAPI 入口CORS 配置, lifespan 管理(启动/关闭数据库、Redis、MinIO 连接)
#### B2. `backend/config.py`
- **类型**: 新增
- **内容**: 环境变量配置DB_URL、REDIS_URL、MINIO_ENDPOINT、SAM_MODEL_PATH 等)
#### B3. `backend/database.py`
- **类型**: 新增
- **内容**: SQLAlchemy engine + session + BasePostgreSQL 连接
#### B4. `backend/models.py`
- **类型**: 新增
- **内容**: 数据库 ORM 模型Project、Frame、Annotation、Template、Mask
#### B5. `backend/schemas.py`
- **类型**: 新增
- **内容**: Pydantic 数据校验模型
#### B6. `backend/minio_client.py`
- **类型**: 新增
- **内容**: MinIO 客户端封装(上传、下载、预签名 URL
#### B7. `backend/redis_client.py`
- **类型**: 新增
- **内容**: Redis 客户端封装
#### B8. `backend/routers/projects.py`
- **类型**: 新增
- **内容**: 项目 CRUD API
#### B9. `backend/routers/templates.py`
- **类型**: 新增
- **内容**: 本体模板 API
#### B10. `backend/routers/media.py`
- **类型**: 新增
- **内容**: 视频/图片/DCM 上传 APIFFmpeg 拆帧任务触发
#### B11. `backend/routers/ai.py`
- **类型**: 新增
- **内容**: SAM 3 推理 APIpoint/box/semanticMask 生成与存储
#### B12. `backend/routers/export.py`
- **类型**: 新增
- **内容**: 标注数据导出COCO JSON、PNG Mask
#### B13. `backend/services/sam3_engine.py`
- **类型**: 新增
- **内容**: SAM 3 模型加载、推理封装、GPU 内存管理
#### B14. `backend/services/frame_parser.py`
- **类型**: 新增
- **内容**: FFmpeg 视频拆帧、DCM 影像逐帧提取、帧上传 MinIO
#### B15. `backend/services/task_queue.py`
- **类型**: 新增
- **内容**: Celery + Redis 异步任务队列封装
#### B16. `backend/requirements.txt`
- **类型**: 新增
- **内容**: Python 依赖清单
### C. 前端层 (修改现有 src/)
#### C1. `src/store/index.ts` (新增)
- **内容**: Zustand 全局 Storeproject、workspace、annotations、ui 状态)
#### C2. `src/lib/api.ts` (新增)
- **内容**: Axios 实例封装baseURL 指向 FastAPI
#### C3. `src/lib/websocket.ts` (新增)
- **内容**: WebSocket 客户端,接收解析进度
#### C4. `src/App.tsx` (修改)
- **内容**: 移除 Login 硬编码,接入真实登录 APIProvider 包裹
#### C5. `src/components/ProjectLibrary.tsx` (修改)
- **内容**: 从后端 API 加载项目列表,支持创建/删除
#### C6. `src/components/TemplateRegistry.tsx` (修改)
- **内容**: 从后端 API 加载本体字典,支持动态增删
#### C7. `src/components/CanvasArea.tsx` (修改)
- **内容**: 点击画布捕获坐标,发送至后端 AI 接口,接收并渲染 Mask Path
#### C8. `src/components/AISegmentation.tsx` (修改)
- **内容**: 对接后端 SAM 3 推理接口,显示推理状态
#### C9. `src/components/Dashboard.tsx` (修改)
- **内容**: 显示真实解析队列进度WebSocket
### D. 部署与运维
#### D1. `start_services.sh` (新增)
- **内容**: 一键启动 PostgreSQL、Redis、MinIO、FastAPI 的脚本
#### D2. `backend/download_sam3.py` (新增)
- **内容**: SAM 3 模型权重自动下载脚本
## 新增依赖
### Python (conda 环境)
```
fastapi uvicorn[standard] python-multipart
sqlalchemy psycopg2-binary alembic
redis celery
minio
opencv-python pillow scikit-image pydicom numpy
```
### 前端 (npm)
```
zustand axios
```
## 兼容性分析
- Express `server.ts` 将被保留但不再作为默认启动方式FastAPI 成为主后端
- 前端路由逻辑不变,仅数据获取方式从内存/mock 改为 HTTP API
- 回滚策略: 回退到 `npm start` 仍可运行旧版 Express 前端
## 预估工作量
- 基础设施: 20 分钟
- 后端骨架: 40 分钟
- 前端改造: 30 分钟
- SAM 3 部署: 20 分钟
- 联调验证: 20 分钟

View File

@@ -0,0 +1,78 @@
# 测试方案 - 2026-04-29-21-51-19
## 对应实现方案
- 实现方案文档: `实现方案-2026-04-29-21-51-19.md`
## 测试范围
- 基础设施服务可达性
- FastAPI 后端启动与 API 响应
- 前端构建与后端联调
- SAM 3 模型加载与 GPU 可用性
- 文件上传与解析流程
## 测试用例
### 用例 1: 基础设施服务验证
- **前置条件**: 执行 start_services.sh
- **操作步骤**:
1. `curl http://localhost:9000/minio/health/live` → MinIO
2. `redis-cli ping` → Redis
3. `sudo -u postgres psql -c "\\l"` → PostgreSQL
- **预期结果**: 全部返回正常响应
- **通过标准**: MinIO 200, Redis PONG, PostgreSQL 显示数据库列表
### 用例 2: Conda 环境 + GPU 验证
- **前置条件**: conda 环境已创建
- **操作步骤**:
1. `conda activate seg_server`
2. `python -c "import torch; print(torch.cuda.is_available())"`
- **预期结果**: 输出 True
- **通过标准**: PyTorch 识别到 CUDA
### 用例 3: SAM 3 权重下载验证
- **前置条件**: 运行 download_sam3.py
- **操作步骤**: 检查权重文件存在且大小合理
- **预期结果**: .pt/.pth 文件存在于 models/ 目录
- **通过标准**: 文件大小 > 100MB
### 用例 4: FastAPI 启动验证
- **前置条件**: 依赖安装完成
- **操作步骤**:
1. `cd backend && uvicorn main:app --host 0.0.0.0 --port 8000`
2. 访问 `http://localhost:8000/docs`
- **预期结果**: Swagger UI 正常显示,包含所有 API 路由
- **通过标准**: HTTP 200路由列表完整
### 用例 5: 前端构建验证
- **前置条件**: 前端代码已改造
- **操作步骤**: `npm run lint && npm run build`
- **预期结果**: 无类型错误,构建成功
- **通过标准**: exit code 0
### 用例 6: 前后端联调验证
- **前置条件**: 前后端均运行中
- **操作步骤**:
1. 前端访问 `http://localhost:3000`
2. 打开浏览器 DevTools Network 面板
3. 触发项目列表加载
- **预期结果**: 可见对 `http://localhost:8000/api/projects` 的请求,且返回 200
- **通过标准**: API 数据正确渲染到界面
### 用例 7: 文件上传验证
- **前置条件**: 媒体解析模块就绪
- **操作步骤**: 上传 @Data_MyVideo_1.mp4
- **预期结果**: 后端接收文件,存入 MinIO触发解析任务
- **通过标准**: MinIO bucket 中出现文件,返回 job_id
## 回归测试
- [ ] 现有 React 组件无运行时错误
- [ ] 深色主题样式未被破坏
- [ ] Konva Canvas 可正常交互
- [ ] 构建产物体积未异常膨胀
## 测试环境
- OS: Ubuntu 22.04
- GPU: NVIDIA RTX 4090 24GB
- CUDA: 13.2
- Python: 3.11 (conda)
- Node.js: 22.x

View File

@@ -65,4 +65,38 @@ AI 助手运行的容器/环境与项目实际开发环境分离,后者才装
---
## 2026-04-29-21-51-19 — 全栈系统改造FastAPI + SAM2 + PostgreSQL + Redis + MinIO
### A. 具体问题
1. 将纯前端 React 应用改造为全栈系统时工程涉及后端框架替换、数据库设计、对象存储、AI 推理引擎、前端状态管理重构等多个复杂模块,单次执行工程量大。
2. 系统磁盘空间24G不足PyTorch CUDA 版本(>2GB和 sam2 pip 包编译(需下载 CUDA 工具链 + 编译 C++ 扩展)均因 `OSError: No space left on device` 失败。
3. MinIO 对象存储在磁盘紧张时报 `XMinioStorageFull`,导致文件上传失败。
4. 前端 Agent 执行时因目录扁平化后子目录不存在,产生多次 `File not found` 错误。
### B. 产生原因
1. 项目改造范围超出单次会议合理容量,涉及 15+ 后端文件、10+ 前端文件、4 个基础设施服务、1 个 AI 模型栈。
2. 系统盘仅有 24GBconda 环境2GB、node_modules222MB、模型权重1.5GB、MinIO 帧文件1.4GB)叠加后迅速耗尽空间。
3. sam2 的 pip 包依赖 torch>=2.5.1pip 会尝试重新下载完整 torch wheel530MB作为 build dependency即使 torch 已安装。
4. 前端 Agent 的 prompt 中未明确说明组件目录已扁平化Agent 仍尝试读取旧的子目录路径。
### C. 解决方案
1. **任务拆分 + 并行 Agent**: 将后端和前端代码编写拆分为两个独立 Agent 并行执行,基础设施安装与代码编写并行推进,显著缩短总耗时。
2. **磁盘管理策略**:
- 安装 PyTorch CPU 版本替代 CUDA 版本(占用更小)
- 只保留 sam2_hiera_tiny.pt149MB删除其他大模型
- 清理 conda pkgs 缓存(释放 600MB+
- 删除 MinIO 中解析生成的临时帧文件(释放 1.4GB
3. **sam2 安装降级**: 当前环境以 stub 模式运行,提供 `install_sam2.sh` 脚本供用户在扩展磁盘后执行真实安装。
4. **API 路径修复**: 后端添加 `/api/auth/login` 路由,修复前端 api.ts 中 `/api/predict``/api/ai/predict` 的路径不匹配。
5. **MinIO API 适配**: minio 7.2.x 中 `presigned_url()` 已改为 `get_presigned_url()`,需从 `datetime.timedelta` 传入 expires。
### D. 后续如何避免问题
1. **大型改造前先做磁盘评估**: 执行 `df -h` 确认可用空间 > 5GB 再开始安装大型依赖。
2. **AI 模型依赖延迟加载**: 所有 AI 推理引擎必须实现 graceful fallback模型缺失时不阻塞系统启动。
3. **Agent prompt 需同步最新目录结构**: 给 Agent 的上下文必须包含当前真实的文件路径,避免 `File not found`
4. **构建依赖隔离**: 使用 `--no-build-isolation``--no-deps` 安装源码包,避免 pip 重复下载已安装的依赖。
5. **MinIO 空间监控**: 定期清理解析产生的临时帧文件,或配置 MinIO 使用独立大容量数据盘。
---
> 新增经验请追加到文件末尾,保持时间倒序或正序均可,但需确保每条经验包含完整的 A/B/C/D 四段。

View File

@@ -0,0 +1,75 @@
# 需求分析 - 2026-04-29-21-51-19
## 需求来源
- 提出时间: 2026-04-29-21-51-19
- 需求类型: 全栈系统改造 / 架构重构
## 原始需求描述
将现有纯前端 React 静态 UI 彻底改造为前后端联动的全栈语义分割系统。打通数据流转,实现:
1. 本地多媒体资产(视频/图片/DCM影像上传
2. 服务器端按帧解析FFmpeg 拆帧)
3. AI 视觉大模型 SAM 3 实时交互式推理点分割、box分割、语义分割
4. 动态图层状态管理
5. 标注数据结构化导出
后端技术栈Python + FastAPI + PostgreSQL + Redis + MinIO
AI 基座SAM 3
输入数据:@Data_MyVideo_1.mp4、@Data_Dicom帧DCM格式
## 需求拆解
### 需求 1: 基础设施部署
- **详细描述**: 在 Ubuntu 22.04 上部署 PostgreSQL、Redis、MinIO 对象存储
- **优先级**: P0-阻塞
- **影响范围**: 系统级服务
- **验收标准**: 三个服务均可访问,有持久化数据目录
### 需求 2: Conda 环境 + SAM 3 模型部署
- **详细描述**: 新建 conda 环境 seg_server安装 PyTorch + CUDA部署 SAM 3 模型权重
- **优先级**: P0-阻塞
- **影响范围**: Python AI 推理层
- **验收标准**: conda 环境可激活GPU 可被 PyTorch 识别SAM 3 权重文件存在
### 需求 3: FastAPI 后端框架搭建
- **详细描述**: 替换现有 Express 后端,建立 FastAPI 服务包含上传、项目、模板、AI推理、导出等模块
- **优先级**: P0-阻塞
- **影响范围**: 后端全部
- **验收标准**: FastAPI 服务可启动Swagger UI 可访问
### 需求 4: 前端状态管理 + 网络层改造
- **详细描述**: 引入 Zustand 全局状态管理Axios 替代 fetchWebSocket 对接解析进度
- **优先级**: P0-阻塞
- **影响范围**: src/ 全部
- **验收标准**: 前端可正常与后端通信,状态跨组件同步
### 需求 5: Canvas Mask 渲染对接
- **详细描述**: Konva Canvas 接收后端 SAM 3 返回的 Mask 数据Polygon/RLE动态渲染遮罩图层
- **优先级**: P1-高
- **影响范围**: CanvasArea.tsx, AISegmentation.tsx
- **验收标准**: 点击画布后,后端返回 Mask前端可正确绘制
### 需求 6: 视频/DCM 解析流水线
- **详细描述**: 上传视频后 FFmpeg 拆帧DCM 影像逐帧提取,进度通过 WebSocket 推送
- **优先级**: P1-高
- **影响范围**: 后端任务模块 + 前端 Dashboard
- **验收标准**: 可上传文件,解析队列可显示进度,帧图片存入 MinIO
### 需求 7: 数据持久化(项目/模板/标注)
- **详细描述**: PostgreSQL 存储项目元数据、本体字典、标注结果MinIO 存储原始媒体/帧/Mask
- **优先级**: P1-高
- **影响范围**: 后端数据库模型 + API
- **验收标准**: 数据重启后仍可读取
## 约束条件
- 使用 Ubuntu 22.04, sudo 密码 Wkmgc
- GPU: RTX 4090, CUDA 13.2
- 无 Docker使用系统包管理 + 二进制部署
- 保持现有前端界面风格(深色主题、中文)
## 风险评估
| 风险点 | 影响 | 缓解措施 |
|--------|------|----------|
| SAM 3 权重下载慢/失败 | 高 | 使用国内镜像源,分段下载验证 |
| PostgreSQL/Redis 端口冲突 | 中 | 检查端口占用,使用非默认端口 |
| MinIO 磁盘空间不足 | 中 | 监控磁盘,配置清理策略 |
| 单次会议工程过大 | 高 | 分模块迭代,先跑通骨架再丰富 |