【小白深度教程 1.4】手把手教你复现 CompletionFormer 深度补全网络(含代码解读)

在上一节中,我们展示了如何利用训练好的 BP-Net 进行深度补全:

【小白深度教程 1.3】使用 BP-Net 深度补全网络,进行 KITTI 稠密点云和图像融合(含 Python 代码)

这节我们将手把手教你,如何训练和使用 CompletionFormer。

在这里插入图片描述

1. 介绍

CompletionFormer 结合 卷积神经网络 (CNN)和 Vision Transformer,提出了一种联合卷积注意力和 Transformer 块(JCAT),用于深度补全任务。该方法将卷积的局部连接性和 Transformer 的全局上下文结合到一个单一模型中,从而在户外 KITTI 和室内 NYUv2 数据集上超越了现有的基于 CNN 的方法,并在效率上显著优于纯 Transformer 方法。

具体解析可以查看:

CompletionFormer:用于深度补全的 Transformer 网络!CVPR 2023

2. 配置环境

下载代码:

git clone https://github.com/youmi-zym/CompletionFormer.git
cd CompletionFormer
  • 1
  • 2

假设我们已经安装了 Anaconda ,我们可以运行如下命令创建和安装环境:

conda create -n completionformer python=3.8
conda activate completionformer
# For CUDA Version == 11.3
pip install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113
pip install mmcv-full==1.4.4 mmsegmentation==0.22.1  
pip install timm tqdm thop tensorboardX opencv-python ipdb h5py ipython Pillow==9.5.0 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

注意:这里假设我们的电脑有独立显卡。

源代码使用的环境是:PyTorch 1.10.1, CUDA 11.3, Python 3.8 以及 Ubuntu 20.04。

然后我们安装 apex 进行多卡训练:

git clone https://github.com/NVIDIA/apex
cd apex
git reset --hard 4ef930c1c884fdca5f472ab2ce7cb9b505d26c1a
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 
  • 1
  • 2
  • 3
  • 4

然后编译可变形卷积:

cd THIS_PROJECT_ROOT/src/model/deformconv
sh make.sh
  • 1
  • 2

3. 准备数据集

这里我们使用 KITTI 数据集 进行训练,可以在 这里 下载:

在这里插入图片描述

这三个文件均要下载,并且按照如下文件结构进行排布:

├── kitti_depth
|   ├──data_depth_annotated
|   |  ├── train
|   |  ├── val
|   ├── data_depth_velodyne
|   |  ├── train
|   |  ├── val
|   ├── data_depth_selection
|   |  ├── test_depth_completion_anonymous
|   |  |── test_depth_prediction_anonymous
|   |  ├── val_selection_cropped
|   ├── kitti_raw
|   |   ├── 2011_09_26
|   |   ├── 2011_09_28
|   |   ├── 2011_09_29
|   |   ├── 2011_09_30
|   |   ├── 2011_10_03
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

注意我们还需要下载 kitti_raw,因为刚刚下载的数据中不包含图像。

4. 下载 KITTI RAW 数据

可以使用 KITTI 的官方工具:

https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data_downloader.zip

直接下载后运行:

bash raw_data_downloader.sh
  • 1

5. 开始训练

我们使用 L1 和 L2 损失进行训练:

cd /src

python main.py --dir_data PATH_TO_KITTI_DC --data_name KITTIDC --split_json ../data_json/kitti_dc.json \
    --patch_height 240 --patch_width 1216 --gpus 0,1,2,3 --loss 1.0*L1+1.0*L2 --lidar_lines 64 \
    --batch_size 3 --max_depth 90.0 --lr 0.001 --epochs 250 --milestones 150 180 210 240 \
    --top_crop 100 --test_crop --log_dir ../experiments/ --save NAME_TO_SAVE \
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

其中 PATH_TO_KITTI_DC 要改成 KITTI 数据集的路径。

6. 测试

python main.py --dir_data PATH_TO_KITTI_DC --data_name KITTIDC --split_json ../data_json/kitti_dc.json \
    --patch_height 240 --patch_width 1216 --gpus 0 --max_depth 90.0 --top_crop 100 --test_crop --save-image \
    --test_only --pretrain PATH_TO_WEIGHTS --save NAME_TO_SAVE
  • 1
  • 2
  • 3

7. 生成提交结果

python main.py --dir_data PATH_TO_KITTI_DC --data_name KITTIDC --split_json ../data_json/kitti_dc_test.json \
    --patch_height 240 --patch_width 1216 --gpus 0 --max_depth 90.0 \
    --test_only --pretrain PATH_TO_WEIGHTS --save_image --save_result_only --save NAME_TO_SAVE
  • 1
  • 2
  • 3

这是我们可以通过 KITTI 官方提交界面 提交结果:

在这里插入图片描述

8. 代码解读

Backbone 模块

Backbone 是一个包含编码器和解码器的模块,用于处理 RGB、深度(Depth),或者 RGB-D(结合 RGB 和深度)输入。该网络通过一系列卷积、转置卷积和基本块来提取和融合特征,最终生成深度图、引导特征图和置信度图。

代码解析

  1. 初始化函数 ( __init__ 方法) :

    • 初始化函数接收参数 args (网络配置参数)和 mode (输入模式: rgbd rgb d )。
    • 根据输入模式初始化不同的卷积层:
      • 如果是 rgbd 模式,分别初始化 RGB 和深度图的卷积层,并在后续将它们融合。
      • 如果是 rgb 模式,仅初始化 RGB 输入的卷积层。
      • 如果是 d 模式,仅初始化深度输入的卷积层。
    • 使用 PVT(Pyramid Vision Transformer)作为编码器,提取特征。
    • 定义通道数列表 channels ,用于后续的解码操作。
    • 共享解码器部分
      • 解码器从 1/16 尺度逐步恢复到 1/1 尺度,分别通过转置卷积和 BasicBlock 实现。
    • 深度分支
      • 用于生成初始深度图的解码器部分。
    • 引导分支
      • 用于生成引导特征图的解码器部分。
    • 置信度分支 (可选):
      • 用于生成置信度图,仅在 conf_prop (置信度传播)开启时使用。
  2. 特征融合函数 ( _concat 方法) :

    • 该方法用于将不同尺度的特征图通过双线性插值调整到相同大小,然后在指定维度上进行拼接。
  3. 前向传播函数 ( forward 方法) :

    • 根据输入模式对 RGB 和深度数据进行特征提取。
      • rgbd 模式下,将 RGB 和深度特征进行融合。
      • rgb d 模式下,分别处理 RGB 或深度输入。
    • 将提取到的特征输入到 PVT 编码器中,获得多层特征图。
    • 共享解码过程
      • 从最深层特征图开始,通过解码器逐步恢复特征图的空间尺寸,并融合相应的编码器特征。
    • 深度解码
      • 利用共享解码后的特征图,进一步处理生成初始深度图。
    • 引导特征图解码
      • 生成用于空间传播的引导特征图。
    • 置信度图解码(可选)
      • 如果开启置信度传播,则生成置信度图,否则返回 None

关键概念

  • PVT 编码器 :用 Pyramid Vision Transformer 提取多尺度特征,增强了网络的上下文感知能力。
  • 解码器模块 :利用转置卷积和 BasicBlock 将特征图逐步还原到输入图像的原始分辨率。
  • 深度、引导和置信度解码 :网络输出三个关键图——深度图、引导图和置信度图,分别用于初始深度估计、特征传播和自信度评估。

NLSPN 模块

NLSPN 模块是非局部空间传播网络(Non-Local Spatial Propagation Network)的实现,主要用于需要进行空间传播和亲和力建模的任务。它广泛应用于深度估计、图像修复等需要对像素级细节进行调整的任务中。以下是对该代码的详细解读。

模块概述

NLSPN 通过引入可调制的可变形卷积(Modulated Deformable Convolution)和亲和力计算来增强特征传播能力。该模块的核心在于通过可变形卷积灵活地调整特征图的传播方式,从而实现对特征的自适应调整。

代码的关键部分

  1. . 偏移和亲和力计算 ( _get_offset_affinity 方法) :

    • 该方法根据引导图计算偏移和亲和力。
    • 使用卷积操作生成偏移量和亲和力,然后根据特定的亲和力模式(如 AS ASS TC TGASS )进行处理。
    • 如果启用了置信度传播( conf_prop ),则会进一步处理偏移量,利用可变形卷积函数对偏移量进行调整。
    • 最后对亲和力进行归一化处理,以确保亲和力的值在传播过程中保持合适的比例。
  2. 单次传播 ( _propagate_once 方法) :

    • 该方法使用可变形卷积函数,根据之前计算的偏移和亲和力,对特征图进行传播。
    • 可变形卷积可以根据偏移和亲和力动态调整卷积核的位置和形状,从而实现特征图的灵活传播。
  3. 前向传播 ( forward 方法) :

    • 前向传播方法处理初始特征图( feat_init ),引导图( guidance )以及可选的置信度输入。
    • 方法中对输入形状进行检查,以确保输入符合预期的尺寸要求,并在启用置信度传播时保证置信度输入存在。
    • 如果启用了输入保护( preserve_input ),则在每次传播中保留固定的特征图。
    • 通过指定的传播次数( prop_time )迭代更新特征图,每次迭代都基于当前的偏移和亲和力进行传播。
    • 最终返回结果特征图、传播过程中的所有特征图、偏移、亲和力以及亲和力的缩放常数。

应用场景

NLSPN 模块通常用于需要细粒度特征调整的场景,如深度估计、图像补全等。通过灵活的卷积操作和亲和力建模, NLSPN 能够更有效地处理空间上的不规则特征传播问题。

如果你需要进一步深入了解某个部分或有其他问题,欢迎告诉我!

文章知识点与官方知识档案匹配,可进一步学习相关知识

举报

选择你想要举报的内容(必选)
  • 内容涉黄
  • 政治相关
  • 内容抄袭
  • 涉嫌广告
  • 内容侵权
  • 侮辱谩骂
  • 样式问题
  • 其他