【小白深度教程 1.11】手把手教你使用 PSMNet 估计视差和计算深度,并映射到 3D 点云 (含 Python 代码)
在之前的章节中,我们展示了如何用 Depth Anything V2 进行 单目深度估计 ,以及 3D 点云生成:
【小白深度教程 1.8】手把手教你使用 Depth Anything V2 估计单目深度并映射到 3D 点云
但是,单目深度估计存在固有的尺度模糊问题:
因此这次我们尝试使用 PSMNet 视察估计技术,来进行准确的双目深度估计,并将场景转换成 3D 点云,Python 代码在最后。
最终效果如图:
对比传统方法:
1. PSMNet 简介
- 提出了一种不要后处理的端到端的立体匹配网络。
- 引入了一个金字塔池化模块,用于将全局上下文信息整合到图像特征中。
- 提出了一个堆叠的沙漏 3D CNN 来扩展成本量中上下文信息的区域支持。
- 在 KITTI 数据集上实现了最先进的精度。
更详细的介绍可以查看上一章节:
【小白深度教程 1.10】手把手教你使用深度学习方法(PSMNet)进行视差估计(含 Python 代码解析)
2. 环境配置
下载源码:
git clone https://github.com/JiaRenChang/PSMNet.git
cd PSMNet
- 1
- 2
创建 Conda 环境:
conda create -n psmnet python=3.7
conda activate psmnet
conda install pytorch==1.6.0 torchvision==0.7.0 -c pytorch
pip install opencv-python
pip install matplotlib
pip install scikit-image
pip install open3d
- 1
- 2
- 3
- 4
- 5
- 6
- 7
3. 下载预训练模型
4. 修改推理代码
由于我们使用 CPU 推理,所以要修改代码 Test_img.py:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
import time
import math
from models import *
import cv2
from PIL import Image
# 2012 data /media/jiaren/ImageNet/data_scene_flow_2012/testing/
parser = argparse.ArgumentParser(description='PSMNet')
parser.add_argument('--KITTI', default='2015',
help='KITTI version')
parser.add_argument('--datapath', default='/media/jiaren/ImageNet/data_scene_flow_2015/testing/',
help='select model')
parser.add_argument('--loadmodel', default='./trained/pretrained_model_KITTI2015.tar',
help='loading model')
parser.add_argument('--leftimg', default= './VO04_L.png',
help='load model')
parser.add_argument('--rightimg', default= './VO04_R.png',
help='load model')
parser.add_argument('--model', default='stackhourglass',
help='select model')
parser.add_argument('--maxdisp', type=int, default=192,
help='maxium disparity')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
args = parser.parse_args()
torch.manual_seed(args.seed)
if args.model == 'stackhourglass':
model = stackhourglass(args.maxdisp)
elif args.model == 'basic':
model = basic(args.maxdisp)
else:
print('no model')
if args.loadmodel is not None:
print('load PSMNet')
state_dict = torch.load(args.loadmodel, map_location=torch.device('cpu'))
state_dict2load = {}
for k, v in state_dict['state_dict'].items():
state_dict2load[k.replace("module.", "")] = v
model.load_state_dict(state_dict2load)
print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
def test(imgL,imgR):
model.eval()
disp = model(imgL,imgR)
disp = torch.squeeze(disp)
pred_disp = disp.data.cpu().numpy()
return pred_disp
def main():
normal_mean_var = {'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225]}
infer_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(**normal_mean_var)])
imgL_o = Image.open(args.leftimg).convert('RGB')
imgR_o = Image.open(args.rightimg).convert('RGB')
imgL = infer_transform(imgL_o)
imgR = infer_transform(imgR_o)
# pad to width and hight to 16 times
if imgL.shape[1] % 16 != 0:
times = imgL.shape[1]//16
top_pad = (times+1)*16 -imgL.shape[1]
else:
top_pad = 0
if imgL.shape[2] % 16 != 0:
times = imgL.shape[2]//16
right_pad = (times+1)*16-imgL.shape[2]
else:
right_pad = 0
imgL = F.pad(imgL,(0,right_pad, top_pad,0)).unsqueeze(0)
imgR = F.pad(imgR,(0,right_pad, top_pad,0)).unsqueeze(0)
start_time = time.time()
pred_disp = test(imgL,imgR)
print('time = %.2f' %(time.time() - start_time))
if top_pad !=0 and right_pad != 0:
img = pred_disp[top_pad:,:-right_pad]
elif top_pad ==0 and right_pad != 0:
img = pred_disp[:,:-right_pad]
elif top_pad !=0 and right_pad == 0:
img = pred_disp[top_pad:,:]
else:
img = pred_disp
print(img.shape)
np.save("disp.npy", img)
img = (img*256).astype('uint16')
img = Image.fromarray(img)
img.save('Test_disparity.png')
if __name__ == '__main__':
main()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117
- 118
- 119
- 120
- 121
- 122
5. 用 PSMNet 估计视差
python Test_img.py --loadmodel pretrained_model_KITTI2015.tar --leftimg ../Code/left_img.png --rightimg ../Code/right_img.png
- 1
6. 报错解决
RuntimeError: [enforce fail at ..\c10\core\CPUAllocator.cpp:72] data. DefaultCPUAllocator: not
enough memory: you tried to allocate 190464000 bytes. Buy new RAM!
- 1
- 2
需要先把图像下采样,防止内存不足
7. 映射到 3D 点云
注意:映射代码需要从之前的博客中查看
import cv2
import numpy as np
import time
import matplotlib.pyplot as plt
from depth import depth_map
from configs import img_path1,img_path2
from disparity import disparitymap
from image import Image_processing,downsample_image,create_output
def main():
img = cv2.imread(img_path1, 1)
img = downsample_image(img, 1)
imgL = Image_processing(img_path1)
imgR = Image_processing(img_path2)
disp = np.load("disp.npy")
print(imgL.shape, disp.shape, disp.max()) # (1000, 1482) (1000, 1482) 120.57847
coordinates = depth_map(disp, img)
print('\n Creating the output file... \n')
create_output(coordinates, 'praxis_psmnet.ply')
main()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
可以得到对应的视差图:
8. 对比传统方法
9. 点云可视化
import open3d as o3d
# 读取 .ply 文件
ply_file_path = "praxis_psmnet.ply" # 替换为你的 .ply 文件路径
point_cloud = o3d.io.read_point_cloud(ply_file_path)
# 创建可视化窗口
vis = o3d.visualization.Visualizer()
vis.create_window()
# 将点云添加到可视化窗口
vis.add_geometry(point_cloud)
# 获取渲染选项并调整点的大小
render_option = vis.get_render_option()
render_option.point_size = 2.5 # 调整点大小,默认值通常为 5.0,值越小点越小
# 启动可视化
vis.run()
# 销毁窗口
vis.destroy_window()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22