- 新增 Seg_Server_Docker 自包含部署内容,包含前后端、FastAPI、Celery、PostgreSQL、Redis、MinIO、演示视频和 DICOM 数据。 - 保留 demo 数据以支持恢复演示出厂设置,排除 SAM 2.1 .pt 权重并在 README 中补充下载命令。 - 补充 GPU 部署、backend/worker 镜像复用、frpc/frps + NPM 公网域名反代部署说明。 - 在 .env/.env.example 中用 # XXXX 标注局域网和公网域名部署需要修改的配置项。 - 添加部署分支 .gitignore,忽略本地模型权重、构建产物、缓存和日志。
765 lines
27 KiB
Python
765 lines
27 KiB
Python
"""Annotation export endpoints (COCO, PNG masks)."""
|
||
|
||
import io
|
||
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
import zipfile
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List
|
||
from urllib.parse import quote
|
||
|
||
import numpy as np
|
||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||
from fastapi.responses import StreamingResponse
|
||
from sqlalchemy.orm import Session
|
||
|
||
from database import get_db
|
||
from minio_client import download_file
|
||
from models import Project, Annotation, Frame, Template, User
|
||
from routers.auth import get_current_user
|
||
|
||
logger = logging.getLogger(__name__)
|
||
router = APIRouter(prefix="/api/export", tags=["Export"])
|
||
|
||
|
||
def _mask_from_polygon(
|
||
polygon: List[List[float]],
|
||
width: int,
|
||
height: int,
|
||
) -> np.ndarray:
|
||
"""Render a normalized polygon to a binary mask."""
|
||
import cv2
|
||
|
||
pts = np.array(
|
||
[[int(p[0] * width), int(p[1] * height)] for p in polygon],
|
||
dtype=np.int32,
|
||
)
|
||
mask = np.zeros((height, width), dtype=np.uint8)
|
||
cv2.fillPoly(mask, [pts], 255)
|
||
return mask
|
||
|
||
|
||
def _annotation_z_index(annotation: Annotation) -> int:
|
||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||
if isinstance(class_meta, dict) and class_meta.get("zIndex") is not None:
|
||
try:
|
||
return int(class_meta["zIndex"])
|
||
except (TypeError, ValueError):
|
||
pass
|
||
if annotation.template and annotation.template.z_index is not None:
|
||
return int(annotation.template.z_index)
|
||
return 0
|
||
|
||
|
||
def _annotation_mask_id(annotation: Annotation) -> int | None:
|
||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||
if isinstance(class_meta, dict):
|
||
for key in ("maskId", "maskid", "mask_id"):
|
||
if class_meta.get(key) is None:
|
||
continue
|
||
try:
|
||
value = int(class_meta[key])
|
||
except (TypeError, ValueError):
|
||
continue
|
||
if value >= 0:
|
||
return value
|
||
return None
|
||
|
||
|
||
def _annotation_category_name(annotation: Annotation) -> str:
|
||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||
if isinstance(class_meta, dict) and class_meta.get("category"):
|
||
return str(class_meta["category"])
|
||
if annotation.template and annotation.template.name:
|
||
return str(annotation.template.name)
|
||
return ""
|
||
|
||
|
||
def _annotation_class_key(annotation: Annotation) -> str:
|
||
class_meta = (annotation.mask_data or {}).get("class") or {}
|
||
if isinstance(class_meta, dict):
|
||
if class_meta.get("id"):
|
||
return f"class:{class_meta['id']}"
|
||
if class_meta.get("name"):
|
||
return f"name:{class_meta['name']}"
|
||
if annotation.template_id:
|
||
return f"template:{annotation.template_id}"
|
||
return f"annotation:{annotation.id}"
|
||
|
||
|
||
def _annotation_label(annotation: Annotation) -> str:
|
||
mask_data = annotation.mask_data or {}
|
||
class_meta = mask_data.get("class") or {}
|
||
if isinstance(class_meta, dict) and class_meta.get("name"):
|
||
return str(class_meta["name"])
|
||
if mask_data.get("label"):
|
||
return str(mask_data["label"])
|
||
if annotation.template and annotation.template.name:
|
||
return str(annotation.template.name)
|
||
return f"Annotation {annotation.id}"
|
||
|
||
|
||
def _annotation_color(annotation: Annotation) -> str:
|
||
mask_data = annotation.mask_data or {}
|
||
class_meta = mask_data.get("class") or {}
|
||
if isinstance(class_meta, dict) and class_meta.get("color"):
|
||
return str(class_meta["color"])
|
||
if mask_data.get("color"):
|
||
return str(mask_data["color"])
|
||
if annotation.template and annotation.template.color:
|
||
return str(annotation.template.color)
|
||
return "#ffffff"
|
||
|
||
|
||
def _hex_to_rgb(color: str) -> list[int]:
|
||
value = str(color or "").strip()
|
||
if value.startswith("#"):
|
||
value = value[1:]
|
||
if len(value) == 3:
|
||
value = "".join(part * 2 for part in value)
|
||
if len(value) != 6:
|
||
return [255, 255, 255]
|
||
try:
|
||
return [int(value[i:i + 2], 16) for i in (0, 2, 4)]
|
||
except ValueError:
|
||
return [255, 255, 255]
|
||
|
||
|
||
def _safe_filename_part(value: Any, fallback: str = "unknown") -> str:
|
||
text = str(value or "").strip()
|
||
if not text:
|
||
text = fallback
|
||
text = re.sub(r"[\\/:*?\"<>|\s]+", "_", text)
|
||
text = re.sub(r"_+", "_", text).strip("._")
|
||
return text or fallback
|
||
|
||
|
||
def _project_video_name(project: Project) -> str:
|
||
if project.video_path:
|
||
stem = Path(project.video_path).name
|
||
if "." in stem:
|
||
stem = ".".join(stem.split(".")[:-1])
|
||
if stem:
|
||
return _safe_filename_part(stem, f"project_{project.id}")
|
||
return _safe_filename_part(project.name, f"project_{project.id}")
|
||
|
||
|
||
def _project_export_name(project: Project) -> str:
|
||
return _safe_filename_part(project.name, f"project_{project.id}")
|
||
|
||
|
||
def _frame_timestamp_ms(frame: Frame, project: Project) -> float:
|
||
if frame.timestamp_ms is not None:
|
||
return float(frame.timestamp_ms)
|
||
fps = project.parse_fps or project.original_fps or 30.0
|
||
return float(frame.frame_index) * 1000.0 / max(float(fps), 1.0)
|
||
|
||
|
||
def _project_frame_number(frame: Frame) -> int:
|
||
return int(frame.frame_index) + 1
|
||
|
||
|
||
def _format_timestamp_ms(value: float) -> str:
|
||
total_ms = max(0, int(round(float(value))))
|
||
hours = total_ms // 3_600_000
|
||
minutes = (total_ms % 3_600_000) // 60_000
|
||
seconds = (total_ms % 60_000) // 1_000
|
||
milliseconds = total_ms % 1_000
|
||
return f"{hours}h{minutes:02d}m{seconds:02d}s{milliseconds:03d}ms"
|
||
|
||
|
||
def _frame_export_stem(project: Project, frame: Frame) -> str:
|
||
return "_".join([
|
||
_project_video_name(project),
|
||
_format_timestamp_ms(_frame_timestamp_ms(frame, project)),
|
||
f"frame{_project_frame_number(frame):06d}",
|
||
])
|
||
|
||
|
||
def _segmentation_results_filename(project: Project, frames: list[Frame]) -> str:
|
||
if not frames:
|
||
return f"{_project_export_name(project)}_seg_T_0h00m00s000ms-0h00m00s000ms_P_0-0.zip"
|
||
first_frame = frames[0]
|
||
last_frame = frames[-1]
|
||
return (
|
||
f"{_project_export_name(project)}"
|
||
f"_seg_T_{_format_timestamp_ms(_frame_timestamp_ms(first_frame, project))}"
|
||
f"-{_format_timestamp_ms(_frame_timestamp_ms(last_frame, project))}"
|
||
f"_P_{_project_frame_number(first_frame)}-{_project_frame_number(last_frame)}.zip"
|
||
)
|
||
|
||
|
||
def _download_content_disposition(filename: str) -> str:
|
||
ascii_fallback = filename.encode("ascii", "ignore").decode("ascii") or "segmentation_results.zip"
|
||
ascii_fallback = _safe_filename_part(ascii_fallback, "segmentation_results.zip")
|
||
if not ascii_fallback.endswith(".zip") and filename.endswith(".zip"):
|
||
ascii_fallback = f"{ascii_fallback}.zip"
|
||
return f"attachment; filename=\"{ascii_fallback}\"; filename*=UTF-8''{quote(filename)}"
|
||
|
||
|
||
def _frame_image_extension(frame: Frame) -> str:
|
||
suffix = Path(frame.image_url or "").suffix.lower()
|
||
return suffix if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} else ".jpg"
|
||
|
||
|
||
def _project_or_404(project_id: int, db: Session, current_user: User) -> Project:
|
||
_ = current_user
|
||
project = db.query(Project).filter(Project.id == project_id).first()
|
||
if not project:
|
||
raise HTTPException(status_code=404, detail="Project not found")
|
||
return project
|
||
|
||
|
||
def _project_frames(project_id: int, db: Session) -> list[Frame]:
|
||
return (
|
||
db.query(Frame)
|
||
.filter(Frame.project_id == project_id)
|
||
.order_by(Frame.frame_index)
|
||
.all()
|
||
)
|
||
|
||
|
||
def _filter_frames(
|
||
frames: list[Frame],
|
||
*,
|
||
scope: str = "all",
|
||
start_frame: int | None = None,
|
||
end_frame: int | None = None,
|
||
frame_id: int | None = None,
|
||
) -> list[Frame]:
|
||
if scope == "current":
|
||
if frame_id is None:
|
||
raise HTTPException(status_code=400, detail="frame_id is required for current-frame export")
|
||
selected = [frame for frame in frames if frame.id == frame_id]
|
||
if not selected:
|
||
raise HTTPException(status_code=404, detail="Frame not found")
|
||
return selected
|
||
|
||
if scope == "range":
|
||
if start_frame is None or end_frame is None:
|
||
raise HTTPException(status_code=400, detail="start_frame and end_frame are required for range export")
|
||
start = max(1, min(int(start_frame), int(end_frame)))
|
||
end = max(1, max(int(start_frame), int(end_frame)))
|
||
return frames[start - 1:end]
|
||
|
||
return frames
|
||
|
||
|
||
def _filtered_annotations(project_id: int, frame_ids: set[int], db: Session) -> list[Annotation]:
|
||
if not frame_ids:
|
||
return []
|
||
return (
|
||
db.query(Annotation)
|
||
.filter(Annotation.project_id == project_id)
|
||
.filter(Annotation.frame_id.in_(frame_ids))
|
||
.all()
|
||
)
|
||
|
||
|
||
def _build_coco(project: Project, frames: list[Frame], annotations: list[Annotation], templates: list[Template]) -> dict[str, Any]:
|
||
images = []
|
||
for frame in frames:
|
||
images.append({
|
||
"id": frame.id,
|
||
"file_name": frame.image_url,
|
||
"width": frame.width or 1920,
|
||
"height": frame.height or 1080,
|
||
"frame_index": frame.frame_index,
|
||
})
|
||
|
||
categories = []
|
||
template_id_to_cat_id: Dict[int, int] = {}
|
||
for cat_idx, tmpl in enumerate(templates, start=1):
|
||
categories.append({
|
||
"id": cat_idx,
|
||
"name": tmpl.name,
|
||
"color": tmpl.color,
|
||
})
|
||
template_id_to_cat_id[tmpl.id] = cat_idx
|
||
|
||
coco_annotations = []
|
||
ann_id = 1
|
||
selected_frame_ids = {frame.id for frame in frames}
|
||
for ann in annotations:
|
||
if ann.frame_id not in selected_frame_ids or not ann.mask_data:
|
||
continue
|
||
polygons = ann.mask_data.get("polygons", [])
|
||
if not polygons:
|
||
continue
|
||
|
||
first_poly = polygons[0]
|
||
xs = [p[0] for p in first_poly]
|
||
ys = [p[1] for p in first_poly]
|
||
width = ann.frame.width if ann.frame else 1920
|
||
height = ann.frame.height if ann.frame else 1080
|
||
bbox = [
|
||
min(xs) * width,
|
||
min(ys) * height,
|
||
(max(xs) - min(xs)) * width,
|
||
(max(ys) - min(ys)) * height,
|
||
]
|
||
area = bbox[2] * bbox[3]
|
||
|
||
segmentation = []
|
||
for poly in polygons:
|
||
flat = []
|
||
for p in poly:
|
||
flat.append(p[0] * width)
|
||
flat.append(p[1] * height)
|
||
segmentation.append(flat)
|
||
|
||
coco_annotations.append({
|
||
"id": ann_id,
|
||
"image_id": ann.frame_id,
|
||
"category_id": template_id_to_cat_id.get(ann.template_id, 0),
|
||
"segmentation": segmentation,
|
||
"area": area,
|
||
"bbox": bbox,
|
||
"iscrowd": 0,
|
||
})
|
||
ann_id += 1
|
||
|
||
return {
|
||
"info": {
|
||
"description": f"Annotations for {project.name}",
|
||
"version": "1.0",
|
||
"year": datetime.now().year,
|
||
"date_created": datetime.now().isoformat(),
|
||
},
|
||
"images": images,
|
||
"annotations": coco_annotations,
|
||
"categories": categories,
|
||
}
|
||
|
||
|
||
def _class_mapping_entry(annotation: Annotation) -> dict[str, Any]:
|
||
return {
|
||
"key": _annotation_class_key(annotation),
|
||
"className": _annotation_label(annotation),
|
||
"chineseName": _annotation_label(annotation),
|
||
"categoryName": _annotation_category_name(annotation),
|
||
"color": _annotation_color(annotation),
|
||
"internalPriority": _annotation_z_index(annotation),
|
||
"maskidHint": _annotation_mask_id(annotation),
|
||
"template_id": annotation.template_id,
|
||
}
|
||
|
||
|
||
def _build_gt_class_mapping(annotations: list[Annotation]) -> tuple[dict[str, int], list[dict[str, Any]]]:
|
||
entries_by_key: dict[str, dict[str, Any]] = {}
|
||
for annotation in annotations:
|
||
if not annotation.mask_data or not annotation.mask_data.get("polygons"):
|
||
continue
|
||
entry = _class_mapping_entry(annotation)
|
||
entries_by_key.setdefault(entry["key"], entry)
|
||
|
||
ordered = sorted(
|
||
entries_by_key.values(),
|
||
key=lambda item: (
|
||
item["maskidHint"] if isinstance(item.get("maskidHint"), int) and item["maskidHint"] >= 0 else 10_000_000,
|
||
str(item["className"]),
|
||
str(item["key"]),
|
||
),
|
||
)
|
||
key_to_value: dict[str, int] = {}
|
||
classes: list[dict[str, Any]] = []
|
||
used_maskids: set[int] = set()
|
||
next_maskid = 1
|
||
|
||
def next_available_maskid() -> int:
|
||
nonlocal next_maskid
|
||
while next_maskid in used_maskids:
|
||
next_maskid += 1
|
||
if next_maskid > 255:
|
||
raise HTTPException(status_code=400, detail="GT_label 仅支持 8-bit maskid,类别值必须在 1-255 之间")
|
||
value = next_maskid
|
||
used_maskids.add(value)
|
||
next_maskid += 1
|
||
return value
|
||
|
||
for entry in ordered:
|
||
hinted_maskid = entry.get("maskidHint")
|
||
if isinstance(hinted_maskid, int) and hinted_maskid > 255:
|
||
raise HTTPException(status_code=400, detail="GT_label 仅支持 8-bit maskid,类别值必须在 1-255 之间")
|
||
if isinstance(hinted_maskid, int) and hinted_maskid == 0:
|
||
maskid = 0
|
||
used_maskids.add(maskid)
|
||
elif isinstance(hinted_maskid, int) and 0 < hinted_maskid <= 255 and hinted_maskid not in used_maskids:
|
||
maskid = hinted_maskid
|
||
used_maskids.add(maskid)
|
||
else:
|
||
maskid = next_available_maskid()
|
||
key_to_value[entry["key"]] = maskid
|
||
classes.append({
|
||
"gt_pixel_value": maskid,
|
||
"maskid": maskid,
|
||
"chineseName": entry["chineseName"],
|
||
"className": entry["className"],
|
||
"categoryName": entry["categoryName"],
|
||
"rgb": _hex_to_rgb(entry["color"]),
|
||
"color": entry["color"],
|
||
"key": entry["key"],
|
||
"template_id": entry["template_id"],
|
||
})
|
||
return key_to_value, classes
|
||
|
||
|
||
def _parse_result_outputs(mask_type: str, outputs: str | None) -> set[str]:
|
||
allowed = {"separate", "gt_label", "pro_label", "mix_label"}
|
||
if outputs:
|
||
parsed = {item.strip() for item in outputs.split(",") if item.strip()}
|
||
invalid = parsed - allowed
|
||
if invalid:
|
||
raise HTTPException(status_code=400, detail=f"Invalid outputs: {', '.join(sorted(invalid))}")
|
||
return parsed or allowed
|
||
|
||
if mask_type == "separate":
|
||
return {"separate"}
|
||
if mask_type == "gt_label":
|
||
return {"gt_label"}
|
||
if mask_type == "pro_label":
|
||
return {"pro_label"}
|
||
if mask_type == "mix_label":
|
||
return {"mix_label"}
|
||
return allowed
|
||
|
||
|
||
def _write_original_frames(
|
||
zf: zipfile.ZipFile,
|
||
project: Project,
|
||
frames: list[Frame],
|
||
) -> dict[int, bytes]:
|
||
image_bytes_by_frame: dict[int, bytes] = {}
|
||
for frame in frames:
|
||
image_bytes = download_file(frame.image_url)
|
||
image_bytes_by_frame[frame.id] = image_bytes
|
||
zf.writestr(
|
||
f"原始图片/{_frame_export_stem(project, frame)}{_frame_image_extension(frame)}",
|
||
image_bytes,
|
||
)
|
||
return image_bytes_by_frame
|
||
|
||
|
||
def _decode_original_image(image_bytes: bytes | None, width: int, height: int) -> np.ndarray:
|
||
import cv2
|
||
|
||
if image_bytes:
|
||
decoded = cv2.imdecode(np.frombuffer(image_bytes, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||
if decoded is not None:
|
||
if decoded.shape[1] != width or decoded.shape[0] != height:
|
||
decoded = cv2.resize(decoded, (width, height), interpolation=cv2.INTER_AREA)
|
||
return decoded
|
||
return np.zeros((height, width, 3), dtype=np.uint8)
|
||
|
||
|
||
def _write_result_mask_outputs(
|
||
zf: zipfile.ZipFile,
|
||
project: Project,
|
||
frames: list[Frame],
|
||
annotations: list[Annotation],
|
||
*,
|
||
outputs: set[str],
|
||
class_values: dict[str, int],
|
||
class_mapping: list[dict[str, Any]],
|
||
original_images: dict[int, bytes],
|
||
mix_opacity: float,
|
||
) -> None:
|
||
import cv2
|
||
|
||
include_individual = "separate" in outputs
|
||
include_semantic = "gt_label" in outputs
|
||
include_pro_label = "pro_label" in outputs
|
||
include_mix_label = "mix_label" in outputs
|
||
class_rgb_by_key = {
|
||
item["key"]: item.get("rgb") or _hex_to_rgb(item.get("color", "#ffffff"))
|
||
for item in class_mapping
|
||
}
|
||
annotations_by_frame: dict[int, list[Annotation]] = {}
|
||
selected_frame_ids = {frame.id for frame in frames}
|
||
for annotation in annotations:
|
||
if annotation.frame_id not in selected_frame_ids or not annotation.mask_data:
|
||
continue
|
||
if not annotation.mask_data.get("polygons"):
|
||
continue
|
||
annotations_by_frame.setdefault(annotation.frame_id, []).append(annotation)
|
||
|
||
for frame in frames:
|
||
frame_annotations = annotations_by_frame.get(frame.id, [])
|
||
if not frame_annotations:
|
||
continue
|
||
width = frame.width or 1920
|
||
height = frame.height or 1080
|
||
frame_stem = _frame_export_stem(project, frame)
|
||
|
||
if include_individual:
|
||
class_masks: dict[str, np.ndarray] = {}
|
||
class_meta: dict[str, dict[str, Any]] = {}
|
||
for annotation in frame_annotations:
|
||
key = _annotation_class_key(annotation)
|
||
combined = class_masks.setdefault(key, np.zeros((height, width), dtype=np.uint8))
|
||
for poly in (annotation.mask_data or {}).get("polygons", []):
|
||
combined[:] = np.maximum(combined, _mask_from_polygon(poly, width, height))
|
||
class_meta.setdefault(key, _class_mapping_entry(annotation))
|
||
|
||
folder = f"分开Mask分割结果/{frame_stem}_分别导出"
|
||
for key, mask in sorted(class_masks.items(), key=lambda item: int(class_meta[item[0]]["internalPriority"])):
|
||
meta = class_meta[key]
|
||
maskid = class_values.get(key)
|
||
if maskid is None:
|
||
continue
|
||
_, encoded = cv2.imencode(".png", mask)
|
||
class_name = _safe_filename_part(meta["className"], "class")
|
||
zf.writestr(
|
||
f"{folder}/{frame_stem}_{class_name}_maskid{maskid}.png",
|
||
encoded.tobytes(),
|
||
)
|
||
|
||
needs_fused_output = include_semantic or include_pro_label or include_mix_label
|
||
semantic = np.zeros((height, width), dtype=np.uint8) if needs_fused_output else None
|
||
pro_label = np.zeros((height, width, 3), dtype=np.uint8) if (include_pro_label or include_mix_label) else None
|
||
|
||
if needs_fused_output:
|
||
for annotation in sorted(frame_annotations, key=_annotation_z_index):
|
||
key = _annotation_class_key(annotation)
|
||
value = class_values.get(key)
|
||
if value is None:
|
||
continue
|
||
combined = np.zeros((height, width), dtype=np.uint8)
|
||
for poly in (annotation.mask_data or {}).get("polygons", []):
|
||
combined = np.maximum(combined, _mask_from_polygon(poly, width, height))
|
||
if semantic is not None:
|
||
semantic[combined > 0] = value
|
||
if pro_label is not None:
|
||
rgb = class_rgb_by_key.get(key, [255, 255, 255])
|
||
bgr = np.array([rgb[2], rgb[1], rgb[0]], dtype=np.uint8)
|
||
pro_label[combined > 0] = bgr
|
||
|
||
if include_semantic and semantic is not None:
|
||
_, encoded = cv2.imencode(".png", semantic)
|
||
zf.writestr(f"GT_label图/{frame_stem}.png", encoded.tobytes())
|
||
|
||
if include_pro_label and pro_label is not None:
|
||
_, encoded = cv2.imencode(".png", pro_label)
|
||
zf.writestr(f"Pro_label彩色分割结果/{frame_stem}.png", encoded.tobytes())
|
||
|
||
if include_mix_label and pro_label is not None:
|
||
original = _decode_original_image(original_images.get(frame.id), width, height)
|
||
mask_pixels = np.any(pro_label > 0, axis=2)
|
||
mixed = original.copy()
|
||
opacity = min(max(float(mix_opacity), 0.0), 1.0)
|
||
mixed[mask_pixels] = (
|
||
original[mask_pixels].astype(np.float32) * (1.0 - opacity)
|
||
+ pro_label[mask_pixels].astype(np.float32) * opacity
|
||
).clip(0, 255).astype(np.uint8)
|
||
_, encoded = cv2.imencode(".png", mixed)
|
||
zf.writestr(f"Mix_label重叠覆盖彩色分割结果/{frame_stem}.png", encoded.tobytes())
|
||
|
||
|
||
def _write_mask_pngs(
|
||
zf: zipfile.ZipFile,
|
||
frames: list[Frame],
|
||
annotations: list[Annotation],
|
||
*,
|
||
mask_type: str,
|
||
individual_prefix: str = "",
|
||
semantic_prefix: str = "",
|
||
semantic_file_stem: str = "semantic_frame",
|
||
semantic_dtype: Any = np.uint8,
|
||
) -> list[dict[str, Any]]:
|
||
import cv2
|
||
|
||
class_values: dict[str, int] = {}
|
||
semantic_classes: list[dict[str, Any]] = []
|
||
|
||
def class_value(annotation: Annotation) -> int:
|
||
key = _annotation_class_key(annotation)
|
||
if key not in class_values:
|
||
value = len(class_values) + 1
|
||
class_values[key] = value
|
||
semantic_classes.append({
|
||
"value": value,
|
||
"key": key,
|
||
"label": _annotation_label(annotation),
|
||
"color": _annotation_color(annotation),
|
||
"zIndex": _annotation_z_index(annotation),
|
||
"template_id": annotation.template_id,
|
||
})
|
||
return class_values[key]
|
||
|
||
include_individual = mask_type in {"separate", "both"}
|
||
include_semantic = mask_type in {"gt_label", "both"}
|
||
frame_masks: dict[int, list[tuple[Annotation, np.ndarray]]] = {}
|
||
selected_frame_ids = {frame.id for frame in frames}
|
||
|
||
for ann in annotations:
|
||
if ann.frame_id not in selected_frame_ids or not ann.mask_data:
|
||
continue
|
||
polygons = ann.mask_data.get("polygons", [])
|
||
if not polygons:
|
||
continue
|
||
|
||
width = ann.frame.width if ann.frame else 1920
|
||
height = ann.frame.height if ann.frame else 1080
|
||
combined = np.zeros((height, width), dtype=np.uint8)
|
||
for poly in polygons:
|
||
mask = _mask_from_polygon(poly, width, height)
|
||
combined = np.maximum(combined, mask)
|
||
|
||
if include_individual:
|
||
_, encoded = cv2.imencode(".png", combined)
|
||
zf.writestr(f"{individual_prefix}mask_{ann.id:06d}.png", encoded.tobytes())
|
||
if include_semantic and ann.frame_id is not None:
|
||
frame_masks.setdefault(ann.frame_id, []).append((ann, combined))
|
||
|
||
if include_semantic:
|
||
for frame in frames:
|
||
entries = frame_masks.get(frame.id, [])
|
||
if not entries:
|
||
continue
|
||
width = frame.width or 1920
|
||
height = frame.height or 1080
|
||
semantic = np.zeros((height, width), dtype=semantic_dtype)
|
||
for ann, mask in sorted(entries, key=lambda item: _annotation_z_index(item[0])):
|
||
semantic[mask > 0] = class_value(ann)
|
||
_, encoded = cv2.imencode(".png", semantic)
|
||
zf.writestr(f"{semantic_prefix}{semantic_file_stem}_{frame.frame_index:06d}.png", encoded.tobytes())
|
||
|
||
if include_semantic:
|
||
zf.writestr(
|
||
f"{semantic_prefix}semantic_classes.json",
|
||
json.dumps({"classes": semantic_classes}, ensure_ascii=False, indent=2).encode("utf-8"),
|
||
)
|
||
return semantic_classes
|
||
|
||
|
||
@router.get(
|
||
"/{project_id}/coco",
|
||
summary="Export annotations in COCO format",
|
||
)
|
||
def export_coco(
|
||
project_id: int,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
) -> StreamingResponse:
|
||
"""Export all annotations for a project as a COCO-format JSON file."""
|
||
project = _project_or_404(project_id, db, current_user)
|
||
frames = _project_frames(project_id, db)
|
||
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
|
||
templates = db.query(Template).all()
|
||
coco = _build_coco(project, frames, annotations, templates)
|
||
|
||
data = json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8")
|
||
filename = f"project_{project_id}_coco.json"
|
||
|
||
return StreamingResponse(
|
||
io.BytesIO(data),
|
||
media_type="application/json",
|
||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||
)
|
||
|
||
|
||
@router.get(
|
||
"/{project_id}/masks",
|
||
summary="Export PNG masks as a ZIP archive",
|
||
)
|
||
def export_masks(
|
||
project_id: int,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
) -> StreamingResponse:
|
||
"""Export individual masks plus z-index fused semantic masks inside a ZIP."""
|
||
_project_or_404(project_id, db, current_user)
|
||
frames = _project_frames(project_id, db)
|
||
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
|
||
|
||
zip_buffer = io.BytesIO()
|
||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||
_write_mask_pngs(
|
||
zf,
|
||
frames,
|
||
annotations,
|
||
mask_type="both",
|
||
semantic_prefix="",
|
||
individual_prefix="",
|
||
)
|
||
|
||
zip_buffer.seek(0)
|
||
filename = f"project_{project_id}_masks.zip"
|
||
|
||
return StreamingResponse(
|
||
zip_buffer,
|
||
media_type="application/zip",
|
||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||
)
|
||
|
||
|
||
@router.get(
|
||
"/{project_id}/results",
|
||
summary="Export segmentation results as a ZIP archive",
|
||
)
|
||
def export_results(
|
||
project_id: int,
|
||
scope: str = Query("all", pattern="^(all|range|current)$"),
|
||
mask_type: str = Query("both", pattern="^(separate|gt_label|pro_label|mix_label|both|all)$"),
|
||
outputs: str | None = Query(None),
|
||
mix_opacity: float = Query(0.3, ge=0.0, le=1.0),
|
||
start_frame: int | None = Query(None, ge=1),
|
||
end_frame: int | None = Query(None, ge=1),
|
||
frame_id: int | None = Query(None, ge=1),
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
) -> StreamingResponse:
|
||
"""Export JSON annotations plus selected PNG mask outputs inside one ZIP.
|
||
|
||
`scope=all` exports the whole video. `scope=range` uses 1-based frame
|
||
numbers from the sorted project frame sequence. `scope=current` uses the
|
||
concrete backend `frame_id`.
|
||
"""
|
||
project = _project_or_404(project_id, db, current_user)
|
||
frames = _filter_frames(
|
||
_project_frames(project_id, db),
|
||
scope=scope,
|
||
start_frame=start_frame,
|
||
end_frame=end_frame,
|
||
frame_id=frame_id,
|
||
)
|
||
annotations = _filtered_annotations(project_id, {frame.id for frame in frames}, db)
|
||
templates = db.query(Template).all()
|
||
coco = _build_coco(project, frames, annotations, templates)
|
||
class_values, class_mapping = _build_gt_class_mapping(annotations)
|
||
selected_outputs = _parse_result_outputs(mask_type, outputs)
|
||
|
||
zip_buffer = io.BytesIO()
|
||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
||
zf.writestr(
|
||
"annotations_coco.json",
|
||
json.dumps(coco, ensure_ascii=False, indent=2).encode("utf-8"),
|
||
)
|
||
zf.writestr(
|
||
"maskid_GT像素值_类别映射.json",
|
||
json.dumps({"classes": class_mapping}, ensure_ascii=False, indent=2).encode("utf-8"),
|
||
)
|
||
original_images = _write_original_frames(zf, project, frames)
|
||
_write_result_mask_outputs(
|
||
zf,
|
||
project,
|
||
frames,
|
||
annotations,
|
||
outputs=selected_outputs,
|
||
class_values=class_values,
|
||
class_mapping=class_mapping,
|
||
original_images=original_images,
|
||
mix_opacity=mix_opacity,
|
||
)
|
||
|
||
zip_buffer.seek(0)
|
||
filename = _segmentation_results_filename(project, frames)
|
||
return StreamingResponse(
|
||
zip_buffer,
|
||
media_type="application/zip",
|
||
headers={"Content-Disposition": _download_content_disposition(filename)},
|
||
)
|