2026-05-08-02-36-12 实现STL模型切分mask
This commit is contained in:
251
web_backend.py
251
web_backend.py
@@ -2,6 +2,7 @@ import base64
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import struct
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
@@ -17,6 +18,7 @@ from urllib.parse import parse_qs, quote, unquote, urlparse
|
||||
os.environ.setdefault("MPLCONFIGDIR", "/tmp/head_ct_morph_matplotlib")
|
||||
|
||||
import pydicom
|
||||
import numpy as np
|
||||
from pydicom.multival import MultiValue
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
@@ -50,6 +52,9 @@ RESULT_DIR = APP_DIR / "web_results"
|
||||
JOBS_META = RESULT_DIR / "jobs.json"
|
||||
USER_TASKS_META = RESULT_DIR / "user_tasks.json"
|
||||
PREVIEW_CACHE_DIR = LIBRARY_DIR / "_preview_cache"
|
||||
MODEL_DIR = LIBRARY_DIR / "_stl_models"
|
||||
STL_MODEL_CACHE = {}
|
||||
STL_MODEL_CACHE_LOCK = threading.Lock()
|
||||
VIEWER_WINDOWS = {
|
||||
"default": {"label": "默认", "low": -500, "high": 1200},
|
||||
"bone": {"label": "骨窗", "low": -500, "high": 1800},
|
||||
@@ -73,6 +78,11 @@ def safe_filename(name):
|
||||
return "".join(char if char.isalnum() or char in "._-" else "_" for char in Path(name).name)
|
||||
|
||||
|
||||
def safe_model_filename(name):
|
||||
name = safe_filename(name or "model.stl")
|
||||
return name if name.lower().endswith(".stl") else f"{name}.stl"
|
||||
|
||||
|
||||
def normalized_username(username):
|
||||
username = str(username or "").strip()
|
||||
return username or "anonymous"
|
||||
@@ -231,6 +241,222 @@ def clear_dicom_caches(dicom_dir=None):
|
||||
DICOM_VOLUME_CACHE.pop(cache_key, None)
|
||||
|
||||
|
||||
def parse_ascii_stl(text):
|
||||
vertices = []
|
||||
triangles = []
|
||||
for line in text.splitlines():
|
||||
parts = line.strip().split()
|
||||
if len(parts) == 4 and parts[0].lower() == "vertex":
|
||||
try:
|
||||
vertices.append([float(parts[1]), float(parts[2]), float(parts[3])])
|
||||
except ValueError:
|
||||
vertices = []
|
||||
break
|
||||
if len(vertices) == 3:
|
||||
triangles.append(vertices)
|
||||
vertices = []
|
||||
if not triangles:
|
||||
raise RuntimeError("STL 文件中没有可解析的三角面。")
|
||||
return np.asarray(triangles, dtype=np.float32)
|
||||
|
||||
|
||||
def parse_binary_stl(data):
|
||||
if len(data) < 84:
|
||||
raise RuntimeError("STL 文件过小,无法解析。")
|
||||
triangle_count = struct.unpack_from("<I", data, 80)[0]
|
||||
expected_size = 84 + triangle_count * 50
|
||||
if triangle_count <= 0 or expected_size > len(data):
|
||||
raise RuntimeError("Binary STL 三角面数量异常。")
|
||||
|
||||
triangles = np.zeros((triangle_count, 3, 3), dtype=np.float32)
|
||||
offset = 84
|
||||
for index in range(triangle_count):
|
||||
values = struct.unpack_from("<12fH", data, offset)
|
||||
triangles[index] = np.asarray(values[3:12], dtype=np.float32).reshape(3, 3)
|
||||
offset += 50
|
||||
return triangles
|
||||
|
||||
|
||||
def parse_stl_bytes(data):
|
||||
try:
|
||||
text = data.decode("utf-8", errors="ignore")
|
||||
except Exception:
|
||||
text = ""
|
||||
if text.lstrip().lower().startswith("solid"):
|
||||
try:
|
||||
return parse_ascii_stl(text)
|
||||
except Exception:
|
||||
pass
|
||||
return parse_binary_stl(data)
|
||||
|
||||
|
||||
def save_uploaded_stl(headers, body):
|
||||
if not body:
|
||||
raise RuntimeError("上传的 STL 文件为空。")
|
||||
safe_mkdir(MODEL_DIR)
|
||||
source_name = safe_model_filename(unquote(headers.get("x-file-name", "model.stl")))
|
||||
model_id = uuid.uuid4().hex[:12]
|
||||
model_path = MODEL_DIR / f"{model_id}_{source_name}"
|
||||
model_path.write_bytes(body)
|
||||
triangles = parse_stl_bytes(body)
|
||||
bounds_min = triangles.reshape(-1, 3).min(axis=0).tolist()
|
||||
bounds_max = triangles.reshape(-1, 3).max(axis=0).tolist()
|
||||
with STL_MODEL_CACHE_LOCK:
|
||||
STL_MODEL_CACHE[model_id] = {
|
||||
"path": model_path,
|
||||
"triangles": triangles,
|
||||
"name": source_name,
|
||||
"bounds": [bounds_min, bounds_max],
|
||||
}
|
||||
return {
|
||||
"modelId": model_id,
|
||||
"name": source_name,
|
||||
"triangleCount": int(triangles.shape[0]),
|
||||
"bounds": [bounds_min, bounds_max],
|
||||
}
|
||||
|
||||
|
||||
def load_stl_model(model_id):
|
||||
model_id = safe_filename(model_id)
|
||||
if not model_id:
|
||||
raise RuntimeError("模型 ID 为空。")
|
||||
with STL_MODEL_CACHE_LOCK:
|
||||
cached = STL_MODEL_CACHE.get(model_id)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
matches = list(MODEL_DIR.glob(f"{model_id}_*.stl"))
|
||||
if not matches:
|
||||
raise RuntimeError("没有找到已上传的 STL 模型。")
|
||||
model_path = matches[0]
|
||||
triangles = parse_stl_bytes(model_path.read_bytes())
|
||||
bounds_min = triangles.reshape(-1, 3).min(axis=0).tolist()
|
||||
bounds_max = triangles.reshape(-1, 3).max(axis=0).tolist()
|
||||
model = {
|
||||
"path": model_path,
|
||||
"triangles": triangles,
|
||||
"name": model_path.name.split("_", 1)[1] if "_" in model_path.name else model_path.name,
|
||||
"bounds": [bounds_min, bounds_max],
|
||||
}
|
||||
with STL_MODEL_CACHE_LOCK:
|
||||
STL_MODEL_CACHE[model_id] = model
|
||||
return model
|
||||
|
||||
|
||||
def dicom_geometry(dicom_dir):
|
||||
dicom_files = sorted_dicom_files(dicom_dir)
|
||||
if not dicom_files:
|
||||
return None
|
||||
try:
|
||||
first = pydicom.dcmread(str(dicom_files[0]), stop_before_pixels=True, force=True)
|
||||
last = pydicom.dcmread(str(dicom_files[-1]), stop_before_pixels=True, force=True)
|
||||
orientation = np.asarray(first.ImageOrientationPatient, dtype=np.float64)
|
||||
col_dir = orientation[:3]
|
||||
row_dir = orientation[3:]
|
||||
slice_dir = np.cross(col_dir, row_dir)
|
||||
pixel_spacing = np.asarray(first.PixelSpacing, dtype=np.float64)
|
||||
row_spacing = float(pixel_spacing[0])
|
||||
col_spacing = float(pixel_spacing[1])
|
||||
first_pos = np.asarray(first.ImagePositionPatient, dtype=np.float64)
|
||||
last_pos = np.asarray(last.ImagePositionPatient, dtype=np.float64)
|
||||
slice_spacing = float(np.linalg.norm(last_pos - first_pos) / max(1, len(dicom_files) - 1))
|
||||
if slice_spacing <= 0:
|
||||
slice_spacing = float(getattr(first, "SliceThickness", 1) or 1)
|
||||
basis = np.column_stack([
|
||||
slice_dir * slice_spacing,
|
||||
row_dir * row_spacing,
|
||||
col_dir * col_spacing,
|
||||
])
|
||||
inverse = np.linalg.inv(basis)
|
||||
return {
|
||||
"origin": first_pos,
|
||||
"inverse": inverse,
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def stl_triangles_to_voxels(triangles, dicom_dir):
|
||||
geometry = dicom_geometry(dicom_dir)
|
||||
if not geometry:
|
||||
return triangles.astype(np.float32)
|
||||
points = triangles.reshape(-1, 3).astype(np.float64)
|
||||
voxel_points = (points - geometry["origin"]) @ geometry["inverse"].T
|
||||
return voxel_points.reshape(triangles.shape).astype(np.float32)
|
||||
|
||||
|
||||
def triangle_plane_segment(triangle, axis, value):
|
||||
intersections = []
|
||||
for start, end in [(triangle[0], triangle[1]), (triangle[1], triangle[2]), (triangle[2], triangle[0])]:
|
||||
start_delta = float(start[axis] - value)
|
||||
end_delta = float(end[axis] - value)
|
||||
if abs(start_delta) < 1e-4 and abs(end_delta) < 1e-4:
|
||||
continue
|
||||
if abs(start_delta) < 1e-4:
|
||||
intersections.append(start)
|
||||
if start_delta * end_delta < 0:
|
||||
t = start_delta / (start_delta - end_delta)
|
||||
intersections.append(start + t * (end - start))
|
||||
elif abs(end_delta) < 1e-4:
|
||||
intersections.append(end)
|
||||
|
||||
unique = []
|
||||
for point in intersections:
|
||||
if not any(np.linalg.norm(point - existing) < 1e-3 for existing in unique):
|
||||
unique.append(point)
|
||||
if len(unique) >= 2:
|
||||
return unique[0], unique[1]
|
||||
return None
|
||||
|
||||
|
||||
def make_stl_slice_mask(triangles, plane, index, image_shape):
|
||||
height, width = image_shape
|
||||
axis = 1 if plane == "coronal" else 2
|
||||
mask_image = Image.new("L", (width, height), 0)
|
||||
draw = ImageDraw.Draw(mask_image)
|
||||
segment_count = 0
|
||||
|
||||
for triangle in triangles:
|
||||
segment = triangle_plane_segment(triangle, axis, index)
|
||||
if not segment:
|
||||
continue
|
||||
points = []
|
||||
for point in segment:
|
||||
if plane == "coronal":
|
||||
x_value = float(point[2])
|
||||
else:
|
||||
x_value = float(point[1])
|
||||
y_value = float(point[0])
|
||||
points.append((x_value, y_value))
|
||||
draw.line(points, fill=255, width=2)
|
||||
segment_count += 1
|
||||
|
||||
if segment_count == 0:
|
||||
return None, 0
|
||||
|
||||
mask = np.asarray(mask_image) > 0
|
||||
try:
|
||||
from scipy.ndimage import binary_fill_holes
|
||||
filled = binary_fill_holes(mask)
|
||||
if int(filled.sum()) > int(mask.sum()):
|
||||
mask = filled
|
||||
except Exception:
|
||||
pass
|
||||
mask_pixels = int(mask.sum())
|
||||
if mask_pixels == 0:
|
||||
return None, 0
|
||||
return Image.fromarray((mask.astype(np.uint8) * 255), mode="L"), mask_pixels
|
||||
|
||||
|
||||
def overlay_mask_on_preview(preview, mask):
|
||||
overlay = Image.new("RGBA", preview.size, (255, 120, 20, 0))
|
||||
alpha = mask.resize(preview.size, Image.Resampling.NEAREST)
|
||||
overlay.putalpha(alpha.point(lambda value: 118 if value else 0))
|
||||
preview_rgba = preview.convert("RGBA")
|
||||
preview_rgba.alpha_composite(overlay)
|
||||
return preview_rgba.convert("RGB")
|
||||
|
||||
|
||||
def find_library_item(item_id):
|
||||
return next((item for item in list_library() if item["id"] == item_id), None)
|
||||
|
||||
@@ -273,7 +499,7 @@ def make_library_slice_preview(item_id, index):
|
||||
}
|
||||
|
||||
|
||||
def make_library_reformat_preview(item_id, plane, index, window):
|
||||
def make_library_reformat_preview(item_id, plane, index, window, model_id=""):
|
||||
item = find_library_item(item_id)
|
||||
if not item:
|
||||
raise RuntimeError("影像库中没有找到该数据。")
|
||||
@@ -290,6 +516,8 @@ def make_library_reformat_preview(item_id, plane, index, window):
|
||||
except Exception:
|
||||
return count // 2
|
||||
|
||||
mask = None
|
||||
mask_pixels = 0
|
||||
if plane == "coronal":
|
||||
count = volume.shape[1]
|
||||
index = normalize_reformat_index(index, count)
|
||||
@@ -301,12 +529,20 @@ def make_library_reformat_preview(item_id, plane, index, window):
|
||||
index = max(0, min(index, count - 1))
|
||||
image = volume[:, :, index]
|
||||
|
||||
if model_id:
|
||||
model = load_stl_model(model_id)
|
||||
triangles = stl_triangles_to_voxels(model["triangles"], item["dicomPath"])
|
||||
mask, mask_pixels = make_stl_slice_mask(triangles, plane, index, image.shape)
|
||||
|
||||
cache_dir = PREVIEW_CACHE_DIR / item_id / "reformat"
|
||||
safe_mkdir(cache_dir)
|
||||
preview_path = cache_dir / f"{plane}_{window}_{index:04d}.png"
|
||||
model_suffix = f"_model_{safe_filename(model_id)}" if model_id else ""
|
||||
preview_path = cache_dir / f"{plane}_{window}_{index:04d}{model_suffix}.png"
|
||||
if not preview_path.exists():
|
||||
preset = VIEWER_WINDOWS[window]
|
||||
preview = Image.fromarray(ct_window(image, preset["low"], preset["high"])).convert("RGB")
|
||||
if mask is not None:
|
||||
preview = overlay_mask_on_preview(preview, mask)
|
||||
preview = fit_image(preview, 960, 720)
|
||||
preview.save(preview_path, format="PNG")
|
||||
|
||||
@@ -318,6 +554,8 @@ def make_library_reformat_preview(item_id, plane, index, window):
|
||||
"window": window,
|
||||
"windowLabel": VIEWER_WINDOWS[window]["label"],
|
||||
"patientId": item["patientId"],
|
||||
"modelId": model_id,
|
||||
"maskPixels": mask_pixels,
|
||||
}
|
||||
|
||||
|
||||
@@ -995,7 +1233,8 @@ class Handler(BaseHTTPRequestHandler):
|
||||
plane = params.get("plane", ["coronal"])[0]
|
||||
index = params.get("index", ["0"])[0]
|
||||
window = params.get("window", ["default"])[0]
|
||||
self.send_json(make_library_reformat_preview(item_id, plane, index, window))
|
||||
model_id = params.get("modelId", [""])[0]
|
||||
self.send_json(make_library_reformat_preview(item_id, plane, index, window, model_id))
|
||||
return
|
||||
|
||||
if parsed.path == "/api/library/info":
|
||||
@@ -1065,6 +1304,11 @@ class Handler(BaseHTTPRequestHandler):
|
||||
self.send_json(upload_library_item(self.headers, body), status=201)
|
||||
return
|
||||
|
||||
if parsed.path == "/api/model/upload":
|
||||
body = self.read_bytes()
|
||||
self.send_json(save_uploaded_stl(self.headers, body), status=201)
|
||||
return
|
||||
|
||||
body = self.read_json()
|
||||
if parsed.path == "/api/demo/reset":
|
||||
self.send_json(reset_demo_environment())
|
||||
@@ -1267,6 +1511,7 @@ def main():
|
||||
safe_mkdir(APP_DIR / "ppt_video")
|
||||
safe_mkdir(LIBRARY_DIR)
|
||||
safe_mkdir(RESULT_DIR)
|
||||
safe_mkdir(MODEL_DIR)
|
||||
load_persisted_jobs()
|
||||
server = ThreadingHTTPServer((HOST, PORT), Handler)
|
||||
print(f"Head CT Morph backend running at http://{HOST}:{PORT}")
|
||||
|
||||
Reference in New Issue
Block a user