【小白深度教程 1.11】手把手教你使用 PSMNet 估计视差和计算深度,并映射到 3D 点云(含 Python 代码)

在之前的章节中,我们展示了如何用 Depth Anything V2 进行 单目深度估计 ,以及 3D 点云生成:

【小白深度教程 1.8】手把手教你使用 Depth Anything V2 估计单目深度并映射到 3D 点云

但是,单目深度估计存在固有的尺度模糊问题:

在这里插入图片描述

因此这次我们尝试使用 PSMNet 视察估计技术,来进行准确的双目深度估计,并将场景转换成 3D 点云,Python 代码在最后。

最终效果如图:
在这里插入图片描述
对比传统方法:
在这里插入图片描述

1. PSMNet 简介

  • 提出了一种不要后处理的端到端的立体匹配网络。
  • 引入了一个金字塔池化模块,用于将全局上下文信息整合到图像特征中。
  • 提出了一个堆叠的沙漏 3D CNN 来扩展成本量中上下文信息的区域支持。
  • 在 KITTI 数据集上实现了最先进的精度。

在这里插入图片描述
更详细的介绍可以查看上一章节:

【小白深度教程 1.10】手把手教你使用深度学习方法(PSMNet)进行视差估计(含 Python 代码解析)

2. 环境配置

下载源码:

git clone https://github.com/JiaRenChang/PSMNet.git
cd PSMNet
  • 1
  • 2

创建 Conda 环境:

conda create -n psmnet python=3.7
conda activate psmnet
conda install pytorch==1.6.0 torchvision==0.7.0 -c pytorch
pip install opencv-python
pip install matplotlib
pip install scikit-image
pip install open3d
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

3. 下载预训练模型

在这里插入图片描述

4. 修改推理代码

由于我们使用 CPU 推理,所以要修改代码 Test_img.py:

from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
import time
import math
from models import *
import cv2
from PIL import Image

# 2012 data /media/jiaren/ImageNet/data_scene_flow_2012/testing/

parser = argparse.ArgumentParser(description='PSMNet')
parser.add_argument('--KITTI', default='2015',
                    help='KITTI version')
parser.add_argument('--datapath', default='/media/jiaren/ImageNet/data_scene_flow_2015/testing/',
                    help='select model')
parser.add_argument('--loadmodel', default='./trained/pretrained_model_KITTI2015.tar',
                    help='loading model')
parser.add_argument('--leftimg', default= './VO04_L.png',
                    help='load model')
parser.add_argument('--rightimg', default= './VO04_R.png',
                    help='load model')                                      
parser.add_argument('--model', default='stackhourglass',
                    help='select model')
parser.add_argument('--maxdisp', type=int, default=192,
                    help='maxium disparity')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
args = parser.parse_args()

torch.manual_seed(args.seed)

if args.model == 'stackhourglass':
    model = stackhourglass(args.maxdisp)
elif args.model == 'basic':
    model = basic(args.maxdisp)
else:
    print('no model')


if args.loadmodel is not None:
    print('load PSMNet')
    state_dict = torch.load(args.loadmodel, map_location=torch.device('cpu'))
    
    state_dict2load = {}
    for k, v in state_dict['state_dict'].items():
        state_dict2load[k.replace("module.", "")] = v
    
    model.load_state_dict(state_dict2load)

print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))

def test(imgL,imgR):
    model.eval() 

    disp = model(imgL,imgR)

    disp = torch.squeeze(disp)
    pred_disp = disp.data.cpu().numpy()

    return pred_disp


def main():

        normal_mean_var = {'mean': [0.485, 0.456, 0.406],
                            'std': [0.229, 0.224, 0.225]}
        infer_transform = transforms.Compose([transforms.ToTensor(),
                                              transforms.Normalize(**normal_mean_var)])    

        imgL_o = Image.open(args.leftimg).convert('RGB')
        imgR_o = Image.open(args.rightimg).convert('RGB')

        imgL = infer_transform(imgL_o)
        imgR = infer_transform(imgR_o) 
       

        # pad to width and hight to 16 times
        if imgL.shape[1] % 16 != 0:
            times = imgL.shape[1]//16       
            top_pad = (times+1)*16 -imgL.shape[1]
        else:
            top_pad = 0

        if imgL.shape[2] % 16 != 0:
            times = imgL.shape[2]//16                       
            right_pad = (times+1)*16-imgL.shape[2]
        else:
            right_pad = 0    

        imgL = F.pad(imgL,(0,right_pad, top_pad,0)).unsqueeze(0)
        imgR = F.pad(imgR,(0,right_pad, top_pad,0)).unsqueeze(0)

        start_time = time.time()
        pred_disp = test(imgL,imgR)
        print('time = %.2f' %(time.time() - start_time))

        
        if top_pad !=0 and right_pad != 0:
            img = pred_disp[top_pad:,:-right_pad]
        elif top_pad ==0 and right_pad != 0:
            img = pred_disp[:,:-right_pad]
        elif top_pad !=0 and right_pad == 0:
            img = pred_disp[top_pad:,:]
        else:
            img = pred_disp
        
        print(img.shape)
        np.save("disp.npy", img)
        
        img = (img*256).astype('uint16')
        img = Image.fromarray(img)
        img.save('Test_disparity.png')

if __name__ == '__main__':
   main()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122

5. 用 PSMNet 估计视差

python Test_img.py --loadmodel pretrained_model_KITTI2015.tar --leftimg ../Code/left_img.png --rightimg ../Code/right_img.png
  • 1

6. 报错解决

RuntimeError: [enforce fail at ..\c10\core\CPUAllocator.cpp:72] data. DefaultCPUAllocator: not 
enough memory: you tried to allocate 190464000 bytes. Buy new RAM!
  • 1
  • 2

需要先把图像下采样,防止内存不足

在这里插入图片描述

7. 映射到 3D 点云

注意:映射代码需要从之前的博客中查看

import cv2
import numpy as np
import time 
import matplotlib.pyplot as plt
from depth import depth_map
from configs import img_path1,img_path2
from disparity import disparitymap
from image import Image_processing,downsample_image,create_output


def main():
    img = cv2.imread(img_path1, 1)
    img = downsample_image(img, 1)
    
    imgL = Image_processing(img_path1)
    imgR = Image_processing(img_path2)
    
    disp = np.load("disp.npy")
    print(imgL.shape, disp.shape, disp.max()) # (1000, 1482) (1000, 1482) 120.57847
    
    coordinates = depth_map(disp, img)
    print('\n Creating the output file... \n')
    create_output(coordinates, 'praxis_psmnet.ply')
    
main()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

可以得到对应的视差图:
在这里插入图片描述

8. 对比传统方法

在这里插入图片描述

9. 点云可视化

import open3d as o3d

# 读取 .ply 文件
ply_file_path = "praxis_psmnet.ply"  # 替换为你的 .ply 文件路径
point_cloud = o3d.io.read_point_cloud(ply_file_path)

# 创建可视化窗口
vis = o3d.visualization.Visualizer()
vis.create_window()

# 将点云添加到可视化窗口
vis.add_geometry(point_cloud)

# 获取渲染选项并调整点的大小
render_option = vis.get_render_option()
render_option.point_size = 2.5  # 调整点大小,默认值通常为 5.0,值越小点越小

# 启动可视化
vis.run()

# 销毁窗口
vis.destroy_window()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

在这里插入图片描述

文章知识点与官方知识档案匹配,可进一步学习相关知识
Python入门技能树 首页 概览 462192 人正在系统学习中

举报

选择你想要举报的内容(必选)
  • 内容涉黄
  • 政治相关
  • 内容抄袭
  • 涉嫌广告
  • 内容侵权
  • 侮辱谩骂
  • 样式问题
  • 其他