Files
Head_CT_Morph/head_extension_app.py

674 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import os
import shutil
import tempfile
import threading
import time
import uuid
from pathlib import Path
from tkinter import (
BOTH,
DISABLED,
END,
HORIZONTAL,
NORMAL,
Button,
Entry,
Frame,
Label,
Scale,
StringVar,
Tk,
filedialog,
messagebox,
)
import numpy as np
import pydicom
import SimpleITK as sitk
from PIL import Image, ImageDraw, ImageFont, ImageTk
from platipy.imaging.generation.mask import get_external_mask
from platipy.imaging.registration.utils import apply_transform
np.alen = len
from DeformHeadCT.VolumeInfo import VolumeDeformation, convert_nifti_to_dicom_series
from DeformHeadCT.deformation import HeadDeformation, generate_field_rotation
APP_DIR = Path(__file__).resolve().parent
RUNTIME_DIR = Path(tempfile.gettempdir()) / "HeadExtensionApp"
PREVIEW_DIR = APP_DIR / "app_previews"
DEFAULT_COORDINATES_CUTOFF = [[145, 305, 256], [135, 205, 256]]
DEFAULT_VERTEBRAE = {
"Oc-C1": [220, 255, 256],
"C1-C2": [208, 255, 256],
"C2-C3": [190, 255, 256],
"C3-C4": [172, 258, 256],
"C4-C5": [154, 262, 256],
"C5-C6": [136, 268, 256],
"C6-C7": [118, 274, 256],
"C7-T1": [100, 282, 256],
}
STATE_LABELS = [
("original", "Original", "ct_original"),
("hard_boundary", "Hard boundary", "ct_hard_boundary"),
("gaussian_smooth", "Gaussian smooth", "ct_gaussian_smooth"),
("soft_transition", "Soft transition", "ct_soft_transition"),
]
def safe_mkdir(path):
Path(path).mkdir(parents=True, exist_ok=True)
def copy_dicom_to_ascii_folder(input_dir, run_id):
target = RUNTIME_DIR / run_id / "input_ct"
safe_mkdir(target)
count = 0
for file_path in Path(input_dir).iterdir():
if file_path.is_file() and file_path.suffix.lower() == ".dcm":
shutil.copy2(file_path, target / file_path.name)
count += 1
if count == 0:
raise RuntimeError("输入文件夹里没有找到 .dcm 文件。")
return target, count
def load_dicom_volume(input_dir):
items = []
for file_path in Path(input_dir).iterdir():
if file_path.is_file() and file_path.suffix.lower() == ".dcm":
ds = pydicom.dcmread(str(file_path), force=True)
instance = int(getattr(ds, "InstanceNumber", len(items)))
items.append((instance, ds))
if not items:
raise RuntimeError("输入文件夹里没有找到 .dcm 文件。")
items.sort(key=lambda item: item[0])
volume = []
for _, ds in items:
image = ds.pixel_array.astype(np.float32)
image = image * float(getattr(ds, "RescaleSlope", 1))
image = image + float(getattr(ds, "RescaleIntercept", 0))
volume.append(image)
return np.stack(volume, axis=0)
def ct_window(image, low=-500, high=1200):
image = np.clip((image - low) / (high - low), 0, 1)
return (image * 255).astype(np.uint8)
def sagittal_mip(volume):
x0 = max(0, volume.shape[2] // 2 - 21)
x1 = min(volume.shape[2], volume.shape[2] // 2 + 24)
image = ct_window(np.max(volume[:, :, x0:x1], axis=2))
return Image.fromarray(image).convert("RGB")
def crop_head_neck(image):
width, height = image.size
left = int(width * 0.09)
right = int(width * 0.89)
top = 0
bottom = int(height * 0.72)
return image.crop((left, top, right, bottom))
def cutoff_center_z(coordinates_cutoff):
return float(np.mean([point[0] for point in coordinates_cutoff]))
def draw_cutoff_line(panel, image_depth, coordinates_cutoff=DEFAULT_COORDINATES_CUTOFF):
panel = panel.copy()
crop_height = int(image_depth * 0.72)
if crop_height <= 0:
return panel
line_y = int(round((image_depth - 1 - cutoff_center_z(coordinates_cutoff)) * panel.height / crop_height))
if line_y < 0 or line_y >= panel.height:
return panel
draw = ImageDraw.Draw(panel)
shadow = (0, 0, 0)
line_color = (255, 215, 60)
draw.line((0, line_y, panel.width, line_y), fill=shadow, width=6)
draw.line((0, line_y, panel.width, line_y), fill=line_color, width=3)
return panel
def fit_image(image, width, height):
scale = min(width / image.width, height / image.height)
resized = image.resize(
(int(image.width * scale), int(image.height * scale)),
Image.Resampling.LANCZOS,
)
canvas = Image.new("RGB", (width, height), (0, 0, 0))
canvas.paste(resized, ((width - resized.width) // 2, (height - resized.height) // 2))
return canvas
def preview_deform_2d(image, angle_degrees):
"""Fast visual preview only. The DICOM output uses the real 3D field."""
try:
from scipy.ndimage import map_coordinates
except Exception:
return image
arr = np.asarray(image.convert("L")).astype(np.float32)
h, w = arr.shape
yy, xx = np.mgrid[0:h, 0:w]
pivot_x = int(w * 0.55)
pivot_y = int(h * 0.62)
full_motion_y = h * 0.50
fixed_y = h * 0.92
t = np.clip((yy - full_motion_y) / (fixed_y - full_motion_y), 0, 1)
weight = 1 - (t * t * (3 - 2 * t))
x_soft = np.clip((xx - w * 0.15) / (w * 0.75), 0, 1)
x_soft = x_soft * x_soft * (3 - 2 * x_soft)
weight = np.clip(weight * (0.90 + 0.10 * x_soft), 0, 1)
theta = np.deg2rad(angle_degrees) * weight
cos_t = np.cos(theta)
sin_t = np.sin(theta)
dx = xx - pivot_x
dy = yy - pivot_y
src_x = pivot_x + cos_t * dx + sin_t * dy
src_y = pivot_y - sin_t * dx + cos_t * dy
warped = map_coordinates(arr, [src_y, src_x], order=1, mode="constant", cval=0)
return Image.fromarray(np.clip(warped, 0, 255).astype(np.uint8)).convert("RGB")
def transition_weight(image, coordinates_cutoff, width_voxels):
size_x, size_y, size_z = image.GetSize()
center_z = float(np.mean([p[0] for p in coordinates_cutoff]))
edge0 = center_z - float(width_voxels)
edge1 = center_z + float(width_voxels)
z_grid = np.arange(size_z, dtype=np.float32)[:, None, None]
weight = np.clip((z_grid - edge0) / (edge1 - edge0), 0, 1)
weight = weight * weight * (3 - 2 * weight)
return np.broadcast_to(weight, (size_z, size_y, size_x)).astype(np.float32)
def hard_boundary_weight(image, coordinates_cutoff):
size_x, size_y, size_z = image.GetSize()
center_z = float(np.mean([p[0] for p in coordinates_cutoff]))
z_grid = np.arange(size_z, dtype=np.float32)[:, None, None]
weight = (z_grid >= center_z).astype(np.float32)
return np.broadcast_to(weight, (size_z, size_y, size_x)).astype(np.float32)
def image_from_dvf_array(dvf_arr, reference_image):
dvf = sitk.GetImageFromArray(dvf_arr.astype(np.float32))
dvf.CopyInformation(reference_image)
return dvf
def apply_dvf_to_ct(image_ct, dvf):
transform = sitk.DisplacementFieldTransform(sitk.Cast(dvf, sitk.sitkVectorFloat64))
return apply_transform(
image_ct,
transform=transform,
interpolator=sitk.sitkLinear,
default_value=int(sitk.GetArrayViewFromImage(image_ct).min()),
)
def reset_folder(path):
path = Path(path)
if path.exists():
shutil.rmtree(path)
path.mkdir(parents=True, exist_ok=True)
def write_dicom_series(image, reference_dicom_dir, output_dir, run_root, state_key):
output_dir = Path(output_dir)
reset_folder(output_dir)
ascii_output_dir = run_root / "dicom_output" / state_key
reset_folder(ascii_output_dir)
convert_nifti_to_dicom_series(
image=image,
reference_dcm=str(reference_dicom_dir),
output_directory=str(ascii_output_dir),
)
for dicom_path in ascii_output_dir.glob("*.dcm"):
shutil.copy2(dicom_path, output_dir / dicom_path.name)
def write_info_json(info_path, input_dir, temp_dir, output_dir, angle_degrees, transition_width):
data = {
"name": "HEAD_EXTENSION",
"InputDirectory": str(input_dir).replace("\\", "/"),
"TempDirectory": str(temp_dir).replace("\\", "/"),
"OutputDirectory": str(output_dir).replace("\\", "/"),
"axes": [0, 0, -1],
"angles": [float(angle_degrees)],
"coordinates_cutoff": DEFAULT_COORDINATES_CUTOFF,
"transition_width_voxels": int(transition_width),
}
data.update(DEFAULT_VERTEBRAE)
info_path.write_text(json.dumps(data, indent=2), encoding="utf-8")
def run_deformation(input_dir, output_dir, angle_degrees, transition_width, progress):
run_id = time.strftime("%Y%m%d_%H%M%S_") + uuid.uuid4().hex[:8]
run_root = RUNTIME_DIR / run_id
temp_dir = run_root / "temp"
ascii_input_dir, copied_count = copy_dicom_to_ascii_folder(input_dir, run_id)
safe_mkdir(output_dir)
info_path = run_root / "head_extension.json"
write_info_json(
info_path,
ascii_input_dir,
temp_dir,
Path(output_dir),
angle_degrees,
transition_width,
)
progress(f"已复制 {copied_count} 张 DICOM开始转换体数据...")
vol_info = VolumeDeformation(InfoFile=str(info_path))
vol_info.PrepareDcmData()
patient_image_dir = vol_info.nifti_directory / vol_info.patientunderscore / "IMAGES"
if not patient_image_dir.exists():
patient_image_dir = next(vol_info.nifti_directory.glob("*/IMAGES"))
vol_info.patientunderscore = patient_image_dir.parent.name
root_image_dir = vol_info.nifti_directory / "IMAGES"
safe_mkdir(root_image_dir)
for image_path in patient_image_dir.glob("*.nii.gz"):
shutil.copy2(image_path, root_image_dir / image_path.name)
progress("正在生成四种状态的三维位移场...")
head_def = HeadDeformation(vol_info.nifti_directory, vol_info.patientunderscore, 0)
external_mask = get_external_mask(head_def.image_ct)
_, _, full_rotation_dvf, _ = generate_field_rotation(
head_def.image_ct,
external_mask,
tuple(vol_info.point_of_rotation[0]),
axis_of_rotation=vol_info.axes,
angle=-vol_info.angles[0] * np.pi / 180,
gaussian_smooth=0,
)
base_dvf_arr = sitk.GetArrayFromImage(full_rotation_dvf).astype(np.float32)
mask_arr = sitk.GetArrayFromImage(external_mask).astype(np.float32)
hard_weight = hard_boundary_weight(head_def.image_ct, vol_info.coordinates_cutoff)
soft_weight = transition_weight(
head_def.image_ct,
vol_info.coordinates_cutoff,
int(transition_width),
)
hard_dvf = image_from_dvf_array(
base_dvf_arr * (hard_weight * mask_arr)[..., None],
head_def.image_ct,
)
gaussian_dvf = sitk.SmoothingRecursiveGaussian(hard_dvf, 3)
soft_dvf = image_from_dvf_array(
base_dvf_arr * (soft_weight * mask_arr)[..., None],
head_def.image_ct,
)
soft_dvf = sitk.SmoothingRecursiveGaussian(soft_dvf, 3)
progress("正在应用形变...")
state_images = {
"original": head_def.image_ct,
"hard_boundary": apply_dvf_to_ct(head_def.image_ct, hard_dvf),
"gaussian_smooth": apply_dvf_to_ct(head_def.image_ct, gaussian_dvf),
"soft_transition": apply_dvf_to_ct(head_def.image_ct, soft_dvf),
}
progress("正在写出四种状态的完整 DICOM 序列...")
reference_dicom_dir = vol_info.nifti_directory / "dicom" / "ct"
output_paths = {}
for state_key, _, folder_name in STATE_LABELS:
state_output_dir = Path(output_dir) / folder_name
write_dicom_series(
state_images[state_key],
reference_dicom_dir,
state_output_dir,
run_root,
state_key,
)
output_paths[state_key] = state_output_dir
legacy_soft_dir = Path(output_dir) / "ct"
write_dicom_series(
state_images["soft_transition"],
reference_dicom_dir,
legacy_soft_dir,
run_root,
"legacy_soft_transition",
)
output_paths["legacy_soft"] = legacy_soft_dir
progress("正在生成四状态过程对比图...")
preview_paths = make_four_state_preview(
state_images,
Path(output_dir),
angle_degrees,
vol_info.coordinates_cutoff,
)
make_output_preview_from_images(
state_images["original"],
state_images["soft_transition"],
Path(output_dir),
angle_degrees,
vol_info.coordinates_cutoff,
)
return output_paths, preview_paths
def make_output_preview(original_dir, deformed_dicom_dir, output_dir, angle_degrees):
safe_mkdir(PREVIEW_DIR)
orig = load_dicom_volume(original_dir)
deformed = load_dicom_volume(deformed_dicom_dir)[::-1]
before = crop_head_neck(sagittal_mip(orig))
after = crop_head_neck(sagittal_mip(deformed))
slide = Image.new("RGB", (2560, 1440), (0, 0, 0))
draw = ImageDraw.Draw(slide)
font_path = Path(r"C:\Windows\Fonts\arial.ttf")
title_font = ImageFont.truetype(str(font_path), 58) if font_path.exists() else ImageFont.load_default()
slide.paste(fit_image(before, 920, 675), (280, 390))
slide.paste(fit_image(after, 920, 675), (1360, 390))
draw.text((300, 190), "Before", font=title_font, fill=(255, 255, 255))
draw.text(
(1380, 190),
f"After: {angle_degrees:g} deg head extension",
font=title_font,
fill=(255, 255, 255),
)
arrow = (255, 210, 60)
x0, y0, x1, y1 = 1685, 545, 1855, 455
draw.line((x0, y0, x1, y1), fill=arrow, width=10)
draw.polygon([(x1, y1), (x1 - 42, y1 + 5), (x1 - 15, y1 + 36)], fill=arrow)
preview_path = output_dir / "before_after_preview.png"
slide.save(preview_path, quality=95)
return preview_path
def sitk_sagittal_panel(image, coordinates_cutoff=None):
volume = sitk.GetArrayFromImage(image)[::-1]
panel = crop_head_neck(sagittal_mip(volume))
if coordinates_cutoff is not None:
panel = draw_cutoff_line(panel, volume.shape[0], coordinates_cutoff)
return panel
def make_output_preview_from_images(original_image, deformed_image, output_dir, angle_degrees, coordinates_cutoff=None):
before = sitk_sagittal_panel(original_image, coordinates_cutoff)
after = sitk_sagittal_panel(deformed_image, coordinates_cutoff)
slide = Image.new("RGB", (2560, 1440), (0, 0, 0))
draw = ImageDraw.Draw(slide)
font_path = Path(r"C:\Windows\Fonts\arial.ttf")
title_font = ImageFont.truetype(str(font_path), 58) if font_path.exists() else ImageFont.load_default()
slide.paste(fit_image(before, 920, 675), (280, 390))
slide.paste(fit_image(after, 920, 675), (1360, 390))
draw.text((300, 190), "Before", font=title_font, fill=(255, 255, 255))
draw.text(
(1380, 190),
f"After: {angle_degrees:g} deg head extension",
font=title_font,
fill=(255, 255, 255),
)
arrow = (255, 210, 60)
x0, y0, x1, y1 = 1685, 545, 1855, 455
draw.line((x0, y0, x1, y1), fill=arrow, width=10)
draw.polygon([(x1, y1), (x1 - 42, y1 + 5), (x1 - 15, y1 + 36)], fill=arrow)
preview_path = Path(output_dir) / "before_after_preview.png"
slide.save(preview_path, quality=95)
return preview_path
def make_four_state_preview(state_images, output_dir, angle_degrees, coordinates_cutoff=None):
output_dir = Path(output_dir)
screenshot_dir = output_dir / "process_screenshots"
reset_folder(screenshot_dir)
panels = []
for state_key, label, _ in STATE_LABELS:
panel = sitk_sagittal_panel(state_images[state_key], coordinates_cutoff)
panel_path = screenshot_dir / f"{state_key}.png"
panel.save(panel_path, quality=95)
panels.append((label, panel))
slide = Image.new("RGB", (2560, 720), (0, 0, 0))
draw = ImageDraw.Draw(slide)
font_path = Path(r"C:\Windows\Fonts\arial.ttf")
title_font = ImageFont.truetype(str(font_path), 36) if font_path.exists() else ImageFont.load_default()
small_font = ImageFont.truetype(str(font_path), 24) if font_path.exists() else ImageFont.load_default()
panel_width = 560
panel_height = 430
margin = 55
gap = 70
y_image = 145
y_title = 62
for index, (label, panel) in enumerate(panels):
x = margin + index * (panel_width + gap)
slide.paste(fit_image(panel, panel_width, panel_height), (x, y_image))
draw.text((x, y_title), label, font=title_font, fill=(255, 255, 255))
draw.text(
(margin, 650),
f"Head extension angle: {angle_degrees:g} deg",
font=small_font,
fill=(190, 190, 190),
)
comparison_path = output_dir / "process_comparison_4states.png"
slide.save(comparison_path, quality=95)
return {
"comparison": comparison_path,
"screenshots": screenshot_dir,
}
class HeadExtensionApp:
def __init__(self, root):
self.root = root
self.root.title("头颈部 CT 仰头形变工具")
self.root.geometry("1180x820")
self.input_dir = StringVar(value=str(APP_DIR / "input_ct_2F"))
self.output_dir = StringVar(value=str(APP_DIR / "app_output"))
self.status = StringVar(value="请选择 DICOM 文件夹,调节角度后可先预览,再生成四状态 DICOM 和过程对比图。")
self.angle_text = StringVar(value="12.0°")
self.transition_text = StringVar(value="45")
self.cached_volume = None
self.preview_photo = None
self.preview_after_id = None
self.build_ui()
def build_ui(self):
top = Frame(self.root)
top.pack(fill="x", padx=12, pady=10)
Label(top, text="输入 DICOM 文件夹").grid(row=0, column=0, sticky="w")
Entry(top, textvariable=self.input_dir, width=92).grid(row=0, column=1, padx=8)
Button(top, text="选择", command=self.choose_input).grid(row=0, column=2)
Label(top, text="输出文件夹").grid(row=1, column=0, sticky="w", pady=6)
Entry(top, textvariable=self.output_dir, width=92).grid(row=1, column=1, padx=8)
Button(top, text="选择", command=self.choose_output).grid(row=1, column=2)
controls = Frame(self.root)
controls.pack(fill="x", padx=12)
Label(controls, text="仰头角度").grid(row=0, column=0, sticky="w")
self.angle = Scale(
controls,
from_=0,
to=20,
orient=HORIZONTAL,
resolution=0.5,
length=420,
command=self.on_angle_change,
)
self.angle.set(12)
self.angle.grid(row=0, column=1, sticky="w", padx=8)
Label(controls, textvariable=self.angle_text, width=8, anchor="w").grid(
row=0, column=2, sticky="w"
)
Label(controls, text="过渡平滑宽度").grid(row=0, column=3, sticky="w", padx=(20, 0))
self.transition = Scale(
controls,
from_=50,
to=160,
orient=HORIZONTAL,
resolution=10,
length=300,
command=self.on_transition_change,
)
self.transition.set(90)
self.transition.grid(row=0, column=4, sticky="w", padx=8)
Label(controls, textvariable=self.transition_text, width=8, anchor="w").grid(
row=0, column=5, sticky="w"
)
buttons = Frame(self.root)
buttons.pack(fill="x", padx=12, pady=8)
self.preview_button = Button(buttons, text="更新预览", command=self.update_preview, width=16)
self.preview_button.pack(side="left", padx=(0, 8))
self.run_all_button = Button(
buttons,
text="保存过程对比图+四状态DICOM",
command=self.start_run,
width=28,
)
self.run_all_button.pack(side="left")
self.preview_label = Label(self.root, bg="black")
self.preview_label.pack(fill=BOTH, expand=True, padx=12, pady=8)
Label(self.root, textvariable=self.status, anchor="w").pack(fill="x", padx=12, pady=(0, 8))
self.log = Entry(self.root)
self.log.pack(fill="x", padx=12, pady=(0, 10))
def on_angle_change(self, value):
self.angle_text.set(f"{float(value):.1f}°")
self.schedule_preview_refresh()
def on_transition_change(self, value):
self.transition_text.set(f"{int(float(value))}")
def schedule_preview_refresh(self):
if self.preview_after_id is not None:
self.root.after_cancel(self.preview_after_id)
self.preview_after_id = self.root.after(250, self.update_preview)
def choose_input(self):
path = filedialog.askdirectory(title="选择 DICOM 文件夹")
if path:
self.input_dir.set(path)
self.cached_volume = None
def choose_output(self):
path = filedialog.askdirectory(title="选择输出文件夹")
if path:
self.output_dir.set(path)
def set_busy(self, busy):
state = DISABLED if busy else NORMAL
self.preview_button.config(state=state)
self.run_all_button.config(state=state)
def update_status(self, message):
self.root.after(0, lambda: self.status.set(message))
def update_preview(self):
self.preview_after_id = None
try:
if self.cached_volume is None:
self.status.set("正在读取 DICOM 生成预览...")
self.cached_volume = load_dicom_volume(self.input_dir.get())
before = crop_head_neck(sagittal_mip(self.cached_volume))
after = preview_deform_2d(before, float(self.angle.get()))
after = draw_cutoff_line(after, self.cached_volume.shape[0])
canvas = Image.new("RGB", (1120, 610), (0, 0, 0))
draw = ImageDraw.Draw(canvas)
before_panel = fit_image(before, 520, 500)
after_panel = fit_image(after, 520, 500)
canvas.paste(before_panel, (30, 80))
canvas.paste(after_panel, (570, 80))
draw.text((40, 25), "Before", fill=(255, 255, 255))
draw.text((580, 25), f"Preview: {float(self.angle.get()):g} deg", fill=(255, 255, 255))
self.preview_photo = ImageTk.PhotoImage(canvas)
self.preview_label.config(image=self.preview_photo)
self.status.set("预览已更新。预览是快速 2D 示意,最终输出会使用 3D 位移场。")
except Exception as exc:
messagebox.showerror("预览失败", str(exc))
self.status.set("预览失败,请检查输入 DICOM 文件夹。")
def start_run(self):
self.set_busy(True)
self.status.set("开始生成四状态完整输出请稍等。300 层 CT 通常需要 1-3 分钟。")
thread = threading.Thread(target=self.run_worker, daemon=True)
thread.start()
def run_worker(self):
try:
output_paths, preview_paths = run_deformation(
self.input_dir.get(),
self.output_dir.get(),
float(self.angle.get()),
int(self.transition.get()),
self.update_status,
)
self.update_status(f"完成。四状态 DICOM 与过程对比图已输出到:{self.output_dir.get()}")
self.root.after(
0,
lambda: messagebox.showinfo(
"完成",
"四状态完整输出已生成:\n"
f"Original{output_paths['original']}\n"
f"Hard boundary{output_paths['hard_boundary']}\n"
f"Gaussian smooth{output_paths['gaussian_smooth']}\n"
f"Soft transition{output_paths['soft_transition']}\n\n"
f"过程对比图:\n{preview_paths['comparison']}\n\n"
f"兼容旧版本的 Soft transition 输出:\n{output_paths['legacy_soft']}",
),
)
except Exception as exc:
error_message = str(exc)
self.update_status("生成失败。")
self.root.after(0, lambda: messagebox.showerror("生成失败", error_message))
finally:
self.root.after(0, lambda: self.set_busy(False))
def main():
safe_mkdir(RUNTIME_DIR)
safe_mkdir(PREVIEW_DIR)
root = Tk()
app = HeadExtensionApp(root)
app.update_preview()
root.mainloop()
if __name__ == "__main__":
main()