2026-05-08-03-57-51 接入DICOM SEG双切面展示

This commit is contained in:
2026-05-08 04:19:24 +08:00
parent 946c0f4ef3
commit fe4b90abcd
6 changed files with 626 additions and 25 deletions

View File

@@ -55,6 +55,19 @@ PREVIEW_CACHE_DIR = LIBRARY_DIR / "_preview_cache"
MODEL_DIR = LIBRARY_DIR / "_stl_models"
STL_MODEL_CACHE = {}
STL_MODEL_CACHE_LOCK = threading.Lock()
SEGMENTATION_CACHE = {}
SEGMENTATION_CACHE_LOCK = threading.Lock()
SEGMENTATION_CACHE_LIMIT = 4
SEGMENTATION_COLORS = [
(255, 112, 32),
(34, 211, 238),
(168, 85, 247),
(74, 222, 128),
(250, 204, 21),
(244, 114, 182),
(96, 165, 250),
(251, 146, 60),
]
VIEWER_WINDOWS = {
"default": {"label": "默认", "low": -500, "high": 1200},
"bone": {"label": "骨窗", "low": -500, "high": 1800},
@@ -83,6 +96,15 @@ def safe_model_filename(name):
return name if name.lower().endswith(".stl") else f"{name}.stl"
def safe_segmentation_filename(name):
name = safe_filename(name or "segmentation.dcm")
path = Path(name)
suffix = path.suffix.lower()
if suffix in {".dcm", ".dicom"}:
return f"{path.stem}{suffix}"
return f"{name}.dcm"
def normalized_username(username):
username = str(username or "").strip()
return username or "anonymous"
@@ -343,6 +365,244 @@ def load_stl_model(model_id):
return model
def segmentation_root_for_item(item):
return Path(item["dicomPath"]).resolve().parent / "segmentations"
def segmentation_file_for_id(item, segmentation_id):
segmentation_id = safe_filename(segmentation_id)
if not segmentation_id:
return None
root = segmentation_root_for_item(item)
matches = sorted(root.glob(f"{segmentation_id}_*.dcm")) + sorted(root.glob(f"{segmentation_id}_*.dicom"))
return matches[0] if matches else None
def segmentation_meta_path(segmentation_path):
return segmentation_path.with_suffix(segmentation_path.suffix + ".json")
def segmentation_signature(segmentation_path):
segmentation_path = Path(segmentation_path).resolve()
return (
str(segmentation_path),
segmentation_path.stat().st_size,
segmentation_path.stat().st_mtime,
)
def segment_labels_from_dataset(ds):
labels = {}
for segment in getattr(ds, "SegmentSequence", []) or []:
try:
number = int(getattr(segment, "SegmentNumber", 0))
except Exception:
number = 0
if number > 0:
labels[number] = str(getattr(segment, "SegmentLabel", f"Segment {number}"))
return labels
def frame_segment_number(ds, frame_index):
per_frame = getattr(ds, "PerFrameFunctionalGroupsSequence", None)
if per_frame and frame_index < len(per_frame):
segment_sequence = getattr(per_frame[frame_index], "SegmentIdentificationSequence", None)
if segment_sequence:
try:
return int(getattr(segment_sequence[0], "ReferencedSegmentNumber", 1))
except Exception:
return 1
return 1
def frame_position_patient(ds, frame_index):
per_frame = getattr(ds, "PerFrameFunctionalGroupsSequence", None)
if per_frame and frame_index < len(per_frame):
plane_position = getattr(per_frame[frame_index], "PlanePositionSequence", None)
if plane_position and hasattr(plane_position[0], "ImagePositionPatient"):
return np.asarray(plane_position[0].ImagePositionPatient, dtype=np.float64)
if hasattr(ds, "ImagePositionPatient"):
return np.asarray(ds.ImagePositionPatient, dtype=np.float64)
return None
def normalize_segmentation_frames(pixel_array):
array = np.asarray(pixel_array)
if array.ndim == 2:
return array[np.newaxis, :, :]
if array.ndim == 3:
return array
if array.ndim == 4:
if array.shape[-1] == 1:
return array[..., 0]
return np.max(array, axis=-1)
raise RuntimeError("Segmentation Mask 像素维度不受支持。")
def resize_label_frame(frame, rows, cols):
frame = np.asarray(frame)
if frame.shape == (rows, cols):
return frame
image = Image.fromarray(frame.astype(np.uint16))
image = image.resize((cols, rows), Image.Resampling.NEAREST)
return np.asarray(image)
def load_segmentation_mask(item_id, segmentation_id):
item = find_library_item(item_id)
if not item:
raise RuntimeError("影像库中没有找到该数据。")
segmentation_path = segmentation_file_for_id(item, segmentation_id)
if not segmentation_path:
raise RuntimeError("没有找到已上传的 DICOM Segmentation Mask。")
signature = segmentation_signature(segmentation_path)
cache_key = f"{item_id}:{safe_filename(segmentation_id)}"
with SEGMENTATION_CACHE_LOCK:
cached = SEGMENTATION_CACHE.get(cache_key)
if cached and cached["signature"] == signature:
cached["last_access"] = time.time()
return cached["data"]
dicom_files = sorted_dicom_files(item["dicomPath"])
if not dicom_files:
raise RuntimeError("该影像数据没有可配准的 CT DICOM。")
first_ct = pydicom.dcmread(str(dicom_files[0]), stop_before_pixels=True, force=True)
ct_rows = int(getattr(first_ct, "Rows", 0) or 0)
ct_cols = int(getattr(first_ct, "Columns", 0) or 0)
if ct_rows <= 0 or ct_cols <= 0:
raise RuntimeError("CT DICOM 缺少 Rows/Columns 信息,无法配准 Segmentation Mask。")
ds = pydicom.dcmread(str(segmentation_path), force=True)
frames = normalize_segmentation_frames(ds.pixel_array)
geometry = dicom_geometry(item["dicomPath"])
label_volume = np.zeros((len(dicom_files), ct_rows, ct_cols), dtype=np.uint16)
labels = segment_labels_from_dataset(ds)
for frame_index, frame in enumerate(frames):
frame = resize_label_frame(frame, ct_rows, ct_cols)
active = frame > 0
if not np.any(active):
continue
position = frame_position_patient(ds, frame_index)
if geometry and position is not None:
voxel = (position - geometry["origin"]) @ geometry["inverse"].T
slice_index = int(round(float(voxel[0])))
elif len(frames) == len(dicom_files):
slice_index = frame_index
else:
slice_index = min(len(dicom_files) - 1, frame_index)
slice_index = max(0, min(len(dicom_files) - 1, slice_index))
if int(frame.max()) > 1 and not labels:
label_volume[slice_index][active] = np.maximum(label_volume[slice_index][active], frame[active].astype(np.uint16))
else:
label = max(1, frame_segment_number(ds, frame_index))
label_volume[slice_index][active] = label
labels.setdefault(label, f"Segment {label}")
unique_labels = sorted(int(value) for value in np.unique(label_volume) if int(value) > 0)
if not unique_labels:
raise RuntimeError("DICOM Segmentation Mask 中没有可渲染的分割像素。")
if not labels:
labels = {value: f"Label {value}" for value in unique_labels}
data = {
"itemId": item_id,
"segId": safe_filename(segmentation_id),
"path": segmentation_path,
"name": segmentation_path.name.split("_", 1)[1] if "_" in segmentation_path.name else segmentation_path.name,
"volume": label_volume,
"frameCount": int(frames.shape[0]),
"segmentCount": len(unique_labels),
"labels": [
{"value": int(value), "label": labels.get(int(value), f"Segment {int(value)}")}
for value in unique_labels
],
}
with SEGMENTATION_CACHE_LOCK:
SEGMENTATION_CACHE[cache_key] = {
"signature": signature,
"data": data,
"last_access": time.time(),
}
while len(SEGMENTATION_CACHE) > SEGMENTATION_CACHE_LIMIT:
oldest_key = min(
SEGMENTATION_CACHE,
key=lambda key: SEGMENTATION_CACHE[key].get("last_access", 0),
)
if oldest_key == cache_key:
break
SEGMENTATION_CACHE.pop(oldest_key, None)
return data
def serialize_segmentation(segmentation):
return {
"segId": segmentation["segId"],
"name": segmentation["name"],
"frameCount": segmentation["frameCount"],
"segmentCount": segmentation["segmentCount"],
"labels": segmentation["labels"],
}
def list_segmentations(item_id):
item = find_library_item(item_id)
if not item:
raise RuntimeError("影像库中没有找到该数据。")
root = segmentation_root_for_item(item)
if not root.exists():
return []
segmentations = []
for path in sorted(list(root.glob("*.dcm")) + list(root.glob("*.dicom"))):
meta = read_json_file(segmentation_meta_path(path), None)
if meta:
segmentations.append(meta)
continue
seg_id = path.name.split("_", 1)[0]
try:
segmentations.append(serialize_segmentation(load_segmentation_mask(item_id, seg_id)))
except Exception:
segmentations.append({
"segId": seg_id,
"name": path.name.split("_", 1)[1] if "_" in path.name else path.name,
"frameCount": 0,
"segmentCount": 0,
"labels": [],
})
return segmentations
def save_uploaded_segmentation(headers, body):
if not body:
raise RuntimeError("上传的 Segmentation Mask 文件为空。")
item_id = safe_filename(unquote(headers.get("x-library-id", "")))
item = find_library_item(item_id)
if not item:
raise RuntimeError("影像库中没有找到要绑定的 DICOM 数据。")
source_name = safe_segmentation_filename(unquote(headers.get("x-file-name", "segmentation.dcm")))
seg_id = uuid.uuid4().hex[:12]
root = segmentation_root_for_item(item)
safe_mkdir(root)
segmentation_path = root / f"{seg_id}_{source_name}"
segmentation_path.write_bytes(body)
try:
segmentation = load_segmentation_mask(item_id, seg_id)
meta = serialize_segmentation(segmentation)
write_json_file(segmentation_meta_path(segmentation_path), meta)
return meta
except Exception:
try:
segmentation_path.unlink()
except Exception:
pass
raise
def dicom_geometry(dicom_dir):
dicom_files = sorted_dicom_files(dicom_dir)
if not dicom_files:
@@ -512,6 +772,64 @@ def render_mask_only_preview(mask, size):
return preview_rgba.convert("RGB")
def render_segmentation_label_preview(label_slice):
labels = np.asarray(label_slice, dtype=np.uint16)
rgb = np.zeros((labels.shape[0], labels.shape[1], 3), dtype=np.uint8)
rgb[:, :] = np.asarray((8, 13, 28), dtype=np.uint8)
for value in sorted(int(item) for item in np.unique(labels) if int(item) > 0):
color = SEGMENTATION_COLORS[(value - 1) % len(SEGMENTATION_COLORS)]
rgb[labels == value] = color
return Image.fromarray(rgb, mode="RGB")
def normalize_reformat_index(raw_index, count):
if str(raw_index) == "middle":
return count // 2
try:
return int(raw_index)
except Exception:
return count // 2
def make_segmentation_reformat_preview(item_id, segmentation_id, plane, index):
segmentation = load_segmentation_mask(item_id, segmentation_id)
plane = plane if plane in {"coronal", "sagittal"} else "coronal"
volume = segmentation["volume"]
if plane == "coronal":
count = volume.shape[1]
index = normalize_reformat_index(index, count)
index = max(0, min(index, count - 1))
label_slice = volume[:, index, :]
else:
count = volume.shape[2]
index = normalize_reformat_index(index, count)
index = max(0, min(index, count - 1))
label_slice = volume[:, :, index]
mask_pixels = int(np.count_nonzero(label_slice))
cache_dir = PREVIEW_CACHE_DIR / item_id / "segmentation"
safe_mkdir(cache_dir)
preview_path = cache_dir / f"{plane}_{index:04d}_seg_{safe_filename(segmentation_id)}.png"
if not preview_path.exists():
preview = render_segmentation_label_preview(label_slice)
preview = fit_image(preview, 960, 720)
preview.save(preview_path, format="PNG")
return {
"imageUrl": f"/api/file?path={quote(str(preview_path.resolve()), safe='')}",
"index": index,
"count": count,
"plane": plane,
"window": "segmentation",
"windowLabel": "Segmentation Mask",
"patientId": segmentation["itemId"],
"segId": segmentation["segId"],
"maskPixels": mask_pixels,
"segmentCount": segmentation["segmentCount"],
"labels": segmentation["labels"],
}
def make_library_reformat_preview(item_id, plane, index, window, model_id="", mask_only=False):
item = find_library_item(item_id)
if not item:
@@ -521,14 +839,6 @@ def make_library_reformat_preview(item_id, plane, index, window, model_id="", ma
window = window if window in VIEWER_WINDOWS else "default"
volume = load_cached_dicom_volume(item["dicomPath"])
def normalize_reformat_index(raw_index, count):
if str(raw_index) == "middle":
return count // 2
try:
return int(raw_index)
except Exception:
return count // 2
mask = None
mask_pixels = 0
if plane == "coronal":
@@ -1256,6 +1566,21 @@ class Handler(BaseHTTPRequestHandler):
self.send_json(make_library_reformat_preview(item_id, plane, index, window, model_id, mask_only))
return
if parsed.path == "/api/segmentation/list":
params = parse_qs(parsed.query)
item_id = params.get("id", [""])[0]
self.send_json({"items": list_segmentations(item_id)})
return
if parsed.path == "/api/segmentation/preview":
params = parse_qs(parsed.query)
item_id = params.get("id", [""])[0]
segmentation_id = params.get("segId", [""])[0]
plane = params.get("plane", ["coronal"])[0]
index = params.get("index", ["0"])[0]
self.send_json(make_segmentation_reformat_preview(item_id, segmentation_id, plane, index))
return
if parsed.path == "/api/library/info":
params = parse_qs(parsed.query)
item_id = params.get("id", [""])[0]
@@ -1328,6 +1653,11 @@ class Handler(BaseHTTPRequestHandler):
self.send_json(save_uploaded_stl(self.headers, body), status=201)
return
if parsed.path == "/api/segmentation/upload":
body = self.read_bytes()
self.send_json(save_uploaded_segmentation(self.headers, body), status=201)
return
body = self.read_json()
if parsed.path == "/api/demo/reset":
self.send_json(reset_demo_environment())
@@ -1513,7 +1843,7 @@ class Handler(BaseHTTPRequestHandler):
def send_cors_headers(self):
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Methods", "GET,POST,OPTIONS")
self.send_header("Access-Control-Allow-Headers", "Content-Type")
self.send_header("Access-Control-Allow-Headers", "Content-Type, x-file-name, x-library-id")
def send_json(self, payload, status=200):
data = json.dumps(payload, ensure_ascii=False, default=json_default).encode("utf-8")