"""Annotation export endpoints (COCO, PNG masks).""" import io import json import logging import os import re import zipfile from datetime import datetime from pathlib import Path from typing import Any, Dict, List from urllib.parse import quote import numpy as np from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from database import get_db from minio_client import download_file from models import Project, Annotation, Frame, Template, User from routers.auth import get_current_user logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/export", tags=["Export"]) def _mask_from_polygon( polygon: List[List[float]], width: int, height: int, ) -> np.ndarray: """Render a normalized polygon to a binary mask.""" import cv2 pts = np.array( [[int(p[0] * width), int(p[1] * height)] for p in polygon], dtype=np.int32, ) mask = np.zeros((height, width), dtype=np.uint8) cv2.fillPoly(mask, [pts], 255) return mask def _annotation_z_index(annotation: Annotation) -> int: class_meta = (annotation.mask_data or {}).get("class") or {} if isinstance(class_meta, dict) and class_meta.get("zIndex") is not None: try: return int(class_meta["zIndex"]) except (TypeError, ValueError): pass if annotation.template and annotation.template.z_index is not None: return int(annotation.template.z_index) return 0 def _annotation_mask_id(annotation: Annotation) -> int | None: class_meta = (annotation.mask_data or {}).get("class") or {} if isinstance(class_meta, dict): for key in ("maskId", "maskid", "mask_id"): if class_meta.get(key) is None: continue try: value = int(class_meta[key]) except (TypeError, ValueError): continue if value >= 0: return value return None def _annotation_category_name(annotation: Annotation) -> str: class_meta = (annotation.mask_data or {}).get("class") or {} if isinstance(class_meta, dict) and class_meta.get("category"): return str(class_meta["category"]) if annotation.template and annotation.template.name: return str(annotation.template.name) return "" def _annotation_class_key(annotation: Annotation) -> str: class_meta = (annotation.mask_data or {}).get("class") or {} if isinstance(class_meta, dict): if class_meta.get("id"): return f"class:{class_meta['id']}" if class_meta.get("name"): return f"name:{class_meta['name']}" if annotation.template_id: return f"template:{annotation.template_id}" return f"annotation:{annotation.id}" def _annotation_label(annotation: Annotation) -> str: mask_data = annotation.mask_data or {} class_meta = mask_data.get("class") or {} if isinstance(class_meta, dict) and class_meta.get("name"): return str(class_meta["name"]) if mask_data.get("label"): return str(mask_data["label"]) if annotation.template and annotation.template.name: return str(annotation.template.name) return f"Annotation {annotation.id}" def _annotation_color(annotation: Annotation) -> str: mask_data = annotation.mask_data or {} class_meta = mask_data.get("class") or {} if isinstance(class_meta, dict) and class_meta.get("color"): return str(class_meta["color"]) if mask_data.get("color"): return str(mask_data["color"]) if annotation.template and annotation.template.color: return str(annotation.template.color) return "#ffffff" def _hex_to_rgb(color: str) -> list[int]: value = str(color or "").strip() if value.startswith("#"): value = value[1:] if len(value) == 3: value = "".join(part * 2 for part in value) if len(value) != 6: return [255, 255, 255] try: return [int(value[i:i + 2], 16) for i in (0, 2, 4)] except ValueError: return [255, 255, 255] def _safe_filename_part(value: Any, fallback: str = "unknown") -> str: text = str(value or "").strip() if not text: text = fallback text = re.sub(r"[\\/:*?\"<>|\s]+", "_", text) text = re.sub(r"_+", "_", text).strip("._") return text or fallback def _project_video_name(project: Project) -> str: if project.video_path: stem = Path(project.video_path).name if "." in stem: stem = ".".join(stem.split(".")[:-1]) if stem: return _safe_filename_part(stem, f"project_{project.id}") return _safe_filename_part(project.name, f"project_{project.id}") def _project_export_name(project: Project) -> str: return _safe_filename_part(project.name, f"project_{project.id}") def _frame_timestamp_ms(frame: Frame, project: Project) -> float: if frame.timestamp_ms is not None: return float(frame.timestamp_ms) fps = project.parse_fps or project.original_fps or 30.0 return float(frame.frame_index) * 1000.0 / max(float(fps), 1.0) def _project_frame_number(frame: Frame) -> int: return int(frame.frame_index) + 1 def _format_timestamp_ms(value: float) -> str: total_ms = max(0, int(round(float(value)))) hours = total_ms // 3_600_000 minutes = (total_ms % 3_600_000) // 60_000 seconds = (total_ms % 60_000) // 1_000 milliseconds = total_ms % 1_000 return f"{hours}h{minutes:02d}m{seconds:02d}s{milliseconds:03d}ms" def _frame_export_stem(project: Project, frame: Frame) -> str: return "_".join([ _project_video_name(project), _format_timestamp_ms(_frame_timestamp_ms(frame, project)), f"frame{_project_frame_number(frame):06d}", ]) def _segmentation_results_filename(project: Project, frames: list[Frame]) -> str: if not frames: return f"{_project_export_name(project)}_seg_T_0h00m00s000ms-0h00m00s000ms_P_0-0.zip" first_frame = frames[0] last_frame = frames[-1] return ( f"{_project_export_name(project)}" f"_seg_T_{_format_timestamp_ms(_frame_timestamp_ms(first_frame, project))}" f"-{_format_timestamp_ms(_frame_timestamp_ms(last_frame, project))}" f"_P_{_project_frame_number(first_frame)}-{_project_frame_number(last_frame)}.zip" ) def _download_content_disposition(filename: str) -> str: ascii_fallback = filename.encode("ascii", "ignore").decode("ascii") or "segmentation_results.zip" ascii_fallback = _safe_filename_part(ascii_fallback, "segmentation_results.zip") if not ascii_fallback.endswith(".zip") and filename.endswith(".zip"): ascii_fallback = f"{ascii_fallback}.zip" return f"attachment; filename=\"{ascii_fallback}\"; filename*=UTF-8''{quote(filename)}" def _frame_image_extension(frame: Frame) -> str: suffix = Path(frame.image_url or "").suffix.lower() return suffix if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} else ".jpg" def _project_or_404(project_id: int, db: Session, current_user: User) -> Project: _ = current_user project = db.query(Project).filter(Project.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") return project def _project_frames(project_id: int, db: Session) -> list[Frame]: return ( db.query(Frame) .filter(Frame.project_id == project_id) .order_by(Frame.frame_index) .all() ) def _filter_frames( frames: list[Frame], *, scope: str = "all", start_frame: int | None = None, end_frame: int | None = None, frame_id: int | None = None, ) -> list[Frame]: if scope == "current": if frame_id is None: raise HTTPException(status_code=400, detail="frame_id is required for current-frame export") selected = [frame for frame in frames if frame.id == frame_id] if not selected: raise HTTPException(status_code=404, detail="Frame not found") return selected if scope == "range": if start_frame is None or end_frame is None: raise HTTPException(status_code=400, detail="start_frame and end_frame are required for range export") start = max(1, min(int(start_frame), int(end_frame))) end = max(1, max(int(start_frame), int(end_frame))) return frames[start - 1:end] return frames def _filtered_annotations(project_id: int, frame_ids: set[int], db: Session) -> list[Annotation]: if not frame_ids: return [] return ( db.query(Annotation) .filter(Annotation.project_id == project_id) .filter(Annotation.frame_id.in_(frame_ids)) .all() ) def _build_coco(project: Project, frames: list[Frame], annotations: list[Annotation], templates: list[Template]) -> dict[str, Any]: images = [] for frame in frames: images.append({ "id": frame.id, "file_name": frame.image_url, "width": frame.width or 1920, "height": frame.height or 1080, "frame_index": frame.frame_index, }) categories = [] template_id_to_cat_id: Dict[int, int] = {} for cat_idx, tmpl in enumerate(templates, start=1): categories.append({ "id": cat_idx, "name": tmpl.name, "color": tmpl.color, }) template_id_to_cat_id[tmpl.id] = cat_idx coco_annotations = [] ann_id = 1 selected_frame_ids = {frame.id for frame in frames} for ann in annotations: if ann.frame_id not in selected_frame_ids or not ann.mask_data: continue polygons = ann.mask_data.get("polygons", []) if not polygons: continue first_poly = polygons[0] xs = [p[0] for p in first_poly] ys = [p[1] for p in first_poly] width = ann.frame.width if ann.frame else 1920 height = ann.frame.height if ann.frame else 1080 bbox = [ min(xs) * width, min(ys) * height, (max(xs) - min(xs)) * width, (max(ys) - min(ys)) * height, ] area = bbox[2] * bbox[3] segmentation = [] for poly in polygons: flat = [] for p in poly: flat.append(p[0] * width) flat.append(p[1] * height) segmentation.append(flat) coco_annotations.append({ "id": ann_id, "image_id": ann.frame_id, "category_id": template_id_to_cat_id.get(ann.template_id, 0), "segmentation": segmentation, "area": area, "bbox": bbox, "iscrowd": 0, }) ann_id += 1 return { "info": { "description": f"Annotations for {project.name}", "version": "1.0", "year": datetime.now().year, "date_created": datetime.now().isoformat(), }, "images": images, "annotations": coco_annotations, "categories": categories, } def _class_mapping_entry(annotation: Annotation) -> dict[str, Any]: return { "key": _annotation_class_key(annotation), "className": _annotation_label(annotation), "chineseName": _annotation_label(annotation), "categoryName": _annotation_category_name(annotation), "color": _annotation_color(annotation), "internalPriority": _annotation_z_index(annotation), "maskidHint": _annotation_mask_id(annotation), "template_id": annotation.template_id, } def _build_gt_class_mapping(annotations: list[Annotation]) -> tuple[dict[str, int], list[dict[str, Any]]]: entries_by_key: dict[str, dict[str, Any]] = {} for annotation in annotations: if not annotation.mask_data or not annotation.mask_data.get("polygons"): continue entry = _class_mapping_entry(annotation) entries_by_key.setdefault(entry["key"], entry) ordered = sorted( entries_by_key.values(), key=lambda item: ( item["maskidHint"] if isinstance(item.get("maskidHint"), int) and item["maskidHint"] >= 0 else 10_000_000, str(item["className"]), str(item["key"]), ), ) key_to_value: dict[str, int] = {} classes: list[dict[str, Any]] = [] used_maskids: set[int] = set() next_maskid = 1 def next_available_maskid() -> int: nonlocal next_maskid while next_maskid in used_maskids: next_maskid += 1 if next_maskid > 255: raise HTTPException(status_code=400, detail="GT_label 仅支持 8-bit maskid,类别值必须在 1-255 之间") value = next_maskid used_maskids.add(value) next_maskid += 1 return value for entry in ordered: hinted_maskid = entry.get("maskidHint") if isinstance(hinted_maskid, int) and hinted_maskid > 255: raise HTTPException(status_code=400, detail="GT_label 仅支持 8-bit maskid,类别值必须在 1-255 之间") if isinstance(hinted_maskid, int) and hinted_maskid == 0: maskid = 0 used_maskids.add(maskid) elif isinstance(hinted_maskid, int) and 0 < hinted_maskid <= 255 and hinted_maskid not in used_maskids: maskid = hinted_maskid used_maskids.add(maskid) else: maskid = next_available_maskid() key_to_value[entry["key"]] = maskid classes.append({ "gt_pixel_value": maskid, "maskid": maskid, "chineseName": entry["chineseName"], "className": entry["className"], "categoryName": entry["categoryName"], "rgb": _hex_to_rgb(entry["color"]), "color": entry["color"], "key": entry["key"], "template_id": entry["template_id"], }) return key_to_value, classes def _parse_result_outputs(mask_type: str, outputs: str | None) -> set[str]: allowed = {"separate", "gt_label", "pro_label", "mix_label"} if outputs: parsed = {item.strip() for item in outputs.split(",") if item.strip()} invalid = parsed - allowed if invalid: raise HTTPException(status_code=400, detail=f"Invalid outputs: {', '.join(sorted(invalid))}") return parsed or allowed if mask_type == "separate": return {"separate"} if mask_type == "gt_label": return {"gt_label"} if mask_type == "pro_label": return {"pro_label"} if mask_type == "mix_label": return {"mix_label"} return allowed def _write_original_frames( zf: zipfile.ZipFile, project: Project, frames: list[Frame], ) -> dict[int, bytes]: image_bytes_by_frame: dict[int, bytes] = {} for frame in frames: image_bytes = download_file(frame.image_url) image_bytes_by_frame[frame.id] = image_bytes zf.writestr( f"原始图片/{_frame_export_stem(project, frame)}{_frame_image_extension(frame)}", image_bytes, ) return image_bytes_by_frame def _decode_original_image(image_bytes: bytes | None, width: int, height: int) -> np.ndarray: import cv2 if image_bytes: decoded = cv2.imdecode(np.frombuffer(image_bytes, dtype=np.uint8), cv2.IMREAD_COLOR) if decoded is not None: if decoded.shape[1] != width or decoded.shape[0] != height: decoded = cv2.resize(decoded, (width, height), interpolation=cv2.INTER_AREA) return decoded return np.zeros((height, width, 3), dtype=np.uint8) def _write_result_mask_outputs( zf: zipfile.ZipFile, project: Project, frames: list[Frame], annotations: list[Annotation], *, outputs: set[str], class_values: dict[str, int], class_mapping: list[dict[str, Any]], original_images: dict[int, bytes], mix_opacity: float, ) -> None: import cv2 include_individual = "separate" in outputs include_semantic = "gt_label" in outputs include_pro_label = "pro_label" in outputs include_mix_label = "mix_label" in outputs class_rgb_by_key = { item["key"]: item.get("rgb") or _hex_to_rgb(item.get("color", "#ffffff")) for item in class_mapping } annotations_by_frame: dict[int, list[Annotation]] = {} selected_frame_ids = {frame.id for frame in frames} for annotation in annotations: if annotation.frame_id not in selected_frame_ids or not annotation.mask_data: continue if not annotation.mask_data.get("polygons"): continue annotations_by_frame.setdefault(annotation.frame_id, []).append(annotation) for frame in frames: frame_annotations = annotations_by_frame.get(frame.id, []) if not frame_annotations: continue width = frame.width or 1920 height = frame.height or 1080 frame_stem = _frame_export_stem(project, frame) if include_individual: class_masks: dict[str, np.ndarray] = {} class_meta: dict[str, dict[str, Any]] = {} for annotation in frame_annotations: key = _annotation_class_key(annotation) combined = class_masks.setdefault(key, np.zeros((height, width), dtype=np.uint8)) for poly in (annotation.mask_data or {}).get("polygons", []): combined[:] = np.maximum(combined, _mask_from_polygon(poly, width, height)) class_meta.setdefault(key, _class_mapping_entry(annotation)) folder = f"分开Mask分割结果/{frame_stem}_分别导出" for key, mask in sorted(class_masks.items(), key=lambda item: int(class_meta[item[0]]["internalPriority"])): meta = class_meta[key] maskid = class_values.get(key) if maskid is None: continue _, encoded = cv2.imencode(".png", mask) class_name = _safe_filename_part(meta["className"], "class") zf.writestr( f"{folder}/{frame_stem}_{class_name}_maskid{maskid}.png", encoded.tobytes(), ) needs_fused_output = include_semantic or include_pro_label or include_mix_label semantic = np.zeros((height, width), dtype=np.uint8) if needs_fused_output else None pro_label = np.zeros((height, width, 3), dtype=np.uint8) if (include_pro_label or include_mix_label) else None if needs_fused_output: for annotation in sorted(frame_annotations, key=_annotation_z_index): key = _annotation_class_key(annotation) value = class_values.get(key) if value is None: continue combined = np.zeros((height, width), dtype=np.uint8) for poly in (annotation.mask_data or {}).get("polygons", []): combined = np.maximum(combined, _mask_from_polygon(poly, width, height)) if semantic is not None: semantic[combined > 0] = value if pro_label is not None: rgb = class_rgb_by_key.get(key, [255, 255, 255]) bgr = np.array([rgb[2], rgb[1], rgb[0]], dtype=np.uint8) pro_label[combined > 0] = bgr if include_semantic and semantic is not None: _, encoded = cv2.imencode(".png", semantic) zf.writestr(f"GT_label图/{frame_stem}.png", encoded.tobytes()) if include_pro_label and pro_label is not None: _, encoded = cv2.imencode(".png", pro_label) zf.writestr(f"Pro_label彩色分割结果/{frame_stem}.png", encoded.tobytes()) if include_mix_label and pro_label is not None: original = _decode_original_image(original_images.get(frame.id), width, height) mask_pixels = np.any(pro_label > 0, axis=2) mixed = original.copy() opacity = min(max(float(mix_opacity), 0.0), 1.0) mixed[mask_pixels] = ( original[mask_pixels].astype(np.float32) * (1.0 - opacity) + pro_label[mask_pixels].astype(np.float32) * opacity ).clip(0, 255).astype(np.uint8) _, encoded = cv2.imencode(".png", mixed) zf.writestr(f"Mix_label重叠覆盖彩色分割结果/{frame_stem}.png", encoded.tobytes()) def _write_mask_pngs( zf: zipfile.ZipFile, frames: list[Frame], annotations: list[Annotation], *, mask_type: str, individual_prefix: str = "", semantic_prefix: str = "", semantic_file_stem: str = "semantic_frame", semantic_dtype: Any = np.uint8, ) -> list[dict[str, Any]]: import cv2 class_values: dict[str, int] = {} semantic_classes: list[dict[str, Any]] = [] def class_value(annotation: Annotation) -> int: key = _annotation_class_key(annotation) if key not in class_values: value = len(class_values) + 1 class_values[key] = value semantic_classes.append({ "value": value, "key": key, "label": _annotation_label(annotation), "color": _annotation_color(annotation), "zIndex": _annotation_z_index(annotation), "template_id": annotation.template_id, }) return class_values[key] include_individual = mask_type in {"separate", "both"} include_semantic = mask_type in {"gt_label", "both"} frame_masks: dict[int, list[tuple[Annotation, np.ndarray]]] = {} selected_frame_ids = {frame.id for frame in frames} for ann in annotations: if ann.frame_id not in selected_frame_ids or not ann.mask_data: continue polygons = ann.mask_data.get("polygons", []) if not polygons: continue width = ann.frame.width if ann.frame else 1920 height = ann.frame.height if ann.frame else 1080 combined = np.zeros((height, width), dtype=np.uint8) for poly in polygons: mask = _mask_from_polygon(poly, width, height) combined = np.maximum(combined, mask) if include_individual: _, encoded = cv2.imencode(".png", combined) zf.writestr(f"{individual_prefix}mask_{ann.id:06d}.png", encoded.tobytes()) if include_semantic and ann.frame_id is not None: frame_masks.setdefault(ann.frame_id, []).append((ann, combined)) if include_semantic: for frame in frames: entries = frame_masks.get(frame.id, []) if not entries: continue width = frame.width or 1920 height = frame.height or 1080 semantic = np.zeros((height, width), dtype=semantic_dtype) for ann, mask in sorted(entries, key=lambda item: _annotation_z_index(item[0])): semantic[mask > 0] = class_value(ann) _, encoded = cv2.imencode(".png", semantic) zf.writestr(f"{semantic_prefix}{semantic_file_stem}_{frame.frame_index:06d}.png", encoded.tobytes()) if include_semantic: zf.writestr( f"{semantic_prefix}semantic_classes.json", json.dumps({"classes": semantic_classes}, ensure_ascii=False, indent=2).encode("utf-8"), ) return semantic_classes @router.get( "/{project_id}/coco", summary="Export annotations in COCO format", ) def export_coco( project_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> StreamingResponse: """Export all annotations for a project as a COCO-format JSON file.""" project = _project_or_404(project_id, db, current_user) frames = _project_frames(project_id, db) annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db) templates = db.query(Template).all() coco = _build_coco(project, frames, annotations, templates) data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8") filename = f"project_{project_id}_coco.json" return StreamingResponse( io.BytesIO(data), media_type="application/json", headers={"Content-Disposition": f'attachment; filename="{filename}"'}, ) @router.get( "/{project_id}/masks", summary="Export PNG masks as a ZIP archive", ) def export_masks( project_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> StreamingResponse: """Export individual masks plus z-index fused semantic masks inside a ZIP.""" _project_or_404(project_id, db, current_user) frames = _project_frames(project_id, db) annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db) zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf: _write_mask_pngs( zf, frames, annotations, mask_type="both", semantic_prefix="", individual_prefix="", ) zip_buffer.seek(0) filename = f"project_{project_id}_masks.zip" return StreamingResponse( zip_buffer, media_type="application/zip", headers={"Content-Disposition": f'attachment; filename="{filename}"'}, ) @router.get( "/{project_id}/results", summary="Export segmentation results as a ZIP archive", ) def export_results( project_id: int, scope: str = Query("all", pattern="^(all|range|current)$"), mask_type: str = Query("both", pattern="^(separate|gt_label|pro_label|mix_label|both|all)$"), outputs: str | None = Query(None), mix_opacity: float = Query(0.3, ge=0.0, le=1.0), start_frame: int | None = Query(None, ge=1), end_frame: int | None = Query(None, ge=1), frame_id: int | None = Query(None, ge=1), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> StreamingResponse: """Export JSON annotations plus selected PNG mask outputs inside one ZIP. `scope=all` exports the whole video. `scope=range` uses 1-based frame numbers from the sorted project frame sequence. `scope=current` uses the concrete backend `frame_id`. """ project = _project_or_404(project_id, db, current_user) frames = _filter_frames( _project_frames(project_id, db), scope=scope, start_frame=start_frame, end_frame=end_frame, frame_id=frame_id, ) annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db) templates = db.query(Template).all() coco = _build_coco(project, frames, annotations, templates) class_values, class_mapping = _build_gt_class_mapping(annotations) selected_outputs = _parse_result_outputs(mask_type, outputs) zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf: zf.writestr( "annotations_coco.json", json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8"), ) zf.writestr( "maskid_GT像素值_类别映射.json", json.dumps({"classes": class_mapping}, ensure_ascii=False, indent=2).encode("utf-8"), ) original_images = _write_original_frames(zf, project, frames) _write_result_mask_outputs( zf, project, frames, annotations, outputs=selected_outputs, class_values=class_values, class_mapping=class_mapping, original_images=original_images, mix_opacity=mix_opacity, ) zip_buffer.seek(0) filename = _segmentation_results_filename(project, frames) return StreamingResponse( zip_buffer, media_type="application/zip", headers={"Content-Disposition": _download_content_disposition(filename)}, )