first commit
This commit is contained in:
153
Seg_All_In_One_SegModel/config.py
Normal file
153
Seg_All_In_One_SegModel/config.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
# --- 1. 核心目录设置 (Core Directories) ---
|
||||
# 使用 pathlib 进行路径管理,更现代、更健壮
|
||||
HARDISK_DIR = Path('../Hardisk')
|
||||
DATA_SETS_DIR = Path('../DataSet_Public_outputs') # 模型保存位置
|
||||
PREDICT_ALL_BEST_MODELS_DIR = Path('../BestMode_Predict_Results_DataSet_Public') # 预测模型存储位置
|
||||
|
||||
# V1: 1_CholecSeg8k-13Type-1920x1080
|
||||
# DATA_DIR = Path('../DataSet_Public/1_CholecSeg8k-13Type-1920x1080') # Path(__file__).parent.parent【有中文放弃了】 # 项目根目录 (假设 config.py 在 src/ 下)
|
||||
# OUTPUTS_DIR = DATA_SETS_DIR / "1_CholecSeg8k-13Type-1920x1080_outputs-SegModel" # 所有输出文件的根目录 # train 中 Path(config.OUTPUTS_DIR / architecture)
|
||||
# PREDICT_BEST_MODEL_DIR = PREDICT_ALL_BEST_MODELS_DIR / "1_CholecSeg8k-13Type-1920x1080_outputs-SegModel" # 最优模型位置 # train 中 Path(config.PREDICT_BEST_MODEL_DIR / architecture)
|
||||
# # V2:2_AutoLaparo-10Type-1920x1080
|
||||
# DATA_DIR = Path('../DataSet_Public/2_AutoLaparo-10Type-1920x1080') # Path(__file__).parent.parent【有中文放弃了】 # 项目根目录 (假设 config.py 在 src/ 下)
|
||||
# OUTPUTS_DIR = DATA_SETS_DIR / "2_AutoLaparo-10Type-1920x1080_outputs-SegModel" # 所有输出文件的根目录 # train 中 Path(config.OUTPUTS_DIR / architecture)
|
||||
# PREDICT_BEST_MODEL_DIR = PREDICT_ALL_BEST_MODELS_DIR / "2_AutoLaparo-10Type-1920x1080_outputs-SegModel" # 最优模型位置 # train 中 Path(config.PREDICT_BEST_MODEL_DIR / architecture)
|
||||
# # V3:3_1_Endovis_2017-8Type-512x512
|
||||
# DATA_DIR = Path('../DataSet_Public/3_1_Endovis_2017-8Type-512x512') # Path(__file__).parent.parent【有中文放弃了】 # 项目根目录 (假设 config.py 在 src/ 下)
|
||||
# OUTPUTS_DIR = DATA_SETS_DIR / "3_1_Endovis_2017-8Type-512x512_outputs-SegModel" # 所有输出文件的根目录 # train 中 Path(config.OUTPUTS_DIR / architecture)
|
||||
# PREDICT_BEST_MODEL_DIR = PREDICT_ALL_BEST_MODELS_DIR / "3_1_Endovis_2017-8Type-512x512_outputs-SegModel" # 最优模型位置 # train 中 Path(config.PREDICT_BEST_MODEL_DIR / architecture)
|
||||
# # V4:3_2_Endovis_2018-8Type-512x512
|
||||
# DATA_DIR = Path('../DataSet_Public/3_2_Endovis_2018-8Type-512x512') # Path(__file__).parent.parent【有中文放弃了】 # 项目根目录 (假设 config.py 在 src/ 下)
|
||||
# OUTPUTS_DIR = DATA_SETS_DIR / "3_2_Endovis_2018-8Type-512x512_outputs-SegModel" # 所有输出文件的根目录 # train 中 Path(config.OUTPUTS_DIR / architecture)
|
||||
# PREDICT_BEST_MODEL_DIR = PREDICT_ALL_BEST_MODELS_DIR / "3_2_Endovis_2018-8Type-512x512_outputs-SegModel" # 最优模型位置 # train 中 Path(config.PREDICT_BEST_MODEL_DIR / architecture)
|
||||
# # V5:4_Dresden-11Type-512x512
|
||||
# DATA_DIR = Path('../DataSet_Public/4_Dresden-11Type-512x512') # Path(__file__).parent.parent【有中文放弃了】 # 项目根目录 (假设 config.py 在 src/ 下)
|
||||
# OUTPUTS_DIR = DATA_SETS_DIR / "4_Dresden-11Type-512x512_outputs-SegModel" # 所有输出文件的根目录 # train 中 Path(config.OUTPUTS_DIR / architecture)
|
||||
# PREDICT_BEST_MODEL_DIR = PREDICT_ALL_BEST_MODELS_DIR / "4_Dresden-11Type-512x512_outputs-SegModel" # 最优模型位置 # train 中 Path(config.PREDICT_BEST_MODEL_DIR / architecture)
|
||||
# # Test_V1:5_Predict_Video
|
||||
DATA_DIR = Path('../DataSet_Public/5_Predict_Video/LC_Video_1') # Path(__file__).parent.parent【有中文放弃了】 # 项目根目录 (假设 config.py 在 src/ 下)
|
||||
OUTPUTS_DIR = None # 所有输出文件的根目录 # train 中 Path(config.OUTPUTS_DIR / architecture)
|
||||
PREDICT_BEST_MODEL_DIR = PREDICT_ALL_BEST_MODELS_DIR / "LC_Video_1_outputs-SegModel" # 最优模型位置 # train 中 Path(config.PREDICT_BEST_MODEL_DIR / architecture)
|
||||
|
||||
# --- 2. 训练与验证数据路径 (Training & Validation Paths for train.py) ---
|
||||
# 在 train.py 中已使用
|
||||
TRAIN_IMAGE_DIR = DATA_DIR / "images" / "train"
|
||||
TRAIN_MASK_DIR = DATA_DIR / "labels_GT" / "train"
|
||||
VAL_IMAGE_DIR = DATA_DIR / "images" / "val" # TODO "val_images"
|
||||
VAL_MASK_DIR = DATA_DIR / "labels_GT" / "val" # TODO "val_masks"
|
||||
|
||||
# --- 3. 预测数据路径 (Prediction Paths for predict.py) ---
|
||||
TEST_IMAGE_DIR = DATA_DIR / "images" / "val" # 测试图像目录 # TODO "test_images"
|
||||
TEST_MASK_DIR = DATA_DIR / "labels_GT" / "val" # 测试掩码目录 (用于评估) # TODO "test_masks"
|
||||
|
||||
# --- 4. 输出文件与目录路径 (Output Files & Directories) ---
|
||||
RAW_MASK_FOLDER = "predicted_raw_masks" # 存放预测出的单通道原始掩码
|
||||
ANALYSIS_RESULTS_FOLDER = "prediction_analysis" # 存放对比图、曲线和指标CSV
|
||||
|
||||
# 训练过程中的输出文件
|
||||
BEST_MODEL_SAVE_NAME = "best_model.pth"
|
||||
METRICS_CSV_NAME = "training_metrics.csv"
|
||||
|
||||
# --- 5. 模型与数据参数 (Model & Data Parameters) ---
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# --- 5.1. 【新增】模型选择与参数化配置 ---
|
||||
# ※ 不定义训练用模型,在train、predict中定义 ※
|
||||
|
||||
# 这是一个“配置库”,存放了所有模型的参数设置 # 具体请参考: https://smp.readthedocs.io/en/latest/models.html
|
||||
ALL_MODEL_CONFIGS = {
|
||||
'Unet': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
'decoder_channels': (256, 128, 64, 32, 16),
|
||||
'decoder_attention_type': None, # 可选 'scse'
|
||||
},
|
||||
'UnetPlusPlus': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
'decoder_channels': (256, 128, 64, 32, 16),
|
||||
'decoder_attention_type': None, # 可选 'scse'
|
||||
},
|
||||
'FPN': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
},
|
||||
'PSPNet': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
},
|
||||
'DeepLabV3': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
},
|
||||
'DeepLabV3Plus': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
},
|
||||
|
||||
'Linknet': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
},
|
||||
'MAnet': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
},
|
||||
'PAN': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
},
|
||||
'UPerNet': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
},
|
||||
'Segformer': {
|
||||
'encoder_name': 'resnet34',
|
||||
'encoder_weights': 'imagenet',
|
||||
},
|
||||
'DPT': {
|
||||
'encoder_name': 'tu-vit_base_patch16_224.augreg_in21k',
|
||||
'encoder_weights': 'imagenet',
|
||||
}
|
||||
}
|
||||
|
||||
SEG_MODE = "multiclass" # 分割模式:'multiclass' 或 'multilabel'
|
||||
|
||||
# TODO 评估参数排除项:在计算评估指标 (如 IoU, F1-score) 时要忽略的类别列表。 TODO
|
||||
EVALUATION_CLASSES_TO_IGNORE = [] # 如果列表为空 [], 则评估所有类别。
|
||||
# EVALUATION_CLASSES_TO_IGNORE = ['background'] # 例如,设置为 ['background'] 将在计算总体指标时排除背景类。
|
||||
IGNORE_INDEX = -100 # 当不包含背景时,掩码中背景像素的值,损失函数会忽略它
|
||||
|
||||
# V1:1_CholecSeg8k-13Type-1920x1080
|
||||
CLASSES = ['background', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12'] # 根据您的数据集修改
|
||||
CLASS_RGB_VALUES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # 根据您的数据集修改
|
||||
# V2:2_AutoLaparo-10Type-1920x1080
|
||||
# CLASSES = ['background', '1', '2', '3', '4', '5', '6', '7', '8', '9'] # 根据您的数据集修改
|
||||
# CLASS_RGB_VALUES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] # 根据您的数据集修改
|
||||
# V3:3_1_Endovis_2017-8Type-512x512
|
||||
# CLASSES = ['background', '1', '2', '3', '4', '5', '6', '7'] # 根据您的数据集修改
|
||||
# CLASS_RGB_VALUES = [0, 1, 2, 3, 4, 5, 6, 7] # 根据您的数据集修改
|
||||
# V4:3_2_Endovis_2018-8Type-512x512
|
||||
# CLASSES = ['background', '1', '2', '3', '4', '5', '6', '7'] # 根据您的数据集修改
|
||||
# CLASS_RGB_VALUES = [0, 1, 2, 3, 4, 5, 6, 7] # 根据您的数据集修改
|
||||
# V5:4_Dresden-11Type-512x512
|
||||
# CLASSES = ['background', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10'] # 根据您的数据集修改
|
||||
# CLASS_RGB_VALUES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # 根据您的数据集修改
|
||||
|
||||
# 图像尺寸
|
||||
IMAGE_HEIGHT = 256 # 512
|
||||
IMAGE_WIDTH = 256 # 512
|
||||
|
||||
# --- 6. 训练超参数 (Training Hyperparameters) ---
|
||||
BATCH_SIZE = 16
|
||||
NUM_WORKERS = 8
|
||||
EPOCHS = 300
|
||||
LEARNING_RATE = 1e-4 # 1e-3 # MANet_降低学习率至 1e-4
|
||||
WEIGHT_DECAY = 1e-4 # L2 正则化权重衰减
|
||||
PIN_MEMORY = True # 是否将数据加载到锁页内存中以加速传输到 GPU
|
||||
FBETA_BETA = 1.0 # F-beta score 的 beta 值。beta=1 等同于 F1-score
|
||||
EARLY_STOPPING_PATIENCE = 100 # 早停机制的耐心值
|
||||
CKPT_SAVE_INTERVAL = 10 # 每隔多少轮保存一次检查点模型
|
||||
Reference in New Issue
Block a user