【小白深度教程 1.18】手把手教你使用 Pytorch3D(3)使用可微分体积渲染拟合 Volume

本教程展示了如何使用可微分体积渲染,根据场景的一组视图来拟合体积(Volume)。

更具体地说,本教程将解释如何:

  • 创建一个可微分体积渲染器。
  • 创建一个体积模型(包括如何使用 Volumes 类)。
  • 使用可微分体积渲染器基于图像拟合体积。
  • 可视化预测的体积。

1. 安装和导入模块

确保安装了“torch”和“torchvision”。如果’ pytorch3d '未安装,请使用以下命令安装:

pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html

或者从 github 源码安装:

pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'

然后导入模块:

import os
import sys
import time
import json
import glob
import torch
import math
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from IPython import display

# 用于渲染的数据结构和函数
from pytorch3d.structures import Volumes
from pytorch3d.renderer import (
    FoVPerspectiveCameras, 
    VolumeRenderer,
    NDCMultinomialRaysampler,
    EmissionAbsorptionRaymarcher
)
from pytorch3d.transforms import so3_exp_map

# 获取使用的设备
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")

2. 生成场景的图像和掩码

以下单元格生成我们的训练数据。它从上一章教程中渲染的牛网格从多个视角进行渲染并返回:

  • 由牛网格渲染器生成的一批图像和轮廓张量。
  • 每个渲染对应的一组相机。
target_cameras, target_images, target_silhouettes = generate_cow_renders(num_views=40)
print(f'Generated {len(target_images)} images/silhouettes/cameras.')

3. 初始化 Volume 渲染器

以下代码初始化一个体积渲染器,该渲染器从目标图像的每个像素发射一条射线,并沿射线采样一组均匀分布的点。在每个射线点,通过查询场景体积模型中的相应位置获得对应的密度和颜色值

render_size = target_images.shape[1]
volume_extent_world = 3.0

raysampler = NDCMultinomialRaysampler(
    image_width=render_size,
    image_height=render_size,
    n_pts_per_ray=150,
    min_depth=0.1,
    max_depth=volume_extent_world,
)

raymarcher = EmissionAbsorptionRaymarcher()

renderer = VolumeRenderer(
    raysampler=raysampler, raymarcher=raymarcher,
)

4. 初始化 Volume 模型

接下来我们实例化场景的体积模型。它将 3D 空间量化为立方体素,每个体素用一个 RGB 颜色和一个密度标量表示,描述体素的不透明度。

class VolumeModel(torch.nn.Module):
    def __init__(self, renderer, volume_size=[64] * 3, voxel_size=0.1):
        super().__init__()
        self.log_densities = torch.nn.Parameter(-4.0 * torch.ones(1, *volume_size))
        self.log_colors = torch.nn.Parameter(torch.zeros(3, *volume_size))
        self._voxel_size = voxel_size
        self._renderer = renderer
        
    def forward(self, cameras):
        batch_size = cameras.R.shape[0]
        densities = torch.sigmoid(self.log_densities)
        colors = torch.sigmoid(self.log_colors)
        
        volumes = Volumes(
            densities=densities[None].expand(batch_size, *self.log_densities.shape),
            features=colors[None].expand(batch_size, *self.log_colors.shape),
            voxel_size=self._voxel_size,
        )
        return self._renderer(cameras=cameras, volumes=volumes)[0]
    
def huber(x, y, scaling=0.1):
    diff_sq = (x - y) ** 2
    loss = ((1 + diff_sq / (scaling**2)).clamp(1e-4).sqrt() - 1) * float(scaling)
    return loss

5. 拟合 Volume

这里我们使用可微分渲染进行体积拟合。

target_cameras = target_cameras.to(device)
target_images = target_images.to(device)
target_silhouettes = target_silhouettes.to(device)

volume_size = 128
volume_model = VolumeModel(
    renderer,
    volume_size=[volume_size] * 3, 
    voxel_size=volume_extent_world / volume_size,
).to(device)

lr = 0.1
optimizer = torch.optim.Adam(volume_model.parameters(), lr=lr)
batch_size = 10
n_iter = 300

for iteration in range(n_iter):
    if iteration == round(n_iter * 0.75):
        print('Decreasing LR 10-fold ...')
        optimizer = torch.optim.Adam(
            volume_model.parameters(), lr=lr * 0.1
        )
    
    optimizer.zero_grad()
    
    batch_idx = torch.randperm(len(target_cameras))[:batch_size]
    
    batch_cameras = FoVPerspectiveCameras(
        R=target_cameras.R[batch_idx], 
        T=target_cameras.T[batch_idx], 
        znear=target_cameras.znear[batch_idx],
        zfar=target_cameras.zfar[batch_idx],
        aspect_ratio=target_cameras.aspect_ratio[batch_idx],
        fov=target_cameras.fov[batch_idx],
        device=device,
    )
    
    rendered_images, rendered_silhouettes = volume_model(
        batch_cameras
    ).split([3, 1], dim=-1)
    
    sil_err = huber(
        rendered_silhouettes[..., 0], target_silhouettes[batch_idx],
    ).abs().mean()

    color_err = huber(
        rendered_images, target_images[batch_idx],
    ).abs().mean()
    
    loss = color_err + sil_err 
    
    if iteration % 10 == 0:
        print(
            f'Iteration {iteration:05d}:'
            + f' color_err = {float(color_err):1.2e}'
            + f' mask_err = {float(sil_err):1.2e}'
        )
    
    loss.backward()
    optimizer.step()
    
    if iteration % 40 == 0:
        im_show_idx = int(torch.randint(low=0, high=batch_size, size=(1,)))
        fig, ax = plt.subplots(2, 2, figsize=(10, 10))
        ax = ax.ravel()
        clamp_and_detach = lambda x: x.clamp(0.0, 1.0).cpu().detach().numpy()
        ax[0].imshow(clamp_and_detach(rendered_images[im_show_idx]))
        ax[1].imshow(clamp_and_detach(target_images[batch_idx[im_show_idx], ..., :3]))
        ax[2].imshow(clamp_and_detach(rendered_silhouettes[im_show_idx, ..., 0]))
        ax[3].imshow(clamp_and_detach(target_silhouettes[batch_idx[im_show_idx]]))
        for ax_, title_ in zip(
            ax, 
            ("rendered image", "target image", "rendered silhouette", "target silhouette")
        ):
            ax_.grid("off")
            ax_.axis("off")
            ax_.set_title(title_)
        fig.canvas.draw(); fig.show()
        display.clear_output(wait=True)
        display.display(fig)

在这里插入图片描述

6. 可视化优化后的 Volume

最后,我们通过从多个视角渲染围绕体积的 y 轴旋转的体积来可视化优化的体积。

def generate_rotating_volume(volume_model, n_frames=50):
    logRs = torch.zeros(n_frames, 3, device=device)
    logRs[:, 1] = torch.linspace(0.0, 2.0 * 3.14, n_frames, device=device)
    Rs = so3_exp_map(logRs)
    Ts = torch.zeros(n_frames, 3, device=device)
    Ts[:, 2] = 2.7
    frames = []
    print('Generating rotating volume ...')
    for R, T in zip(tqdm(Rs), Ts):
        camera = FoVPerspectiveCameras(
            R=R[None], 
            T=T[None], 
            znear=target_cameras.znear[0],
            zfar=target_cameras.zfar[0],
            aspect_ratio=target_cameras.aspect_ratio[0],
            fov=target_cameras.fov[0],
            device=device,
        )
        frames.append(volume_model(camera)[..., :3].clamp(0.0, 1.0))
    return torch.cat(frames)
    
with torch.no_grad():
    rotating_volume_frames = generate_rotating_volume(volume_model, n_frames=7*4)

image_grid(rotating_volume_frames.clamp(0., 1.).cpu().numpy(), rows=4, cols=7, rgb=True, fill=True)
plt.show()