【小白深度教程 1.16】手把手教你使用 Pytorch3D(1)使用 3D 损失函数来拟合 Mesh
在这篇文章中,我们将学习如何使用3D损失函数变形源网格(Source Mesh)以形成目标网格(Target Mesh)
在本教程中,我们学习如何将一个初始的通用形状(例如球体)变形为目标形状。
使用 3D 损失函数来拟合 Mesh_files/29712dd013c14a95b4adcf3cdd0b0ed9.png)
我们将涵盖:
- 如何从
.obj文件中加载网格 - 如何使用 PyTorch3D 的 Meshes 数据结构
- 如何使用 PyTorch3D 的 4 种不同的网格损失函数
- 如何设置一个优化循环
从一个球体网格开始,我们学习网格中每个顶点的偏移量,使得在每次优化步骤中预测的网格更接近目标网格。为此,我们需要最小化:
chamfer_distance,即预测(变形后)网格与目标网格之间的距离,定义为从其表面可微分采样点云集合之间的切面距离。
然而,仅仅最小化预测网格和目标网格之间的切面距离会导致不光滑的形状(可以通过将 w_chamfer=1.0 而所有其他权重设为 0.0 来验证这一点)。
我们通过在目标函数中添加形状正则化器来强制实现平滑性。具体来说,我们添加了:
mesh_edge_length,用于最小化预测网格中边的长度。mesh_normal_consistency,用于强制相邻面的法线一致性。mesh_laplacian_smoothing,即拉普拉斯正则化器。
1. 安装和导入模块
确保安装了“torch”和“torchvision”。如果’ pytorch3d '未安装,请使用以下命令安装:
pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
或者从 github 源码安装:
pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'
然后导入模块:
import os
import torch
from pytorch3d.io import load_obj, save_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import (
chamfer_distance,
mesh_edge_loss,
mesh_laplacian_smoothing,
mesh_normal_consistency,
)
import numpy as np
from tqdm.notebook import tqdm
%matplotlib notebook
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['savefig.dpi'] = 80
mpl.rcParams['figure.dpi'] = 80
# Set the device
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
print("WARNING: CPU only, this will be slow!")
2. 加载 .obj 文件并创建 Mesh 对象
下载海豚的目标3D模型。它将在本地保存为一个名为 dolphin.obj 的文件。
wget https://dl.fbaipublicfiles.com/pytorch3d/data/dolphin/dolphin.obj
# Load the dolphin mesh.
trg_obj = 'dolphin.obj'
# We read the target 3D model using load_obj
verts, faces, aux = load_obj(trg_obj)
# verts is a FloatTensor of shape (V, 3) where V is the number of vertices in the mesh
# faces is an object which contains the following LongTensors: verts_idx, normals_idx and textures_idx
# For this tutorial, normals and textures are ignored.
faces_idx = faces.verts_idx.to(device)
verts = verts.to(device)
# We scale normalize and center the target mesh to fit in a sphere of radius 1 centered at (0,0,0).
# (scale, center) will be used to bring the predicted mesh to its original center and scale
# Note that normalizing the target mesh, speeds up the optimization but is not necessary!
center = verts.mean(0)
verts = verts - center
scale = max(verts.abs().max(0)[0])
verts = verts / scale
# We construct a Meshes structure for the target mesh
trg_mesh = Meshes(verts=[verts], faces=[faces_idx])
# We initialize the source shape to be a sphere of radius 1
src_mesh = ico_sphere(4, device)
3. 可视化源 Mesh 和目标 Mesh
def plot_pointcloud(mesh, title=""):
# Sample points uniformly from the surface of the mesh.
points = sample_points_from_meshes(mesh, 5000)
x, y, z = points.clone().detach().cpu().squeeze().unbind(1)
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111, projection='3d')
ax.scatter3D(x, z, -y)
ax.set_xlabel('x')
ax.set_ylabel('z')
ax.set_zlabel('y')
ax.set_title(title)
ax.view_init(190, 30)
plt.show()
# %matplotlib notebook
plot_pointcloud(trg_mesh, "Target mesh")
plot_pointcloud(src_mesh, "Source mesh")
4. 迭代优化进行拟合
# We will learn to deform the source mesh by offsetting its vertices
# The shape of the deform parameters is equal to the total number of vertices in src_mesh
deform_verts = torch.full(src_mesh.verts_packed().shape, 0.0, device=device, requires_grad=True)
# The optimizer
optimizer = torch.optim.SGD([deform_verts], lr=1.0, momentum=0.9)
# Number of optimization steps
Niter = 2000
# Weight for the chamfer loss
w_chamfer = 1.0
# Weight for mesh edge loss
w_edge = 1.0
# Weight for mesh normal consistency
w_normal = 0.01
# Weight for mesh laplacian smoothing
w_laplacian = 0.1
# Plot period for the losses
plot_period = 250
loop = tqdm(range(Niter))
chamfer_losses = []
laplacian_losses = []
edge_losses = []
normal_losses = []
%matplotlib inline
for i in loop:
# Initialize optimizer
optimizer.zero_grad()
# Deform the mesh
new_src_mesh = src_mesh.offset_verts(deform_verts)
# We sample 5k points from the surface of each mesh
sample_trg = sample_points_from_meshes(trg_mesh, 5000)
sample_src = sample_points_from_meshes(new_src_mesh, 5000)
# We compare the two sets of pointclouds by computing (a) the chamfer loss
loss_chamfer, _ = chamfer_distance(sample_trg, sample_src)
# and (b) the edge length of the predicted mesh
loss_edge = mesh_edge_loss(new_src_mesh)
# mesh normal consistency
loss_normal = mesh_normal_consistency(new_src_mesh)
# mesh laplacian smoothing
loss_laplacian = mesh_laplacian_smoothing(new_src_mesh, method="uniform")
# Weighted sum of the losses
loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian
# Print the losses
loop.set_description('total_loss = %.6f' % loss)
# Save the losses for plotting
chamfer_losses.append(float(loss_chamfer.detach().cpu()))
edge_losses.append(float(loss_edge.detach().cpu()))
normal_losses.append(float(loss_normal.detach().cpu()))
laplacian_losses.append(float(loss_laplacian.detach().cpu()))
# Plot mesh
if i % plot_period == 0:
plot_pointcloud(new_src_mesh, title="iter: %d" % i)
# Optimization step
loss.backward()
optimizer.step()
5. 可视化损失
fig = plt.figure(figsize=(13, 5))
ax = fig.gca()
ax.plot(chamfer_losses, label="chamfer loss")
ax.plot(edge_losses, label="edge loss")
ax.plot(normal_losses, label="normal loss")
ax.plot(laplacian_losses, label="laplacian loss")
ax.legend(fontsize="16")
ax.set_xlabel("Iteration", fontsize="16")
ax.set_ylabel("Loss", fontsize="16")
ax.set_title("Loss vs iterations", fontsize="16");
6. 保存结果
# Fetch the verts and faces of the final predicted mesh
final_verts, final_faces = new_src_mesh.get_mesh_verts_faces(0)
# Scale normalize back to the original target size
final_verts = final_verts * scale + center
# Store the predicted mesh using save_obj
final_obj = 'final_model.obj'
save_obj(final_obj, final_verts, final_faces)