Files
Media_Depth/Tool_Gen_3d_points_Cloud.py
2026-05-20 12:25:12 +08:00

177 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
import os
import numpy as np
import torch
from PIL import Image, ImageEnhance
import cv2
from argparse import ArgumentParser
from tqdm import tqdm
def read_image_safe(path):
try:
img_array = np.fromfile(path, dtype=np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
return img
except Exception as e:
print(f"Error reading image {path}: {e}")
return None
# 参数设置
parser = ArgumentParser()
parser.add_argument("--img_path_ori", type=str, required=True, help="输入原始图像的目录路径")
parser.add_argument("--img_path_depth", type=str, required=True, help="输入3d灰度图像的目录路径")
parser.add_argument("--outdir", type=str, required=True, help="输出处理后的 PLY 文件的目录")
parser.add_argument("--appendix", type=str, default='_depth', help="深度图后缀")
parser.add_argument("--suffix", type=str, default='', help="深度图前缀")
parser.add_argument("--label-type", type=str, default='Disparity', help="标注图格式Depth 或 Disparity")
parser.add_argument("--focal-length", type=float, default=-1, help="相机焦距")
parser.add_argument("--FoV", type=float, default=60, help="相机视场角")
# === 新增调整参数 ===
parser.add_argument("--z-scale", type=float, default=1.0, help="[新增] Z轴拉伸倍数数值越大模型越立体/修长")
parser.add_argument("--brightness", type=float, default=1.0, help="[新增] 亮度调整倍数 (例如 1.2 为变亮20%)")
parser.add_argument("--saturation", type=float, default=1.0, help="[新增] 饱和度调整倍数 (例如 1.2 为颜色更鲜艳)")
parser.add_argument("--gamma", type=float, default=1.0, help="[新增] Gamma校正 (数值<1.0 提亮暗部细节)")
# ===================
args = parser.parse_args()
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)
if os.path.isdir(args.img_path_ori):
file_paths_ori = [os.path.join(args.img_path_ori, f) for f in os.listdir(args.img_path_ori) if f.endswith('.png') or f.endswith('.jpg')]
file_paths_ori.sort()
else:
file_paths_ori = [args.img_path_ori]
if os.path.isdir(args.img_path_depth):
depth_file_path = args.img_path_depth
else:
depth_file_path = os.path.dirname(args.img_path_depth)
def init_image_coor(height, width):
x_row = np.arange(0, width)
y_col = np.arange(0, height)
x = np.tile(x_row, (height, 1))
y = np.tile(y_col, (width, 1)).T
x = torch.from_numpy(x.astype(np.float32)).unsqueeze(0) - width / 2.0
y = torch.from_numpy(y.astype(np.float32)).unsqueeze(0) - height / 2.0
return x, y
def adjust_image_color(pil_image, brightness=1.0, saturation=1.0, gamma=1.0):
"""处理图像颜色:亮度 -> 饱和度 -> Gamma"""
# 1. 调整亮度
if brightness != 1.0:
enhancer = ImageEnhance.Brightness(pil_image)
pil_image = enhancer.enhance(brightness)
# 2. 调整饱和度
if saturation != 1.0:
enhancer = ImageEnhance.Color(pil_image)
pil_image = enhancer.enhance(saturation)
# 转为numpy处理Gamma
img_array = np.array(pil_image).astype(np.float32)
# 3. Gamma 校正 (I_new = I_old ^ gamma)
# 通常 gamma < 1.0 会让暗部变亮
if gamma != 1.0:
img_array = img_array / 255.0
img_array = np.power(img_array, gamma)
img_array = img_array * 255.0
return np.clip(img_array, 0, 255).astype(np.uint8)
for file_path_ori in tqdm(file_paths_ori):
# 读取原始RGB图像 (PIL)
rgb_pil = Image.open(file_path_ori).convert('RGB')
# === [新增] 在生成点云前优化颜色 ===
# 应用亮度、饱和度和Gamma调整
rgb_array = adjust_image_color(rgb_pil, args.brightness, args.saturation, args.gamma)
# =================================
# 用于边缘检测的灰度图处理 (保持原有逻辑不变)
rgb_pil_ = read_image_safe(file_path_ori)
if rgb_pil_ is None: continue
imgray = cv2.cvtColor(rgb_pil_, cv2.COLOR_BGR2GRAY)
ret, thresh = cv2.threshold(imgray, 5, 255, 0)
contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
black_edge = np.ones_like(imgray, dtype=np.uint8)
cv2.drawContours(black_edge, contours, -1, color=0, thickness=cv2.FILLED)
black_edge = np.array(black_edge)
# 读取深度图
base_name = os.path.splitext(os.path.basename(file_path_ori))[0]
# 尝试匹配 png 或 jpg 后缀
possible_exts = ['.png', '.jpg', '.jpeg']
depth_file = None
for ext in possible_exts:
temp_path = os.path.join(depth_file_path, args.suffix + base_name + args.appendix + ext)
if os.path.exists(temp_path):
depth_file = temp_path
break
if depth_file and os.path.exists(depth_file):
depth_pil = Image.open(depth_file).convert('L')
depth_array = np.array(depth_pil)
# 处理黑边
depth_array[black_edge == 1] = 255
if args.label_type == "Disparity":
depth_array[black_edge == 1] = 0
w, h = depth_pil.size
u_u0, v_v0 = init_image_coor(h, w)
depth_tensor = torch.tensor(depth_array / 255.0).unsqueeze(0).unsqueeze(0)
if args.label_type == "Disparity":
depth_tensor = 1 - depth_tensor
# 计算焦距
if args.focal_length != -1:
focal_length = args.focal_length
else:
focal_length = (depth_array.shape[0] // 2 / np.tan((args.FoV / 2.0) * np.pi / 180))
# 计算坐标
x = u_u0 * depth_tensor / focal_length
y = v_v0 * depth_tensor / focal_length
# === [新增] 应用 Z轴缩放 ===
# 乘以 args.z_scale 来拉伸深度
z = depth_tensor * args.z_scale
# ==========================
data_ply = np.stack([
x.reshape(-1),
y.reshape(-1),
z.reshape(-1),
rgb_array[:, :, 0].reshape(-1),
rgb_array[:, :, 1].reshape(-1),
rgb_array[:, :, 2].reshape(-1)
]).reshape(6, -1).T
# 生成PLY内容
points = [f"{p[0]:.4f} {p[1]:.4f} {p[2]:.4f} {int(p[3])} {int(p[4])} {int(p[5])} 0\n" for p in data_ply]
ply_name = base_name + ".ply"
ply_path = os.path.join(args.outdir, ply_name)
with open(ply_path, "w") as file:
file.write(f'''ply
format ascii 1.0
element vertex {len(points)}
property float x
property float y
property float z
property uchar red
property uchar green
property uchar blue
property uchar alpha
end_header
{''.join(points)}''')
else:
print(f'未找到对应的深度图: {base_name}')