first commit

This commit is contained in:
admin
2026-05-20 15:05:35 +08:00
commit ac09b26253
2048 changed files with 189478 additions and 0 deletions

View File

@@ -0,0 +1,117 @@
#!/bin/bash
# +----------------------------------------------------------------------------+
# | 脚本: 将指定的 .pt 文件 (默认为 best.pt) 从 Nas_BackUp_Seg 复制到 BestMode... |
# | 功能: 1. 自动查找源目录中所有 "weights/[指定文件名].pt" 文件。 |
# | 2. 将其复制到目标目录,并从路径中移除 "/weights" 中间层。
# | 3. 仅在目标文件不存在或内容不一致时才执行复制,提高效率。
# | 4. 提供可视化进度条和最终统计。
# | V3.0: «-- 修改: 改为使用 --pt_name 命名参数 |
# | 用法: ./script.sh [--pt_name "文件名.pt"] (例如: ./script.sh --pt_name "100.pt") «-- 修改
# +----------------------------------------------------------------------------+
# --- 1. 目录设置 ---
SOURCE_BASE_DIR=$(realpath "../Hardisk")
DEST_BASE_DIR=$(realpath "../BestMode_Predict_Results_DataSet_Public")
# --- 1b. 参数设置 --- «-- 修改: 替换为参数解析循环
# 设置默认文件名
TARGET_FILENAME="best.pt"
# 解析命令行参数
while [[ "$#" -gt 0 ]]; do
case $1 in
--pt_name)
if [[ -n "$2" && ! "$2" =~ ^-- ]]; then
TARGET_FILENAME="$2"
shift # 消耗 "--pt_name"
shift # 消耗它的值
else
echo "错误: '--pt_name' 需要一个文件名参数。" >&2
exit 1
fi
;;
*)
echo "未知参数: $1" >&2
exit 1
;;
esac
done
# --- 2. 环境检查 ---
if [ ! -d "$SOURCE_BASE_DIR" ]; then
echo "错误: 源目录 '$SOURCE_BASE_DIR' 不存在。"
exit 1
fi
mkdir -p "$DEST_BASE_DIR"
echo "正在扫描源目录: $SOURCE_BASE_DIR"
echo "目标模型文件: $TARGET_FILENAME" # 告知用户当前在找哪个文件
# --- 3. 查找所有符合条件的文件 ---
echo "正在查找所有符合 '...-Yolo/YOLO*_*/weights/$TARGET_FILENAME' 结构的文件..."
readarray -d '' files_to_process < <(find "$SOURCE_BASE_DIR"/*-Yolo -path "*/YOLO*_*/weights/$TARGET_FILENAME" -type f -print0 2>/dev/null)
TOTAL_FILES=${#files_to_process[@]}
if [ "$TOTAL_FILES" -eq 0 ]; then
echo "在源目录中没有找到任何 '$TARGET_FILENAME' 文件。脚本退出。"
exit 0
fi
echo "总共找到 $TOTAL_FILES 个 '$TARGET_FILENAME' 文件需要处理。"
echo "--------------------------------------------------"
# --- 4. 进度条函数 ---
print_progress() {
local current=$1
local total=$2
local term_width=$(tput cols)
local bar_width=$((term_width - 35)) # 为文本留出更多空间
if [ "$bar_width" -lt 10 ]; then
bar_width=10
fi
local percent=$((current * 100 / total))
local filled_len=$((bar_width * percent / 100))
local bar=""
for ((i=0; i<filled_len; i++)); do bar+="="; done
for ((i=filled_len; i<bar_width; i++)); do bar+=" "; done
printf "\r拷贝进度: %3d%% [%s] %d/%d" $percent "$bar" $current $total
}
# --- 5. 执行复制操作 ---
CURRENT_FILE=0
COPIED_COUNT=0
SKIPPED_COUNT=0
for source_file_path in "${files_to_process[@]}"; do
if [[ -z "$source_file_path" ]]; then
continue
fi
((CURRENT_FILE++))
relative_path="${source_file_path#"$SOURCE_BASE_DIR"}"
dest_file_path="${DEST_BASE_DIR}${relative_path}"
dest_dir_path=$(dirname "$dest_file_path")
if [ ! -f "$dest_file_path" ] || ! cmp -s "$source_file_path" "$dest_file_path"; then
mkdir -p "$dest_dir_path"
cp "$source_file_path" "$dest_file_path"
((COPIED_COUNT++))
else
((SKIPPED_COUNT++))
fi
print_progress $CURRENT_FILE $TOTAL_FILES
done
# --- 6. 打印最终总结 ---
echo ""
echo "--------------------------------------------------"
echo "所有操作已完成!"
echo " - 总共处理文件: $TOTAL_FILES (目标文件: $TARGET_FILENAME)"
echo " - 已复制或更新的文件: $COPIED_COUNT"
echo " - 因已存在且内容一致而跳过的文件: $SKIPPED_COUNT"

View File

@@ -0,0 +1,2 @@
# 本手册主要目的在于判断给Yolo模型的各个层是否能成功生成热图
使用方法: python yolo_layer_tester.py --model "YOLOv8m-seg" --cam_method "GradCAM" --pt_name "best.pt"

View File

@@ -0,0 +1,359 @@
import logging
import argparse
from pathlib import Path
import cv2
import numpy as np
import torch
from ultralytics import YOLO
import random
import sys
import os
from typing import Dict, Any, Tuple, List # [!!] 确保导入 List
# --- [!! 关键修改 !!] 动态添加项目根目录到 sys.path ---
# (来自您之前的目录结构)
script_path = Path(__file__).resolve()
project_root = script_path.parent.parent
sys.path.insert(0, str(project_root))
# ----------------------------------------------------
# 现在下面这行导入就可以正常工作了
import yolo_config as config
# --- [导入 V3 的 CAM 库] ---
from pytorch_grad_cam import (
GradCAM, GradCAMPlusPlus, XGradCAM, EigenCAM,
HiResCAM, LayerCAM, RandomCAM, EigenGradCAM
)
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.base_cam import BaseCAM
# --- 日志设置 ---
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()]
)
logging.getLogger("pytorch_grad_cam").setLevel(logging.WARNING)
# --- [!! 新增 !!] 从 V1 移植的函数 ---
def find_trained_models(outputs_dir: Path, model_key: str, pt_name: str) -> List[str]:
"""
扫描输出目录,查找特定基础模型的所有有效的、已完成的训练项目。
(函数来自 yolo_predict_visualize_nn_V1.py)
"""
trained_models = []
if not outputs_dir.is_dir():
logging.warning(f"输出目录不存在: {outputs_dir}")
return []
for project_folder in outputs_dir.iterdir():
if project_folder.is_dir() and project_folder.name.startswith(model_key + '_'):
if (project_folder / 'weights' / pt_name).exists():
trained_models.append(project_folder.name)
trained_models.sort()
return trained_models
# -------------------------------------
# --- [保留] V3 的 CAM 方法字典 ---
CAM_METHODS = {
"GradCAM": GradCAM,
"GradCAMPlusPlus": GradCAMPlusPlus,
"XGradCAM": XGradCAM,
"EigenCAM": EigenCAM,
"HiResCAM": HiResCAM,
"LayerCAM": LayerCAM,
"RandomCAM": RandomCAM,
"EigenGradCAM": EigenGradCAM,
}
# --- [保留] V3 的 Target ---
class ActivationMaximizationTarget:
def __init__(self, channel=0):
self.channel = channel
def __call__(self, model_output):
if model_output.ndim == 4:
return model_output[:, self.channel, :, :].mean()
elif model_output.ndim == 3:
return model_output[self.channel, :, :].mean()
else:
raise ValueError(f"Unsupported model_output shape: {model_output.shape}")
# --- [保留] V3 的预处理 ---
def preprocess_image(bgr_img: np.ndarray, imgsz: int, device: str) -> (torch.Tensor, Dict[str, Any]):
# ... (此函数内容与上一版完全相同,为节省篇幅已折叠) ...
orig_h, orig_w = bgr_img.shape[:2]
rgb_img = bgr_img[:, :, ::-1]
scale = min(imgsz / orig_w, imgsz / orig_h)
new_w, new_h = int(orig_w * scale), int(orig_h * scale)
resized_image = cv2.resize(rgb_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
pad_w, pad_h = imgsz - new_w, imgsz - new_h
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
padded_image = cv2.copyMakeBorder(resized_image, pad_top, pad_bottom, pad_left, pad_right,
cv2.BORDER_CONSTANT, value=(114, 114, 114))
img_float = np.float32(padded_image) / 255.0
input_tensor = torch.from_numpy(img_float).permute(2, 0, 1).unsqueeze(0).to(device)
padding_info = {
"pad_top": pad_top, "new_h": new_h,
"pad_left": pad_left, "new_w": new_w,
"orig_h": orig_h, "orig_w": orig_w
}
return input_tensor, padding_info
# --- [保留] V1/V3 的关键修复 ---
def reshape_transform(outputs):
if isinstance(outputs, tuple):
return outputs[0]
return outputs
# --- [!! 核心功能 !!] (此函数本身无需修改) ---
def generate_heatmaps_for_all_layers(model_path: str, image_path: str, cam_method_name: str, model_name: str):
"""
加载模型和图像,遍历所有层,并为每个成功的层生成热图。
(此函数内容与上一版完全相同,为节省篇幅已折叠)
"""
# --- 1. 获取 CAM 方法 ---
cam_class = CAM_METHODS.get(cam_method_name)
if not cam_class:
logging.error(f"无效的 CAM 方法: {cam_method_name}")
return
# --- 2. 定义 Targets ---
targets = None
if cam_method_name != "EigenCAM":
targets = [ActivationMaximizationTarget(channel=0)]
logging.info(f"已为 {cam_method_name} 设置 ActivationMaximizationTarget(channel=0)。")
else:
logging.info(f"将使用 {cam_method_name} (无需 Targets)。")
# --- 3. 加载模型 ---
try:
logging.info(f"正在加载模型: {model_path}")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = YOLO(model_path).to(device)
model.model.eval()
imgsz = model.model.args['imgsz']
logging.info(f"模型加载成功。设备: {device}, 图像尺寸: {imgsz}")
except Exception as e:
logging.error(f"无法加载模型 {model_path}: {e}", exc_info=True)
return
# --- 4. 加载和预处理图像 ---
try:
logging.info(f"正在加载测试图像: {image_path}")
orig_img_bgr = cv2.imread(image_path)
if orig_img_bgr is None:
logging.error(f"无法读取图像: {image_path}")
return
orig_img_rgb_float = np.float32(orig_img_bgr[:, :, ::-1]) / 255.0
input_tensor, pad_info = preprocess_image(orig_img_bgr, imgsz, device)
input_tensor.requires_grad_()
logging.info(f"图像加载并预处理完成。Tensor shape: {input_tensor.shape}")
except Exception as e:
logging.error(f"处理图像时出错: {e}", exc_info=True)
return
# --- 5. 创建输出目录 (您的自定义逻辑) ---
image_stem = Path(image_path).stem
output_dir_name = f"./{cam_method_name}_{model_name}_{image_stem}_All_HeartMap"
output_dir = Path(output_dir_name)
output_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"所有热图将保存到: {output_dir.resolve()}")
# --- 6. 遍历并测试所有层 ---
base_model_layers = model.model.model
successful_layers, failed_layers = [], []
logging.info(f"\n" + "="*50)
logging.info(f"开始遍历 {len(base_model_layers)} 个层 (使用 {cam_method_name})")
logging.info("\n" + "="*50 + "\n")
for i, layer in enumerate(base_model_layers):
layer_name = f"model.model.model[{i}]"
layer_type = layer.__class__.__name__
print(f"--- [ {i:02d} / {len(base_model_layers)-1} ] 尝试层: {layer_name:<25} (类型: {layer_type}) ---")
try:
with cam_class(model=model.model, target_layers=[layer], reshape_transform=reshape_transform) as cam:
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
if grayscale_cam is None: raise ValueError("CAM-Method 返回了 None")
# ... (结果提取、裁剪、保存逻辑) ...
if isinstance(grayscale_cam, (list, tuple)): result = grayscale_cam[0]
else: result = grayscale_cam
if isinstance(result, torch.Tensor): result = result.detach().cpu().numpy()
if result.ndim == 4: result = result[0].mean(axis=0)
elif result.ndim == 3: result = result[0, :]
elif result.ndim != 2: raise TypeError(f"CAM 输出维度异常 {result.shape}")
cam_cropped = result[pad_info['pad_top'] : pad_info['pad_top'] + pad_info['new_h'],
pad_info['pad_left'] : pad_info['pad_left'] + pad_info['new_w']]
cam_upscaled = cv2.resize(cam_cropped, (pad_info['orig_w'], pad_info['orig_h']), interpolation=cv2.INTER_LINEAR)
cam_normalized = (cam_upscaled - cam_upscaled.min()) / (cam_upscaled.max() + 1e-8)
visualization = show_cam_on_image(orig_img_rgb_float, cam_normalized, use_rgb=True, image_weight=0.5)
viz_bgr = visualization[:, :, ::-1]
heatmap_uint8 = np.uint8(255 * cam_normalized)
heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
base_filename = f"Layer_{i:02d}_{layer_type}"
cv2.imwrite(str(output_dir / f"{base_filename}_overlay.jpg"), viz_bgr)
# cv2.imwrite(str(output_dir / f"{base_filename}_heatmap.jpg"), heatmap_color)
print(f" [√] 成功! 已保存 {base_filename}_overlay.jpg 和 _heatmap.jpg")
successful_layers.append((i, layer_name, layer_type))
except Exception as e:
error_msg = str(e).split('\n')[0]
print(f" [X] 失败. 错误: {error_msg}")
failed_layers.append((i, layer_name, layer_type, error_msg))
# --- 7. 打印总结报告 ---
# ... (报告逻辑与上一版相同) ...
print("\n" + "="*60)
print(" 全层热图生成报告")
print("="*60)
print(f"模型: {Path(model_path).name} (来自: {Path(model_path).parent.parent})")
print(f"测试图像: {Path(image_path).name}")
print(f"测试方法: {cam_method_name}")
print(f"输出目录: {output_dir.resolve()}")
print(f"总计: {len(successful_layers)} 个成功 (已生成), {len(failed_layers)} 个失败/跳过")
print("\n" + "--- [√] 成功生成的层 ---")
if not successful_layers: print(" (无)")
else:
for i, name, type in successful_layers: print(f" - [{i:02d}] {name:<25} (类型: {type})")
print("\n" + "--- [X] 失败或跳过的层 ---")
if not failed_layers: print(" (无)")
else:
for i, name, type, error in failed_layers:
print(f" - [{i:02d}] {name:<25} (类型: {type})")
print(f" └> 失败原因: {error[:100]}...")
print("\n" + "="*60)
if __name__ == "__main__":
# --- [!! 重构 !!] (以下逻辑来自 yolo_predict_visualize_nn_V1.py) ---
config.show_config_summary()
parser = argparse.ArgumentParser(description="全层热图生成器 (V1 模型选择模式)")
# 1. (来自 V1) --model 参数
parser.add_argument(
"--model",
type=str,
required=True,
choices=list(config.MODEL_CONFIGS.keys()),
help="选择一个基础模型类型来筛选其训练历史。"
)
# 2. (来自 V1) --pt_name 参数
parser.add_argument(
"--pt_name",
type=str,
default="best.pt",
help="要使用的权重文件名 (例如 'best.pt''epoch100.pt')。"
)
# 3. (来自您) --source 参数
parser.add_argument(
"--source",
type=str,
default="random",
help="测试图像源: 'random' (从 config.TEST_IMAGE_DIR 随机选), "
"'path/to/image.jpg' (指定文件), "
"'path/to/dir/' (从该目录随机选)"
)
# 4. (来自您) --cam_method 参数
parser.add_argument(
"--cam_method",
type=str,
default="GradCAM",
choices=list(CAM_METHODS.keys()),
help=f"选择要使用的 CAM 可视化方法。默认为: GradCAM。"
)
args = parser.parse_args()
# --- (来自 V1) 模型搜索和选择逻辑 ---
available_runs = find_trained_models(config.PREDICT_BEST_MODEL_DIR, args.model, args.pt_name)
run_to_use = None
if not available_runs:
logging.error(f"错误:在 {config.PREDICT_BEST_MODEL_DIR} 中未找到模型 '{args.model}' (使用 '{args.pt_name}') 的有效训练记录。")
sys.exit(1)
elif len(available_runs) == 1:
run_to_use = available_runs[0]
logging.info(f"只找到一个训练版本,已自动选择: {run_to_use}")
else:
print(f"\n为模型 '{args.model}' (使用 '{args.pt_name}') 找到多个训练版本:")
for i, run_name in enumerate(available_runs, 1):
print(f" [{i}] {run_name}")
while True:
try:
choice = input(f"请输入您想使用的版本序号 (1-{len(available_runs)}): ")
choice_index = int(choice)
if 1 <= choice_index <= len(available_runs):
run_to_use = available_runs[choice_index - 1]
break
else:
print("错误:输入无效,请输入列表中的序号。")
except ValueError:
print("错误:请输入一个数字。")
except (KeyboardInterrupt, EOFError):
print("\n操作已取消。")
sys.exit(0)
# --- (!! 融合 !!) 执行逻辑 ---
if run_to_use:
# 1. (来自 V1) 确定最终模型路径
model_to_use_path = config.PREDICT_BEST_MODEL_DIR / run_to_use / 'weights' / args.pt_name
if not model_to_use_path.exists():
logging.error(f"严重错误:找不到权重文件 {model_to_use_path}")
sys.exit(1)
# 2. (来自您) 确定最终图像路径
source_input = args.source
image_to_test = None
source_dir_for_random = None
if source_input.lower() == "random":
source_dir_for_random = config.TEST_IMAGE_DIR
logging.info(f"Source='random'。将从 config.TEST_IMAGE_DIR 随机选择图像: {source_dir_for_random}")
elif Path(source_input).is_file():
image_to_test = source_input
logging.info(f"Source 是一个特定文件: {image_to_test}")
elif Path(source_input).is_dir():
source_dir_for_random = Path(source_input)
logging.info(f"Source 是一个目录。将从该目录随机选择图像: {source_dir_for_random}")
else:
logging.error(f"错误: 指定的源路径无效: {source_input}")
sys.exit(1)
# 如果需要随机选择...
if image_to_test is None:
if not source_dir_for_random or not source_dir_for_random.is_dir():
logging.error(f"错误: 用于随机选择的目录无效: {source_dir_for_random}")
sys.exit(1)
image_paths = sorted(
list(source_dir_for_random.glob("*.jpg")) +
list(source_dir_for_random.glob("*.jpeg")) +
list(source_dir_for_random.glob("*.png"))
)
if not image_paths:
logging.error(f"错误: 在目录 {source_dir_for_random} 中未找到任何图像 (.jpg, .jpeg, .png)。")
sys.exit(1)
image_to_test = str(random.choice(image_paths))
logging.info(f"已随机选择图像: {Path(image_to_test).name}")
# 3. (融合) 调用核心函数
if image_to_test:
generate_heatmaps_for_all_layers(
model_path=str(model_to_use_path), # <-- 使用 V1 找到的路径
image_path=image_to_test, # <-- 使用您的逻辑找到的路径
cam_method_name=args.cam_method, # <-- 使用您的参数
model_name = args.model
)

View File

@@ -0,0 +1,261 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import sys
from pathlib import Path
from typing import Set, Dict, Tuple
# <--- 新增: 导入 Pillow 库用于图像处理 ---
try:
from PIL import Image
# 禁用解压缩炸弹检查,以防万一标签图非常大且简单
Image.MAX_IMAGE_PIXELS = None
except ImportError:
print("错误: 本脚本需要 'Pillow' 库来检查图像尺寸。", file=sys.stderr)
print("请运行: pip install Pillow", file=sys.stderr)
sys.exit(1)
# --- 新增结束 ---
# 定义支持的文件扩展名
VALID_EXTENSIONS = {'.png', '.jpg', '.jpeg'}
def process_directory(path: Path,
prefix: str = "",
suffix: str = "") -> Tuple[Set[str], Dict[str, Path], Dict[Tuple[int, int], int]]: # <--- 修改:更新了返回类型
"""
处理单个目录提取所有有效文件的文件名stem并根据前缀和后缀进行规范化。
同时统计所有图像的尺寸。
返回:
normalized_stems (Set[str]): 规范化处理后的文件名集合。
normalized_to_full_path (Dict[str, Path]): 规范化文件名 -> 原始完整路径 的映射。
size_counts (Dict[Tuple[int, int], int]): (宽度, 高度) -> 数量 的映射。 # <--- 新增
"""
if not path.is_dir():
print(f"错误: 路径 '{path}' 不是一个有效的目录。", file=sys.stderr)
return set(), {}, {} # <--- 修改
normalized_stems: Set[str] = set()
normalized_to_full_path: Dict[str, Path] = {}
size_counts: Dict[Tuple[int, int], int] = {} # <--- 新增:用于统计尺寸
for file_path in path.glob('*'):
# 确保是文件,并且扩展名在我们的有效列表中
if file_path.is_file() and file_path.suffix.lower() in VALID_EXTENSIONS:
# <--- 新增: 尝试获取图像尺寸 ---
try:
# 使用 with...as... 确保文件被正确关闭
with Image.open(file_path) as img:
size: Tuple[int, int] = img.size
size_counts[size] = size_counts.get(size, 0) + 1
except Exception as e:
print(f"警告: 无法读取图像 '{file_path.name}' 的尺寸. 错误: {e}", file=sys.stderr)
# --- 尺寸获取结束 ---
original_stem = file_path.stem
normalized_stem = original_stem
# 仅当提供了前缀/后缀时才进行处理
if prefix and normalized_stem.startswith(prefix):
normalized_stem = normalized_stem[len(prefix):]
if suffix and normalized_stem.endswith(suffix):
normalized_stem = normalized_stem[:-len(suffix)]
# 检查处理后是否重名,如果重名则发出警告
if normalized_stem in normalized_to_full_path:
print(f"警告: 规范化后文件名发生冲突。")
print(f" '{file_path.name}''{normalized_to_full_path[normalized_stem].name}' 都变成了 '{normalized_stem}'")
normalized_stems.add(normalized_stem)
normalized_to_full_path[normalized_stem] = file_path
return normalized_stems, normalized_to_full_path, size_counts # <--- 修改
# <--- 新增: 打印尺寸统计的辅助函数 ---
def print_size_report(title: str, size_counts: Dict[Tuple[int, int], int]):
"""
格式化并打印尺寸统计报告。
"""
print(f"\n{title}")
if not size_counts:
print(" (未找到或无法读取任何图像文件)")
return
total_files = sum(size_counts.values())
print(f" 总共 {total_files} 个文件,分布如下:")
# 按尺寸 (宽, 高) 排序输出
for size, count in sorted(size_counts.items()):
width, height = size
print(f" - 尺寸 (宽, 高) {width}x{height}: {count} 个文件")
# --- 新增结束 ---
def main():
parser = argparse.ArgumentParser(description="比较两个文件夹中的文件名是否匹配,并可选择删除不匹配项。")
parser.add_argument("-i", "--image",
type=Path,
default="./ORI",
help="Image 文件夹路径")
parser.add_argument("-l", "--label",
type=Path,
default="./Label",
help="Label 文件夹路径")
parser.add_argument("-p", "--prefix",
type=str,
default="",
help="在 Label 文件名中要忽略的前缀")
parser.add_argument("-s", "--suffix",
type=str,
default="",
help="在 Label 文件名中要忽略的后缀")
parser.add_argument("-y", "--yes",
action="store_true",
help="自动确认删除所有不匹配的文件,跳过交互式提示。")
args = parser.parse_args()
# 1. 处理 Image 文件夹
print(f"--- 正在处理 Image 文件夹: {args.image} ---")
image_stems, image_norm_to_path, image_size_counts = process_directory(args.image) # <--- 修改
if not image_stems:
print(f"未在 Image 文件夹中找到任何 .png 或 .jpg 文件。")
# <--- 新增: 打印 Image 尺寸报告 ---
print_size_report("--- Image 文件夹尺寸统计 ---", image_size_counts)
# 2. 处理 Label 文件夹
print(f"\n--- 正在处理 Label 文件夹: {args.label} ---")
print(f"(忽略前缀: '{args.prefix}', 忽略后缀: '{args.suffix}')")
label_stems_normalized, label_norm_to_path, label_size_counts = process_directory( # <--- 修改
args.label,
args.prefix,
args.suffix
)
if not label_stems_normalized:
print(f"未在 Label 文件夹中找到任何 .png 或 .jpg 文件。")
# <--- 新增: 打印 Label 尺寸报告 ---
print_size_report("--- Label 文件夹尺寸统计 ---", label_size_counts)
# 3. 执行比较 (使用集合运算)
matching_stems = image_stems.intersection(label_stems_normalized)
extra_in_image = image_stems.difference(label_stems_normalized)
extra_in_label_normalized = label_stems_normalized.difference(image_stems)
extra_in_label_original_stems = {label_norm_to_path[stem].stem for stem in extra_in_label_normalized}
# 4. 输出结果
print("\n" + "="*30)
print(" 匹配结果报告")
print("="*30)
print(f"\n匹配的文件总数: {len(matching_stems)}")
# <--- 新增: 检查匹配项的尺寸是否一致 ---
print("\n--- 正在检查匹配文件的尺寸 ---")
mismatched_size_pairs = []
if not matching_stems:
print("(无匹配文件,跳过尺寸检查)")
else:
print(f"正在检查 {len(matching_stems)} 对文件...")
for stem in sorted(matching_stems):
try:
img_path = image_norm_to_path[stem]
lbl_path = label_norm_to_path[stem]
# 再次打开以比较尺寸
with Image.open(img_path) as img:
img_size = img.size
with Image.open(lbl_path) as lbl:
lbl_size = lbl.size
if img_size != lbl_size:
mismatched_size_pairs.append((img_path, img_size, lbl_path, lbl_size))
except Exception as e:
print(f" 错误: 无法比较 '{stem}' 的尺寸. 错误: {e}", file=sys.stderr)
if not mismatched_size_pairs:
print(f"√ 成功: 所有 {len(matching_stems)} 对匹配文件均具有相同的尺寸。")
else:
print(f"\n!!! 警告: 发现 {len(mismatched_size_pairs)} 对文件尺寸不匹配:")
for img_path, img_size, lbl_path, lbl_size in mismatched_size_pairs:
print(f" - [Image] {img_path.name} {img_size} != [Label] {lbl_path.name} {lbl_size}")
# --- 尺寸检查结束 ---
print("\n--- Image 文件夹中多余的文件 (共 {} 个) ---".format(len(extra_in_image)))
if not extra_in_image:
print("(无)")
else:
for file_stem in sorted(extra_in_image):
print(file_stem)
print("\n--- Label 文件夹中多余的文件 (共 {} 个) ---".format(len(extra_in_label_original_stems)))
print("(显示的是原始文件名,非规范化名称)")
if not extra_in_label_original_stems:
print("(无)")
else:
for file_stem in sorted(extra_in_label_original_stems):
print(file_stem)
print("\n" + "="*30)
# 5. 删除交互部分 (无修改)
if not extra_in_image and not extra_in_label_normalized:
print("\n所有文件均完美匹配,无需删除。")
# <--- 新增: 如果文件名匹配,但尺寸不匹配,也在这里提醒一下 ---
if mismatched_size_pairs:
print("注意:虽然文件名匹配,但有 {} 对文件的尺寸不一致,请检查上面的报告。".format(len(mismatched_size_pairs)))
return
# ... (后续的删除逻辑保持不变) ...
confirm_delete = False
if args.yes:
print("\n检测到 -y/--yes 参数,将自动删除不匹配的文件。")
confirm_delete = True
else:
try:
choice = input("\n警告: 是否要永久删除所有上述 '多余' 的文件? (y/n): ").strip().lower()
if choice in ['y', 'yes']:
confirm_delete = True
except (KeyboardInterrupt, EOFError):
print("\n操作被用户中断。")
confirm_delete = False
if confirm_delete:
print("\n--- 正在删除 Image 文件夹中的多余文件 ---")
deleted_count = 0
for stem in sorted(extra_in_image):
file_path = image_norm_to_path[stem]
try:
file_path.unlink()
print(f" 已删除: {file_path.name}")
deleted_count += 1
except OSError as e:
print(f" 删除失败: {file_path.name} (错误: {e})", file=sys.stderr)
print(f"--- Image 文件夹共删除 {deleted_count} 个文件 ---")
print("\n--- 正在删除 Label 文件夹中的多余文件 ---")
deleted_count = 0
for norm_stem in sorted(extra_in_label_normalized):
file_path = label_norm_to_path[norm_stem]
try:
file_path.unlink()
print(f" 已删除: {file_path.name} (原始文件名)")
deleted_count += 1
except OSError as e:
print(f" 删除失败: {file_path.name} (错误: {e})", file=sys.stderr)
print(f"--- Label 文件夹共删除 {deleted_count} 个文件 ---")
print("\n删除操作完成。")
else:
print("\n操作已取消。未删除任何文件。")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,122 @@
#!/bin/bash
usage() {
echo "Usage: $0 -i <ori_image_directory> -l <ori_label_directory> -r <stack_result_directory> [ -a <alpha> -p <prefix> -s <suffix> -h]"
echo "对image图片和label图片进行匹配-i、-l -r均不能为空-p -s默认为空"" -a默认为\"0.3\") "
echo "-i:原始image的路径-l:原始label的路径-p:前缀内容,-s:后缀内容(不用管文件后缀名)-h:帮助"
echo "e.g. bash 0_2_TOOL_stack_pics.sh -i ./ori -l ./label -r ./result_0.3透明度 -a 0.3 -p Prefix -s _label"
}
ori_image_directorys=""
ori_label_directorys=""
stack_result_directorys=""
prefix=""
suffix=""
alpha="0.3"
while getopts "hl:i:r:p:s:a:" opt; do
case $opt in
h)
usage
exit 0
;;
i)
ori_image_directorys=$OPTARG
;;
l)
ori_label_directorys=$OPTARG
;;
p)
prefix=$OPTARG
;;
s)
suffix=$OPTARG
;;
r)
stack_result_directorys=$OPTARG
;;
a)
alpha=$OPTARG
;;
*)
echo -e '\033[31m!!! Error, Illegal input !!!\033[0m'
usage
exit 1
;;
esac
done
# 判断输入地址是否为空
if [ -z "$ori_label_directorys" ] || [ -z "$ori_image_directorys" ] || [ -z "$stack_result_directorys" ]; then
echo -e "\033[31m输入地址 -i -l -z 存在空地址\033[0m"
usage
exit 1
fi
# 地址转化
ori_image_directory=$(readlink -f "$ori_image_directorys")
ori_label_directory=$(readlink -f "$ori_label_directorys")
stack_result_directory=$(readlink -f "$stack_result_directorys")
if [ -z "$ori_label_directory" ] || [ -z "$ori_image_directory" ]|| [ -z "$stack_result_directory" ]; then
echo "image、label、result存在无法解析地址程序退出"
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directorys"
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
echo -e "\033[31mori_label_directory\033[0m: $stack_result_directorys"
exit 1
fi
if [ ! -d "$ori_label_directory" ] || [ ! -d "$ori_image_directory" ]; then
echo "image、label两目录有一个不存在程序退出"
echo -e "\033[31mori_image_directory\033[0m: $ori_image_directory"
echo -e "\033[31mori_label_directory\033[0m: $ori_label_directorys"
exit 1
fi
# 获取当前脚本的路径和名称
script_path=$(dirname "$0")
# 将当前目录更改为脚本所在的路径
cd "$script_path"
# 激活conda环境
source /home/"$USER"/miniconda/bin/activate Deal_Data
echo -e "\033[32m_____ 0_2_TOOL_stack_pics.sh _____\033[0m"
echo -e "\033[33mimage所在文件夹为$ori_image_directory\nlable所在文件夹为$ori_label_directory\033[0m"
# 遍历label目录
for file_path in "$ori_label_directory"/*; do
# 判断是否是文件
if [[ -f "$file_path" ]]; then
file_name=$(basename "$file_path")
# 判断文件名是否符合规范
if [[ "$file_name" =~ .*\.(jpg|png|bmp|JPG|PNG|BMP) ]]; then # 判断是否有为图片
# if [[ "$file_name" =~ "$prefix".*"$suffix".*\.(jpg|png|bmp|JPG|PNG|BMP)$ ]]; then # 判断是否有满足要求的文件名
# 抽取文件名(有前缀、后缀的抽取前缀、后缀里面的,没有的返回整个)
if [ -z $prefix ];then
file_name_extract=$(echo $file_name | sed "s/"$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
else
file_name_extract=$(echo $file_name | sed "s/".*$prefix"\(.*\)"$suffix".*/\1/" | sed "s/\(.*\)\.\(jpg\|png\|bmp\|JPG\|PNG\|BMP\)$/\1/")
fi
# 从label目录中看是否有此文件
file_name_other=$(ls $ori_image_directory | grep "^${file_name_extract}\.")
file_name_other=$(echo "$(echo "$file_name_other" | sed '/^$/d')" | head -n1) # 提取出文件名
# 如果另一个目录没有此文件的话
if [ -z "$file_name_other" ]; then
echo "$file_name label中对应内容未在$ori_image_directory搜索到"
# 建立相关存储文件夹
if [ ! -d "$ori_label_directory/Not_pair_pics" ]; then
mkdir -p "$ori_label_directory/Not_pair_pics" # 建立存储文件夹
fi
# 移动相关文件
cp "$ori_label_directory/$file_name" "$ori_label_directory/Not_pair_pics"
echo "$file_name" >> "$ori_label_directory/Not_pair_pics/not_pair.txt"
else # 如果另一个目录有此配对文件的话,则运行相关程序
echo "image中的$file_name_other与lable中的$file_name"
mkdir -p "$stack_result_directory"
python 0_2_stack_picture.py "$ori_image_directory/$file_name_other" "$ori_label_directory/$file_name" "$stack_result_directory" "$alpha"
echo ""
fi
fi
else
echo "$file_path不是文件"
fi
done

View File

@@ -0,0 +1,41 @@
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
import cv2, os, sys
def Stack_pic(Background_path, Overlay_path, Result_dir, alpha=0.3):
# 读取两张没有alpha通道的图片
img1 = cv2.imread(Background_path) # 底层图片
img2 = cv2.imread(Overlay_path) # 顶层图片
Result_name = os.path.splitext(os.path.basename(Background_path))[0]
# 将img2调整为与img1大小相同
img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
# 将img2的透明度调整为20%
overlay_alpha = alpha
# 将img2叠加到img1上
overlay = cv2.addWeighted(img1, 1 - overlay_alpha, img2, overlay_alpha, 0)
# 保存结果
if not os.path.exists(Result_dir):
os.makedirs(Result_dir)
cv2.imwrite(os.path.join(Result_dir, Result_name+'.png'), overlay)
print("堆叠图片写入地址:", os.path.join(Result_dir, Result_name+'.png'))
if __name__ == '__main__':
Background_path = sys.argv[1] # 背景所在路径
Overlay_path = sys.argv[2] # 上层图片所在路径
Result_dir = sys.argv[3] # 结果所在目录
# 透明度默认为0.3
try:
alpha = float(sys.argv[4])
if(alpha > 1 or alpha < 0):
print("alpha 透明度输入不正确其值应该在0~1之间")
alpha = 0.3
except:
alpha = 0.3
# 进行对叠程序
Stack_pic(Background_path, Overlay_path, Result_dir, alpha)

View File

@@ -0,0 +1,442 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*
import os,time,sys,threading, colorsys, argparse
import asyncio, cv2, multiprocessing, random
from PIL import Image
import numpy as np
from Tool_deal_labels import edge_detection, detect_connected_regions, Tool_color_connected_array, fill_white_regions, color_connected_regions
from Tool_Classes_And_Palette import Annotate_CLASSES, Annotate_PALETTE, bg_PALETTE
def getFileList(dir,Filelist=[], ext=None, Max_layer=1, layer=0, Donot_Search=['1_边缘检测并膨胀', '2_连通区域检测', '3_分水岭算法填充']):
"""
获取文件夹及其子文件夹中文件列表
输入 dir文件夹根目录
输入 ext: 扩展名
返回: 文件路径列表
"""
newDir = dir
if os.path.isfile(dir):
if ext is None:
Filelist.append(dir)
else:
if ext in dir[-3:]:
Filelist.append(dir)
elif os.path.isdir(dir):
file_name = os.path.basename(dir)
# 判断是否在禁搜名单中
if file_name in Donot_Search:
return Filelist
for s in os.listdir(dir):
newDir=os.path.join(dir,s)
if layer <= Max_layer:
getFileList(newDir, Filelist, ext, Max_layer, layer+1)
return Filelist
class Deal_image():
def __init__(self, Annotate_CLASSES = ('肝脏','胆囊'), Annotate_PALETTE = [[255,91,0],[255,234,0]], src_label_fold = "./Label", save_pro_label_fold = "./LABEL_PNG_new", save_GT_label_fold = "./Label_Generate", GT_channel = 1, pro_append_name="_label", GT_append_name="", ori_img_folder="./ORI_PNG", res_label_folder="./Result_label", save_merge_pic_folder="./Result_merge", back_gnd_color=0, first_class_color=1, pic_type="png", Max_width = 10000, Label_Max_Search_layer=1000, save_process_pics=False, bg_PALETTE = [0,0,0]):
# 背景最好放在最后
# self.src_CLASSES = ('肝脏','胆囊','分离钳','止血海绵','肝总管','胆总管','吸引器','剪刀','止血纱布','生物夹','无损伤钳','喷洒','胆囊管','胆囊动脉','电凝','标本袋','引流管','纱布','金属钛夹','术中超声','吻合器','乳胶管','推结器','肝带','钳夹','超声刀','脂肪','双极电凝','棉球','血管阻断夹','肿瘤','针','线','韧带','胆囊静脉','背景')
# self.src_PALETTE = np.array([[255,91,0],[255,234,0],[85, 111, 181],[181, 227, 14],[72, 0, 255],[0, 155, 33],[255,0,255],[29, 32, 136],[160, 15, 95],[0,160,233],[52,184,178],[90,120,41],[255,0,0],[177,0,0],[167,24,233],[112,113,150],[0,255,0],[255,255,255],[0,255,255],[138,251,213],[136,162,196],[197,83,181],[202,202,200],[113,102,140],[66,115,82],[240,16,116],[155,132,0],[155,62,0],[146,175,236],[255,172,159],[245,161,0],[134,124,118], [0,157,142], [181,85,105], [42,8,66],[0,0,0]])
# self.src_CLASSES_NUM = np.shape(self.src_CLASSES)[0]
self.bg_PALETTE = bg_PALETTE # 背景颜色 TODO
self.Annotate_CLASSES = Annotate_CLASSES # 待分类的类
self.Annotate_PALETTE = np.array(Annotate_PALETTE) # 每一类的像素直
self.Annotate_CLASSES_NUM = np.shape(Annotate_CLASSES)[0] # 类数量
self.save_process_pics = save_process_pics # 保存中间过程图片
self.src_label_fold = src_label_fold # 原始标签图片 保存位置
self.save_pro_label_fold = save_pro_label_fold # 优化后标签图片 保存位置
self.save_GT_label_fold = save_GT_label_fold # GT标签图片 保存位置
self.ori_img_folder = ori_img_folder # 最原始手术图片 保存位置
self.res_label_folder = res_label_folder # 训练出来的label 保存位置
self.save_merge_pic_folder = save_merge_pic_folder # 融合图像保存位置
self.pro_append_name = pro_append_name # 优化后标签图片后缀
self.GT_append_name = GT_append_name # GT标签图片后缀
self.GT_channel = GT_channel # GT标签图片通道数
self.Max_width = Max_width # 最大图片宽度(匹配时候用)
self.pic_type = pic_type # 图片类型
self.back_gnd_color = back_gnd_color # 背景颜色
self.first_class_color = first_class_color # 第一类上的颜色
self.Label_Max_Search_layer=Label_Max_Search_layer # 文件夹最大搜索深度
try:
self.labellist_src = getFileList(src_label_fold, [], pic_type, self.Label_Max_Search_layer)
print('本次执行检索到ori_label图片 '+str(len(self.labellist_src))+' 张图像')
except:
self.labellist_src = None
print("没有ori_label相关文件")
try:
# print(save_pro_label_fold)
self.labellist_pro = getFileList(save_pro_label_fold, [], pic_type, self.Label_Max_Search_layer)
print('本次执行检索到pro_label图片 '+str(len(self.labellist_pro))+' 张图像')
except:
self.labellist_pro = None
print("没有pro_label相关文件")
try:
self.imglist_src = getFileList(ori_img_folder, [], pic_type, self.Label_Max_Search_layer)
self.reslist_src = getFileList(res_label_folder, [], pic_type, self.Label_Max_Search_layer)
print('本次执行检索到ori原始图片 '+str(len(self.imglist_src))+' 张图像')
print('本次执行检索到训练train_result图片 '+str(len(self.reslist_src))+' 张图像')
except:
self.imglist_src = None
self.reslist_src = None
print("没有train_result和原始图片相关文件")
# 获取单张图片各个通路信息
def get_single_pic_rgb(self, imgpath):
print(imgpath)
image = Image.open(imgpath).convert('RGB') # 转为RGB图片
# 将 RGB 色值分离
image.load()
r, g, b = image.split()
r = np.array(r)
g = np.array(g)
b = np.array(b)
return image, r, g, b
# 将单个pro图片变成GT图片
def Conver_pro_label_pic_2_GT_pic(self, imgpath, imgname):
time_start=time.time() # 记录开始时间
# 获取单张图片各个通路信息
image, r,g,b = self.get_single_pic_rgb(imgpath)
result_gt = np.ones(np.shape(image))*self.back_gnd_color # 初始化填充内容为back_gnd_color
gt_number = self.first_class_color # 第一类上色颜色确定
# PALETTE中排除掉 '背景' [0,0,0]
PALETTE_No_Bg = self.Annotate_PALETTE[~np.all(self.Annotate_PALETTE == self.bg_PALETTE, axis=1)]
# 遍历所有待识别颜色
for [Annotate_PALETTE_r, Annotate_PALETTE_g, Annotate_PALETTE_b] in PALETTE_No_Bg:
# 查找三原色匹配位置
locate_r = np.where( r == Annotate_PALETTE_r )
locate_g = np.where( g == Annotate_PALETTE_g )
locate_b = np.where( b == Annotate_PALETTE_b )
# 查找都匹配位置(交集)
# 将矩阵换一种表示形式
locate_r = np.array(locate_r[0]) * self.Max_width + np.array(locate_r[1])
locate_g = np.array(locate_g[0]) * self.Max_width + np.array(locate_g[1])
locate_b = np.array(locate_b[0]) * self.Max_width + np.array(locate_b[1])
# 用自带函数寻找匹配项
matched = np.intersect1d(np.intersect1d(locate_r, locate_g), locate_b)
matched = np.concatenate(([matched // self.Max_width], [np.mod(matched, self.Max_width)]), 0)
result_gt[matched[0],matched[1], :] = gt_number
gt_number = gt_number + 1
# 输出GT图片
if(int(self.GT_channel) == 1):
result_gt = result_gt[:,:,0]
elif(int(self.GT_channel) == 3):
result_gt = cv2.cvtColor(np.float32(result_gt), cv2.COLOR_RGB2BGR) # rgb颜色互换
else:
print("GT_channel 必须为1或3")
quit
try: # 新建文件夹
os.mkdir(self.save_GT_label_fold)
except:
print("已有"+self.save_GT_label_fold)
if imgname.lower().endswith(('.jpg', '.png')):
save_dir = os.path.join(self.save_GT_label_fold, os.path.basename(imgname).rpartition('.')[0]+self.GT_append_name+'.'+self.pic_type)
else:
save_dir = os.path.join(self.save_GT_label_fold, os.path.basename(imgname)+self.GT_append_name+'.'+self.pic_type)
cv2.imwrite(save_dir, result_gt)
print("GT图片已保存", save_dir)
time_end=time.time() # 输出结束时间
print('time cost',time_end-time_start,'s')
# 将处理好的图片转化为GT图片
def Conver_pro_label_pic_2_GT_pic_all(self):
print("\033[33m**** 进行转换将Pro_label_pic转换为GT_label_pic ****\033[0m")
print("\033[33mPro_label_pic存储位置为\033[0m", self.save_pro_label_fold)
print("\033[33mGT_label_pic生成位置为\033[0m", self.save_GT_label_fold)
try:
# print(save_pro_label_fold)
self.labellist_pro = getFileList(save_pro_label_fold, [], pic_type, self.Label_Max_Search_layer)
print('本次执行检索到pro_label图片 '+str(len(self.labellist_pro))+' 张图像')
except:
self.labellist_pro = None
print("没有pro_label相关文件")
try:
os.mkdir(self.save_GT_label_fold) # 新建存储文件夹
except:
print("已有"+self.save_GT_label_fold)
# 指定最大进程数为 3
max_processes = 20
# 创建Pool对象
pool = multiprocessing.Pool(processes=max_processes)
# 创建并启动进程
args_list1 = []
args_list2 = []
# 遍历整个文件夹
for imgpath in self.labellist_pro:
imgname = os.path.basename(imgpath).rpartition('.')[0].replace(self.pro_append_name,"")
args_list1.append(imgpath)
args_list2.append(imgname)
args_list = zip(args_list1, args_list2)
# 使用进程池并行执行任务
pool.starmap(self.Conver_pro_label_pic_2_GT_pic, args_list)
# 关闭进程池
pool.close()
pool.join()
def Conver_ori_label_pic_2_pro_pic(self, imgpath, imgname):
time_start=time.time() # 记录开始时间
# 获取单张图片各个通路信息
image = cv2.imread(imgpath)
# 1. 边缘检测并膨胀
dilated_image = edge_detection(image)
# 如果需要存储中间态图片
if(self.save_process_pics == True):
if imgname.lower().endswith(('.jpg', '.png')):
save_dir = os.path.join(self.save_pro_label_fold, '1_边缘检测并膨胀', os.path.basename(imgname).rpartition('.')[0]+self.pro_append_name+'_Edge'+'.'+self.pic_type)
else:
save_dir = os.path.join(self.save_pro_label_fold, '1_边缘检测并膨胀', os.path.basename(imgname)+self.pro_append_name+'_Edge'+'.'+self.pic_type)
cv2.imwrite(save_dir, dilated_image)
print("中间态-边缘检测并膨胀 图片已保存", save_dir)
time_end=time.time() # 输出结束时间
print('time cost',time_end-time_start,'s')
# 2. 检测连通区域
filtered_labeled_array, _ = detect_connected_regions(dilated_image)
colored_image_filtered = Tool_color_connected_array(filtered_labeled_array)
# 如果需要存储中间态图片
if(self.save_process_pics == True):
if imgname.lower().endswith(('.jpg', '.png')):
save_dir = os.path.join(self.save_pro_label_fold, '2_连通区域检测', os.path.basename(imgname).rpartition('.')[0]+self.pro_append_name+'_Region'+'.'+self.pic_type)
else:
save_dir = os.path.join(self.save_pro_label_fold, '2_连通区域检测', os.path.basename(imgname)+self.pro_append_name+'_Region'+'.'+self.pic_type)
cv2.imwrite(save_dir, colored_image_filtered)
print("中间态-连通区域检测 图片已保存", save_dir)
time_end=time.time() # 输出结束时间
print('time cost',time_end-time_start,'s')
# 3. 分水岭填充白色区域
filled_labeled_array = fill_white_regions(filtered_labeled_array)
colored_image_filled = Tool_color_connected_array(filled_labeled_array)
# 如果需要存储中间态图片
if(self.save_process_pics == True):
if imgname.lower().endswith(('.jpg', '.png')):
save_dir = os.path.join(self.save_pro_label_fold, '3_分水岭算法填充', os.path.basename(imgname).rpartition('.')[0]+self.pro_append_name+'_FillEdge'+'.'+self.pic_type)
else:
save_dir = os.path.join(self.save_pro_label_fold, '3_分水岭算法填充', os.path.basename(imgname)+self.pro_append_name+'_FillEdge'+'.'+self.pic_type)
cv2.imwrite(save_dir, colored_image_filled)
print("中间态-分水岭算法填充 图片已保存", save_dir)
time_end=time.time() # 输出结束时间
print('time cost',time_end-time_start,'s')
# 4. 对连通区域最终上色
ori_labeled_image = image
result_pro = color_connected_regions(filled_labeled_array, filtered_labeled_array, ori_labeled_image, self.Annotate_PALETTE)
if imgname.lower().endswith(('.jpg', '.png')):
save_dir = os.path.join(self.save_pro_label_fold, os.path.basename(imgname).rpartition('.')[0]+self.pro_append_name+'.'+self.pic_type)
else:
save_dir = os.path.join(self.save_pro_label_fold, os.path.basename(imgname)+self.pro_append_name+'.'+self.pic_type)
print("Pro图片已保存", save_dir)
cv2.imwrite(save_dir, result_pro)
time_end=time.time() # 输出结束时间
print('time cost',time_end-time_start,'s')
# 将原始src图片转化为处理好的pro图片
def Conver_ori_label_pic_2_pro_pic_all(self):
print("\033[33m**** 进行转换将Ori_label_pic转换为Pro_label_pic ****\033[0m")
print("\033[33mOri_label_pic存储位置为\033[0m", self.src_label_fold)
print("\033[33mPro_label_pic生成位置为\033[0m", self.save_pro_label_fold)
# 输出颜色预处理图片
try:
os.mkdir(self.save_pro_label_fold) # 新建存储文件夹
except:
print("已有"+self.save_pro_label_fold)
if(self.save_process_pics == True):
try:
os.mkdir(os.path.join(self.save_pro_label_fold, '1_边缘检测并膨胀')) # 新建存储1_边缘检测并膨胀文件夹
except:
print("已有"+os.path.join(self.save_pro_label_fold, '1_边缘检测并膨胀'))
try:
os.mkdir(os.path.join(self.save_pro_label_fold, '2_连通区域检测')) # 新建存储2_连通区域检测文件夹
except:
print("已有"+os.path.join(self.save_pro_label_fold, '2_连通区域检测'))
try:
os.mkdir(os.path.join(self.save_pro_label_fold, '3_分水岭算法填充')) # 新建存储1_边缘检测并膨胀文件夹
except:
print("已有"+os.path.join(self.save_pro_label_fold, '3_分水岭算法填充'))
# 指定最大进程数为 20多参数函数并行
max_processes = 20
# 创建Pool对象
pool = multiprocessing.Pool(processes=max_processes)
# 创建并启动进程
args_list1 = []
args_list2 = []
# 遍历整个文件夹
for imgpath in self.labellist_src:
if imgpath.lower().endswith(('.jpg', '.png')):
imgname= os.path.basename(imgpath).rpartition('.')[0].replace(self.pro_append_name,"")
else:
imgname= os.path.basename(imgpath).replace(self.pro_append_name,"")
try:
print("Processing: ", imgname, "...")
# self.Conver_ori_label_pic_2_pro_pic(imgpath, imgname)s
# args_list.append({'imgpath': imgpath, 'imgname': imgname})
args_list1.append(imgpath)
args_list2.append(imgname)
except:
os.system("echo "+imgname+" >> error_1.txt")
args_list = zip(args_list1, args_list2)
# 使用进程池并行执行任务
pool.starmap(self.Conver_ori_label_pic_2_pro_pic, args_list) # 使用starmap进行多参数并行
# 关闭进程池
pool.close()
pool.join()
# 图片堆叠
def Merge_ori_pic_and_label_pic(self, res_img_path, res_imgname):
time_start=time.time() # 记录开始时间
# 获取单张图片各个通路信息
ori_img_path = os.path.join(self.ori_img_folder, res_imgname+'.'+self.pic_type)
if not os.path.exists(ori_img_path):
print("****照片不存在:****", ori_img_path)
return -1
ori_image, ori_r, ori_g, ori_b = self.get_single_pic_rgb(ori_img_path)
res_image, res_r, res_g, res_b = self.get_single_pic_rgb(res_img_path)
merge_img = np.array(ori_image) # merge图片初始化默认图片背景为0.0.0
# 遍历所有待识别颜色
for [Annotate_PALETTE_r, Annotate_PALETTE_g, Annotate_PALETTE_b] in self.Annotate_PALETTE:
# 查找三原色匹配位置
locate_r = np.where( res_r == Annotate_PALETTE_r )
locate_g = np.where( res_g == Annotate_PALETTE_g )
locate_b = np.where( res_b == Annotate_PALETTE_b )
# 查找都匹配位置(交集)
# 将矩阵换一种表示形式
locate_r = np.array(locate_r[0]) * self.Max_width + np.array(locate_r[1])
locate_g = np.array(locate_g[0]) * self.Max_width + np.array(locate_g[1])
locate_b = np.array(locate_b[0]) * self.Max_width + np.array(locate_b[1])
# 用自带函数寻找匹配项
matched = np.intersect1d(np.intersect1d(locate_r, locate_g), locate_b)
matched = np.concatenate(([matched // self.Max_width], [np.mod(matched, self.Max_width)]), 0)
merge_img[matched[0],matched[1], 0] = Annotate_PALETTE_r
merge_img[matched[0],matched[1], 1] = Annotate_PALETTE_g
merge_img[matched[0],matched[1], 2] = Annotate_PALETTE_b
# 转成cv2形式
merge_img = cv2.cvtColor(np.float32(merge_img), cv2.COLOR_RGB2BGR)
try: # 新建文件夹
os.mkdir(self.save_merge_pic_folder)
except:
print("已有"+self.save_merge_pic_folder)
if res_imgname.lower().endswith(('.jpg', '.png')):
save_dir = os.path.join(self.save_merge_pic_folder, os.path.basename(res_imgname).rpartition('.')[0]+'.'+self.pic_type)
else:
save_dir = os.path.join(self.save_merge_pic_folder, os.path.basename(res_imgname)+'.'+self.pic_type)
cv2.imwrite(save_dir, merge_img)
print("Merge图片已保存", save_dir)
time_end=time.time() # 输出结束时间
print('time cost',time_end-time_start,'s')
# 将label图片与原图片重合
def Merge_ori_pic_and_label_pic_all(self):
# 遍历整个文件夹
for res_img_path in self.reslist_src:
if res_img_path.lower().endswith(('.jpg', '.png')):
res_imgname = os.path.basename(res_img_path).rpartition('.')[0].replace(self.pro_append_name,"")
else:
res_imgname = os.path.basename(res_img_path).replace(self.pro_append_name,"")
print("Processing: ", res_imgname, "...")
self.Merge_ori_pic_and_label_pic(res_img_path, res_imgname)
if __name__ == "__main__":
# 创建参数解析器
parser = argparse.ArgumentParser(description='Process some files.')
# 添加参数选项
parser.add_argument('-src_fold', dest='src_label_fold', default='./', help='source label folder')
parser.add_argument('-save_pro_fold', dest='save_pro_label_fold', default='./ORI_pro_label_fold', help='processed label folder')
parser.add_argument('-save_GT_fold', dest='save_GT_label_fold', default='./ORI_GT_label_fold', help='ground truth folder')
parser.add_argument('-fold_search_depth', dest='Label_Max_Search_layer', default='1000', type=int, help='Folder Search Depth')
parser.add_argument('-pro_suffix_name', dest='pro_append_name', default='_label', help='Pro file suffix')
parser.add_argument('-GT_suffix_name', dest='GT_append_name', default='', help='GT file suffix')
parser.add_argument('-GT_channel', dest='GT_channel', default='1', type=int, help='GT file channel(1 or 3)')
parser.add_argument('-back_gnd_color', dest='back_gnd_color', default='0', type=int, help='Color of "Back ground"(0 or 255)')
parser.add_argument('-first_class_color', dest='first_class_color', default='1', type=int, help='Color of "First Class"')
parser.add_argument('-pic_type', dest='pic_type', default='png', help='type of picture(Do not add ".")')
parser.add_argument('-Max_width', dest='Max_width', default='10000', type=int, help='Max width of picture')
parser.add_argument('-Rebuild_from', dest='Rebuild_from', default='label', help='Source to Rebuild Labels(label/pro)')
parser.add_argument('-Rebuild_to', dest='Rebuild_to', default='GT', help='Destination of Rebuild Labels(pro/GT)')
parser.add_argument('-save_process_pics', dest='save_process_pics', default='False', help='Save the processed pics(e.g.Gray_pics,Color_pics) in generating pro_pics')
# 解析命令行参数
args = parser.parse_args()
src_label_fold = args.src_label_fold
save_pro_label_fold = args.save_pro_label_fold
save_GT_label_fold = args.save_GT_label_fold
Label_Max_Search_layer = args.Label_Max_Search_layer
pro_append_name = args.pro_append_name
GT_append_name = args.GT_append_name
GT_channel = args.GT_channel
back_gnd_color = args.back_gnd_color
first_class_color = args.first_class_color
pic_type = args.pic_type
Max_width = args.Max_width
Rebuild_from = args.Rebuild_from
Rebuild_to = args.Rebuild_to
save_process_pics = args.save_process_pics
try: # 遍历文件深度最小为1
Label_Max_Search_layer=int(Label_Max_Search_layer)
except:
Label_Max_Search_layer=1000
try: # GT标签图片通道数
GT_channel=int(GT_channel)
except:
GT_channel=1
try: # 背景颜色背景选择0或255)
back_gnd_color=int(back_gnd_color)
except:
back_gnd_color=0
try: # 第一类上的颜色(如果背景为0,选择1)
first_class_color=int(first_class_color)
except:
first_class_color=1
try: # 最大图片宽度(匹配时候用)
Max_width=int(Max_width)
except:
Max_width=10000
if(save_process_pics.lower() == 'false'):
save_process_pics = False
elif(save_process_pics.lower() == 'true'):
save_process_pics = True
else:
save_process_pics = False
D = Deal_image(Annotate_CLASSES=Annotate_CLASSES, Annotate_PALETTE=Annotate_PALETTE, src_label_fold=src_label_fold, save_pro_label_fold=save_pro_label_fold, save_GT_label_fold=save_GT_label_fold, GT_channel=GT_channel, pro_append_name=pro_append_name, GT_append_name=GT_append_name, back_gnd_color=back_gnd_color, first_class_color=first_class_color, pic_type=pic_type, Max_width=Max_width, Label_Max_Search_layer=Label_Max_Search_layer, save_process_pics=save_process_pics, bg_PALETTE = bg_PALETTE)
# print(D.src_CLASSES_NUM)
if Rebuild_from == 'label':
# 1.先将所有原始图片转为pro图片
D.Conver_ori_label_pic_2_pro_pic_all()
pass
if Rebuild_to == 'GT':
# 2.再将pro图片转为GT图片
D.Conver_pro_label_pic_2_GT_pic_all()
pass

View File

@@ -0,0 +1,169 @@
import os
import cv2
import numpy as np
from collections import Counter, defaultdict
from PIL import Image
from Tool_Classes_And_Palette import Annotate_CLASSES, Annotate_PALETTE, bg_PALETTE
# ※ 需要修改输入输出路径 ※ #
input_dir = './ORI_GT_label_fold'
output_dir_1 = 'Data/labels/train'
output_dir_2 = 'Data/labels/val'
# ----------------------------------------------------
# ※※※ 1. 修改: 移除背景项并统一格式 ※※※
# ----------------------------------------------------
# (1) 从 Annotate_CLASSES (元组) 中移除 '背景'
if Annotate_CLASSES and Annotate_CLASSES[-1] == '背景':
Annotate_CLASSES = Annotate_CLASSES[:-1]
# (2) 从 Annotate_PALETTE (列表) 中移除背景颜色 [0,0,0]
if Annotate_PALETTE and (Annotate_PALETTE[-1] == [0, 0, 0] or Annotate_PALETTE[-1] == (0, 0, 0)):
Annotate_PALETTE.pop()
# (3) 确保调色板中的所有颜色都是元组 (Tuple),以便后续查找
Annotate_PALETTE = [tuple(color) for color in Annotate_PALETTE]
# (4) 确保 bg_PALETTE 是元组,以匹配 Counter 中的键类型
bg_PALETTE = tuple(bg_PALETTE)
# (5) 检查类别和调色板长度是否一致
if len(Annotate_CLASSES) != len(Annotate_PALETTE):
print(f"[警告] 移除背景后,类别 ({len(Annotate_CLASSES)}) 和调色板 ({len(Annotate_PALETTE)}) 长度不一致!")
else:
print(f"--- 成功移除背景项,剩余 {len(Annotate_CLASSES)} 个有效类别。 ---")
os.makedirs(output_dir_1, exist_ok=True)
os.makedirs(output_dir_2, exist_ok=True)
# 全局统计颜色频率与 class 像素频率
global_color_counter = Counter()
global_class_counter = Counter() # 将用于统计ori_label 模式)
color_class_counter = defaultdict(int) # (R,G,B) → count
color_to_old_class = {} # (R,G,B) → class_id
# *** ori_label 模式: 移除了 remap_class_dict ***
# 自动提取颜色映射(跳过背景)
def extract_color_mapping(img_path):
img = Image.open(img_path).convert('RGB')
pixels = list(img.getdata())
counter = Counter(pixels)
color_map = {}
for color, count in counter.items():
global_color_counter[color] += count
if color != bg_PALETTE and color[0] == color[1] == color[2]: # 使用元组 bg_PALETTE
class_id = color[0] - 1
if class_id >= 0:
color_map[color] = class_id
global_class_counter[class_id] += count # 使用 global_class_counter 统计
color_class_counter[color] += count
color_to_old_class[color] = class_id
return color_map
# 处理单张图片
def process_image(img_path, save_path_list):
color_to_class = extract_color_mapping(img_path)
if not color_to_class:
print(f"[跳过] {os.path.basename(img_path)} 无有效目标")
return
img = cv2.imread(img_path)
h, w = img.shape[:2]
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
lines = []
for rgb, old_class_id in color_to_class.items():
mask = np.all(img_rgb == rgb, axis=-1).astype(np.uint8) * 255
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
if contour.shape[0] < 3:
continue
norm_pts = contour.squeeze(1).astype(np.float32)
norm_pts[:, 0] /= w
norm_pts[:, 1] /= h
flat = norm_pts.flatten()
# *** ori_label 模式: 直接使用 old_class_id ***
line = f"{old_class_id} " + " ".join([f"{x:.6f}" for x in flat])
lines.append(line)
if lines:
for save_path in save_path_list:
with open(save_path, 'w') as f:
for line in lines:
f.write(line + "\n")
print(f"[✔] 转换成功: {os.path.basename(save_path)},共 {len(lines)} 个实例")
else:
print(f"[⚠] {os.path.basename(img_path)} 没有轮廓")
# 第一次遍历图像,仅提取颜色信息用于构建 class 映射
for fname in os.listdir(input_dir):
if fname.lower().endswith('.png'):
img_path = os.path.join(input_dir, fname)
extract_color_mapping(img_path)
# *** ori_label 模式: 移除了重映射 (remap) 逻辑块 ***
# (保留 class_pixel_count 用于统计)
class_pixel_count = {color_to_old_class[c]: count for c, count in color_class_counter.items()}
# 第二次遍历图像,正式处理并生成标签
for fname in os.listdir(input_dir):
if fname.lower().endswith('.png'):
img_path = os.path.join(input_dir, fname)
base_name = os.path.splitext(fname)[0]
txt_path_1 = os.path.join(output_dir_1, base_name + '.txt')
txt_path_2 = os.path.join(output_dir_2, base_name + '.txt')
process_image(img_path, [txt_path_1, txt_path_2])
# ----------------------------------------------------
# ※※※ 2. 修改: 打印详细的颜色统计 ※※※
# ----------------------------------------------------
print("\n📊 所有图像颜色统计:")
def get_label_info(class_id):
"""辅助函数:根据 class_id 获取标签名和颜色"""
if 0 <= class_id < len(Annotate_CLASSES) and 0 <= class_id < len(Annotate_PALETTE):
return Annotate_CLASSES[class_id], Annotate_PALETTE[class_id]
else:
return "未知标签", (255, 255, 255) # 返回一个默认值
for color, count in global_color_counter.most_common():
if color == bg_PALETTE: # 使用 bg_PALETTE 变量
print(f"背景颜色 {color} 出现次数: {count}")
elif color[0] == color[1] == color[2]:
class_id = color[0] - 1
label_name, label_color = get_label_info(class_id)
print(f"颜色 {color} → class {class_id} (标签: '{label_name}', 颜色: {label_color}),出现次数: {count}")
else:
# 此情况理论上不应出现,因为 extract_color_mapping 已过滤
print(f"[⚠] 非灰阶颜色 {color},出现次数: {count} (此颜色不应被处理)")
# ----------------------------------------------------
# ※※※ 3. 修改: 打印详细的类别统计 (ori_label 模式) ※※※
# ----------------------------------------------------
print("\n✅ 有效类别统计(按 原 class_id 排序 → 总像素数):")
# 按照 old_id (原类别索引) 排序输出
sorted_classes = sorted(class_pixel_count.items(), key=lambda item: item[0])
for old_id, pixel_count in sorted_classes:
label_name, label_color = get_label_info(old_id)
print(f"class {old_id} (标签: '{label_name}', 颜色: {label_color}): {pixel_count} pixels")
# ----------------------------------------------------
# ※※※ 4. 修改: 额外输出 原 class_id 到颜色的映射 (ori_label 模式) ※※※
# ----------------------------------------------------
print(f"\n🎨 找到的 原 class_id 与标签颜色 (Annotate_PALETTE) 映射表:")
# 1. 按照 old_id (0, 1, 2...) 排序并按指定格式打印
# 我们使用 class_pixel_count.keys() 来获取所有实际找到的 old_id
sorted_old_ids = sorted(class_pixel_count.keys())
for old_id in sorted_old_ids:
if 0 <= old_id < len(Annotate_PALETTE):
color_tuple = Annotate_PALETTE[old_id]
color_list = list(color_tuple)
print(f" {old_id}: {color_list}")
else:
# 预防性代码
print(f" {old_id}: [颜色未在调色板中定义]")
# ----------------------------------------------------
# ※※※ 修改结束 ※※※
# ----------------------------------------------------
print(f"\n✅ 全部图像处理完毕。标签输出目录:{output_dir_1}{output_dir_2}")

View File

@@ -0,0 +1,204 @@
import os
import cv2
import numpy as np
from collections import Counter, defaultdict
from PIL import Image
from Tool_Classes_And_Palette import Annotate_CLASSES, Annotate_PALETTE, bg_PALETTE
# ※ 需要修改输入输出路径 ※ #
input_dir = 'ORI_GT_label_fold'
output_dir_1 = 'Data/labels/train'
output_dir_2 = 'Data/labels/val'
output_GT_dir_1 = 'Data/labels_GT/train'
output_GT_dir_2 = 'Data/labels_GT/val'
# 1. --- 您提供的类别和调色板 ---
# (1) 从 Annotate_CLASSES (元组) 中移除 '背景'
if Annotate_CLASSES and Annotate_CLASSES[-1] == '背景':
Annotate_CLASSES = Annotate_CLASSES[:-1]
# (2) 从 Annotate_PALETTE (列表) 中移除背景颜色 [0,0,0]
if Annotate_PALETTE and (Annotate_PALETTE[-1] == [0, 0, 0] or Annotate_PALETTE[-1] == (0, 0, 0)):
Annotate_PALETTE.pop()
# (3) 确保调色板中的所有颜色都是元组 (Tuple),以便后续查找
Annotate_PALETTE = [tuple(color) for color in Annotate_PALETTE]
# (4) 确保 bg_PALETTE 是元组,以匹配 Counter 中的键类型
bg_PALETTE = tuple(bg_PALETTE)
# (5) 检查类别和调色板长度是否一致
if len(Annotate_CLASSES) != len(Annotate_PALETTE):
print(f"[警告] 移除背景后,类别 ({len(Annotate_CLASSES)}) 和调色板 ({len(Annotate_PALETTE)}) 长度不一致!")
else:
print(f"--- 成功移除背景项,剩余 {len(Annotate_CLASSES)} 个有效类别。 ---")
os.makedirs(output_dir_1, exist_ok=True)
os.makedirs(output_dir_2, exist_ok=True)
os.makedirs(output_GT_dir_1, exist_ok=True)
os.makedirs(output_GT_dir_2, exist_ok=True)
# 全局统计颜色频率与 class 像素频率
global_color_counter = Counter()
global_class_counter = Counter()
color_class_counter = defaultdict(int) # (R,G,B) → count
color_to_old_class = {} # (R,G,B) → class_id
remap_class_dict = {} # old_class_id → new_class_id
# 自动提取颜色映射(跳过背景)
def extract_color_mapping(img_path):
img = Image.open(img_path).convert('RGB')
pixels = list(img.getdata())
counter = Counter(pixels)
color_map = {}
for color, count in counter.items():
global_color_counter[color] += count
if color != (0, 0, 0) and color[0] == color[1] == color[2]:
class_id = color[0] - 1
if class_id >= 0:
color_map[color] = class_id
global_class_counter[class_id] += count
color_class_counter[color] += count
color_to_old_class[color] = class_id
return color_map
# --- 修改:函数签名增加了 gt_save_paths ---
def process_image(img_path, txt_save_paths, gt_save_paths):
color_to_class = extract_color_mapping(img_path)
if not color_to_class:
print(f"[跳过] {os.path.basename(img_path)} 无有效目标")
return
img = cv2.imread(img_path)
h, w = img.shape[:2]
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# --- 新增:创建新的 GT 掩码 (灰度图)0 为背景 ---
new_gt_mask = np.zeros((h, w), dtype=np.uint8)
lines = []
for rgb, old_class_id in color_to_class.items():
# --- 修改:获取 new_class_id 并填充 new_gt_mask ---
new_class_id = remap_class_dict.get(old_class_id, old_class_id)
# 创建布尔掩码
mask_bool = np.all(img_rgb == rgb, axis=-1)
# --- 新增:在 new_gt_mask 上填充新的 class_id
# (使用 new_class_id + 1因为 0 是背景)
new_gt_mask[mask_bool] = new_class_id + 1
# --- 修改:基于布尔掩码创建 255 掩码用于 findContours ---
mask_255 = mask_bool.astype(np.uint8) * 255
contours, _ = cv2.findContours(mask_255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
if contour.shape[0] < 3:
continue
norm_pts = contour.squeeze(1).astype(np.float32)
norm_pts[:, 0] /= w
norm_pts[:, 1] /= h
flat = norm_pts.flatten()
# --- 修改:此处 new_class_id 已在循环外层定义 ---
line = f"{new_class_id} " + " ".join([f"{x:.6f}" for x in flat])
lines.append(line)
# --- 新增:保存处理后的 GT 掩码图像 ---
# 只要存在有效类别 (color_to_class 非空),就保存 GT 掩码
for save_path in gt_save_paths:
try:
# 使用 PIL 保存灰度 PNG 图像
Image.fromarray(new_gt_mask).save(save_path)
print(f"[✔] GT 掩码保存成功: {os.path.basename(save_path)}")
except Exception as e:
print(f"[✘] GT 掩码保存失败: {os.path.basename(save_path)} - {e}")
# 保存 .txt 标签文件 (仅当找到轮廓时)
if lines:
for save_path in txt_save_paths:
with open(save_path, 'w') as f:
for line in lines:
f.write(line + "\n")
print(f"[✔] 转换成功: {os.path.basename(save_path)},共 {len(lines)} 个实例")
else:
# 即使没有轮廓GT 掩码也已保存
print(f"[⚠] {os.path.basename(img_path)} 没有轮廓 (但 GT 掩码已保存)")
# 第一次遍历图像,仅提取颜色信息用于构建 class 映射
for fname in os.listdir(input_dir):
if fname.lower().endswith('.png'):
img_path = os.path.join(input_dir, fname)
extract_color_mapping(img_path)
# 构建 class_id 重映射表(按像素数从大到小排序)
class_pixel_count = {color_to_old_class[c]: count for c, count in color_class_counter.items()}
sorted_classes = sorted(class_pixel_count.items(), key=lambda x: x[1], reverse=True)
remap_class_dict = {old_cls: new_idx for new_idx, (old_cls, _) in enumerate(sorted_classes)}
# 第二次遍历图像,正式处理并生成标签
for fname in os.listdir(input_dir):
if fname.lower().endswith('.png'):
img_path = os.path.join(input_dir, fname)
base_name = os.path.splitext(fname)[0]
txt_path_1 = os.path.join(output_dir_1, base_name + '.txt')
txt_path_2 = os.path.join(output_dir_2, base_name + '.txt')
# --- 新增:定义 GT 掩码输出路径 (保存为 .png) ---
gt_path_1 = os.path.join(output_GT_dir_1, base_name + '.png')
gt_path_2 = os.path.join(output_GT_dir_2, base_name + '.png')
process_image(img_path, [txt_path_1, txt_path_2], [gt_path_1, gt_path_2])
# 打印颜色统计和类别映射
print("\n📊 所有图像颜色统计:")
def get_label_info(class_id):
"""辅助函数:根据 class_id 获取标签名和颜色"""
if 0 <= class_id < len(Annotate_CLASSES) and 0 <= class_id < len(Annotate_PALETTE):
return Annotate_CLASSES[class_id], Annotate_PALETTE[class_id]
else:
return "未知标签", (255, 255, 255) # 返回一个默认值
for color, count in global_color_counter.most_common():
if color == bg_PALETTE: # 使用 bg_PALETTE 变量
print(f"背景颜色 {color} 出现次数: {count}")
elif color[0] == color[1] == color[2]:
class_id = color[0] - 1
label_name, label_color = get_label_info(class_id)
print(f"颜色 {color} → class {class_id} (标签: '{label_name}', 颜色: {label_color}),出现次数: {count}")
else:
# 此情况理论上不应出现,因为 extract_color_mapping 已过滤
print(f"[⚠] 非灰阶颜色 {color},出现次数: {count} (此颜色不应被处理)")
# 打印详细的类别统计
print("\n✅ 有效类别统计(原 class_id → 新 class_id → 总像素数):")
# 按照 new_id (新类别索引) 排序输出
sorted_remap = sorted(remap_class_dict.items(), key=lambda item: item[1])
for old_id, new_id in sorted_remap:
label_name, label_color = get_label_info(old_id)
print(f"class {old_id} (标签: '{label_name}', 颜色: {label_color}) → {new_id}: {class_pixel_count[old_id]} pixels")
# 4. 额外输出新 class_id 到颜色的映射 ※※※
print(f"\n🎨 新 class_id 与标签颜色 (Annotate_PALETTE) 映射表:")
# 1. 创建一个 new_id -> color 的映射
new_id_to_color = {}
for old_id, new_id in remap_class_dict.items():
if 0 <= old_id < len(Annotate_PALETTE):
# 从 Annotate_PALETTE 获取原始颜色(它是一个元组)
color_tuple = Annotate_PALETTE[old_id]
# 转换为列表 [R, G, B] 以匹配您要的格式
new_id_to_color[new_id] = list(color_tuple)
else:
# 预防性代码,以防 old_id 超出范围
new_id_to_color[new_id] = [-1, -1, -1] # 表示错误/未找到
# 2. 按照 new_id (0, 1, 2...) 排序并按指定格式打印
sorted_new_ids = sorted(new_id_to_color.keys())
for new_id in sorted_new_ids:
color_list = new_id_to_color[new_id]
# 格式化输出: {id}: {color_list}
# (注意:颜色列表的格式会自然包含逗号和空格)
print(f" {new_id}: {color_list}")
print(f"\n✅ 全部图像处理完毕。")
print(f" 标签 (.txt) 输出目录: {output_dir_1}, {output_dir_2}")
print(f" GT 掩码 (.png) 输出目录: {output_GT_dir_1}, {output_GT_dir_2}")

View File

@@ -0,0 +1,20 @@
# # 胆囊标注
# Annotate_CLASSES = ('肝脏','胆囊','分离钳','止血海绵','肝总管','胆总管','吸引器','剪刀','止血纱布','生物夹','无损伤钳','喷洒','胆囊管','胆囊动脉','电凝','标本袋','引流管','纱布','金属钛夹','术中超声','吻合器','乳胶管','推结器','肝带','钳夹','超声刀','脂肪','双极电凝','棉球','血管阻断夹','肿瘤','针','线','韧带','胆囊静脉','背景') # 待分类的类
# Annotate_PALETTE = [[255,91,0],[255,234,0],[85, 111, 181],[181, 227, 14],[72, 0, 255],[0, 155, 33],[255,0,255],[29, 32, 136],[160, 15, 95],[0,160,233],[52,184,178],[90,120,41],[255,0,0],[117,0,0],[167,24,233],[112,113,150],[0,255,0],[255,255,255],[0,255,255],[138,251,213],[136,162,196],[197,83,181],[202,202,200],[113,102,140],[66,115,82],[240,16,116],[155,132,0],[155,62,0],[146,175,236],[255,172,159],[245,161,0],[134,124,118], [0,157,142], [181,85,105], [42,8,66],[0,0,0]] # 每一类的像素直
# # 甲状腺标注
# Annotate_CLASSES = ('甲状腺', '针', '双极电凝', '止血海绵', '止血纱布', '喷洒', '标本袋', '肝蒂', '神经', '线', '肝脏', '肝总管', '胆总管', '胆囊管', '引流管', '推结器', '肌肉', '肿瘤', '胆囊', '吸引器', '生物夹', '动脉', '纱布', '乳胶管', '血管阻断夹', '分离钳', '肌肉', '无损伤钳', '电凝', '金属钛夹', '吻合器', '棉球', '脂肪', '超声刀', '钳夹', '静脉', '韧带', '术中超声','背景') # 待分类的类
# Annotate_PALETTE = [[255, 148, 81], [134, 124, 118], [155, 62, 0], [181, 227, 14], [160, 15, 95], [90, 120, 41], [112, 113, 150], [133, 102, 140], [168, 162, 252], [0, 157, 142], [255, 91, 0], [72, 0, 255], [0, 155, 33], [255, 0, 0], [0, 255, 0], [202, 202, 200], [254, 141, 179], [254, 161, 0], [255, 234, 0], [255, 0, 255], [0, 160, 233], [117, 0, 0], [255, 255, 255], [197, 83, 181],[255, 172, 159], [85, 111, 181], [29, 32, 136], [52, 184, 178], [166, 24, 232], [0, 254, 254], [136, 162, 196], [146, 175, 236], [155, 132, 0], [240, 16, 116], [66, 115, 82], [42, 8, 66], [181, 85, 105], [138, 251, 213],[0,0,0]] # 每一类的像素直
# # 胃癌标注+去雾影像标注
Annotate_CLASSES = ('', '', '双极电凝', '止血海绵', '止血纱布', '喷洒', '标本袋', '肝蒂', '小肠', '线', '肝脏', '脾脏', '胆总管', '胆囊管', '引流管', '推结器', '淋巴结', '胰腺', '胆囊', '吸引器', '生物夹', '动脉', '纱布', '乳胶管', '分离钳', '超声刀', '无损伤钳', '电凝', '金属钛夹', '吻合器', '脂肪', '剪刀', '钳夹', '静脉', '韧带','背景') # 待分类的类
Annotate_PALETTE = [(237, 35, 85), (134, 124, 118), (155, 62, 0), (187, 227, 14), (160, 15, 95), (90, 120, 41), (112, 113, 150), (133, 102, 140), (110, 255, 166), (0, 157, 142), (255, 91, 0), (72, 0, 255), (0, 155, 33), (255, 0, 0), (0, 255, 0), (202, 202, 200), (201, 255, 74), (245, 161, 0), (255, 234, 0), (255, 0, 255), (0, 160, 233), (117, 0, 0), (255, 255, 255), (197, 83, 181), (85, 111, 181), (29, 32, 136), (52, 184, 178), (167, 24, 233), (0, 255, 255), (136, 162, 196), (155, 132, 0), (240, 16, 116), (66, 115, 82), (42, 8, 66), (181, 85, 105), [0,0,0]] # 每一类的像素直
# # 甲状腺标注
# Annotate_CLASSES = ('甲状旁腺', '喉返神经', '电凝', '无损伤钳', '超声刀', '分离钳', '纱布', '背景') # 待分类的类
# Annotate_PALETTE = [(238, 25, 30), (24, 124, 248), (198, 24, 248), (24, 248, 240), (248, 119, 24), (24, 248, 114), (255,255,255), [0,0,0]] # 每一类的像素值
# # 磁器械标注
# Annotate_CLASSES = ( '双极电凝', '止血海绵', '止血纱布', '喷洒', '标本袋', '肝脏', '肝总管', '胆总管', '胆囊管', '引流管', '胆囊', '磁器械', '生物夹', '胆囊动脉', '纱布', '分离钳', '剪刀', '无损伤钳','电凝', '金属钛夹', '脂肪', '背景') # 待分类的类
# Annotate_PALETTE = [ (155, 62, 0), (181, 227, 14), (160, 15, 95), (90, 120, 41), (112, 113, 150), (255, 91, 0), (72, 0, 255), (0, 155, 33), (255, 0, 0), (0, 255, 0), (255, 234, 0), (255, 0, 255), (0, 160, 233), (117, 0, 0), (255, 255, 255), (85, 111, 181), (29, 32, 136), (52, 184, 178), (167, 24, 233), (0, 255, 255), (155, 132, 0), [0,0,0]] # 每一类的像素值
# # 二分类标注
# Annotate_CLASSES = ( '特殊部位', '背景') # 待分类的类
# Annotate_PALETTE = [ [255,255,255], [0,0,0]] # 每一类的像素值
bg_PALETTE = [0,0,0] # 背景的RGB

View File

@@ -0,0 +1,167 @@
import os
import argparse
from PIL import Image
import concurrent.futures
# 确保使用的是 os.cpu_count() 来自动检测核心数
try:
DEFAULT_WORKERS = os.cpu_count() or 1
except AttributeError:
try:
import multiprocessing
DEFAULT_WORKERS = multiprocessing.cpu_count()
except (ImportError, NotImplementedError):
DEFAULT_WORKERS = 4
def _process_single_file(source_path, png_path, delete_source):
"""
工作函数处理单个文件的转换、缩放并保存为PNG。
(此函数与上一版完全相同)
"""
filename = os.path.basename(source_path)
png_filename = os.path.basename(png_path)
try:
with Image.open(source_path) as img:
width, height = img.size
min_side = min(width, height)
resize_info = ""
if min_side > 1080:
scale_factor = 1080 / min_side
new_width = int(width * scale_factor)
new_height = int(height * scale_factor)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
resize_info = f" (已缩放至 {new_width}x{new_height})"
img.save(png_path, 'PNG')
if delete_source:
os.remove(source_path)
return "success_del", filename, f"{png_filename}{resize_info}"
else:
return "success", filename, f"{png_filename}{resize_info}"
except Exception as e:
return "fail", filename, str(e)
def process_images(folder_path, delete_source=False, max_workers=DEFAULT_WORKERS):
"""
批量将指定文件夹 (及其所有子文件夹) 内的 .bmp, .jpg, .jpeg 图像
转换为 .png, 并按比例缩放 (最小边<=1080px)。
"""
if not os.path.isdir(folder_path):
print(f"错误:文件夹 '{folder_path}' 不存在或不是一个有效的目录。")
return
print(f"--- 开始递归处理文件夹: {folder_path} ---")
print(f"--- 使用最多 {max_workers} 个进程并行处理 ---")
if delete_source:
print("警告:已启用源文件删除模式。")
# --- 1. 收集所有需要处理的任务 ---
tasks = []
supported_extensions = ('.bmp', '.jpg', '.jpeg')
# <--- 更改:从 os.listdir() 切换到 os.walk() 以支持递归
print("--- 正在扫描所有子文件夹...")
for dirpath, dirnames, filenames in os.walk(folder_path):
for filename in filenames:
if filename.lower().endswith(supported_extensions):
# 构建完整的文件路径
source_path = os.path.join(dirpath, filename)
# 构建PNG输出路径 (保持在同一个子文件夹内)
base_name = os.path.splitext(filename)[0]
png_path = os.path.join(dirpath, base_name + '.png')
# 避免重复转换
if os.path.exists(png_path):
# 打印相对路径,使其更清晰
relative_path = os.path.relpath(png_path, folder_path)
print(f" [跳过] {relative_path} (目标文件已存在)")
continue
tasks.append((source_path, png_path, delete_source))
# <--- 更改结束 ---
if not tasks:
print(f"未在文件夹中找到新的 {supported_extensions} 文件。")
return
print(f"--- 发现 {len(tasks)} 个新图像文件,开始转换... ---")
converted_count = 0
failed_count = 0
# --- 2. 使用 ProcessPoolExecutor 执行任务 (此部分不变) ---
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
future_to_source = {
executor.submit(_process_single_file, source, png, delete): source
for source, png, delete in tasks
}
# --- 3. 实时获取已完成的结果 (此部分不变) ---
for future in concurrent.futures.as_completed(future_to_source):
source_path_orig = future_to_source[future]
try:
status, name, result = future.result()
# 获取相对路径以便于阅读
relative_dir = os.path.relpath(os.path.dirname(source_path_orig), folder_path)
# 如果是根目录relative_dir 会是 ".",我们将其替换为空字符串
if relative_dir == ".":
log_name = name
else:
log_name = os.path.join(relative_dir, name)
if status == "success":
print(f" [成功] {log_name} -> {result}")
converted_count += 1
elif status == "success_del":
print(f" [成功] {log_name} -> {result} (并已删除源文件)")
converted_count += 1
elif status == "fail":
print(f" [失败] 转换 {log_name} 时出错: {result} (源文件未删除)")
failed_count += 1
except Exception as e:
print(f" [严重失败] 处理 {os.path.basename(source_path_orig)} 时进程出错: {e}")
failed_count += 1
print("--- 处理完毕 ---")
if converted_count > 0:
print(f"成功转换 {converted_count} 个文件。")
if failed_count > 0:
print(f"失败 {failed_count} 个文件。")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="将指定文件夹(及其所有子文件夹)中的所有 .bmp, .jpg, .jpeg 文件批量转换为 .png 文件 (并行加速),并按比例缩放 (最小边<=1080px)。"
)
parser.add_argument(
"folder",
type=str,
help="包含 .bmp, .jpg, .jpeg 图像的 (根) 目标文件夹路径 (例如 ./ABC)"
)
parser.add_argument(
"-d", "--delete-source",
action="store_true",
help="在成功转换为 .png 后,删除原始的 .bmp/.jpg/.jpeg 文件。"
)
parser.add_argument(
"-w", "--workers",
type=int,
default=DEFAULT_WORKERS,
help=f"指定用于转换的工作进程数 (默认: {DEFAULT_WORKERS}, 即本机CPU核心数)"
)
args = parser.parse_args()
process_images(args.folder, args.delete_source, args.workers)

View File

@@ -0,0 +1,180 @@
import cv2
import random
import numpy as np
from scipy.ndimage import label, distance_transform_edt
def skeletonize(image):
"""骨架化函数确保线条连通性并缩减为1像素宽"""
skeleton = np.zeros_like(image)
temp_image = np.copy(image)
while True:
eroded = cv2.erode(temp_image, None) # 腐蚀操作
temp_dilate = cv2.dilate(eroded, None) # 膨胀操作
temp = cv2.subtract(temp_image, temp_dilate) # 提取边缘
skeleton = cv2.bitwise_or(skeleton, temp) # 将边缘加入骨架
temp_image = np.copy(eroded)
if cv2.countNonZero(temp_image) == 0:
break
return skeleton
# 1. *** 边缘检测并膨胀 ***
def edge_detection(image):
"""对图像的各个通道进行边缘检测并进行膨胀处理"""
b_channel, g_channel, r_channel = cv2.split(image)
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
edges_b = cv2.Canny(b_channel, 100, 200)
edges_g = cv2.Canny(g_channel, 100, 200)
edges_r = cv2.Canny(r_channel, 100, 200)
edges_gray = cv2.Canny(gray_image, 100, 200)
# 合并所有边缘检测结果
edges = cv2.bitwise_or(edges_b, edges_g)
edges = cv2.bitwise_or(edges, edges_r)
edges = cv2.bitwise_or(edges, edges_gray)
# 创建膨胀核并进行膨胀操作
kernel = np.ones((3, 3), np.uint8)
dilated_image = cv2.dilate(edges, kernel, iterations=1)
return dilated_image
# 2. ** 检测连通区域 ***
def detect_connected_regions(dilated_image):
"""检测图像中的连通区域并过滤掉小区域"""
_, binary = cv2.threshold(dilated_image, 1, 255, cv2.THRESH_BINARY_INV)
binary[binary > 0] = 1 # 转换为二值图像
# 标记连通区域
structure = np.ones((3, 3), dtype=int)
labeled_array, num_features = label(binary, structure=structure)
# 清除掉小于100像素的区域
filtered_labeled_array = np.copy(labeled_array)
for label_num in range(1, num_features + 1):
area = np.sum(labeled_array == label_num)
if area < 100:
filtered_labeled_array[filtered_labeled_array == label_num] = 0
# 重新标记过滤后的连通区域
filtered_labeled_array, num_features = label(filtered_labeled_array, structure=structure)
return filtered_labeled_array, num_features # 返回过滤后的labeled_array和未过滤的labeled_array用于寻找颜色
# 3. *** 分水岭填充白色区域 ***
def fill_white_regions(filtered_labeled_array):
"""使用分水岭算法填充白色区域"""
# 准备三通道图像作为分水岭算法的输入
color_image = np.zeros((filtered_labeled_array.shape[0], filtered_labeled_array.shape[1], 3), dtype=np.uint8)
# 将过滤后的 labeled_array 转化为 32 位整型,作为分水岭的 markers
markers = np.copy(filtered_labeled_array).astype(np.int32)
markers[markers == 0] = -1 # 背景标记为 -1
# 执行分水岭算法
cv2.watershed(color_image, markers)
# 更新分水岭结果
filled_labeled_array = markers.astype(int)
filled_labeled_array[filled_labeled_array == -1] = 0
# 使用距离变换计算边缘像素0最近的非零值
non_zero_mask = filled_labeled_array != 0
distance, nearest_indices = distance_transform_edt(non_zero_mask == 0, return_indices=True)
nearest_values = filled_labeled_array[tuple(nearest_indices)]
filled_labeled_array[filled_labeled_array == 0] = nearest_values[filled_labeled_array == 0]
return filled_labeled_array
# 4. *** 对连通区域上色(使用“filtered_labeled_array”作为颜色判断给“filled_labeled_array”上色) ***
def color_connected_regions(filled_labeled_array, filtered_labeled_array, ori_labeled_image, Annotate_PALETTE):
"""根据原始图像的颜色和注解调色板给连通区域上色"""
# 初始化一个三通道的彩色图像
colored_image = np.zeros((*filled_labeled_array.shape, 3), dtype=np.uint8)
# 遍历filtered_labeled_array中的每个标签
unique_labels = np.unique(filtered_labeled_array)
for label_num in unique_labels:
if label_num == 0:
continue # 跳过背景标签
# 找到filtered_labeled_array中等于当前标签的区域
mask_filtered = (filtered_labeled_array == label_num)
# 获取ori_labeled_image中对应区域的RGB值
region_rgb_values = ori_labeled_image[mask_filtered]
if len(region_rgb_values) == 0:
continue
# 计算区域的RGB平均值
average_rgb = np.mean(region_rgb_values, axis=0)
# 找到Annotate_PALETTE中与average_rgb最接近的颜色
closest_palette_color = find_closest_palette_color(average_rgb, Annotate_PALETTE)
# 将该颜色赋给filled_labeled_array对应区域的元素
mask_filled = (filled_labeled_array == label_num)
colored_image[mask_filled] = closest_palette_color
return colored_image
# 寻找最近邻颜色
def find_closest_palette_color(average_rgb, Annotate_PALETTE):
"""根据平均RGB值找到Annotate_PALETTE中最接近的颜色"""
Annotate_PALETTE = [[color[2], color[1], color[0]] for color in Annotate_PALETTE]
average_rgb = np.array(average_rgb)
min_distance = float('inf')
closest_color = None
# 遍历调色板计算每个颜色与平均RGB的欧几里得距离
for palette_color in Annotate_PALETTE:
palette_color = np.array(palette_color)
distance = np.linalg.norm(average_rgb - palette_color) # 欧几里得距离
if distance < min_distance:
min_distance = distance
closest_color = palette_color
return closest_color
# 5. 对Array区域上色4的简化版
def Tool_color_connected_array(Array):
colored_image = np.zeros((*Array.shape, 3), dtype=np.uint8)
for label_num in range(1, np.max(Array) + 1):
color = [np.random.randint(0, 254) for _ in range(3)]
colored_image[Array == label_num] = color
return colored_image
if __name__ == '__main__':
"""超参数"""
image_path = './2023_02_03_09_13_48.00_08_04_21.Still085.png'
"""主函数,处理图像并保存结果"""
# 读取图像
image = cv2.imread(image_path)
# 1. 边缘检测并膨胀
dilated_image = edge_detection(image)
cv2.imwrite('./1_1_range_image.png', dilated_image)
# 2. 检测连通区域
filtered_labeled_array, _ = detect_connected_regions(dilated_image)
colored_image_filtered = Tool_color_connected_array(filtered_labeled_array)
cv2.imwrite('./2_colored_image_filtered.png', colored_image_filtered)
# 3. 分水岭填充白色区域
filled_labeled_array = fill_white_regions(filtered_labeled_array)
colored_image_filled = Tool_color_connected_array(filled_labeled_array)
cv2.imwrite('./3_colored_image_filled.png', colored_image_filled)
# 4. 对连通区域上色
ori_labeled_image = image
colored_image_final = color_connected_regions(filled_labeled_array, filtered_labeled_array, ori_labeled_image, Annotate_PALETTE)
cv2.imwrite('./4_color_image_Final.png', colored_image_final)
print("处理后的图片已保存。")

View File

@@ -0,0 +1,190 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import argparse
from PIL import Image
import concurrent.futures
import time
# --- 1. 复制自您脚本中的 CPU 核心数检测逻辑 ---
try:
DEFAULT_WORKERS = os.cpu_count() or 1
except AttributeError:
try:
import multiprocessing
DEFAULT_WORKERS = multiprocessing.cpu_count()
except (ImportError, NotImplementedError):
DEFAULT_WORKERS = 4
# --- 2. 定义支持的图片格式 ---
SUPPORTED_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff', '.tif')
# --- 3. 确定Pillow的重采样过滤器 ---
try:
RESAMPLE_FILTER = Image.Resampling.LANCZOS
except AttributeError:
RESAMPLE_FILTER = Image.LANCZOS
def _process_single_image(image_path, target_min_side):
"""
工作函数:处理单个文件的缩放。(此函数与上一版完全相同)
:param image_path: 完整图片路径
:param target_min_side: 目标最小边长 (例如 1080)
:return: (状态, 原始文件名, 结果/错误信息)
"""
filename = os.path.basename(image_path)
try:
with Image.open(image_path) as img:
img_info = img.info.copy()
width, height = img.size
min_side = min(width, height)
# --- 核心逻辑:检查是否需要缩放 ---
if min_side <= target_min_side:
# 返回相对路径以便于阅读
return "skipped", image_path, f"最小边 ({min_side}px) 已 <= {target_min_side}px"
# --- 计算新尺寸 ---
scale_ratio = target_min_side / min_side
new_width = int(width * scale_ratio)
new_height = int(height * scale_ratio)
# --- 执行缩放 ---
resized_img = img.resize((new_width, new_height), RESAMPLE_FILTER)
# --- 覆盖保存 ---
resized_img.save(image_path, **img_info)
# 返回相对路径以便于阅读
return "success", image_path, f"{width}x{height} -> {new_width}x{new_height}"
except Exception as e:
# 返回相对路径以便于阅读
return "fail", image_path, str(e)
def resize_images_in_folder(folder_path, target_min_side=1080, max_workers=DEFAULT_WORKERS):
"""
*** [已更新] ***
批量将文件夹及其所有子文件夹中最小边 > target_min_side 的图片等比例缩小。
:param folder_path: 目标根文件夹路径
:param target_min_side: 目标最小边长
:param max_workers: 使用的进程数
"""
if not os.path.isdir(folder_path):
print(f"错误:文件夹 '{folder_path}' 不存在或不是一个有效的目录。")
return
print(f"--- 开始 **递归** 处理文件夹: {folder_path} ---")
print(f"--- 目标最小边: {target_min_side}px ---")
print(f"--- 使用最多 {max_workers} 个进程并行处理 ---")
start_time = time.time()
# --- 1. 收集所有需要处理的任务 (*** [核心修改] ***) ---
tasks = []
print("--- 正在递归扫描所有子文件夹... ---")
# 使用 os.walk 递归遍历
for dirpath, dirnames, filenames in os.walk(folder_path):
for filename in filenames:
# 检查文件扩展名是否在支持的列表中
if filename.lower().endswith(SUPPORTED_EXTENSIONS):
# 构造完整的文件路径
image_path = os.path.join(dirpath, filename)
tasks.append((image_path, target_min_side))
if not tasks:
print(f"未在 {folder_path} 及其子文件夹中找到支持的图片文件。")
return
print(f"--- 发现 {len(tasks)} 个图片文件,开始处理... ---")
resized_count = 0
skipped_count = 0
failed_count = 0
# 转换为相对路径,使输出更简洁
base_folder_path = os.path.abspath(folder_path)
# --- 2. 使用 ProcessPoolExecutor 执行任务 (与上一版相同) ---
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
future_to_path = {
executor.submit(_process_single_image, path, size): path
for path, size in tasks
}
# --- 3. 实时获取已完成的结果 (与上一版相同) ---
for future in concurrent.futures.as_completed(future_to_path):
try:
# status, path_or_name, result
status, image_path, result = future.result()
# 转换为相对路径
try:
display_path = os.path.relpath(image_path, base_folder_path)
except ValueError:
display_path = image_path # 如果不在同一驱动器(Windows),则显示完整路径
if status == "success":
print(f" [成功] {display_path}: {result}")
resized_count += 1
elif status == "skipped":
print(f" [跳过] {display_path}: {result}")
skipped_count += 1
elif status == "fail":
print(f" [失败] {display_path}: {result}")
failed_count += 1
except Exception as e:
orig_path = future_to_path[future]
print(f" [严重失败] 处理 {orig_path} 时进程出错: {e}")
failed_count += 1
end_time = time.time()
print("--- 处理完毕 ---")
if resized_count > 0:
print(f"成功缩放 {resized_count} 个文件。")
if skipped_count > 0:
print(f"跳过 {skipped_count} 个文件 (无需缩放)。")
if failed_count > 0:
print(f"失败 {failed_count} 个文件。")
print(f"总耗时: {end_time - start_time:.2f} 秒。")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="*** [已更新] *** 批量 **递归** 缩放文件夹及其子文件夹中的图片,使最小边不超过指定值。"
)
# 位置参数:文件夹路径 (必需)
parser.add_argument(
"folder",
type=str,
help="包含图片的 **根** 文件夹路径 (例如 ./MyImages),将递归处理所有子文件夹。"
)
# 选项参数:目标尺寸 (可选)
parser.add_argument(
"-s", "--size",
type=int,
default=1080,
help="指定目标最小边长 (默认: 1080)。如果图片最小边已小于此值,则跳过。"
)
# 选项参数:控制进程数 (可选)
parser.add_argument(
"-w", "--workers",
type=int,
default=DEFAULT_WORKERS,
help=f"指定用于转换的工作进程数 (默认: {DEFAULT_WORKERS}, 即本机CPU核心数)"
)
args = parser.parse_args()
# 将文件夹路径、"目标尺寸" 和 "工作进程数" 传递给函数
resize_images_in_folder(args.folder, args.size, args.workers)

View File

@@ -0,0 +1,69 @@
#### 0. 准备工作 ####
# 清理旧文件
rm -r ./Label/* ./ORI/*
rm -r ./result_stack_*透明度 ./Data
rm -r ./ORI_GT_label_fold ./ORI_pro_label_fold ./__pycache__
conda activate SMP
# A. 图像移动
cp 磁辅助分割图像-有效图/* Label/ # TODO TODO 将标注后图片放入Label中 # 保证Label中图片与ORI中图片一一对应且命名相同
cp 磁辅助分割-原图/* ORI/ # TODO TODO 将原始图片放入ORI中
# B.1 修改类别颜色
vim Tool_Classes_And_Palette.py # 修改 Annotate_CLASSES、Annotate_PALETTE 以匹配标注时的类别与颜色
# B.2 将bmp、jpg图片转为png
python Tool_convert_bmp_jpg_to_png.py ./Label --delete-source
python Tool_convert_bmp_jpg_to_png.py ./ORI --delete-source
# B.3 将图片转为最大边限制为1080
python Tool_resize_pics.py ./Label # -s 1080
python Tool_resize_pics.py ./ORI # -s 1080
# C. 检测图片是否匹配
python 0_1_check_picture_pair.py # -i "./ORI" -l "./Label" -p "" -s ""
# python 0_1_check_picture_pair.py # -i "../../DataSet_Public/6_CWK_2_cfz/images/train" -l "../../DataSet_Public/6_CWK_2_cfz/labels_GT/train" -p "" -s ""
# D. 生成堆叠图片(可视化标签效果)
bash 0_2_TOOL_stack_pics.sh -i "./ORI" -l ./Label -r ./result_stack_0.3透明度 -a 0.3 -p "" -s "_label"
# E. ※下载图片,查看匹配的是否有问题※
#### 1. 批量化生成训练、测试集图片 ####
# A. 将图片转为GT图片
python 1_deal_labels.py -src_fold ./Label
# B. 新建数据最终存储文件夹
rm -r ./Data
mkdir -p Data/images/train Data/images/val
cp ORI/* Data/images/train/
cp ORI/* Data/images/val/ # 或根据需要分配训练集、验证集
mkdir -p Data/labels/train Data/labels/val
mkdir -p Data/labels_GT/train Data/labels_GT/val
# C. 生成labels 、 labels_GT图片到 Data/labels_GT/train 中
# Way 1※推荐※将类别压缩
python 2_Check_and_Gen_Txt_Label_sort_label.py
# Way 2使用原始类别
# python 2_Check_and_Gen_Txt_Label_ori_label.py # Way 2
# cp ORI_GT_label_fold/* Data/labels_GT/train/ # Way 2
# cp ORI_GT_label_fold/* Data/labels_GT/val/ # Way 2
#### 2. 纳入到Yolo训练体系 ####
# A. 移动数据
cp -r ./Data ../../DataSet_Public/8_Haze_Baidu_Plus # TODO TODO 将 Data 文件夹改名后 放入 ../../DataSet_Public 下
# B. 设置输出颜色参考
将 程序输出的信息存入 ../../Seg_All_In_One_YoloModel/dataset.yaml 的 color中作为最终输出颜色的参考
# C. 根据 Seg_All_In_One_YoloModel 中的使用手册新增数据集、进行配置、训练
修改 dataset.yaml 中 训练数据集、修改 yolo_config.py 中 EPOCHS、PATIENCE
############################## 进行训练推理(整体流程) ############################################
# 1. 批量化训练
conda activate SMP
cd ~/Desktop/Seg/Seg_All_In_One_YoloModel
bash yolo_train.sh
# 2. 复制最优模型到预测文件夹
bash ./Tool_Yolo_Copy_Best_Model.sh --pt_name "best.pt" && bash ./Tool_Yolo_Copy_Best_Model.sh --pt_name "epoch100.pt" && bash ./Tool_Yolo_Copy_Best_Model.sh --pt_name "epoch50.pt" && bash ./Tool_Yolo_Copy_Best_Model.sh --pt_name "epoch150.pt"
# 3. 批量化预测+热度图可视化
bash yolo_predict.sh --conf 0.2 --pt_name "epoch100.pt" && bash yolo_predict.sh --conf 0.2 && bash yolo_predict.sh --conf 0.2 --pt_name "epoch50.pt" && bash yolo_predict.sh --conf 0.2 --pt_name "epoch150.pt"
bash ./yolo_predict.sh --pt_name "best.pt" --heatmap_method "All" # && bash ./yolo_predict.sh --pt_name "epoch100.pt" --heatmap_method "All"
# 4. 横向对比结果
python yolo_predict_V2_compare_all.py
# 5. 打包预测结果(不包含*.pt模型文件)
cd /home/wkmgc/Desktop/Seg/BestMode_Predict_Results_DataSet_Public/
zip -r /home/wkmgc/Desktop/8_Haze_Baidu_Plus-Yolo.zip 8_Haze_Baidu_Plus*-Yolo -x "*.pt" # TODO TODO
# 6. 打包训练结果(只有.png、.jpg、.csv文件)
cd /home/wkmgc/Desktop/Seg/Hardisk/
zip -r /home/wkmgc/Desktop/8_Haze_Baidu_Plus-Yolo_train.zip 8_Haze_Baidu_Plus-Yolo -i \*.png \*.jpg \*.csv # TODO TODO

View File

@@ -0,0 +1,318 @@
############ 第一部分:数据集跟路径 ############
# # V1: 1_CholecSeg8k-13Type-1920x1080
# path: ../DataSet_Public/1_CholecSeg8k-13Type-1920x1080
# # V22_AutoLaparo-10Type-1920x1080
# path: ../DataSet_Public/2_AutoLaparo-10Type-1920x1080
# # V33_1_Endovis_2017-8Type-512x512
# path: ../DataSet_Public/3_1_Endovis_2017-8Type-512x512
# # V43_2_Endovis_2018-8Type-512x512
# path: ../DataSet_Public/3_2_Endovis_2018-8Type-512x512
# # V54_Dresden-11Type-512x512
# path: ../DataSet_Public/4_Dresden-11Type-512x512
# # V65_LC_1_blood_verssel
# path: ../DataSet_Public/5_LC_1_blood_verssel
# # V75_LC_2_artery
# path: ../DataSet_Public/5_LC_2_artery
# # V85_LC_3_cystic_duct
# path: ../DataSet_Public/5_LC_3_cystic_duct
# # V95_LC_4_foreigner
# path: ../DataSet_Public/5_LC_4_foreigner
# # V105_LC_5_stop_bleed
# path: ../DataSet_Public/5_LC_5_stop_bleed
# # V116_CWK_1_yws
# path: ../DataSet_Public/6_CWK_1_yws
# # V126_CWK_2_cfz
# path: ../DataSet_Public/6_CWK_2_cfz
# # V13 5_TQY # TODO
# path: ../DataSet_Public/5_TQY
# # V14 5_Haze_ori、6_Haze_AOD_Net、7_Haze_Baidu、8_Haze_Baidu_Plus # TODO
# path: ../DataSet_Public/8_Haze_Baidu_Plus
# 默认使用本机已存在的公开数据集,便于 yolo_config.py smoke test 直接通过。
# 切换实验数据集时,请同步修改 path/test/train/val 和 names。
path: ../DataSet_Public/3_1_Endovis_2017-8Type-512x512
# # # Test_V15_Predict_Video
# path: ../DataSet_Public/5_Predict_Video/LC_Video_1
############ 第二部分:测试集相对路径 ############
# 训练集和验证集图片路径 (相对于 'path')
# # V1: 1_CholecSeg8k-13Type-1920x1080
# test: images/val
# # V22_AutoLaparo-10Type-1920x1080
# test: images/val
# # V33_1_Endovis_2017-8Type-512x512
# test: images/val
# # V43_2_Endovis_2018-8Type-512x512
# test: images/val
# # V54_Dresden-11Type-512x512
# test: images/val # images/test
# # V65_LC_1_blood_verssel
test: images/val
# # V75_LC_2_artery
# test: images/val
# # V85_LC_3_cystic_duct
# test: images/val
# # V95_LC_4_foreigner
# test: images/val
# # V105_LC_5_stop_bleed
# test: images/val
# # V116_CWK_1_yws
# test: images/val
# # V126_CWK_2_cfz
# test: images/val
# # V135_TQY
# test: images/val
# # V145_Haze_ori、6_Haze_AOD_Net、7_Haze_Baidu、8_Haze_Baidu_Plus
# test: images/val
# # Test_V15_Predict_Video
# test: images/val
############ 第三部分:训练集、验证集相对路径 ############
train: images/train
val: images/val
############ 第四部分类别名称【从0开始】 ############
# # V1: 1_CholecSeg8k-13Type-1920x1080
# names:
# 0: background
# 1: 1
# 2: 2
# 3: 3
# 4: 4
# 5: 5
# 6: 6
# 7: 7
# 8: 8
# 9: 9
# 10: 10
# 11: 11
# 12: 12
# V22_AutoLaparo-10Type-1920x1080
# names:
# 0: background
# 1: 1
# 2: 2
# 3: 3
# 4: 4
# 5: 5
# 6: 6
# 7: 7
# 8: 8
# 9: 9
# # V33_1_Endovis_2017-8Type-512x512
# names:
# 0: background
# 1: 1
# 2: 2
# 3: 3
# 4: 4
# 5: 5
# 6: 6
# 7: 7
# # V43_2_Endovis_2018-8Type-512x512
# names:
# 0: background
# 1: 1
# 2: 2
# 3: 3
# 4: 4
# 5: 5
# 6: 6
# 7: 7
# # V54_Dresden-11Type-512x512
# names:
# 0: background
# 1: 1
# 2: 2
# 3: 3
# 4: 4
# 5: 5
# 6: 6
# 7: 7
# 8: 8
# 9: 9
# 10: 10
# # V65_LC_1_blood_verssel # TODO
# names:
# 0: 0
# 1: 1
# 2: 2
# 3: 3
# 4: 4
# 5: 5
# # V75_LC_2_artery # TODO
# names:
# 0: 0
# 1: 1
# 2: 2
# 3: 3
# 4: 4
# 5: 5
# 6: 6
# 7: 7
# 8: 8
# 9: 9
# 10: 10
# 11: 11
# 12: 12
# 13: 13
# # V85_LC_3_cystic_duct # TODO
# names:
# 0: 0
# 1: 1
# 2: 2
# 3: 3
# 4: 4
# 5: 5
# 6: 6
# 7: 7
# 8: 8
# 9: 9
# # V95_LC_4_foreigner # TODO
# names:
# 0: 0
# 1: 1
# 2: 2
# 3: 3
# 4: 4
# 5: 5
# # V105_LC_5_stop_bleed # TODO
# names:
# 0: 0
# 1: 1
# 2: 2
# 3: 3
# 4: 4
# 5: 5
# 6: 6
# 7: 7
# 8: 8
# 9: 9
# 10: 10
# 11: 11
# # V116_CWK_1_yws # TODO
# names:
# 0: 0
# 1: 1
# 2: 2
# 3: 3
# 4: 4
# 5: 5
# # V126_CWK_2_cfz # TODO
# names:
# 0: 0
# 1: 1
# 2: 2
# 3: 3
# 4: 4
# 5: 5
# # V135_TQY # TODO
# names:
# 0: 0
# # V145_Haze_ori、6_Haze_AOD_Net、7_Haze_Baidu、8_Haze_Baidu_Plus # TODO
# names:
# 0: 0
# 1: 1
# 2: 2
# 3: 3
# 当前默认V3 3_1_Endovis_2017-8Type-512x512
names:
0: background
1: 1
2: 2
3: 3
4: 4
5: 5
6: 6
7: 7
############ 第五部分:最终上色 ############
# # V65_LC_1_blood_verssel # TODO
# colors:
# 0: [255, 91, 0]
# 1: [255, 234, 0]
# 2: [167, 24, 233]
# 3: [52, 184, 178]
# 4: [255, 0, 0]
# 5: [0, 155, 33]
# # V75_LC_2_artery # TODO
# colors:
# 0: [255, 91, 0]
# 1: [255, 234, 0]
# 2: [255, 0, 0]
# 3: [0, 160, 233]
# 4: [0, 155, 33]
# 5: [52, 184, 178]
# 6: [167, 24, 233]
# 7: [255, 255, 255]
# 8: [117, 0, 0]
# 9: [72, 0, 255]
# 10: [85, 111, 181]
# 11: [0, 255, 255]
# 12: [42, 8, 66]
# 13: [66, 115, 82]
# # V85_LC_3_cystic_duct # TODO
# colors:
# 0: [255, 91, 0]
# 1: [255, 234, 0]
# 2: [255, 0, 0]
# 3: [167, 24, 233]
# 4: [0, 155, 33]
# 5: [0, 160, 233]
# 6: [72, 0, 255]
# 7: [52, 184, 178]
# 8: [255, 255, 255]
# 9: [117, 0, 0]
# # V95_LC_4_foreigner # TODO
# colors:
# 0: [255, 91, 0]
# 1: [255, 234, 0]
# 2: [255, 0, 0]
# 3: [0, 155, 33]
# 4: [52, 184, 178]
# 5: [167, 24, 233]
# # V105_LC_5_stop_bleed # TODO
# colors:
# 0: [255, 91, 0]
# 1: [181, 227, 14]
# 2: [160, 15, 95]
# 3: [85, 111, 181]
# 4: [52, 184, 178]
# 5: [255, 0, 0]
# 6: [255, 255, 255]
# 7: [255, 0, 255]
# 8: [167, 24, 233]
# 9: [255, 234, 0]
# 10: [0, 160, 233]
# 11: [113, 102, 140]
# # V116_CWK_1_yws # TODO
# colors:
# 0: [255, 91, 0]
# 1: [155, 132, 0]
# 2: [255, 234, 0]
# 3: [255, 0, 255]
# 4: [52, 184, 178]
# 5: [85, 111, 181]
# # # V126_CWK_2_cfz # TODO
# colors:
# 0: [255, 91, 0]
# 1: [155, 132, 0]
# 2: [255, 234, 0]
# 3: [52, 184, 178]
# 4: [255, 0, 255]
# 5: [85, 111, 181]
# # V13 5_TQY # TODO
# colors:
# 0: [225, 182, 193]
# V145_Haze_ori、6_Haze_AOD_Net、7_Haze_Baidu、8_Haze_Baidu_Plus
colors:
0: [167, 24, 233]
1: [52, 184, 178]
2: [255, 255, 255]
3: [0, 160, 233]

View File

@@ -0,0 +1,48 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# YOLO12-seg instance segmentation model with P3/8 - P5/32 outputs
# Model docs: https://docs.ultralytics.com/models/yolo12
# Task docs: https://docs.ultralytics.com/tasks/segment
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo12n-seg.yaml' will call yolo12-seg.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.50, 0.25, 1024] # summary: 294 layers, 2,855,056 parameters, 2,855,040 gradients, 10.6 GFLOPs
s: [0.50, 0.50, 1024] # summary: 294 layers, 9,938,592 parameters, 9,938,576 gradients, 35.7 GFLOPs
m: [0.50, 1.00, 512] # summary: 314 layers, 22,505,376 parameters, 22,505,360 gradients, 123.5 GFLOPs
l: [1.00, 1.00, 512] # summary: 510 layers, 28,756,992 parameters, 28,756,976 gradients, 145.1 GFLOPs
x: [1.00, 1.50, 512] # summary: 510 layers, 64,387,264 parameters, 64,387,248 gradients, 324.6 GFLOPs
# YOLO12n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 2, C3k2, [256, False, 0.25]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 2, C3k2, [512, False, 0.25]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 4, A2C2f, [512, True, 4]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 4, A2C2f, [1024, True, 1]] # 8
# YOLO12n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 2, A2C2f, [512, False, -1]] # 11
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 2, A2C2f, [256, False, -1]] # 14
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 11], 1, Concat, [1]] # cat head P4
- [-1, 2, A2C2f, [512, False, -1]] # 17
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 8], 1, Concat, [1]] # cat head P5
- [-1, 2, C3k2, [1024, True]] # 20 (P5/32-large)
- [[14, 17, 20], 1, Segment, [nc, 32, 256]] # Detect(P3, P4, P5)

View File

@@ -0,0 +1,123 @@
import yaml, sys
import torch
from pathlib import Path
# --- 工具函数:动态设备选择 (Dynamic Device Selection) ---
def get_auto_device(verbose = False):
"""自动检测并返回最合适的设备"""
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
if gpu_count > 1:
# 如果有多个GPU使用所有GPU
device_str = ",".join(str(i) for i in range(gpu_count))
if verbose:
print(f"检测到 {gpu_count} 个可用的GPU。将使用所有GPU: {device_str}")
return device_str
else:
# 如果只有1个GPU
if verbose:
print("检测到 1 个可用的GPU。将使用 GPU: 0")
return '0'
else:
# 如果没有可用的GPU使用CPU
if verbose:
print("未检测到可用的GPU。将使用CPU。")
return 'cpu'
# --- 1. 核心目录设置 (Core Directories) ---
HARDISK_DIR = Path.home() / "Desktop" / "Seg" / "Hardisk" # 硬盘根目录
BASE_DIR = Path(__file__).parent.parent # 当前脚本所在目录上级
DATASET_YAML_PATH = Path(__file__).parent / "dataset.yaml" # 数据集的 YAML 配置文件路径
OUTPUTS_DIR = BASE_DIR / "DataSet_Public_outputs" # 输出结果目录 # 最终为OUTPUTS_DIR / dataset_name
PREDICT_ALL_BEST_MODELS_DIR = BASE_DIR / 'BestMode_Predict_Results_DataSet_Public' # 所有最佳模型存放目录
TEST_IMAGE_DIR = None # 初始化 TEST_IMAGE_DIR # 测试文件地址 dataset.path / TODO "val" / "test" TODO
# try:
# 3. 读取并解析 YAML 文件
with open(DATASET_YAML_PATH, 'r', encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
# 4. 从解析后的数据中获取 'path' 的值
relative_path_from_yaml = yaml_data.get('path')
test_path_from_yaml = yaml_data.get('test')
dataset_name = Path(relative_path_from_yaml).name
if relative_path_from_yaml:
# 5. 【核心步骤】构建绝对路径
# relative_path_from_yaml 是相对于 .yaml 文件本身的路径。
# 所以,我们需要获取 .yaml 文件所在的目录,然后与这个相对路径拼接。
yaml_file_directory = DATASET_YAML_PATH.parent
TEST_IMAGE_DIR = (DATASET_YAML_PATH.parent / relative_path_from_yaml / test_path_from_yaml).resolve() # 这里val 或 test在dataset.yaml中定义
# 使用 .exists() 方法来检查路径是否存在
if not TEST_IMAGE_DIR.exists():
# 如果路径不存在,则执行这里的代码
print(f"警告: 测试图片目录不存在: {TEST_IMAGE_DIR}")
sys.exit(1)
# 6. 获取OUTPUTS_DIR
if dataset_name and dataset_name not in {'.', '..'}:
dataset_name_ = dataset_name + "-Yolo"
# 设定输出路径
OUTPUTS_DIR = OUTPUTS_DIR / dataset_name_
# ../BestMode_Predict_Results_DataSet_Public/"dataset_name+"-Yolo"" # 需要用到 Tool_Yolo_Copy_Best_Model.sh
PREDICT_BEST_MODEL_DIR = PREDICT_ALL_BEST_MODELS_DIR / dataset_name_ # 最优模型保存位置 PREDICT_ALL_BEST_MODELS_DIR / dataset_name_ / weights / best.pt
else:
print(f"警告: 提取的dataset_name: '{dataset_name}' 无效(为空、'.''..')。")
sys.exit(1)
else:
print(f"警告: 在 '{DATASET_YAML_PATH}' 文件中没有找到 'path' 键。")
sys.exit(1)
# except FileNotFoundError:
# print(f"错误: YAML 配置文件未找到: '{DATASET_YAML_PATH}'")
# except Exception as e:
# print(f"读取或解析 YAML 文件时发生错误: {e}")
# --- 修改:模型、图像大小和批处理大小的集成配置 ---
# 将所有模型的配置(权重、图像大小、批处理大小)集中管理
MODEL_CONFIGS = {
# YOLOv8
'YOLOv8n-seg': {'weights': 'yolov8n-seg.pt', 'image_size': 640, 'batch_size': 16},
'YOLOv8s-seg': {'weights': 'yolov8s-seg.pt', 'image_size': 640, 'batch_size': 16},
'YOLOv8m-seg': {'weights': 'yolov8m-seg.pt', 'image_size': 640, 'batch_size': 16},
'YOLOv8l-seg': {'weights': 'yolov8l-seg.pt', 'image_size': 640, 'batch_size': 16}, # 6401612.5GB
'YOLOv8x-seg': {'weights': 'yolov8x-seg.pt', 'image_size': 640, 'batch_size': 16}, # 6401615.5GB # 示例X模型使用1280分辨率
# YOLOv9
'YOLOv9c-seg': {'weights': 'yolov9c-seg.pt', 'image_size': 640, 'batch_size': 16}, # 6401613GB
'YOLOv9e-seg': {'weights': 'yolov9e-seg.pt', 'image_size': 640, 'batch_size': 8}, # 64016内存超了 # 示例E模型最大使用1280分辨率和更小的batch
# YOLOv11 (假设)
'YOLO11n-seg': {'weights': 'yolo11n-seg.pt', 'image_size': 640, 'batch_size': 16},
'YOLO11s-seg': {'weights': 'yolo11s-seg.pt', 'image_size': 640, 'batch_size': 16},
'YOLO11m-seg': {'weights': 'yolo11m-seg.pt', 'image_size': 640, 'batch_size': 16},
'YOLO11l-seg': {'weights': 'yolo11l-seg.pt', 'image_size': 640, 'batch_size': 16}, # 6401612.5GB
'YOLO11x-seg': {'weights': 'yolo11x-seg.pt', 'image_size': 640, 'batch_size': 16}, # 6401619.5GB
# YOLOv12 (假设)
'YOLO12-seg': {'weights': str(Path(__file__).parent / 'yolo12-seg.yaml'), 'image_size': 640, 'batch_size': 16}, # 640163GB
}
# --- 4. 训练超参数 (Training Hyperparameters) ---
DEVICE = get_auto_device(verbose=False) # 调用函数来设置设备
EPOCHS = 300 # 训练轮次 # TODO
PATIENCE = 150 # 提前停止训练的轮数
SAVE_PERIOD = 10 # 每隔多少轮保存一次模型 # TODO
# BATCH_SIZE 已在上面根据模型自动设置
LEARNING_RATE = 0.01 # 初始学习率('SGD:0.01', 'Adam:1e-4', 'AdamW:1e-4', 'auto:0.01'
OPTIMIZER = 'auto' # 优化器 (TODO 'SGD', 'Adam', 'AdamW', 'auto' TODO)
WORKERS = 4 # 数据加载的工作线程数
# --- 5. 预测设置 (Prediction Settings) ---
SHOW_LABELS = True # 是否在预测结果上显示类别标签
SHOW_CONF = True # 是否在预测结果上显示置信度和边界框
SAVE_PREDICTIONS = True # 是否保存预测结果图像
# ==============================================================================
# --- 主执行块:只在直接运行时才打印配置信息 ---
# ==============================================================================
def show_config_summary():
"""打印所有配置摘要信息"""
print(f"设定输出路径为: '{str(OUTPUTS_DIR)}'")
print(f"最佳模型路径为: '{str(PREDICT_BEST_MODEL_DIR)}'")
# 再次调用 get_auto_device 并设置 verbose=True 来打印设备信息
get_auto_device(verbose=True)
# 当这个脚本被直接执行时python yolo_config.py__name__ 的值是 '__main__'
# 当它被其他脚本import时__name__ 的值是 'yolo_config'
if __name__ == '__main__':
show_config_summary()

View File

@@ -0,0 +1,167 @@
#!/bin/bash
# =================================================================
# YOLO 模型批量并行预测脚本
# =================================================================
# - 此脚本会自动为每个模型架构查找其训练好的 'best.pt'
# - 使用 'echo "1" |' 来自动选择找到的第一个训练版本
# - 在不同的指定GPU上并行执行预测任务
# =================================================================
# --- 1. Conda 环境设置 ---
CONDA_BASE_PATH="/home/wkmgc/miniconda3" # <--- 在这里修改为您自己的 Conda 路径
CONDA_ENV_NAME="${SEG_CONDA_ENV:-seg_smp}" # 可用 SEG_CONDA_ENV=SMP bash yolo_predict.sh 临时覆盖
pt_name="best.pt" # <--- 在这里修改为您想使用的权重文件名,例如 "best.pt" 或 "epoch100.pt"
conf_threshold=0.2 # <--- [新增] 默认的置信度阈值
heatmap_method="None" # <--- [!! 新增 !!] 默认不运行热度图
# 循环解析参数
while [[ $# -gt 0 ]]; do
key="$1"
case $key in
--pt_name)
if [ -n "$2" ] && [[ "$2" != -* ]]; then
pt_name="$2"
shift # 移过 --pt_name
shift # 移过它的值
else
echo "错误: --pt_name 参数需要一个值。" >&2
exit 1
fi
;;
--conf)
if [ -n "$2" ] && [[ "$2" != -* ]]; then
conf_threshold="$2"
shift # 移过 --conf
shift # 移过它的值
else
echo "错误: --conf 参数需要一个值。" >&2
exit 1
fi
;;
--heatmap_method)
if [ -n "$2" ] && [[ "$2" != -* ]]; then
heatmap_method="$2"
shift # 移过 --heatmap_method
shift # 移过它的值
else
echo "错误: --heatmap_method 参数需要一个值 (例如 'GradCAM' 或 'All')。" >&2
exit 1
fi
;;
*)
# 移过未知参数,不报错
shift
;;
esac
done
# 初始化并激活 Conda 环境
if [ -f "${CONDA_BASE_PATH}/etc/profile.d/conda.sh" ]; then
source "${CONDA_BASE_PATH}/etc/profile.d/conda.sh"
conda activate "${CONDA_ENV_NAME}"
if [ $? -ne 0 ]; then
echo "错误: 激活 Conda 环境 '${CONDA_ENV_NAME}' 失败!"
exit 1
fi
echo "Conda 环境 '${CONDA_ENV_NAME}' 已成功激活。"
else
echo "错误: 找不到 conda.sh 脚本。请检查您的 CONDA_BASE_PATH 设置是否正确。"
exit 1
fi
# --- 2. 模型与 GPU 配置 ---
# 此处的分组应与 train.sh 保持一致,以确保能正确找到模型并分配资源
GPUS_GROUP_0="0"
GPUS_GROUP_1="1"
GPUS_GROUP_2="2"
GPUS_GROUP_3="3"
# TODO #
GPUS_GROUP_4="0"
GPUS_GROUP_5="1"
GPUS_GROUP_6="2"
GPUS_GROUP_7="3"
# 从 yolo_config.py/train.sh 中选择的模型列表
GROUP_0_MODELS=("YOLO11l-seg")
GROUP_1_MODELS=("YOLOv8n-seg" "YOLOv8m-seg")
GROUP_2_MODELS=("YOLO11n-seg" "YOLO11s-seg" "YOLO11m-seg")
GROUP_3_MODELS=("YOLOv9e-seg")
GROUP_4_MODELS=("YOLO11x-seg")
GROUP_5_MODELS=("YOLOv9c-seg" "YOLOv8s-seg")
GROUP_6_MODELS=("YOLOv8l-seg" "YOLO12-seg")
GROUP_7_MODELS=("YOLOv8x-seg")
# 1. 从 config.py 中读取 PREDICT_BEST_MODEL_DIR 的值
PREDICT_BEST_MODEL_DIR=$(python -c "from yolo_config import PREDICT_BEST_MODEL_DIR; print(PREDICT_BEST_MODEL_DIR)")
# 检查是否成功获取了 PREDICT_BEST_MODEL_DIR
if [ -z "$PREDICT_BEST_MODEL_DIR" ] || [ ! -e "$PREDICT_BEST_MODEL_DIR" ]; then
echo "PREDICT_BEST_MODEL_DIR: $PREDICT_BEST_MODEL_DIR"
echo "Error: Could not read PREDICT_BEST_MODEL_DIR from yolo_config.py. Exiting."
echo "Error 2: Or the directory specified by PREDICT_BEST_MODEL_DIR does not exist. Please create it first."
exit 1
fi
# 2. 定义带有时间戳的日志目录名
LOG_DIR_NAME="yolo_predict_logs_parallel_$(date +%Y-%m-%d_%H-%M-%S)"
# 3. 拼接成最终的完整路径
LOG_DIR="$PREDICT_BEST_MODEL_DIR/$LOG_DIR_NAME"
mkdir -p "${LOG_DIR}"
echo "所有模型的预测日志将保存在 ./${LOG_DIR}/ 目录中。"
echo "----------------------------------------------------"
# --- 3. 预测执行函数 ---
# 定义一个函数来启动一组预测,以避免代码重复
start_prediction_group() {
# 使用 nameref (引用) 来传递数组
local -n models=$1
local gpus=$2
local group_name=$3
echo ">>> 准备启动 ${group_name} 的预测任务 (后台运行)..."
# 遍历指定组中的所有模型
for model_key in "${models[@]}"; do
if [ "${heatmap_method}" == "None" ]; then
# --- 模式 1: 运行标准预测 (原有逻辑) ---
echo " -> 正在后台启动 [标准预测]: ${model_key} on GPUs: ${gpus}"
# [注意] 我为您添加了 --conf 参数,您原有的脚本 没有传递它
echo "1" | CUDA_VISIBLE_DEVICES=${gpus} python yolo_predict_V2.py --model "${model_key}" --pt_name "${pt_name}" --conf "${conf_threshold}" > "${LOG_DIR}/${model_key}_predict.log" 2>&1 &
echo " - 模型 ${model_key} 的预测已在后台启动。日志文件: ${LOG_DIR}/${model_key}_predict.log"
else
# --- 模式 2: 运行热度图可视化 ---
echo " -> 正在后台启动 [热度图可视化]: ${model_key} on GPUs: ${gpus} (Method: ${heatmap_method})"
# [注意] 我们使用 yolo_predict_visualize_nn.py 并传递新参数
echo "1" | CUDA_VISIBLE_DEVICES=${gpus} python yolo_predict_visualize_nn.py --model "${model_key}" --target_layers "default" --cam_method "${heatmap_method}" --pt_name "${pt_name}" > "${LOG_DIR}/${model_key}_heatmap.log" 2>&1 &
echo " - 模型 ${model_key} 的热度图已在后台启动。日志文件: ${LOG_DIR}/${model_key}_heatmap.log"
fi
echo " - 等待 5 秒,确保 GPU 资源稳定分配..."
sleep 5
done
echo ">>> ${group_name} 的所有模型均已启动。"
echo "----------------------------------------------------"
}
# --- 4. 依次启动所有预测任务 ---
# 脚本将快速地按顺序启动每一组任务到后台
start_prediction_group GROUP_0_MODELS "${GPUS_GROUP_0}" "第零组"
start_prediction_group GROUP_1_MODELS "${GPUS_GROUP_1}" "第一组"
start_prediction_group GROUP_2_MODELS "${GPUS_GROUP_2}" "第二组"
start_prediction_group GROUP_3_MODELS "${GPUS_GROUP_3}" "第三组"
start_prediction_group GROUP_4_MODELS "${GPUS_GROUP_4}" "第四组"
start_prediction_group GROUP_5_MODELS "${GPUS_GROUP_5}" "第五组"
start_prediction_group GROUP_6_MODELS "${GPUS_GROUP_6}" "第六组"
start_prediction_group GROUP_7_MODELS "${GPUS_GROUP_7}" "第七组"
# --- 5. 等待所有后台任务完成 ---
echo ""
echo "--- 所有模型均已在后台启动。现在等待所有预测任务完成... ---"
# 'wait' 命令会暂停脚本,直到所有由此脚本启动的后台子进程全部执行完毕
wait
echo "--- 所有后台预测任务已全部完成! ---"
# --- 6. 退出脚本 ---
echo "预测流程结束。"
conda deactivate
echo "已取消激活 Conda 环境。"

View File

@@ -0,0 +1,238 @@
import logging, sys, argparse, yaml
from pathlib import Path
from ultralytics import YOLO
import yolo_config as config
config.show_config_summary() # 显示配置信息
from typing import List
import cv2
import numpy as np
# +----------------------------------------------------------------------------+
# ../BestMode_Predict_Results_DataSet_Public/
# └── YOLOv8n-seg_2025-09-20_10-00-00/ # The folder for the trained model you selected
# ├── weights/
# │ └── best.pt
# │
# ├── prediction/ # Visualized results (image with colored mask overlay)
# │ ├── image1.jpg
# │ └── ...
# │
# └── prediction_masks_combined/ # <-- Your new unique folder for combined masks
# ├── image1.png # Grayscale mask, original size, pixel values are class IDs
# ├── image2.png # Grayscale mask
# └── ...
# +----------------------------------------------------------------------------+
# --- 日志设置 ---
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()]
)
# --- 使用带 model_key 过滤的版本 ---
def find_trained_models(outputs_dir: Path, model_key: str) -> List[str]:
"""
扫描输出目录,查找特定基础模型的所有有效的、已完成的训练项目。
"""
trained_models = []
if not outputs_dir.is_dir():
logging.warning(f"输出目录不存在: {outputs_dir}")
return []
for project_folder in outputs_dir.iterdir():
if project_folder.is_dir() and project_folder.name.startswith(model_key + '_'):
if (project_folder / 'weights' / 'best.pt').exists():
trained_models.append(project_folder.name)
trained_models.sort()
return trained_models
# 老版predict函数只有默认输出
# def predict_old(model_path: str, source: str, project_name: str):
# """
# 使用指定的模型对图像源进行预测。
# """
# if not Path(source).exists():
# logging.error(f"错误:预测源路径不存在: {source}")
# return
# try:
# logging.info(f"正在加载模型: {model_path}")
# model = YOLO(model_path)
# logging.info("模型加载成功。")
# logging.info(f"正在对源进行预测: {source}")
# results = model.predict(
# source=source,
# save=config.SAVE_PREDICTIONS,
# show_labels=config.SHOW_LABELS,
# show_conf=config.SHOW_CONF,
# project=project_name,
# name="prediction",
# exist_ok=True,
# retina_masks=True, # <<< ADDED: 保存与原图一样大的分割图
# )
# final_save_dir = Path(project_name) / "prediction"
# logging.info(f"预测完成。结果保存在: {final_save_dir}")
# except Exception as e:
# logging.error(f"预测过程中发生错误: {e}")
def predict(model_path: str, source: str, project_name: str):
"""
使用指定模型进行预测,并将所有类别的掩码合并到一个单通道灰度图中。
在该图中每个像素的值对应其类别ID (0=背景, 1=类别1, 2=类别2, ...)。
"""
# --- 1. 加载类别名称 (用于日志记录) ---
try:
with open(config.DATASET_YAML_PATH, 'r', encoding='utf-8') as f:
class_names = yaml.safe_load(f)['names']
logging.info(f"成功从 {config.DATASET_YAML_PATH} 加载类别名称: {class_names}")
except Exception as e:
logging.error(f"加载或解析 dataset.yaml 失败: {e}")
return
if not Path(source).exists():
logging.error(f"错误:预测源路径不存在: {source}")
return
try:
# --- 2. 加载模型并执行预测 ---
logging.info(f"正在加载模型: {model_path}")
model = YOLO(model_path)
logging.info("模型加载成功。")
logging.info(f"正在对源进行预测: {source}")
results = model.predict(
source=source,
stream=True, # 使用流式处理以节省内存,但是会减速
save=config.SAVE_PREDICTIONS,
show_labels=config.SHOW_LABELS,
show_conf=config.SHOW_CONF,
project=project_name,
name="prediction",
exist_ok=True,
retina_masks=True,
conf=0.1
)
# --- 3. 创建新的输出目录并处理结果 ---
mask_save_dir = Path(project_name) / "predicted_raw_masks"
mask_save_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"预测完成。正在处理结果并保存组合掩码图于: {mask_save_dir}")
for result in results:
original_image_path = Path(result.path)
logging.info(f"正在处理图片: {original_image_path.name}")
h, w = result.orig_shape
# 创建一个空白画布用于存储所有类别的组合掩码。0 代表背景。
combined_mask = np.zeros((h, w), dtype=np.uint8)
if result.masks is None:
logging.warning(f" -> 在图片 {original_image_path.name} 中未检测到任何物体,将创建一张全黑的掩码图。")
# 如果没有检测到物体,直接保存空白掩码图并继续处理下一张图片
mask_filename = mask_save_dir / f"{original_image_path.stem}.png"
cv2.imwrite(str(mask_filename), combined_mask)
continue
# --- 核心:创建组合掩码 ---
masks_data = result.masks.data
class_ids = result.boxes.cls.int().cpu().numpy()
if len(masks_data) != len(class_ids):
logging.error(f" -> 掩码和类别ID数量不匹配跳过图片 {original_image_path.name}")
continue
# YOLO结果通常按置信度从高到低排序。
# 我们反向迭代(从低置信度到高置信度),以确保在掩码重叠区域,
# 置信度更高的物体类别能够覆盖置信度较低的。
for i in reversed(range(len(masks_data))):
instance_mask = masks_data[i].cpu().numpy().astype(bool)
class_id = class_ids[i]
# 将掩码区域的像素值设置为其对应的类别ID。
# 我们假设类别ID 0 是背景所以即使模型预测了ID为0的物体它也会被视为背景。
if class_id != 0:
combined_mask[instance_mask] = class_id
# --- 保存最终的组合掩码图 ---
mask_filename = mask_save_dir / f"{original_image_path.stem}.png"
cv2.imwrite(str(mask_filename), combined_mask)
logging.info(f"所有组合掩码图已成功保存。")
except Exception as e:
logging.error(f"预测或手动保存过程中发生错误: {e}", exc_info=True)
if __name__ == "__main__":
# 1. 创建解析器,现在只需要 --model 和 --source
parser = argparse.ArgumentParser(description="使用已训练的YOLO模型进行预测。")
parser.add_argument(
"--model",
type=str,
required=True,
choices=list(config.MODEL_CONFIGS.keys()),
help="选择一个基础模型类型,以筛选其训练历史。"
)
parser.add_argument(
"--source",
type=str,
default=str(config.TEST_IMAGE_DIR),
help="图片或图片文件夹的路径。"
)
args = parser.parse_args()
# 2. 根据 --model 参数查找对应的训练历史
available_runs = find_trained_models(config.PREDICT_BEST_MODEL_DIR, args.model)
run_to_use = None
# 3. 根据找到的结果数量,决定下一步操作
if not available_runs:
logging.error(f"错误:在 {config.PREDICT_BEST_MODEL_DIR} 中没有找到任何关于模型 '{args.model}' 的有效训练记录。")
sys.exit(1)
elif len(available_runs) == 1:
# 如果只有一个结果,自动选择
run_to_use = available_runs[0]
logging.info(f"只找到一个训练版本,已自动选择: {run_to_use}")
else:
# 如果有多个结果,进入交互式选择模式
print(f"\n为模型 '{args.model}' 找到多个训练版本:")
for i, run_name in enumerate(available_runs, 1):
print(f" [{i}] {run_name}")
while True:
try:
choice = input(f"请输入您想使用的版本序号 (1-{len(available_runs)}): ")
choice_index = int(choice)
if 1 <= choice_index <= len(available_runs):
run_to_use = available_runs[choice_index - 1]
break
else:
print("错误:输入无效,请输入列表中的序号。")
except ValueError:
print("错误:请输入一个数字。")
except (KeyboardInterrupt, EOFError):
print("\n操作已取消。")
sys.exit(0)
# 4. 使用最终确定的 run_to_use 来执行预测
if run_to_use:
model_to_use = config.PREDICT_BEST_MODEL_DIR / run_to_use / 'weights' / 'best.pt'
project_to_save_in = str(config.PREDICT_BEST_MODEL_DIR / run_to_use)
if not model_to_use.exists():
logging.error(f"严重错误:找不到权重文件 {model_to_use}")
else:
predict(
model_path=str(model_to_use),
source=args.source,
project_name=project_to_save_in
)

View File

@@ -0,0 +1,538 @@
import logging, sys, argparse, yaml
from pathlib import Path
from ultralytics import YOLO
import yolo_config as config
config.show_config_summary() # 显示配置信息
from typing import List
import cv2
import numpy as np
# +----------------------------------------------------------------------------+
# ../BestMode_Predict_Results_DataSet_Public/
# └── YOLOv8n-seg_2025-09-20_10-00-00/ # The folder for the trained model you selected
# ├── weights/
# │ └── best.pt
# │
# ├── prediction/ # Visualized results (image with colored mask overlay)
# │ ├── image1.jpg
# │ └── ...
# │
# └── prediction_masks_combined/ # <-- Your new unique folder for combined masks
# ├── image1.png # Grayscale mask, original size, pixel values are class IDs
# ├── image2.png # Grayscale mask
# └── ...
# +----------------------------------------------------------------------------+
# --- 日志设置 ---
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()]
)
# --- 使用带 model_key 过滤的版本 ---
def find_trained_models(outputs_dir: Path, model_key: str, pt_name: str) -> List[str]:
"""
扫描输出目录,查找特定基础模型的所有有效的、已完成的训练项目。
"""
trained_models = []
if not outputs_dir.is_dir():
logging.warning(f"输出目录不存在: {outputs_dir}")
return []
for project_folder in outputs_dir.iterdir():
if project_folder.is_dir() and project_folder.name.startswith(model_key + '_'):
if (project_folder / 'weights' / pt_name).exists():
trained_models.append(project_folder.name)
trained_models.sort()
return trained_models
# 老版predict函数只有默认输出
# def predict_old(model_path: str, source: str, project_name: str):
# """
# 使用指定模型进行预测。
# 1. 保存原始灰度掩码 (predicted_raw_masks)像素值为类别ID。
# 2. 保存彩色可视化掩码 (predicted_color_masks),根据 color_map.yaml 上色。
# """
# # --- 1. 加载类别名称 (用于日志记录和确定类别数量) ---
# try:
# with open(config.DATASET_YAML_PATH, 'r', encoding='utf-8') as f:
# class_names = yaml.safe_load(f)['names']
# logging.info(f"成功从 {config.DATASET_YAML_PATH} 加载类别名称: {class_names}")
# num_classes = len(class_names)
# logging.info(f"检测到 {num_classes} 个类别 (ID 0 到 {num_classes-1})。")
# except Exception as e:
# logging.error(f"加载或解析 dataset.yaml 失败: {e}")
# return
# # --- [新增] 2. 加载颜色查找表 (LUT) ---
# color_lut = np.zeros((num_classes, 3), dtype=np.uint8) # 默认全黑
# color_map_path = 'color_map.yaml'
# try:
# with open(color_map_path, 'r', encoding='utf-8') as f:
# color_data = yaml.safe_load(f)['colors']
# logging.info(f"正在从 {color_map_path} 加载颜色...")
# for class_id, bgr_value in color_data.items():
# if 0 <= class_id < num_classes:
# color_lut[class_id] = bgr_value
# else:
# logging.warning(f"color_map.yaml 中的类别ID {class_id} 超出范围 (0-{num_classes-1}),已忽略。")
# logging.info("颜色查找表 (LUT) 创建成功。")
# except FileNotFoundError:
# logging.error(f"错误:未找到 {color_map_path}。")
# logging.warning("将使用随机颜色作为备用方案(背景除外)。")
# # 创建随机颜色作为备用
# for i in range(1, num_classes): # 保持 0 (背景) 为黑色 [0,0,0]
# color_lut[i] = np.random.randint(0, 255, 3, dtype=np.uint8)
# except Exception as e:
# logging.error(f"加载或解析 {color_map_path} 失败: {e}", exc_info=True)
# return
# if not Path(source).exists():
# logging.error(f"错误:预测源路径不存在: {source}")
# return
# try:
# # --- 3. 加载模型并执行预测 ---
# logging.info(f"正在加载模型: {model_path}")
# model = YOLO(model_path)
# logging.info("模型加载成功。")
# logging.info(f"正在对源进行预测: {source}")
# results = model.predict(
# source=source,
# stream=True,
# save=config.SAVE_PREDICTIONS,
# show_labels=config.SHOW_LABELS,
# show_conf=config.SHOW_CONF,
# project=project_name,
# name="prediction",
# exist_ok=True,
# retina_masks=True,
# conf=0.1
# )
# # --- 4. 创建新的输出目录并处理结果 ---
# # [修改] 定义两个保存目录
# raw_mask_save_dir = Path(project_name) / "predicted_raw_masks"
# color_mask_save_dir = Path(project_name) / "predicted_color_masks" # [新增] 彩色图目录
# raw_mask_save_dir.mkdir(parents=True, exist_ok=True)
# color_mask_save_dir.mkdir(parents=True, exist_ok=True) # [新增]
# # [修改] 更新日志信息
# logging.info(f"预测完成。正在处理结果...")
# logging.info(f" -> 原始灰度掩码将保存于: {raw_mask_save_dir}")
# logging.info(f" -> 彩色可视化掩码将保存于: {color_mask_save_dir}")
# for result in results:
# original_image_path = Path(result.path)
# logging.info(f"正在处理图片: {original_image_path.name}")
# h, w = result.orig_shape
# # 创建一个空白画布用于存储所有类别的组合掩码。0 代表背景。
# combined_mask = np.zeros((h, w), dtype=np.uint8)
# if result.masks is None:
# logging.warning(f" -> 在图片 {original_image_path.name} 中未检测到任何物体,将创建一张全黑的掩码图。")
# # --- 保存原始灰度掩码 (全黑) ---
# raw_mask_filename = raw_mask_save_dir / f"{original_image_path.stem}.png"
# cv2.imwrite(str(raw_mask_filename), combined_mask)
# # --- [新增] 保存彩色掩码 (全黑) ---
# # 即使是全黑的也应用LUT结果还是全黑
# colorized_mask = color_lut[combined_mask] # (H, W, 3)
# color_mask_filename = color_mask_save_dir / f"{original_image_path.stem}.png"
# cv2.imwrite(str(color_mask_filename), colorized_mask)
# continue # 继续处理下一张图片
# # --- 核心:创建组合掩码 ---
# masks_data = result.masks.data
# class_ids = result.boxes.cls.int().cpu().numpy()
# if len(masks_data) != len(class_ids):
# logging.error(f" -> 掩码和类别ID数量不匹配跳过图片 {original_image_path.name}")
# continue
# # (核心逻辑保持不变)
# # 反向迭代,确保高置信度覆盖低置信度
# for i in reversed(range(len(masks_data))):
# instance_mask = masks_data[i].cpu().numpy().astype(bool)
# class_id = class_ids[i]
# if class_id != 0: # 假设 0 是背景
# combined_mask[instance_mask] = class_id
# # --- 保存 1: 最终的组合掩码图 (灰度) ---
# raw_mask_filename = raw_mask_save_dir / f"{original_image_path.stem}.png"
# cv2.imwrite(str(raw_mask_filename), combined_mask)
# # --- [新增] 保存 2: 最终的彩色掩码图 (RGB) ---
# # 这是最高效的上色方法:
# # color_lut (14, 3)
# # combined_mask (H, W),值 0-13
# # NumPy 会自动将 (H, W) 中的每个值作为索引去 color_lut 中取 (B,G,R) 元组
# colorized_mask = color_lut[combined_mask] # 结果维度 (H, W, 3)
# color_mask_filename = color_mask_save_dir / f"{original_image_path.stem}.png"
# cv2.imwrite(str(color_mask_filename), colorized_mask)
# logging.info(f"所有掩码图已成功保存。")
# except Exception as e:
# logging.error(f"预测或手动保存过程中发生错误: {e}", exc_info=True)
def predict(model_path: str, source: str, project_name: str, pt_name: str, conf_threshold: float):
"""
使用指定模型进行预测。
1. 保存原始灰度掩码 (predicted_raw_masks)像素值为类别ID。
2. 保存彩色可视化掩码 (predicted_color_masks)。
3. 保存三合一对比图 (predicted_comparison) [Original | Ground Truth | Prediction]。
"""
#
pt_name_raw = pt_name.replace('.pt','') # pt_name去掉.pt后缀方便文件命名使用
# [新增] 标志,用于决定是否将(类别0..N-1) 偏移到 (1..N),为 "真背景" 腾出 0
perform_id_shift = False
# --- 1. 加载类别名称和颜色 ---
try:
with open(config.DATASET_YAML_PATH, 'r', encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
class_names = yaml_data.get('names')
color_data_from_dataset = yaml_data.get('colors')
if not class_names:
logging.error(f"{config.DATASET_YAML_PATH} 中未找到 'names' 键。")
return
logging.info(f"成功从 {config.DATASET_YAML_PATH} 加载类别名称: {class_names}")
num_classes = len(class_names)
logging.info(f"检测到 {num_classes} 个类别 (ID 0 到 {num_classes-1})。")
except Exception as e:
logging.error(f"加载或解析 dataset.yaml 失败: {e}")
return
# --- 2. 创建颜色查找表 (LUT) ---
# [调整] 先创建一个临时的 (N) 维 LUT
temp_lut = np.zeros((num_classes, 3), dtype=np.uint8)
if color_data_from_dataset:
logging.info(f"正在从 {config.DATASET_YAML_PATH} 加载颜色...")
try:
for class_id, rgb_value in color_data_from_dataset.items():
if 0 <= class_id < num_classes:
temp_lut[class_id] = rgb_value[::-1] # BGR
logging.info("临时颜色查找表 (LUT) 已成功从 dataset.yaml 创建。")
except Exception as e:
logging.error(f"解析 'colors' 键失败: {e}。将尝试备用方案。")
color_data_from_dataset = None # 强制使用备用方案
if not color_data_from_dataset:
# 备用方案 1: 尝试 color_map.yaml
color_map_path = 'color_map.yaml'
logging.warning(f"未从 dataset.yaml 加载颜色,正在尝试备用方案: {color_map_path}")
try:
with open(color_map_path, 'r', encoding='utf-8') as f:
color_data = yaml.safe_load(f)['colors']
for class_id, bgr_value in color_data.items():
if 0 <= class_id < num_classes:
color_lut[class_id] = bgr_value
logging.info(f"颜色查找表 (LUT) 已成功从 {color_map_path} 创建。")
except Exception:
# 备用方案 2: 随机颜色
logging.warning(f"{color_map_path} 未找到或解析失败。将使用随机颜色。")
for i in range(1, num_classes): # 保持 0 (背景) 为黑色 [0,0,0]
color_lut[i] = np.random.randint(0, 255, 3, dtype=np.uint8)
# --- [新增] 核心条件判断 ---
# 检查 类别0 (索引0) 的颜色是否为 BGR(0,0,0)
if np.any(temp_lut[0] != [0, 0, 0]):
# 颜色不是黑色,执行 "ID+1" 偏移
logging.info("检测到 类别0 的颜色不是黑色。将执行 ID+1 偏移,使 0 成为专用背景。")
perform_id_shift = True
# [调整] 创建一个 (N+1) 维的最终 LUT
color_lut = np.zeros((num_classes + 1, 3), dtype=np.uint8)
# 将原 0..N-1 的颜色 复制到 1..N
color_lut[1:] = temp_lut
# 索引 0 保持为 [0, 0, 0] (真背景)
else:
# 颜色是黑色,不执行偏移
logging.info("检测到 类别0 的颜色是黑色。将 类别0 视为背景。")
perform_id_shift = False
color_lut = temp_lut # [调整] 使用 (N) 维的原始 LUT
if not Path(source).exists():
logging.error(f"错误:预测源路径不存在: {source}")
return
try:
# --- 3. 加载模型并执行预测 ---
logging.info(f"正在加载模型: {model_path}")
model = YOLO(model_path)
logging.info("模型加载成功。")
logging.info(f"正在对源进行预测: {source}")
results = model.predict(
source=source,
stream=True,
save=False, # <-- 关闭保存,之后手动处理
show_labels=config.SHOW_LABELS,
show_conf=config.SHOW_CONF,
# project=project_name, # <-- 不再需要,我们手动处理
# name="prediction", # <-- 不再需要,我们手动处理
exist_ok=True,
retina_masks=True,
conf=conf_threshold # TODO
)
# --- 4. [扩展] 创建所有输出目录 ---
yolo_prediction_dir = Path(project_name) / "prediction"
raw_mask_save_dir = Path(project_name) / "predicted_raw_masks"
color_mask_save_dir = Path(project_name) / "predicted_color_masks"
comparison_save_dir = Path(project_name) / "predicted_comparison" # [新增]
logging.info(f"预测完成。正在处理结果...")
logging.info(f" -> Yolo预测结果将保存于: {yolo_prediction_dir}")
logging.info(f" -> 原始灰度掩码将保存于: {raw_mask_save_dir}")
logging.info(f" -> 彩色预测掩码将保存于: {color_mask_save_dir}")
logging.info(f" -> 三合一对比图将保存于: {comparison_save_dir}") # [新增]
yolo_prediction_dir.mkdir(parents=True, exist_ok=True)
raw_mask_save_dir.mkdir(parents=True, exist_ok=True)
color_mask_save_dir.mkdir(parents=True, exist_ok=True)
comparison_save_dir.mkdir(parents=True, exist_ok=True) # [新增]
# --- 辅助函数:在图像上添加标签 ---
def add_label_to_image(img, label):
"""在图像左上角添加一个带背景的标签"""
img_with_label = img.copy()
h, w = img_with_label.shape[:2]
# 动态调整字体大小和粗细
font_scale = max(0.8, w // 1000)
thickness = max(1, w // 500)
(text_w, text_h), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
# 绘制黑色背景矩形
cv2.rectangle(img_with_label, (0, 0), (text_w + 20, text_h + 30), (0, 0, 0), -1)
# 绘制白色文字
cv2.putText(img_with_label, label, (10, text_h + 15), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
return img_with_label
# --- 5. [扩展] 处理结果循环 ---
for result in results:
original_image_path = Path(result.path)
logging.info(f"正在处理图片: {original_image_path.name}")
h, w = result.orig_shape
# --- 5a. 保存 Yolo-predict 本身结果 ---
if config.SAVE_PREDICTIONS == True:
# a. 获取原始图片的文件名和后缀
# result.path 是原始图片的完整路径
original_path = Path(result.path)
image_stem = original_path.stem # 这是 "图片名"
image_suffix = original_path.suffix # 这是 ".后缀" (例如 .jpg)
# b. 构建你的新文件名
# 格式: 图片名_{pt_name_raw}.后缀
new_filename = f"{image_stem}_{pt_name_raw}{image_suffix}"
# c. 构建完整的保存路径
save_path = yolo_prediction_dir / new_filename
# d. 获取绘制了检测框的图像 (NumPy 数组)
# result.plot() 会自动使用你在 predict 中设置的 show_labels, show_conf 等参数
annotated_image = result.plot()
# e. 使用 OpenCV (cv2) 保存图像
cv2.imwrite(str(save_path), annotated_image)
print(f"已保存: {save_path}")
# --- 5b. [新增] 查找并加载 Ground Truth 掩码 ---
gt_mask = None
try:
# 假设的路径结构: .../DataSet_Public/5_.../images/val/image.png
# 我们要找: .../DataSet_Public/5_.../labels_GT/val/image.png
image_filename = original_image_path.name
image_parent_dir_name = original_image_path.parent.name # 'val' or 'train'
images_dir = original_image_path.parent.parent # '.../images'
dataset_root_dir = images_dir.parent # '.../5_My_Gastric_2025_10_29'
gt_path = dataset_root_dir / "labels_GT" / image_parent_dir_name / image_filename
if not gt_path.exists():
logging.warning(f" -> 未找到 Ground Truth 文件: {gt_path}")
else:
gt_mask = cv2.imread(str(gt_path), cv2.IMREAD_GRAYSCALE)
if gt_mask is None:
logging.warning(f" -> 读取 Ground Truth 文件失败: {gt_path}")
elif gt_mask.shape != (h, w):
logging.warning(f" -> GT 掩码形状 {gt_mask.shape} 与图像形状 {(h, w)} 不匹配。正在调整GT大小...")
gt_mask = cv2.resize(gt_mask, (w, h), interpolation=cv2.INTER_NEAREST)
except Exception as e:
logging.error(f" -> 查找或加载GT文件时出错: {e}", exc_info=True)
gt_mask = None
# --- 5c. 创建预测掩码 (与之前相同) ---
combined_mask = np.zeros((h, w), dtype=np.uint8)
if result.masks is None:
logging.warning(f" -> 在图片 {original_image_path.name} 中未检测到任何物体。")
else:
masks_data = result.masks.data
class_ids = result.boxes.cls.int().cpu().numpy()
if len(masks_data) != len(class_ids):
logging.error(f" -> 掩码和类别ID数量不匹配跳过图片 {original_image_path.name}")
continue
for i in reversed(range(len(masks_data))):
instance_mask = masks_data[i].cpu().numpy().astype(bool)
class_id = class_ids[i]
# [调整] 根据标志执行不同逻辑
if perform_id_shift:
# 偏移模式:将所有 ID (包括0) 都 +1
combined_mask[instance_mask] = class_id + 1
else:
# 原始模式:忽略 ID 0
if class_id != 0:
combined_mask[instance_mask] = class_id
# --- 5d. [扩展] 保存所有三种输出 ---
# --- 保存 1: 原始灰度掩码 (预测) ---
raw_mask_filename = raw_mask_save_dir / f"{original_image_path.stem}_{pt_name_raw}.png"
cv2.imwrite(str(raw_mask_filename), combined_mask)
# --- 保存 2: 彩色预测掩码 ---
colorized_pred_mask = color_lut[combined_mask] # (H, W, 3)
color_mask_filename = color_mask_save_dir / f"{original_image_path.stem}_{pt_name_raw}.png"
cv2.imwrite(str(color_mask_filename), colorized_pred_mask)
# --- [新增] 保存 3: 三合一对比图 ---
try:
# 1. 获取原始图
original_image_bgr = result.orig_img # YOLO 结果中自带 BGR 格式原图
# 2. 获取彩色 GT 图
if gt_mask is not None:
colorized_gt_mask = color_lut[gt_mask]
else:
# 如果GT不存在创建一个黑色占位图
logging.warning(f" -> 在对比图中将使用黑色图像作为 Ground Truth 占位符。")
colorized_gt_mask = np.zeros_like(original_image_bgr) # (H, W, 3)
# 3. 获取彩色预测图 (已在上面生成: colorized_pred_mask)
# 确保所有图像形状一致 (理论上应该一致)
if not (original_image_bgr.shape == colorized_gt_mask.shape == colorized_pred_mask.shape):
logging.error(f" -> 形状不匹配! 原图: {original_image_bgr.shape}, GT: {colorized_gt_mask.shape}, 预测: {colorized_pred_mask.shape}。跳过对比图。")
continue
# 为每张图添加标签
img_labeled = add_label_to_image(original_image_bgr, "Original")
gt_labeled = add_label_to_image(colorized_gt_mask, "Ground Truth")
pred_labeled = add_label_to_image(colorized_pred_mask, "Prediction")
# 水平拼接
comparison_image = cv2.hconcat([img_labeled, gt_labeled, pred_labeled])
# 保存 (使用 .jpg 格式以节省空间)
comparison_filename = comparison_save_dir / f"{original_image_path.stem}_{pt_name_raw}.jpg"
cv2.imwrite(str(comparison_filename), comparison_image)
except Exception as e:
logging.error(f" -> 创建或保存对比图失败: {e}", exc_info=True)
logging.info(f"所有掩码图和对比图已成功保存。")
except Exception as e:
logging.error(f"预测或手动保存过程中发生错误: {e}", exc_info=True)
if __name__ == "__main__":
# 1. 创建解析器,现在只需要 --model 和 --source
parser = argparse.ArgumentParser(description="使用已训练的YOLO模型进行预测。")
parser.add_argument(
"--model",
type=str,
required=True,
choices=list(config.MODEL_CONFIGS.keys()),
help="选择一个基础模型类型,以筛选其训练历史。"
)
parser.add_argument(
"--source",
type=str,
default=str(config.TEST_IMAGE_DIR),
help="图片或图片文件夹的路径。"
)
parser.add_argument(
"--pt_name",
type=str,
default=str("best.pt"),
help="图片或图片文件夹的路径。"
)
parser.add_argument(
"--conf",
type=float, # 1. 类型应为浮点数
default=0.2,
help="设置预测的置信度阈值 (例如: 0.25)" # 2. 补充 help 信息
)
args = parser.parse_args()
# 2. 根据 --model 参数查找对应的训练历史
available_runs = find_trained_models(config.PREDICT_BEST_MODEL_DIR, args.model, args.pt_name)
run_to_use = None
# 3. 根据找到的结果数量,决定下一步操作
if not available_runs:
logging.error(f"错误:在 {config.PREDICT_BEST_MODEL_DIR} 中没有找到任何关于模型 '{args.model}' 的有效训练记录。")
sys.exit(1)
elif len(available_runs) == 1:
# 如果只有一个结果,自动选择
run_to_use = available_runs[0]
logging.info(f"只找到一个训练版本,已自动选择: {run_to_use}")
else:
# 如果有多个结果,进入交互式选择模式
print(f"\n为模型 '{args.model}' 找到多个训练版本:")
for i, run_name in enumerate(available_runs, 1):
print(f" [{i}] {run_name}")
while True:
try:
choice = input(f"请输入您想使用的版本序号 (1-{len(available_runs)}): ")
choice_index = int(choice)
if 1 <= choice_index <= len(available_runs):
run_to_use = available_runs[choice_index - 1]
break
else:
print("错误:输入无效,请输入列表中的序号。")
except ValueError:
print("错误:请输入一个数字。")
except (KeyboardInterrupt, EOFError):
print("\n操作已取消。")
sys.exit(0)
# 4. 使用最终确定的 run_to_use 来执行预测
if run_to_use:
model_to_use = config.PREDICT_BEST_MODEL_DIR / run_to_use / 'weights' / args.pt_name
project_to_save_in = str(config.PREDICT_BEST_MODEL_DIR / run_to_use)
if not model_to_use.exists():
logging.error(f"严重错误:找不到权重文件 {model_to_use}")
else:
predict(
model_path=str(model_to_use),
source=args.source,
project_name=project_to_save_in,
pt_name = args.pt_name,
conf_threshold=args.conf
)

View File

@@ -0,0 +1,383 @@
import cv2
import numpy as np
import yaml
import logging
import argparse
import sys
from pathlib import Path
from tqdm import tqdm
# --- 日志设置 ---
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()]
)
try:
import yolo_config as config
logging.info(f"成功加载配置 yolo_config.py")
logging.info(f" -> 最佳模型目录: {config.PREDICT_BEST_MODEL_DIR}")
logging.info(f" -> 数据集 YAML: {config.DATASET_YAML_PATH}")
logging.info(f" -> 测试图片目录: {config.TEST_IMAGE_DIR}")
except ImportError:
logging.error("错误: yolo_config.py 未找到。请确保它在同一目录下。")
sys.exit(1)
except AttributeError as e:
logging.error(f"错误: yolo_config.py 中缺少必要的配置。 {e}")
sys.exit(1)
# --- 辅助函数: 1. 添加标签 ---
def add_label_to_image(img, label):
"""在图像左上角添加一个带背景的标签"""
img_with_label = img.copy()
h, w = img_with_label.shape[:2]
# 动态调整字体大小和粗细
font_scale = max(0.8, w // 1000)
thickness = max(1, w // 500)
(text_w, text_h), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
# 绘制黑色背景矩形
cv2.rectangle(img_with_label, (0, 0), (text_w + 20, text_h + 30), (0, 0, 0), -1)
# 绘制白色文字
cv2.putText(img_with_label, label, (10, text_h + 15), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
return img_with_label
# --- 辅助函数: 2. 加载颜色LUT ---
def load_color_lut(dataset_yaml_path):
"""
加载颜色查找表 (LUT),精确复制 yolo_predict_V2.py 的逻辑
"""
try:
with open(dataset_yaml_path, 'r', encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
class_names = yaml_data.get('names')
color_data_from_dataset = yaml_data.get('colors')
if not class_names:
logging.error(f"{dataset_yaml_path} 中未找到 'names' 键。")
return None
num_classes = len(class_names)
logging.info(f"检测到 {num_classes} 个类别 (ID 0 到 {num_classes-1})。")
except Exception as e:
logging.error(f"加载或解析 dataset.yaml 失败: {e}")
return None
# 创建 (N) 维临时 LUT
temp_lut = np.zeros((num_classes, 3), dtype=np.uint8)
if color_data_from_dataset:
logging.info(f"正在从 {dataset_yaml_path} 加载颜色...")
try:
for class_id, rgb_value in color_data_from_dataset.items():
if 0 <= class_id < num_classes:
temp_lut[class_id] = rgb_value[::-1] # BGR
logging.info("临时颜色查找表 (LUT) 已成功从 dataset.yaml 创建。")
except Exception as e:
logging.error(f"解析 'colors' 键失败: {e}。将尝试备用方案。")
color_data_from_dataset = None
if not color_data_from_dataset:
color_map_path = 'color_map.yaml'
logging.warning(f"未从 dataset.yaml 加载颜色,正在尝试备用方案: {color_map_path}")
try:
with open(color_map_path, 'r', encoding='utf-8') as f:
color_data = yaml.safe_load(f)['colors']
for class_id, bgr_value in color_data.items():
if 0 <= class_id < num_classes:
temp_lut[class_id] = bgr_value # 假设 color_map.yaml 存的是 BGR
logging.info(f"颜色查找表 (LUT) 已成功从 {color_map_path} 创建。")
except Exception:
logging.warning(f"{color_map_path} 未找到或解析失败。将使用随机颜色。")
for i in range(1, num_classes):
temp_lut[i] = np.random.randint(0, 255, 3, dtype=np.uint8)
# --- 核心条件判断 (与 yolo_predict_V2.py 完全一致) ---
if np.any(temp_lut[0] != [0, 0, 0]):
logging.info("检测到 类别0 的颜色不是黑色。将创建 (N+1) LUT使 0 成为专用背景。")
color_lut = np.zeros((num_classes + 1, 3), dtype=np.uint8)
color_lut[1:] = temp_lut
else:
logging.info("检测到 类别0 的颜色是黑色。将使用 (N) LUT。")
color_lut = temp_lut
logging.info(f"最终颜色LUT创建成功维度: {color_lut.shape}")
return color_lut
# --- 辅助函数: 3. 查找模型目录 ---
def find_model_dirs(base_dir, required_subdir):
"""查找所有包含 'required_subdir' 子目录的有效模型文件夹"""
model_dirs = []
if not base_dir.is_dir():
logging.error(f"基础目录不存在: {base_dir}")
return []
for d in base_dir.iterdir():
# 必须是一个目录,并且包含所需的预测结果子目录
if d.is_dir() and (d / required_subdir).exists():
model_dirs.append(d)
model_dirs.sort()
return model_dirs
# --- 辅助函数: 4. 查找图像名称 ---
def find_image_names(model_dirs, sub_dir_name, file_glob_patterns):
"""
从第一个模型目录的指定子目录(sub_dir_name)中获取所有待处理的图像文件名
使用 file_glob_patterns 列表 (例如 ['*.jpg', '*.png'])
"""
image_names = set() # 使用集合避免重复
if not model_dirs:
return []
first_model_target_dir = model_dirs[0] / sub_dir_name
for pattern in file_glob_patterns:
for f in first_model_target_dir.glob(pattern):
if f.is_file():
image_names.add(f.name)
sorted_names = sorted(list(image_names))
return sorted_names
# --- 辅助函数: 5. 获取原图和GT路径 ---
def get_gt_and_original_paths(image_name):
"""
根据输出的掩码文件名反向推导原始图像和GT图像的路径。
假设文件名格式为: {original_stem}_{pt_name_raw}.png
"""
try:
# 1. 从 config 中获取关键路径
test_image_dir = config.TEST_IMAGE_DIR
images_dir = test_image_dir.parent
dataset_root_dir = images_dir.parent
image_parent_dir_name = test_image_dir.name # e.g., 'val'
# 2. 推导原始图像的 stem
if '_' not in image_name:
logging.warning(f" -> 图像 {image_name} 似乎不含 '_{{pt_name}}' 后缀,将尝试使用完整文件名作为 stem。")
original_stem = Path(image_name).stem
else:
# 从最后一个 '_' 拆分
original_stem = image_name[:image_name.rfind('_')]
# 3. 在 test_image_dir 中查找原始图像
original_image_path = None
# glob 查找, 匹配 .png, .jpg, .bmp 等
for f in test_image_dir.glob(f"{original_stem}.*"):
if f.is_file() and f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
original_image_path = f
break # 找到第一个匹配项
if not original_image_path:
# 尝试在上一级目录查找 (兼容某些奇怪的 val/test 结构)
for f in test_image_dir.parent.glob(f"{original_stem}.*"):
if f.is_file() and f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
original_image_path = f
break
if not original_image_path:
logging.warning(f" -> 在 {test_image_dir} (及其父目录) 中未找到原始图像 (如 {original_stem}.*)")
return None, None
# 4. 推导GT路径 (逻辑复制自 yolo_predict_V2.py)
image_filename = original_image_path.name
gt_path = dataset_root_dir / "labels_GT" / image_parent_dir_name / image_filename
gt_path_png = gt_path.with_suffix('.png') # 备用
if gt_path.exists():
return original_image_path, gt_path
elif gt_path_png.exists():
return original_image_path, gt_path_png
else:
# 返回预期的路径,主循环将处理 "文件未找到"
return original_image_path, gt_path
except Exception as e:
logging.error(f" -> 推导路径时出错: {e}", exc_info=True)
return None, None
# --- [新增] 核心处理函数 ---
def generate_comparison_images(base_dir, output_dir_name, source_subdir_name,
gt_panel_label, use_gt_overlay,
file_glob_patterns, color_lut):
"""
核心处理函数,用于生成一种类型的对比图。
"""
logging.info(f"--- 🚀 开始生成 '{output_dir_name}' ---")
output_dir = base_dir / output_dir_name
output_dir.mkdir(exist_ok=True)
# 3. 查找模型目录
model_dirs = find_model_dirs(base_dir, source_subdir_name)
if not model_dirs:
logging.error(f"{base_dir} 中未找到任何包含 '{source_subdir_name}' 的有效模型预测目录。")
logging.warning(f"--- ⚠️ 跳过 '{output_dir_name}' 的生成 ---")
return
logging.info(f"找到 {len(model_dirs)} 个模型进行对比:")
for d in model_dirs:
logging.info(f" -> {d.name}")
# 4. 查找图像列表
image_names = find_image_names(model_dirs, source_subdir_name, file_glob_patterns)
if not image_names:
logging.error(f"未找到任何匹配 {file_glob_patterns} 的图像文件。")
logging.error(f"请检查 {model_dirs[0] / source_subdir_name} 目录。")
logging.warning(f"--- ⚠️ 跳过 '{output_dir_name}' 的生成 ---")
return
logging.info(f"找到 {len(image_names)} 张图像进行处理。")
# 5. --- 遍历每张图像 ---
for image_name in tqdm(image_names, desc=f"生成 {output_dir_name}"):
comparison_panels = []
# 5a. 获取路径
orig_path, gt_path = get_gt_and_original_paths(image_name)
if not orig_path:
logging.warning(f"\n跳过: 未能找到 {image_name} 的原始图像。")
continue
# 5b. 加载原图
orig_img = cv2.imread(str(orig_path))
if orig_img is None:
logging.warning(f"\n跳过: 无法读取原始图像 {orig_path}")
continue
h, w = orig_img.shape[:2]
# 1. ORI_PIC
comparison_panels.append(add_label_to_image(orig_img, "ORI_PIC"))
# 5c. 加载 Ground Truth
gt_color = np.zeros_like(orig_img) # 默认黑色
if gt_path and gt_path.exists():
gt_gray = cv2.imread(str(gt_path), cv2.IMREAD_GRAYSCALE)
if gt_gray is not None:
if gt_gray.shape != (h, w):
gt_gray = cv2.resize(gt_gray, (w, h), interpolation=cv2.INTER_NEAREST)
gt_color = color_lut[gt_gray]
else:
logging.warning(f"\n无法读取GT图像 {gt_path},使用黑色占位符。")
else:
logging.warning(f"\nGT图像未找到 (预期路径: {gt_path}),使用黑色占位符。")
# 2. GT Panel (Mask or Overlay)
if use_gt_overlay:
gt_panel = cv2.addWeighted(orig_img, 0.6, gt_color, 0.4, 0)
else:
gt_panel = gt_color
comparison_panels.append(add_label_to_image(gt_panel, gt_panel_label))
# 5d. 加载每个模型的预测结果
for model_dir in model_dirs:
pred_path = model_dir / source_subdir_name / image_name
pred_img = np.zeros_like(orig_img) # 默认黑色
if pred_path.exists():
img = cv2.imread(str(pred_path))
if img is not None:
if img.shape != (h, w, 3):
logging.warning(f"\n预测图 {pred_path.name} 尺寸 {img.shape} 与原图 {(h,w,3)} 不匹配。正在调整大小...")
img = cv2.resize(img, (w, h), interpolation=cv2.INTER_AREA)
pred_img = img
else:
logging.warning(f"\n无法读取预测图 {pred_path},使用黑色占位符。")
else:
logging.warning(f"\n预测图未找到 {pred_path},使用黑色占位符。")
short_label = model_dir.name.split('_')[0]
comparison_panels.append(add_label_to_image(pred_img, short_label))
# 5e. 拼接并保存
try:
final_image = cv2.hconcat(comparison_panels)
save_stem = Path(image_name).stem
save_path = output_dir / (save_stem + ".png")
cv2.imwrite(str(save_path), final_image)
except Exception as e:
logging.error(f"\n--- 拼接或保存 {image_name} 失败: {e} ---")
logging.error(f"面板尺寸 (H, W, C): {[p.shape for p in comparison_panels]}")
logging.warning("请检查所有图像是否具有相同的高度。")
logging.info(f"--- ✅ '{output_dir_name}' 生成完毕 ---")
# --- [修改] 主函数 (重构) ---
def main():
parser = argparse.ArgumentParser(
description="将所有模型的预测结果 (Mask 和 Yolo) 与原图和GT图拼接成对比图。"
)
# 模式参数已移除
parser.add_argument(
"--pt_name",
type=str,
default="all",
help="您想比较的权重文件名称 (例如 'best.pt''epoch100.pt''all'表示对全部权重进行处理)"
)
args = parser.parse_args()
base_dir = config.PREDICT_BEST_MODEL_DIR
# 1. 加载颜色LUT (只需一次)
color_lut = load_color_lut(config.DATASET_YAML_PATH)
if color_lut is None:
logging.error("无法加载颜色LUT脚本终止。")
return
# 2. 定义通用文件模式
pt_name_raw = args.pt_name.replace('.pt', '')
all_image_types = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tif", "*.tiff"]
# 3. --- 执行 "Mask" 模式 ---
mask_patterns = []
if args.pt_name.lower() == "all":
mask_patterns = ["*.png"]
else:
mask_patterns = [f"*_{pt_name_raw}.png"]
generate_comparison_images(
base_dir=base_dir,
output_dir_name="compare_all_masks",
source_subdir_name="predicted_color_masks",
gt_panel_label="Ground_Truth",
use_gt_overlay=False,
file_glob_patterns=mask_patterns,
color_lut=color_lut
)
# 4. --- 执行 "Yolo" 模式 ---
yolo_patterns = []
if args.pt_name.lower() == "all":
yolo_patterns = all_image_types
else:
yolo_patterns = [
f"*_{pt_name_raw}.jpg", f"*_{pt_name_raw}.jpeg",
f"*_{pt_name_raw}.png", f"*_{pt_name_raw}.bmp",
f"*_{pt_name_raw}.tif", f"*_{pt_name_raw}.tiff"
]
generate_comparison_images(
base_dir=base_dir,
output_dir_name="compare_all_masks_Yolo",
source_subdir_name="prediction",
gt_panel_label="GT_Overlay",
use_gt_overlay=True,
file_glob_patterns=yolo_patterns,
color_lut=color_lut
)
logging.info("--- 🏁 所有对比图任务执行完毕 ---")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,233 @@
import sys
import yaml
from pathlib import Path
from typing import Set, Tuple
import argparse # <-- [新增]
import logging # <-- [新增]
try:
# Import config to get base directories and dataset configuration path
import yolo_config as config
except ImportError:
print("错误:无法导入 'yolo_config.py'。请确保此脚本与 config.py 在同一目录下。")
sys.exit(1)
def get_file_info(directory: Path) -> Tuple[Set[str], Set[str]]:
"""
安全地获取目录中所有文件的基本名stem和后缀suffix
返回:
一个包含 stems 集合和 suffixes 集合的元组。
"""
if not directory.is_dir():
return set(), set()
stems = set()
suffixes = set()
for item in directory.iterdir():
if item.is_file():
stems.add(item.stem)
suffixes.add(item.suffix)
return stems, suffixes
def get_source_labels_dir() -> Path:
"""
从 dataset.yaml 解析并构建源标签目录的绝对路径。
路径的最后一部分 (例如 'train', 'val''test') 将根据 dataset.yaml 中的 'test' 键动态确定。
"""
try:
with open(config.DATASET_YAML_PATH, 'r', encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
# 1. 获取数据集的根目录相对路径
relative_path_from_yaml = yaml_data.get('path')
if not relative_path_from_yaml:
print(f"错误: 在 '{config.DATASET_YAML_PATH}' 文件中没有找到 'path' 键。")
sys.exit(1)
# 2. 【新增】获取 'test' 键的值 (例如 'images/val')
test_path_from_yaml = yaml_data.get('test')
if not test_path_from_yaml:
print(f"错误: 在 '{config.DATASET_YAML_PATH}' 文件中没有找到 'test' 键。")
sys.exit(1)
# 3. 【新增】从 'images/val' 中提取最后一部分 'val'
# Path(...).name 可以轻松实现这个功能
source_split = Path(test_path_from_yaml).name
# 4. 构建数据集的绝对路径
dataset_dir = (config.DATASET_YAML_PATH.parent / relative_path_from_yaml).resolve()
# 5. 【已修改】使用动态提取的 source_split 构建并返回标签目录路径
labels_dir = dataset_dir / "labels" / source_split
return labels_dir
except FileNotFoundError:
print(f"错误: YAML 配置文件未找到: '{config.DATASET_YAML_PATH}'")
sys.exit(1)
except Exception as e:
print(f"读取或解析 YAML 文件时发生错误: {e}")
sys.exit(1)
def main():
"""
主函数,用于详细检查源训练标签集和所有预测输出的原始掩码
目录之间的文件数量和文件名一致性。
"""
# --- [新增] ---
parser = argparse.ArgumentParser(description="检查预测掩码与源标签的一致性。")
parser.add_argument(
"--pt_name",
type=str,
default="best.pt",
help="要检查的权重文件名称 (例如 'best.pt''epoch100.pt')。脚本将查找以此为后缀的掩码文件。"
)
args = parser.parse_args()
# 从 'best.pt' 提取 'best'
pt_name_raw = Path(args.pt_name).stem
# 构建预期的后缀, e.g., '_best'
expected_suffix = f"_{pt_name_raw}"
# [核心修改] 只有 'best.pt' 会兼容检查无后缀的传统文件
check_traditional = (args.pt_name == "best.pt")
print(f"--- 开始详细检查预测原始掩码 (predicted_raw_masks) ---")
if check_traditional:
print(f"[i] 本次检查目标: --pt_name='{args.pt_name}' (检查后缀 '{expected_suffix}' 或 无后缀的传统文件)")
else:
print(f"[i] 本次检查目标: --pt_name='{args.pt_name}' (仅检查后缀 '{expected_suffix}' 的文件)")
# 1. 获取源标签目录并解析文件信息
source_dir = get_source_labels_dir()
if not source_dir.exists() or not source_dir.is_dir():
print(f"错误:在 '{source_dir}' 找不到源训练标签目录。请检查 dataset.yaml 中的 'path' 设置。")
sys.exit(1)
source_stems, _ = get_file_info(source_dir)
source_file_count = len(source_stems)
if source_file_count == 0:
print(f"警告:源标签目录 '{source_dir}' 为空或不包含任何文件。检查中止。")
sys.exit(0)
# 在你的日志中,路径是 '1_CholecSeg8k-13Type-1920x1080',但配置文件是 '4_Dresden-11Type-512x512'
# 这里我们以代码逻辑为准,它会动态读取配置文件
print(f"源标签集: 在 '{source_dir}' 中找到 {source_file_count} 个标签文件 (.txt)。")
print("-" * 60)
# 2. 获取预测结果的基础目录
predictions_base_dir = config.PREDICT_BEST_MODEL_DIR
if not predictions_base_dir.exists() or not predictions_base_dir.is_dir():
print(f"错误:在 '{predictions_base_dir}' 找不到预测结果的基础目录。")
sys.exit(1)
# 3. 遍历每个模型的运行目录并进行详细检查
found_predictions = False
global_mismatch_found = False
# 【已修正】直接查找所有模型的运行目录,例如 'YOLOv8n-seg_2025-09-20_10-00-00'
run_dirs = sorted([d for d in predictions_base_dir.iterdir() if d.is_dir()])
if not run_dirs:
print(f"信息:在基础目录 '{predictions_base_dir}' 中没有找到任何模型的预测结果文件夹。")
sys.exit(0)
# 【已修正】移除外层循环,直接遍历 run_dirs
for run_dir in run_dirs:
target_dir = run_dir / "predicted_raw_masks"
# 如果目标目录不存在,则跳过
if not (target_dir.exists() and target_dir.is_dir()):
continue
print(f"正在检查: '{target_dir}'...")
found_predictions = True
is_current_dir_ok = True
# 获取目录中所有的stems
all_predicted_stems, _ = get_file_info(target_dir)
# --- [核心修改] ---
# 1. 提取我们关心的stems
processed_stems = set() # 存储处理后 (去掉后缀) 的stems, e.g., {'fileA', 'fileB'}
stems_we_checked = set() # 存储我们检查过的原始stems, e.g., {'fileA_best', 'fileA'}
for stem in all_predicted_stems:
# 检查 1: 是否为带后缀的新模式 (e.g., "fileA_best" 或 "fileA_epoch100")
if stem.endswith(expected_suffix):
# 去掉后缀, e.g., 'fileA_best' -> 'fileA'
unstripped_stem = stem[:-len(expected_suffix)]
# 仅当去掉后缀后的部分在源文件中时才添加
if unstripped_stem in source_stems:
processed_stems.add(unstripped_stem)
stems_we_checked.add(stem)
# 检查 2: [条件] 是否允许检查传统文件 (check_traditional == True)
# [条件] 且该文件是否为无后缀的传统文件 (e.g., "fileA")
elif check_traditional and (stem in source_stems):
# 兼容 'best.pt' 检查传统文件
processed_stems.add(stem)
stems_we_checked.add(stem)
# 2. 现在,'processed_stems' 只包含与 'source_stems' 对应的基础文件名
predicted_file_count = len(processed_stems)
# --- 检查 1: 文件数量 ---
if source_file_count != predicted_file_count:
print(f" [✗ 数量不匹配] 源目录有 {source_file_count} 个文件,但此目录中(匹配检查规则的)有 {predicted_file_count} 个。")
is_current_dir_ok = False
# --- 检查 2: 文件名 (基于处理后的stems) ---
missing_files = source_stems - processed_stems
if missing_files:
examples = list(missing_files)[:3]
print(f" [✗ 文件名缺失] 预测结果中缺少 {len(missing_files)} 个文件 (匹配检查规则)。例如: {examples}...")
is_current_dir_ok = False
extra_files = processed_stems - source_stems
if extra_files:
examples = list(extra_files)[:3]
print(f" [✗ 文件名多余] (逻辑异常) 预测结果中多出 {len(extra_files)} 个文件。例如: {examples}...")
is_current_dir_ok = False
# --- [新增] 检查 3: 提示其他被忽略的文件 ---
other_files = all_predicted_stems - stems_we_checked
if other_files:
examples = list(other_files)[:3]
print(f" [i 信息] 目录中还包含 {len(other_files)} 个其他文件 (例如: {examples}...)。")
if check_traditional:
print(f" [i 信息] (这些文件不匹配后缀 '{expected_suffix}' 且不属于无后缀传统文件,已忽略)")
else:
print(f" [i 信息] (这些文件不匹配后缀 '{expected_suffix}',已忽略)")
# --- 当前目录的检查总结 ---
if is_current_dir_ok:
# 再次检查目录为空的特殊情况
if predicted_file_count == 0 and source_file_count > 0:
print(f" [✗ 目录为空] 预测目录中没有找到任何匹配本次检查规则的文件。")
global_mismatch_found = True
else:
print(f" [✓ OK] 所有检查通过 (文件数量: {predicted_file_count},匹配检查规则)")
else:
global_mismatch_found = True
print("-" * 30)
# 4. 输出最终的全局检查摘要
print("\n--- 检查摘要 ---")
if not found_predictions:
print("结果: 没有找到任何 'predicted_raw_masks' 目录进行检查。")
elif global_mismatch_found:
print("结论: 检查不通过。发现至少一个预测目录未能与源标签集完全匹配,请查看上方日志。")
sys.exit(1)
else:
print("结论: 检查通过!所有已找到的 'predicted_raw_masks' 目录均通过了与源标签集的文件名和数量一致性检查。")
print("--- 检查完成 ---")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,538 @@
import logging
import sys
import argparse
import shutil
import os
from pathlib import Path
from typing import List, Dict, Any
from ultralytics import YOLO
# --- [!! 关键 !!] 动态添加项目根目录 (来自之前的讨论) ---
script_path = Path(__file__).resolve()
project_root = script_path.parent.parent
sys.path.insert(0, str(project_root))
# ----------------------------------------------------
import yolo_config as config
# --- [新增] 导入 V2 所需的库 ---
import cv2
import numpy as np
import torch
from pytorch_grad_cam import (
GradCAM, GradCAMPlusPlus, XGradCAM, EigenCAM,
HiResCAM, LayerCAM, RandomCAM, EigenGradCAM
)
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.base_cam import BaseCAM
# --- 日志设置 ---
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()]
)
# --- [保留] V2 中的 CAM 方法字典 ---
CAM_METHODS = {
"GradCAM": GradCAM, # ok # 速度快
"GradCAMPlusPlus": GradCAMPlusPlus, # ok # 速度快
# "XGradCAM": XGradCAM, # ok # 速度快
"EigenCAM": EigenCAM,
"HiResCAM": HiResCAM, # ok # 速度快
# "LayerCAM": LayerCAM, # ok # 速度快
"RandomCAM": RandomCAM, # ok # 速度快
# "EigenGradCAM": EigenGradCAM, # 这个耗时长
}
# --- [保留] V2 中的 ActivationMaximizationTarget ---
class ActivationMaximizationTarget:
def __init__(self, channel=0):
self.channel = channel
def __call__(self, model_output):
if model_output.ndim == 4:
return model_output[:, self.channel, :, :].mean()
elif model_output.ndim == 3:
return model_output[self.channel, :, :].mean()
else:
raise ValueError(f"Unsupported model_output shape: {model_output.shape}")
# --- [保留] 图像预处理函数 ---
def preprocess_image(bgr_img: np.ndarray, imgsz: int, device: str) -> (torch.Tensor, Dict[str, Any]):
# (函数内容与原版相同,为节省篇幅已折叠)
orig_h, orig_w = bgr_img.shape[:2]
rgb_img = bgr_img[:, :, ::-1]
scale = min(imgsz / orig_w, imgsz / orig_h)
new_w, new_h = int(orig_w * scale), int(orig_h * scale)
resized_image = cv2.resize(rgb_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
pad_w, pad_h = imgsz - new_w, imgsz - new_h
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
padded_image = cv2.copyMakeBorder(resized_image, pad_top, pad_bottom, pad_left, pad_right,
cv2.BORDER_CONSTANT, value=(114, 114, 114))
img_float = np.float32(padded_image) / 255.0
input_tensor = torch.from_numpy(img_float).permute(2, 0, 1).unsqueeze(0).to(device)
padding_info = {
"pad_top": pad_top, "new_h": new_h, "pad_left": pad_left,
"new_w": new_w, "orig_h": orig_h, "orig_w": orig_w
}
return input_tensor, padding_info
# --- [保留] V1 中的关键修复函数 ---
def reshape_transform(outputs):
if isinstance(outputs, tuple):
return outputs[0]
return outputs
# --- [保留] V1 中的模型查找函数 ---
def find_trained_models(outputs_dir: Path, model_key: str, pt_name: str) -> List[str]:
# (函数内容与原版相同,为节省篇幅已折叠)
trained_models = []
if not outputs_dir.is_dir():
logging.warning(f"输出目录不存在: {outputs_dir}")
return []
for project_folder in outputs_dir.iterdir():
if project_folder.is_dir() and project_folder.name.startswith(model_key + '_'):
if (project_folder / 'weights' / pt_name).exists():
trained_models.append(project_folder.name)
trained_models.sort()
return trained_models
# --- [!! 核心重构 !!] 融合 V1 和 V2 的可视化函数 ---
def visualize_nn_comprehensive(
model_path: str,
source_dir: str,
base_save_dir: Path,
pt_name: str,
cam_method_name: str,
target_layer_str: str,
model_key: str # [!! 新增 !!] 接收模型 Key (e.g., "YOLO11m-seg")
):
"""
使用 V2 的多方法和 V1 的框架进行 CAM 可视化。
"""
# --- 1. 加载模型 (V1 逻辑) ---
try:
logging.info(f"正在加载模型: {model_path}")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = YOLO(model_path).to(device)
model.model.eval()
imgsz = model.model.args['imgsz']
logging.info(f"模型加载成功。设备: {device}, 图像尺寸: {imgsz}")
except Exception as e:
logging.error(f"无法加载模型 {model_path}: {e}", exc_info=True)
return
# --- 2. [!! 关键修改 !!] 解析目标层 ---
target_layers_to_process = [] # 存储 (name, layer_module) 元组
if target_layer_str.strip().lower() == "default":
logging.info("检测到 'default' 目标层,正在解析...")
# 对 YOLOv8 系列模型的 'default' 逻辑 # TODO # 对模型进行修改
if model_key.lower().startswith("yolov8"):
logging.info(f"检测到 YOLOv8 系列模型 ({model_key})。将使用精选的8个默认层。")
default_yolo8_layers = [
(1, "model.model.model[1]"), # Layer_01_Conv
(4, "model.model.model[4]"), # Layer_04_C2f
(9, "model.model.model[9]"), # Layer_09_SPPF
(11, "model.model.model[11]"), # Layer_11_Concat
(14, "model.model.model[14]"), # Layer_14_Concat
(19, "model.model.model[19]"), # Layer_19_Conv
(20, "model.model.model[20]"), # Layer_20_Concat
(21, "model.model.model[21]") # Layer_21_C2f
]
# 如果使用 EigenCAM 或 EigenGradCAM则只分析最后两个层
if cam_method_name in ["EigenCAM", "EigenGradCAM"]:
logging.info(f" -> 使用 {cam_method_name}。将自动裁剪为最后2个默认层。")
default_yolo8_layers = default_yolo8_layers[-3:] # (只保留最后两个)
for idx, path_str in default_yolo8_layers:
try:
layer = model.model.model[idx]
target_layers_to_process.append((path_str, layer))
logging.info(f" -> 已添加 YOLOv8 默认层: {path_str} (索引 {idx})")
except Exception as e:
logging.warning(f" -> [!] 无法添加 YOLOv8 默认层 {path_str}: {e}")
# 对 YOLOv9 系列模型的 'default' 逻辑 # TODO # 对模型进行修改
elif model_key.lower().startswith("yolov9"):
logging.info(f"检测到 YOLOv9 系列模型 ({model_key})。将使用精选的8个默认层。")
# (这是您新指定的列表)
default_yolo9_layers = [
(2, "model.model.model[2]"), # Layer_02_Conv
(7, "model.model.model[7]"), # Layer_07_RepNCSPELAN4
(29, "model.model.model[29]"), # Layer_29_SPPELAN
(31, "model.model.model[31]"), # Layer_31_Concat
(34, "model.model.model[34]"), # Layer_34_Concat
(37, "model.model.model[37]"), # Layer_37_Concat
(40, "model.model.model[40]"), # Layer_40_Concat
(41, "model.model.model[41]") # Layer_41_RepNCSPELAN4
]
# 如果使用 EigenCAM 或 EigenGradCAM则只分析最后两个层
if cam_method_name in ["EigenCAM", "EigenGradCAM"]:
logging.info(f" -> 使用 {cam_method_name}。将自动裁剪为最后2个默认层。")
default_yolo9_layers = default_yolo9_layers[-3:] # (只保留最后两个)
for idx, path_str in default_yolo9_layers:
try:
layer = model.model.model[idx]
target_layers_to_process.append((path_str, layer))
logging.info(f" -> 已添加 YOLOv9 默认层: {path_str} (索引 {idx})")
except Exception as e:
logging.warning(f" -> [!] 无法添加 YOLOv9 默认层 {path_str}: {e}")
# 对 YOLO11 系列模型的 'default' 逻辑 # TODO # 对模型进行修改
elif model_key.lower().startswith("yolo11"):
logging.info(f"检测到 YOLO11 系列模型 ({model_key})。将使用精选的8个默认层。")
# (索引, 路径字符串)
default_yolo11_layers = [
(0, "model.model.model[0]"), # Layer_00_Conv
(4, "model.model.model[4]"), # Layer_04_C3k2
(9, "model.model.model[9]"), # Layer_09_SPPF
(12, "model.model.model[12]"), # Layer_12_Concat
(15, "model.model.model[15]"), # Layer_15_Concat
(20, "model.model.model[20]"), # Layer_20_Conv
(21, "model.model.model[21]"), # Layer_21_Concat
(22, "model.model.model[22]") # Layer_22_C3k2
]
# 如果使用 EigenCAM 或 EigenGradCAM则只分析最后两个层
if cam_method_name in ["EigenCAM", "EigenGradCAM"]:
logging.info(f" -> 使用 {cam_method_name}。将自动裁剪为最后2个默认层。")
default_yolo11_layers = default_yolo11_layers[-3:] # (只保留最后两个)
for idx, path_str in default_yolo11_layers:
try:
# 假设 'model.model.model' 是一个列表
layer = model.model.model[idx]
target_layers_to_process.append((path_str, layer))
logging.info(f" -> 已添加 YOLO11 默认层: {path_str} (索引 {idx})")
except Exception as e:
logging.warning(f" -> [!] 无法添加 YOLO11 默认层 {path_str}: {e}")
# 对 YOLO12 系列模型的 'default' 逻辑 # TODO # 对模型进行修改
elif model_key.lower().startswith("yolo12"):
logging.info(f"检测到 YOLOv12 系列模型 ({model_key})。将使用您精选的8个默认层。")
# (这是您新指定的列表)
default_yolo12_layers = [
(2, "model.model.model[2]"), # Layer_02_C3k2
(4, "model.model.model[4]"), # Layer_04_C3k2
(8, "model.model.model[8]"), # Layer_08_A2C2f
(10, "model.model.model[10]"), # Layer_10_Concat
(13, "model.model.model[13]"), # Layer_13_Concat
(17, "model.model.model[17]"), # Layer_17_A2C2f
(19, "model.model.model[19]"), # Layer_19_Concat
(20, "model.model.model[20]") # Layer_20_C3k2
]
# 如果使用 EigenCAM 或 EigenGradCAM则只分析最后两个层
if cam_method_name in ["EigenCAM", "EigenGradCAM"]:
logging.info(f" -> 使用 {cam_method_name}。将自动裁剪为最后2个默认层。")
default_yolo12_layers = default_yolo12_layers[-3:] # (只保留最后两个)
for idx, path_str in default_yolo12_layers:
try:
layer = model.model.model[idx]
target_layers_to_process.append((path_str, layer))
logging.info(f" -> 已添加 YOLOv12 默认层: {path_str} (索引 {idx})")
except Exception as e:
logging.warning(f" -> [!] 无法添加 YOLOv12 默认层 {path_str}: {e}")
# [!! 保留 !!] 非 YOLOv8、YOLOv9、YOLO11、YOLO12 模型的 'default' 逻辑
else:
logging.info(f"非 YOLO11 模型 ({model_key})。将使用标准 'default' 逻辑。")
try:
layer = model.model.model[8].proto.cv3 # (来自您 V2 文件的 `[8]`)
name = "model.model.model[8].proto.cv3"
target_layers_to_process.append((name, layer))
logging.info(f" -> 已添加默认层: {name}")
except Exception as e:
logging.warning(f" -> 无法定位 'model[8].proto.cv3'。回退到 'model[15]'。错误: {e}")
try:
layer = model.model.model[15]
name = "model.model.model[15]"
target_layers_to_process.append((name, layer))
logging.info(f" -> 已添加回退层: {name}")
except Exception as e_fallback:
logging.error(f" -> [!] 无法定位 'model[15]'。跳过 'default'。错误: {e_fallback}")
# [!! 保留 !!] 处理非 'default' 的自定义层路径
else:
logging.info(f"正在解析自定义目标层: {target_layer_str}")
layer_paths = target_layer_str.split(',')
for path in layer_paths:
path_cleaned = path.strip()
if not path_cleaned:
continue
try:
layer = eval(path_cleaned, {"model": model})
if isinstance(layer, torch.nn.Module):
target_layers_to_process.append((path_cleaned, layer))
logging.info(f" -> 已添加指定层: {path_cleaned}")
else:
logging.warning(f" -> 路径 '{path_cleaned}' 不是一个 nn.Module已跳过。")
except Exception as e:
logging.error(f" -> [!] 无法解析目标层路径 '{path_cleaned}': {e}。已跳过。")
if not target_layers_to_process:
logging.error("未找到有效的目标层进行可视化。程序退出。")
return
# --- 3. 查找源图片 (V1 逻辑) ---
source_path = Path(source_dir)
image_paths = []
if source_path.is_file():
image_paths = [source_path]
elif source_path.is_dir():
image_paths = sorted(
list(source_path.glob("*.jpg")) +
list(source_path.glob("*.jpeg")) +
list(source_path.glob("*.png"))
)
if not image_paths:
logging.error(f"在源路径中未找到任何图片: {source_dir}")
return
# --- 4. 获取 CAM 方法 (V2 逻辑) ---
cam_class = CAM_METHODS.get(cam_method_name)
if not cam_class:
logging.error(f"无效的 CAM 方法: {cam_method_name}")
return
logging.info(f"--- 将使用 CAM 方法: {cam_method_name} ---")
# --- 5. [!! 循环顺序保留 !!] ---
targets = None
if cam_method_name != "EigenCAM":
targets = [ActivationMaximizationTarget(channel=0)]
logging.info("已为非 EigenCAM 方法设置 ActivationMaximizationTarget(channel=0)。")
pt_name_raw = pt_name.replace('.pt', '') # 移到循环外
# 外循环:遍历图片
for img_path in image_paths:
logging.info(f"\n--- 正在处理图片: {img_path.name} ---")
try:
# --- a. 加载和预处理 ---
orig_img_bgr = cv2.imread(str(img_path))
if orig_img_bgr is None:
logging.warning(f" -> 无法读取图片: {img_path.name},已跳过。")
continue
orig_img_rgb_float = np.float32(orig_img_bgr[:, :, ::-1]) / 255
input_tensor, pad_info = preprocess_image(orig_img_bgr, imgsz, device)
input_tensor.requires_grad_()
# --- b. [!! 关键修改 !!] 创建新的三层输出目录 ---
image_stem = img_path.stem
# 1. 创建算法根目录 (e.g., .../train6/HeartMap_Visual/GradCAM/)
cam_root = base_save_dir / "HeartMap_Visual" / f"{cam_method_name}_{pt_name_raw}"
# 2. 创建该图片专用的子目录 (e.g., .../HeartMap_Visual/GradCAM/image01_GradCAM/)
nn_visual_root = cam_root / f"{image_stem}_{cam_method_name}"
nn_visual_root.mkdir(parents=True, exist_ok=True)
# 3. 创建该图片专用的 _ORI 子目录 (e.g., .../HeartMap_Visual/GradCAM/image01_GradCAM_ORI/)
nn_visual_ori_root = cam_root / f"{image_stem}_{cam_method_name}_ORI"
nn_visual_ori_root.mkdir(parents=True, exist_ok=True)
# [!! 关键修改 !!] 更新日志打印路径,使其相对于 base_save_dir
logging.info(f" -> 覆盖图将保存到: {nn_visual_root.relative_to(base_save_dir)}")
logging.info(f" -> 纯热度图将保存到: {nn_visual_ori_root.relative_to(base_save_dir)}")
# --- c. 内循环:遍历层 ---
for layer_name, layer_module in target_layers_to_process:
safe_layer_name = layer_name.replace('.', '_').replace('[', '').replace(']', '')
logging.info(f" -> 正在处理层: {layer_name}")
try:
# --- d. 初始化 CAM ---
with cam_class(model=model.model,
target_layers=[layer_module],
reshape_transform=reshape_transform) as cam:
# --- e. 运行 CAM ---
cam_output = cam(input_tensor=input_tensor, targets=targets)
# --- f. CAM 输出后处理 ---
# (此部分逻辑与 V2 相同)
grayscale_cam = None
if isinstance(cam_output, (list, tuple)):
grayscale_cam = cam_output[0]
else:
grayscale_cam = cam_output
if isinstance(grayscale_cam, torch.Tensor):
grayscale_cam = grayscale_cam.detach().cpu().numpy()
if grayscale_cam.ndim == 4:
grayscale_cam = grayscale_cam[0].mean(axis=0)
elif grayscale_cam.ndim == 3:
grayscale_cam = grayscale_cam[0, :]
elif grayscale_cam.ndim == 2:
pass # OK
else:
logging.warning(f" -> [!] 层 {layer_name} 的 CAM 结果维度异常: {grayscale_cam.shape},已跳过。")
continue
# --- g. 裁剪 Padding 并 Resize ---
cam_cropped = grayscale_cam[pad_info['pad_top'] : pad_info['pad_top'] + pad_info['new_h'],
pad_info['pad_left'] : pad_info['pad_left'] + pad_info['new_w']]
cam_upscaled = cv2.resize(cam_cropped,
(pad_info['orig_w'], pad_info['orig_h']),
interpolation=cv2.INTER_LINEAR)
cam_upscaled = (cam_upscaled - cam_upscaled.min()) / (cam_upscaled.max() + 1e-8)
# --- h. 定义文件名和路径 ---
save_filename_overlay = f"Layer_{safe_layer_name}_overlay.jpg"
save_filename_ori = f"Layer_{safe_layer_name}_heatmap.jpg"
save_path_overlay = nn_visual_root / save_filename_overlay
save_path_ori = nn_visual_ori_root / save_filename_ori
# --- i. 保存覆盖图 ---
visualization = show_cam_on_image(orig_img_rgb_float,
cam_upscaled,
use_rgb=True,
image_weight=0.5)
viz_bgr = visualization[:, :, ::-1]
cv2.imwrite(str(save_path_overlay), viz_bgr)
# --- j. 保存纯热度图 ---
heatmap_uint8 = np.uint8(255 * cam_upscaled)
heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
cv2.imwrite(str(save_path_ori), heatmap_color)
logging.info(f" -> (√) {layer_name} 可视化结果已保存")
except Exception as e_layer:
logging.error(f" -> [!] 处理层 {layer_name} 时发生错误: {e_layer}", exc_info=False)
except Exception as e_img:
logging.error(f" -> [!] 处理图片 {img_path.name} 时发生严重错误: {e_img}", exc_info=True)
logging.info("\n--- 所有图片已处理完毕 ---")
if __name__ == "__main__":
config.show_config_summary()
parser = argparse.ArgumentParser(description="使用 Grad-CAM (EigenCAM) 生成 YOLO 神经网络可视化。")
parser.add_argument(
"--model",
type=str,
required=True,
choices=list(config.MODEL_CONFIGS.keys()),
help="选择一个基础模型类型来筛选其训练历史。"
)
parser.add_argument(
"--source",
type=str,
default=str(config.TEST_IMAGE_DIR),
help="图片或图片文件夹的路径。"
)
parser.add_argument(
"--pt_name",
type=str,
default="best.pt",
help="要使用的权重文件名 (例如 'best.pt''epoch100.pt')。"
)
all_cam_choices = list(CAM_METHODS.keys()) + ["All"]
parser.add_argument(
"--cam_method",
type=str,
default="All",
choices=all_cam_choices,
help=f"选择要使用的 CAM 可视化方法。默认为: All。"
)
parser.add_argument(
"--target_layers",
type=str,
default="default",
help="指定要可视化的目标层。 "
"'default': (YOLOv8/YOLO11系列使用精选8层其他使用标准默认层); "
"'model.model.model[15]': (指定单个层); "
"'model.model.model[10],model.model.model[15]': (指定多个层)"
)
args = parser.parse_args()
# --- [!! 关键修改 !!] 移除路径校准 ---
# 假设 yolo_config.py 已被修复config.PREDICT_BEST_MODEL_DIR 是绝对路径
logging.info(f"正在搜索模型目录: {config.PREDICT_BEST_MODEL_DIR}")
available_runs = find_trained_models(config.PREDICT_BEST_MODEL_DIR, args.model, args.pt_name)
run_to_use = None
if not available_runs:
logging.error(f"错误:在 {config.PREDICT_BEST_MODEL_DIR} 中未找到模型 '{args.model}' (使用 '{args.pt_name}') 的有效训练记录。")
sys.exit(1)
elif len(available_runs) == 1:
run_to_use = available_runs[0]
logging.info(f"只找到一个训练版本,已自动选择: {run_to_use}")
else:
# (V1 的模型选择菜单)
print(f"\n为模型 '{args.model}' (使用 '{args.pt_name}') 找到多个训练版本:")
for i, run_name in enumerate(available_runs, 1):
print(f" [{i}] {run_name}")
while True:
try:
choice = input(f"请输入您想使用的版本序号 (1-{len(available_runs)}): ")
choice_index = int(choice)
if 1 <= choice_index <= len(available_runs):
run_to_use = available_runs[choice_index - 1]
break
else:
print("错误:输入无效,请输入列表中的序号。")
except ValueError:
print("错误:请输入一个数字。")
except (KeyboardInterrupt, EOFError):
print("\n操作已取消。")
sys.exit(0)
# --- 4. 调用 [!! 新的 !!] 可视化函数 ---
if run_to_use:
model_to_use = config.PREDICT_BEST_MODEL_DIR / run_to_use / 'weights' / args.pt_name
project_to_save_in = config.PREDICT_BEST_MODEL_DIR / run_to_use
if not model_to_use.exists():
logging.error(f"严重错误:找不到权重文件 {model_to_use}")
else:
methods_to_run = []
if args.cam_method == "All":
methods_to_run = list(CAM_METHODS.keys()) #
logging.info(f"--- [!!] 检测到 'All' 选项,将运行全部 {len(methods_to_run)} 种 CAM 方法 ---")
else:
methods_to_run = [args.cam_method] # 列表中只有一项
for method_name in methods_to_run:
# (可选) 检查 Eigen* 方法是否与 'default' 一起使用
is_slow_method = method_name in ["EigenCAM", "EigenGradCAM"] #
is_default_layers = args.target_layers == "default" #
if is_slow_method and not is_default_layers:
logging.warning(f"--- [警告] 您正在对自定义层运行 {method_name}")
logging.warning(f"--- 如果目标层是浅层,可能会非常缓慢或卡住。---")
logging.info(f"\n--- 正在启动 CAM 方法: [ {method_name} ] ---")
# 将原始调用放入循环中
visualize_nn_comprehensive(
model_path=str(model_to_use),
source_dir=args.source,
base_save_dir=project_to_save_in,
pt_name=args.pt_name,
cam_method_name=method_name, # [!! 修改 !!] 使用循环中的变量
target_layer_str=args.target_layers,
model_key=args.model
)
logging.info("\n--- [!!] 所有指定的 CAM 方法均已执行完毕 ---")

View File

@@ -0,0 +1,158 @@
import argparse, shutil, glob
import logging
from pathlib import Path
from ultralytics import YOLO
from datetime import datetime
import yolo_config as config
import gc, torch
# --- 日志设置 ---
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler()]
)
# 寻找最匹配的 last.pt文件 和 其所在文件夹
def find_latest_last_pt(model_key: str, outputs_dir: Path):
"""根据模型名查找最新的 last.pt 路径,并返回 (last_pt路径, 文件夹名)"""
pattern = str(outputs_dir / f"{model_key}_*" / "weights" / "last.pt")
candidates = glob.glob(pattern)
if not candidates:
return None, None
# 按修改时间排序,取最新的
latest_pt = max(candidates, key=lambda p: Path(p).stat().st_mtime)
# 提取文件夹名
save_folder_name = Path(latest_pt).parent.parent.name # weights 的上级就是训练目录
return latest_pt, save_folder_name
def train_model(model_key: str):
"""
根据给定的模型密钥训练YOLO分割模型。
Args:
model_key (str): 在 yolo_config.MODEL_CONFIGS 中定义的模型密钥。
"""
# 生成时间戳以区分不同训练运行
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
project_folder = f"{model_key}_{timestamp}"
if model_key not in config.MODEL_CONFIGS:
logging.error(f"错误:模型 '{model_key}' 不在 yolo_config.py 中定义。")
logging.info(f"可用模型: {list(config.MODEL_CONFIGS.keys())}")
return
model_config = config.MODEL_CONFIGS[model_key]
model_file = model_config['weights']
model_batch_size = model_config['batch_size']
model_image_size = model_config['image_size']
logging.info(f"开始训练模型: {model_key} ({model_file})")
try:
# # V1--- 1. 加载模型 ---
# model = YOLO(model_file) # 如果没有模型会自动下载
# logging.info("模型加载成功。")
# V2--- 1. 加载模型(支持断点续训) ---
last_pt, save_folder_name = find_latest_last_pt(model_key, config.OUTPUTS_DIR)
if last_pt and Path(last_pt).exists():
model = YOLO(last_pt)
logging.info(f"断点续训: {last_pt}")
logging.info(f"续训目录名为: {save_folder_name}")
model.train(
resume=True, # ✅ 关键参数
data=str(config.DATASET_YAML_PATH),
epochs=config.EPOCHS,
imgsz=model_image_size,
batch=model_batch_size,
optimizer=config.OPTIMIZER,
lr0=config.LEARNING_RATE,
device=config.DEVICE,
project=str(config.OUTPUTS_DIR), # 指定输出的根目录
name=save_folder_name, # 指定本次训练的项目名
workers=config.WORKERS,
exist_ok=True, # 如果项目已存在,则覆盖
patience=config.PATIENCE, # 提前停止训练的轮数
save_period=config.SAVE_PERIOD # 每隔多少轮保存一次模型
)
else:
model = YOLO(model_file)
logging.info(f"从头训练: {model_file}")
save_folder_name = project_folder
# --- 2. 模型训练 ---
logging.info(f"数据集配置文件: {config.DATASET_YAML_PATH}")
logging.info(f"训练参数: Epochs={config.EPOCHS}, Batch Size={model_batch_size}, Img Size={model_image_size}")
model.train(
data=str(config.DATASET_YAML_PATH),
epochs=config.EPOCHS,
imgsz=model_image_size,
batch=model_batch_size,
optimizer=config.OPTIMIZER,
lr0=config.LEARNING_RATE,
device=config.DEVICE,
project=str(config.OUTPUTS_DIR), # 指定输出的根目录
name=save_folder_name, # 指定本次训练的项目名
workers=config.WORKERS,
exist_ok=True, # 如果项目已存在,则覆盖
patience=config.PATIENCE, # 提前停止训练的轮数
save_period=config.SAVE_PERIOD # 每隔多少轮保存一次模型
)
logging.info("模型训练完成。")
logging.info(f"训练结果保存在: {config.OUTPUTS_DIR / save_folder_name}")
return save_folder_name
except Exception as e:
logging.error(f"训练过程中发生错误: {e}")
return None
# --- 3. 复制结果到硬盘并删除原文件夹 (新增逻辑) ---
def move_results_to_hardisk(project_folder: str):
source_dir = config.OUTPUTS_DIR / project_folder
# 确保目标硬盘目录存在
outputs_folder_name = config.OUTPUTS_DIR.name
destination_dir = config.HARDISK_DIR / outputs_folder_name / project_folder
destination_dir.parent.mkdir(parents=True, exist_ok=True)
logging.info(f"准备将结果从 {source_dir} 移动到 {destination_dir}...")
try:
# 步骤 1: 复制文件夹
shutil.copytree(source_dir, destination_dir)
logging.info(f"成功复制结果到: {destination_dir}")
# 步骤 2: 复制成功后,删除原文件夹
logging.info(f"正在删除原文件夹: {source_dir}")
shutil.rmtree(source_dir)
logging.info("成功删除原文件夹。")
logging.info(f"任务完成,最终结果已保存至: {destination_dir}")
except Exception as e:
logging.error(f"移动文件夹时发生错误: {e}")
logging.warning(f"原始训练结果仍保留在: {source_dir}")
if __name__ == "__main__":
# 创建命令行参数解析器
parser = argparse.ArgumentParser(description="训练 YOLO 分割模型。")
parser.add_argument(
"--model",
type=str,
required=True,
choices=list(config.MODEL_CONFIGS.keys()),
help=f"选择要训练的模型。可选: {list(config.MODEL_CONFIGS.keys())}"
)
args = parser.parse_args()
# 1. 开始训练
save_folder_name = train_model(args.model)
# 2. 训练完成后,移动结果到硬盘
if save_folder_name:
move_results_to_hardisk(save_folder_name)
else:
logging.error("由于训练失败,未执行结果移动操作。")
# 3. 释放资源
torch.cuda.empty_cache() # 释放 CUDA 显存
gc.collect() # 释放系统内存
logging.info("已清理 CUDA 显存与系统内存。")

View File

@@ -0,0 +1,137 @@
#!/bin/bash
# =================================================================
# YOLO 模型批量并行训练脚本
# =================================================================
# --- 1. Conda 环境设置 ---
CONDA_BASE_PATH="/home/wkmgc/miniconda3" # <--- 在这里修改为您自己的 Conda 路径
CONDA_ENV_NAME="${SEG_CONDA_ENV:-seg_smp}" # 可用 SEG_CONDA_ENV=SMP bash yolo_train.sh 临时覆盖
# 初始化并激活 Conda 环境
if [ -f "${CONDA_BASE_PATH}/etc/profile.d/conda.sh" ]; then
source "${CONDA_BASE_PATH}/etc/profile.d/conda.sh"
conda activate "${CONDA_ENV_NAME}"
if [ $? -ne 0 ]; then
echo "错误: 激活 Conda 环境 '${CONDA_ENV_NAME}' 失败!"
exit 1
fi
echo "Conda 环境 '${CONDA_ENV_NAME}' 已成功激活。"
else
echo "错误: 找不到 conda.sh 脚本。请检查您的 CONDA_BASE_PATH 设置是否正确。"
exit 1
fi
echo "正在为当前终端会话设置 HTTP/HTTPS 代理..."
export https_proxy=http://127.0.0.1:1080
export http_proxy=http://127.0.0.1:1080
echo "代理已设置为 http://127.0.0.1:1080"
echo "如果模型权重不存在,将通过此代理进行下载。"
# --- 2. 模型与 GPU 配置 ---
# 根据您的 GPU 硬件和训练计划进行分组
# yolo_train.py 会自动使用分配给它的所有可见 GPU
GPUS_GROUP_0="0"
GPUS_GROUP_1="1"
GPUS_GROUP_2="2"
GPUS_GROUP_3="3"
# TODO
GPUS_GROUP_4="0"
GPUS_GROUP_5="1"
GPUS_GROUP_6="2"
GPUS_GROUP_7="3"
# GPUS_GROUP_4="4"
# GPUS_GROUP_5="5"
# GPUS_GROUP_6="6"
# GPUS_GROUP_7="7"
# 从 yolo_config.py 中选择要训练的模型,并分配到不同的组
# 建议将计算量/模型大小相似的模型放在同一组
GROUP_0_MODELS=("YOLO11l-seg" "YOLOv8m-seg")
GROUP_1_MODELS=("YOLO11x-seg")
GROUP_2_MODELS=("YOLO11n-seg" "YOLO11s-seg" "YOLO11m-seg")
GROUP_3_MODELS=("YOLOv9e-seg")
GROUP_4_MODELS=("YOLOv9c-seg")
GROUP_5_MODELS=("YOLOv8x-seg")
GROUP_6_MODELS=("YOLOv8l-seg" "YOLO12-seg") # YOLOv9e-seg # 模型太大
GROUP_7_MODELS=("YOLOv8n-seg" "YOLOv8s-seg" ) # 训练完毕
# 1. 从 config.py 中读取 OUTPUTS_DIR 的值
OUTPUTS_DIR=$(python -c "from yolo_config import OUTPUTS_DIR; print(OUTPUTS_DIR)")
# 检查是否成功获取了 OUTPUTS_DIR
if [ -z "$OUTPUTS_DIR" ]; then
echo "OUTPUTS_DIR: $OUTPUTS_DIR"
echo "Error 1: Could not read OUTPUTS_DIR from yolo_config.py. Exiting."
echo "Error 2: Or the directory specified by OUTPUTS_DIR does not exist. Please create it first."
# exit 1
fi
# 2. 定义带有时间戳的日志目录名
LOG_DIR_NAME="yolo_train_logs_parallel_$(date +%Y-%m-%d_%H-%M-%S)"
# 3. 拼接成最终的完整路径
LOG_DIR="$OUTPUTS_DIR/$LOG_DIR_NAME"
mkdir -p "${LOG_DIR}"
echo "所有模型的日志将保存在 ./${LOG_DIR}/ 目录中。"
echo "----------------------------------------------------"
# --- 3. 训练执行函数 ---
# 定义一个函数来启动一组训练,以避免代码重复
start_training_group() {
# 使用 nameref (引用) 来传递数组
local -n models=$1
local gpus=$2
local group_name=$3
echo ">>> 准备启动 ${group_name} 的训练任务 (后台运行)..."
# 遍历指定组中的所有模型
for model_key in "${models[@]}"; do
echo " -> 正在后台启动模型: ${model_key} on GPUs: ${gpus}"
# 使用 '&' 将命令放入后台运行
# 通过 --model 参数将模型名称传递给 yolo_train.py
CUDA_VISIBLE_DEVICES=${gpus} python yolo_train.py --model "${model_key}" > "${LOG_DIR}/${model_key}.log" 2>&1 &
echo " - 模型 ${model_key} 已在后台启动。日志文件: ${LOG_DIR}/${model_key}.log"
echo " - 等待 30 秒,确保 GPU 资源稳定分配..."
sleep 30
done
echo ">>> ${group_name} 的所有模型均已启动。"
echo "----------------------------------------------------"
}
# --- 4. 依次启动所有训练任务 ---
# 脚本将快速地按顺序启动每一组任务到后台
start_training_group GROUP_0_MODELS "${GPUS_GROUP_0}" "第零组" # TODO
start_training_group GROUP_1_MODELS "${GPUS_GROUP_1}" "第一组"
start_training_group GROUP_2_MODELS "${GPUS_GROUP_2}" "第二组"
start_training_group GROUP_3_MODELS "${GPUS_GROUP_3}" "第三组"
# --- 5. 等待所有后台任务完成 ---
echo ""
echo "--- 所有模型均已在后台启动。现在等待所有训练任务完成... ---"
# 'wait' 命令会暂停脚本,直到所有由此脚本启动的后台子进程全部执行完毕
wait
echo "--- 所有后台训练任务已全部完成! ---"
# 适用于4卡版本
start_training_group GROUP_4_MODELS "${GPUS_GROUP_4}" "第四组" # TODO
start_training_group GROUP_5_MODELS "${GPUS_GROUP_5}" "第五组"
start_training_group GROUP_6_MODELS "${GPUS_GROUP_6}" "第六组"
start_training_group GROUP_7_MODELS "${GPUS_GROUP_7}" "第七组"
# --- 5. 等待所有后台任务完成 ---
echo ""
echo "--- 所有模型均已在后台启动。现在等待所有训练任务完成... ---"
# 'wait' 命令会暂停脚本,直到所有由此脚本启动的后台子进程全部执行完毕
wait
echo "--- 所有后台训练任务已全部完成! ---"
# --- 6. 退出脚本 ---
# 训练完成,取消激活 Conda 环境
echo "训练流程结束。"
# --- 取消网络代理设置 ---
echo "正在取消当前终端会话的 HTTP/HTTPS 代理设置..."
unset https_proxy
unset http_proxy
echo "代理已取消。"
conda deactivate
echo "取消激活 Conda 环境。"

View File

@@ -0,0 +1,100 @@
############## A. 配置环境 ##############
pip install ultralytics grad-cam
修改 /home/wkmgc/miniconda3/envs/SMP/lib/python3.9/site-packages/pytorch_grad_cam/base_cam.py
if targets is None:
target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
targets = [ClassifierOutputTarget(category) for category in target_categories]
变为
if targets is None:
# 循环解包,直到找到第一个非元组的元素,这可以处理嵌套元组,例如 ((Tensor, ...), ...)
processed_outputs = outputs
print(f"DEBUG: 原始 outputs 类型: {type(processed_outputs)}")
while isinstance(processed_outputs, tuple):
if len(processed_outputs) == 0:
# 如果遇到一个空元组,我们就无法继续,打印错误并退出循环
print(f"ERROR: 在解包时遇到空元组。原始 outputs: {outputs}")
break
print(f"DEBUG: 正在解包元组... 选择第 0 个元素。")
processed_outputs = processed_outputs[0] # 逐次选择第一个元素
print(f"DEBUG: 最终选定的 outputs 类型: {type(processed_outputs)}")
# 现在processed_outputs 应该是我们需要的张量 (Tensor)
try:
target_categories = np.argmax(processed_outputs.cpu().data.numpy(), axis=-1)
targets = [ClassifierOutputTarget(category) for category in target_categories]
except AttributeError as e:
print(f"ERROR: 最终选定的元素 (类型: {type(processed_outputs)}) 无法处理。")
print(f" 它没有 .cpu() 属性。原始错误: {e}")
print(f" 请检查您的模型输出结构。原始 outputs: {outputs}")
raise e # 重新抛出异常,以便程序停止
############## B. 数据集构建(可选) ##############
参考:./Yolo数据集构建
############## C. Train 训练程序 ##############
# A. 第一次训练开一个梯子【需要下载内容】
export https_proxy=http://127.0.0.1:1080 http_proxy=http://127.0.0.1:1080
CUDA_VISIBLE_DEVICES=0 python train.py --model {YOLOv8n-seg, YOLOv8s-seg, YOLOv8m-seg, YOLOv8l-seg, YOLOv8x-seg, YOLOv9e-seg, YOLOv9c-seg, YOLO11l-seg, YOLO11n-seg, YOLO11s-seg, YOLO11m-seg, YOLO11x-seg, YOLO12-seg, YOLO12-seg}
# B. 运行单个训练程序
1. 在 dataset.yaml 中修改 训练数据集CUDA_VISIBLE_DEVICES修改使用显卡--model 修改算法;
2. CUDA_VISIBLE_DEVICES=0 python yolo_train.py --model "YOLOv8n-seg"
# C. 批量运行训练程序
1. train.sh 修改想让它使用的显卡、算法;
2. bash yolo_train.sh
3. 会生成 ./yolo_logs_parallel_DATE 的终端记录文件;结果先存储在 ../DataSet_Public_outputs/DATASET-Yolo 后自动移动到 ../Hardisk/DATASET_outputs-SegModel 中
############## D. Predict 推理程序 ##############
# A. 预先步骤:将模型同步到 Nas_BackUp_Seg 或 ./Hardisk 文件夹中【cd .. && bash ./Back_Up.sh】
# B1. 将最优模型文件从 ./Nas_BackUp_Seg 移动到 ./BestMode_Predict_Results_DataSet_Public 指定文件夹中
bash ./Tool_Yolo_Copy_Best_Model.sh --pt_name "best.pt" # "epoch100.pt" # 修改里面的路径
# B2. 如果需要处理自定义数据集,请将模型文件夹手动复制为 ./BestMode_Predict_Results_DataSet_Public/DATASET-Yolo 中
可以先对原有 ./BestMode_Predict_Results_DataSet_Public/ORI_DATASET-Yolo 临时改名
运行 bash ./Tool_Yolo_Copy_Best_Model.sh --pt_name "best.pt" # "epoch100.pt" # 修改里面的路径
在将生成的 ./BestMode_Predict_Results_DataSet_Public/ORI_DATASET-Yolo 改为 ./BestMode_Predict_Results_DataSet_Public/DATASET-Yolo
# C. 运行单个推理程序
1. 在 dataset.yaml 中修改 预测数据集 及 val/test yolo_config.py 中需修改模型保存路径 PREDICT_ALL_BEST_MODELS_DIR/PREDICT_BEST_MODEL_DIRCUDA_VISIBLE_DEVICES修改使用显卡--model 修改算法;
2. CUDA_VISIBLE_DEVICES=0 python yolo_predict_V1_NoColor.py --model "YOLOv8n-seg"
CUDA_VISIBLE_DEVICES=0 python yolo_predict_V2.py --model "YOLOv8n-seg" --conf 0.2 --pt_name "epoch100.pt" # "epoch100.pt" # "best.pt"
# D. 批量运行推理程序(分割图可视化 或 热图可视化)
1. yolo_predict.sh 修改想让它使用的显卡、算法;
2. bash yolo_predict.sh --conf 0.2 --pt_name "epoch100.pt" # "best.pt"(默认)# 分割图可视化
3. bash yolo_predict.sh --conf 0.2 --heatmap_method "All" --pt_name "epoch100.pt" # "best.pt"(默认)# 热图可视化
4. 会生成 ./yolo_predict_logs_parallel_DATE 的终端记录文件;结果存储在 ../BestMode_Predict_Results_DataSet_Public/DATASET-Yolo 中
# E. 横向对比训练结果
1. 将模型结果进行横向对比(生成color_mask、Yolo_result的横向对比)
python yolo_predict_V2_compare_all.py --pt_name "all" # "epoch100.pt" # "best.pt" # "all"(默认)
############## E. Predict_raw_img_Check 检测推理输出图片是否齐全 ##############
# 检测 dataset.yaml path/"labels"/test.name 和 yolo_config.PREDICT_BEST_MODEL_DIR/****/predicted_raw_masks 中图片是否匹配
1. 在 dataset.yaml 中修改 path、test在yolo_config.PREDICT_ALL_BEST_MODELS_DIRCUDA_VISIBLE_DEVICES修改使用显卡
2. python yolo_predict_raw_masks_check.py --pt_name "best.pt" # "epoch100.pt" # "best.pt"(默认)
############## F. 神经网络热图可视化目前只支持EigenCAM无指定类别的方法 ##############
# A. 预先步骤cd ./Yolo可视化测试
python yolo_layer_tester.py --model "YOLO12-seg" --cam_method "GradCAM" --pt_name "best.pt" # 测试各个层是否能成功生成热图,选择合适的层
vim ../yolo_predict_visualize_nn.py # 修改 visualize_nn_comprehensive 中 model_key.lower().startswith("yoloXXX") 的部分,添加对应模型的默认层
# B. 运行单个热图可视化程序
# YOLOv8n-seg,YOLOv8s-seg,YOLOv8m-seg,YOLOv8l-seg,YOLOv8x-seg,YOLOv9c-seg,YOLOv9e-seg,YOLO11n-seg,YOLO11s-seg,YOLO11m-seg,YOLO11l-seg,YOLO11x-seg,YOLO12-seg
python yolo_predict_visualize_nn.py --model "YOLO11m-seg" --target_layers "default" --cam_method "All" --pt_name "best.pt" # "epoch100.pt" # "best.pt"
############## G.※ 快速进行分割实验 ※ ##############
# 1. 修改 dataset.yaml 中 训练数据集、修改 yolo_config.py 中 EPOCHS、PATIENCE
cd ~/Desktop/Seg/Seg_All_In_One_YoloModel
# 2. 批量化训练
bash yolo_train.sh
# 3. 复制最优模型到预测文件夹
bash ./Tool_Yolo_Copy_Best_Model.sh --pt_name "best.pt" && bash ./Tool_Yolo_Copy_Best_Model.sh --pt_name "epoch100.pt"
# 4. 批量化预测+热度图可视化
bash yolo_predict.sh --conf 0.2 --pt_name "epoch100.pt" && bash yolo_predict.sh --conf 0.2
bash ./yolo_predict.sh --heatmap_method "All" && bash ./yolo_predict.sh --pt_name "epoch100.pt" --heatmap_method "All"
# 5. 横向对比结果
python yolo_predict_V2_compare_all.py
# 6. 打包预测结果(不包含*.pt模型文件)
cd /home/wkmgc/Desktop/Seg/BestMode_Predict_Results_DataSet_Public/
zip -r /home/wkmgc/Desktop/5_My_Gastric_2025_10_29_938-Yolo.zip 5_My_Gastric_2025_10_29_938*-Yolo -x "*.pt"
# 7. 打包训练结果(只有.png、.jpg、.csv文件)
cd /home/wkmgc/Desktop/Seg/Hardisk/
zip -r /home/wkmgc/Desktop/5_My_Gastric_2025_10_29_938-Yolo_train.zip 5_My_Gastric_2025_10_29_938-Yolo -i \*.png \*.jpg \*.csv