import logging, argparse, sys, utils import shutil, tempfile from pathlib import Path from typing import Dict, List, Tuple import albumentations as A import cv2 import matplotlib.pyplot as plt import numpy as np import segmentation_models_pytorch as smp import torch from albumentations.pytorch import ToTensorV2 from scipy.special import softmax from sklearn.metrics import (auc, average_precision_score, precision_recall_curve, roc_curve) from tqdm import tqdm # 本地应用/库的导入 import config from utils import log_metrics_to_csv, get_preprocessing_transform, setup_directories # --- 日志设置 --- # 配置日志记录器 logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler()] # 输出日志到控制台 ) # 为 matplotlib 设置中文字体,请根据您的系统环境选择合适的字体 plt.rcParams['font.sans-serif'] = ['Noto Sans CJK JP', 'SimHei', 'Microsoft YaHei'] # 例如:SimHei, Microsoft YaHei plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题 # --- 核心工具函数 --- # 将单通道的类别索引掩码转换为 RGB 彩色掩码 def colorize_mask(mask: np.ndarray, class_rgb_values: List[Tuple[int, int, int]]) -> np.ndarray: """将单通道的类别索引掩码转换为 RGB 彩色掩码。""" color_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) # 一个通用的调色板,可以轻松扩展 palette = [ (0, 0, 0), (220, 20, 60), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255), (244, 164, 96), ] for class_idx, color in enumerate(palette): if class_idx < len(class_rgb_values): # 假设掩码中的像素值 [0, 1, 2...] 对应类别的索引 pixels_to_color = (mask == class_idx) color_mask[pixels_to_color] = color return color_mask # --- 分析与可视化函数 --- def generate_and_save_curves(y_true: np.ndarray, y_probs: np.ndarray, class_names: List[str], output_dir: Path) -> None: """ 根据整个测试集的聚合预测结果,为每个类别生成并保存 ROC 和 PR 曲线。 """ logging.info("正在生成 ROC 和 PR 曲线...") num_classes = len(class_names) plt.style.use('seaborn-v0_8-whitegrid') # --- 准备画布 --- fig_roc, ax_roc = plt.subplots(figsize=(10, 8)) ax_roc.plot([0, 1], [0, 1], 'k--', label='随机猜测') fig_pr, ax_pr = plt.subplots(figsize=(10, 8)) # --- 为每个类别(跳过背景)生成曲线 --- for i in range(1, num_classes): class_name = class_names[i] y_true_binary = (y_true == i).astype(int) y_class_probs = y_probs[:, i] # ROC 曲线 fpr, tpr, _ = roc_curve(y_true_binary, y_class_probs) roc_auc = auc(fpr, tpr) ax_roc.plot(fpr, tpr, label=f'类别 {class_name} (AUC = {roc_auc:.4f})') # PR 曲线 precision, recall, _ = precision_recall_curve(y_true_binary, y_class_probs) avg_precision = average_precision_score(y_true_binary, y_class_probs) ax_pr.plot(recall, precision, label=f'类别 {class_name} (AP = {avg_precision:.4f})') # --- 美化并保存 ROC 曲线图 --- ax_roc.set(xlabel='False positive rate (FPR)', ylabel='True positive rate (TPR)', title='Receiver operating characteristic (ROC) curve') ax_roc.legend(loc='lower right') ax_roc.set_aspect('equal', adjustable='box') fig_roc.tight_layout() roc_save_path = output_dir / "ROC_Curves.png" fig_roc.savefig(roc_save_path) plt.close(fig_roc) logging.info(f"-> ROC 曲线图已保存至: {roc_save_path}") # --- 美化并保存 PR 曲线图 --- ax_pr.set(xlabel='Recall', ylabel='Precision', title='Precision-Recall (PR) Curve') ax_pr.legend(loc='lower left') ax_pr.set_aspect('equal', adjustable='box') fig_pr.tight_layout() pr_save_path = output_dir / "Precision_Recall_Curves.png" fig_pr.savefig(pr_save_path) plt.close(fig_pr) logging.info(f"-> PR 曲线图已保存至: {pr_save_path}") def save_visual_comparison( image_name: str, original_image: np.ndarray, pred_mask: np.ndarray, gt_mask: np.ndarray = None, stats: Dict = None, stats_text: str = "", save_dir: str = "" ) -> None: """保存原始图像、预测掩码以及可选的真实掩码的对比图。""" pred_color_mask = colorize_mask(pred_mask, config.CLASS_RGB_VALUES) if gt_mask is not None and stats is not None: num_plots = 3 figsize = (22, 8) gt_color_mask = colorize_mask(gt_mask, config.CLASS_RGB_VALUES) save_name = f"{Path(image_name).stem}-IOU_{stats['iou']:.4f}-ACC_{stats['acc']:.4f}.png" else: num_plots = 2 figsize = (14, 6) save_name = f"{Path(image_name).stem}.png" fig, axes = plt.subplots(1, num_plots, figsize=figsize) fig.suptitle(f"Predict: {image_name}", fontsize=16) axes[0].imshow(original_image) axes[0].set_title("ORI_IMG") axes[0].axis('off') pred_title = "Predicted Mask" if stats: pred_title += f"\nIoU: {stats['iou']:.4f} | Acc: {stats['acc']:.4f}" axes[1].imshow(pred_color_mask) axes[1].set_title(pred_title) axes[1].axis('off') if num_plots == 3: axes[2].imshow(gt_color_mask) axes[2].set_title("Ground Truth") axes[2].axis('off') plt.figtext(0.5, 0.01, stats_text, ha="center", fontsize=10, bbox={"facecolor":"white", "alpha":0.7, "pad":5}, family='monospace') fig.subplots_adjust(bottom=0.3) plt.tight_layout() plt.savefig(save_dir / save_name) plt.close(fig) # --- 主要逻辑函数 --- def load_model(model_architecture: str, model_save_path: Path) -> torch.nn.Module: """ 加载已训练的分割模型。 此函数经过优化,可以智能处理由 torch.nn.DataParallel 保存的 (带有 'module.' 前缀) 和常规保存的模型权重。 """ logging.info(f"正在从以下路径加载模型: {model_save_path}") num_classes = len(config.CLASSES) # 1. 【修改】根据 config 动态创建基础模型架构 logging.info(f"正在重建模型架构: '{model_architecture}'") try: model_params = config.ALL_MODEL_CONFIGS[model_architecture].copy() # model_class = getattr(smp, model_architecture) except AttributeError: logging.error(f"模型 '{model_architecture}' 在 segmentation_models_pytorch 库中不存在!") raise # 2. 【修改】准备参数,与训练时保持一致 # 注意:加载权重时,我们不希望再次下载预训练权重,所以设为 None params = model_params params['in_channels'] = 3 params['classes'] = num_classes params['encoder_weights'] = None # 非常重要!避免在预测时重新下载ImageNet权重 # ======================== 【新增代码段开始】 ======================== # # 自动检测 encoder_name 是否包含 'vit' encoder_name = model_params.get('encoder_name', '') if 'vit' in encoder_name.lower(): params['dynamic_img_size'] = True logging.info(f"检测到 ViT 编码器 ('{encoder_name}')。自动设置 dynamic_img_size=True。") # ======================== 【新增代码段结束】 ======================== # # 3. 【修改】使用 **kwargs 创建模型实例 model = model_class(**params).to(config.DEVICE) # 4. 从文件加载 state_dict state_dict = torch.load(model_save_path, map_location=torch.device(config.DEVICE)) # 3. 检查是否存在 'module.' 前缀 (判断是否为 DataParallel 保存的权重) # 我们通过检查字典中任何一个键是否以 'module.' 开头来判断 is_dataparallel = any(key.startswith('module.') for key in state_dict.keys()) if is_dataparallel: logging.info("检测到模型由 DataParallel 保存,将移除 'module.' 前缀。") # 创建一个新的、不带前缀的 state_dict new_state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()} model.load_state_dict(new_state_dict) else: # 如果没有 'module.' 前缀,直接加载 logging.info("直接加载标准模型权重。") model.load_state_dict(state_dict) model.eval() return model def calculate_and_save_final_metrics(stats_list: List[Dict], num_classes: int, save_dir: str) -> None: """聚合每张图像的统计数据,计算总体指标,并保存到 CSV 文件。""" if not stats_list: logging.warning("未收集到统计数据,跳过最终指标计算。") return logging.info("\n正在聚合结果以计算最终指标...") # 堆叠所有统计张量并按批次维度求和 tp = torch.cat([s['tp'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,) fp = torch.cat([s['fp'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,) fn = torch.cat([s['fn'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,) tn = torch.cat([s['tn'] for s in stats_list], dim=0).sum(dim=0) # (N×B, C) -> (C,) # ======================== 【新增代码段开始】 ======================== # # 根据 config.py 中的设置,筛选出用于计算宏观指标的类别 ignore_indices = [config.CLASSES.index(cls) for cls in config.EVALUATION_CLASSES_TO_IGNORE if cls in config.CLASSES] keep_mask = torch.ones(num_classes, dtype=torch.bool, device=tp.device) if ignore_indices: keep_mask[ignore_indices] = False logging.info(f"最终评估时将忽略类别: {config.EVALUATION_CLASSES_TO_IGNORE}") tp_filtered = tp[keep_mask] fp_filtered = fp[keep_mask] fn_filtered = fn[keep_mask] tn_filtered = tn[keep_mask] # ======================== 【新增代码段结束】 ======================== # # ======================== 【修改代码段开始】 ======================== # # 使用过滤后的统计数据计算总体指标 metrics = { "iou_score": smp.metrics.iou_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(), "f1_score": smp.metrics.f1_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(), "accuracy": smp.metrics.accuracy(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(), "recall": smp.metrics.recall(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(), "precision": smp.metrics.precision(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(), } # ======================== 【修改代码段结束】 ======================== # # 添加每个类别的原始统计数据(使用未经过滤的数据) for i in range(num_classes): metrics[f'tp_class_{i}'] = tp[i].item() metrics[f'tn_class_{i}'] = tn[i].item() metrics[f'fp_class_{i}'] = fp[i].item() metrics[f'fn_class_{i}'] = fn[i].item() csv_save_path = save_dir / "test_set_metrics.csv" log_metrics_to_csv(metrics, str(csv_save_path)) logging.info(f"-> 整体测试集指标已保存至: {csv_save_path}") # 2. 【新增】一个帮助函数,用于交互式选择模型目录 def select_trained_model_path(base_dir: Path, model_architecture: str) -> Path: """ 查找指定架构的所有训练运行目录,并让用户选择一个。 """ logging.info(f"正在 '{base_dir}' 中搜索模型 '{model_architecture}' 的训练记录...") # 查找所有以架构名开头的文件夹 run_dirs = sorted([d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith(model_architecture+'_')]) if not run_dirs: logging.error(f"未找到任何 '{model_architecture}' 模型的训练记录。请先运行 train.py。") return None print("\n请选择要用于预测的模型:") for i, dir_path in enumerate(run_dirs): print(f" [{i+1}] {dir_path.name}") while True: try: choice = input(f"请输入选项编号 (1-{len(run_dirs)}) 或按 Enter 退出: ") if not choice: return None choice_idx = int(choice) - 1 if 0 <= choice_idx < len(run_dirs): selected_dir = run_dirs[choice_idx] logging.info(f"已选择模型: {selected_dir}") return selected_dir else: print("无效的选项,请重新输入。") except (ValueError, IndexError): print("无效的输入,请输入一个数字。") # 【修改】函数签名和内部逻辑,不再接收一个列表,而是直接接收聚合后的统计数据 def calculate_and_save_final_metrics(tp, fp, fn, tn, num_classes: int, save_dir: Path) -> None: """根据聚合后的统计数据,计算总体指标,并保存到 CSV 文件。""" logging.info("\n正在计算最终的整体指标...") # 根据 config.py 中的设置,筛选出用于计算宏观指标的类别 ignore_indices = [config.CLASSES.index(cls) for cls in config.EVALUATION_CLASSES_TO_IGNORE if cls in config.CLASSES] keep_mask = torch.ones(num_classes, dtype=torch.bool, device=tp.device) if ignore_indices: keep_mask[ignore_indices] = False logging.info(f"最终评估时将忽略类别: {config.EVALUATION_CLASSES_TO_IGNORE}") tp_filtered = tp[keep_mask] fp_filtered = fp[keep_mask] fn_filtered = fn[keep_mask] tn_filtered = tn[keep_mask] # 使用过滤后的统计数据计算总体指标 metrics = { "iou_score": smp.metrics.iou_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(), "f1_score": smp.metrics.f1_score(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(), "accuracy": smp.metrics.accuracy(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(), "recall": smp.metrics.recall(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(), "precision": smp.metrics.precision(tp_filtered, fp_filtered, fn_filtered, tn_filtered, reduction='micro').item(), } # 添加每个类别的原始统计数据(使用未经过滤的数据) for i in range(num_classes): metrics[f'tp_class_{i}'] = tp[i].item() metrics[f'tn_class_{i}'] = tn[i].item() metrics[f'fp_class_{i}'] = fp[i].item() metrics[f'fn_class_{i}'] = fn[i].item() csv_save_path = save_dir / "test_set_metrics.csv" log_metrics_to_csv(metrics, str(csv_save_path)) logging.info(f"-> 整体测试集指标已保存至: {csv_save_path}") def main(model_architecture: str) -> None: """ 主函数,执行模型加载、图像预测、评估,并生成分析结果。 """ # --- 1. 【修改】选择模型并动态设置路径 --- selected_run_dir = select_trained_model_path(Path(config.PREDICT_BEST_MODEL_DIR), model_architecture) if not selected_run_dir: logging.info("未选择任何模型,程序退出。") sys.exit() # 【重要】动态覆写 config 中的路径,使其指向所选模型的目录 # 这样,后续所有代码都会自动在正确的文件夹中加载模型和保存结果 best_model_save_path = selected_run_dir / "best_model.pth" raw_mask_dir = selected_run_dir / "predicted_raw_masks" analysis_results_dir = selected_run_dir / "prediction_analysis" if not best_model_save_path.exists(): logging.error(f"错误: 在 '{selected_run_dir}' 中未找到 'best_model.pth' 文件!") sys.exit() # setup_directories 函数现在用于创建预测结果的子目录 utils.setup_directories(raw_mask_dir) utils.setup_directories(analysis_results_dir) model = load_model(model_architecture, best_model_save_path) # 1. 加载模型 # ======================== 【新增/修改的代码段开始】 ======================== # # --- 2. 探测模型以确定正确的图像尺寸 --- logging.info("正在探测模型以确定预测时所需的输入尺寸...") # target_height, target_width = config.IMAGE_HEIGHT, config.IMAGE_WIDTH # try: # 检查编码器是否需要固定输入尺寸 if model.encoder.is_fixed_input_size: required_size = model.encoder.input_size target_height = required_size[1] target_width = required_size[2] logging.warning( f"模型 '{model_architecture}' 的编码器需要固定的输入尺寸 " f"({target_height}x{target_width})。" f"将忽略 config.py 中的尺寸设置进行预测。" ) else: logging.info(f"模型 '{model_architecture}' 支持动态输入尺寸。将使用 config.py 中定义的尺寸 ({target_height}x{target_width})。") except AttributeError: logging.info("无法自动检测模型尺寸要求。将使用 config.py 中定义的尺寸。") # --- 3. 获取用于预测的图像预处理流程 --- # 使用探测到的或配置中指定的目标尺寸 transform = get_preprocessing_transform(target_height, target_width) # # ======================== 【新增/修改的代码段结束】 ======================== # # --- 4. 获取图像列表并判断模式 --- image_filenames = sorted(list(config.TEST_IMAGE_DIR.iterdir())) evaluate_mode = config.TEST_MASK_DIR.is_dir() if evaluate_mode: logging.info("正在以 [评估模式] 运行。将使用真实掩码进行评估。") else: logging.info("正在以 [仅预测模式] 运行。未找到真实掩码。") # --- 3. 初始化用于增量计算的变量 --- # 【修改】不再使用 list 来累积,而是直接累加 TP, FP, FN, TN num_classes = len(config.CLASSES) total_tp = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE) total_fp = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE) total_fn = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE) total_tn = torch.zeros(num_classes, dtype=torch.long, device=config.DEVICE) # 【新增】为 ROC/PR 曲线创建临时文件,用磁盘空间换内存 # 使用 tempfile 模块来安全地创建临时文件 temp_gt_file = tempfile.NamedTemporaryFile(delete=False, suffix='.npy') temp_probs_file = tempfile.NamedTemporaryFile(delete=False, suffix='.npy') logging.info(f"为ROC/PR曲线创建临时文件: {temp_gt_file.name}, {temp_probs_file.name}") # --- 4. 主推理循环 - 预测图片 --- for image_path in tqdm(image_filenames, desc="正在处理测试图像"): # --- 4.1. 读取图片并预处理 --- original_image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB) original_h, original_w = original_image.shape[:2] image_tensor = transform(image=original_image)['image'].unsqueeze(0).to(config.DEVICE) # --- 4.2. 图片推理并保存 --- with torch.no_grad(): pred_logits_tensor = model(image_tensor) pred_mask_tensor = torch.argmax(pred_logits_tensor, dim=1) pred_mask_numpy = pred_mask_tensor.squeeze().cpu().numpy().astype(np.uint8) pred_mask_resized_raw = cv2.resize(pred_mask_numpy, (original_w, original_h), interpolation=cv2.INTER_NEAREST) # 保存原始(单通道)的预测掩码 cv2.imwrite(str(raw_mask_dir / f"{image_path.stem}.png"), pred_mask_resized_raw) # --- 4.3. 评估与可视化 --- gt_mask_numpy, per_image_stats, stats_text_display, gt_mask_raw = None, None, "", None if evaluate_mode: mask_path = config.TEST_MASK_DIR / image_path.name if mask_path.exists(): gt_mask_raw = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) gt_mask_numpy = cv2.resize( gt_mask_raw, (target_width, target_height), interpolation=cv2.INTER_NEAREST ) gt_mask_tensor = torch.from_numpy(gt_mask_numpy).long().to(config.DEVICE).unsqueeze(0) # ======================== 【修复代码段开始】 ======================== # # --- 1. [修复] 计算单张图像的统计数据 --- # 这个代码块是必需的,它为下面的 per_image_stats 和 stats_text_display 定义了 tp, fp, fn, tn tp, fp, fn, tn = smp.metrics.get_stats( pred_mask_tensor, gt_mask_tensor, mode=config.SEG_MODE, num_classes=num_classes ) # --- 2. 累加到总数,用于最终的整体评估 --- total_tp += tp.squeeze(0).to(config.DEVICE) total_fp += fp.squeeze(0).to(config.DEVICE) total_fn += fn.squeeze(0).to(config.DEVICE) total_tn += tn.squeeze(0).to(config.DEVICE) # --- 3. 将曲线数据写入临时文件 (这部分逻辑正确,无需修改) --- probs_tensor = torch.softmax(pred_logits_tensor, dim=1) gt_flat = gt_mask_tensor.cpu().numpy().flatten() probs_flat = probs_tensor.permute(0, 2, 3, 1).reshape(-1, num_classes).cpu().numpy() with open(temp_gt_file.name, 'ab') as f_gt, open(temp_probs_file.name, 'ab') as f_probs: np.save(f_gt, gt_flat) np.save(f_probs, probs_flat) # --- 4. [修复] 现在可以安全地计算单图指标用于可视化 --- # 因为上面的代码块已经定义了 tp, fp, fn, tn,所以这里不再报错 per_image_stats = { 'iou': smp.metrics.iou_score(tp, fp, fn, tn, reduction='micro').item(), 'acc': smp.metrics.accuracy(tp, fp, fn, tn, reduction='micro').item() } # --- 5. [修复] 修正生成统计文本时的 Bug --- # `get_stats` 返回的 tp 张量形状是 (1, num_classes), # 因此我们应该遍历第1维度 (类别),并使用 [0, i] 进行索引 stats_text_display = "Each Class TP / FP / FN / TN:\n" + "-"*35 + "\n" for i in range(1, tp.shape[1]): # 修正: 遍历类别维度 tp.shape[1] if i < len(config.CLASSES): class_name = config.CLASSES[i] stats_text_display += ( f" {class_name:<10}: " # 修正: 使用正确的索引 tp[0, i] f"{tp[0, i]:>5} / {fp[0, i]:>5} / {fn[0, i]:>5} / {tn[0, i]:>7}\n" ) save_visual_comparison( image_name=image_path.name, original_image=original_image, pred_mask=pred_mask_resized_raw, gt_mask=gt_mask_raw, stats=per_image_stats, stats_text=stats_text_display, save_dir = analysis_results_dir ) logging.info("\n预测流程已完成!") logging.info(f"原始预测掩码已保存至: {raw_mask_dir}") logging.info(f"分析图像已保存至: {analysis_results_dir}") # --- 5. 最终分析(循环结束后) --- if evaluate_mode: # calculate_and_save_final_metrics(all_stats_for_metrics, len(config.CLASSES), save_dir = analysis_results_dir) calculate_and_save_final_metrics(total_tp.cpu(), total_fp.cpu(), total_fn.cpu(), total_tn.cpu(), num_classes, save_dir=analysis_results_dir) # if all_gt_for_curves: # logging.info("\n正在聚合结果以生成 ROC 和 PR 曲线...") # full_gt = torch.cat(all_gt_for_curves).numpy().flatten() # full_logits = torch.cat(all_logits_for_curves) # # 调整维度并计算概率 # num_classes = len(config.CLASSES) # full_probs = softmax( # full_logits.permute(0, 2, 3, 1).reshape(-1, num_classes).numpy(), # axis=1 # ) # generate_and_save_curves( # y_true=full_gt, y_probs=full_probs, # class_names=config.CLASSES, output_dir=analysis_results_dir # ) # 【新增】从临时文件中分块读取数据来生成曲线 try: logging.info("\n正在从临时文件加载数据以生成 ROC 和 PR 曲线...") temp_gt_file.close() temp_probs_file.close() all_gt_parts, all_probs_parts = [], [] with open(temp_gt_file.name, 'rb') as f_gt, open(temp_probs_file.name, 'rb') as f_probs: while True: try: all_gt_parts.append(np.load(f_gt)) all_probs_parts.append(np.load(f_probs)) # 【修复】将 EOFError 添加到要捕获的异常列表中 except (IOError, ValueError, EOFError): break # 文件读取完毕 if all_gt_parts: full_gt = np.concatenate(all_gt_parts) full_probs = np.concatenate(all_probs_parts) generate_and_save_curves( y_true=full_gt, y_probs=full_probs, class_names=config.CLASSES, output_dir=analysis_results_dir ) else: logging.warning("临时文件中没有数据,跳过 ROC/PR 曲线生成。") finally: # 【新增】确保在最后删除临时文件 import os os.remove(temp_gt_file.name) os.remove(temp_probs_file.name) logging.info("已清理临时文件。") if __name__ == "__main__": # 创建一个参数解析器 parser = argparse.ArgumentParser(description="训练图像分割模型。") # 添加 --architecture 参数 # - `choices` 会自动从你的 config 文件中获取所有可用的模型名称 # - `default` 设置了在不提供参数时的默认模型 parser.add_argument( "-a", "--architecture", type=str, # default='Unet', choices=list(config.ALL_MODEL_CONFIGS.keys()), required=True, help="选择要训练的模型架构。" ) # 解析命令行传入的参数 args = parser.parse_args() # 将解析出的架构名称传入 main 函数并执行 main(model_architecture=args.architecture)