# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.
このデモでは、PyTorch3DのVolumeRendererをImplicitronのカスタム陰関数として使用します。Implicitronの主要なオブジェクトの一部と
torch
とtorchvision
がインストールされていることを確認してください。pytorch3d
がインストールされていない場合は、次のセルを使用してインストールしてください。
import os
import sys
import torch
need_pytorch3d=False
try:
import pytorch3d
except ModuleNotFoundError:
need_pytorch3d=True
if need_pytorch3d:
if torch.__version__.startswith("2.2.") and sys.platform.startswith("linux"):
# We try to install PyTorch3D via a released wheel.
pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
version_str="".join([
f"py3{sys.version_info.minor}_cu",
torch.version.cuda.replace(".",""),
f"_pyt{pyt_version_str}"
])
!pip install fvcore iopath
!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
else:
# We try to install PyTorch3D from source.
!pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'
omegaconfとvisdomがインストールされていることを確認してください。インストールされていない場合は、このセルを実行してください。(ランタイムを再起動する必要はありません。)
!pip install omegaconf visdom
import logging
from typing import Tuple
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import torch
import tqdm
from IPython.display import HTML
from omegaconf import OmegaConf
from PIL import Image
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase, ImplicitronRayBundle
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
from pytorch3d.implicitron.tools.config import get_default_args, registry, remove_unused_components
from pytorch3d.renderer.implicit.renderer import VolumeSampler
from pytorch3d.structures import Volumes
from pytorch3d.vis.plotly_vis import plot_batch_individually, plot_scene
output_resolution = 80
torch.set_printoptions(sci_mode=False)
Implicitronにおけるデータセットのtrain、val、test部分はdataset_map
として表され、DatasetMapProvider
の実装によって提供されます。RenderedMeshDatasetMapProvider
は、メッシュを取得してレンダリングすることで、trainコンポーネントのみを持つシングルシーンデータセットを生成するものです。これを牛のメッシュで使用します。
Google Colabを使用してこのノートブックを実行する場合は、次のセルを実行してメッシュのobjファイルとテクスチャファイルをフェッチし、data/cow_meshパスに保存してください。ローカルで実行する場合は、データは既に正しいパスにあります。
!mkdir -p data/cow_mesh
!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj
!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl
!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png
cow_provider = RenderedMeshDatasetMapProvider(
data_file="data/cow_mesh/cow.obj",
use_point_light=False,
resolution=output_resolution,
)
dataset_map = cow_provider.get_dataset_map()
tr_cameras = [training_frame.camera for training_frame in dataset_map.train]
# The cameras are all in the XZ plane, in a circle about 2.7 from the origin
centers = torch.cat([i.get_camera_center() for i in tr_cameras])
print(centers.min(0).values)
print(centers.max(0).values)
# visualization of the cameras
plot = plot_scene({"k": {i: camera for i, camera in enumerate(tr_cameras)}}, camera_scale=0.25)
plot.layout.scene.aspectmode = "data"
plot
ニューラルレンダリング手法の中核には、空間座標の関数である陰関数があり、これは何らかのレンダリングプロセスで使用されます。(多くの場合、これらの関数は、視点方向などの他のデータも追加で受け取ることができます。)一般的なレンダリングプロセスは、陰関数によって提供される密度と色に対するレイトレーシングです。ここでは、3Dボリュームグリッドからのサンプリングは、空間座標の非常に単純な関数です。
ここでは、PyTorch3Dの既存の機能を使用してボリュームグリッドからサンプリングする独自の陰関数を定義します。これは、ImplicitFunctionBase
をサブクラス化することで行います。特別なデコレータを使用して、サブクラスを登録する必要があります。モジュールの構成にはPythonのdataclassアノテーションを使用します。
@registry.register
class MyVolumes(ImplicitFunctionBase, torch.nn.Module):
grid_resolution: int = 50 # common HWD of volumes, the number of voxels in each direction
extent: float = 1.0 # In world coordinates, the volume occupies is [-extent, extent] along each axis
def __post_init__(self):
# We have to call this explicitly if there are other base classes like Module
super().__init__()
# We define parameters like other torch.nn.Module objects.
# In this case, both our parameter tensors are trainable; they govern the contents of the volume grid.
density = torch.full((self.grid_resolution, self.grid_resolution, self.grid_resolution), -2.0)
self.density = torch.nn.Parameter(density)
color = torch.full((3, self.grid_resolution, self.grid_resolution, self.grid_resolution), 0.0)
self.color = torch.nn.Parameter(color)
self.density_activation = torch.nn.Softplus()
def forward(
self,
ray_bundle: ImplicitronRayBundle,
fun_viewpool=None,
global_code=None,
):
densities = self.density_activation(self.density[None, None])
voxel_size = 2.0 * float(self.extent) / self.grid_resolution
features = self.color.sigmoid()[None]
# Like other PyTorch3D structures, the actual Volumes object should only exist as long
# as one iteration of training. It is local to this function.
volume = Volumes(densities=densities, features=features, voxel_size=voxel_size)
sampler = VolumeSampler(volumes=volume)
densities, features = sampler(ray_bundle)
# When an implicit function is used for raymarching, i.e. for MultiPassEmissionAbsorptionRenderer,
# it must return (densities, features, an auxiliary tuple)
return densities, features, {}
PyTorch3Dの主要なモデルオブジェクトはGenericModel
であり、レンダラーや陰関数など、主要なステップのプラグ可能なコンポーネントを備えています。ここでは同等の2つの構築方法があります。
CONSTRUCT_MODEL_FROM_CONFIG = True
if CONSTRUCT_MODEL_FROM_CONFIG:
# Via a DictConfig - this is how our training loop with hydra works
cfg = get_default_args(GenericModel)
cfg.implicit_function_class_type = "MyVolumes"
cfg.render_image_height=output_resolution
cfg.render_image_width=output_resolution
cfg.loss_weights={"loss_rgb_huber": 1.0}
cfg.tqdm_trigger_threshold=19000
cfg.raysampler_AdaptiveRaySampler_args.scene_extent= 4.0
gm = GenericModel(**cfg)
else:
# constructing GenericModel directly
gm = GenericModel(
implicit_function_class_type="MyVolumes",
render_image_height=output_resolution,
render_image_width=output_resolution,
loss_weights={"loss_rgb_huber": 1.0},
tqdm_trigger_threshold=19000,
raysampler_AdaptiveRaySampler_args = {"scene_extent": 4.0}
)
# In this case we can get the equivalent DictConfig cfg object to the way gm is configured as follows
cfg = OmegaConf.structured(gm)
デフォルトのレンダラーは、放射吸収レイトレーサーです。このデフォルトを維持します。
# We can display the configuration in use as follows.
remove_unused_components(cfg)
yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
%page -r yaml
device = torch.device("cuda:0")
gm.to(device)
assert next(gm.parameters()).is_cuda
train_data_collated = [FrameData.collate([frame.to(device)]) for frame in dataset_map.train]
gm.train()
optimizer = torch.optim.Adam(gm.parameters(), lr=0.1)
iterator = tqdm.tqdm(range(2000))
for n_batch in iterator:
optimizer.zero_grad()
frame = train_data_collated[n_batch % len(dataset_map.train)]
out = gm(**frame, evaluation_mode=EvaluationMode.TRAINING)
out["objective"].backward()
if n_batch % 100 == 0:
iterator.set_postfix_str(f"loss: {float(out['objective']):.5f}")
optimizer.step()
すべての視点から完全な画像を生成して、外観を確認します。
def to_numpy_image(image):
# Takes an image of shape (C, H, W) in [0,1], where C=3 or 1
# to a numpy uint image of shape (H, W, 3)
return (image * 255).to(torch.uint8).permute(1, 2, 0).detach().cpu().expand(-1, -1, 3).numpy()
def resize_image(image):
# Takes images of shape (B, C, H, W) to (B, C, output_resolution, output_resolution)
return torch.nn.functional.interpolate(image, size=(output_resolution, output_resolution))
gm.eval()
images = []
expected = []
masks = []
masks_expected = []
for frame in tqdm.tqdm(train_data_collated):
with torch.no_grad():
out = gm(**frame, evaluation_mode=EvaluationMode.EVALUATION)
image_rgb = to_numpy_image(out["images_render"][0])
mask = to_numpy_image(out["masks_render"][0])
expd = to_numpy_image(resize_image(frame.image_rgb)[0])
mask_expected = to_numpy_image(resize_image(frame.fg_probability)[0])
images.append(image_rgb)
masks.append(mask)
expected.append(expd)
masks_expected.append(mask_expected)
各視点から予測された画像と期待される画像、それに続く予測されたマスクと期待されるマスクを示すグリッドを描画します。これは、いくつかの大きな行にラップされた、4行の画像のグリッドです。
┌────────┬────────┐ ┌────────┐
│pred │pred │ │pred │
│image │image │ │image │
│1 │2 │ │n │
├────────┼────────┤ ├────────┤
│expected│expected│ │expected│
│image │image │ ... │image │
│1 │2 │ │n │
├────────┼────────┤ ├────────┤
│pred │pred │ │pred │
│mask │mask │ │mask │
│1 │2 │ │n │
├────────┼────────┤ ├────────┤
│expected│expected│ │expected│
│mask │mask │ │mask │
│1 │2 │ │n │
├────────┼────────┤ ├────────┤
│pred │pred │ │pred │
│image │image │ │image │
│n+1 │n+1 │ │2n │
├────────┼────────┤ ├────────┤
│expected│expected│ │expected│
│image │image │ ... │image │
│n+1 │n+2 │ │2n │
├────────┼────────┤ ├────────┤
│pred │pred │ │pred │
│mask │mask │ │mask │
│n+1 │n+2 │ │2n │
├────────┼────────┤ ├────────┤
│expected│expected│ │expected│
│mask │mask │ │mask │
│n+1 │n+2 │ │2n │
└────────┴────────┘ └────────┘
...
</center></small>
images_to_display = [images.copy(), expected.copy(), masks.copy(), masks_expected.copy()]
n_rows = 4
n_images = len(images)
blank_image = images[0] * 0
n_per_row = 1+(n_images-1)//n_rows
for _ in range(n_per_row*n_rows - n_images):
for group in images_to_display:
group.append(blank_image)
images_to_display_listed = [[[i] for i in j] for j in images_to_display]
split = []
for row in range(n_rows):
for group in images_to_display_listed:
split.append(group[row*n_per_row:(row+1)*n_per_row])
Image.fromarray(np.block(split))
# Print the maximum channel intensity in the first image.
print(images[1].max()/255)
plt.ioff()
fig, ax = plt.subplots(figsize=(3,3))
ax.grid(None)
ims = [[ax.imshow(im, animated=True)] for im in images]
ani = animation.ArtistAnimation(fig, ims, interval=80, blit=True)
ani_html = ani.to_jshtml()
HTML(ani_html)
# If you want to see the output of the model with the volume forced to opaque white, run this and re-evaluate
# with torch.no_grad():
# gm._implicit_functions[0]._fn.density.fill_(9.0)
# gm._implicit_functions[0]._fn.color.fill_(9.0)