213 lines
7.3 KiB
Python
213 lines
7.3 KiB
Python
import os
|
|
import argparse
|
|
import subprocess
|
|
import cv2
|
|
import numpy as np
|
|
import glob
|
|
import sys
|
|
|
|
def read_image_safe(path):
|
|
"""
|
|
[解决中文路径问题] 读取图片
|
|
使用 numpy 读取字节流再解码,绕过 cv2.imread 的路径编码 bug
|
|
"""
|
|
try:
|
|
# np.fromfile 读取文件内容到内存
|
|
img_array = np.fromfile(path, dtype=np.uint8)
|
|
# cv2.imdecode 解码内存数据
|
|
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
|
|
return img
|
|
except Exception as e:
|
|
print(f"Error reading image {path}: {e}")
|
|
return None
|
|
|
|
def write_image_safe(path, img):
|
|
"""
|
|
[解决中文路径问题] 保存图片
|
|
使用 cv2.imencode 编码再用 numpy 保存,绕过 cv2.imwrite 的路径编码 bug
|
|
"""
|
|
try:
|
|
# 获取文件扩展名 (例如 .png)
|
|
ext = os.path.splitext(path)[1]
|
|
if not ext:
|
|
ext = ".png"
|
|
|
|
# cv2.imencode 编码图片
|
|
success, img_array = cv2.imencode(ext, img)
|
|
if success:
|
|
# tofile 保存到文件
|
|
img_array.tofile(path)
|
|
return True
|
|
else:
|
|
print(f"Error encoding image for {path}")
|
|
return False
|
|
except Exception as e:
|
|
print(f"Error writing image {path}: {e}")
|
|
return False
|
|
|
|
def run_encoder(encoder, img_path, out_dir, grayscale):
|
|
"""
|
|
调用 run.py 生成指定 encoder 的深度图
|
|
"""
|
|
cmd = [
|
|
sys.executable, 'run.py',
|
|
'--encoder', encoder,
|
|
'--img-path', img_path,
|
|
'--outdir', out_dir,
|
|
'--pred-only'
|
|
]
|
|
|
|
if grayscale:
|
|
cmd.append('--grayscale')
|
|
|
|
print(f"Executing: {' '.join(cmd)}")
|
|
try:
|
|
subprocess.check_call(cmd)
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"Error occurred while running {encoder} (Gray={grayscale}): {e}")
|
|
return False
|
|
return True
|
|
|
|
def add_text(img, text):
|
|
"""
|
|
在图片左上角添加文字标注
|
|
"""
|
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
|
font_scale = 1.0
|
|
font_color = (255, 255, 255)
|
|
thickness = 2
|
|
position = (30, 50)
|
|
|
|
cv2.putText(img, text, position, font, font_scale, (0, 0, 0), thickness + 3, cv2.LINE_AA)
|
|
cv2.putText(img, text, position, font, font_scale, font_color, thickness, cv2.LINE_AA)
|
|
return img
|
|
|
|
def create_comparison(image_files, folder_name, encoders, mode_suffix, out_folder_name):
|
|
"""
|
|
生成对比图的核心逻辑
|
|
"""
|
|
compare_dir = os.path.join(folder_name, out_folder_name)
|
|
os.makedirs(compare_dir, exist_ok=True)
|
|
|
|
print(f"\n--- Generating Comparisons ({out_folder_name}) -> {compare_dir} ---")
|
|
|
|
for img_file in image_files:
|
|
file_basename = os.path.basename(img_file)
|
|
file_name_no_ext = os.path.splitext(file_basename)[0]
|
|
|
|
# 1. 读取原图
|
|
raw_img = read_image_safe(img_file)
|
|
if raw_img is None:
|
|
continue
|
|
|
|
# 获取原图尺寸 (高度, 宽度)
|
|
h, w = raw_img.shape[:2]
|
|
|
|
# 标注原图
|
|
raw_img_labeled = raw_img.copy()
|
|
raw_img_labeled = add_text(raw_img_labeled, "Original")
|
|
|
|
images_to_concat = [raw_img_labeled]
|
|
valid_group = True
|
|
|
|
# 2. 依次读取 4 个 encoder 的结果
|
|
for encoder in encoders:
|
|
result_dir = os.path.join(folder_name, f'V2_{encoder}{mode_suffix}')
|
|
result_path = os.path.join(result_dir, file_name_no_ext + '.png')
|
|
|
|
if not os.path.exists(result_path):
|
|
print(f"Missing result for {file_basename} in {encoder}, skipping.")
|
|
valid_group = False
|
|
break
|
|
|
|
depth_img = read_image_safe(result_path)
|
|
|
|
if depth_img is None:
|
|
print(f"Failed to read {result_path}, skipping.")
|
|
valid_group = False
|
|
break
|
|
|
|
# 确保通道数正确 (转为3通道以便拼接)
|
|
if len(depth_img.shape) == 2:
|
|
depth_img = cv2.cvtColor(depth_img, cv2.COLOR_GRAY2BGR)
|
|
|
|
# --- 关键修改:强制 Resize 到原图尺寸 ---
|
|
# 无论原始输出尺寸如何,都强制缩放为 (w, h)
|
|
if depth_img.shape[:2] != (h, w):
|
|
depth_img = cv2.resize(depth_img, (w, h), interpolation=cv2.INTER_LINEAR)
|
|
|
|
# 标注文字
|
|
label = f"Depth_Anything_V2_{encoder}"
|
|
depth_img = add_text(depth_img, label)
|
|
|
|
images_to_concat.append(depth_img)
|
|
|
|
if valid_group:
|
|
# 水平拼接
|
|
try:
|
|
combined_result = cv2.hconcat(images_to_concat)
|
|
# 保存
|
|
save_path = os.path.join(compare_dir, f"{file_name_no_ext}_compare.png")
|
|
if write_image_safe(save_path, combined_result):
|
|
print(f"Saved: {save_path}")
|
|
else:
|
|
print(f"Failed to save: {save_path}")
|
|
except Exception as e:
|
|
print(f"Error combining images for {file_name_no_ext}: {e}")
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Batch process depth estimation (Gray & Color) and compare.')
|
|
parser.add_argument('--img-path', type=str, required=True, help='Path to input images folder')
|
|
args = parser.parse_args()
|
|
|
|
# 1. 确定根目录名称: 输入文件夹名 + "_results"
|
|
folder_name = os.path.basename(os.path.normpath(args.img_path)) + "_results"
|
|
|
|
encoders = ['vits', 'vitb', 'vitl', 'vitg']
|
|
|
|
# 2. 批量生成深度图 (Run Inference)
|
|
for encoder in encoders:
|
|
# A. 灰度图
|
|
gray_out_dir = os.path.join(folder_name, f'V2_{encoder}_gray_result')
|
|
if not os.path.exists(gray_out_dir):
|
|
run_encoder(encoder, args.img_path, gray_out_dir, grayscale=True)
|
|
|
|
# B. 彩色图
|
|
color_out_dir = os.path.join(folder_name, f'V2_{encoder}_colorful_result')
|
|
if not os.path.exists(color_out_dir):
|
|
run_encoder(encoder, args.img_path, color_out_dir, grayscale=False)
|
|
|
|
# 3. 准备原图列表
|
|
exts = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tif', '*.tiff']
|
|
image_files = []
|
|
for ext in exts:
|
|
image_files.extend(glob.glob(os.path.join(args.img_path, ext)))
|
|
image_files.extend(glob.glob(os.path.join(args.img_path, '**', ext), recursive=True))
|
|
image_files = sorted(list(set(image_files)))
|
|
|
|
if not image_files:
|
|
print("No images found in input path.")
|
|
return
|
|
|
|
# 4. 生成灰度对比图
|
|
create_comparison(
|
|
image_files=image_files,
|
|
folder_name=folder_name,
|
|
encoders=encoders,
|
|
mode_suffix='_gray_result',
|
|
out_folder_name='compare_all_gray'
|
|
)
|
|
|
|
# 5. 生成彩色对比图
|
|
create_comparison(
|
|
image_files=image_files,
|
|
folder_name=folder_name,
|
|
encoders=encoders,
|
|
mode_suffix='_colorful_result',
|
|
out_folder_name='compare_all_colorful'
|
|
)
|
|
|
|
print(f"\nAll done! Results are in: {folder_name}")
|
|
|
|
if __name__ == "__main__":
|
|
main() |