Engineering python pytorch 機械学習

Pytorch入門:nn.ModuleListを使ってnn.Moduleをリストで保持する

nn.ModuleListはpytorchのネットワーク定義の際に使用されるクラスの一つになります。pytorchでネットワークを定義していると、同じような記載を繰り返すことがあります。そのようなときはfor文などで、コードをシンプルにしたいのがエンジニアの性なのですが、nn.Moduleは普通のpythonのlistでそのまま持つことはできません。そのようなときに使うのがnn.ModuleListになります。今日は公式ドキュメントを元にnn.ModuleListの解説をしていきます。

nn.ModuleListとは?

nn.ModuleListは、公式ドキュメント(https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html#torch.nn.ModuleList)によると以下のように説明されています。

ModuleList can be indexed like a regular Python list, but modules it contains are properly registered, and will be visible by all Module methods.

(ModuleListは通常のPythonのリストのようにインデックスを付けることができますが、含まれるモジュールは適切に登録され、すべてのModuleメソッドから見えるようになります)

つまり、pytorchのモデル定義の基本であるnn.Moduleをpythonのリスト構造で持てるようになります。そのような持ち方をすることで、繰り返し出てくるモジュールなどを

pythonのlistでModuleを保持できないのか?

pytorchでは、各nn.Moduleで使用されているlayerやほかのModule内のメンバ変数に格納されているオブジェクトの型を見て、ネットワークの構成がどのようなものであるかを認識しています。もし単純にpythonのlistlayerModuleが格納されただけの場合、それらはpytorchから学習可能なパラメータを持ったレイヤとしては認識されなくなってしまいます。そうすると学習時にうまく学習ができないエラーなどが発生するため、nn.ModuleListなどで学習可能なパラメータがあることをpytorchに認識させる必要があるのです。

詳しくは以下を参照ください

nn.Moduleの使い方

それでは実際の使用例を公式ドキュメントから引用して確認してみます。

# Official docs impl.
class MyModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

上記ではnn.Linear (全結合層)をリストとして保持した後に、nn.ModuleListに渡してレイヤをリストとして定義しています。その後forward関数で、for文を用いて、1つずつ取り出して順伝搬を定義しています。こうすることで、大幅にコード量を減らすことができ、シンプルに繰り返しのレイヤーを定義できるようになります。

前述したように以下のようにpythonのlistとして持つと、学習可能なパラメータを持つレイヤであっても、そのようにpytorchから認識されなくなるため、注意が必要です。例えば、self.linears = [nn.Linear(10, 10) for i in range(10)]と単純に書いただけでは、パラメータを持たないレイヤとして認識されてしまうため、学習時にエラーが出たりするので注意が必要です。

まとめ

nn.ModuleListの使い方について解説をしました。pytorch初学者がはまりやすいポイントだと思いますので、この点に注意してうまくモデル定義ができるとよいかなと思います。

-Engineering, python, pytorch, 機械学習
-, ,