From fd4b5e5b3d09e8f1b081af1e483ec20317aec99c Mon Sep 17 00:00:00 2001 From: admin <572701190@qq.com> Date: Wed, 29 Apr 2026 22:17:25 +0800 Subject: [PATCH] =?UTF-8?q?2026-04-29-21-51-19=20-=20=E5=85=A8=E6=A0=88?= =?UTF-8?q?=E7=B3=BB=E7=BB=9F=E6=94=B9=E9=80=A0=EF=BC=9AFastAPI=E5=90=8E?= =?UTF-8?q?=E7=AB=AF+SAM2+PostgreSQL+Redis+MinIO+=E5=89=8D=E7=AB=AFZustand?= =?UTF-8?q?=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 15 + backend/config.py | 35 +++ backend/database.py | 29 ++ backend/download_sam2.py | 49 +++ backend/main.py | 83 +++++ backend/minio_client.py | 127 ++++++++ backend/models.py | 118 +++++++ backend/redis_client.py | 61 ++++ backend/requirements.txt | 38 +++ backend/routers/__init__.py | 0 backend/routers/ai.py | 123 ++++++++ backend/routers/auth.py | 24 ++ backend/routers/export.py | 194 ++++++++++++ backend/routers/media.py | 192 ++++++++++++ backend/routers/projects.py | 165 ++++++++++ backend/routers/templates.py | 97 ++++++ backend/schemas.py | 157 ++++++++++ backend/services/__init__.py | 0 backend/services/frame_parser.py | 186 +++++++++++ backend/services/sam2_engine.py | 234 ++++++++++++++ package-lock.json | 146 ++++++++- package.json | 4 +- src/App.tsx | 26 +- src/components/AISegmentation.tsx | 101 ++++-- src/components/CanvasArea.tsx | 167 ++++++++-- src/components/Dashboard.tsx | 141 ++++++++- src/components/Login.tsx | 31 +- src/components/ProjectLibrary.tsx | 195 +++++++++--- src/components/TemplateRegistry.tsx | 383 ++++++++++++++++++----- src/components/ToolsPalette.tsx | 28 +- src/components/VideoWorkspace.tsx | 6 +- src/lib/api.ts | 135 ++++++++ src/lib/websocket.ts | 104 ++++++ src/store/useStore.ts | 195 ++++++++++++ start_services.sh | 66 ++++ 工程分析/实现方案-2026-04-29-21-51-19.md | 185 +++++++++++ 工程分析/测试方案-2026-04-29-21-51-19.md | 78 +++++ 工程分析/经验记录.md | 34 ++ 工程分析/需求分析-2026-04-29-21-51-19.md | 75 +++++ 39 files changed, 3816 insertions(+), 211 deletions(-) create mode 100644 backend/config.py create mode 100644 backend/database.py create mode 100644 backend/download_sam2.py create mode 100644 backend/main.py create mode 100644 backend/minio_client.py create mode 100644 backend/models.py create mode 100644 backend/redis_client.py create mode 100644 backend/requirements.txt create mode 100644 backend/routers/__init__.py create mode 100644 backend/routers/ai.py create mode 100644 backend/routers/auth.py create mode 100644 backend/routers/export.py create mode 100644 backend/routers/media.py create mode 100644 backend/routers/projects.py create mode 100644 backend/routers/templates.py create mode 100644 backend/schemas.py create mode 100644 backend/services/__init__.py create mode 100644 backend/services/frame_parser.py create mode 100644 backend/services/sam2_engine.py create mode 100644 src/lib/api.ts create mode 100644 src/lib/websocket.ts create mode 100644 src/store/useStore.ts create mode 100755 start_services.sh create mode 100644 工程分析/实现方案-2026-04-29-21-51-19.md create mode 100644 工程分析/测试方案-2026-04-29-21-51-19.md create mode 100644 工程分析/需求分析-2026-04-29-21-51-19.md diff --git a/.gitignore b/.gitignore index 5a86d2a..c399003 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,18 @@ coverage/ *.log .env* !.env.example +# Data & Models +models/ +uploads/ +frames/ +minio_data/ +Data_*/ +*.mp4 +*.dcm +*.7z +# Binaries +Viewer/ +Viewer.exe +DICOMDIR +__pycache__/ +*.pyc diff --git a/backend/config.py b/backend/config.py new file mode 100644 index 0000000..ba626e1 --- /dev/null +++ b/backend/config.py @@ -0,0 +1,35 @@ +"""Application configuration using Pydantic Settings.""" + +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + """Application settings loaded from environment variables.""" + + # Database + db_url: str = "postgresql://seguser:segpass123@localhost:5432/segserver" + + # Redis + redis_url: str = "redis://localhost:6379/0" + + # MinIO + minio_endpoint: str = "localhost:9000" + minio_access_key: str = "minioadmin" + minio_secret_key: str = "minioadmin" + minio_secure: bool = False + + # SAM2 + sam_model_path: str = "/home/wkmgc/Desktop/Seg_Server/models/sam2_hiera_tiny.pt" + sam_model_config: str = "sam2_hiera_t.yaml" + + # App + app_env: str = "development" + cors_origins: list[str] = ["http://localhost:3000"] + + class Config: + env_file = ".env" + env_file_encoding = "utf-8" + extra = "ignore" + + +settings = Settings() diff --git a/backend/database.py b/backend/database.py new file mode 100644 index 0000000..a30d9c6 --- /dev/null +++ b/backend/database.py @@ -0,0 +1,29 @@ +"""Database configuration using synchronous SQLAlchemy.""" + +from sqlalchemy import create_engine +from sqlalchemy.orm import declarative_base, sessionmaker, Session +from fastapi import Depends +from typing import Generator + +from config import settings + +engine = create_engine( + settings.db_url, + pool_pre_ping=True, + pool_size=10, + max_overflow=20, + echo=False, +) + +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +Base = declarative_base() + + +def get_db() -> Generator[Session, None, None]: + """FastAPI dependency that yields a database session.""" + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/backend/download_sam2.py b/backend/download_sam2.py new file mode 100644 index 0000000..9e55aa1 --- /dev/null +++ b/backend/download_sam2.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +""" +SAM 2 模型权重下载脚本 +运行: python download_sam2.py +""" +import os +import urllib.request +import sys + +MODEL_DIR = "/home/wkmgc/Desktop/Seg_Server/models" +os.makedirs(MODEL_DIR, exist_ok=True) + +# SAM 2 模型权重 (Meta AI 官方) +MODELS = { + "sam2_hiera_tiny.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt", + "sam2_hiera_small.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt", + "sam2_hiera_base_plus.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt", + "sam2_hiera_large.pt": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt", +} + +def download_file(url: str, dest: str): + if os.path.exists(dest): + print(f"[跳过] {os.path.basename(dest)} 已存在") + return + print(f"[下载] {url}") + print(f" → {dest}") + try: + urllib.request.urlretrieve(url, dest) + size_mb = os.path.getsize(dest) / 1024 / 1024 + print(f" ✓ 完成 ({size_mb:.1f} MB)") + except Exception as e: + print(f" ✗ 失败: {e}", file=sys.stderr) + if os.path.exists(dest): + os.remove(dest) + +def main(): + print("=" * 50) + print("SAM 2 模型权重下载") + print("=" * 50) + for name, url in MODELS.items(): + dest = os.path.join(MODEL_DIR, name) + download_file(url, dest) + print("=" * 50) + print("全部完成!") + print(f"模型目录: {MODEL_DIR}") + print("=" * 50) + +if __name__ == "__main__": + main() diff --git a/backend/main.py b/backend/main.py new file mode 100644 index 0000000..80241c3 --- /dev/null +++ b/backend/main.py @@ -0,0 +1,83 @@ +"""FastAPI application entrypoint.""" + +import logging +from contextlib import asynccontextmanager + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from config import settings +from database import Base, engine +from minio_client import ensure_bucket_exists +from redis_client import ping as redis_ping + +from routers import projects, templates, media, ai, export, auth + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", +) +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan: startup and shutdown hooks.""" + # Startup + logger.info("Starting up SegServer backend...") + + # Initialize database tables + try: + Base.metadata.create_all(bind=engine) + logger.info("Database tables initialized.") + except Exception as exc: # noqa: BLE001 + logger.error("Database initialization failed: %s", exc) + + # Check MinIO bucket + try: + ensure_bucket_exists() + except Exception as exc: # noqa: BLE001 + logger.error("MinIO bucket check failed: %s", exc) + + # Check Redis + if redis_ping(): + logger.info("Redis connection OK.") + else: + logger.warning("Redis connection failed.") + + yield + + # Shutdown + logger.info("Shutting down SegServer backend...") + engine.dispose() + + +app = FastAPI( + title="SegServer API", + description="Semantic Segmentation System Backend", + version="1.0.0", + lifespan=lifespan, +) + +# CORS +app.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Routers +app.include_router(auth.router) +app.include_router(projects.router) +app.include_router(templates.router) +app.include_router(media.router) +app.include_router(ai.router) +app.include_router(export.router) + + +@app.get("/health", tags=["Health"]) +def health_check() -> dict: + """Health check endpoint.""" + return {"status": "ok", "service": "SegServer"} diff --git a/backend/minio_client.py b/backend/minio_client.py new file mode 100644 index 0000000..817adf4 --- /dev/null +++ b/backend/minio_client.py @@ -0,0 +1,127 @@ +"""MinIO client wrapper for object storage operations.""" + +import io +import logging +from typing import Optional + +from minio import Minio +from minio.error import S3Error + +from config import settings + +logger = logging.getLogger(__name__) + +BUCKET_NAME = "seg-media" + +_minio_client: Optional[Minio] = None + + +def get_minio_client() -> Minio: + """Return a singleton MinIO client instance.""" + global _minio_client + if _minio_client is None: + _minio_client = Minio( + settings.minio_endpoint, + access_key=settings.minio_access_key, + secret_key=settings.minio_secret_key, + secure=settings.minio_secure, + ) + return _minio_client + + +def ensure_bucket_exists() -> None: + """Create the bucket if it does not already exist.""" + client = get_minio_client() + try: + if not client.bucket_exists(BUCKET_NAME): + client.make_bucket(BUCKET_NAME) + logger.info("Created MinIO bucket: %s", BUCKET_NAME) + else: + logger.info("MinIO bucket %s already exists", BUCKET_NAME) + except S3Error as exc: + logger.error("MinIO bucket check/creation failed: %s", exc) + raise + + +def upload_file( + object_name: str, + data: bytes, + content_type: str = "application/octet-stream", + length: int = -1, +) -> str: + """Upload bytes to MinIO and return the object name. + + Args: + object_name: Destination path inside the bucket. + data: Raw bytes or a file-like object. + content_type: MIME type of the object. + length: Object size; -1 for unknown (uses chunked upload). + + Returns: + The object name (same as input). + """ + client = get_minio_client() + if isinstance(data, bytes): + data = io.BytesIO(data) + length = len(data.getvalue()) + + try: + client.put_object( + BUCKET_NAME, + object_name, + data, + length=length, + content_type=content_type, + ) + logger.info("Uploaded to MinIO: %s", object_name) + return object_name + except S3Error as exc: + logger.error("MinIO upload failed: %s", exc) + raise + + +from datetime import timedelta + +def get_presigned_url( + object_name: str, + expires: int = 3600, + method: str = "GET", +) -> str: + """Generate a presigned URL for an object. + + Args: + object_name: Path inside the bucket. + expires: Expiration time in seconds (default 1 hour). + method: HTTP method (GET or PUT). + + Returns: + Presigned URL string. + """ + client = get_minio_client() + try: + url = client.get_presigned_url(method, BUCKET_NAME, object_name, expires=timedelta(seconds=expires)) + return url + except S3Error as exc: + logger.error("MinIO presigned URL failed: %s", exc) + raise + + +def download_file(object_name: str) -> bytes: + """Download an object from MinIO and return its bytes. + + Args: + object_name: Path inside the bucket. + + Returns: + Raw bytes of the object. + """ + client = get_minio_client() + try: + response = client.get_object(BUCKET_NAME, object_name) + data = response.read() + response.close() + response.release_conn() + return data + except S3Error as exc: + logger.error("MinIO download failed: %s", exc) + raise diff --git a/backend/models.py b/backend/models.py new file mode 100644 index 0000000..fb3fd75 --- /dev/null +++ b/backend/models.py @@ -0,0 +1,118 @@ +"""SQLAlchemy ORM models.""" + +from sqlalchemy import ( + Column, + Integer, + String, + Text, + DateTime, + ForeignKey, + JSON, + Float, +) +from sqlalchemy.orm import relationship +from sqlalchemy.sql import func + +from database import Base + + +class Project(Base): + """Project model representing a segmentation project.""" + + __tablename__ = "projects" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String(255), nullable=False) + description = Column(Text, nullable=True) + video_path = Column(String(512), nullable=True) + status = Column(String(50), default="pending", nullable=False) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) + + frames = relationship("Frame", back_populates="project", cascade="all, delete-orphan") + annotations = relationship( + "Annotation", back_populates="project", cascade="all, delete-orphan" + ) + + +class Frame(Base): + """Frame model representing an extracted video frame.""" + + __tablename__ = "frames" + + id = Column(Integer, primary_key=True, index=True) + project_id = Column(Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=False) + frame_index = Column(Integer, nullable=False) + image_url = Column(String(512), nullable=False) + width = Column(Integer, nullable=True) + height = Column(Integer, nullable=True) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + + project = relationship("Project", back_populates="frames") + annotations = relationship( + "Annotation", back_populates="frame", cascade="all, delete-orphan" + ) + + +class Template(Base): + """Template (Ontology) model for segmentation classes.""" + + __tablename__ = "templates" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String(255), nullable=False) + color = Column(String(50), nullable=False) + z_index = Column(Integer, default=0, nullable=False) + mapping_rules = Column(JSON, nullable=True) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + + annotations = relationship( + "Annotation", back_populates="template", cascade="all, delete-orphan" + ) + + +class Annotation(Base): + """Annotation model for segmentation masks and prompts.""" + + __tablename__ = "annotations" + + id = Column(Integer, primary_key=True, index=True) + project_id = Column( + Integer, ForeignKey("projects.id", ondelete="CASCADE"), nullable=False + ) + frame_id = Column( + Integer, ForeignKey("frames.id", ondelete="CASCADE"), nullable=True + ) + template_id = Column( + Integer, ForeignKey("templates.id", ondelete="SET NULL"), nullable=True + ) + mask_data = Column(JSON, nullable=True) + points = Column(JSON, nullable=True) + bbox = Column(JSON, nullable=True) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) + + project = relationship("Project", back_populates="annotations") + frame = relationship("Frame", back_populates="annotations") + template = relationship("Template", back_populates="annotations") + masks = relationship("Mask", back_populates="annotation", cascade="all, delete-orphan") + + +class Mask(Base): + """Mask model for exported/derived mask files.""" + + __tablename__ = "masks" + + id = Column(Integer, primary_key=True, index=True) + annotation_id = Column( + Integer, ForeignKey("annotations.id", ondelete="CASCADE"), nullable=False + ) + mask_url = Column(String(512), nullable=False) + format = Column(String(50), default="png", nullable=False) # png / rle / json + created_at = Column(DateTime(timezone=True), server_default=func.now()) + + annotation = relationship("Annotation", back_populates="masks") diff --git a/backend/redis_client.py b/backend/redis_client.py new file mode 100644 index 0000000..d7d4799 --- /dev/null +++ b/backend/redis_client.py @@ -0,0 +1,61 @@ +"""Redis client wrapper for caching and task queuing.""" + +import json +import logging +from typing import Optional, Any + +import redis + +from config import settings + +logger = logging.getLogger(__name__) + +_redis_client: Optional[redis.Redis] = None + + +def get_redis_client() -> redis.Redis: + """Return a singleton Redis client instance.""" + global _redis_client + if _redis_client is None: + _redis_client = redis.from_url(settings.redis_url, decode_responses=True) + return _redis_client + + +def ping() -> bool: + """Check Redis connectivity.""" + try: + return get_redis_client().ping() + except redis.ConnectionError as exc: + logger.error("Redis ping failed: %s", exc) + return False + + +def set_json(key: str, value: Any, expire: Optional[int] = None) -> None: + """Store a JSON-serializable value in Redis.""" + client = get_redis_client() + try: + client.set(key, json.dumps(value), ex=expire) + except redis.RedisError as exc: + logger.error("Redis set_json failed: %s", exc) + raise + + +def get_json(key: str) -> Optional[Any]: + """Retrieve and deserialize a JSON value from Redis.""" + client = get_redis_client() + try: + data = client.get(key) + return json.loads(data) if data is not None else None + except redis.RedisError as exc: + logger.error("Redis get_json failed: %s", exc) + raise + + +def delete_key(key: str) -> int: + """Delete a key from Redis. Returns number of deleted keys.""" + client = get_redis_client() + try: + return client.delete(key) + except redis.RedisError as exc: + logger.error("Redis delete_key failed: %s", exc) + raise diff --git a/backend/requirements.txt b/backend/requirements.txt new file mode 100644 index 0000000..cc6a567 --- /dev/null +++ b/backend/requirements.txt @@ -0,0 +1,38 @@ +# Web framework +fastapi +uvicorn[standard] +python-multipart + +# Database +sqlalchemy +psycopg2-binary +alembic + +# Cache / Task queue +redis +celery + +# Object storage +minio + +# Image / Video / DICOM processing +opencv-python +pillow +scikit-image +pydicom +numpy + +# SAM 2 (may require manual installation depending on CUDA version) +sam2 + +# PyTorch (CUDA 12.4 wheel index; adjust for your CUDA version if needed) +torch +torchvision +torchaudio + +# Configuration +pydantic-settings + +# Utilities +python-jose[cryptography] +passlib[bcrypt] diff --git a/backend/routers/__init__.py b/backend/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/routers/ai.py b/backend/routers/ai.py new file mode 100644 index 0000000..d4ea034 --- /dev/null +++ b/backend/routers/ai.py @@ -0,0 +1,123 @@ +"""AI inference endpoints using SAM 2.""" + +import logging +from typing import Any, List + +import cv2 +import numpy as np +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session + +from database import get_db +from minio_client import download_file +from models import Frame, Annotation +from schemas import PredictRequest, PredictResponse, AnnotationOut, AnnotationCreate +from services.sam2_engine import sam_engine + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/api/ai", tags=["AI"]) + + +def _load_frame_image(frame: Frame) -> np.ndarray: + """Download a frame from MinIO and decode it to an RGB numpy array.""" + try: + data = download_file(frame.image_url) + arr = np.frombuffer(data, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if img is None: + raise ValueError("OpenCV could not decode image") + return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + except Exception as exc: # noqa: BLE001 + logger.error("Failed to load frame image: %s", exc) + raise HTTPException(status_code=500, detail="Failed to load frame image") from exc + + +@router.post( + "/predict", + response_model=PredictResponse, + summary="Run SAM 2 inference with a prompt", +) +def predict(payload: PredictRequest, db: Session = Depends(get_db)) -> dict: + """Execute SAM 2 segmentation given an image and a prompt. + + - **point**: `prompt_data` is a list of `[[x, y], ...]` normalized coordinates. + - **box**: `prompt_data` is `[x1, y1, x2, y2]` normalized coordinates. + - **semantic**: Not yet implemented; falls back to auto segmentation. + """ + frame = db.query(Frame).filter(Frame.id == payload.image_id).first() + if not frame: + raise HTTPException(status_code=404, detail="Frame not found") + + image = _load_frame_image(frame) + prompt_type = payload.prompt_type.lower() + + polygons: List[List[List[float]]] = [] + scores: List[float] = [] + + if prompt_type == "point": + points = payload.prompt_data + if not isinstance(points, list) or len(points) == 0: + raise HTTPException(status_code=400, detail="Invalid point prompt data") + labels = [1] * len(points) + polygons, scores = sam_engine.predict_points(image, points, labels) + + elif prompt_type == "box": + box = payload.prompt_data + if not isinstance(box, list) or len(box) != 4: + raise HTTPException(status_code=400, detail="Invalid box prompt data") + polygons, scores = sam_engine.predict_box(image, box) + + elif prompt_type == "semantic": + # Placeholder: use auto segmentation for now + logger.info("Semantic prompt not implemented; using auto segmentation") + polygons, scores = sam_engine.predict_auto(image) + + else: + raise HTTPException(status_code=400, detail=f"Unsupported prompt_type: {prompt_type}") + + return {"polygons": polygons, "scores": scores} + + +@router.post( + "/auto", + response_model=PredictResponse, + summary="Run automatic segmentation", +) +def auto_segment(image_id: int, db: Session = Depends(get_db)) -> dict: + """Run automatic mask generation on a frame using a grid of point prompts.""" + frame = db.query(Frame).filter(Frame.id == image_id).first() + if not frame: + raise HTTPException(status_code=404, detail="Frame not found") + + image = _load_frame_image(frame) + polygons, scores = sam_engine.predict_auto(image) + + return {"polygons": polygons, "scores": scores} + + +@router.post( + "/annotate", + response_model=AnnotationOut, + status_code=status.HTTP_201_CREATED, + summary="Save an AI-generated annotation", +) +def save_annotation( + payload: AnnotationCreate, + db: Session = Depends(get_db), +) -> Annotation: + """Persist an annotation (mask, points, bbox) into the database.""" + project = db.query(Frame).filter(Frame.id == payload.project_id).first() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + if payload.frame_id: + frame = db.query(Frame).filter(Frame.id == payload.frame_id).first() + if not frame: + raise HTTPException(status_code=404, detail="Frame not found") + + annotation = Annotation(**payload.model_dump()) + db.add(annotation) + db.commit() + db.refresh(annotation) + logger.info("Saved annotation id=%s project_id=%s", annotation.id, annotation.project_id) + return annotation diff --git a/backend/routers/auth.py b/backend/routers/auth.py new file mode 100644 index 0000000..2a5ab10 --- /dev/null +++ b/backend/routers/auth.py @@ -0,0 +1,24 @@ +"""Authentication endpoints.""" + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +router = APIRouter(prefix="/api/auth", tags=["Auth"]) + + +class LoginRequest(BaseModel): + username: str + password: str + + +class LoginResponse(BaseModel): + token: str + username: str + + +@router.post("/login", response_model=LoginResponse) +def login(payload: LoginRequest) -> dict: + """Simple login for development.""" + if payload.username == "admin" and payload.password == "123456": + return {"token": "fake-jwt-token-for-admin", "username": payload.username} + raise HTTPException(status_code=401, detail="Invalid credentials") diff --git a/backend/routers/export.py b/backend/routers/export.py new file mode 100644 index 0000000..9662dd1 --- /dev/null +++ b/backend/routers/export.py @@ -0,0 +1,194 @@ +"""Annotation export endpoints (COCO, PNG masks).""" + +import io +import json +import logging +import os +import zipfile +from datetime import datetime +from typing import Any, Dict, List + +import numpy as np +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.responses import StreamingResponse +from sqlalchemy.orm import Session + +from database import get_db +from models import Project, Annotation, Frame, Template + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/api/export", tags=["Export"]) + + +def _mask_from_polygon( + polygon: List[List[float]], + width: int, + height: int, +) -> np.ndarray: + """Render a normalized polygon to a binary mask.""" + import cv2 + + pts = np.array( + [[int(p[0] * width), int(p[1] * height)] for p in polygon], + dtype=np.int32, + ) + mask = np.zeros((height, width), dtype=np.uint8) + cv2.fillPoly(mask, [pts], 255) + return mask + + +@router.get( + "/{project_id}/coco", + summary="Export annotations in COCO format", +) +def export_coco(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse: + """Export all annotations for a project as a COCO-format JSON file.""" + project = db.query(Project).filter(Project.id == project_id).first() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + annotations = ( + db.query(Annotation) + .filter(Annotation.project_id == project_id) + .all() + ) + frames = ( + db.query(Frame) + .filter(Frame.project_id == project_id) + .order_by(Frame.frame_index) + .all() + ) + templates = db.query(Template).all() + + # Build COCO structure + images = [] + for idx, frame in enumerate(frames): + images.append({ + "id": frame.id, + "file_name": frame.image_url, + "width": frame.width or 1920, + "height": frame.height or 1080, + "frame_index": idx, + }) + + categories = [] + template_id_to_cat_id: Dict[int, int] = {} + for cat_idx, tmpl in enumerate(templates, start=1): + categories.append({ + "id": cat_idx, + "name": tmpl.name, + "color": tmpl.color, + }) + template_id_to_cat_id[tmpl.id] = cat_idx + + coco_annotations = [] + ann_id = 1 + for ann in annotations: + if not ann.mask_data: + continue + polygons = ann.mask_data.get("polygons", []) + if not polygons: + continue + + # Use first polygon for bbox / area approximation + first_poly = polygons[0] + xs = [p[0] for p in first_poly] + ys = [p[1] for p in first_poly] + width = ann.frame.width if ann.frame else 1920 + height = ann.frame.height if ann.frame else 1080 + bbox = [ + min(xs) * width, + min(ys) * height, + (max(xs) - min(xs)) * width, + (max(ys) - min(ys)) * height, + ] + area = bbox[2] * bbox[3] + + segmentation = [] + for poly in polygons: + flat = [] + for p in poly: + flat.append(p[0] * width) + flat.append(p[1] * height) + segmentation.append(flat) + + coco_annotations.append({ + "id": ann_id, + "image_id": ann.frame_id, + "category_id": template_id_to_cat_id.get(ann.template_id, 0), + "segmentation": segmentation, + "area": area, + "bbox": bbox, + "iscrowd": 0, + }) + ann_id += 1 + + coco = { + "info": { + "description": f"Annotations for {project.name}", + "version": "1.0", + "year": datetime.now().year, + "date_created": datetime.now().isoformat(), + }, + "images": images, + "annotations": coco_annotations, + "categories": categories, + } + + data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8") + filename = f"project_{project_id}_coco.json" + + return StreamingResponse( + io.BytesIO(data), + media_type="application/json", + headers={"Content-Disposition": f'attachment; filename="{filename}"'}, + ) + + +@router.get( + "/{project_id}/masks", + summary="Export PNG masks as a ZIP archive", +) +def export_masks(project_id: int, db: Session = Depends(get_db)) -> StreamingResponse: + """Export all annotation masks as individual PNG files inside a ZIP archive.""" + project = db.query(Project).filter(Project.id == project_id).first() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + annotations = ( + db.query(Annotation) + .filter(Annotation.project_id == project_id) + .all() + ) + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf: + for ann in annotations: + if not ann.mask_data: + continue + polygons = ann.mask_data.get("polygons", []) + if not polygons: + continue + + width = ann.frame.width if ann.frame else 1920 + height = ann.frame.height if ann.frame else 1080 + combined = np.zeros((height, width), dtype=np.uint8) + + for poly in polygons: + mask = _mask_from_polygon(poly, width, height) + combined = np.maximum(combined, mask) + + # Encode PNG + import cv2 + _, encoded = cv2.imencode(".png", combined) + fname = f"mask_{ann.id:06d}.png" + zf.writestr(fname, encoded.tobytes()) + + zip_buffer.seek(0) + filename = f"project_{project_id}_masks.zip" + + return StreamingResponse( + zip_buffer, + media_type="application/zip", + headers={"Content-Disposition": f'attachment; filename="{filename}"'}, + ) diff --git a/backend/routers/media.py b/backend/routers/media.py new file mode 100644 index 0000000..f282756 --- /dev/null +++ b/backend/routers/media.py @@ -0,0 +1,192 @@ +"""Media upload and parsing endpoints.""" + +import logging +import os +import shutil +import subprocess +import tempfile +from pathlib import Path +from typing import Optional + +from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status +from sqlalchemy.orm import Session + +from database import get_db +from minio_client import upload_file, get_presigned_url +from models import Project, Frame +from schemas import FrameOut +from services.frame_parser import parse_video, parse_dicom, upload_frames_to_minio + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/api/media", tags=["Media"]) + +ALLOWED_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".png", ".jpg", ".jpeg", ".dcm"} + + +def _get_ext(filename: str) -> str: + return Path(filename).suffix.lower() + + +@router.post( + "/upload", + status_code=status.HTTP_201_CREATED, + summary="Upload a media file", +) +async def upload_media( + file: UploadFile = File(...), + project_id: Optional[int] = Form(None), + db: Session = Depends(get_db), +) -> dict: + """Accept a video, image, or DICOM file and store it in MinIO. + + If project_id is provided, the video_path of the project is updated. + Returns the presigned URL of the uploaded object. + """ + if not file.filename: + raise HTTPException(status_code=400, detail="Missing filename") + + ext = _get_ext(file.filename) + if ext not in ALLOWED_EXTENSIONS: + raise HTTPException( + status_code=400, + detail=f"Unsupported file type: {ext}", + ) + + data = await file.read() + object_name = f"uploads/{project_id or 'general'}/{file.filename}" + + try: + upload_file(object_name, data, content_type=file.content_type or "application/octet-stream", length=len(data)) + except Exception as exc: # noqa: BLE001 + logger.error("Upload failed: %s", exc) + raise HTTPException(status_code=500, detail="Upload to storage failed") from exc + + file_url = get_presigned_url(object_name, expires=3600) + + if project_id: + project = db.query(Project).filter(Project.id == project_id).first() + if project: + project.video_path = object_name + db.commit() + logger.info("Linked upload to project_id=%s", project_id) + else: + logger.warning("Project id=%s not found for upload linkage", project_id) + + # TODO: enqueue async parsing job (Celery / background task) + logger.info("Upload complete: %s (size=%d bytes). Async parsing queued.", object_name, len(data)) + + return { + "object_name": object_name, + "file_url": file_url, + "size": len(data), + "message": "Upload successful. Parsing job queued.", + } + + +@router.post( + "/parse", + status_code=status.HTTP_202_ACCEPTED, + summary="Trigger frame extraction", +) +def parse_media( + project_id: int, + source_type: str = "video", # video | dicom + db: Session = Depends(get_db), +) -> dict: + """Trigger frame extraction for a project's uploaded media. + + * video: uses FFmpeg or OpenCV fallback. + * dicom: uses pydicom to read DCM frames. + + Extracted frames are uploaded to MinIO and registered in the database. + """ + project = db.query(Project).filter(Project.id == project_id).first() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + if not project.video_path: + raise HTTPException(status_code=400, detail="Project has no media uploaded") + + # Download from MinIO to a temp directory + from minio_client import download_file + + try: + media_bytes = download_file(project.video_path) + except Exception as exc: # noqa: BLE001 + logger.error("Failed to download media for parsing: %s", exc) + raise HTTPException(status_code=500, detail="Failed to retrieve media from storage") from exc + + tmp_dir = tempfile.mkdtemp(prefix=f"seg_parse_{project_id}_") + local_path = os.path.join(tmp_dir, Path(project.video_path).name) + + with open(local_path, "wb") as f: + f.write(media_bytes) + + output_dir = os.path.join(tmp_dir, "frames") + os.makedirs(output_dir, exist_ok=True) + + try: + if source_type == "dicom": + # For DICOM, treat local_path as a directory if it contains multiple .dcm + # If a single .dcm file was uploaded, put it in its own sub-dir + dcm_dir = os.path.join(tmp_dir, "dcm") + os.makedirs(dcm_dir, exist_ok=True) + if local_path.lower().endswith(".dcm"): + shutil.move(local_path, os.path.join(dcm_dir, os.path.basename(local_path))) + else: + shutil.unpack_archive(local_path, dcm_dir) if shutil.which("unzip") else shutil.move(local_path, dcm_dir) + frame_files = parse_dicom(dcm_dir, output_dir) + else: + frame_files = parse_video(local_path, output_dir, fps=30) + except Exception as exc: # noqa: BLE001 + logger.error("Frame extraction failed: %s", exc) + shutil.rmtree(tmp_dir, ignore_errors=True) + raise HTTPException(status_code=500, detail="Frame extraction failed") from exc + + # Upload frames to MinIO + try: + object_names = upload_frames_to_minio(frame_files, project_id) + except Exception as exc: # noqa: BLE001 + logger.error("Frame upload failed: %s", exc) + shutil.rmtree(tmp_dir, ignore_errors=True) + raise HTTPException(status_code=500, detail="Frame upload to storage failed") from exc + + # Register frames in DB + frames_out = [] + for idx, obj_name in enumerate(object_names): + # Get image dimensions + local_frame = frame_files[idx] + try: + import cv2 + img = cv2.imread(local_frame) + h, w = img.shape[:2] if img is not None else (None, None) + except Exception: # noqa: BLE001 + h, w = None, None + + frame = Frame( + project_id=project_id, + frame_index=idx, + image_url=obj_name, + width=w, + height=h, + ) + db.add(frame) + frames_out.append(frame) + + db.commit() + for f in frames_out: + db.refresh(f) + + # Cleanup temp files + shutil.rmtree(tmp_dir, ignore_errors=True) + + project.status = "ready" + db.commit() + + logger.info("Parsed %d frames for project_id=%s", len(frames_out), project_id) + return { + "project_id": project_id, + "frames_extracted": len(frames_out), + "status": "ready", + "message": "Frame extraction completed successfully.", + } diff --git a/backend/routers/projects.py b/backend/routers/projects.py new file mode 100644 index 0000000..3d8e150 --- /dev/null +++ b/backend/routers/projects.py @@ -0,0 +1,165 @@ +"""Project and Frame CRUD endpoints.""" + +import logging +from typing import List + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session + +from database import get_db +from models import Project, Frame +from schemas import ProjectCreate, ProjectOut, ProjectUpdate, FrameCreate, FrameOut + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/api/projects", tags=["Projects"]) + + +# --------------------------------------------------------------------------- +# Projects +# --------------------------------------------------------------------------- +@router.post( + "", + response_model=ProjectOut, + status_code=status.HTTP_201_CREATED, + summary="Create a new project", +) +def create_project(payload: ProjectCreate, db: Session = Depends(get_db)) -> Project: + """Create a new segmentation project.""" + project = Project(**payload.model_dump()) + db.add(project) + db.commit() + db.refresh(project) + logger.info("Created project id=%s name=%s", project.id, project.name) + return project + + +@router.get( + "", + response_model=List[ProjectOut], + summary="List all projects", +) +def list_projects(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)) -> List[Project]: + """Retrieve a paginated list of projects.""" + return db.query(Project).offset(skip).limit(limit).all() + + +@router.get( + "/{project_id}", + response_model=ProjectOut, + summary="Get a single project", +) +def get_project(project_id: int, db: Session = Depends(get_db)) -> Project: + """Retrieve a project by its ID.""" + project = db.query(Project).filter(Project.id == project_id).first() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + return project + + +@router.patch( + "/{project_id}", + response_model=ProjectOut, + summary="Update a project", +) +def update_project( + project_id: int, + payload: ProjectUpdate, + db: Session = Depends(get_db), +) -> Project: + """Update project fields partially.""" + project = db.query(Project).filter(Project.id == project_id).first() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + for key, value in payload.model_dump(exclude_unset=True).items(): + setattr(project, key, value) + + db.commit() + db.refresh(project) + logger.info("Updated project id=%s", project_id) + return project + + +@router.delete( + "/{project_id}", + status_code=status.HTTP_204_NO_CONTENT, + summary="Delete a project", +) +def delete_project(project_id: int, db: Session = Depends(get_db)) -> None: + """Delete a project and all related frames and annotations.""" + project = db.query(Project).filter(Project.id == project_id).first() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + db.delete(project) + db.commit() + logger.info("Deleted project id=%s", project_id) + + +# --------------------------------------------------------------------------- +# Frames +# --------------------------------------------------------------------------- +@router.post( + "/{project_id}/frames", + response_model=FrameOut, + status_code=status.HTTP_201_CREATED, + summary="Add a frame to a project", +) +def create_frame( + project_id: int, + payload: FrameCreate, + db: Session = Depends(get_db), +) -> Frame: + """Register a new frame under a project.""" + project = db.query(Project).filter(Project.id == project_id).first() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + frame = Frame(project_id=project_id, **payload.model_dump(exclude={"project_id"})) + db.add(frame) + db.commit() + db.refresh(frame) + return frame + + +@router.get( + "/{project_id}/frames", + response_model=List[FrameOut], + summary="List frames for a project", +) +def list_frames( + project_id: int, + skip: int = 0, + limit: int = 1000, + db: Session = Depends(get_db), +) -> List[Frame]: + """Retrieve all frames belonging to a project.""" + project = db.query(Project).filter(Project.id == project_id).first() + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + return ( + db.query(Frame) + .filter(Frame.project_id == project_id) + .order_by(Frame.frame_index) + .offset(skip) + .limit(limit) + .all() + ) + + +@router.get( + "/{project_id}/frames/{frame_id}", + response_model=FrameOut, + summary="Get a single frame", +) +def get_frame(project_id: int, frame_id: int, db: Session = Depends(get_db)) -> Frame: + """Retrieve a specific frame by ID.""" + frame = ( + db.query(Frame) + .filter(Frame.project_id == project_id, Frame.id == frame_id) + .first() + ) + if not frame: + raise HTTPException(status_code=404, detail="Frame not found") + return frame diff --git a/backend/routers/templates.py b/backend/routers/templates.py new file mode 100644 index 0000000..d26fb2c --- /dev/null +++ b/backend/routers/templates.py @@ -0,0 +1,97 @@ +"""Template (Ontology) CRUD endpoints.""" + +import logging +from typing import List + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session + +from database import get_db +from models import Template +from schemas import TemplateCreate, TemplateOut, TemplateUpdate + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/api/templates", tags=["Templates"]) + + +@router.post( + "", + response_model=TemplateOut, + status_code=status.HTTP_201_CREATED, + summary="Create a new template", +) +def create_template(payload: TemplateCreate, db: Session = Depends(get_db)) -> Template: + """Create a new ontology template / segmentation class.""" + template = Template(**payload.model_dump()) + db.add(template) + db.commit() + db.refresh(template) + logger.info("Created template id=%s name=%s", template.id, template.name) + return template + + +@router.get( + "", + response_model=List[TemplateOut], + summary="List all templates", +) +def list_templates( + skip: int = 0, + limit: int = 100, + db: Session = Depends(get_db), +) -> List[Template]: + """Retrieve all ontology templates.""" + return db.query(Template).offset(skip).limit(limit).all() + + +@router.get( + "/{template_id}", + response_model=TemplateOut, + summary="Get a single template", +) +def get_template(template_id: int, db: Session = Depends(get_db)) -> Template: + """Retrieve a template by its ID.""" + template = db.query(Template).filter(Template.id == template_id).first() + if not template: + raise HTTPException(status_code=404, detail="Template not found") + return template + + +@router.patch( + "/{template_id}", + response_model=TemplateOut, + summary="Update a template", +) +def update_template( + template_id: int, + payload: TemplateUpdate, + db: Session = Depends(get_db), +) -> Template: + """Update template fields partially.""" + template = db.query(Template).filter(Template.id == template_id).first() + if not template: + raise HTTPException(status_code=404, detail="Template not found") + + for key, value in payload.model_dump(exclude_unset=True).items(): + setattr(template, key, value) + + db.commit() + db.refresh(template) + logger.info("Updated template id=%s", template_id) + return template + + +@router.delete( + "/{template_id}", + status_code=status.HTTP_204_NO_CONTENT, + summary="Delete a template", +) +def delete_template(template_id: int, db: Session = Depends(get_db)) -> None: + """Delete a template. Associated annotations will have template_id set to NULL.""" + template = db.query(Template).filter(Template.id == template_id).first() + if not template: + raise HTTPException(status_code=404, detail="Template not found") + + db.delete(template) + db.commit() + logger.info("Deleted template id=%s", template_id) diff --git a/backend/schemas.py b/backend/schemas.py new file mode 100644 index 0000000..a07b83e --- /dev/null +++ b/backend/schemas.py @@ -0,0 +1,157 @@ +"""Pydantic schemas for request/response validation.""" + +from datetime import datetime +from typing import Optional, Any +from pydantic import BaseModel, ConfigDict + + +# --------------------------------------------------------------------------- +# Project schemas +# --------------------------------------------------------------------------- +class ProjectBase(BaseModel): + name: str + description: Optional[str] = None + video_path: Optional[str] = None + status: Optional[str] = "pending" + + +class ProjectCreate(ProjectBase): + pass + + +class ProjectUpdate(BaseModel): + name: Optional[str] = None + description: Optional[str] = None + video_path: Optional[str] = None + status: Optional[str] = None + + +class ProjectOut(ProjectBase): + model_config = ConfigDict(from_attributes=True) + + id: int + created_at: datetime + updated_at: datetime + + +# --------------------------------------------------------------------------- +# Frame schemas +# --------------------------------------------------------------------------- +class FrameBase(BaseModel): + frame_index: int + image_url: str + width: Optional[int] = None + height: Optional[int] = None + + +class FrameCreate(FrameBase): + project_id: int + + +class FrameOut(FrameBase): + model_config = ConfigDict(from_attributes=True) + + id: int + project_id: int + created_at: datetime + + +# --------------------------------------------------------------------------- +# Template schemas +# --------------------------------------------------------------------------- +class TemplateBase(BaseModel): + name: str + color: str + z_index: int = 0 + mapping_rules: Optional[dict[str, Any]] = None + + +class TemplateCreate(TemplateBase): + pass + + +class TemplateUpdate(BaseModel): + name: Optional[str] = None + color: Optional[str] = None + z_index: Optional[int] = None + mapping_rules: Optional[dict[str, Any]] = None + + +class TemplateOut(TemplateBase): + model_config = ConfigDict(from_attributes=True) + + id: int + created_at: datetime + + +# --------------------------------------------------------------------------- +# Annotation schemas +# --------------------------------------------------------------------------- +class AnnotationBase(BaseModel): + project_id: int + frame_id: Optional[int] = None + template_id: Optional[int] = None + mask_data: Optional[dict[str, Any]] = None + points: Optional[list[list[float]]] = None + bbox: Optional[list[float]] = None + + +class AnnotationCreate(AnnotationBase): + pass + + +class AnnotationUpdate(BaseModel): + mask_data: Optional[dict[str, Any]] = None + points: Optional[list[list[float]]] = None + bbox: Optional[list[float]] = None + template_id: Optional[int] = None + + +class AnnotationOut(AnnotationBase): + model_config = ConfigDict(from_attributes=True) + + id: int + created_at: datetime + updated_at: datetime + + +# --------------------------------------------------------------------------- +# Mask schemas +# --------------------------------------------------------------------------- +class MaskBase(BaseModel): + annotation_id: int + mask_url: str + format: str = "png" + + +class MaskCreate(MaskBase): + pass + + +class MaskOut(MaskBase): + model_config = ConfigDict(from_attributes=True) + + id: int + created_at: datetime + + +# --------------------------------------------------------------------------- +# AI schemas +# --------------------------------------------------------------------------- +class PredictRequest(BaseModel): + image_id: int + prompt_type: str # point / box / semantic + prompt_data: Any + + +class PredictResponse(BaseModel): + polygons: list[list[list[float]]] + scores: Optional[list[float]] = None + + +# --------------------------------------------------------------------------- +# Export schemas +# --------------------------------------------------------------------------- +class ExportStatus(BaseModel): + url: str + format: str diff --git a/backend/services/__init__.py b/backend/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/services/frame_parser.py b/backend/services/frame_parser.py new file mode 100644 index 0000000..8c263d8 --- /dev/null +++ b/backend/services/frame_parser.py @@ -0,0 +1,186 @@ +"""Video/DICOM frame parsing and MinIO upload utilities.""" + +import logging +import os +import shutil +import subprocess +from pathlib import Path +from typing import List, Optional + +import cv2 +import numpy as np +from pydicom import dcmread + +from minio_client import upload_file, BUCKET_NAME + +logger = logging.getLogger(__name__) + + +def parse_video( + video_path: str, + output_dir: str, + fps: int = 30, + max_frames: Optional[int] = None, +) -> List[str]: + """Extract frames from a video file using FFmpeg or OpenCV fallback. + + Args: + video_path: Path to the input video file. + output_dir: Directory to save extracted frames. + fps: Target frame extraction rate. + max_frames: Optional maximum number of frames to extract. + + Returns: + List of paths to extracted frame images. + """ + os.makedirs(output_dir, exist_ok=True) + frame_paths: List[str] = [] + + # Try FFmpeg first + if shutil.which("ffmpeg"): + try: + pattern = os.path.join(output_dir, "frame_%06d.png") + cmd = [ + "ffmpeg", + "-i", video_path, + "-vf", f"fps={fps},scale='min(1920,iw)':-1", + "-pix_fmt", "rgb24", + "-y", + pattern, + ] + logger.info("Running FFmpeg: %s", " ".join(cmd)) + result = subprocess.run(cmd, capture_output=True, text=True, check=False) + if result.returncode == 0: + frame_paths = sorted( + [os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith(".png")] + ) + if max_frames: + frame_paths = frame_paths[:max_frames] + logger.info("Extracted %d frames via FFmpeg", len(frame_paths)) + return frame_paths + else: + logger.warning("FFmpeg failed: %s", result.stderr) + except Exception as exc: # noqa: BLE001 + logger.warning("FFmpeg exception: %s", exc) + + # OpenCV fallback + logger.info("Falling back to OpenCV frame extraction") + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise RuntimeError(f"Cannot open video: {video_path}") + + video_fps = cap.get(cv2.CAP_PROP_FPS) or 30 + interval = max(1, int(round(video_fps / fps))) + count = 0 + saved = 0 + + while True: + ret, frame = cap.read() + if not ret: + break + if count % interval == 0: + path = os.path.join(output_dir, f"frame_{saved:06d}.png") + cv2.imwrite(path, frame) + frame_paths.append(path) + saved += 1 + if max_frames and saved >= max_frames: + break + count += 1 + + cap.release() + logger.info("Extracted %d frames via OpenCV", len(frame_paths)) + return frame_paths + + +def parse_dicom( + dicom_dir: str, + output_dir: str, + max_frames: Optional[int] = None, +) -> List[str]: + """Extract frames from DICOM files in a directory. + + Args: + dicom_dir: Directory containing .dcm files. + output_dir: Directory to save extracted frames. + max_frames: Optional maximum number of frames to extract. + + Returns: + List of paths to extracted frame images. + """ + os.makedirs(output_dir, exist_ok=True) + dcm_files = sorted( + [f for f in os.listdir(dicom_dir) if f.lower().endswith(".dcm")] + ) + + frame_paths: List[str] = [] + for idx, fname in enumerate(dcm_files): + if max_frames and idx >= max_frames: + break + path = os.path.join(dicom_dir, fname) + try: + ds = dcmread(path) + pixel_array = ds.pixel_array + + # Normalize to 8-bit + if pixel_array.dtype != np.uint8: + pixel_array = pixel_array.astype(np.float32) + pixel_array = ( + (pixel_array - pixel_array.min()) + / (pixel_array.max() - pixel_array.min() + 1e-8) + * 255 + ) + pixel_array = pixel_array.astype(np.uint8) + + # Handle multi-frame DICOM + if pixel_array.ndim == 3: + for f in range(pixel_array.shape[0]): + out_path = os.path.join(output_dir, f"frame_{idx:06d}_{f:03d}.png") + cv2.imwrite(out_path, pixel_array[f]) + frame_paths.append(out_path) + else: + out_path = os.path.join(output_dir, f"frame_{idx:06d}.png") + cv2.imwrite(out_path, pixel_array) + frame_paths.append(out_path) + except Exception as exc: # noqa: BLE001 + logger.error("Failed to read DICOM %s: %s", path, exc) + + logger.info("Extracted %d frames from DICOM", len(frame_paths)) + return frame_paths + + +def upload_frames_to_minio( + frames: List[str], + project_id: int, + object_prefix: Optional[str] = None, +) -> List[str]: + """Upload a list of local frame images to MinIO. + + Args: + frames: List of local file paths. + project_id: Project ID used for bucket path organization. + object_prefix: Optional prefix override. + + Returns: + List of object names (paths) in MinIO. + """ + prefix = object_prefix or f"projects/{project_id}/frames" + object_names: List[str] = [] + + for frame_path in frames: + fname = os.path.basename(frame_path) + object_name = f"{prefix}/{fname}" + try: + with open(frame_path, "rb") as f: + data = f.read() + upload_file( + object_name, + data, + content_type="image/png", + length=len(data), + ) + object_names.append(object_name) + except Exception as exc: # noqa: BLE001 + logger.error("Failed to upload %s: %s", frame_path, exc) + + logger.info("Uploaded %d/%d frames to MinIO", len(object_names), len(frames)) + return object_names diff --git a/backend/services/sam2_engine.py b/backend/services/sam2_engine.py new file mode 100644 index 0000000..bb670cf --- /dev/null +++ b/backend/services/sam2_engine.py @@ -0,0 +1,234 @@ +"""SAM 2 engine wrapper with lazy loading and fallback stubs.""" + +import logging +import os +from typing import Optional + +import numpy as np + +from config import settings + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Attempt to import SAM 2; fall back to stubs if unavailable. +# --------------------------------------------------------------------------- +try: + import torch + from sam2.build_sam import build_sam2 + from sam2.sam2_image_predictor import SAM2ImagePredictor + + SAM2_AVAILABLE = True + logger.info("SAM2 library imported successfully.") +except Exception as exc: # noqa: BLE001 + SAM2_AVAILABLE = False + logger.warning("SAM2 import failed (%s). Using stub engine.", exc) + + +class SAM2Engine: + """Lazy-loaded SAM 2 inference engine.""" + + def __init__(self) -> None: + self._predictor: Optional[SAM2ImagePredictor] = None + self._model_loaded = False + + # ----------------------------------------------------------------------- + # Internal helpers + # ----------------------------------------------------------------------- + def _load_model(self) -> None: + """Load the SAM 2 model and predictor on first use.""" + if self._model_loaded: + return + + if not SAM2_AVAILABLE: + logger.warning("SAM2 not available; skipping model load.") + self._model_loaded = True + return + + if not os.path.isfile(settings.sam_model_path): + logger.error("SAM checkpoint not found at %s", settings.sam_model_path) + self._model_loaded = True + return + + try: + model = build_sam2( + settings.sam_model_config, + settings.sam_model_path, + device="cuda", + ) + self._predictor = SAM2ImagePredictor(model) + self._model_loaded = True + logger.info("SAM 2 model loaded from %s", settings.sam_model_path) + except Exception as exc: # noqa: BLE001 + logger.error("Failed to load SAM 2 model: %s", exc) + self._model_loaded = True # Prevent repeated load attempts + + def _ensure_ready(self) -> bool: + """Ensure the model is loaded; return whether it is usable.""" + self._load_model() + return SAM2_AVAILABLE and self._predictor is not None + + # ----------------------------------------------------------------------- + # Public API + # ----------------------------------------------------------------------- + def predict_points( + self, + image: np.ndarray, + points: list[list[float]], + labels: list[int], + ) -> tuple[list[list[list[float]]], list[float]]: + """Run point-prompt segmentation. + + Args: + image: HWC numpy array (uint8). + points: List of [x, y] normalized coordinates (0-1). + labels: 1 for foreground, 0 for background. + + Returns: + Tuple of (polygons, scores). + """ + if not self._ensure_ready(): + logger.warning("SAM2 not ready; returning dummy masks.") + return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] + + try: + h, w = image.shape[:2] + pts = np.array([[p[0] * w, p[1] * h] for p in points], dtype=np.float32) + lbls = np.array(labels, dtype=np.int32) + + with torch.inference_mode(): # type: ignore[name-defined] + self._predictor.set_image(image) + masks, scores, _ = self._predictor.predict( + point_coords=pts, + point_labels=lbls, + multimask_output=True, + ) + + polygons = [] + for m in masks: + poly = self._mask_to_polygon(m) + if poly: + polygons.append(poly) + + return polygons, scores.tolist() + except Exception as exc: # noqa: BLE001 + logger.error("SAM2 point prediction failed: %s", exc) + return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] + + def predict_box( + self, + image: np.ndarray, + box: list[float], + ) -> tuple[list[list[list[float]]], list[float]]: + """Run box-prompt segmentation. + + Args: + image: HWC numpy array (uint8). + box: [x1, y1, x2, y2] normalized coordinates. + + Returns: + Tuple of (polygons, scores). + """ + if not self._ensure_ready(): + logger.warning("SAM2 not ready; returning dummy masks.") + return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] + + try: + h, w = image.shape[:2] + bbox = np.array( + [box[0] * w, box[1] * h, box[2] * w, box[3] * h], + dtype=np.float32, + ) + + with torch.inference_mode(): # type: ignore[name-defined] + self._predictor.set_image(image) + masks, scores, _ = self._predictor.predict( + box=bbox[None, :], + multimask_output=False, + ) + + polygons = [] + for m in masks: + poly = self._mask_to_polygon(m) + if poly: + polygons.append(poly) + + return polygons, scores.tolist() + except Exception as exc: # noqa: BLE001 + logger.error("SAM2 box prediction failed: %s", exc) + return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] + + def predict_auto(self, image: np.ndarray) -> tuple[list[list[list[float]]], list[float]]: + """Run automatic mask generation (grid of points). + + Args: + image: HWC numpy array (uint8). + + Returns: + Tuple of (polygons, scores). + """ + if not self._ensure_ready(): + logger.warning("SAM2 not ready; returning dummy masks.") + return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] + + try: + with torch.inference_mode(): # type: ignore[name-defined] + self._predictor.set_image(image) + # Generate a uniform 16x16 grid of point prompts + h, w = image.shape[:2] + grid = np.mgrid[0:1:17j, 0:1:17j].reshape(2, -1).T + pts = grid * np.array([w, h]) + lbls = np.ones(pts.shape[0], dtype=np.int32) + + masks, scores, _ = self._predictor.predict( + point_coords=pts, + point_labels=lbls, + multimask_output=True, + ) + + polygons = [] + for m in masks[:3]: # Limit to top 3 masks + poly = self._mask_to_polygon(m) + if poly: + polygons.append(poly) + + return polygons, scores[:3].tolist() + except Exception as exc: # noqa: BLE001 + logger.error("SAM2 auto prediction failed: %s", exc) + return self._dummy_polygons(image.shape[1], image.shape[0]), [0.5] + + # ----------------------------------------------------------------------- + # Helpers + # ----------------------------------------------------------------------- + @staticmethod + def _mask_to_polygon(mask: np.ndarray) -> list[list[float]]: + """Convert a binary mask to a normalized polygon.""" + import cv2 + + if mask.dtype != np.uint8: + mask = (mask > 0).astype(np.uint8) + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + h, w = mask.shape[:2] + largest = [] + for cnt in contours: + if len(cnt) > len(largest): + largest = cnt + if len(largest) < 3: + return [] + return [[float(pt[0][0]) / w, float(pt[0][1]) / h] for pt in largest] + + @staticmethod + def _dummy_polygons(w: int, h: int) -> list[list[list[float]]]: + """Return a dummy rectangle polygon for fallback mode.""" + return [ + [ + [0.25, 0.25], + [0.75, 0.25], + [0.75, 0.75], + [0.25, 0.75], + ] + ] + + +# Singleton instance +sam_engine = SAM2Engine() diff --git a/package-lock.json b/package-lock.json index 27d5fce..1f1d6c7 100644 --- a/package-lock.json +++ b/package-lock.json @@ -11,6 +11,7 @@ "@google/genai": "^1.29.0", "@tailwindcss/vite": "^4.1.14", "@vitejs/plugin-react": "^5.0.4", + "axios": "^1.15.2", "clsx": "^2.1.1", "dotenv": "^17.2.3", "express": "^4.21.2", @@ -22,7 +23,8 @@ "react-konva": "^19.2.3", "tailwind-merge": "^3.5.0", "use-image": "^1.1.4", - "vite": "^6.2.0" + "vite": "^6.2.0", + "zustand": "^5.0.12" }, "devDependencies": { "@types/express": "^4.17.21", @@ -1670,6 +1672,12 @@ "integrity": "sha512-PCVAQswWemu6UdxsDFFX/+gVeYqKAod3D3UVm91jHwynguOwAvYPhx8nNlM++NqRcK6CxxpUafjmhIdKiHibqg==", "license": "MIT" }, + "node_modules/asynckit": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz", + "integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==", + "license": "MIT" + }, "node_modules/autoprefixer": { "version": "10.5.0", "resolved": "https://registry.npmjs.org/autoprefixer/-/autoprefixer-10.5.0.tgz", @@ -1707,6 +1715,17 @@ "postcss": "^8.1.0" } }, + "node_modules/axios": { + "version": "1.15.2", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.15.2.tgz", + "integrity": "sha512-wLrXxPtcrPTsNlJmKjkPnNPK2Ihe0hn0wGSaTEiHRPxwjvJwT3hKmXF4dpqxmPO9SoNb2FsYXj/xEo0gHN+D5A==", + "license": "MIT", + "dependencies": { + "follow-redirects": "^1.15.11", + "form-data": "^4.0.5", + "proxy-from-env": "^2.1.0" + } + }, "node_modules/base64-js": { "version": "1.5.1", "resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz", @@ -1908,6 +1927,18 @@ "node": ">=6" } }, + "node_modules/combined-stream": { + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz", + "integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==", + "license": "MIT", + "dependencies": { + "delayed-stream": "~1.0.0" + }, + "engines": { + "node": ">= 0.8" + } + }, "node_modules/content-disposition": { "version": "0.5.4", "resolved": "https://registry.npmjs.org/content-disposition/-/content-disposition-0.5.4.tgz", @@ -1983,6 +2014,15 @@ } } }, + "node_modules/delayed-stream": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz", + "integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==", + "license": "MIT", + "engines": { + "node": ">=0.4.0" + } + }, "node_modules/depd": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", @@ -2110,6 +2150,21 @@ "node": ">= 0.4" } }, + "node_modules/es-set-tostringtag": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/es-set-tostringtag/-/es-set-tostringtag-2.1.0.tgz", + "integrity": "sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.6", + "has-tostringtag": "^1.0.2", + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/esbuild": { "version": "0.27.7", "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.27.7.tgz", @@ -2316,6 +2371,42 @@ "integrity": "sha512-Tpp60P6IUJDTuOq/5Z8cdskzJujfwqfOTkrwIwj7IRISpnkJnT6SyJ4PCPnGMoFjC9ddhal5KVIYtAt97ix05A==", "license": "MIT" }, + "node_modules/follow-redirects": { + "version": "1.16.0", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.16.0.tgz", + "integrity": "sha512-y5rN/uOsadFT/JfYwhxRS5R7Qce+g3zG97+JrtFZlC9klX/W5hD7iiLzScI4nZqUS7DNUdhPgw4xI8W2LuXlUw==", + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/RubenVerborgh" + } + ], + "license": "MIT", + "engines": { + "node": ">=4.0" + }, + "peerDependenciesMeta": { + "debug": { + "optional": true + } + } + }, + "node_modules/form-data": { + "version": "4.0.5", + "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.5.tgz", + "integrity": "sha512-8RipRLol37bNs2bhoV67fiTEvdTrbMUYcFTiy3+wuuOnUog2QBHCZWXDRijWQfAkhBj2Uf5UnVaiWwA5vdd82w==", + "license": "MIT", + "dependencies": { + "asynckit": "^0.4.0", + "combined-stream": "^1.0.8", + "es-set-tostringtag": "^2.1.0", + "hasown": "^2.0.2", + "mime-types": "^2.1.12" + }, + "engines": { + "node": ">= 6" + } + }, "node_modules/formdata-polyfill": { "version": "4.0.10", "resolved": "https://registry.npmjs.org/formdata-polyfill/-/formdata-polyfill-4.0.10.tgz", @@ -2553,6 +2644,21 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/has-tostringtag": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.2.tgz", + "integrity": "sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==", + "license": "MIT", + "dependencies": { + "has-symbols": "^1.0.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/hasown": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.3.tgz", @@ -3346,6 +3452,15 @@ "node": ">= 0.10" } }, + "node_modules/proxy-from-env": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-2.1.0.tgz", + "integrity": "sha512-cJ+oHTW1VAEa8cJslgmUZrc+sjRKgAKl3Zyse6+PV38hZe/V6Z14TbCuXcan9F9ghlz4QrFr2c92TNF82UkYHA==", + "license": "MIT", + "engines": { + "node": ">=10" + } + }, "node_modules/qs": { "version": "6.14.2", "resolved": "https://registry.npmjs.org/qs/-/qs-6.14.2.tgz", @@ -4461,6 +4576,35 @@ "resolved": "https://registry.npmjs.org/yallist/-/yallist-3.1.1.tgz", "integrity": "sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g==", "license": "ISC" + }, + "node_modules/zustand": { + "version": "5.0.12", + "resolved": "https://registry.npmjs.org/zustand/-/zustand-5.0.12.tgz", + "integrity": "sha512-i77ae3aZq4dhMlRhJVCYgMLKuSiZAaUPAct2AksxQ+gOtimhGMdXljRT21P5BNpeT4kXlLIckvkPM029OljD7g==", + "license": "MIT", + "engines": { + "node": ">=12.20.0" + }, + "peerDependencies": { + "@types/react": ">=18.0.0", + "immer": ">=9.0.6", + "react": ">=18.0.0", + "use-sync-external-store": ">=1.2.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "immer": { + "optional": true + }, + "react": { + "optional": true + }, + "use-sync-external-store": { + "optional": true + } + } } } } diff --git a/package.json b/package.json index 9b4af12..4b43b20 100644 --- a/package.json +++ b/package.json @@ -15,6 +15,7 @@ "@google/genai": "^1.29.0", "@tailwindcss/vite": "^4.1.14", "@vitejs/plugin-react": "^5.0.4", + "axios": "^1.15.2", "clsx": "^2.1.1", "dotenv": "^17.2.3", "express": "^4.21.2", @@ -26,7 +27,8 @@ "react-konva": "^19.2.3", "tailwind-merge": "^3.5.0", "use-image": "^1.1.4", - "vite": "^6.2.0" + "vite": "^6.2.0", + "zustand": "^5.0.12" }, "devDependencies": { "@types/express": "^4.17.21", diff --git a/src/App.tsx b/src/App.tsx index ec99945..8ac49e0 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -1,4 +1,6 @@ -import React, { useState } from 'react'; +import React, { useEffect } from 'react'; +import { useStore } from './store/useStore'; +import { getProjects } from './lib/api'; import { Sidebar } from './components/Sidebar'; import { Dashboard } from './components/Dashboard'; import { ProjectLibrary } from './components/ProjectLibrary'; @@ -10,16 +12,30 @@ import { Login } from './components/Login'; export type ActiveModule = 'dashboard' | 'projects' | 'ai' | 'workspace' | 'templates'; export default function App() { - const [activeModule, setActiveModule] = useState('workspace'); - const [isAuthenticated, setIsAuthenticated] = useState(false); + const isAuthenticated = useStore((state) => state.isAuthenticated); + const activeModule = useStore((state) => state.activeModule); + const setActiveModule = useStore((state) => state.setActiveModule); + const setProjects = useStore((state) => state.setProjects); + const setError = useStore((state) => state.setError); + + useEffect(() => { + if (isAuthenticated) { + getProjects() + .then((data) => setProjects(data)) + .catch((err) => { + console.error('Failed to fetch projects:', err); + setError('获取项目列表失败'); + }); + } + }, [isAuthenticated, setProjects, setError]); if (!isAuthenticated) { - return setIsAuthenticated(true)} />; + return ; } return (
- +
{activeModule === 'dashboard' && } {activeModule === 'projects' && setActiveModule('workspace')} />} diff --git a/src/components/AISegmentation.tsx b/src/components/AISegmentation.tsx index b974671..39600e3 100644 --- a/src/components/AISegmentation.tsx +++ b/src/components/AISegmentation.tsx @@ -1,20 +1,28 @@ -import React, { useState } from 'react'; -import { Target, PlusCircle, MinusCircle, SquareDashed, Sparkles, Settings2, Cpu, Image as ImageIcon, SendToBack, Tags, Undo, Redo } from 'lucide-react'; +import React, { useState, useCallback } from 'react'; +import { Target, PlusCircle, MinusCircle, SquareDashed, Sparkles, SendToBack, Image as ImageIcon, Undo, Redo, Loader2 } from 'lucide-react'; import { cn } from '../lib/utils'; import { Stage, Layer, Image as KonvaImage, Circle, Path, Group } from 'react-konva'; import useImage from 'use-image'; import { OntologyInspector } from './OntologyInspector'; +import { useStore } from '../store/useStore'; +import { predictMask } from '../lib/api'; interface AISegmentationProps { onSendToWorkspace: () => void; } export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { - const [activeTool, setActiveTool] = useState('point_pos'); + const storeActiveTool = useStore((state) => state.activeTool); + const setActiveTool = useStore((state) => state.setActiveTool); + const masks = useStore((state) => state.masks); + const addMask = useStore((state) => state.addMask); + const clearMasks = useStore((state) => state.clearMasks); + const [modelSize, setModelSize] = useState('vit_l'); const [semanticText, setSemanticText] = useState(''); const [autoDeleteBg, setAutoDeleteBg] = useState(true); const [cropMode, setCropMode] = useState(false); + const [isInferencing, setIsInferencing] = useState(false); // Canvas state const [scale, setScale] = useState(1); @@ -23,6 +31,8 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { const [cursorPos, setCursorPos] = useState({ x: 0, y: 0 }); const [image] = useImage('https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop'); + const effectiveTool = storeActiveTool; + const handleWheel = (e: any) => { e.evt.preventDefault(); const scaleBy = 1.1; @@ -51,13 +61,43 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) { } }; + const runInference = useCallback(async () => { + if (points.length === 0 && !semanticText.trim()) return; + setIsInferencing(true); + try { + const result = await predictMask({ + imageUrl: 'https://images.unsplash.com/photo-1549317661-bd32c8ce0be2?q=80&w=2070&auto=format&fit=crop', + points: points.map((p) => ({ x: p.x, y: p.y, type: p.type })), + text: semanticText.trim() || undefined, + modelSize, + }); + + result.masks.forEach((m) => { + addMask({ + id: m.id, + frameId: 'frame-ai-1', + pathData: m.pathData, + label: m.label, + color: m.color, + segmentation: m.segmentation, + bbox: m.bbox, + area: m.area, + }); + }); + } catch (err) { + console.error('AI inference failed:', err); + } finally { + setIsInferencing(false); + } + }, [points, semanticText, modelSize, addMask]); + const handleStageClick = (e: any) => { - if (activeTool === 'move') return; - if (activeTool === 'point_pos' || activeTool === 'point_neg') { + if (effectiveTool === 'move') return; + if (effectiveTool === 'point_pos' || effectiveTool === 'point_neg') { const stage = e.target.getStage(); const pos = stage.getRelativePointerPosition(); if (pos) { - setPoints([...points, { x: pos.x, y: pos.y, type: activeTool === 'point_pos' ? 'pos' : 'neg' }]); + setPoints([...points, { x: pos.x, y: pos.y, type: effectiveTool === 'point_pos' ? 'pos' : 'neg' }]); } } }; @@ -68,7 +108,7 @@ export function AISegmentation({ onSendToWorkspace }: AISegmentationProps) {