Engineering python pytorch 機械学習

Pytorch入門:nn.Sequentialの使い方と実際のユースケースを解説

nn.Sequentialは、pytorchでネットワークの定義を行う際に、一番最初に出てくるクラスの一つではないかと思います。よく一方通行のモデルを定義する際に使用するという説明があったりするので、ほかのネットワーク定義のモデルと比べるといささか汎用性に欠けるような印象を受けますが、nn.Sequentialを使うことでメリットが得られる場合もあります。その点について公式ドキュメントをベースに解説します。

nn.Sequential とは?

nn.Sequentialとは、pytorchでモデルのネットワークを定義する際に使用される関数の一つで、公式ドキュメント(https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html)によると以下のような説明がされています。

The forward() method of Sequential accepts any input and forwards it to the first module it contains. It then “chains” outputs to inputs sequentially for each subsequent module, finally returning the output of the last module.

(Sequentialのforward()メソッドは、任意の入力を受け取り、それを含む最初のモジュールに転送します。そして、それに続く各モジュールの入力に対する出力を順次「連鎖」させ、最終的に最後のモジュールの出力を返します。)

つまりは、nn.Sequentialに渡されたモジュールを順番に接続して、最後に渡されているモジュールの出力を返すネットワークを定義してくれるクラスということになります。公式ドキュメントでは以下のようなコードが例に上がっています。以下のように書くことで、シンプルなconv2dと、ReLUを2回繰り返すネットワークが定義されたモデルが返ってきます。

model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )

nn.Sequential の特徴と注意点

nn.Sequential は登録された順番のネットワークを作成する機能しかないため、複雑なマルチヘッドのモデルを構築したり、スキップコネクションを行ったりなど、ネットワークの分岐が必要になってくるような実際のビジネスの場で使われるモデルを定義するには、少し表現力が不足します。nn.Sequential で定義されたモデルの末尾などに、nn.Module.add_moduleを使用して、モジュールを追加することもできますが、あくまでも順番に処理するだけになるので、前述したようなモデルの定義を単独で行うのは、難しいクラスになります。

nn.Sequential を使うメリット

それではnn.Sequential を使うメリットは何なのでしょうか?それは、オリジナルのモデルの定義をする際に、forward()のコード量を大きく減らせるところにあります。以下にものすごく簡単に具体例を書いてみました。まずはnn.Sequentialを全く使用せずに、オリジナルのモデルクラスを定義した場合です。この場合、コンストラクタで定義されたモジュールをforward関数で順伝搬するために、4回繰り返しのようなコードを書くことになります。

import torch.nn as nn


class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        # def each layer
        self.conv1 = nn.Conv2d(1, 3, 3)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(3, 3, 3)
        self.relu2 = nn.ReLU()

    def forward(self, x):
        # write forward function 4 times... 
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)

        return x

これがnn.Sequentialを使用すると以下のようになります。

import torch.nn as nn


class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        # def each layer
        self.block1 = nn.Sequential(
          nn.Conv2d(1, 3, 3),
          nn.ReLU()
        )
        self.block2 = nn.Sequential(
          nn.Conv2d(3, 3, 3),
          nn.ReLU()
        )

    def forward(self, x):
        # write forward function only 2 lines.
        x = self.block1(x)
        x = self.block2(x)

        return x

forward関数の中身がすっきりしたかと思います。この程度のコード量だとメリットが感じにくいかもしれませんが、複雑なモデルになってくると、このようにモジュールをまとめることができるnn.Sequentialは重宝します。特にコード量を減らすだけでなく、各ネットワークのブロック単位などで括ってメンバ変数に格納することができるので、直感的に理解しやすいコードになり可読性を高く保てるところが大きなメリットだと考えています。また個人的には、ブロックごとに囲って実装できるので、実装ミスが起きにくくなるというようなメリットもあると考えています。

nn.Sequential と nn.ModuleListの違い

似たようなクラスにnn.ModuleListというものがあります。これとの違いについては公式ドキュメントでは以下のように説明されています。

What’s the difference between a Sequential and a torch.nn.ModuleList? A ModuleList is exactly what it sounds like–a list for storing Module s! On the other hand, the layers in a Sequential are connected in a cascading way.

Sequentialと torch.nn.ModuleList の違いは何ですか?ModuleListはその名の通り、モジュールを保存するためのリストです!一方、Sequentialのレイヤーはカスケード接続されています。

つまりは

  • nn.Sequentialは、渡されたモジュールたちを接続してモデルとして定義する。
  • nn.ModuleListは、あくまでもモジュールを保持するリストであり、その中身を接続する処理は自分でforward関数などで書く必要がある。

という違いがあります。

まとめ

nn.Sequentialの基本的な使い方と、実際にあるユースケースについて解説しました。なかなかチュートリアル的なところではありがたみは感じにくいクラスにはなりますが、複雑なオリジナルモデルを定義する際に使用すると、非常にありがたいクラスになりますので、是非とも使ってみてください。

-Engineering, python, pytorch, 機械学習
-, ,