2026-05-08-02-36-12 实现STL模型切分mask

This commit is contained in:
2026-05-08 02:45:12 +08:00
parent 7b7c555321
commit 8e0e54fc3c
6 changed files with 691 additions and 5 deletions

View File

@@ -2,6 +2,7 @@ import base64
import json
import os
import shutil
import struct
import threading
import time
import traceback
@@ -17,6 +18,7 @@ from urllib.parse import parse_qs, quote, unquote, urlparse
os.environ.setdefault("MPLCONFIGDIR", "/tmp/head_ct_morph_matplotlib")
import pydicom
import numpy as np
from pydicom.multival import MultiValue
from PIL import Image, ImageDraw
@@ -50,6 +52,9 @@ RESULT_DIR = APP_DIR / "web_results"
JOBS_META = RESULT_DIR / "jobs.json"
USER_TASKS_META = RESULT_DIR / "user_tasks.json"
PREVIEW_CACHE_DIR = LIBRARY_DIR / "_preview_cache"
MODEL_DIR = LIBRARY_DIR / "_stl_models"
STL_MODEL_CACHE = {}
STL_MODEL_CACHE_LOCK = threading.Lock()
VIEWER_WINDOWS = {
"default": {"label": "默认", "low": -500, "high": 1200},
"bone": {"label": "骨窗", "low": -500, "high": 1800},
@@ -73,6 +78,11 @@ def safe_filename(name):
return "".join(char if char.isalnum() or char in "._-" else "_" for char in Path(name).name)
def safe_model_filename(name):
name = safe_filename(name or "model.stl")
return name if name.lower().endswith(".stl") else f"{name}.stl"
def normalized_username(username):
username = str(username or "").strip()
return username or "anonymous"
@@ -231,6 +241,222 @@ def clear_dicom_caches(dicom_dir=None):
DICOM_VOLUME_CACHE.pop(cache_key, None)
def parse_ascii_stl(text):
vertices = []
triangles = []
for line in text.splitlines():
parts = line.strip().split()
if len(parts) == 4 and parts[0].lower() == "vertex":
try:
vertices.append([float(parts[1]), float(parts[2]), float(parts[3])])
except ValueError:
vertices = []
break
if len(vertices) == 3:
triangles.append(vertices)
vertices = []
if not triangles:
raise RuntimeError("STL 文件中没有可解析的三角面。")
return np.asarray(triangles, dtype=np.float32)
def parse_binary_stl(data):
if len(data) < 84:
raise RuntimeError("STL 文件过小,无法解析。")
triangle_count = struct.unpack_from("<I", data, 80)[0]
expected_size = 84 + triangle_count * 50
if triangle_count <= 0 or expected_size > len(data):
raise RuntimeError("Binary STL 三角面数量异常。")
triangles = np.zeros((triangle_count, 3, 3), dtype=np.float32)
offset = 84
for index in range(triangle_count):
values = struct.unpack_from("<12fH", data, offset)
triangles[index] = np.asarray(values[3:12], dtype=np.float32).reshape(3, 3)
offset += 50
return triangles
def parse_stl_bytes(data):
try:
text = data.decode("utf-8", errors="ignore")
except Exception:
text = ""
if text.lstrip().lower().startswith("solid"):
try:
return parse_ascii_stl(text)
except Exception:
pass
return parse_binary_stl(data)
def save_uploaded_stl(headers, body):
if not body:
raise RuntimeError("上传的 STL 文件为空。")
safe_mkdir(MODEL_DIR)
source_name = safe_model_filename(unquote(headers.get("x-file-name", "model.stl")))
model_id = uuid.uuid4().hex[:12]
model_path = MODEL_DIR / f"{model_id}_{source_name}"
model_path.write_bytes(body)
triangles = parse_stl_bytes(body)
bounds_min = triangles.reshape(-1, 3).min(axis=0).tolist()
bounds_max = triangles.reshape(-1, 3).max(axis=0).tolist()
with STL_MODEL_CACHE_LOCK:
STL_MODEL_CACHE[model_id] = {
"path": model_path,
"triangles": triangles,
"name": source_name,
"bounds": [bounds_min, bounds_max],
}
return {
"modelId": model_id,
"name": source_name,
"triangleCount": int(triangles.shape[0]),
"bounds": [bounds_min, bounds_max],
}
def load_stl_model(model_id):
model_id = safe_filename(model_id)
if not model_id:
raise RuntimeError("模型 ID 为空。")
with STL_MODEL_CACHE_LOCK:
cached = STL_MODEL_CACHE.get(model_id)
if cached:
return cached
matches = list(MODEL_DIR.glob(f"{model_id}_*.stl"))
if not matches:
raise RuntimeError("没有找到已上传的 STL 模型。")
model_path = matches[0]
triangles = parse_stl_bytes(model_path.read_bytes())
bounds_min = triangles.reshape(-1, 3).min(axis=0).tolist()
bounds_max = triangles.reshape(-1, 3).max(axis=0).tolist()
model = {
"path": model_path,
"triangles": triangles,
"name": model_path.name.split("_", 1)[1] if "_" in model_path.name else model_path.name,
"bounds": [bounds_min, bounds_max],
}
with STL_MODEL_CACHE_LOCK:
STL_MODEL_CACHE[model_id] = model
return model
def dicom_geometry(dicom_dir):
dicom_files = sorted_dicom_files(dicom_dir)
if not dicom_files:
return None
try:
first = pydicom.dcmread(str(dicom_files[0]), stop_before_pixels=True, force=True)
last = pydicom.dcmread(str(dicom_files[-1]), stop_before_pixels=True, force=True)
orientation = np.asarray(first.ImageOrientationPatient, dtype=np.float64)
col_dir = orientation[:3]
row_dir = orientation[3:]
slice_dir = np.cross(col_dir, row_dir)
pixel_spacing = np.asarray(first.PixelSpacing, dtype=np.float64)
row_spacing = float(pixel_spacing[0])
col_spacing = float(pixel_spacing[1])
first_pos = np.asarray(first.ImagePositionPatient, dtype=np.float64)
last_pos = np.asarray(last.ImagePositionPatient, dtype=np.float64)
slice_spacing = float(np.linalg.norm(last_pos - first_pos) / max(1, len(dicom_files) - 1))
if slice_spacing <= 0:
slice_spacing = float(getattr(first, "SliceThickness", 1) or 1)
basis = np.column_stack([
slice_dir * slice_spacing,
row_dir * row_spacing,
col_dir * col_spacing,
])
inverse = np.linalg.inv(basis)
return {
"origin": first_pos,
"inverse": inverse,
}
except Exception:
return None
def stl_triangles_to_voxels(triangles, dicom_dir):
geometry = dicom_geometry(dicom_dir)
if not geometry:
return triangles.astype(np.float32)
points = triangles.reshape(-1, 3).astype(np.float64)
voxel_points = (points - geometry["origin"]) @ geometry["inverse"].T
return voxel_points.reshape(triangles.shape).astype(np.float32)
def triangle_plane_segment(triangle, axis, value):
intersections = []
for start, end in [(triangle[0], triangle[1]), (triangle[1], triangle[2]), (triangle[2], triangle[0])]:
start_delta = float(start[axis] - value)
end_delta = float(end[axis] - value)
if abs(start_delta) < 1e-4 and abs(end_delta) < 1e-4:
continue
if abs(start_delta) < 1e-4:
intersections.append(start)
if start_delta * end_delta < 0:
t = start_delta / (start_delta - end_delta)
intersections.append(start + t * (end - start))
elif abs(end_delta) < 1e-4:
intersections.append(end)
unique = []
for point in intersections:
if not any(np.linalg.norm(point - existing) < 1e-3 for existing in unique):
unique.append(point)
if len(unique) >= 2:
return unique[0], unique[1]
return None
def make_stl_slice_mask(triangles, plane, index, image_shape):
height, width = image_shape
axis = 1 if plane == "coronal" else 2
mask_image = Image.new("L", (width, height), 0)
draw = ImageDraw.Draw(mask_image)
segment_count = 0
for triangle in triangles:
segment = triangle_plane_segment(triangle, axis, index)
if not segment:
continue
points = []
for point in segment:
if plane == "coronal":
x_value = float(point[2])
else:
x_value = float(point[1])
y_value = float(point[0])
points.append((x_value, y_value))
draw.line(points, fill=255, width=2)
segment_count += 1
if segment_count == 0:
return None, 0
mask = np.asarray(mask_image) > 0
try:
from scipy.ndimage import binary_fill_holes
filled = binary_fill_holes(mask)
if int(filled.sum()) > int(mask.sum()):
mask = filled
except Exception:
pass
mask_pixels = int(mask.sum())
if mask_pixels == 0:
return None, 0
return Image.fromarray((mask.astype(np.uint8) * 255), mode="L"), mask_pixels
def overlay_mask_on_preview(preview, mask):
overlay = Image.new("RGBA", preview.size, (255, 120, 20, 0))
alpha = mask.resize(preview.size, Image.Resampling.NEAREST)
overlay.putalpha(alpha.point(lambda value: 118 if value else 0))
preview_rgba = preview.convert("RGBA")
preview_rgba.alpha_composite(overlay)
return preview_rgba.convert("RGB")
def find_library_item(item_id):
return next((item for item in list_library() if item["id"] == item_id), None)
@@ -273,7 +499,7 @@ def make_library_slice_preview(item_id, index):
}
def make_library_reformat_preview(item_id, plane, index, window):
def make_library_reformat_preview(item_id, plane, index, window, model_id=""):
item = find_library_item(item_id)
if not item:
raise RuntimeError("影像库中没有找到该数据。")
@@ -290,6 +516,8 @@ def make_library_reformat_preview(item_id, plane, index, window):
except Exception:
return count // 2
mask = None
mask_pixels = 0
if plane == "coronal":
count = volume.shape[1]
index = normalize_reformat_index(index, count)
@@ -301,12 +529,20 @@ def make_library_reformat_preview(item_id, plane, index, window):
index = max(0, min(index, count - 1))
image = volume[:, :, index]
if model_id:
model = load_stl_model(model_id)
triangles = stl_triangles_to_voxels(model["triangles"], item["dicomPath"])
mask, mask_pixels = make_stl_slice_mask(triangles, plane, index, image.shape)
cache_dir = PREVIEW_CACHE_DIR / item_id / "reformat"
safe_mkdir(cache_dir)
preview_path = cache_dir / f"{plane}_{window}_{index:04d}.png"
model_suffix = f"_model_{safe_filename(model_id)}" if model_id else ""
preview_path = cache_dir / f"{plane}_{window}_{index:04d}{model_suffix}.png"
if not preview_path.exists():
preset = VIEWER_WINDOWS[window]
preview = Image.fromarray(ct_window(image, preset["low"], preset["high"])).convert("RGB")
if mask is not None:
preview = overlay_mask_on_preview(preview, mask)
preview = fit_image(preview, 960, 720)
preview.save(preview_path, format="PNG")
@@ -318,6 +554,8 @@ def make_library_reformat_preview(item_id, plane, index, window):
"window": window,
"windowLabel": VIEWER_WINDOWS[window]["label"],
"patientId": item["patientId"],
"modelId": model_id,
"maskPixels": mask_pixels,
}
@@ -995,7 +1233,8 @@ class Handler(BaseHTTPRequestHandler):
plane = params.get("plane", ["coronal"])[0]
index = params.get("index", ["0"])[0]
window = params.get("window", ["default"])[0]
self.send_json(make_library_reformat_preview(item_id, plane, index, window))
model_id = params.get("modelId", [""])[0]
self.send_json(make_library_reformat_preview(item_id, plane, index, window, model_id))
return
if parsed.path == "/api/library/info":
@@ -1065,6 +1304,11 @@ class Handler(BaseHTTPRequestHandler):
self.send_json(upload_library_item(self.headers, body), status=201)
return
if parsed.path == "/api/model/upload":
body = self.read_bytes()
self.send_json(save_uploaded_stl(self.headers, body), status=201)
return
body = self.read_json()
if parsed.path == "/api/demo/reset":
self.send_json(reset_demo_environment())
@@ -1267,6 +1511,7 @@ def main():
safe_mkdir(APP_DIR / "ppt_video")
safe_mkdir(LIBRARY_DIR)
safe_mkdir(RESULT_DIR)
safe_mkdir(MODEL_DIR)
load_persisted_jobs()
server = ThreadingHTTPServer((HOST, PORT), Handler)
print(f"Head CT Morph backend running at http://{HOST}:{PORT}")