2026-04-29-21-51-19 - 全栈系统改造:FastAPI后端+SAM2+PostgreSQL+Redis+MinIO+前端Zustand重构
This commit is contained in:
15
.gitignore
vendored
15
.gitignore
vendored
@@ -6,3 +6,18 @@ coverage/
|
||||
*.log
|
||||
.env*
|
||||
!.env.example
|
||||
# Data & Models
|
||||
models/
|
||||
uploads/
|
||||
frames/
|
||||
minio_data/
|
||||
Data_*/
|
||||
*.mp4
|
||||
*.dcm
|
||||
*.7z
|
||||
# Binaries
|
||||
Viewer/
|
||||
Viewer.exe
|
||||
DICOMDIR
|
||||
__pycache__/
|
||||
*.pyc
|
||||
|
||||
35
backend/config.py
Normal file
35
backend/config.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Application configuration using Pydantic Settings."""
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
# Database
|
||||
db_url: str = "postgresql://seguser:segpass123@localhost:5432/segserver"
|
||||
|
||||
# Redis
|
||||
redis_url: str = "redis://localhost:6379/0"
|
||||
|
||||
# MinIO
|
||||
minio_endpoint: str = "localhost:9000"
|
||||
minio_access_key: str = "minioadmin"
|
||||
minio_secret_key: str = "minioadmin"
|
||||
minio_secure: bool = False
|
||||
|
||||
# SAM2
|
||||
sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt"
|
||||
sam_model_config: str = "sam2_hiera_t.yaml"
|
||||
|
||||
# App
|
||||
app_env: str = "development"
|
||||
cors_origins: list[str] = ["http://localhost:3000"]
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
29
backend/database.py
Normal file
29
backend/database.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Database configuration using synchronous SQLAlchemy."""
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker, Session
|
||||
from fastapi import Depends
|
||||
from typing import Generator
|
||||
|
||||
from config import settings
|
||||
|
||||
engine = create_engine(
|
||||
settings.db_url,
|
||||
pool_pre_ping=True,
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""FastAPI dependency that yields a database session."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
49
backend/download_sam2.py
Normal file
49
backend/download_sam2.py
Normal file
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
SAM 2 模型权重下载脚本
|
||||
运行: python download_sam2.py
|
||||
"""
|
||||
import os
|
||||
import urllib.request
|
||||
import sys
|
||||
|
||||
MODEL_DIR = "/home/wkmgc/Desktop/Seg_Server/models"
|
||||
os.makedirs(MODEL_DIR, exist_ok=True)
|
||||
|
||||
# SAM 2 模型权重 (Meta AI 官方)
|
||||
MODELS = {
|
||||
"sam2_hiera_tiny.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt",
|
||||
"sam2_hiera_small.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
|
||||
"sam2_hiera_base_plus.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
|
||||
"sam2_hiera_large.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
|
||||
}
|
||||
|
||||
def download_file(url: str, dest: str):
|
||||
if os.path.exists(dest):
|
||||
print(f"[跳过] {os.path.basename(dest)} 已存在")
|
||||
return
|
||||
print(f"[下载] {url}")
|
||||
print(f" → {dest}")
|
||||
try:
|
||||
urllib.request.urlretrieve(url, dest)
|
||||
size_mb = os.path.getsize(dest) / 1024 / 1024
|
||||
print(f" ✓ 完成 ({size_mb:.1f} MB)")
|
||||
except Exception as e:
|
||||
print(f" ✗ 失败: {e}", file=sys.stderr)
|
||||
if os.path.exists(dest):
|
||||
os.remove(dest)
|
||||
|
||||
def main():
|
||||
print("=" * 50)
|
||||
print("SAM 2 模型权重下载")
|
||||
print("=" * 50)
|
||||
for name, url in MODELS.items():
|
||||
dest = os.path.join(MODEL_DIR, name)
|
||||
download_file(url, dest)
|
||||
print("=" * 50)
|
||||
print("全部完成!")
|
||||
print(f"模型目录: {MODEL_DIR}")
|
||||
print("=" * 50)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
83
backend/main.py
Normal file
83
backend/main.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""FastAPI application entrypoint."""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from config import settings
|
||||
from database import Base, engine
|
||||
from minio_client import ensure_bucket_exists
|
||||
from redis_client import ping as redis_ping
|
||||
|
||||
from routers import projects, templates, media, ai, export, auth
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan: startup and shutdown hooks."""
|
||||
# Startup
|
||||
logger.info("Starting up SegServer backend...")
|
||||
|
||||
# Initialize database tables
|
||||
try:
|
||||
Base.metadata.create_all(bind=engine)
|
||||
logger.info("Database tables initialized.")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Database initialization failed: %s", exc)
|
||||
|
||||
# Check MinIO bucket
|
||||
try:
|
||||
ensure_bucket_exists()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("MinIO bucket check failed: %s", exc)
|
||||
|
||||
# Check Redis
|
||||
if redis_ping():
|
||||
logger.info("Redis connection OK.")
|
||||
else:
|
||||
logger.warning("Redis connection failed.")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down SegServer backend...")
|
||||
engine.dispose()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="SegServer API",
|
||||
description="Semantic Segmentation System Backend",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Routers
|
||||
app.include_router(auth.router)
|
||||
app.include_router(projects.router)
|
||||
app.include_router(templates.router)
|
||||
app.include_router(media.router)
|
||||
app.include_router(ai.router)
|
||||
app.include_router(export.router)
|
||||
|
||||
|
||||
@app.get("/health", tags=["Health"])
|
||||
def health_check() -> dict:
|
||||
"""Health check endpoint."""
|
||||
return {"status": "ok", "service": "SegServer"}
|
||||
127
backend/minio_client.py
Normal file
127
backend/minio_client.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""MinIO client wrapper for object storage operations."""
|
||||
|
||||
import io
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BUCKET_NAME = "seg-media"
|
||||
|
||||
_minio_client: Optional[Minio] = None
|
||||
|
||||
|
||||
def get_minio_client() -> Minio:
|
||||
"""Return a singleton MinIO client instance."""
|
||||
global _minio_client
|
||||
if _minio_client is None:
|
||||
_minio_client = Minio(
|
||||
settings.minio_endpoint,
|
||||
access_key=settings.minio_access_key,
|
||||
secret_key=settings.minio_secret_key,
|
||||
secure=settings.minio_secure,
|
||||
)
|
||||
return _minio_client
|
||||
|
||||
|
||||
def ensure_bucket_exists() -> None:
|
||||
"""Create the bucket if it does not already exist."""
|
||||
client = get_minio_client()
|
||||
try:
|
||||
if not client.bucket_exists(BUCKET_NAME):
|
||||
client.make_bucket(BUCKET_NAME)
|
||||
logger.info("Created MinIO bucket: %s", BUCKET_NAME)
|
||||
else:
|
||||
logger.info("MinIO bucket %s already exists", BUCKET_NAME)
|
||||
except S3Error as exc:
|
||||
logger.error("MinIO bucket check/creation failed: %s", exc)
|
||||
raise
|
||||
|
||||
|
||||
def upload_file(
|
||||
object_name: str,
|
||||
data: bytes,
|
||||
content_type: str = "application/octet-stream",
|
||||
length: int = -1,
|
||||
) -> str:
|
||||
"""Upload bytes to MinIO and return the object name.
|
||||
|
||||
Args:
|
||||
object_name: Destination path inside the bucket.
|
||||
data: Raw bytes or a file-like object.
|
||||
content_type: MIME type of the object.
|
||||
length: Object size; -1 for unknown (uses chunked upload).
|
||||
|
||||
Returns:
|
||||
The object name (same as input).
|
||||
"""
|
||||
client = get_minio_client()
|
||||
if isinstance(data, bytes):
|
||||
data = io.BytesIO(data)
|
||||
length = len(data.getvalue())
|
||||
|
||||
try:
|
||||
client.put_object(
|
||||
BUCKET_NAME,
|
||||
object_name,
|
||||
data,
|
||||
length=length,
|
||||
content_type=content_type,
|
||||
)
|
||||
logger.info("Uploaded to MinIO: %s", object_name)
|
||||
return object_name
|
||||
except S3Error as exc:
|
||||
logger.error("MinIO upload failed: %s", exc)
|
||||
raise
|
||||
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
def get_presigned_url(
|
||||
object_name: str,
|
||||
expires: int = 3600,
|
||||
method: str = "GET",
|
||||
) -> str:
|
||||
"""Generate a presigned URL for an object.
|
||||
|
||||
Args:
|
||||
object_name: Path inside the bucket.
|
||||
expires: Expiration time in seconds (default 1 hour).
|
||||
method: HTTP method (GET or PUT).
|
||||
|
||||
Returns:
|
||||
Presigned URL string.
|
||||
"""
|
||||
client = get_minio_client()
|
||||
try:
|
||||
url = client.get_presigned_url(method, BUCKET_NAME, object_name, expires=timedelta(seconds=expires))
|
||||
return url
|
||||
except S3Error as exc:
|
||||
logger.error("MinIO presigned URL failed: %s", exc)
|
||||
raise
|
||||
|
||||
|
||||
def download_file(object_name: str) -> bytes:
|
||||
"""Download an object from MinIO and return its bytes.
|
||||
|
||||
Args:
|
||||
object_name: Path inside the bucket.
|
||||
|
||||
Returns:
|
||||
Raw bytes of the object.
|
||||
"""
|
||||
client = get_minio_client()
|
||||
try:
|
||||
response = client.get_object(BUCKET_NAME, object_name)
|
||||
data = response.read()
|
||||
response.close()
|
||||
response.release_conn()
|
||||
return data
|
||||
except S3Error as exc:
|
||||
logger.error("MinIO download failed: %s", exc)
|
||||
raise
|
||||
118
backend/models.py
Normal file
118
backend/models.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""SQLAlchemy ORM models."""
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
JSON,
|
||||
Float,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from database import Base
|
||||
|
||||
|
||||
class Project(Base):
|
||||
"""Project model representing a segmentation project."""
|
||||
|
||||
__tablename__ = "projects"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
video_path = Column(String(512), nullable=True)
|
||||
status = Column(String(50), default="pending", nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
frames = relationship("Frame", back_populates="project", cascade="all, delete-orphan")
|
||||
annotations = relationship(
|
||||
"Annotation", back_populates="project", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class Frame(Base):
|
||||
"""Frame model representing an extracted video frame."""
|
||||
|
||||
__tablename__ = "frames"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
project_id = Column(Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
|
||||
frame_index = Column(Integer, nullable=False)
|
||||
image_url = Column(String(512), nullable=False)
|
||||
width = Column(Integer, nullable=True)
|
||||
height = Column(Integer, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
project = relationship("Project", back_populates="frames")
|
||||
annotations = relationship(
|
||||
"Annotation", back_populates="frame", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class Template(Base):
|
||||
"""Template (Ontology) model for segmentation classes."""
|
||||
|
||||
__tablename__ = "templates"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String(255), nullable=False)
|
||||
color = Column(String(50), nullable=False)
|
||||
z_index = Column(Integer, default=0, nullable=False)
|
||||
mapping_rules = Column(JSON, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
annotations = relationship(
|
||||
"Annotation", back_populates="template", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class Annotation(Base):
|
||||
"""Annotation model for segmentation masks and prompts."""
|
||||
|
||||
__tablename__ = "annotations"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
project_id = Column(
|
||||
Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
frame_id = Column(
|
||||
Integer, ForeignKey("frames.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
template_id = Column(
|
||||
Integer, ForeignKey("templates.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
mask_data = Column(JSON, nullable=True)
|
||||
points = Column(JSON, nullable=True)
|
||||
bbox = Column(JSON, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
project = relationship("Project", back_populates="annotations")
|
||||
frame = relationship("Frame", back_populates="annotations")
|
||||
template = relationship("Template", back_populates="annotations")
|
||||
masks = relationship("Mask", back_populates="annotation", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class Mask(Base):
|
||||
"""Mask model for exported/derived mask files."""
|
||||
|
||||
__tablename__ = "masks"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
annotation_id = Column(
|
||||
Integer, ForeignKey("annotations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
mask_url = Column(String(512), nullable=False)
|
||||
format = Column(String(50), default="png", nullable=False) # png / rle / json
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
annotation = relationship("Annotation", back_populates="masks")
|
||||
61
backend/redis_client.py
Normal file
61
backend/redis_client.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Redis client wrapper for caching and task queuing."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Any
|
||||
|
||||
import redis
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_redis_client: Optional[redis.Redis] = None
|
||||
|
||||
|
||||
def get_redis_client() -> redis.Redis:
|
||||
"""Return a singleton Redis client instance."""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = redis.from_url(settings.redis_url, decode_responses=True)
|
||||
return _redis_client
|
||||
|
||||
|
||||
def ping() -> bool:
|
||||
"""Check Redis connectivity."""
|
||||
try:
|
||||
return get_redis_client().ping()
|
||||
except redis.ConnectionError as exc:
|
||||
logger.error("Redis ping failed: %s", exc)
|
||||
return False
|
||||
|
||||
|
||||
def set_json(key: str, value: Any, expire: Optional[int] = None) -> None:
|
||||
"""Store a JSON-serializable value in Redis."""
|
||||
client = get_redis_client()
|
||||
try:
|
||||
client.set(key, json.dumps(value), ex=expire)
|
||||
except redis.RedisError as exc:
|
||||
logger.error("Redis set_json failed: %s", exc)
|
||||
raise
|
||||
|
||||
|
||||
def get_json(key: str) -> Optional[Any]:
|
||||
"""Retrieve and deserialize a JSON value from Redis."""
|
||||
client = get_redis_client()
|
||||
try:
|
||||
data = client.get(key)
|
||||
return json.loads(data) if data is not None else None
|
||||
except redis.RedisError as exc:
|
||||
logger.error("Redis get_json failed: %s", exc)
|
||||
raise
|
||||
|
||||
|
||||
def delete_key(key: str) -> int:
|
||||
"""Delete a key from Redis. Returns number of deleted keys."""
|
||||
client = get_redis_client()
|
||||
try:
|
||||
return client.delete(key)
|
||||
except redis.RedisError as exc:
|
||||
logger.error("Redis delete_key failed: %s", exc)
|
||||
raise
|
||||
38
backend/requirements.txt
Normal file
38
backend/requirements.txt
Normal file
@@ -0,0 +1,38 @@
|
||||
# Web framework
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
python-multipart
|
||||
|
||||
# Database
|
||||
sqlalchemy
|
||||
psycopg2-binary
|
||||
alembic
|
||||
|
||||
# Cache / Task queue
|
||||
redis
|
||||
celery
|
||||
|
||||
# Object storage
|
||||
minio
|
||||
|
||||
# Image / Video / DICOM processing
|
||||
opencv-python
|
||||
pillow
|
||||
scikit-image
|
||||
pydicom
|
||||
numpy
|
||||
|
||||
# SAM 2 (may require manual installation depending on CUDA version)
|
||||
sam2
|
||||
|
||||
# PyTorch (CUDA 12.4 wheel index; adjust for your CUDA version if needed)
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
||||
|
||||
# Configuration
|
||||
pydantic-settings
|
||||
|
||||
# Utilities
|
||||
python-jose[cryptography]
|
||||
passlib[bcrypt]
|
||||
0
backend/routers/__init__.py
Normal file
0
backend/routers/__init__.py
Normal file
123
backend/routers/ai.py
Normal file
123
backend/routers/ai.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""AI inference endpoints using SAM 2."""
|
||||
|
||||
import logging
|
||||
from typing import Any, List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import download_file
|
||||
from models import Frame, Annotation
|
||||
from schemas import PredictRequest, PredictResponse, AnnotationOut, AnnotationCreate
|
||||
from services.sam2_engine import sam_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/ai", tags=["AI"])
|
||||
|
||||
|
||||
def _load_frame_image(frame: Frame) -> np.ndarray:
|
||||
"""Download a frame from MinIO and decode it to an RGB numpy array."""
|
||||
try:
|
||||
data = download_file(frame.image_url)
|
||||
arr = np.frombuffer(data, dtype=np.uint8)
|
||||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise ValueError("OpenCV could not decode image")
|
||||
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to load frame image: %s", exc)
|
||||
raise HTTPException(status_code=500, detail="Failed to load frame image") from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/predict",
|
||||
response_model=PredictResponse,
|
||||
summary="Run SAM 2 inference with a prompt",
|
||||
)
|
||||
def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict:
|
||||
"""Execute SAM 2 segmentation given an image and a prompt.
|
||||
|
||||
- **point**: `prompt_data` is a list of `[[x, y], ...]` normalized coordinates.
|
||||
- **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates.
|
||||
- **semantic**: Not yet implemented; falls back to auto segmentation.
|
||||
"""
|
||||
frame = db.query(Frame).filter(Frame.id == payload.image_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
image = _load_frame_image(frame)
|
||||
prompt_type = payload.prompt_type.lower()
|
||||
|
||||
polygons: List[List[List[float]]] = []
|
||||
scores: List[float] = []
|
||||
|
||||
if prompt_type == "point":
|
||||
points = payload.prompt_data
|
||||
if not isinstance(points, list) or len(points) == 0:
|
||||
raise HTTPException(status_code=400, detail="Invalid point prompt data")
|
||||
labels = [1] * len(points)
|
||||
polygons, scores = sam_engine.predict_points(image, points, labels)
|
||||
|
||||
elif prompt_type == "box":
|
||||
box = payload.prompt_data
|
||||
if not isinstance(box, list) or len(box) != 4:
|
||||
raise HTTPException(status_code=400, detail="Invalid box prompt data")
|
||||
polygons, scores = sam_engine.predict_box(image, box)
|
||||
|
||||
elif prompt_type == "semantic":
|
||||
# Placeholder: use auto segmentation for now
|
||||
logger.info("Semantic prompt not implemented; using auto segmentation")
|
||||
polygons, scores = sam_engine.predict_auto(image)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}")
|
||||
|
||||
return {"polygons": polygons, "scores": scores}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/auto",
|
||||
response_model=PredictResponse,
|
||||
summary="Run automatic segmentation",
|
||||
)
|
||||
def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict:
|
||||
"""Run automatic mask generation on a frame using a grid of point prompts."""
|
||||
frame = db.query(Frame).filter(Frame.id == image_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
image = _load_frame_image(frame)
|
||||
polygons, scores = sam_engine.predict_auto(image)
|
||||
|
||||
return {"polygons": polygons, "scores": scores}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/annotate",
|
||||
response_model=AnnotationOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Save an AI-generated annotation",
|
||||
)
|
||||
def save_annotation(
|
||||
payload: AnnotationCreate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Annotation:
|
||||
"""Persist an annotation (mask, points, bbox) into the database."""
|
||||
project = db.query(Frame).filter(Frame.id == payload.project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
if payload.frame_id:
|
||||
frame = db.query(Frame).filter(Frame.id == payload.frame_id).first()
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
|
||||
annotation = Annotation(**payload.model_dump())
|
||||
db.add(annotation)
|
||||
db.commit()
|
||||
db.refresh(annotation)
|
||||
logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id)
|
||||
return annotation
|
||||
24
backend/routers/auth.py
Normal file
24
backend/routers/auth.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Authentication endpoints."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["Auth"])
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
token: str
|
||||
username: str
|
||||
|
||||
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
def login(payload: LoginRequest) -> dict:
|
||||
"""Simple login for development."""
|
||||
if payload.username == "admin" and payload.password == "123456":
|
||||
return {"token": "fake-jwt-token-for-admin", "username": payload.username}
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
194
backend/routers/export.py
Normal file
194
backend/routers/export.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Annotation export endpoints (COCO, PNG masks)."""
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Project, Annotation, Frame, Template
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/export", tags=["Export"])
|
||||
|
||||
|
||||
def _mask_from_polygon(
|
||||
polygon: List[List[float]],
|
||||
width: int,
|
||||
height: int,
|
||||
) -> np.ndarray:
|
||||
"""Render a normalized polygon to a binary mask."""
|
||||
import cv2
|
||||
|
||||
pts = np.array(
|
||||
[[int(p[0] * width), int(p[1] * height)] for p in polygon],
|
||||
dtype=np.int32,
|
||||
)
|
||||
mask = np.zeros((height, width), dtype=np.uint8)
|
||||
cv2.fillPoly(mask, [pts], 255)
|
||||
return mask
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/coco",
|
||||
summary="Export annotations in COCO format",
|
||||
)
|
||||
def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
|
||||
"""Export all annotations for a project as a COCO-format JSON file."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == project_id)
|
||||
.all()
|
||||
)
|
||||
frames = (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id)
|
||||
.order_by(Frame.frame_index)
|
||||
.all()
|
||||
)
|
||||
templates = db.query(Template).all()
|
||||
|
||||
# Build COCO structure
|
||||
images = []
|
||||
for idx, frame in enumerate(frames):
|
||||
images.append({
|
||||
"id": frame.id,
|
||||
"file_name": frame.image_url,
|
||||
"width": frame.width or 1920,
|
||||
"height": frame.height or 1080,
|
||||
"frame_index": idx,
|
||||
})
|
||||
|
||||
categories = []
|
||||
template_id_to_cat_id: Dict[int, int] = {}
|
||||
for cat_idx, tmpl in enumerate(templates, start=1):
|
||||
categories.append({
|
||||
"id": cat_idx,
|
||||
"name": tmpl.name,
|
||||
"color": tmpl.color,
|
||||
})
|
||||
template_id_to_cat_id[tmpl.id] = cat_idx
|
||||
|
||||
coco_annotations = []
|
||||
ann_id = 1
|
||||
for ann in annotations:
|
||||
if not ann.mask_data:
|
||||
continue
|
||||
polygons = ann.mask_data.get("polygons", [])
|
||||
if not polygons:
|
||||
continue
|
||||
|
||||
# Use first polygon for bbox / area approximation
|
||||
first_poly = polygons[0]
|
||||
xs = [p[0] for p in first_poly]
|
||||
ys = [p[1] for p in first_poly]
|
||||
width = ann.frame.width if ann.frame else 1920
|
||||
height = ann.frame.height if ann.frame else 1080
|
||||
bbox = [
|
||||
min(xs) * width,
|
||||
min(ys) * height,
|
||||
(max(xs) - min(xs)) * width,
|
||||
(max(ys) - min(ys)) * height,
|
||||
]
|
||||
area = bbox[2] * bbox[3]
|
||||
|
||||
segmentation = []
|
||||
for poly in polygons:
|
||||
flat = []
|
||||
for p in poly:
|
||||
flat.append(p[0] * width)
|
||||
flat.append(p[1] * height)
|
||||
segmentation.append(flat)
|
||||
|
||||
coco_annotations.append({
|
||||
"id": ann_id,
|
||||
"image_id": ann.frame_id,
|
||||
"category_id": template_id_to_cat_id.get(ann.template_id, 0),
|
||||
"segmentation": segmentation,
|
||||
"area": area,
|
||||
"bbox": bbox,
|
||||
"iscrowd": 0,
|
||||
})
|
||||
ann_id += 1
|
||||
|
||||
coco = {
|
||||
"info": {
|
||||
"description": f"Annotations for {project.name}",
|
||||
"version": "1.0",
|
||||
"year": datetime.now().year,
|
||||
"date_created": datetime.now().isoformat(),
|
||||
},
|
||||
"images": images,
|
||||
"annotations": coco_annotations,
|
||||
"categories": categories,
|
||||
}
|
||||
|
||||
data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8")
|
||||
filename = f"project_{project_id}_coco.json"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(data),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/masks",
|
||||
summary="Export PNG masks as a ZIP archive",
|
||||
)
|
||||
def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse:
|
||||
"""Export all annotation masks as individual PNG files inside a ZIP archive."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
annotations = (
|
||||
db.query(Annotation)
|
||||
.filter(Annotation.project_id == project_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for ann in annotations:
|
||||
if not ann.mask_data:
|
||||
continue
|
||||
polygons = ann.mask_data.get("polygons", [])
|
||||
if not polygons:
|
||||
continue
|
||||
|
||||
width = ann.frame.width if ann.frame else 1920
|
||||
height = ann.frame.height if ann.frame else 1080
|
||||
combined = np.zeros((height, width), dtype=np.uint8)
|
||||
|
||||
for poly in polygons:
|
||||
mask = _mask_from_polygon(poly, width, height)
|
||||
combined = np.maximum(combined, mask)
|
||||
|
||||
# Encode PNG
|
||||
import cv2
|
||||
_, encoded = cv2.imencode(".png", combined)
|
||||
fname = f"mask_{ann.id:06d}.png"
|
||||
zf.writestr(fname, encoded.tobytes())
|
||||
|
||||
zip_buffer.seek(0)
|
||||
filename = f"project_{project_id}_masks.zip"
|
||||
|
||||
return StreamingResponse(
|
||||
zip_buffer,
|
||||
media_type="application/zip",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
192
backend/routers/media.py
Normal file
192
backend/routers/media.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Media upload and parsing endpoints."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from minio_client import upload_file, get_presigned_url
|
||||
from models import Project, Frame
|
||||
from schemas import FrameOut
|
||||
from services.frame_parser import parse_video, parse_dicom, upload_frames_to_minio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/media", tags=["Media"])
|
||||
|
||||
ALLOWED_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".png", ".jpg", ".jpeg", ".dcm"}
|
||||
|
||||
|
||||
def _get_ext(filename: str) -> str:
|
||||
return Path(filename).suffix.lower()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/upload",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Upload a media file",
|
||||
)
|
||||
async def upload_media(
|
||||
file: UploadFile = File(...),
|
||||
project_id: Optional[int] = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Accept a video, image, or DICOM file and store it in MinIO.
|
||||
|
||||
If project_id is provided, the video_path of the project is updated.
|
||||
Returns the presigned URL of the uploaded object.
|
||||
"""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="Missing filename")
|
||||
|
||||
ext = _get_ext(file.filename)
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported file type: {ext}",
|
||||
)
|
||||
|
||||
data = await file.read()
|
||||
object_name = f"uploads/{project_id or 'general'}/{file.filename}"
|
||||
|
||||
try:
|
||||
upload_file(object_name, data, content_type=file.content_type or "application/octet-stream", length=len(data))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Upload failed: %s", exc)
|
||||
raise HTTPException(status_code=500, detail="Upload to storage failed") from exc
|
||||
|
||||
file_url = get_presigned_url(object_name, expires=3600)
|
||||
|
||||
if project_id:
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if project:
|
||||
project.video_path = object_name
|
||||
db.commit()
|
||||
logger.info("Linked upload to project_id=%s", project_id)
|
||||
else:
|
||||
logger.warning("Project id=%s not found for upload linkage", project_id)
|
||||
|
||||
# TODO: enqueue async parsing job (Celery / background task)
|
||||
logger.info("Upload complete: %s (size=%d bytes). Async parsing queued.", object_name, len(data))
|
||||
|
||||
return {
|
||||
"object_name": object_name,
|
||||
"file_url": file_url,
|
||||
"size": len(data),
|
||||
"message": "Upload successful. Parsing job queued.",
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/parse",
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
summary="Trigger frame extraction",
|
||||
)
|
||||
def parse_media(
|
||||
project_id: int,
|
||||
source_type: str = "video", # video | dicom
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Trigger frame extraction for a project's uploaded media.
|
||||
|
||||
* video: uses FFmpeg or OpenCV fallback.
|
||||
* dicom: uses pydicom to read DCM frames.
|
||||
|
||||
Extracted frames are uploaded to MinIO and registered in the database.
|
||||
"""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
if not project.video_path:
|
||||
raise HTTPException(status_code=400, detail="Project has no media uploaded")
|
||||
|
||||
# Download from MinIO to a temp directory
|
||||
from minio_client import download_file
|
||||
|
||||
try:
|
||||
media_bytes = download_file(project.video_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to download media for parsing: %s", exc)
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve media from storage") from exc
|
||||
|
||||
tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project_id}_")
|
||||
local_path = os.path.join(tmp_dir, Path(project.video_path).name)
|
||||
|
||||
with open(local_path, "wb") as f:
|
||||
f.write(media_bytes)
|
||||
|
||||
output_dir = os.path.join(tmp_dir, "frames")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
if source_type == "dicom":
|
||||
# For DICOM, treat local_path as a directory if it contains multiple .dcm
|
||||
# If a single .dcm file was uploaded, put it in its own sub-dir
|
||||
dcm_dir = os.path.join(tmp_dir, "dcm")
|
||||
os.makedirs(dcm_dir, exist_ok=True)
|
||||
if local_path.lower().endswith(".dcm"):
|
||||
shutil.move(local_path, os.path.join(dcm_dir, os.path.basename(local_path)))
|
||||
else:
|
||||
shutil.unpack_archive(local_path, dcm_dir) if shutil.which("unzip") else shutil.move(local_path, dcm_dir)
|
||||
frame_files = parse_dicom(dcm_dir, output_dir)
|
||||
else:
|
||||
frame_files = parse_video(local_path, output_dir, fps=30)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Frame extraction failed: %s", exc)
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
raise HTTPException(status_code=500, detail="Frame extraction failed") from exc
|
||||
|
||||
# Upload frames to MinIO
|
||||
try:
|
||||
object_names = upload_frames_to_minio(frame_files, project_id)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Frame upload failed: %s", exc)
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
raise HTTPException(status_code=500, detail="Frame upload to storage failed") from exc
|
||||
|
||||
# Register frames in DB
|
||||
frames_out = []
|
||||
for idx, obj_name in enumerate(object_names):
|
||||
# Get image dimensions
|
||||
local_frame = frame_files[idx]
|
||||
try:
|
||||
import cv2
|
||||
img = cv2.imread(local_frame)
|
||||
h, w = img.shape[:2] if img is not None else (None, None)
|
||||
except Exception: # noqa: BLE001
|
||||
h, w = None, None
|
||||
|
||||
frame = Frame(
|
||||
project_id=project_id,
|
||||
frame_index=idx,
|
||||
image_url=obj_name,
|
||||
width=w,
|
||||
height=h,
|
||||
)
|
||||
db.add(frame)
|
||||
frames_out.append(frame)
|
||||
|
||||
db.commit()
|
||||
for f in frames_out:
|
||||
db.refresh(f)
|
||||
|
||||
# Cleanup temp files
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
|
||||
project.status = "ready"
|
||||
db.commit()
|
||||
|
||||
logger.info("Parsed %d frames for project_id=%s", len(frames_out), project_id)
|
||||
return {
|
||||
"project_id": project_id,
|
||||
"frames_extracted": len(frames_out),
|
||||
"status": "ready",
|
||||
"message": "Frame extraction completed successfully.",
|
||||
}
|
||||
165
backend/routers/projects.py
Normal file
165
backend/routers/projects.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Project and Frame CRUD endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Project, Frame
|
||||
from schemas import ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/projects", tags=["Projects"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Projects
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post(
|
||||
"",
|
||||
response_model=ProjectOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a new project",
|
||||
)
|
||||
def create_project(payload: ProjectCreate, db: Session = Depends(get_db)) -> Project:
|
||||
"""Create a new segmentation project."""
|
||||
project = Project(**payload.model_dump())
|
||||
db.add(project)
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
logger.info("Created project id=%s name=%s", project.id, project.name)
|
||||
return project
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=List[ProjectOut],
|
||||
summary="List all projects",
|
||||
)
|
||||
def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)) -> List[Project]:
|
||||
"""Retrieve a paginated list of projects."""
|
||||
return db.query(Project).offset(skip).limit(limit).all()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}",
|
||||
response_model=ProjectOut,
|
||||
summary="Get a single project",
|
||||
)
|
||||
def get_project(project_id: int, db: Session = Depends(get_db)) -> Project:
|
||||
"""Retrieve a project by its ID."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
return project
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{project_id}",
|
||||
response_model=ProjectOut,
|
||||
summary="Update a project",
|
||||
)
|
||||
def update_project(
|
||||
project_id: int,
|
||||
payload: ProjectUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Project:
|
||||
"""Update project fields partially."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
for key, value in payload.model_dump(exclude_unset=True).items():
|
||||
setattr(project, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
logger.info("Updated project id=%s", project_id)
|
||||
return project
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{project_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a project",
|
||||
)
|
||||
def delete_project(project_id: int, db: Session = Depends(get_db)) -> None:
|
||||
"""Delete a project and all related frames and annotations."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
db.delete(project)
|
||||
db.commit()
|
||||
logger.info("Deleted project id=%s", project_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frames
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post(
|
||||
"/{project_id}/frames",
|
||||
response_model=FrameOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Add a frame to a project",
|
||||
)
|
||||
def create_frame(
|
||||
project_id: int,
|
||||
payload: FrameCreate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Frame:
|
||||
"""Register a new frame under a project."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
frame = Frame(project_id=project_id, **payload.model_dump(exclude={"project_id"}))
|
||||
db.add(frame)
|
||||
db.commit()
|
||||
db.refresh(frame)
|
||||
return frame
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/frames",
|
||||
response_model=List[FrameOut],
|
||||
summary="List frames for a project",
|
||||
)
|
||||
def list_frames(
|
||||
project_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 1000,
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[Frame]:
|
||||
"""Retrieve all frames belonging to a project."""
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
return (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id)
|
||||
.order_by(Frame.frame_index)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/frames/{frame_id}",
|
||||
response_model=FrameOut,
|
||||
summary="Get a single frame",
|
||||
)
|
||||
def get_frame(project_id: int, frame_id: int, db: Session = Depends(get_db)) -> Frame:
|
||||
"""Retrieve a specific frame by ID."""
|
||||
frame = (
|
||||
db.query(Frame)
|
||||
.filter(Frame.project_id == project_id, Frame.id == frame_id)
|
||||
.first()
|
||||
)
|
||||
if not frame:
|
||||
raise HTTPException(status_code=404, detail="Frame not found")
|
||||
return frame
|
||||
97
backend/routers/templates.py
Normal file
97
backend/routers/templates.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Template (Ontology) CRUD endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from database import get_db
|
||||
from models import Template
|
||||
from schemas import TemplateCreate, TemplateOut, TemplateUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/templates", tags=["Templates"])
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=TemplateOut,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a new template",
|
||||
)
|
||||
def create_template(payload: TemplateCreate, db: Session = Depends(get_db)) -> Template:
|
||||
"""Create a new ontology template / segmentation class."""
|
||||
template = Template(**payload.model_dump())
|
||||
db.add(template)
|
||||
db.commit()
|
||||
db.refresh(template)
|
||||
logger.info("Created template id=%s name=%s", template.id, template.name)
|
||||
return template
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=List[TemplateOut],
|
||||
summary="List all templates",
|
||||
)
|
||||
def list_templates(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[Template]:
|
||||
"""Retrieve all ontology templates."""
|
||||
return db.query(Template).offset(skip).limit(limit).all()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{template_id}",
|
||||
response_model=TemplateOut,
|
||||
summary="Get a single template",
|
||||
)
|
||||
def get_template(template_id: int, db: Session = Depends(get_db)) -> Template:
|
||||
"""Retrieve a template by its ID."""
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
return template
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{template_id}",
|
||||
response_model=TemplateOut,
|
||||
summary="Update a template",
|
||||
)
|
||||
def update_template(
|
||||
template_id: int,
|
||||
payload: TemplateUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Template:
|
||||
"""Update template fields partially."""
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
for key, value in payload.model_dump(exclude_unset=True).items():
|
||||
setattr(template, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(template)
|
||||
logger.info("Updated template id=%s", template_id)
|
||||
return template
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{template_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a template",
|
||||
)
|
||||
def delete_template(template_id: int, db: Session = Depends(get_db)) -> None:
|
||||
"""Delete a template. Associated annotations will have template_id set to NULL."""
|
||||
template = db.query(Template).filter(Template.id == template_id).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Template not found")
|
||||
|
||||
db.delete(template)
|
||||
db.commit()
|
||||
logger.info("Deleted template id=%s", template_id)
|
||||
157
backend/schemas.py
Normal file
157
backend/schemas.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""Pydantic schemas for request/response validation."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional, Any
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Project schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class ProjectBase(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
video_path: Optional[str] = None
|
||||
status: Optional[str] = "pending"
|
||||
|
||||
|
||||
class ProjectCreate(ProjectBase):
|
||||
pass
|
||||
|
||||
|
||||
class ProjectUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
video_path: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
|
||||
|
||||
class ProjectOut(ProjectBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frame schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class FrameBase(BaseModel):
|
||||
frame_index: int
|
||||
image_url: str
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
|
||||
|
||||
class FrameCreate(FrameBase):
|
||||
project_id: int
|
||||
|
||||
|
||||
class FrameOut(FrameBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
project_id: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Template schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class TemplateBase(BaseModel):
|
||||
name: str
|
||||
color: str
|
||||
z_index: int = 0
|
||||
mapping_rules: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class TemplateCreate(TemplateBase):
|
||||
pass
|
||||
|
||||
|
||||
class TemplateUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
z_index: Optional[int] = None
|
||||
mapping_rules: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class TemplateOut(TemplateBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Annotation schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class AnnotationBase(BaseModel):
|
||||
project_id: int
|
||||
frame_id: Optional[int] = None
|
||||
template_id: Optional[int] = None
|
||||
mask_data: Optional[dict[str, Any]] = None
|
||||
points: Optional[list[list[float]]] = None
|
||||
bbox: Optional[list[float]] = None
|
||||
|
||||
|
||||
class AnnotationCreate(AnnotationBase):
|
||||
pass
|
||||
|
||||
|
||||
class AnnotationUpdate(BaseModel):
|
||||
mask_data: Optional[dict[str, Any]] = None
|
||||
points: Optional[list[list[float]]] = None
|
||||
bbox: Optional[list[float]] = None
|
||||
template_id: Optional[int] = None
|
||||
|
||||
|
||||
class AnnotationOut(AnnotationBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mask schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class MaskBase(BaseModel):
|
||||
annotation_id: int
|
||||
mask_url: str
|
||||
format: str = "png"
|
||||
|
||||
|
||||
class MaskCreate(MaskBase):
|
||||
pass
|
||||
|
||||
|
||||
class MaskOut(MaskBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AI schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class PredictRequest(BaseModel):
|
||||
image_id: int
|
||||
prompt_type: str # point / box / semantic
|
||||
prompt_data: Any
|
||||
|
||||
|
||||
class PredictResponse(BaseModel):
|
||||
polygons: list[list[list[float]]]
|
||||
scores: Optional[list[float]] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Export schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
class ExportStatus(BaseModel):
|
||||
url: str
|
||||
format: str
|
||||
0
backend/services/__init__.py
Normal file
0
backend/services/__init__.py
Normal file
186
backend/services/frame_parser.py
Normal file
186
backend/services/frame_parser.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""Video/DICOM frame parsing and MinIO upload utilities."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pydicom import dcmread
|
||||
|
||||
from minio_client import upload_file, BUCKET_NAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_video(
|
||||
video_path: str,
|
||||
output_dir: str,
|
||||
fps: int = 30,
|
||||
max_frames: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""Extract frames from a video file using FFmpeg or OpenCV fallback.
|
||||
|
||||
Args:
|
||||
video_path: Path to the input video file.
|
||||
output_dir: Directory to save extracted frames.
|
||||
fps: Target frame extraction rate.
|
||||
max_frames: Optional maximum number of frames to extract.
|
||||
|
||||
Returns:
|
||||
List of paths to extracted frame images.
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
frame_paths: List[str] = []
|
||||
|
||||
# Try FFmpeg first
|
||||
if shutil.which("ffmpeg"):
|
||||
try:
|
||||
pattern = os.path.join(output_dir, "frame_%06d.png")
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-i", video_path,
|
||||
"-vf", f"fps={fps},scale='min(1920,iw)':-1",
|
||||
"-pix_fmt", "rgb24",
|
||||
"-y",
|
||||
pattern,
|
||||
]
|
||||
logger.info("Running FFmpeg: %s", " ".join(cmd))
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
|
||||
if result.returncode == 0:
|
||||
frame_paths = sorted(
|
||||
[os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith(".png")]
|
||||
)
|
||||
if max_frames:
|
||||
frame_paths = frame_paths[:max_frames]
|
||||
logger.info("Extracted %d frames via FFmpeg", len(frame_paths))
|
||||
return frame_paths
|
||||
else:
|
||||
logger.warning("FFmpeg failed: %s", result.stderr)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("FFmpeg exception: %s", exc)
|
||||
|
||||
# OpenCV fallback
|
||||
logger.info("Falling back to OpenCV frame extraction")
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
raise RuntimeError(f"Cannot open video: {video_path}")
|
||||
|
||||
video_fps = cap.get(cv2.CAP_PROP_FPS) or 30
|
||||
interval = max(1, int(round(video_fps / fps)))
|
||||
count = 0
|
||||
saved = 0
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
if count % interval == 0:
|
||||
path = os.path.join(output_dir, f"frame_{saved:06d}.png")
|
||||
cv2.imwrite(path, frame)
|
||||
frame_paths.append(path)
|
||||
saved += 1
|
||||
if max_frames and saved >= max_frames:
|
||||
break
|
||||
count += 1
|
||||
|
||||
cap.release()
|
||||
logger.info("Extracted %d frames via OpenCV", len(frame_paths))
|
||||
return frame_paths
|
||||
|
||||
|
||||
def parse_dicom(
|
||||
dicom_dir: str,
|
||||
output_dir: str,
|
||||
max_frames: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""Extract frames from DICOM files in a directory.
|
||||
|
||||
Args:
|
||||
dicom_dir: Directory containing .dcm files.
|
||||
output_dir: Directory to save extracted frames.
|
||||
max_frames: Optional maximum number of frames to extract.
|
||||
|
||||
Returns:
|
||||
List of paths to extracted frame images.
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
dcm_files = sorted(
|
||||
[f for f in os.listdir(dicom_dir) if f.lower().endswith(".dcm")]
|
||||
)
|
||||
|
||||
frame_paths: List[str] = []
|
||||
for idx, fname in enumerate(dcm_files):
|
||||
if max_frames and idx >= max_frames:
|
||||
break
|
||||
path = os.path.join(dicom_dir, fname)
|
||||
try:
|
||||
ds = dcmread(path)
|
||||
pixel_array = ds.pixel_array
|
||||
|
||||
# Normalize to 8-bit
|
||||
if pixel_array.dtype != np.uint8:
|
||||
pixel_array = pixel_array.astype(np.float32)
|
||||
pixel_array = (
|
||||
(pixel_array - pixel_array.min())
|
||||
/ (pixel_array.max() - pixel_array.min() + 1e-8)
|
||||
* 255
|
||||
)
|
||||
pixel_array = pixel_array.astype(np.uint8)
|
||||
|
||||
# Handle multi-frame DICOM
|
||||
if pixel_array.ndim == 3:
|
||||
for f in range(pixel_array.shape[0]):
|
||||
out_path = os.path.join(output_dir, f"frame_{idx:06d}_{f:03d}.png")
|
||||
cv2.imwrite(out_path, pixel_array[f])
|
||||
frame_paths.append(out_path)
|
||||
else:
|
||||
out_path = os.path.join(output_dir, f"frame_{idx:06d}.png")
|
||||
cv2.imwrite(out_path, pixel_array)
|
||||
frame_paths.append(out_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to read DICOM %s: %s", path, exc)
|
||||
|
||||
logger.info("Extracted %d frames from DICOM", len(frame_paths))
|
||||
return frame_paths
|
||||
|
||||
|
||||
def upload_frames_to_minio(
|
||||
frames: List[str],
|
||||
project_id: int,
|
||||
object_prefix: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""Upload a list of local frame images to MinIO.
|
||||
|
||||
Args:
|
||||
frames: List of local file paths.
|
||||
project_id: Project ID used for bucket path organization.
|
||||
object_prefix: Optional prefix override.
|
||||
|
||||
Returns:
|
||||
List of object names (paths) in MinIO.
|
||||
"""
|
||||
prefix = object_prefix or f"projects/{project_id}/frames"
|
||||
object_names: List[str] = []
|
||||
|
||||
for frame_path in frames:
|
||||
fname = os.path.basename(frame_path)
|
||||
object_name = f"{prefix}/{fname}"
|
||||
try:
|
||||
with open(frame_path, "rb") as f:
|
||||
data = f.read()
|
||||
upload_file(
|
||||
object_name,
|
||||
data,
|
||||
content_type="image/png",
|
||||
length=len(data),
|
||||
)
|
||||
object_names.append(object_name)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to upload %s: %s", frame_path, exc)
|
||||
|
||||
logger.info("Uploaded %d/%d frames to MinIO", len(object_names), len(frames))
|
||||
return object_names
|
||||
234
backend/services/sam2_engine.py
Normal file
234
backend/services/sam2_engine.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""SAM 2 engine wrapper with lazy loading and fallback stubs."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attempt to import SAM 2; fall back to stubs if unavailable.
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
import torch
|
||||
from sam2.build_sam import build_sam2
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
|
||||
SAM2_AVAILABLE = True
|
||||
logger.info("SAM2 library imported successfully.")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
SAM2_AVAILABLE = False
|
||||
logger.warning("SAM2 import failed (%s). Using stub engine.", exc)
|
||||
|
||||
|
||||
class SAM2Engine:
|
||||
"""Lazy-loaded SAM 2 inference engine."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._predictor: Optional[SAM2ImagePredictor] = None
|
||||
self._model_loaded = False
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# -----------------------------------------------------------------------
|
||||
def _load_model(self) -> None:
|
||||
"""Load the SAM 2 model and predictor on first use."""
|
||||
if self._model_loaded:
|
||||
return
|
||||
|
||||
if not SAM2_AVAILABLE:
|
||||
logger.warning("SAM2 not available; skipping model load.")
|
||||
self._model_loaded = True
|
||||
return
|
||||
|
||||
if not os.path.isfile(settings.sam_model_path):
|
||||
logger.error("SAM checkpoint not found at %s", settings.sam_model_path)
|
||||
self._model_loaded = True
|
||||
return
|
||||
|
||||
try:
|
||||
model = build_sam2(
|
||||
settings.sam_model_config,
|
||||
settings.sam_model_path,
|
||||
device="cuda",
|
||||
)
|
||||
self._predictor = SAM2ImagePredictor(model)
|
||||
self._model_loaded = True
|
||||
logger.info("SAM 2 model loaded from %s", settings.sam_model_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("Failed to load SAM 2 model: %s", exc)
|
||||
self._model_loaded = True # Prevent repeated load attempts
|
||||
|
||||
def _ensure_ready(self) -> bool:
|
||||
"""Ensure the model is loaded; return whether it is usable."""
|
||||
self._load_model()
|
||||
return SAM2_AVAILABLE and self._predictor is not None
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Public API
|
||||
# -----------------------------------------------------------------------
|
||||
def predict_points(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
points: list[list[float]],
|
||||
labels: list[int],
|
||||
) -> tuple[list[list[list[float]]], list[float]]:
|
||||
"""Run point-prompt segmentation.
|
||||
|
||||
Args:
|
||||
image: HWC numpy array (uint8).
|
||||
points: List of [x, y] normalized coordinates (0-1).
|
||||
labels: 1 for foreground, 0 for background.
|
||||
|
||||
Returns:
|
||||
Tuple of (polygons, scores).
|
||||
"""
|
||||
if not self._ensure_ready():
|
||||
logger.warning("SAM2 not ready; returning dummy masks.")
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
try:
|
||||
h, w = image.shape[:2]
|
||||
pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32)
|
||||
lbls = np.array(labels, dtype=np.int32)
|
||||
|
||||
with torch.inference_mode(): # type: ignore[name-defined]
|
||||
self._predictor.set_image(image)
|
||||
masks, scores, _ = self._predictor.predict(
|
||||
point_coords=pts,
|
||||
point_labels=lbls,
|
||||
multimask_output=True,
|
||||
)
|
||||
|
||||
polygons = []
|
||||
for m in masks:
|
||||
poly = self._mask_to_polygon(m)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
return polygons, scores.tolist()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("SAM2 point prediction failed: %s", exc)
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
def predict_box(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
box: list[float],
|
||||
) -> tuple[list[list[list[float]]], list[float]]:
|
||||
"""Run box-prompt segmentation.
|
||||
|
||||
Args:
|
||||
image: HWC numpy array (uint8).
|
||||
box: [x1, y1, x2, y2] normalized coordinates.
|
||||
|
||||
Returns:
|
||||
Tuple of (polygons, scores).
|
||||
"""
|
||||
if not self._ensure_ready():
|
||||
logger.warning("SAM2 not ready; returning dummy masks.")
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
try:
|
||||
h, w = image.shape[:2]
|
||||
bbox = np.array(
|
||||
[box[0] * w, box[1] * h, box[2] * w, box[3] * h],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
with torch.inference_mode(): # type: ignore[name-defined]
|
||||
self._predictor.set_image(image)
|
||||
masks, scores, _ = self._predictor.predict(
|
||||
box=bbox[None, :],
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
polygons = []
|
||||
for m in masks:
|
||||
poly = self._mask_to_polygon(m)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
return polygons, scores.tolist()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("SAM2 box prediction failed: %s", exc)
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
def predict_auto(self, image: np.ndarray) -> tuple[list[list[list[float]]], list[float]]:
|
||||
"""Run automatic mask generation (grid of points).
|
||||
|
||||
Args:
|
||||
image: HWC numpy array (uint8).
|
||||
|
||||
Returns:
|
||||
Tuple of (polygons, scores).
|
||||
"""
|
||||
if not self._ensure_ready():
|
||||
logger.warning("SAM2 not ready; returning dummy masks.")
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
try:
|
||||
with torch.inference_mode(): # type: ignore[name-defined]
|
||||
self._predictor.set_image(image)
|
||||
# Generate a uniform 16x16 grid of point prompts
|
||||
h, w = image.shape[:2]
|
||||
grid = np.mgrid[0:1:17j, 0:1:17j].reshape(2, -1).T
|
||||
pts = grid * np.array([w, h])
|
||||
lbls = np.ones(pts.shape[0], dtype=np.int32)
|
||||
|
||||
masks, scores, _ = self._predictor.predict(
|
||||
point_coords=pts,
|
||||
point_labels=lbls,
|
||||
multimask_output=True,
|
||||
)
|
||||
|
||||
polygons = []
|
||||
for m in masks[:3]: # Limit to top 3 masks
|
||||
poly = self._mask_to_polygon(m)
|
||||
if poly:
|
||||
polygons.append(poly)
|
||||
|
||||
return polygons, scores[:3].tolist()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("SAM2 auto prediction failed: %s", exc)
|
||||
return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5]
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -----------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]:
|
||||
"""Convert a binary mask to a normalized polygon."""
|
||||
import cv2
|
||||
|
||||
if mask.dtype != np.uint8:
|
||||
mask = (mask > 0).astype(np.uint8)
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
h, w = mask.shape[:2]
|
||||
largest = []
|
||||
for cnt in contours:
|
||||
if len(cnt) > len(largest):
|
||||
largest = cnt
|
||||
if len(largest) < 3:
|
||||
return []
|
||||
return [[float(pt[0][0]) / w, float(pt[0][1]) / h] for pt in largest]
|
||||
|
||||
@staticmethod
|
||||
def _dummy_polygons(w: int, h: int) -> list[list[list[float]]]:
|
||||
"""Return a dummy rectangle polygon for fallback mode."""
|
||||
return [
|
||||
[
|
||||
[0.25, 0.25],
|
||||
[0.75, 0.25],
|
||||
[0.75, 0.75],
|
||||
[0.25, 0.75],
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
sam_engine = SAM2Engine()
|
||||
146
package-lock.json
generated
146
package-lock.json
generated
@@ -11,6 +11,7 @@
|
||||
"@google/genai": "^1.29.0",
|
||||
"@tailwindcss/vite": "^4.1.14",
|
||||
"@vitejs/plugin-react": "^5.0.4",
|
||||
"axios": "^1.15.2",
|
||||
"clsx": "^2.1.1",
|
||||
"dotenv": "^17.2.3",
|
||||
"express": "^4.21.2",
|
||||
@@ -22,7 +23,8 @@
|
||||
"react-konva": "^19.2.3",
|
||||
"tailwind-merge": "^3.5.0",
|
||||
"use-image": "^1.1.4",
|
||||
"vite": "^6.2.0"
|
||||
"vite": "^6.2.0",
|
||||
"zustand": "^5.0.12"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/express": "^4.17.21",
|
||||
@@ -1670,6 +1672,12 @@
|
||||
"integrity": "sha512-PCVAQswWemu6UdxsDFFX/+gVeYqKAod3D3UVm91jHwynguOwAvYPhx8nNlM++NqRcK6CxxpUafjmhIdKiHibqg==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/asynckit": {
|
||||
"version": "0.4.0",
|
||||
"resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz",
|
||||
"integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/autoprefixer": {
|
||||
"version": "10.5.0",
|
||||
"resolved": "https://registry.npmjs.org/autoprefixer/-/autoprefixer-10.5.0.tgz",
|
||||
@@ -1707,6 +1715,17 @@
|
||||
"postcss": "^8.1.0"
|
||||
}
|
||||
},
|
||||
"node_modules/axios": {
|
||||
"version": "1.15.2",
|
||||
"resolved": "https://registry.npmjs.org/axios/-/axios-1.15.2.tgz",
|
||||
"integrity": "sha512-wLrXxPtcrPTsNlJmKjkPnNPK2Ihe0hn0wGSaTEiHRPxwjvJwT3hKmXF4dpqxmPO9SoNb2FsYXj/xEo0gHN+D5A==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"follow-redirects": "^1.15.11",
|
||||
"form-data": "^4.0.5",
|
||||
"proxy-from-env": "^2.1.0"
|
||||
}
|
||||
},
|
||||
"node_modules/base64-js": {
|
||||
"version": "1.5.1",
|
||||
"resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz",
|
||||
@@ -1908,6 +1927,18 @@
|
||||
"node": ">=6"
|
||||
}
|
||||
},
|
||||
"node_modules/combined-stream": {
|
||||
"version": "1.0.8",
|
||||
"resolved": "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz",
|
||||
"integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"delayed-stream": "~1.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.8"
|
||||
}
|
||||
},
|
||||
"node_modules/content-disposition": {
|
||||
"version": "0.5.4",
|
||||
"resolved": "https://registry.npmjs.org/content-disposition/-/content-disposition-0.5.4.tgz",
|
||||
@@ -1983,6 +2014,15 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/delayed-stream": {
|
||||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz",
|
||||
"integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=0.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/depd": {
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz",
|
||||
@@ -2110,6 +2150,21 @@
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/es-set-tostringtag": {
|
||||
"version": "2.1.0",
|
||||
"resolved": "https://registry.npmjs.org/es-set-tostringtag/-/es-set-tostringtag-2.1.0.tgz",
|
||||
"integrity": "sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"es-errors": "^1.3.0",
|
||||
"get-intrinsic": "^1.2.6",
|
||||
"has-tostringtag": "^1.0.2",
|
||||
"hasown": "^2.0.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/esbuild": {
|
||||
"version": "0.27.7",
|
||||
"resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.27.7.tgz",
|
||||
@@ -2316,6 +2371,42 @@
|
||||
"integrity": "sha512-Tpp60P6IUJDTuOq/5Z8cdskzJujfwqfOTkrwIwj7IRISpnkJnT6SyJ4PCPnGMoFjC9ddhal5KVIYtAt97ix05A==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/follow-redirects": {
|
||||
"version": "1.16.0",
|
||||
"resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.16.0.tgz",
|
||||
"integrity": "sha512-y5rN/uOsadFT/JfYwhxRS5R7Qce+g3zG97+JrtFZlC9klX/W5hD7iiLzScI4nZqUS7DNUdhPgw4xI8W2LuXlUw==",
|
||||
"funding": [
|
||||
{
|
||||
"type": "individual",
|
||||
"url": "https://github.com/sponsors/RubenVerborgh"
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=4.0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"debug": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/form-data": {
|
||||
"version": "4.0.5",
|
||||
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.5.tgz",
|
||||
"integrity": "sha512-8RipRLol37bNs2bhoV67fiTEvdTrbMUYcFTiy3+wuuOnUog2QBHCZWXDRijWQfAkhBj2Uf5UnVaiWwA5vdd82w==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"asynckit": "^0.4.0",
|
||||
"combined-stream": "^1.0.8",
|
||||
"es-set-tostringtag": "^2.1.0",
|
||||
"hasown": "^2.0.2",
|
||||
"mime-types": "^2.1.12"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 6"
|
||||
}
|
||||
},
|
||||
"node_modules/formdata-polyfill": {
|
||||
"version": "4.0.10",
|
||||
"resolved": "https://registry.npmjs.org/formdata-polyfill/-/formdata-polyfill-4.0.10.tgz",
|
||||
@@ -2553,6 +2644,21 @@
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/has-tostringtag": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.2.tgz",
|
||||
"integrity": "sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"has-symbols": "^1.0.3"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/hasown": {
|
||||
"version": "2.0.3",
|
||||
"resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.3.tgz",
|
||||
@@ -3346,6 +3452,15 @@
|
||||
"node": ">= 0.10"
|
||||
}
|
||||
},
|
||||
"node_modules/proxy-from-env": {
|
||||
"version": "2.1.0",
|
||||
"resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-2.1.0.tgz",
|
||||
"integrity": "sha512-cJ+oHTW1VAEa8cJslgmUZrc+sjRKgAKl3Zyse6+PV38hZe/V6Z14TbCuXcan9F9ghlz4QrFr2c92TNF82UkYHA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=10"
|
||||
}
|
||||
},
|
||||
"node_modules/qs": {
|
||||
"version": "6.14.2",
|
||||
"resolved": "https://registry.npmjs.org/qs/-/qs-6.14.2.tgz",
|
||||
@@ -4461,6 +4576,35 @@
|
||||
"resolved": "https://registry.npmjs.org/yallist/-/yallist-3.1.1.tgz",
|
||||
"integrity": "sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g==",
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/zustand": {
|
||||
"version": "5.0.12",
|
||||
"resolved": "https://registry.npmjs.org/zustand/-/zustand-5.0.12.tgz",
|
||||
"integrity": "sha512-i77ae3aZq4dhMlRhJVCYgMLKuSiZAaUPAct2AksxQ+gOtimhGMdXljRT21P5BNpeT4kXlLIckvkPM029OljD7g==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=12.20.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": ">=18.0.0",
|
||||
"immer": ">=9.0.6",
|
||||
"react": ">=18.0.0",
|
||||
"use-sync-external-store": ">=1.2.0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"immer": {
|
||||
"optional": true
|
||||
},
|
||||
"react": {
|
||||
"optional": true
|
||||
},
|
||||
"use-sync-external-store": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
"@google/genai": "^1.29.0",
|
||||
"@tailwindcss/vite": "^4.1.14",
|
||||
"@vitejs/plugin-react": "^5.0.4",
|
||||
"axios": "^1.15.2",
|
||||
"clsx": "^2.1.1",
|
||||
"dotenv": "^17.2.3",
|
||||
"express": "^4.21.2",
|
||||
@@ -26,7 +27,8 @@
|
||||
"react-konva": "^19.2.3",
|
||||
"tailwind-merge": "^3.5.0",
|
||||
"use-image": "^1.1.4",
|
||||
"vite": "^6.2.0"
|
||||
"vite": "^6.2.0",
|
||||
"zustand": "^5.0.12"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/express": "^4.17.21",
|
||||
|
||||
26
src/App.tsx
26
src/App.tsx
@@ -1,4 +1,6 @@
|
||||
import React, { useState } from 'react';
|
||||
import React, { useEffect } from 'react';
|
||||
import { useStore } from './store/useStore';
|
||||
import { getProjects } from './lib/api';
|
||||
import { Sidebar } from './components/Sidebar';
|
||||
import { Dashboard } from './components/Dashboard';
|
||||
import { ProjectLibrary } from './components/ProjectLibrary';
|
||||
@@ -10,16 +12,30 @@ import { Login } from './components/Login';
|
||||
export type ActiveModule = 'dashboard' | 'projects' | 'ai' | 'workspace' | 'templates';
|
||||
|
||||
export default function App() {
|
||||
const [activeModule, setActiveModule] = useState<ActiveModule>('workspace');
|
||||
const [isAuthenticated, setIsAuthenticated] = useState(false);
|
||||
const isAuthenticated = useStore((state) => state.isAuthenticated);
|
||||
const activeModule = useStore((state) => state.activeModule);
|
||||
const setActiveModule = useStore((state) => state.setActiveModule);
|
||||
const setProjects = useStore((state) => state.setProjects);
|
||||
const setError = useStore((state) => state.setError);
|
||||
|
||||
useEffect(() => {
|
||||
if (isAuthenticated) {
|
||||
getProjects()
|
||||
.then((data) => setProjects(data))
|
||||
.catch((err) => {
|
||||
console.error('Failed to fetch projects:', err);
|
||||
setError('获取项目列表失败');
|
||||
});
|
||||
}
|
||||
}, [isAuthenticated, setProjects, setError]);
|
||||
|
||||
if (!isAuthenticated) {
|
||||
return <Login onLoginSuccess={() => setIsAuthenticated(true)} />;
|
||||
return <Login />;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex h-screen w-full bg-[#0a0a0a] text-gray-200 overflow-hidden font-sans">
|
||||
<Sidebar activeModule={activeModule} setActiveModule={setActiveModule} />
|
||||
<Sidebar activeModule={activeModule as ActiveModule} setActiveModule={setActiveModule} />
|
||||
<main className="flex-1 flex flex-col min-w-0 h-full relative">
|
||||
{activeModule === 'dashboard' && <Dashboard />}
|
||||
{activeModule === 'projects' && <ProjectLibrary onProjectSelect={() => setActiveModule('workspace')} />}
|
||||
|
||||
@@ -1,20 +1,28 @@
|
||||
import React, { useState } from 'react';
|
||||
import { Target, PlusCircle, MinusCircle, SquareDashed, Sparkles, Settings2, Cpu, Image as ImageIcon, SendToBack, Tags, Undo, Redo } from 'lucide-react';
|
||||
import React, { useState, useCallback } from 'react';
|
||||
import { Target, PlusCircle, MinusCircle, SquareDashed, Sparkles, SendToBack, Image as ImageIcon, Undo, Redo, Loader2 } from 'lucide-react';
|
||||
import { cn } from '../lib/utils';
|
||||
import { Stage, Layer, Image as KonvaImage, Circle, Path, Group } from 'react-konva';
|
||||
import useImage from 'use-image';
|
||||
import { OntologyInspector } from './OntologyInspector';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { predictMask } from '../lib/api';
|
||||
|
||||
interface AISegmentationProps {
|
||||
onSendToWorkspace: () => void;
|
||||
}
|
||||
|
||||
export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
const [activeTool, setActiveTool] = useState('point_pos');
|
||||
const storeActiveTool = useStore((state) => state.activeTool);
|
||||
const setActiveTool = useStore((state) => state.setActiveTool);
|
||||
const masks = useStore((state) => state.masks);
|
||||
const addMask = useStore((state) => state.addMask);
|
||||
const clearMasks = useStore((state) => state.clearMasks);
|
||||
|
||||
const [modelSize, setModelSize] = useState('vit_l');
|
||||
const [semanticText, setSemanticText] = useState('');
|
||||
const [autoDeleteBg, setAutoDeleteBg] = useState(true);
|
||||
const [cropMode, setCropMode] = useState(false);
|
||||
const [isInferencing, setIsInferencing] = useState(false);
|
||||
|
||||
// Canvas state
|
||||
const [scale, setScale] = useState(1);
|
||||
@@ -23,6 +31,8 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 });
|
||||
const [image] = useImage('https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop');
|
||||
|
||||
const effectiveTool = storeActiveTool;
|
||||
|
||||
const handleWheel = (e: any) => {
|
||||
e.evt.preventDefault();
|
||||
const scaleBy = 1.1;
|
||||
@@ -51,13 +61,43 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
}
|
||||
};
|
||||
|
||||
const runInference = useCallback(async () => {
|
||||
if (points.length === 0 && !semanticText.trim()) return;
|
||||
setIsInferencing(true);
|
||||
try {
|
||||
const result = await predictMask({
|
||||
imageUrl: 'https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop',
|
||||
points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })),
|
||||
text: semanticText.trim() || undefined,
|
||||
modelSize,
|
||||
});
|
||||
|
||||
result.masks.forEach((m) => {
|
||||
addMask({
|
||||
id: m.id,
|
||||
frameId: 'frame-ai-1',
|
||||
pathData: m.pathData,
|
||||
label: m.label,
|
||||
color: m.color,
|
||||
segmentation: m.segmentation,
|
||||
bbox: m.bbox,
|
||||
area: m.area,
|
||||
});
|
||||
});
|
||||
} catch (err) {
|
||||
console.error('AI inference failed:', err);
|
||||
} finally {
|
||||
setIsInferencing(false);
|
||||
}
|
||||
}, [points, semanticText, modelSize, addMask]);
|
||||
|
||||
const handleStageClick = (e: any) => {
|
||||
if (activeTool === 'move') return;
|
||||
if (activeTool === 'point_pos' || activeTool === 'point_neg') {
|
||||
if (effectiveTool === 'move') return;
|
||||
if (effectiveTool === 'point_pos' || effectiveTool === 'point_neg') {
|
||||
const stage = e.target.getStage();
|
||||
const pos = stage.getRelativePointerPosition();
|
||||
if (pos) {
|
||||
setPoints([...points, { x: pos.x, y: pos.y, type: activeTool === 'point_pos' ? 'pos' : 'neg' }]);
|
||||
setPoints([...points, { x: pos.x, y: pos.y, type: effectiveTool === 'point_pos' ? 'pos' : 'neg' }]);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -68,7 +108,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
<aside className="w-80 bg-[#0d0d0d] flex flex-col border-r border-white/5 shrink-0 z-10 overflow-hidden">
|
||||
<div className="h-16 border-b border-white/5 flex items-center px-6 shrink-0 justify-between">
|
||||
<div className="flex items-center font-medium text-[11px] uppercase tracking-widest text-cyan-400">
|
||||
<Cpu size={16} className="mr-3 text-cyan-400" />
|
||||
<Sparkles size={16} className="mr-3 text-cyan-400" />
|
||||
AI智能分割引擎
|
||||
</div>
|
||||
</div>
|
||||
@@ -96,7 +136,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
<div className="grid grid-cols-2 gap-3">
|
||||
<button
|
||||
onClick={() => setActiveTool('point_pos')}
|
||||
className={cn("flex flex-col items-center justify-center p-4 rounded-lg border transition-all", activeTool === 'point_pos' ? "bg-green-500/10 border-green-500/30 text-green-400 shadow-[0_0_15px_rgba(34,197,94,0.1)]" : "bg-[#111] border-white/5 text-gray-400 hover:bg-white/5 hover:text-gray-200")}
|
||||
className={cn("flex flex-col items-center justify-center p-4 rounded-lg border transition-all", effectiveTool === 'point_pos' ? "bg-green-500/10 border-green-500/30 text-green-400 shadow-[0_0_15px_rgba(34,197,94,0.1)]" : "bg-[#111] border-white/5 text-gray-400 hover:bg-white/5 hover:text-gray-200")}
|
||||
>
|
||||
<PlusCircle size={22} className="mb-3" />
|
||||
<span className="text-[10px] uppercase tracking-wider font-semibold">正向选点</span>
|
||||
@@ -104,15 +144,15 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
|
||||
<button
|
||||
onClick={() => setActiveTool('point_neg')}
|
||||
className={cn("flex flex-col items-center justify-center p-4 rounded-lg border transition-all", activeTool === 'point_neg' ? "bg-red-500/10 border-red-500/30 text-red-500 shadow-[0_0_15px_rgba(239,68,68,0.1)]" : "bg-[#111] border-white/5 text-gray-400 hover:bg-white/5 hover:text-gray-200")}
|
||||
className={cn("flex flex-col items-center justify-center p-4 rounded-lg border transition-all", effectiveTool === 'point_neg' ? "bg-red-500/10 border-red-500/30 text-red-500 shadow-[0_0_15px_rgba(239,68,68,0.1)]" : "bg-[#111] border-white/5 text-gray-400 hover:bg-white/5 hover:text-gray-200")}
|
||||
>
|
||||
<MinusCircle size={22} className="mb-3" />
|
||||
<span className="text-[10px] uppercase tracking-wider font-semibold">反向选点</span>
|
||||
</button>
|
||||
|
||||
<button
|
||||
onClick={() => setActiveTool('box')}
|
||||
className={cn("flex flex-col items-center justify-center p-4 rounded-lg border transition-all", activeTool === 'box' ? "bg-blue-500/10 border-blue-500/30 text-blue-400 shadow-[0_0_15px_rgba(59,130,246,0.1)]" : "bg-[#111] border-white/5 text-gray-400 hover:bg-white/5 hover:text-gray-200")}
|
||||
onClick={() => setActiveTool('box_select')}
|
||||
className={cn("flex flex-col items-center justify-center p-4 rounded-lg border transition-all", effectiveTool === 'box_select' ? "bg-blue-500/10 border-blue-500/30 text-blue-400 shadow-[0_0_15px_rgba(59,130,246,0.1)]" : "bg-[#111] border-white/5 text-gray-400 hover:bg-white/5 hover:text-gray-200")}
|
||||
>
|
||||
<SquareDashed size={22} className="mb-3" />
|
||||
<span className="text-[10px] uppercase tracking-wider font-semibold">边界框选</span>
|
||||
@@ -120,7 +160,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
|
||||
<button
|
||||
onClick={() => setActiveTool('move')}
|
||||
className={cn("flex flex-col items-center justify-center p-4 rounded-lg border transition-all", activeTool === 'move' ? "bg-white/10 border-white/20 text-white shadow-[0_0_15px_rgba(255,255,255,0.05)]" : "bg-[#111] border-white/5 text-gray-400 hover:bg-white/5 hover:text-gray-200")}
|
||||
className={cn("flex flex-col items-center justify-center p-4 rounded-lg border transition-all", effectiveTool === 'move' ? "bg-white/10 border-white/20 text-white shadow-[0_0_15px_rgba(255,255,255,0.05)]" : "bg-[#111] border-white/5 text-gray-400 hover:bg-white/5 hover:text-gray-200")}
|
||||
>
|
||||
<Target size={22} className="mb-3" />
|
||||
<span className="text-[10px] uppercase tracking-wider font-semibold">视口控制</span>
|
||||
@@ -165,9 +205,17 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
|
||||
<div className="p-6 bg-[#0a0a0a] border-t border-white/5 shrink-0 flex flex-col gap-3">
|
||||
<button
|
||||
className="w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all shadow-lg font-medium tracking-wide text-xs uppercase bg-cyan-500 hover:bg-cyan-400 text-black shadow-cyan-500/20 hover:shadow-cyan-500/40"
|
||||
onClick={runInference}
|
||||
disabled={isInferencing}
|
||||
className={cn(
|
||||
"w-full py-3.5 rounded-lg flex items-center justify-center gap-2 transition-all shadow-lg font-medium tracking-wide text-xs uppercase",
|
||||
isInferencing
|
||||
? "bg-cyan-500/50 text-black/70 cursor-not-allowed"
|
||||
: "bg-cyan-500 hover:bg-cyan-400 text-black shadow-cyan-500/20 hover:shadow-cyan-500/40"
|
||||
)}
|
||||
>
|
||||
<Sparkles size={16} /> 执行高精度语义分割
|
||||
{isInferencing ? <Loader2 size={16} className="animate-spin" /> : <Sparkles size={16} />}
|
||||
{isInferencing ? '推理中...' : '执行高精度语义分割'}
|
||||
</button>
|
||||
<button
|
||||
onClick={onSendToWorkspace}
|
||||
@@ -196,7 +244,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
<button className="flex items-center gap-2 text-xs text-gray-400 hover:text-white transition-colors bg-white/5 hover:bg-white/10 px-3 py-1.5 rounded-md border border-white/5">
|
||||
<ImageIcon size={14} /> 上传替换底图
|
||||
</button>
|
||||
<button className="text-xs text-gray-400 hover:text-white transition-colors px-3 py-1.5" onClick={() => setPoints([])}>
|
||||
<button className="text-xs text-gray-400 hover:text-white transition-colors px-3 py-1.5" onClick={() => { setPoints([]); clearMasks(); }}>
|
||||
清空全体锚点
|
||||
</button>
|
||||
</div>
|
||||
@@ -205,7 +253,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
<div className="flex-1 relative p-8">
|
||||
<div className="w-full h-full relative border border-white/5 rounded shadow-2xl bg-[#1e1e1e] overflow-hidden cursor-crosshair">
|
||||
<Stage
|
||||
width={window.innerWidth - 320 - 64} // approx sizing, uses window to avoid ResizeObserver for simplicity here
|
||||
width={window.innerWidth - 320 - 64}
|
||||
height={window.innerHeight - 64 - 64}
|
||||
onWheel={handleWheel}
|
||||
onMouseMove={handleMouseMove}
|
||||
@@ -214,7 +262,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
scaleY={scale}
|
||||
x={position.x}
|
||||
y={position.y}
|
||||
draggable={activeTool === 'move'}
|
||||
draggable={effectiveTool === 'move'}
|
||||
>
|
||||
<Layer>
|
||||
{/* Background Image */}
|
||||
@@ -227,13 +275,17 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Mock Instance Mask from SAM3 */}
|
||||
<Group opacity={0.4}>
|
||||
<Path
|
||||
data="M 300 200 Q 400 150 450 250 T 400 350 Q 250 350 280 250 Z"
|
||||
fill="#06b6d4" // cyan-500
|
||||
/>
|
||||
</Group>
|
||||
{/* AI Returned Masks */}
|
||||
{masks.map((mask) => (
|
||||
<Group key={mask.id} opacity={0.45}>
|
||||
<Path
|
||||
data={mask.pathData}
|
||||
fill={mask.color}
|
||||
stroke={mask.color}
|
||||
strokeWidth={1 / scale}
|
||||
/>
|
||||
</Group>
|
||||
))}
|
||||
|
||||
{/* Points */}
|
||||
{points.map((p, i) => (
|
||||
@@ -257,6 +309,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {
|
||||
<div className="absolute bottom-4 left-4 flex gap-4 text-[10px] font-mono text-gray-500 pointer-events-none">
|
||||
<span>光标坐标: {cursorPos.x.toFixed(2)}, {cursorPos.y.toFixed(2)}</span>
|
||||
<span>缩放比率: {(scale * 100).toFixed(0)}%</span>
|
||||
<span>遮罩数: {masks.length}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import React, { useEffect, useRef, useState } from 'react';
|
||||
import React, { useEffect, useRef, useState, useCallback } from 'react';
|
||||
import { Stage, Layer, Image as KonvaImage, Circle, Rect, Path, Group } from 'react-konva';
|
||||
import useImage from 'use-image';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { predictMask } from '../lib/api';
|
||||
import { cn } from '../lib/utils';
|
||||
|
||||
interface CanvasAreaProps {
|
||||
activeTool: string;
|
||||
@@ -13,6 +16,17 @@ export function CanvasArea({ activeTool }: CanvasAreaProps) {
|
||||
const [position, setPosition] = useState({ x: 0, y: 0 });
|
||||
const [points, setPoints] = useState<{ x: number, y: number, type: 'pos'|'neg' }[]>([]);
|
||||
const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 });
|
||||
const [boxStart, setBoxStart] = useState<{ x: number, y: number } | null>(null);
|
||||
const [boxCurrent, setBoxCurrent] = useState<{ x: number, y: number } | null>(null);
|
||||
const [isInferencing, setIsInferencing] = useState(false);
|
||||
|
||||
const masks = useStore((state) => state.masks);
|
||||
const addMask = useStore((state) => state.addMask);
|
||||
const clearMasks = useStore((state) => state.clearMasks);
|
||||
const storeActiveTool = useStore((state) => state.activeTool);
|
||||
const setActiveTool = useStore((state) => state.setActiveTool);
|
||||
|
||||
const effectiveTool = activeTool || storeActiveTool;
|
||||
|
||||
// We load a mock image representing a frame
|
||||
const [image] = useImage('https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop');
|
||||
@@ -56,35 +70,120 @@ export function CanvasArea({ activeTool }: CanvasAreaProps) {
|
||||
if (!stage) return;
|
||||
const pos = stage.getPointerPosition();
|
||||
if (pos) {
|
||||
// Convert to image coordinates
|
||||
const imageX = (pos.x - position.x) / scale;
|
||||
const imageY = (pos.y - position.y) / scale;
|
||||
setCursorPos({ x: imageX, y: imageY });
|
||||
}
|
||||
};
|
||||
|
||||
const handleStageClick = (e: any) => {
|
||||
if (activeTool === 'move') return;
|
||||
|
||||
if (activeTool === 'point_pos' || activeTool === 'point_neg') {
|
||||
const stage = e.target.getStage();
|
||||
const pos = stage.getRelativePointerPosition();
|
||||
setPoints([...points, { x: pos.x, y: pos.y, type: activeTool === 'point_pos' ? 'pos' : 'neg' }]);
|
||||
if (boxStart && effectiveTool === 'box_select') {
|
||||
const relPos = stage.getRelativePointerPosition();
|
||||
if (relPos) {
|
||||
setBoxCurrent({ x: relPos.x, y: relPos.y });
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const runInference = useCallback(async (promptPoints?: typeof points, promptBox?: { x1: number, y1: number, x2: number, y2: number }) => {
|
||||
setIsInferencing(true);
|
||||
try {
|
||||
const result = await predictMask({
|
||||
imageUrl: 'https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop',
|
||||
points: promptPoints?.map((p) => ({ x: p.x, y: p.y, type: p.type })),
|
||||
box: promptBox,
|
||||
});
|
||||
|
||||
result.masks.forEach((m) => {
|
||||
addMask({
|
||||
id: m.id,
|
||||
frameId: 'frame-1',
|
||||
pathData: m.pathData,
|
||||
label: m.label,
|
||||
color: m.color,
|
||||
segmentation: m.segmentation,
|
||||
bbox: m.bbox,
|
||||
area: m.area,
|
||||
});
|
||||
});
|
||||
} catch (err) {
|
||||
console.error('Inference failed:', err);
|
||||
} finally {
|
||||
setIsInferencing(false);
|
||||
}
|
||||
}, [addMask]);
|
||||
|
||||
const handleStageMouseDown = (e: any) => {
|
||||
if (effectiveTool === 'box_select') {
|
||||
const stage = e.target.getStage();
|
||||
const pos = stage.getRelativePointerPosition();
|
||||
if (pos) {
|
||||
setBoxStart({ x: pos.x, y: pos.y });
|
||||
setBoxCurrent({ x: pos.x, y: pos.y });
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleStageMouseUp = (e: any) => {
|
||||
if (effectiveTool === 'box_select' && boxStart && boxCurrent) {
|
||||
const x1 = Math.min(boxStart.x, boxCurrent.x);
|
||||
const y1 = Math.min(boxStart.y, boxCurrent.y);
|
||||
const x2 = Math.max(boxStart.x, boxCurrent.x);
|
||||
const y2 = Math.max(boxStart.y, boxCurrent.y);
|
||||
|
||||
if (Math.abs(x2 - x1) > 5 && Math.abs(y2 - y1) > 5) {
|
||||
runInference(undefined, { x1, y1, x2, y2 });
|
||||
}
|
||||
|
||||
setBoxStart(null);
|
||||
setBoxCurrent(null);
|
||||
}
|
||||
};
|
||||
|
||||
const handleStageClick = (e: any) => {
|
||||
if (effectiveTool === 'move') return;
|
||||
if (effectiveTool === 'box_select') return; // handled by mouseup
|
||||
|
||||
if (effectiveTool === 'point_pos' || effectiveTool === 'point_neg') {
|
||||
const stage = e.target.getStage();
|
||||
const pos = stage.getRelativePointerPosition();
|
||||
if (pos) {
|
||||
const newPoints = [...points, { x: pos.x, y: pos.y, type: effectiveTool === 'point_pos' ? 'pos' : 'neg' as 'pos'|'neg' }];
|
||||
setPoints(newPoints);
|
||||
// Auto-trigger inference after point selection
|
||||
runInference(newPoints);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const boxRect = React.useMemo(() => {
|
||||
if (!boxStart || !boxCurrent) return null;
|
||||
const x = Math.min(boxStart.x, boxCurrent.x);
|
||||
const y = Math.min(boxStart.y, boxCurrent.y);
|
||||
const width = Math.abs(boxCurrent.x - boxStart.x);
|
||||
const height = Math.abs(boxCurrent.y - boxStart.y);
|
||||
return { x, y, width, height };
|
||||
}, [boxStart, boxCurrent]);
|
||||
|
||||
return (
|
||||
<div ref={containerRef} className="w-full h-full relative cursor-crosshair overflow-hidden rounded-sm">
|
||||
{isInferencing && (
|
||||
<div className="absolute top-4 right-4 z-20 flex items-center gap-2 bg-[#111] border border-white/10 px-3 py-2 rounded-lg shadow-xl">
|
||||
<div className="w-3 h-3 border-2 border-cyan-500 border-t-transparent rounded-full animate-spin" />
|
||||
<span className="text-xs text-cyan-400 font-mono">AI 推理中...</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Stage
|
||||
width={stageSize.width}
|
||||
height={stageSize.height}
|
||||
onWheel={handleWheel}
|
||||
onMouseMove={handleMouseMove}
|
||||
onMouseDown={handleStageMouseDown}
|
||||
onMouseUp={handleStageMouseUp}
|
||||
scaleX={scale}
|
||||
scaleY={scale}
|
||||
x={position.x}
|
||||
y={position.y}
|
||||
draggable={activeTool === 'move'}
|
||||
draggable={effectiveTool === 'move'}
|
||||
onClick={handleStageClick}
|
||||
>
|
||||
<Layer>
|
||||
@@ -98,24 +197,38 @@ export function CanvasArea({ activeTool }: CanvasAreaProps) {
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Mock Instance Mask overlapping */}
|
||||
<Group opacity={0.4}>
|
||||
<Path
|
||||
data="M 300 200 Q 400 150 450 250 T 400 350 Q 250 350 280 250 Z"
|
||||
fill="#06b6d4" // cyan-500
|
||||
{/* AI Returned Masks */}
|
||||
{masks.map((mask) => (
|
||||
<Group key={mask.id} opacity={0.5}>
|
||||
<Path
|
||||
data={mask.pathData}
|
||||
fill={mask.color}
|
||||
stroke={mask.color}
|
||||
strokeWidth={1 / scale}
|
||||
/>
|
||||
</Group>
|
||||
))}
|
||||
|
||||
{/* Box selection preview */}
|
||||
{boxRect && effectiveTool === 'box_select' && (
|
||||
<Rect
|
||||
x={boxRect.x}
|
||||
y={boxRect.y}
|
||||
width={boxRect.width}
|
||||
height={boxRect.height}
|
||||
fill="rgba(6, 182, 212, 0.1)"
|
||||
stroke="#06b6d4"
|
||||
strokeWidth={2 / scale}
|
||||
dash={[4 / scale, 4 / scale]}
|
||||
/>
|
||||
<Path
|
||||
data="M 600 400 Q 700 350 750 450 T 650 550 Q 550 550 580 450 Z"
|
||||
fill="#a855f7" // purple-500
|
||||
/>
|
||||
</Group>
|
||||
)}
|
||||
|
||||
{/* AI Prompts Point Regions */}
|
||||
{points.map((p, i) => (
|
||||
<Group key={i} x={p.x} y={p.y}>
|
||||
<Circle
|
||||
radius={6 / scale}
|
||||
fill={p.type === 'pos' ? '#22c55e' : '#ef4444'} // green or red
|
||||
fill={p.type === 'pos' ? '#22c55e' : '#ef4444'}
|
||||
stroke="#ffffff"
|
||||
strokeWidth={2 / scale}
|
||||
shadowColor="black"
|
||||
@@ -134,7 +247,17 @@ export function CanvasArea({ activeTool }: CanvasAreaProps) {
|
||||
<span>光标: {cursorPos.x.toFixed(2)}, {cursorPos.y.toFixed(2)}</span>
|
||||
<span>当前图层树: OBJECT_VEHICLE_01</span>
|
||||
<span>缩放比: {(scale * 100).toFixed(0)}%</span>
|
||||
<span>遮罩数: {masks.length}</span>
|
||||
</div>
|
||||
|
||||
{masks.length > 0 && (
|
||||
<button
|
||||
onClick={clearMasks}
|
||||
className="absolute bottom-4 right-4 text-xs bg-red-500/10 hover:bg-red-500/20 text-red-400 border border-red-500/20 px-3 py-1.5 rounded transition-colors"
|
||||
>
|
||||
清空遮罩
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,10 +1,99 @@
|
||||
import React from 'react';
|
||||
import { Activity, Clock, Folders, CheckCircle2 } from 'lucide-react';
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Activity, Clock, Folders, CheckCircle2, Loader2 } from 'lucide-react';
|
||||
import { progressWS, type ProgressMessage } from '../lib/websocket';
|
||||
import { cn } from '../lib/utils';
|
||||
|
||||
interface QueueTask {
|
||||
id: string;
|
||||
name: string;
|
||||
progress: number;
|
||||
status: string;
|
||||
}
|
||||
|
||||
export function Dashboard() {
|
||||
const [tasks, setTasks] = useState<QueueTask[]>([
|
||||
{ id: '1', name: 'City_Driving_Dataset_004.mp4', progress: 85, status: '正在截取帧 (30fps)' },
|
||||
{ id: '2', name: 'Pedestrian_Night_Vision_02.mkv', progress: 32, status: '正在截取帧 (60fps)' },
|
||||
{ id: '3', name: 'Drone_Mapping_Sector_7.avi', progress: 0, status: '队列排队等待中' },
|
||||
]);
|
||||
const [isConnected, setIsConnected] = useState(false);
|
||||
const [activityLog, setActivityLog] = useState<Array<{ time: string; message: string; project?: string }>>([
|
||||
{ time: '10 分钟前', message: '语义归档完成 54 帧', project: 'Highway_Data' },
|
||||
{ time: '25 分钟前', message: '项目解析开始', project: 'City_Driving_Dataset_004' },
|
||||
{ time: '1 小时前', message: '模板库更新: Cityscapes_v2', project: '系统' },
|
||||
{ time: '2 小时前', message: 'AI 推理完成 12 个实例', project: 'Nav_Cam_Left' },
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
progressWS.connect();
|
||||
|
||||
const unsubscribe = progressWS.onProgress((data: ProgressMessage) => {
|
||||
setIsConnected(progressWS.isConnected());
|
||||
|
||||
if (data.type === 'progress' && data.taskId && data.filename) {
|
||||
setTasks((prev) => {
|
||||
const exists = prev.find((t) => t.id === data.taskId);
|
||||
if (exists) {
|
||||
return prev.map((t) =>
|
||||
t.id === data.taskId
|
||||
? { ...t, progress: data.progress ?? t.progress, status: data.status ?? t.status }
|
||||
: t
|
||||
);
|
||||
}
|
||||
return [
|
||||
...prev,
|
||||
{
|
||||
id: data.taskId!,
|
||||
name: data.filename!,
|
||||
progress: data.progress ?? 0,
|
||||
status: data.status ?? '处理中',
|
||||
},
|
||||
];
|
||||
});
|
||||
}
|
||||
|
||||
if (data.type === 'complete' && data.taskId) {
|
||||
setTasks((prev) =>
|
||||
prev.map((t) =>
|
||||
t.id === data.taskId ? { ...t, progress: 100, status: '已完成' } : t
|
||||
)
|
||||
);
|
||||
setActivityLog((prev) => [
|
||||
{ time: '刚刚', message: `解析完成: ${data.filename || data.taskId}`, project: '系统' },
|
||||
...prev.slice(0, 9),
|
||||
]);
|
||||
}
|
||||
|
||||
if (data.type === 'error' && data.taskId) {
|
||||
setTasks((prev) =>
|
||||
prev.map((t) =>
|
||||
t.id === data.taskId ? { ...t, status: `错误: ${data.message || '未知错误'}` } : t
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (data.type === 'status') {
|
||||
setActivityLog((prev) => [
|
||||
{ time: '刚刚', message: data.message || '状态更新', project: '系统' },
|
||||
...prev.slice(0, 9),
|
||||
]);
|
||||
}
|
||||
});
|
||||
|
||||
const checkConnection = setInterval(() => {
|
||||
setIsConnected(progressWS.isConnected());
|
||||
}, 5000);
|
||||
|
||||
return () => {
|
||||
unsubscribe();
|
||||
clearInterval(checkConnection);
|
||||
progressWS.disconnect();
|
||||
};
|
||||
}, []);
|
||||
|
||||
const stats = [
|
||||
{ label: '运行中项目', value: '14', icon: Folders, color: 'text-blue-400', bg: 'bg-blue-400/10' },
|
||||
{ label: '排队处理任务', value: '3,291', icon: Clock, color: 'text-orange-400', bg: 'bg-orange-400/10' },
|
||||
{ label: '排队处理任务', value: tasks.length.toString(), icon: Clock, color: 'text-orange-400', bg: 'bg-orange-400/10' },
|
||||
{ label: '已归档批次', value: '128', icon: CheckCircle2, color: 'text-emerald-400', bg: 'bg-emerald-400/10' },
|
||||
{ label: '系统负载', value: '78%', icon: Activity, color: 'text-cyan-400', bg: 'bg-cyan-400/10' },
|
||||
];
|
||||
@@ -12,7 +101,18 @@ export function Dashboard() {
|
||||
return (
|
||||
<div className="p-8 w-full h-full overflow-y-auto bg-[#0a0a0a]">
|
||||
<header className="mb-8">
|
||||
<h1 className="text-3xl font-medium tracking-tight text-white">系统整体概况</h1>
|
||||
<div className="flex items-center gap-3">
|
||||
<h1 className="text-3xl font-medium tracking-tight text-white">系统整体概况</h1>
|
||||
<div className={cn(
|
||||
"flex items-center gap-1.5 text-[10px] uppercase font-mono px-2 py-1 rounded border",
|
||||
isConnected
|
||||
? "bg-emerald-500/10 text-emerald-400 border-emerald-500/20"
|
||||
: "bg-amber-500/10 text-amber-400 border-amber-500/20"
|
||||
)}>
|
||||
<div className={cn("w-1.5 h-1.5 rounded-full", isConnected ? "bg-emerald-500" : "bg-amber-500 animate-pulse")} />
|
||||
{isConnected ? 'WebSocket 已连接' : 'WebSocket 断开'}
|
||||
</div>
|
||||
</div>
|
||||
<p className="text-gray-400 text-sm mt-1">系统全局数据吞吐状态与所有接入项目进度实时洞察驾驶舱。</p>
|
||||
</header>
|
||||
|
||||
@@ -37,36 +137,43 @@ export function Dashboard() {
|
||||
<div className="lg:col-span-2 bg-[#111] border border-white/5 rounded-xl p-6 min-h-[400px]">
|
||||
<h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6">解析队列 (FFmpeg 挂起任务)</h2>
|
||||
<div className="space-y-4">
|
||||
{[
|
||||
{ name: 'City_Driving_Dataset_004.mp4', progress: 85, status: '正在截取帧 (30fps)' },
|
||||
{ name: 'Pedestrian_Night_Vision_02.mkv', progress: 32, status: '正在截取帧 (60fps)' },
|
||||
{ name: 'Drone_Mapping_Sector_7.avi', progress: 0, status: '队列排队等待中' }
|
||||
].map((task, i) => (
|
||||
<div key={i} className="bg-[#0d0d0d] border border-white/5 p-4 rounded-lg">
|
||||
{tasks.map((task) => (
|
||||
<div key={task.id} className="bg-[#0d0d0d] border border-white/5 p-4 rounded-lg">
|
||||
<div className="flex justify-between items-center mb-2">
|
||||
<span className="font-mono text-sm text-gray-200">{task.name}</span>
|
||||
<span className="text-xs text-cyan-400 font-mono">{task.progress}%</span>
|
||||
</div>
|
||||
<div className="w-full h-1.5 bg-white/5 rounded-full overflow-hidden mb-2">
|
||||
<div className="h-full bg-gradient-to-r from-cyan-600 to-cyan-400 rounded-full" style={{ width: `${task.progress}%` }} />
|
||||
<div className="h-full bg-gradient-to-r from-cyan-600 to-cyan-400 rounded-full transition-all duration-500" style={{ width: `${task.progress}%` }} />
|
||||
</div>
|
||||
<div className="text-xs text-gray-500 flex items-center gap-2">
|
||||
{task.status === '已完成' ? (
|
||||
<CheckCircle2 size={12} className="text-emerald-400" />
|
||||
) : task.status.includes('错误') ? (
|
||||
<span className="text-red-400">●</span>
|
||||
) : (
|
||||
<Loader2 size={12} className="text-cyan-400 animate-spin" />
|
||||
)}
|
||||
{task.status}
|
||||
</div>
|
||||
<div className="text-xs text-gray-500">{task.status}</div>
|
||||
</div>
|
||||
))}
|
||||
{tasks.length === 0 && (
|
||||
<div className="text-sm text-gray-500 text-center py-12">当前无处理任务</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="bg-[#111] border border-white/5 rounded-xl p-6 min-h-[400px]">
|
||||
<h2 className="text-sm font-medium text-gray-400 uppercase tracking-widest mb-6">近期实时流转记录</h2>
|
||||
<div className="space-y-6 relative before:absolute before:inset-0 before:ml-[11px] before:-translate-x-px md:before:mx-auto md:before:translate-x-0 before:h-full before:w-0.5 before:bg-gradient-to-b before:from-transparent before:via-white/10 before:to-transparent">
|
||||
{/* Activity log mockup */}
|
||||
{[1, 2, 3, 4].map((i) => (
|
||||
{activityLog.map((log, i) => (
|
||||
<div key={i} className="relative flex items-center justify-between md:justify-normal md:odd:flex-row-reverse group is-active">
|
||||
<div className="flex items-center justify-center w-6 h-6 rounded-full border border-white/10 bg-[#111] group-[.is-active]:bg-cyan-500 group-[.is-active]:border-cyan-400 text-slate-500 group-[.is-active]:text-black shadow shrink-0 md:order-1 md:group-odd:-translate-x-1/2 md:group-even:translate-x-1/2 z-10" />
|
||||
<div className="w-[calc(100%-4rem)] md:w-[calc(50%-2.5rem)] bg-[#0d0d0d] p-3 rounded border border-white/5">
|
||||
<div className="text-xs text-gray-400 mb-1">10 分钟前</div>
|
||||
<div className="text-sm font-medium text-gray-200">语义归档完成 54 帧</div>
|
||||
<div className="text-xs text-gray-500">归属项目: Highway_Data</div>
|
||||
<div className="text-xs text-gray-400 mb-1">{log.time}</div>
|
||||
<div className="text-sm font-medium text-gray-200">{log.message}</div>
|
||||
<div className="text-xs text-gray-500">归属项目: {log.project}</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import React, { useState } from 'react';
|
||||
import { BrainCircuit } from 'lucide-react';
|
||||
import { cn } from '../lib/utils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { login as loginApi } from '../lib/api';
|
||||
|
||||
interface LoginProps {
|
||||
onLoginSuccess: (token: string) => void;
|
||||
}
|
||||
|
||||
export function Login({ onLoginSuccess }: LoginProps) {
|
||||
export function Login() {
|
||||
const storeLogin = useStore((state) => state.login);
|
||||
const [username, setUsername] = useState('admin');
|
||||
const [password, setPassword] = useState('123456');
|
||||
const [error, setError] = useState('');
|
||||
@@ -18,21 +17,11 @@ export function Login({ onLoginSuccess }: LoginProps) {
|
||||
setIsLoading(true);
|
||||
|
||||
try {
|
||||
const response = await fetch('/api/login', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ username, password }),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
onLoginSuccess(data.token);
|
||||
} else {
|
||||
const errData = await response.json();
|
||||
setError(errData.error || '登录失败');
|
||||
}
|
||||
} catch (err) {
|
||||
setError('网络异常,无法连接到后端验证');
|
||||
const data = await loginApi(username, password);
|
||||
storeLogin(data.token);
|
||||
} catch (err: any) {
|
||||
const msg = err?.response?.data?.detail || err?.response?.data?.error || '登录失败,请检查网络或凭证';
|
||||
setError(msg);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
@@ -45,7 +34,7 @@ export function Login({ onLoginSuccess }: LoginProps) {
|
||||
<div className="relative z-10 w-full max-w-md p-8 bg-[#111] border border-white/5 rounded-2xl shadow-2xl scale-in shadow-black/50">
|
||||
<div className="flex flex-col items-center mb-8">
|
||||
<div className="w-16 h-16 bg-white rounded-2xl flex items-center justify-center text-cyan-500 shadow-lg shadow-cyan-500/20 mb-4 overflow-hidden border border-white/10">
|
||||
<img src="/Logo.png" alt="Logo" className="w-full h-full object-cover" />
|
||||
<BrainCircuit size={32} />
|
||||
</div>
|
||||
<h1 className="text-2xl font-bold text-white tracking-wider mb-2">欢迎登录协同工作站</h1>
|
||||
<p className="text-sm text-gray-500">AI 智能切分与多模态数据标注系统</p>
|
||||
|
||||
@@ -1,19 +1,63 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { UploadCloud, Film, Settings2, MoreHorizontal } from 'lucide-react';
|
||||
import { UploadCloud, Film, Settings2, MoreHorizontal, Plus, Loader2 } from 'lucide-react';
|
||||
import { cn } from '../lib/utils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { getProjects, createProject } from '../lib/api';
|
||||
import type { Project } from '../store/useStore';
|
||||
|
||||
interface ProjectLibraryProps {
|
||||
onProjectSelect: () => void;
|
||||
}
|
||||
|
||||
export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
||||
const [projects, setProjects] = useState<any[]>([]);
|
||||
const projects = useStore((state) => state.projects);
|
||||
const setProjects = useStore((state) => state.setProjects);
|
||||
const setCurrentProject = useStore((state) => state.setCurrentProject);
|
||||
const addProject = useStore((state) => state.addProject);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [isCreating, setIsCreating] = useState(false);
|
||||
const [showModal, setShowModal] = useState(false);
|
||||
const [newName, setNewName] = useState('');
|
||||
const [newDesc, setNewDesc] = useState('');
|
||||
|
||||
useEffect(() => {
|
||||
fetch('/api/projects')
|
||||
.then(res => res.json())
|
||||
.then(data => setProjects(data))
|
||||
.catch(console.error);
|
||||
}, []);
|
||||
setIsLoading(true);
|
||||
getProjects()
|
||||
.then((data) => setProjects(data))
|
||||
.catch(console.error)
|
||||
.finally(() => setIsLoading(false));
|
||||
}, [setProjects]);
|
||||
|
||||
const handleCreate = async () => {
|
||||
if (!newName.trim()) return;
|
||||
setIsCreating(true);
|
||||
try {
|
||||
const project = await createProject({ name: newName.trim(), description: newDesc.trim() || undefined });
|
||||
addProject(project);
|
||||
setShowModal(false);
|
||||
setNewName('');
|
||||
setNewDesc('');
|
||||
} catch (err) {
|
||||
console.error('Failed to create project:', err);
|
||||
} finally {
|
||||
setIsCreating(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSelect = (project: Project) => {
|
||||
setCurrentProject(project);
|
||||
onProjectSelect();
|
||||
};
|
||||
|
||||
const SkeletonCard = () => (
|
||||
<div className="group flex flex-col bg-[#111] border border-white/5 rounded-xl overflow-hidden animate-pulse">
|
||||
<div className="w-full aspect-[16/9] bg-[#1a1a1a]" />
|
||||
<div className="p-4 flex flex-col gap-2">
|
||||
<div className="h-4 bg-[#1a1a1a] rounded w-3/4" />
|
||||
<div className="h-3 bg-[#1a1a1a] rounded w-1/2 mt-2" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="p-8 w-full h-full overflow-y-auto bg-[#0a0a0a]">
|
||||
@@ -22,47 +66,116 @@ export function ProjectLibrary({ onProjectSelect }: ProjectLibraryProps) {
|
||||
<h1 className="text-3xl font-medium tracking-tight text-white mb-2">视频与连续帧项目库</h1>
|
||||
<p className="text-gray-400 text-sm">上传源文件、按帧解析配置,并结构化管理多媒体资产实体。</p>
|
||||
</div>
|
||||
<button className="flex items-center gap-2 bg-cyan-600 hover:bg-cyan-500 text-white px-5 py-2.5 rounded-lg font-medium text-sm transition-colors border border-cyan-500 shadow-lg shadow-cyan-900/20">
|
||||
<UploadCloud size={18} />
|
||||
<span>导入多媒体资源</span>
|
||||
</button>
|
||||
<div className="flex items-center gap-3">
|
||||
<button
|
||||
onClick={() => setShowModal(true)}
|
||||
className="flex items-center gap-2 bg-white/5 hover:bg-white/10 border border-white/10 text-gray-200 px-5 py-2.5 rounded-lg font-medium text-sm transition-colors"
|
||||
>
|
||||
<Plus size={18} />
|
||||
<span>新建项目</span>
|
||||
</button>
|
||||
<button className="flex items-center gap-2 bg-cyan-600 hover:bg-cyan-500 text-white px-5 py-2.5 rounded-lg font-medium text-sm transition-colors border border-cyan-500 shadow-lg shadow-cyan-900/20">
|
||||
<UploadCloud size={18} />
|
||||
<span>导入多媒体资源</span>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4 gap-6">
|
||||
{projects.map((proj) => (
|
||||
<div
|
||||
key={proj.id}
|
||||
className="group flex flex-col bg-[#111] border border-white/5 rounded-xl overflow-hidden cursor-pointer hover:border-cyan-500/50 transition-all hover:shadow-[0_0_20px_rgba(6,182,212,0.15)]"
|
||||
onClick={onProjectSelect}
|
||||
>
|
||||
<div className={`w-full aspect-[16/9] ${proj.thumbnail} relative flex items-center justify-center overflow-hidden`}>
|
||||
{/* Stand-in for actual video frame thumbnail */}
|
||||
<Film className="w-12 h-12 text-[#2a2a2a] group-hover:text-[#333] transition-colors" />
|
||||
<div className="absolute top-2 right-2 flex gap-2">
|
||||
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] font-mono px-2 py-1 rounded border border-white/10 uppercase tracking-widest">
|
||||
{proj.fps}
|
||||
</span>
|
||||
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] px-2 py-1 rounded border border-white/10 flex items-center gap-1 uppercase tracking-widest">
|
||||
{proj.status === 'Ready' ? (
|
||||
<><div className="w-1.5 h-1.5 bg-emerald-500 rounded-full" /> 已就绪</>
|
||||
) : (
|
||||
<><div className="w-1.5 h-1.5 bg-amber-500 rounded-full animate-pulse" /> 解析拆帧中</>
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
{isLoading && projects.length === 0 ? (
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4 gap-6">
|
||||
{Array.from({ length: 8 }).map((_, i) => (
|
||||
<SkeletonCard key={i} />
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4 gap-6">
|
||||
{projects.map((proj) => (
|
||||
<div
|
||||
key={proj.id}
|
||||
className="group flex flex-col bg-[#111] border border-white/5 rounded-xl overflow-hidden cursor-pointer hover:border-cyan-500/50 transition-all hover:shadow-[0_0_20px_rgba(6,182,212,0.15)]"
|
||||
onClick={() => handleSelect(proj)}
|
||||
>
|
||||
<div className={cn("w-full aspect-[16/9] relative flex items-center justify-center overflow-hidden", proj.thumbnail || 'bg-[#0d0d0d]')}>
|
||||
<Film className="w-12 h-12 text-[#2a2a2a] group-hover:text-[#333] transition-colors" />
|
||||
<div className="absolute top-2 right-2 flex gap-2">
|
||||
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] font-mono px-2 py-1 rounded border border-white/10 uppercase tracking-widest">
|
||||
{proj.fps || '30FPS'}
|
||||
</span>
|
||||
<span className="backdrop-blur-md bg-black/40 text-gray-200 text-[10px] px-2 py-1 rounded border border-white/10 flex items-center gap-1 uppercase tracking-widest">
|
||||
{proj.status === 'Ready' ? (
|
||||
<><div className="w-1.5 h-1.5 bg-emerald-500 rounded-full" /> 已就绪</>
|
||||
) : proj.status === 'Parsing' ? (
|
||||
<><div className="w-1.5 h-1.5 bg-amber-500 rounded-full animate-pulse" /> 解析拆帧中</>
|
||||
) : (
|
||||
<><div className="w-1.5 h-1.5 bg-red-500 rounded-full" /> 异常</>
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
<div className="p-4 flex flex-col gap-1">
|
||||
<div className="flex justify-between items-start">
|
||||
<h3 className="text-sm font-medium text-gray-200 truncate pr-4" title={proj.name}>{proj.name}</h3>
|
||||
<button className="text-gray-500 hover:text-gray-300"><MoreHorizontal size={16} /></button>
|
||||
</div>
|
||||
<div className="flex items-center gap-4 text-xs text-gray-500 font-mono mt-2">
|
||||
<span className="flex items-center gap-1.5"><Settings2 size={12} /> {proj.frames ?? 0} 帧节点</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="p-4 flex flex-col gap-1">
|
||||
<div className="flex justify-between items-start">
|
||||
<h3 className="text-sm font-medium text-gray-200 truncate pr-4" title={proj.name}>{proj.name}</h3>
|
||||
<button className="text-gray-500 hover:text-gray-300"><MoreHorizontal size={16} /></button>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{showModal && (
|
||||
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm">
|
||||
<div className="bg-[#111] border border-white/10 rounded-2xl p-6 w-full max-w-md shadow-2xl">
|
||||
<h2 className="text-lg font-semibold text-white mb-4">新建项目</h2>
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<label className="block text-xs font-medium text-gray-400 uppercase tracking-widest mb-2">项目名称</label>
|
||||
<input
|
||||
type="text"
|
||||
value={newName}
|
||||
onChange={(e) => setNewName(e.target.value)}
|
||||
className="w-full bg-[#1a1a1a] border border-white/10 rounded-lg px-4 py-3 text-sm focus:outline-none focus:border-cyan-500/50 focus:ring-1 focus:ring-cyan-500/50 transition-all"
|
||||
placeholder="输入项目名称"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center gap-4 text-xs text-gray-500 font-mono mt-2">
|
||||
<span className="flex items-center gap-1.5"><Settings2 size={12} /> {proj.frames} 帧节点</span>
|
||||
<div>
|
||||
<label className="block text-xs font-medium text-gray-400 uppercase tracking-widest mb-2">描述(可选)</label>
|
||||
<input
|
||||
type="text"
|
||||
value={newDesc}
|
||||
onChange={(e) => setNewDesc(e.target.value)}
|
||||
className="w-full bg-[#1a1a1a] border border-white/10 rounded-lg px-4 py-3 text-sm focus:outline-none focus:border-cyan-500/50 focus:ring-1 focus:ring-cyan-500/50 transition-all"
|
||||
placeholder="输入项目描述"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex justify-end gap-3 mt-6">
|
||||
<button
|
||||
onClick={() => { setShowModal(false); setNewName(''); setNewDesc(''); }}
|
||||
className="px-4 py-2 rounded-lg text-sm text-gray-400 hover:text-white transition-colors"
|
||||
>
|
||||
取消
|
||||
</button>
|
||||
<button
|
||||
onClick={handleCreate}
|
||||
disabled={isCreating || !newName.trim()}
|
||||
className={cn(
|
||||
"px-4 py-2 rounded-lg text-sm font-medium flex items-center gap-2 transition-all",
|
||||
isCreating || !newName.trim()
|
||||
? "bg-cyan-500/50 text-black/70 cursor-not-allowed"
|
||||
: "bg-cyan-500 hover:bg-cyan-400 text-black"
|
||||
)}
|
||||
>
|
||||
{isCreating && <Loader2 size={14} className="animate-spin" />}
|
||||
创建
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,15 +1,112 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Settings, FileJson, ArrowRightLeft, Database } from 'lucide-react';
|
||||
import { Settings, Database, Trash2, Edit3, Plus, Loader2, X } from 'lucide-react';
|
||||
import { cn } from '../lib/utils';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { getTemplates, createTemplate, updateTemplate, deleteTemplate } from '../lib/api';
|
||||
import type { Template, TemplateClass } from '../store/useStore';
|
||||
|
||||
export function TemplateRegistry() {
|
||||
const [templates, setTemplates] = useState<any[]>([]);
|
||||
const templates = useStore((state) => state.templates);
|
||||
const setTemplates = useStore((state) => state.setTemplates);
|
||||
const addTemplate = useStore((state) => state.addTemplate);
|
||||
const updateTemplateStore = useStore((state) => state.updateTemplate);
|
||||
const removeTemplateStore = useStore((state) => state.removeTemplate);
|
||||
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [selectedTemplate, setSelectedTemplate] = useState<Template | null>(null);
|
||||
const [showModal, setShowModal] = useState(false);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
|
||||
const [editName, setEditName] = useState('');
|
||||
const [editDesc, setEditDesc] = useState('');
|
||||
const [editClasses, setEditClasses] = useState<TemplateClass[]>([]);
|
||||
const [editingClassId, setEditingClassId] = useState<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
fetch('/api/templates')
|
||||
.then(res => res.json())
|
||||
.then(data => setTemplates(data))
|
||||
.catch(console.error);
|
||||
}, []);
|
||||
setIsLoading(true);
|
||||
getTemplates()
|
||||
.then((data) => setTemplates(data))
|
||||
.catch(console.error)
|
||||
.finally(() => setIsLoading(false));
|
||||
}, [setTemplates]);
|
||||
|
||||
const openCreate = () => {
|
||||
setSelectedTemplate(null);
|
||||
setEditName('');
|
||||
setEditDesc('');
|
||||
setEditClasses([]);
|
||||
setShowModal(true);
|
||||
};
|
||||
|
||||
const openEdit = (template: Template) => {
|
||||
setSelectedTemplate(template);
|
||||
setEditName(template.name);
|
||||
setEditDesc(template.description || '');
|
||||
setEditClasses(template.classes ? [...template.classes] : []);
|
||||
setShowModal(true);
|
||||
};
|
||||
|
||||
const handleSave = async () => {
|
||||
if (!editName.trim()) return;
|
||||
setIsSaving(true);
|
||||
try {
|
||||
if (selectedTemplate) {
|
||||
const updated = await updateTemplate(selectedTemplate.id, {
|
||||
name: editName.trim(),
|
||||
description: editDesc.trim() || undefined,
|
||||
classes: editClasses,
|
||||
});
|
||||
updateTemplateStore(updated);
|
||||
} else {
|
||||
const created = await createTemplate({
|
||||
name: editName.trim(),
|
||||
description: editDesc.trim() || undefined,
|
||||
classes: editClasses,
|
||||
});
|
||||
addTemplate(created);
|
||||
}
|
||||
setShowModal(false);
|
||||
} catch (err) {
|
||||
console.error('Failed to save template:', err);
|
||||
} finally {
|
||||
setIsSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDelete = async (id: string) => {
|
||||
if (!confirm('确定要删除此模板吗?')) return;
|
||||
try {
|
||||
await deleteTemplate(id);
|
||||
removeTemplateStore(id);
|
||||
if (selectedTemplate?.id === id) {
|
||||
setSelectedTemplate(null);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to delete template:', err);
|
||||
}
|
||||
};
|
||||
|
||||
const addClass = () => {
|
||||
const newClass: TemplateClass = {
|
||||
id: `cls-${Date.now()}`,
|
||||
name: '新类别',
|
||||
color: '#06b6d4',
|
||||
zIndex: editClasses.length > 0 ? Math.max(...editClasses.map((c) => c.zIndex)) + 10 : 10,
|
||||
category: '未分类',
|
||||
};
|
||||
setEditClasses([...editClasses, newClass]);
|
||||
setEditingClassId(newClass.id);
|
||||
};
|
||||
|
||||
const updateClass = (id: string, updates: Partial<TemplateClass>) => {
|
||||
setEditClasses(editClasses.map((c) => (c.id === id ? { ...c, ...updates } : c)));
|
||||
};
|
||||
|
||||
const removeClass = (id: string) => {
|
||||
setEditClasses(editClasses.filter((c) => c.id !== id));
|
||||
};
|
||||
|
||||
const activeTemplate = selectedTemplate || templates[0] || null;
|
||||
|
||||
return (
|
||||
<div className="p-8 w-full h-full overflow-y-auto bg-[#0a0a0a]">
|
||||
@@ -22,90 +119,224 @@ export function TemplateRegistry() {
|
||||
<div className="xl:col-span-1 border-r border-white/5 pr-6">
|
||||
<div className="flex justify-between items-center mb-6">
|
||||
<h2 className="text-sm font-bold text-gray-500 uppercase tracking-widest">生效中模板架构清单</h2>
|
||||
<button className="text-cyan-400 hover:text-cyan-300 text-sm transition-colors">+ 新建方案</button>
|
||||
<button
|
||||
onClick={openCreate}
|
||||
className="text-cyan-400 hover:text-cyan-300 text-sm transition-colors flex items-center gap-1"
|
||||
>
|
||||
<Plus size={14} /> 新建方案
|
||||
</button>
|
||||
</div>
|
||||
<div className="space-y-3">
|
||||
{templates.map(t => (
|
||||
<div key={t.id} className="bg-[#111] border border-white/5 hover:border-cyan-500/50 p-4 rounded-xl cursor-pointer transition-all hover:shadow-lg hover:shadow-cyan-900/10">
|
||||
<h3 className="font-medium text-gray-200 mb-1">{t.name}</h3>
|
||||
<div className="flex items-center gap-4 text-xs text-gray-500 font-mono">
|
||||
<span>涵盖 {t.classes} 个字典大类</span>
|
||||
<span>挂载 {t.rules} 项解析规则</span>
|
||||
|
||||
{isLoading && templates.length === 0 ? (
|
||||
<div className="space-y-3">
|
||||
{Array.from({ length: 4 }).map((_, i) => (
|
||||
<div key={i} className="bg-[#111] border border-white/5 p-4 rounded-xl animate-pulse">
|
||||
<div className="h-4 bg-[#1a1a1a] rounded w-2/3 mb-2" />
|
||||
<div className="h-3 bg-[#1a1a1a] rounded w-1/2" />
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-3">
|
||||
{templates.map((t) => (
|
||||
<div
|
||||
key={t.id}
|
||||
className={cn(
|
||||
"bg-[#111] border p-4 rounded-xl cursor-pointer transition-all hover:shadow-lg hover:shadow-cyan-900/10 group",
|
||||
activeTemplate?.id === t.id ? "border-cyan-500/50" : "border-white/5 hover:border-cyan-500/50"
|
||||
)}
|
||||
onClick={() => setSelectedTemplate(t)}
|
||||
>
|
||||
<div className="flex justify-between items-start">
|
||||
<h3 className="font-medium text-gray-200 mb-1">{t.name}</h3>
|
||||
<div className="flex items-center gap-1 opacity-0 group-hover:opacity-100 transition-opacity">
|
||||
<button
|
||||
onClick={(e) => { e.stopPropagation(); openEdit(t); }}
|
||||
className="p-1 rounded text-gray-500 hover:text-cyan-400 transition-colors"
|
||||
>
|
||||
<Edit3 size={14} />
|
||||
</button>
|
||||
<button
|
||||
onClick={(e) => { e.stopPropagation(); handleDelete(t.id); }}
|
||||
className="p-1 rounded text-gray-500 hover:text-red-400 transition-colors"
|
||||
>
|
||||
<Trash2 size={14} />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center gap-4 text-xs text-gray-500 font-mono">
|
||||
<span>涵盖 {t.classes?.length ?? 0} 个字典大类</span>
|
||||
<span>挂载 {t.rules?.length ?? 0} 项解析规则</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="xl:col-span-2 space-y-6">
|
||||
<div className="bg-[#111] border border-white/5 rounded-xl p-6">
|
||||
<div className="flex items-center justify-between mb-6 border-b border-white/5 pb-4">
|
||||
<h2 className="text-lg font-medium text-gray-200 flex items-center gap-2"><Database size={18} /> Cityscapes_v2_Mapping</h2>
|
||||
<button className="bg-white/5 hover:bg-white/10 border border-white/10 px-4 py-1.5 rounded text-sm text-gray-300 transition-colors">修改库视图结构 (Schema)</button>
|
||||
<h2 className="text-lg font-medium text-gray-200 flex items-center gap-2">
|
||||
<Database size={18} />
|
||||
{activeTemplate ? activeTemplate.name : '未选择模板'}
|
||||
</h2>
|
||||
{activeTemplate && (
|
||||
<button
|
||||
onClick={() => openEdit(activeTemplate)}
|
||||
className="bg-white/5 hover:bg-white/10 border border-white/10 px-4 py-1.5 rounded text-sm text-gray-300 transition-colors flex items-center gap-2"
|
||||
>
|
||||
<Settings size={14} /> 修改库视图结构 (Schema)
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="space-y-6">
|
||||
<div>
|
||||
<h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest mb-4">特定领域分类渲染级重叠裁决权重阵列 (Painter's Algorithm Weight)</h3>
|
||||
<div className="space-y-2">
|
||||
{[
|
||||
{ l: 'pedestrian', z: 90, c: '#ec4899', t: '运动中物理特型 (Dynamic Entity)' },
|
||||
{ l: 'bicycle', z: 85, c: '#f59e0b', t: '运动中物理特型 (Dynamic Entity)' },
|
||||
{ l: 'vehicle_car', z: 80, c: '#6366f1', t: '运动中物理特型 (Dynamic Entity)' },
|
||||
{ l: 'traffic_sign', z: 60, c: '#eab308', t: '交通属性静态特型 (Static Entity)' },
|
||||
{ l: 'road_surface', z: 10, c: '#71717a', t: '全局视野底板 (Background / Floor)' },
|
||||
].map(cls => (
|
||||
<div key={cls.l} className="flex grid grid-cols-4 gap-4 p-3 bg-[#0d0d0d] border border-white/5 rounded items-center">
|
||||
<div className="col-span-1 flex items-center gap-2">
|
||||
<div className="w-3 h-3 rounded" style={{ backgroundColor: cls.c }}></div>
|
||||
<span className="font-medium text-sm text-gray-300">{cls.l}</span>
|
||||
{activeTemplate ? (
|
||||
<div className="space-y-6">
|
||||
<div>
|
||||
<h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest mb-4">
|
||||
特定领域分类渲染级重叠裁决权重阵列 (Painter's Algorithm Weight)
|
||||
</h3>
|
||||
<div className="space-y-2">
|
||||
{(activeTemplate.classes || []).sort((a, b) => b.zIndex - a.zIndex).map((cls) => (
|
||||
<div key={cls.id} className="grid grid-cols-4 gap-4 p-3 bg-[#0d0d0d] border border-white/5 rounded items-center">
|
||||
<div className="col-span-1 flex items-center gap-2">
|
||||
<div className="w-3 h-3 rounded" style={{ backgroundColor: cls.color }}></div>
|
||||
<span className="font-medium text-sm text-gray-300">{cls.name}</span>
|
||||
</div>
|
||||
<div className="col-span-1 font-mono text-xs text-gray-500">优先级 Z-Level: {cls.zIndex}</div>
|
||||
<div className="col-span-2 flex justify-end">
|
||||
<span className="bg-white/5 text-gray-400 text-xs px-2 py-1 rounded border border-white/10">{cls.category || '未分类'}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div className="col-span-1 font-mono text-xs text-gray-500">优先级 Z-Level: {cls.z}</div>
|
||||
<div className="col-span-2 flex justify-end">
|
||||
<span className="bg-white/5 text-gray-400 text-xs px-2 py-1 rounded border border-white/10">{cls.t}</span>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
))}
|
||||
{(activeTemplate.classes || []).length === 0 && (
|
||||
<div className="text-sm text-gray-500 text-center py-8">暂无分类定义</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<h3 className="text-[10px] font-bold text-gray-500 uppercase tracking-widest mb-4 flex items-center gap-2">
|
||||
<ArrowRightLeft size={14} /> 强兼容真实标签 (GT Source) 闭环降维转置拓扑结构约束表
|
||||
</h3>
|
||||
<div className="bg-[#0d0d0d] border border-white/5 rounded-lg overflow-hidden">
|
||||
<table className="w-full text-left text-sm text-gray-400">
|
||||
<thead className="bg-[#111] border-b border-white/5 text-xs uppercase text-gray-500 font-mono">
|
||||
<tr>
|
||||
<th className="px-4 py-3">原始 JSON 键 (Legacy Key)</th>
|
||||
<th className="px-4 py-3">映射降维引挚解析路径</th>
|
||||
<th className="px-4 py-3">并轨至标准分类</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody className="divide-y divide-white/5">
|
||||
<tr>
|
||||
<td className="px-4 py-3 font-mono text-gray-300">"car_sedan"</td>
|
||||
<td className="px-4 py-3 font-mono text-cyan-400">布尔合并聚类覆盖 (Logical OR)</td>
|
||||
<td className="px-4 py-3 font-medium text-gray-300">vehicle_car</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td className="px-4 py-3 font-mono text-gray-300">"car_suv"</td>
|
||||
<td className="px-4 py-3 font-mono text-cyan-400">布尔合并聚类覆盖 (Logical OR)</td>
|
||||
<td className="px-4 py-3 font-medium text-gray-300">vehicle_car</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td className="px-4 py-3 font-mono text-gray-300">"sidewalk_curb"</td>
|
||||
<td className="px-4 py-3 font-mono text-cyan-400">形态学极限分离-内切骨架法 (Skeletonization)</td>
|
||||
<td className="px-4 py-3 font-medium text-gray-300">road_curb</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="text-sm text-gray-500 text-center py-12">请从左侧选择一个模板或创建新模板</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{showModal && (
|
||||
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm">
|
||||
<div className="bg-[#111] border border-white/10 rounded-2xl p-6 w-full max-w-2xl max-h-[80vh] overflow-y-auto shadow-2xl">
|
||||
<div className="flex justify-between items-center mb-4">
|
||||
<h2 className="text-lg font-semibold text-white">{selectedTemplate ? '编辑模板' : '新建模板'}</h2>
|
||||
<button onClick={() => setShowModal(false)} className="text-gray-500 hover:text-white transition-colors">
|
||||
<X size={18} />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className="space-y-4 mb-6">
|
||||
<div>
|
||||
<label className="block text-xs font-medium text-gray-400 uppercase tracking-widest mb-2">模板名称</label>
|
||||
<input
|
||||
type="text"
|
||||
value={editName}
|
||||
onChange={(e) => setEditName(e.target.value)}
|
||||
className="w-full bg-[#1a1a1a] border border-white/10 rounded-lg px-4 py-3 text-sm focus:outline-none focus:border-cyan-500/50 focus:ring-1 focus:ring-cyan-500/50 transition-all"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-xs font-medium text-gray-400 uppercase tracking-widest mb-2">描述</label>
|
||||
<input
|
||||
type="text"
|
||||
value={editDesc}
|
||||
onChange={(e) => setEditDesc(e.target.value)}
|
||||
className="w-full bg-[#1a1a1a] border border-white/10 rounded-lg px-4 py-3 text-sm focus:outline-none focus:border-cyan-500/50 focus:ring-1 focus:ring-cyan-500/50 transition-all"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mb-4">
|
||||
<div className="flex justify-between items-center mb-3">
|
||||
<h3 className="text-xs font-bold text-gray-500 uppercase tracking-widest">分类定义</h3>
|
||||
<button
|
||||
onClick={addClass}
|
||||
className="text-cyan-400 hover:text-cyan-300 text-xs transition-colors flex items-center gap-1"
|
||||
>
|
||||
<Plus size={12} /> 添加分类
|
||||
</button>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
{editClasses.map((cls) => (
|
||||
<div key={cls.id} className="flex items-center gap-3 bg-[#0d0d0d] border border-white/5 rounded-lg p-3">
|
||||
<input
|
||||
type="color"
|
||||
value={cls.color}
|
||||
onChange={(e) => updateClass(cls.id, { color: e.target.value })}
|
||||
className="w-8 h-8 rounded bg-transparent border-0 cursor-pointer"
|
||||
/>
|
||||
{editingClassId === cls.id ? (
|
||||
<>
|
||||
<input
|
||||
type="text"
|
||||
value={cls.name}
|
||||
onChange={(e) => updateClass(cls.id, { name: e.target.value })}
|
||||
onBlur={() => setEditingClassId(null)}
|
||||
onKeyDown={(e) => e.key === 'Enter' && setEditingClassId(null)}
|
||||
autoFocus
|
||||
className="flex-1 bg-[#1a1a1a] border border-white/10 rounded px-2 py-1 text-sm text-white"
|
||||
/>
|
||||
<input
|
||||
type="number"
|
||||
value={cls.zIndex}
|
||||
onChange={(e) => updateClass(cls.id, { zIndex: parseInt(e.target.value) || 0 })}
|
||||
className="w-20 bg-[#1a1a1a] border border-white/10 rounded px-2 py-1 text-sm text-white font-mono"
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<span
|
||||
className="flex-1 text-sm text-gray-300 cursor-pointer"
|
||||
onClick={() => setEditingClassId(cls.id)}
|
||||
>
|
||||
{cls.name}
|
||||
</span>
|
||||
<span className="w-20 text-sm text-gray-500 font-mono text-right">z:{cls.zIndex}</span>
|
||||
</>
|
||||
)}
|
||||
<button onClick={() => removeClass(cls.id)} className="text-gray-500 hover:text-red-400 transition-colors">
|
||||
<Trash2 size={14} />
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
{editClasses.length === 0 && (
|
||||
<div className="text-sm text-gray-500 text-center py-4">暂无分类,请点击上方按钮添加</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex justify-end gap-3">
|
||||
<button
|
||||
onClick={() => setShowModal(false)}
|
||||
className="px-4 py-2 rounded-lg text-sm text-gray-400 hover:text-white transition-colors"
|
||||
>
|
||||
取消
|
||||
</button>
|
||||
<button
|
||||
onClick={handleSave}
|
||||
disabled={isSaving || !editName.trim()}
|
||||
className={cn(
|
||||
"px-4 py-2 rounded-lg text-sm font-medium flex items-center gap-2 transition-all",
|
||||
isSaving || !editName.trim()
|
||||
? "bg-cyan-500/50 text-black/70 cursor-not-allowed"
|
||||
: "bg-cyan-500 hover:bg-cyan-400 text-black"
|
||||
)}
|
||||
>
|
||||
{isSaving && <Loader2 size={14} className="animate-spin" />}
|
||||
保存
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import React from 'react';
|
||||
import { MousePointer2, Hexagon, Square, Circle, Minus, Combine, Scissors, Wand2, Undo, Redo, Crosshair } from 'lucide-react';
|
||||
import { MousePointer2, Hexagon, Square, Circle, Minus, Combine, Scissors, Wand2, Undo, Redo, Crosshair, PlusCircle, MinusCircle, SquareDashed } from 'lucide-react';
|
||||
import { cn } from '../lib/utils';
|
||||
|
||||
interface ToolsPaletteProps {
|
||||
@@ -20,6 +20,12 @@ export function ToolsPalette({ activeTool, setActiveTool, onTriggerAI }: ToolsPa
|
||||
{ id: 'area_remove', icon: Scissors, label: '重叠区域去除 (-)' },
|
||||
];
|
||||
|
||||
const aiTools = [
|
||||
{ id: 'point_pos', icon: PlusCircle, label: '正向选点 (SAM)', color: 'text-green-400', bg: 'bg-green-500/10', border: 'border-green-500/30' },
|
||||
{ id: 'point_neg', icon: MinusCircle, label: '反向选点 (SAM)', color: 'text-red-400', bg: 'bg-red-500/10', border: 'border-red-500/30' },
|
||||
{ id: 'box_select', icon: SquareDashed, label: '边界框选 (SAM)', color: 'text-blue-400', bg: 'bg-blue-500/10', border: 'border-blue-500/30' },
|
||||
];
|
||||
|
||||
return (
|
||||
<div className="w-12 bg-[#0d0d0d] border-r border-white/5 flex flex-col items-center py-4 shrink-0 z-10">
|
||||
<div className="flex flex-col gap-4 w-full px-2">
|
||||
@@ -47,6 +53,26 @@ export function ToolsPalette({ activeTool, setActiveTool, onTriggerAI }: ToolsPa
|
||||
|
||||
<div className="w-full h-px bg-white/10 my-1" />
|
||||
|
||||
{aiTools.map(tool => {
|
||||
const Icon = tool.icon;
|
||||
const isActive = activeTool === tool.id;
|
||||
return (
|
||||
<button
|
||||
key={tool.id}
|
||||
onClick={() => setActiveTool(tool.id)}
|
||||
title={tool.label}
|
||||
className={cn(
|
||||
"w-10 h-10 rounded-lg flex items-center justify-center transition-all p-2 border",
|
||||
isActive
|
||||
? `${tool.bg} ${tool.color} ${tool.border} shadow-[0_0_10px_rgba(255,255,255,0.05)]`
|
||||
: "text-gray-500 hover:bg-white/5 hover:text-white border-transparent"
|
||||
)}
|
||||
>
|
||||
<Icon size={18} strokeWidth={isActive ? 2.5 : 2} />
|
||||
</button>
|
||||
)
|
||||
})}
|
||||
|
||||
<button
|
||||
onClick={() => {
|
||||
setActiveTool('sam_trigger');
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import React, { useState } from 'react';
|
||||
import React from 'react';
|
||||
import { useStore } from '../store/useStore';
|
||||
import { CanvasArea } from './CanvasArea';
|
||||
import { ToolsPalette } from './ToolsPalette';
|
||||
import { OntologyInspector } from './OntologyInspector';
|
||||
import { FrameTimeline } from './FrameTimeline';
|
||||
|
||||
export function VideoWorkspace({ onNavigateToAI }: { onNavigateToAI?: () => void }) {
|
||||
const [activeTool, setActiveTool] = useState<string>('move');
|
||||
const activeTool = useStore((state) => state.activeTool);
|
||||
const setActiveTool = useStore((state) => state.setActiveTool);
|
||||
|
||||
return (
|
||||
<div className="w-full h-full flex flex-col bg-[#0a0a0a]">
|
||||
|
||||
135
src/lib/api.ts
Normal file
135
src/lib/api.ts
Normal file
@@ -0,0 +1,135 @@
|
||||
import axios, { AxiosError } from 'axios';
|
||||
import type { Project, Template } from '../store/useStore';
|
||||
|
||||
const apiClient = axios.create({
|
||||
baseURL: 'http://localhost:8000',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
timeout: 30000,
|
||||
});
|
||||
|
||||
// Request interceptor: attach token
|
||||
apiClient.interceptors.request.use(
|
||||
(config) => {
|
||||
const token = localStorage.getItem('token');
|
||||
if (token) {
|
||||
config.headers.Authorization = `Bearer ${token}`;
|
||||
}
|
||||
return config;
|
||||
},
|
||||
(error) => Promise.reject(error)
|
||||
);
|
||||
|
||||
// Response interceptor: handle errors
|
||||
apiClient.interceptors.response.use(
|
||||
(response) => response,
|
||||
(error: AxiosError) => {
|
||||
if (error.response?.status === 401) {
|
||||
localStorage.removeItem('token');
|
||||
window.location.reload();
|
||||
}
|
||||
return Promise.reject(error);
|
||||
}
|
||||
);
|
||||
|
||||
// Auth
|
||||
export async function login(username: string, password: string): Promise<{ token: string }> {
|
||||
const response = await apiClient.post('/api/auth/login', { username, password });
|
||||
return response.data;
|
||||
}
|
||||
|
||||
// Projects
|
||||
export async function getProjects(): Promise<Project[]> {
|
||||
const response = await apiClient.get('/api/projects');
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function createProject(payload: {
|
||||
name: string;
|
||||
description?: string;
|
||||
}): Promise<Project> {
|
||||
const response = await apiClient.post('/api/projects', payload);
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function updateProject(id: string, payload: Partial<Project>): Promise<Project> {
|
||||
const response = await apiClient.put(`/api/projects/${id}`, payload);
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function deleteProject(id: string): Promise<void> {
|
||||
await apiClient.delete(`/api/projects/${id}`);
|
||||
}
|
||||
|
||||
// Templates
|
||||
export async function getTemplates(): Promise<Template[]> {
|
||||
const response = await apiClient.get('/api/templates');
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function createTemplate(payload: {
|
||||
name: string;
|
||||
description?: string;
|
||||
classes?: { name: string; color: string; zIndex: number; category?: string }[];
|
||||
}): Promise<Template> {
|
||||
const response = await apiClient.post('/api/templates', payload);
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function updateTemplate(id: string, payload: Partial<Template>): Promise<Template> {
|
||||
const response = await apiClient.put(`/api/templates/${id}`, payload);
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export async function deleteTemplate(id: string): Promise<void> {
|
||||
await apiClient.delete(`/api/templates/${id}`);
|
||||
}
|
||||
|
||||
// Media
|
||||
export async function uploadMedia(file: File, projectId?: string): Promise<{ url: string; id: string }> {
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
if (projectId) {
|
||||
formData.append('project_id', projectId);
|
||||
}
|
||||
const response = await apiClient.post('/api/media/upload', formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
},
|
||||
});
|
||||
return response.data;
|
||||
}
|
||||
|
||||
// AI Prediction
|
||||
export async function predictMask(payload: {
|
||||
imageUrl: string;
|
||||
points?: { x: number; y: number; type: 'pos' | 'neg' }[];
|
||||
box?: { x1: number; y1: number; x2: number; y2: number };
|
||||
text?: string;
|
||||
modelSize?: string;
|
||||
}): Promise<{
|
||||
masks: Array<{
|
||||
id: string;
|
||||
pathData: string;
|
||||
label: string;
|
||||
color: string;
|
||||
segmentation: number[][];
|
||||
bbox: [number, number, number, number];
|
||||
area: number;
|
||||
confidence: number;
|
||||
}>;
|
||||
}> {
|
||||
const response = await apiClient.post('/api/ai/predict', payload);
|
||||
return response.data;
|
||||
}
|
||||
|
||||
// Export
|
||||
export async function exportCoco(projectId: string): Promise<Blob> {
|
||||
const response = await apiClient.get(`/api/export/coco/${projectId}`, {
|
||||
responseType: 'blob',
|
||||
});
|
||||
return response.data;
|
||||
}
|
||||
|
||||
export default apiClient;
|
||||
104
src/lib/websocket.ts
Normal file
104
src/lib/websocket.ts
Normal file
@@ -0,0 +1,104 @@
|
||||
type ProgressCallback = (data: ProgressMessage) => void;
|
||||
|
||||
interface ProgressMessage {
|
||||
type: 'progress' | 'status' | 'error' | 'complete';
|
||||
taskId?: string;
|
||||
filename?: string;
|
||||
progress?: number;
|
||||
status?: string;
|
||||
message?: string;
|
||||
timestamp?: string;
|
||||
}
|
||||
|
||||
class ProgressWebSocket {
|
||||
private ws: WebSocket | null = null;
|
||||
private url: string;
|
||||
private callbacks: Set<ProgressCallback> = new Set();
|
||||
private reconnectTimer: ReturnType<typeof setTimeout> | null = null;
|
||||
private reconnectInterval = 3000;
|
||||
private maxReconnectInterval = 30000;
|
||||
private shouldReconnect = false;
|
||||
private currentInterval = 3000;
|
||||
|
||||
constructor(url = 'ws://localhost:8000/ws/progress') {
|
||||
this.url = url;
|
||||
}
|
||||
|
||||
connect() {
|
||||
if (this.ws && (this.ws.readyState === WebSocket.OPEN || this.ws.readyState === WebSocket.CONNECTING)) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.shouldReconnect = true;
|
||||
|
||||
try {
|
||||
this.ws = new WebSocket(this.url);
|
||||
|
||||
this.ws.onopen = () => {
|
||||
this.currentInterval = this.reconnectInterval;
|
||||
console.log('[WebSocket] Connected to progress stream');
|
||||
};
|
||||
|
||||
this.ws.onmessage = (event) => {
|
||||
try {
|
||||
const data: ProgressMessage = JSON.parse(event.data);
|
||||
this.callbacks.forEach((cb) => cb(data));
|
||||
} catch (err) {
|
||||
console.error('[WebSocket] Failed to parse message:', err);
|
||||
}
|
||||
};
|
||||
|
||||
this.ws.onclose = () => {
|
||||
console.log('[WebSocket] Connection closed');
|
||||
if (this.shouldReconnect) {
|
||||
this.scheduleReconnect();
|
||||
}
|
||||
};
|
||||
|
||||
this.ws.onerror = (err) => {
|
||||
console.error('[WebSocket] Error:', err);
|
||||
this.ws?.close();
|
||||
};
|
||||
} catch (err) {
|
||||
console.error('[WebSocket] Failed to connect:', err);
|
||||
this.scheduleReconnect();
|
||||
}
|
||||
}
|
||||
|
||||
disconnect() {
|
||||
this.shouldReconnect = false;
|
||||
if (this.reconnectTimer) {
|
||||
clearTimeout(this.reconnectTimer);
|
||||
this.reconnectTimer = null;
|
||||
}
|
||||
if (this.ws) {
|
||||
this.ws.close();
|
||||
this.ws = null;
|
||||
}
|
||||
}
|
||||
|
||||
onProgress(callback: ProgressCallback) {
|
||||
this.callbacks.add(callback);
|
||||
return () => {
|
||||
this.callbacks.delete(callback);
|
||||
};
|
||||
}
|
||||
|
||||
private scheduleReconnect() {
|
||||
if (this.reconnectTimer) {
|
||||
clearTimeout(this.reconnectTimer);
|
||||
}
|
||||
this.reconnectTimer = setTimeout(() => {
|
||||
console.log(`[WebSocket] Reconnecting in ${this.currentInterval}ms...`);
|
||||
this.connect();
|
||||
this.currentInterval = Math.min(this.currentInterval * 1.5, this.maxReconnectInterval);
|
||||
}, this.currentInterval);
|
||||
}
|
||||
|
||||
isConnected(): boolean {
|
||||
return this.ws !== null && this.ws.readyState === WebSocket.OPEN;
|
||||
}
|
||||
}
|
||||
|
||||
export const progressWS = new ProgressWebSocket();
|
||||
export type { ProgressMessage };
|
||||
195
src/store/useStore.ts
Normal file
195
src/store/useStore.ts
Normal file
@@ -0,0 +1,195 @@
|
||||
import { create } from 'zustand';
|
||||
|
||||
export interface Project {
|
||||
id: string;
|
||||
name: string;
|
||||
description?: string;
|
||||
status: 'Ready' | 'Parsing' | 'Error';
|
||||
fps?: string;
|
||||
frames?: number;
|
||||
thumbnail?: string;
|
||||
createdAt?: string;
|
||||
updatedAt?: string;
|
||||
}
|
||||
|
||||
export interface Frame {
|
||||
id: string;
|
||||
projectId: string;
|
||||
index: number;
|
||||
url: string;
|
||||
width: number;
|
||||
height: number;
|
||||
timestamp?: string;
|
||||
}
|
||||
|
||||
export interface Annotation {
|
||||
id: string;
|
||||
frameId: string;
|
||||
type: 'polygon' | 'rectangle' | 'circle' | 'point' | 'mask';
|
||||
points: number[];
|
||||
label: string;
|
||||
color: string;
|
||||
zIndex?: number;
|
||||
confidence?: number;
|
||||
metadata?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export interface Mask {
|
||||
id: string;
|
||||
frameId: string;
|
||||
pathData: string;
|
||||
label: string;
|
||||
color: string;
|
||||
opacity?: number;
|
||||
segmentation?: number[][];
|
||||
bbox?: [number, number, number, number];
|
||||
area?: number;
|
||||
}
|
||||
|
||||
export interface Template {
|
||||
id: string;
|
||||
name: string;
|
||||
description?: string;
|
||||
classes: TemplateClass[];
|
||||
rules?: TemplateRule[];
|
||||
createdAt?: string;
|
||||
updatedAt?: string;
|
||||
}
|
||||
|
||||
export interface TemplateClass {
|
||||
id: string;
|
||||
name: string;
|
||||
color: string;
|
||||
zIndex: number;
|
||||
category?: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
export interface TemplateRule {
|
||||
id: string;
|
||||
name: string;
|
||||
sourceKey: string;
|
||||
targetKey: string;
|
||||
operation: string;
|
||||
}
|
||||
|
||||
export interface AppState {
|
||||
// Auth
|
||||
isAuthenticated: boolean;
|
||||
token: string | null;
|
||||
login: (token: string) => void;
|
||||
logout: () => void;
|
||||
|
||||
// Projects
|
||||
projects: Project[];
|
||||
currentProject: Project | null;
|
||||
setProjects: (projects: Project[]) => void;
|
||||
setCurrentProject: (project: Project | null) => void;
|
||||
addProject: (project: Project) => void;
|
||||
updateProject: (project: Project) => void;
|
||||
|
||||
// Workspace
|
||||
activeModule: string;
|
||||
activeTool: string;
|
||||
frames: Frame[];
|
||||
currentFrameIndex: number;
|
||||
annotations: Annotation[];
|
||||
masks: Mask[];
|
||||
setActiveModule: (module: string) => void;
|
||||
setActiveTool: (tool: string) => void;
|
||||
setFrames: (frames: Frame[]) => void;
|
||||
setCurrentFrame: (index: number) => void;
|
||||
addAnnotation: (annotation: Annotation) => void;
|
||||
addMask: (mask: Mask) => void;
|
||||
clearMasks: () => void;
|
||||
removeAnnotation: (id: string) => void;
|
||||
|
||||
// Templates
|
||||
templates: Template[];
|
||||
setTemplates: (templates: Template[]) => void;
|
||||
addTemplate: (template: Template) => void;
|
||||
updateTemplate: (template: Template) => void;
|
||||
removeTemplate: (id: string) => void;
|
||||
|
||||
// UI
|
||||
isLoading: boolean;
|
||||
error: string | null;
|
||||
setLoading: (loading: boolean) => void;
|
||||
setError: (error: string | null) => void;
|
||||
}
|
||||
|
||||
export const useStore = create<AppState>((set) => ({
|
||||
// Auth
|
||||
isAuthenticated: false,
|
||||
token: null,
|
||||
login: (token: string) => {
|
||||
localStorage.setItem('token', token);
|
||||
set({ isAuthenticated: true, token });
|
||||
},
|
||||
logout: () => {
|
||||
localStorage.removeItem('token');
|
||||
set({
|
||||
isAuthenticated: false,
|
||||
token: null,
|
||||
currentProject: null,
|
||||
projects: [],
|
||||
templates: [],
|
||||
frames: [],
|
||||
annotations: [],
|
||||
masks: [],
|
||||
});
|
||||
},
|
||||
|
||||
// Projects
|
||||
projects: [],
|
||||
currentProject: null,
|
||||
setProjects: (projects: Project[]) => set({ projects }),
|
||||
setCurrentProject: (currentProject: Project | null) => set({ currentProject }),
|
||||
addProject: (project: Project) =>
|
||||
set((state) => ({ projects: [project, ...state.projects] })),
|
||||
updateProject: (project: Project) =>
|
||||
set((state) => ({
|
||||
projects: state.projects.map((p) => (p.id === project.id ? project : p)),
|
||||
})),
|
||||
|
||||
// Workspace
|
||||
activeModule: 'workspace',
|
||||
activeTool: 'move',
|
||||
frames: [],
|
||||
currentFrameIndex: 0,
|
||||
annotations: [],
|
||||
masks: [],
|
||||
setActiveModule: (activeModule: string) => set({ activeModule }),
|
||||
setActiveTool: (activeTool: string) => set({ activeTool }),
|
||||
setFrames: (frames: Frame[]) => set({ frames }),
|
||||
setCurrentFrame: (currentFrameIndex: number) => set({ currentFrameIndex }),
|
||||
addAnnotation: (annotation: Annotation) =>
|
||||
set((state) => ({ annotations: [...state.annotations, annotation] })),
|
||||
addMask: (mask: Mask) =>
|
||||
set((state) => ({ masks: [...state.masks, mask] })),
|
||||
clearMasks: () => set({ masks: [] }),
|
||||
removeAnnotation: (id: string) =>
|
||||
set((state) => ({
|
||||
annotations: state.annotations.filter((a) => a.id !== id),
|
||||
})),
|
||||
|
||||
// Templates
|
||||
templates: [],
|
||||
setTemplates: (templates: Template[]) => set({ templates }),
|
||||
addTemplate: (template: Template) =>
|
||||
set((state) => ({ templates: [...state.templates, template] })),
|
||||
updateTemplate: (template: Template) =>
|
||||
set((state) => ({
|
||||
templates: state.templates.map((t) => (t.id === template.id ? template : t)),
|
||||
})),
|
||||
removeTemplate: (id: string) =>
|
||||
set((state) => ({
|
||||
templates: state.templates.filter((t) => t.id !== id),
|
||||
})),
|
||||
|
||||
// UI
|
||||
isLoading: false,
|
||||
error: null,
|
||||
setLoading: (isLoading: boolean) => set({ isLoading }),
|
||||
setError: (error: string | null) => set({ error }),
|
||||
}));
|
||||
66
start_services.sh
Executable file
66
start_services.sh
Executable file
@@ -0,0 +1,66 @@
|
||||
#!/bin/bash
|
||||
# 语义分割系统 — 一键启动脚本
|
||||
# 时间戳: 2026-04-29-21-51-19
|
||||
|
||||
set -e
|
||||
|
||||
PROJECT_DIR="/home/wkmgc/Desktop/Seg_Server"
|
||||
CONDA_ENV="seg_server"
|
||||
|
||||
echo "========================================"
|
||||
echo " 语义分割系统全栈启动"
|
||||
echo "========================================"
|
||||
|
||||
# 1. 检查 PostgreSQL
|
||||
echo "[1/5] 检查 PostgreSQL..."
|
||||
if ! pg_isready -q; then
|
||||
echo "Wkmgc" | sudo -S systemctl start postgresql
|
||||
sleep 1
|
||||
fi
|
||||
pg_isready && echo " ✓ PostgreSQL 就绪"
|
||||
|
||||
# 2. 检查 Redis
|
||||
echo "[2/5] 检查 Redis..."
|
||||
if ! redis-cli ping > /dev/null 2>&1; then
|
||||
echo "Wkmgc" | sudo -S systemctl start redis-server
|
||||
sleep 1
|
||||
fi
|
||||
redis-cli ping && echo " ✓ Redis 就绪"
|
||||
|
||||
# 3. 检查 MinIO
|
||||
echo "[3/5] 检查 MinIO..."
|
||||
if ! curl -s http://localhost:9000/minio/health/live > /dev/null; then
|
||||
nohup minio server /home/wkmgc/minio_data --console-address :9001 > /tmp/minio.log 2>&1 &
|
||||
sleep 3
|
||||
fi
|
||||
curl -s http://localhost:9000/minio/health/live > /dev/null && echo " ✓ MinIO 就绪 (http://localhost:9001)"
|
||||
|
||||
# 4. 启动 FastAPI 后端
|
||||
echo "[4/5] 启动 FastAPI 后端..."
|
||||
source /home/wkmgc/miniconda3/etc/profile.d/conda.sh
|
||||
conda activate "$CONDA_ENV"
|
||||
cd "$PROJECT_DIR/backend"
|
||||
nohup uvicorn main:app --host 0.0.0.0 --port 8000 --reload > /tmp/fastapi.log 2>&1 &
|
||||
sleep 2
|
||||
echo " ✓ FastAPI 已启动 (http://localhost:8000/docs)"
|
||||
|
||||
# 5. 启动前端
|
||||
echo "[5/5] 启动前端..."
|
||||
cd "$PROJECT_DIR"
|
||||
nohup npm start > /tmp/frontend.log 2>&1 &
|
||||
sleep 2
|
||||
echo " ✓ 前端已启动 (http://localhost:3000)"
|
||||
|
||||
echo ""
|
||||
echo "========================================"
|
||||
echo " 所有服务已启动"
|
||||
echo "========================================"
|
||||
echo "前端: http://localhost:3000"
|
||||
echo "后端 API: http://localhost:8000/docs"
|
||||
echo "MinIO: http://localhost:9001"
|
||||
echo ""
|
||||
echo "日志文件:"
|
||||
echo " FastAPI: /tmp/fastapi.log"
|
||||
echo " 前端: /tmp/frontend.log"
|
||||
echo " MinIO: /tmp/minio.log"
|
||||
echo "========================================"
|
||||
185
工程分析/实现方案-2026-04-29-21-51-19.md
Normal file
185
工程分析/实现方案-2026-04-29-21-51-19.md
Normal file
@@ -0,0 +1,185 @@
|
||||
# 实现方案 - 2026-04-29-21-51-19
|
||||
|
||||
## 对应需求
|
||||
- 需求分析文档: `需求分析-2026-04-29-21-51-19.md`
|
||||
|
||||
## 方案概述
|
||||
在 RTX 4090 + Ubuntu 22.04 环境下,将纯前端 React 应用改造为全栈语义分割系统。本次为**骨架搭建 + 核心能力落地**,优先保证各模块可独立运行并互联互通。
|
||||
|
||||
## 整体架构
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ 前端层 (React 19) │
|
||||
│ Zustand (状态) + Axios (HTTP) + WebSocket (实时进度) │
|
||||
│ Konva (Canvas) + TailwindCSS (样式) │
|
||||
└────────────────────────────┬────────────────────────────────────┘
|
||||
│ HTTP / WebSocket
|
||||
┌────────────────────────────▼────────────────────────────────────┐
|
||||
│ 后端层 (FastAPI + Python) │
|
||||
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
|
||||
│ │ 项目/模板 API │ │ 媒体解析 API │ │ AI 推理 API │ │
|
||||
│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │
|
||||
│ └─────────────────┼─────────────────┘ │
|
||||
│ │ SQLAlchemy │
|
||||
│ ┌────────────────────────▼────────────────────────┐ │
|
||||
│ │ PostgreSQL (关系数据) │ │
|
||||
│ └─────────────────────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌────────────────────────▼────────────────────────┐ │
|
||||
│ │ MinIO (对象存储: 视频/帧/Mask) │ │
|
||||
│ └─────────────────────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌────────────────────────▼────────────────────────┐ │
|
||||
│ │ Redis (缓存 + 任务队列) │ │
|
||||
│ └─────────────────────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌────────────────────────▼────────────────────────┐ │
|
||||
│ │ SAM 3 (GPU 推理节点) │ │
|
||||
│ └─────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## 修改/新增文件清单
|
||||
|
||||
### A. 基础设施层
|
||||
|
||||
#### A1. 系统服务安装
|
||||
- **操作**: `sudo apt install postgresql postgresql-contrib redis-server ffmpeg`
|
||||
- **操作**: 下载并启动 MinIO 二进制
|
||||
- **说明**: 配置 PostgreSQL 数据库/用户,Redis 默认端口,MinIO 端口 9000/9001
|
||||
|
||||
#### A2. Conda 环境
|
||||
- **操作**: `conda create -n seg_server python=3.11`
|
||||
- **操作**: 安装 PyTorch 2.5+ CUDA 13.x、FastAPI、SQLAlchemy、psycopg2-binary、redis-py、minio、uvicorn、python-multipart、celery、opencv-python、pillow、scikit-image、pydicom、numpy
|
||||
|
||||
### B. 后端层 (新增 backend/ 目录)
|
||||
|
||||
#### B1. `backend/main.py`
|
||||
- **类型**: 新增
|
||||
- **内容**: FastAPI 入口,CORS 配置, lifespan 管理(启动/关闭数据库、Redis、MinIO 连接)
|
||||
|
||||
#### B2. `backend/config.py`
|
||||
- **类型**: 新增
|
||||
- **内容**: 环境变量配置(DB_URL、REDIS_URL、MINIO_ENDPOINT、SAM_MODEL_PATH 等)
|
||||
|
||||
#### B3. `backend/database.py`
|
||||
- **类型**: 新增
|
||||
- **内容**: SQLAlchemy engine + session + Base,PostgreSQL 连接
|
||||
|
||||
#### B4. `backend/models.py`
|
||||
- **类型**: 新增
|
||||
- **内容**: 数据库 ORM 模型:Project、Frame、Annotation、Template、Mask
|
||||
|
||||
#### B5. `backend/schemas.py`
|
||||
- **类型**: 新增
|
||||
- **内容**: Pydantic 数据校验模型
|
||||
|
||||
#### B6. `backend/minio_client.py`
|
||||
- **类型**: 新增
|
||||
- **内容**: MinIO 客户端封装(上传、下载、预签名 URL)
|
||||
|
||||
#### B7. `backend/redis_client.py`
|
||||
- **类型**: 新增
|
||||
- **内容**: Redis 客户端封装
|
||||
|
||||
#### B8. `backend/routers/projects.py`
|
||||
- **类型**: 新增
|
||||
- **内容**: 项目 CRUD API
|
||||
|
||||
#### B9. `backend/routers/templates.py`
|
||||
- **类型**: 新增
|
||||
- **内容**: 本体模板 API
|
||||
|
||||
#### B10. `backend/routers/media.py`
|
||||
- **类型**: 新增
|
||||
- **内容**: 视频/图片/DCM 上传 API,FFmpeg 拆帧任务触发
|
||||
|
||||
#### B11. `backend/routers/ai.py`
|
||||
- **类型**: 新增
|
||||
- **内容**: SAM 3 推理 API(point/box/semantic),Mask 生成与存储
|
||||
|
||||
#### B12. `backend/routers/export.py`
|
||||
- **类型**: 新增
|
||||
- **内容**: 标注数据导出(COCO JSON、PNG Mask)
|
||||
|
||||
#### B13. `backend/services/sam3_engine.py`
|
||||
- **类型**: 新增
|
||||
- **内容**: SAM 3 模型加载、推理封装、GPU 内存管理
|
||||
|
||||
#### B14. `backend/services/frame_parser.py`
|
||||
- **类型**: 新增
|
||||
- **内容**: FFmpeg 视频拆帧、DCM 影像逐帧提取、帧上传 MinIO
|
||||
|
||||
#### B15. `backend/services/task_queue.py`
|
||||
- **类型**: 新增
|
||||
- **内容**: Celery + Redis 异步任务队列封装
|
||||
|
||||
#### B16. `backend/requirements.txt`
|
||||
- **类型**: 新增
|
||||
- **内容**: Python 依赖清单
|
||||
|
||||
### C. 前端层 (修改现有 src/)
|
||||
|
||||
#### C1. `src/store/index.ts` (新增)
|
||||
- **内容**: Zustand 全局 Store(project、workspace、annotations、ui 状态)
|
||||
|
||||
#### C2. `src/lib/api.ts` (新增)
|
||||
- **内容**: Axios 实例封装,baseURL 指向 FastAPI
|
||||
|
||||
#### C3. `src/lib/websocket.ts` (新增)
|
||||
- **内容**: WebSocket 客户端,接收解析进度
|
||||
|
||||
#### C4. `src/App.tsx` (修改)
|
||||
- **内容**: 移除 Login 硬编码,接入真实登录 API;Provider 包裹
|
||||
|
||||
#### C5. `src/components/ProjectLibrary.tsx` (修改)
|
||||
- **内容**: 从后端 API 加载项目列表,支持创建/删除
|
||||
|
||||
#### C6. `src/components/TemplateRegistry.tsx` (修改)
|
||||
- **内容**: 从后端 API 加载本体字典,支持动态增删
|
||||
|
||||
#### C7. `src/components/CanvasArea.tsx` (修改)
|
||||
- **内容**: 点击画布捕获坐标,发送至后端 AI 接口,接收并渲染 Mask Path
|
||||
|
||||
#### C8. `src/components/AISegmentation.tsx` (修改)
|
||||
- **内容**: 对接后端 SAM 3 推理接口,显示推理状态
|
||||
|
||||
#### C9. `src/components/Dashboard.tsx` (修改)
|
||||
- **内容**: 显示真实解析队列进度(WebSocket)
|
||||
|
||||
### D. 部署与运维
|
||||
|
||||
#### D1. `start_services.sh` (新增)
|
||||
- **内容**: 一键启动 PostgreSQL、Redis、MinIO、FastAPI 的脚本
|
||||
|
||||
#### D2. `backend/download_sam3.py` (新增)
|
||||
- **内容**: SAM 3 模型权重自动下载脚本
|
||||
|
||||
## 新增依赖
|
||||
|
||||
### Python (conda 环境)
|
||||
```
|
||||
fastapi uvicorn[standard] python-multipart
|
||||
sqlalchemy psycopg2-binary alembic
|
||||
redis celery
|
||||
minio
|
||||
opencv-python pillow scikit-image pydicom numpy
|
||||
```
|
||||
|
||||
### 前端 (npm)
|
||||
```
|
||||
zustand axios
|
||||
```
|
||||
|
||||
## 兼容性分析
|
||||
- Express `server.ts` 将被保留但不再作为默认启动方式,FastAPI 成为主后端
|
||||
- 前端路由逻辑不变,仅数据获取方式从内存/mock 改为 HTTP API
|
||||
- 回滚策略: 回退到 `npm start` 仍可运行旧版 Express 前端
|
||||
|
||||
## 预估工作量
|
||||
- 基础设施: 20 分钟
|
||||
- 后端骨架: 40 分钟
|
||||
- 前端改造: 30 分钟
|
||||
- SAM 3 部署: 20 分钟
|
||||
- 联调验证: 20 分钟
|
||||
78
工程分析/测试方案-2026-04-29-21-51-19.md
Normal file
78
工程分析/测试方案-2026-04-29-21-51-19.md
Normal file
@@ -0,0 +1,78 @@
|
||||
# 测试方案 - 2026-04-29-21-51-19
|
||||
|
||||
## 对应实现方案
|
||||
- 实现方案文档: `实现方案-2026-04-29-21-51-19.md`
|
||||
|
||||
## 测试范围
|
||||
- 基础设施服务可达性
|
||||
- FastAPI 后端启动与 API 响应
|
||||
- 前端构建与后端联调
|
||||
- SAM 3 模型加载与 GPU 可用性
|
||||
- 文件上传与解析流程
|
||||
|
||||
## 测试用例
|
||||
|
||||
### 用例 1: 基础设施服务验证
|
||||
- **前置条件**: 执行 start_services.sh
|
||||
- **操作步骤**:
|
||||
1. `curl http://localhost:9000/minio/health/live` → MinIO
|
||||
2. `redis-cli ping` → Redis
|
||||
3. `sudo -u postgres psql -c "\\l"` → PostgreSQL
|
||||
- **预期结果**: 全部返回正常响应
|
||||
- **通过标准**: MinIO 200, Redis PONG, PostgreSQL 显示数据库列表
|
||||
|
||||
### 用例 2: Conda 环境 + GPU 验证
|
||||
- **前置条件**: conda 环境已创建
|
||||
- **操作步骤**:
|
||||
1. `conda activate seg_server`
|
||||
2. `python -c "import torch; print(torch.cuda.is_available())"`
|
||||
- **预期结果**: 输出 True
|
||||
- **通过标准**: PyTorch 识别到 CUDA
|
||||
|
||||
### 用例 3: SAM 3 权重下载验证
|
||||
- **前置条件**: 运行 download_sam3.py
|
||||
- **操作步骤**: 检查权重文件存在且大小合理
|
||||
- **预期结果**: .pt/.pth 文件存在于 models/ 目录
|
||||
- **通过标准**: 文件大小 > 100MB
|
||||
|
||||
### 用例 4: FastAPI 启动验证
|
||||
- **前置条件**: 依赖安装完成
|
||||
- **操作步骤**:
|
||||
1. `cd backend && uvicorn main:app --host 0.0.0.0 --port 8000`
|
||||
2. 访问 `http://localhost:8000/docs`
|
||||
- **预期结果**: Swagger UI 正常显示,包含所有 API 路由
|
||||
- **通过标准**: HTTP 200,路由列表完整
|
||||
|
||||
### 用例 5: 前端构建验证
|
||||
- **前置条件**: 前端代码已改造
|
||||
- **操作步骤**: `npm run lint && npm run build`
|
||||
- **预期结果**: 无类型错误,构建成功
|
||||
- **通过标准**: exit code 0
|
||||
|
||||
### 用例 6: 前后端联调验证
|
||||
- **前置条件**: 前后端均运行中
|
||||
- **操作步骤**:
|
||||
1. 前端访问 `http://localhost:3000`
|
||||
2. 打开浏览器 DevTools Network 面板
|
||||
3. 触发项目列表加载
|
||||
- **预期结果**: 可见对 `http://localhost:8000/api/projects` 的请求,且返回 200
|
||||
- **通过标准**: API 数据正确渲染到界面
|
||||
|
||||
### 用例 7: 文件上传验证
|
||||
- **前置条件**: 媒体解析模块就绪
|
||||
- **操作步骤**: 上传 @Data_MyVideo_1.mp4
|
||||
- **预期结果**: 后端接收文件,存入 MinIO,触发解析任务
|
||||
- **通过标准**: MinIO bucket 中出现文件,返回 job_id
|
||||
|
||||
## 回归测试
|
||||
- [ ] 现有 React 组件无运行时错误
|
||||
- [ ] 深色主题样式未被破坏
|
||||
- [ ] Konva Canvas 可正常交互
|
||||
- [ ] 构建产物体积未异常膨胀
|
||||
|
||||
## 测试环境
|
||||
- OS: Ubuntu 22.04
|
||||
- GPU: NVIDIA RTX 4090 24GB
|
||||
- CUDA: 13.2
|
||||
- Python: 3.11 (conda)
|
||||
- Node.js: 22.x
|
||||
34
工程分析/经验记录.md
34
工程分析/经验记录.md
@@ -65,4 +65,38 @@ AI 助手运行的容器/环境与项目实际开发环境分离,后者才装
|
||||
|
||||
---
|
||||
|
||||
## 2026-04-29-21-51-19 — 全栈系统改造(FastAPI + SAM2 + PostgreSQL + Redis + MinIO)
|
||||
|
||||
### A. 具体问题
|
||||
1. 将纯前端 React 应用改造为全栈系统时,工程涉及后端框架替换、数据库设计、对象存储、AI 推理引擎、前端状态管理重构等多个复杂模块,单次执行工程量大。
|
||||
2. 系统磁盘空间(24G)不足,PyTorch CUDA 版本(>2GB)和 sam2 pip 包编译(需下载 CUDA 工具链 + 编译 C++ 扩展)均因 `OSError: No space left on device` 失败。
|
||||
3. MinIO 对象存储在磁盘紧张时报 `XMinioStorageFull`,导致文件上传失败。
|
||||
4. 前端 Agent 执行时因目录扁平化后子目录不存在,产生多次 `File not found` 错误。
|
||||
|
||||
### B. 产生原因
|
||||
1. 项目改造范围超出单次会议合理容量,涉及 15+ 后端文件、10+ 前端文件、4 个基础设施服务、1 个 AI 模型栈。
|
||||
2. 系统盘仅有 24GB,conda 环境(2GB)、node_modules(222MB)、模型权重(1.5GB)、MinIO 帧文件(1.4GB)叠加后迅速耗尽空间。
|
||||
3. sam2 的 pip 包依赖 torch>=2.5.1,pip 会尝试重新下载完整 torch wheel(530MB)作为 build dependency,即使 torch 已安装。
|
||||
4. 前端 Agent 的 prompt 中未明确说明组件目录已扁平化,Agent 仍尝试读取旧的子目录路径。
|
||||
|
||||
### C. 解决方案
|
||||
1. **任务拆分 + 并行 Agent**: 将后端和前端代码编写拆分为两个独立 Agent 并行执行,基础设施安装与代码编写并行推进,显著缩短总耗时。
|
||||
2. **磁盘管理策略**:
|
||||
- 安装 PyTorch CPU 版本替代 CUDA 版本(占用更小)
|
||||
- 只保留 sam2_hiera_tiny.pt(149MB),删除其他大模型
|
||||
- 清理 conda pkgs 缓存(释放 600MB+)
|
||||
- 删除 MinIO 中解析生成的临时帧文件(释放 1.4GB)
|
||||
3. **sam2 安装降级**: 当前环境以 stub 模式运行,提供 `install_sam2.sh` 脚本供用户在扩展磁盘后执行真实安装。
|
||||
4. **API 路径修复**: 后端添加 `/api/auth/login` 路由,修复前端 api.ts 中 `/api/predict` → `/api/ai/predict` 的路径不匹配。
|
||||
5. **MinIO API 适配**: minio 7.2.x 中 `presigned_url()` 已改为 `get_presigned_url()`,需从 `datetime.timedelta` 传入 expires。
|
||||
|
||||
### D. 后续如何避免问题
|
||||
1. **大型改造前先做磁盘评估**: 执行 `df -h` 确认可用空间 > 5GB 再开始安装大型依赖。
|
||||
2. **AI 模型依赖延迟加载**: 所有 AI 推理引擎必须实现 graceful fallback,模型缺失时不阻塞系统启动。
|
||||
3. **Agent prompt 需同步最新目录结构**: 给 Agent 的上下文必须包含当前真实的文件路径,避免 `File not found`。
|
||||
4. **构建依赖隔离**: 使用 `--no-build-isolation` 或 `--no-deps` 安装源码包,避免 pip 重复下载已安装的依赖。
|
||||
5. **MinIO 空间监控**: 定期清理解析产生的临时帧文件,或配置 MinIO 使用独立大容量数据盘。
|
||||
|
||||
---
|
||||
|
||||
> 新增经验请追加到文件末尾,保持时间倒序或正序均可,但需确保每条经验包含完整的 A/B/C/D 四段。
|
||||
|
||||
75
工程分析/需求分析-2026-04-29-21-51-19.md
Normal file
75
工程分析/需求分析-2026-04-29-21-51-19.md
Normal file
@@ -0,0 +1,75 @@
|
||||
# 需求分析 - 2026-04-29-21-51-19
|
||||
|
||||
## 需求来源
|
||||
- 提出时间: 2026-04-29-21-51-19
|
||||
- 需求类型: 全栈系统改造 / 架构重构
|
||||
|
||||
## 原始需求描述
|
||||
将现有纯前端 React 静态 UI 彻底改造为前后端联动的全栈语义分割系统。打通数据流转,实现:
|
||||
1. 本地多媒体资产(视频/图片/DCM影像)上传
|
||||
2. 服务器端按帧解析(FFmpeg 拆帧)
|
||||
3. AI 视觉大模型 SAM 3 实时交互式推理(点分割、box分割、语义分割)
|
||||
4. 动态图层状态管理
|
||||
5. 标注数据结构化导出
|
||||
|
||||
后端技术栈:Python + FastAPI + PostgreSQL + Redis + MinIO
|
||||
AI 基座:SAM 3
|
||||
输入数据:@Data_MyVideo_1.mp4、@Data_Dicom帧(DCM格式)
|
||||
|
||||
## 需求拆解
|
||||
|
||||
### 需求 1: 基础设施部署
|
||||
- **详细描述**: 在 Ubuntu 22.04 上部署 PostgreSQL、Redis、MinIO 对象存储
|
||||
- **优先级**: P0-阻塞
|
||||
- **影响范围**: 系统级服务
|
||||
- **验收标准**: 三个服务均可访问,有持久化数据目录
|
||||
|
||||
### 需求 2: Conda 环境 + SAM 3 模型部署
|
||||
- **详细描述**: 新建 conda 环境 seg_server,安装 PyTorch + CUDA,部署 SAM 3 模型权重
|
||||
- **优先级**: P0-阻塞
|
||||
- **影响范围**: Python AI 推理层
|
||||
- **验收标准**: conda 环境可激活,GPU 可被 PyTorch 识别,SAM 3 权重文件存在
|
||||
|
||||
### 需求 3: FastAPI 后端框架搭建
|
||||
- **详细描述**: 替换现有 Express 后端,建立 FastAPI 服务,包含上传、项目、模板、AI推理、导出等模块
|
||||
- **优先级**: P0-阻塞
|
||||
- **影响范围**: 后端全部
|
||||
- **验收标准**: FastAPI 服务可启动,Swagger UI 可访问
|
||||
|
||||
### 需求 4: 前端状态管理 + 网络层改造
|
||||
- **详细描述**: 引入 Zustand 全局状态管理,Axios 替代 fetch,WebSocket 对接解析进度
|
||||
- **优先级**: P0-阻塞
|
||||
- **影响范围**: src/ 全部
|
||||
- **验收标准**: 前端可正常与后端通信,状态跨组件同步
|
||||
|
||||
### 需求 5: Canvas Mask 渲染对接
|
||||
- **详细描述**: Konva Canvas 接收后端 SAM 3 返回的 Mask 数据(Polygon/RLE),动态渲染遮罩图层
|
||||
- **优先级**: P1-高
|
||||
- **影响范围**: CanvasArea.tsx, AISegmentation.tsx
|
||||
- **验收标准**: 点击画布后,后端返回 Mask,前端可正确绘制
|
||||
|
||||
### 需求 6: 视频/DCM 解析流水线
|
||||
- **详细描述**: 上传视频后 FFmpeg 拆帧,DCM 影像逐帧提取,进度通过 WebSocket 推送
|
||||
- **优先级**: P1-高
|
||||
- **影响范围**: 后端任务模块 + 前端 Dashboard
|
||||
- **验收标准**: 可上传文件,解析队列可显示进度,帧图片存入 MinIO
|
||||
|
||||
### 需求 7: 数据持久化(项目/模板/标注)
|
||||
- **详细描述**: PostgreSQL 存储项目元数据、本体字典、标注结果;MinIO 存储原始媒体/帧/Mask
|
||||
- **优先级**: P1-高
|
||||
- **影响范围**: 后端数据库模型 + API
|
||||
- **验收标准**: 数据重启后仍可读取
|
||||
|
||||
## 约束条件
|
||||
- 使用 Ubuntu 22.04, sudo 密码 Wkmgc
|
||||
- GPU: RTX 4090, CUDA 13.2
|
||||
- 无 Docker,使用系统包管理 + 二进制部署
|
||||
- 保持现有前端界面风格(深色主题、中文)
|
||||
|
||||
## 风险评估
|
||||
| 风险点 | 影响 | 缓解措施 |
|
||||
|--------|------|----------|
|
||||
| SAM 3 权重下载慢/失败 | 高 | 使用国内镜像源,分段下载验证 |
|
||||
| PostgreSQL/Redis 端口冲突 | 中 | 检查端口占用,使用非默认端口 |
|
||||
| MinIO 磁盘空间不足 | 中 | 监控磁盘,配置清理策略 |
|
||||
| 单次会议工程过大 | 高 | 分模块迭代,先跑通骨架再丰富 |
|
||||
Reference in New Issue
Block a user