import json import os import shutil import tempfile import threading import time import uuid from pathlib import Path from tkinter import ( BOTH, DISABLED, END, HORIZONTAL, NORMAL, Button, Entry, Frame, Label, Scale, StringVar, Tk, filedialog, messagebox, ) import numpy as np import pydicom import SimpleITK as sitk from PIL import Image, ImageDraw, ImageFont, ImageTk from platipy.imaging.generation.mask import get_external_mask from platipy.imaging.registration.utils import apply_transform np.alen = len from DeformHeadCT.VolumeInfo import VolumeDeformation, convert_nifti_to_dicom_series from DeformHeadCT.deformation import HeadDeformation, generate_field_rotation APP_DIR = Path(__file__).resolve().parent RUNTIME_DIR = Path(tempfile.gettempdir()) / "HeadExtensionApp" PREVIEW_DIR = APP_DIR / "app_previews" DEFAULT_COORDINATES_CUTOFF = [[145, 305, 256], [135, 205, 256]] DEFAULT_VERTEBRAE = { "Oc-C1": [220, 255, 256], "C1-C2": [208, 255, 256], "C2-C3": [190, 255, 256], "C3-C4": [172, 258, 256], "C4-C5": [154, 262, 256], "C5-C6": [136, 268, 256], "C6-C7": [118, 274, 256], "C7-T1": [100, 282, 256], } STATE_LABELS = [ ("original", "Original", "ct_original"), ("hard_boundary", "Hard boundary", "ct_hard_boundary"), ("gaussian_smooth", "Gaussian smooth", "ct_gaussian_smooth"), ("soft_transition", "Soft transition", "ct_soft_transition"), ] def safe_mkdir(path): Path(path).mkdir(parents=True, exist_ok=True) def copy_dicom_to_ascii_folder(input_dir, run_id): target = RUNTIME_DIR / run_id / "input_ct" safe_mkdir(target) count = 0 for file_path in Path(input_dir).iterdir(): if file_path.is_file() and file_path.suffix.lower() == ".dcm": shutil.copy2(file_path, target / file_path.name) count += 1 if count == 0: raise RuntimeError("输入文件夹里没有找到 .dcm 文件。") return target, count def load_dicom_volume(input_dir): items = [] for file_path in Path(input_dir).iterdir(): if file_path.is_file() and file_path.suffix.lower() == ".dcm": ds = pydicom.dcmread(str(file_path), force=True) instance = int(getattr(ds, "InstanceNumber", len(items))) items.append((instance, ds)) if not items: raise RuntimeError("输入文件夹里没有找到 .dcm 文件。") items.sort(key=lambda item: item[0]) volume = [] for _, ds in items: image = ds.pixel_array.astype(np.float32) image = image * float(getattr(ds, "RescaleSlope", 1)) image = image + float(getattr(ds, "RescaleIntercept", 0)) volume.append(image) return np.stack(volume, axis=0) def ct_window(image, low=-500, high=1200): image = np.clip((image - low) / (high - low), 0, 1) return (image * 255).astype(np.uint8) def sagittal_mip(volume): x0 = max(0, volume.shape[2] // 2 - 21) x1 = min(volume.shape[2], volume.shape[2] // 2 + 24) image = ct_window(np.max(volume[:, :, x0:x1], axis=2)) return Image.fromarray(image).convert("RGB") def crop_head_neck(image): width, height = image.size left = int(width * 0.09) right = int(width * 0.89) top = 0 bottom = int(height * 0.72) return image.crop((left, top, right, bottom)) def cutoff_center_z(coordinates_cutoff): return float(np.mean([point[0] for point in coordinates_cutoff])) def draw_cutoff_line(panel, image_depth, coordinates_cutoff=DEFAULT_COORDINATES_CUTOFF): panel = panel.copy() crop_height = int(image_depth * 0.72) if crop_height <= 0: return panel line_y = int(round((image_depth - 1 - cutoff_center_z(coordinates_cutoff)) * panel.height / crop_height)) if line_y < 0 or line_y >= panel.height: return panel draw = ImageDraw.Draw(panel) shadow = (0, 0, 0) line_color = (255, 215, 60) draw.line((0, line_y, panel.width, line_y), fill=shadow, width=6) draw.line((0, line_y, panel.width, line_y), fill=line_color, width=3) return panel def fit_image(image, width, height): scale = min(width / image.width, height / image.height) resized = image.resize( (int(image.width * scale), int(image.height * scale)), Image.Resampling.LANCZOS, ) canvas = Image.new("RGB", (width, height), (0, 0, 0)) canvas.paste(resized, ((width - resized.width) // 2, (height - resized.height) // 2)) return canvas def preview_deform_2d(image, angle_degrees): """Fast visual preview only. The DICOM output uses the real 3D field.""" try: from scipy.ndimage import map_coordinates except Exception: return image arr = np.asarray(image.convert("RGB")).astype(np.float32) h, w, _ = arr.shape yy, xx = np.mgrid[0:h, 0:w] pivot_x = int(w * 0.55) pivot_y = int(h * 0.62) full_motion_y = h * 0.50 fixed_y = h * 0.92 t = np.clip((yy - full_motion_y) / (fixed_y - full_motion_y), 0, 1) weight = 1 - (t * t * (3 - 2 * t)) x_soft = np.clip((xx - w * 0.15) / (w * 0.75), 0, 1) x_soft = x_soft * x_soft * (3 - 2 * x_soft) weight = np.clip(weight * (0.90 + 0.10 * x_soft), 0, 1) theta = np.deg2rad(angle_degrees) * weight cos_t = np.cos(theta) sin_t = np.sin(theta) dx = xx - pivot_x dy = yy - pivot_y src_x = pivot_x + cos_t * dx + sin_t * dy src_y = pivot_y - sin_t * dx + cos_t * dy warped_channels = [ map_coordinates(arr[..., channel], [src_y, src_x], order=1, mode="constant", cval=0) for channel in range(arr.shape[2]) ] warped = np.stack(warped_channels, axis=2) return Image.fromarray(np.clip(warped, 0, 255).astype(np.uint8)).convert("RGB") def transition_weight(image, coordinates_cutoff, width_voxels): size_x, size_y, size_z = image.GetSize() center_z = float(np.mean([p[0] for p in coordinates_cutoff])) edge0 = center_z - float(width_voxels) edge1 = center_z + float(width_voxels) z_grid = np.arange(size_z, dtype=np.float32)[:, None, None] weight = np.clip((z_grid - edge0) / (edge1 - edge0), 0, 1) weight = weight * weight * (3 - 2 * weight) return np.broadcast_to(weight, (size_z, size_y, size_x)).astype(np.float32) def hard_boundary_weight(image, coordinates_cutoff): size_x, size_y, size_z = image.GetSize() center_z = float(np.mean([p[0] for p in coordinates_cutoff])) z_grid = np.arange(size_z, dtype=np.float32)[:, None, None] weight = (z_grid >= center_z).astype(np.float32) return np.broadcast_to(weight, (size_z, size_y, size_x)).astype(np.float32) def image_from_dvf_array(dvf_arr, reference_image): dvf = sitk.GetImageFromArray(dvf_arr.astype(np.float32)) dvf.CopyInformation(reference_image) return dvf def apply_dvf_to_ct(image_ct, dvf): transform = sitk.DisplacementFieldTransform(sitk.Cast(dvf, sitk.sitkVectorFloat64)) return apply_transform( image_ct, transform=transform, interpolator=sitk.sitkLinear, default_value=int(sitk.GetArrayViewFromImage(image_ct).min()), ) def reset_folder(path): path = Path(path) if path.exists(): shutil.rmtree(path) path.mkdir(parents=True, exist_ok=True) def write_dicom_series(image, reference_dicom_dir, output_dir, run_root, state_key): output_dir = Path(output_dir) reset_folder(output_dir) ascii_output_dir = run_root / "dicom_output" / state_key reset_folder(ascii_output_dir) convert_nifti_to_dicom_series( image=image, reference_dcm=str(reference_dicom_dir), output_directory=str(ascii_output_dir), ) for dicom_path in ascii_output_dir.glob("*.dcm"): shutil.copy2(dicom_path, output_dir / dicom_path.name) def write_info_json(info_path, input_dir, temp_dir, output_dir, angle_degrees, transition_width): data = { "name": "HEAD_EXTENSION", "InputDirectory": str(input_dir).replace("\\", "/"), "TempDirectory": str(temp_dir).replace("\\", "/"), "OutputDirectory": str(output_dir).replace("\\", "/"), "axes": [0, 0, -1], "angles": [float(angle_degrees)], "coordinates_cutoff": DEFAULT_COORDINATES_CUTOFF, "transition_width_voxels": int(transition_width), } data.update(DEFAULT_VERTEBRAE) info_path.write_text(json.dumps(data, indent=2), encoding="utf-8") def run_deformation(input_dir, output_dir, angle_degrees, transition_width, progress): run_id = time.strftime("%Y%m%d_%H%M%S_") + uuid.uuid4().hex[:8] run_root = RUNTIME_DIR / run_id temp_dir = run_root / "temp" ascii_input_dir, copied_count = copy_dicom_to_ascii_folder(input_dir, run_id) safe_mkdir(output_dir) info_path = run_root / "head_extension.json" write_info_json( info_path, ascii_input_dir, temp_dir, Path(output_dir), angle_degrees, transition_width, ) progress(f"已复制 {copied_count} 张 DICOM,开始转换体数据...") vol_info = VolumeDeformation(InfoFile=str(info_path)) vol_info.PrepareDcmData() patient_image_dir = vol_info.nifti_directory / vol_info.patientunderscore / "IMAGES" if not patient_image_dir.exists(): patient_image_dir = next(vol_info.nifti_directory.glob("*/IMAGES")) vol_info.patientunderscore = patient_image_dir.parent.name root_image_dir = vol_info.nifti_directory / "IMAGES" safe_mkdir(root_image_dir) for image_path in patient_image_dir.glob("*.nii.gz"): shutil.copy2(image_path, root_image_dir / image_path.name) progress("正在生成四种状态的三维位移场...") head_def = HeadDeformation(vol_info.nifti_directory, vol_info.patientunderscore, 0) external_mask = get_external_mask(head_def.image_ct) _, _, full_rotation_dvf, _ = generate_field_rotation( head_def.image_ct, external_mask, tuple(vol_info.point_of_rotation[0]), axis_of_rotation=vol_info.axes, angle=-vol_info.angles[0] * np.pi / 180, gaussian_smooth=0, ) base_dvf_arr = sitk.GetArrayFromImage(full_rotation_dvf).astype(np.float32) mask_arr = sitk.GetArrayFromImage(external_mask).astype(np.float32) hard_weight = hard_boundary_weight(head_def.image_ct, vol_info.coordinates_cutoff) soft_weight = transition_weight( head_def.image_ct, vol_info.coordinates_cutoff, int(transition_width), ) hard_dvf = image_from_dvf_array( base_dvf_arr * (hard_weight * mask_arr)[..., None], head_def.image_ct, ) gaussian_dvf = sitk.SmoothingRecursiveGaussian(hard_dvf, 3) soft_dvf = image_from_dvf_array( base_dvf_arr * (soft_weight * mask_arr)[..., None], head_def.image_ct, ) soft_dvf = sitk.SmoothingRecursiveGaussian(soft_dvf, 3) progress("正在应用形变...") state_images = { "original": head_def.image_ct, "hard_boundary": apply_dvf_to_ct(head_def.image_ct, hard_dvf), "gaussian_smooth": apply_dvf_to_ct(head_def.image_ct, gaussian_dvf), "soft_transition": apply_dvf_to_ct(head_def.image_ct, soft_dvf), } progress("正在写出四种状态的完整 DICOM 序列...") reference_dicom_dir = vol_info.nifti_directory / "dicom" / "ct" output_paths = {} for state_key, _, folder_name in STATE_LABELS: state_output_dir = Path(output_dir) / folder_name write_dicom_series( state_images[state_key], reference_dicom_dir, state_output_dir, run_root, state_key, ) output_paths[state_key] = state_output_dir legacy_soft_dir = Path(output_dir) / "ct" write_dicom_series( state_images["soft_transition"], reference_dicom_dir, legacy_soft_dir, run_root, "legacy_soft_transition", ) output_paths["legacy_soft"] = legacy_soft_dir progress("正在生成四状态过程对比图...") preview_paths = make_four_state_preview(state_images, Path(output_dir), angle_degrees) make_output_preview_from_images( state_images["original"], state_images["soft_transition"], Path(output_dir), angle_degrees, ) return output_paths, preview_paths def make_output_preview(original_dir, deformed_dicom_dir, output_dir, angle_degrees): safe_mkdir(PREVIEW_DIR) orig = load_dicom_volume(original_dir) deformed = load_dicom_volume(deformed_dicom_dir)[::-1] before = crop_head_neck(sagittal_mip(orig)) after = crop_head_neck(sagittal_mip(deformed)) slide = Image.new("RGB", (2560, 1440), (0, 0, 0)) draw = ImageDraw.Draw(slide) font_path = Path(r"C:\Windows\Fonts\arial.ttf") title_font = ImageFont.truetype(str(font_path), 58) if font_path.exists() else ImageFont.load_default() slide.paste(fit_image(before, 920, 675), (280, 390)) slide.paste(fit_image(after, 920, 675), (1360, 390)) draw.text((300, 190), "Before", font=title_font, fill=(255, 255, 255)) draw.text( (1380, 190), f"After: {angle_degrees:g} deg head extension", font=title_font, fill=(255, 255, 255), ) arrow = (255, 210, 60) x0, y0, x1, y1 = 1685, 545, 1855, 455 draw.line((x0, y0, x1, y1), fill=arrow, width=10) draw.polygon([(x1, y1), (x1 - 42, y1 + 5), (x1 - 15, y1 + 36)], fill=arrow) preview_path = output_dir / "before_after_preview.png" slide.save(preview_path, quality=95) return preview_path def sitk_sagittal_panel(image, coordinates_cutoff=None): volume = sitk.GetArrayFromImage(image)[::-1] panel = crop_head_neck(sagittal_mip(volume)) if coordinates_cutoff is not None: panel = draw_cutoff_line(panel, volume.shape[0], coordinates_cutoff) return panel def make_output_preview_from_images(original_image, deformed_image, output_dir, angle_degrees, coordinates_cutoff=None): before = sitk_sagittal_panel(original_image, coordinates_cutoff) after = sitk_sagittal_panel(deformed_image, coordinates_cutoff) slide = Image.new("RGB", (2560, 1440), (0, 0, 0)) draw = ImageDraw.Draw(slide) font_path = Path(r"C:\Windows\Fonts\arial.ttf") title_font = ImageFont.truetype(str(font_path), 58) if font_path.exists() else ImageFont.load_default() slide.paste(fit_image(before, 920, 675), (280, 390)) slide.paste(fit_image(after, 920, 675), (1360, 390)) draw.text((300, 190), "Before", font=title_font, fill=(255, 255, 255)) draw.text( (1380, 190), f"After: {angle_degrees:g} deg head extension", font=title_font, fill=(255, 255, 255), ) arrow = (255, 210, 60) x0, y0, x1, y1 = 1685, 545, 1855, 455 draw.line((x0, y0, x1, y1), fill=arrow, width=10) draw.polygon([(x1, y1), (x1 - 42, y1 + 5), (x1 - 15, y1 + 36)], fill=arrow) preview_path = Path(output_dir) / "before_after_preview.png" slide.save(preview_path, quality=95) return preview_path def make_four_state_preview(state_images, output_dir, angle_degrees, coordinates_cutoff=None): output_dir = Path(output_dir) screenshot_dir = output_dir / "process_screenshots" reset_folder(screenshot_dir) panels = [] for state_key, label, _ in STATE_LABELS: panel = sitk_sagittal_panel(state_images[state_key], coordinates_cutoff) panel_path = screenshot_dir / f"{state_key}.png" panel.save(panel_path, quality=95) panels.append((label, panel)) slide = Image.new("RGB", (2560, 720), (0, 0, 0)) draw = ImageDraw.Draw(slide) font_path = Path(r"C:\Windows\Fonts\arial.ttf") title_font = ImageFont.truetype(str(font_path), 36) if font_path.exists() else ImageFont.load_default() small_font = ImageFont.truetype(str(font_path), 24) if font_path.exists() else ImageFont.load_default() panel_width = 560 panel_height = 430 margin = 55 gap = 70 y_image = 145 y_title = 62 for index, (label, panel) in enumerate(panels): x = margin + index * (panel_width + gap) slide.paste(fit_image(panel, panel_width, panel_height), (x, y_image)) draw.text((x, y_title), label, font=title_font, fill=(255, 255, 255)) draw.text( (margin, 650), f"Head extension angle: {angle_degrees:g} deg", font=small_font, fill=(190, 190, 190), ) comparison_path = output_dir / "process_comparison_4states.png" slide.save(comparison_path, quality=95) return { "comparison": comparison_path, "screenshots": screenshot_dir, } class HeadExtensionApp: def __init__(self, root): self.root = root self.root.title("头颈部 CT 仰头形变工具") self.root.geometry("1180x820") self.input_dir = StringVar(value=str(APP_DIR / "input_ct_2F")) self.output_dir = StringVar(value=str(APP_DIR / "app_output")) self.status = StringVar(value="请选择 DICOM 文件夹,调节角度后可先预览,再生成四状态 DICOM 和过程对比图。") self.angle_text = StringVar(value="12.0°") self.transition_text = StringVar(value="45") self.cached_volume = None self.preview_photo = None self.preview_after_id = None self.build_ui() def build_ui(self): top = Frame(self.root) top.pack(fill="x", padx=12, pady=10) Label(top, text="输入 DICOM 文件夹").grid(row=0, column=0, sticky="w") Entry(top, textvariable=self.input_dir, width=92).grid(row=0, column=1, padx=8) Button(top, text="选择", command=self.choose_input).grid(row=0, column=2) Label(top, text="输出文件夹").grid(row=1, column=0, sticky="w", pady=6) Entry(top, textvariable=self.output_dir, width=92).grid(row=1, column=1, padx=8) Button(top, text="选择", command=self.choose_output).grid(row=1, column=2) controls = Frame(self.root) controls.pack(fill="x", padx=12) Label(controls, text="仰头角度").grid(row=0, column=0, sticky="w") self.angle = Scale( controls, from_=0, to=20, orient=HORIZONTAL, resolution=0.5, length=420, command=self.on_angle_change, ) self.angle.set(12) self.angle.grid(row=0, column=1, sticky="w", padx=8) Label(controls, textvariable=self.angle_text, width=8, anchor="w").grid( row=0, column=2, sticky="w" ) Label(controls, text="过渡平滑宽度").grid(row=0, column=3, sticky="w", padx=(20, 0)) self.transition = Scale( controls, from_=50, to=160, orient=HORIZONTAL, resolution=10, length=300, command=self.on_transition_change, ) self.transition.set(90) self.transition.grid(row=0, column=4, sticky="w", padx=8) Label(controls, textvariable=self.transition_text, width=8, anchor="w").grid( row=0, column=5, sticky="w" ) buttons = Frame(self.root) buttons.pack(fill="x", padx=12, pady=8) self.preview_button = Button(buttons, text="更新预览", command=self.update_preview, width=16) self.preview_button.pack(side="left", padx=(0, 8)) self.run_all_button = Button( buttons, text="保存过程对比图+四状态DICOM", command=self.start_run, width=28, ) self.run_all_button.pack(side="left") self.preview_label = Label(self.root, bg="black") self.preview_label.pack(fill=BOTH, expand=True, padx=12, pady=8) Label(self.root, textvariable=self.status, anchor="w").pack(fill="x", padx=12, pady=(0, 8)) self.log = Entry(self.root) self.log.pack(fill="x", padx=12, pady=(0, 10)) def on_angle_change(self, value): self.angle_text.set(f"{float(value):.1f}°") self.schedule_preview_refresh() def on_transition_change(self, value): self.transition_text.set(f"{int(float(value))}") def schedule_preview_refresh(self): if self.preview_after_id is not None: self.root.after_cancel(self.preview_after_id) self.preview_after_id = self.root.after(250, self.update_preview) def choose_input(self): path = filedialog.askdirectory(title="选择 DICOM 文件夹") if path: self.input_dir.set(path) self.cached_volume = None def choose_output(self): path = filedialog.askdirectory(title="选择输出文件夹") if path: self.output_dir.set(path) def set_busy(self, busy): state = DISABLED if busy else NORMAL self.preview_button.config(state=state) self.run_all_button.config(state=state) def update_status(self, message): self.root.after(0, lambda: self.status.set(message)) def update_preview(self): self.preview_after_id = None try: if self.cached_volume is None: self.status.set("正在读取 DICOM 生成预览...") self.cached_volume = load_dicom_volume(self.input_dir.get()) before = crop_head_neck(sagittal_mip(self.cached_volume)) before_with_line = draw_cutoff_line(before, self.cached_volume.shape[0]) after = preview_deform_2d(before_with_line, float(self.angle.get())) canvas = Image.new("RGB", (1120, 610), (0, 0, 0)) draw = ImageDraw.Draw(canvas) before_panel = fit_image(before_with_line, 520, 500) after_panel = fit_image(after, 520, 500) canvas.paste(before_panel, (30, 80)) canvas.paste(after_panel, (570, 80)) draw.text((40, 25), "Before", fill=(255, 255, 255)) draw.text((580, 25), f"Preview: {float(self.angle.get()):g} deg", fill=(255, 255, 255)) self.preview_photo = ImageTk.PhotoImage(canvas) self.preview_label.config(image=self.preview_photo) self.status.set("预览已更新。预览是快速 2D 示意,最终输出会使用 3D 位移场。") except Exception as exc: messagebox.showerror("预览失败", str(exc)) self.status.set("预览失败,请检查输入 DICOM 文件夹。") def start_run(self): self.set_busy(True) self.status.set("开始生成四状态完整输出,请稍等。300 层 CT 通常需要 1-3 分钟。") thread = threading.Thread(target=self.run_worker, daemon=True) thread.start() def run_worker(self): try: output_paths, preview_paths = run_deformation( self.input_dir.get(), self.output_dir.get(), float(self.angle.get()), int(self.transition.get()), self.update_status, ) self.update_status(f"完成。四状态 DICOM 与过程对比图已输出到:{self.output_dir.get()}") self.root.after( 0, lambda: messagebox.showinfo( "完成", "四状态完整输出已生成:\n" f"Original:{output_paths['original']}\n" f"Hard boundary:{output_paths['hard_boundary']}\n" f"Gaussian smooth:{output_paths['gaussian_smooth']}\n" f"Soft transition:{output_paths['soft_transition']}\n\n" f"过程对比图:\n{preview_paths['comparison']}\n\n" f"兼容旧版本的 Soft transition 输出:\n{output_paths['legacy_soft']}", ), ) except Exception as exc: error_message = str(exc) self.update_status("生成失败。") self.root.after(0, lambda: messagebox.showerror("生成失败", error_message)) finally: self.root.after(0, lambda: self.set_busy(False)) def main(): safe_mkdir(RUNTIME_DIR) safe_mkdir(PREVIEW_DIR) root = Tk() app = HeadExtensionApp(root) app.update_preview() root.mainloop() if __name__ == "__main__": main()