# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.
Implicitron のコンポーネントはすべて、統一された階層構成システムに基づいています。これにより、構成可能な変数とすべてのデフォルト値を新しいコンポーネントごとに別々に定義できます。次に、実験に関連するすべての構成は、実験を完全に指定する単一の構成ファイルに自動的に構成されます。特に重要な機能は、ユーザーが Implicitron ベースコンポーネントのサブクラスを挿入できる拡張ポイントです。
このシステムを定義するファイルは、PyTorch3D リポジトリの こちら にあります。Implicitron のボリュームチュートリアルには、構成システムを使用する簡単な例が含まれています。このチュートリアルでは、Implicitron の構成可能なコンポーネントを使用して、詳細な実践的な経験を提供します。
torch
と torchvision
がインストールされていることを確認します。pytorch3d
がインストールされていない場合は、次のセルを使用してインストールします
import os
import sys
import torch
need_pytorch3d=False
try:
import pytorch3d
except ModuleNotFoundError:
need_pytorch3d=True
if need_pytorch3d:
if torch.__version__.startswith("2.2.") and sys.platform.startswith("linux"):
# We try to install PyTorch3D via a released wheel.
pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
version_str="".join([
f"py3{sys.version_info.minor}_cu",
torch.version.cuda.replace(".",""),
f"_pyt{pyt_version_str}"
])
!pip install fvcore iopath
!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
else:
# We try to install PyTorch3D from source.
!pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'
omegaconf がインストールされていることを確認します。インストールされていない場合は、このセルを実行します。(ランタイムを再起動する必要はありません。)
!pip install omegaconf
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch3d.implicitron.tools.config import (
Configurable,
ReplaceableBase,
expand_args_fields,
get_default_args,
registry,
run_auto_creation,
)
@dataclass
class MyDataclass:
a: int
b: int = 8
c: Optional[Tuple[int, ...]] = None
def __post_init__(self):
print(f"created with a = {self.a}")
self.d = 2 * self.b
my_dataclass_instance = MyDataclass(a=18)
assert my_dataclass_instance.d == 16
👷 ここで dataclass
デコレーターは、クラス自体の定義を変更する関数であることに注意してください。定義の直後に実行されます。当社の構成システムでは、implicitron ライブラリコードには、ユーザー定義の実装を認識する必要がある変更されたバージョンのクラスが含まれる必要があります。したがって、クラスの変更を遅らせる必要があります。デコレーターは使用しません。
dc = DictConfig({"a": 2, "b": True, "c": None, "d": "hello"})
assert dc.a == dc["a"] == 2
OmegaConf には、yaml とのシリアル化があります。 Hydra ライブラリはこれを構成ファイルに依存しています。
print(OmegaConf.to_yaml(dc))
assert OmegaConf.create(OmegaConf.to_yaml(dc)) == dc
OmegaConf.structured は、dataclass またはdataclass のインスタンスから DictConfig を提供します。通常の DictConfig とは異なり、型チェックされており、既知のキーのみを追加できます。
structured = OmegaConf.structured(MyDataclass)
assert isinstance(structured, DictConfig)
print(structured)
print()
print(OmegaConf.to_yaml(structured))
structured
は a
の値がないことを認識しています。
そのようなオブジェクトはdataclass と互換性のあるメンバーを持っているため、次のように初期化を実行できます。
structured.a = 21
my_dataclass_instance2 = MyDataclass(**structured)
print(my_dataclass_instance2)
OmegaConf.structured をインスタンスに呼び出すこともできます。
structured_from_instance = OmegaConf.structured(my_dataclass_instance)
my_dataclass_instance3 = MyDataclass(**structured_from_instance)
print(my_dataclass_instance3)
OmegaConf.structured
と同等の機能を持ちながら、より多くの機能をサポートする関数を提供します。次の操作を実行することで、関数の使用を向上させます。構成可能なクラスは、デコレータではなく、特殊な基本クラス Configurable
を使用して示すことに注意してください。
class MyConfigurable(Configurable):
a: int
b: int = 8
c: Optional[Tuple[int, ...]] = None
def __post_init__(self):
print(f"created with a = {self.a}")
self.d = 2 * self.b
# The expand_args_fields function modifies the class like @dataclasses.dataclass.
# If it has not been called on a Configurable object before it has been instantiated, it will
# be called automatically.
expand_args_fields(MyConfigurable)
my_configurable_instance = MyConfigurable(a=18)
assert my_configurable_instance.d == 16
# get_default_args also calls expand_args_fields automatically
our_structured = get_default_args(MyConfigurable)
assert isinstance(our_structured, DictConfig)
print(OmegaConf.to_yaml(our_structured))
our_structured.a = 21
print(MyConfigurable(**our_structured))
当社のシステムでは、Configurable クラスが相互に含まれることができます。1 つ注意すべき点は、__post_init__
で run_auto_creation
を呼び出すことです。
class Inner(Configurable):
a: int = 8
b: bool = True
c: Tuple[int, ...] = (2, 3, 4, 6)
class Outer(Configurable):
inner: Inner
x: str = "hello"
xx: bool = False
def __post_init__(self):
run_auto_creation(self)
outer_dc = get_default_args(Outer)
print(OmegaConf.to_yaml(outer_dc))
outer = Outer(**outer_dc)
assert isinstance(outer, Outer)
assert isinstance(outer.inner, Inner)
print(vars(outer))
print(outer.inner)
inner_args が outer の追加メンバーである方法に注意してください。run_auto_creation(self)
は以下と同等です。
self.inner = Inner(**self.inner_args)
クラスが Configurable
ではなく ReplaceableBase
を基本クラスとして使用する場合、それを置換可能と呼びます。それは、代わりに子クラスによって使用されるように設計されていることを示します。サブクラスが実装することが想定される機能を示すために NotImplementedError
を使用する場合があります。システムは、各 ReplaceableBase のサブクラスを含むグローバル registry
を管理します。サブクラスはデコレータを使用して、それに自身を登録します。
ReplaceableBase を含む構成可能なクラス(つまり、当社のシステムを使用するクラス、つまり Configurable
または ReplaceableBase
の子)には、使用する具体的な子クラスを示す str
型の対応する class_type フィールドも含まれている必要があります。
class InnerBase(ReplaceableBase):
def say_something(self):
raise NotImplementedError
@registry.register
class Inner1(InnerBase):
a: int = 1
b: str = "h"
def say_something(self):
print("hello from an Inner1")
@registry.register
class Inner2(InnerBase):
a: int = 2
def say_something(self):
print("hello from an Inner2")
class Out(Configurable):
inner: InnerBase
inner_class_type: str = "Inner1"
x: int = 19
def __post_init__(self):
run_auto_creation(self)
def talk(self):
self.inner.say_something()
Out_dc = get_default_args(Out)
print(OmegaConf.to_yaml(Out_dc))
Out_dc.inner_class_type = "Inner2"
out = Out(**Out_dc)
print(out.inner)
out.talk()
この場合、多くの args
メンバーがあることに注意してください。通常はコードでそれらを無視しても問題ありません。それらは構成に必要です。
print(vars(out))
class MyLinear(torch.nn.Module, Configurable):
d_in: int = 2
d_out: int = 200
def __post_init__(self):
super().__init__()
self.linear = torch.nn.Linear(in_features=self.d_in, out_features=self.d_out)
def forward(self, x):
return self.linear.forward(x)
my_linear = MyLinear()
input = torch.zeros(2)
output = my_linear(input)
print("output shape:", output.shape)
my_linear
には Module の通常の機能がすべて備わっています。たとえば、torch.save
と torch.load
で保存およびロードできます。パラメータがあります。
for name, value in my_linear.named_parameters():
print(name, value.shape)
セクション 5 のように Out
が含まれるライブラリを使用していますが、InnerBase の独自の子を実装したいとします。必要な作業は定義の登録のみですが、expand_args_fields が Out で明示的または暗黙的に呼び出される前に、これを実行する必要があります。
@registry.register
class UserImplementedInner(InnerBase):
a: int = 200
def say_something(self):
print("hello from the user")
この時点で、Out クラスを再定義する必要があります。それ以外の場合は、UserImplementedInner なしで既に拡張されていると、クラスが拡張された時点で既知の実装が固定されるため、以下は機能しません。
スクリプトから実験を実行している場合、ここで覚えておくべきことは、ライブラリクラスを使用する前に、独自の実装を登録する独自モジュールをインポートする必要があるということです。
class Out(Configurable):
inner: InnerBase
inner_class_type: str = "Inner1"
x: int = 19
def __post_init__(self):
run_auto_creation(self)
def talk(self):
self.inner.say_something()
out2 = Out(inner_class_type="UserImplementedInner")
print(out2.inner)
ユーザーが独自に供給できるように、サブコンポーネントをプラグイン可能にする必要がある場合に発生する必要があることを確認しましょう。
class SubComponent(Configurable):
x: float = 0.25
def apply(self, a: float) -> float:
return a + self.x
class LargeComponent(Configurable):
repeats: int = 4
subcomponent: SubComponent
def __post_init__(self):
run_auto_creation(self)
def apply(self, a: float) -> float:
for _ in range(self.repeats):
a = self.subcomponent.apply(a)
return a
large_component = LargeComponent()
assert large_component.apply(3) == 4
print(OmegaConf.to_yaml(LargeComponent))
ジェネリックにする
class SubComponentBase(ReplaceableBase):
def apply(self, a: float) -> float:
raise NotImplementedError
@registry.register
class SubComponent(SubComponentBase):
x: float = 0.25
def apply(self, a: float) -> float:
return a + self.x
class LargeComponent(Configurable):
repeats: int = 4
subcomponent: SubComponentBase
subcomponent_class_type: str = "SubComponent"
def __post_init__(self):
run_auto_creation(self)
def apply(self, a: float) -> float:
for _ in range(self.repeats):
a = self.subcomponent.apply(a)
return a
large_component = LargeComponent()
assert large_component.apply(3) == 4
print(OmegaConf.to_yaml(LargeComponent))
次の内容を変更する必要がありました。
@registry.register
デコレーションを取得し、その基本クラスが新しいものに変わりました。subcomponent_class_type
が追加されました。subcomponent_args
を subcomponent_SubComponent_args
に変更する必要がありました。__post_init__
の定義を省略するか、その中で run_auto_creation
を呼び出さない。 subcomponent_class_type = "SubComponent"
の代わりに subcomponent_class_type: str = "SubComponent"