Engineering python pytorch 機械学習

Pytorch入門:nn.ModuleDictを解説 - nn.Moduleを辞書型として保持する

pytorchには、pythonの辞書型のようにModuleを保持できるnn.ModuleDictがあります。辞書型ライクでModuleを持つことで名前付きでレイヤを管理でき、動的にレイヤを選択するネットワークや、複数の分岐を持つモデルなど、より複雑なモデルを定義できるようになります。この記事では、nn.ModuleDictの基本的な使い方から、実践的な活用例までを初心者向けに解説できればと思います。

nn.ModuleDictとは?

公式ドキュメント(https://docs.pytorch.org/docs/stable/generated/torch.nn.ModuleDict.html#torch.nn.ModuleDict)によると以下のように説明されています。

ModuleDict can be indexed like a regular Python dictionary, but modules it contains are properly registered, and will be visible by all Module methods.

(ModuleDictは通常のPython辞書のようにインデックス付けできます。含まれるモジュールは適切に登録され、すべてのModuleメソッドから参照可能になります。)

つまりnn.ModuleDictとは、torchのnn.Moduleを辞書型で持つことができるクラスになります。
例えば以下のように通常のPythonの辞書型で定義をしてforward()で使用すると、Pytorchはそのレイヤを学習対象として認識することができず、エラーが発生することがあります(勾配計算ができない)。

import torch
import torch.nn as nn

class SampleModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Dict as normal python.
        self.dict_layers = {
            "conv": nn.Conv2d(3, 16, 3),
            "relu": nn.ReLU()
        }

    def forward(self, x):
        x1 = self.dict_layers["conv"](x) # not trainable
        x1 = self.dict_layers["relu"](x1) # not trainable

        return x1

このような際に使うのがnn.ModuleDictになります。

nn.ModuleDictの使い方

nn.ModuleDictの使い方はシンプルで、dictにレイヤを格納したあとに、ModuleDictクラスに投げればOKです。
こうすることで、pytorchが学習可能なレイヤとして認識をし、レイヤを辞書型で保持して使うことが可能になります。

import torch
import torch.nn as nn

class SampleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.dict_layers = nn.ModuleDict({
            "conv": nn.Conv2d(3, 16, 3),
            "relu": nn.ReLU()
        })

    def forward(self, x):
        x1 = self.dict_layers["conv"](x)
        x1 = self.dict_layers["relu"](x1)

        return x1

nn.ModuleDictの実際の活用例

このようにnn.ModuleDictを使って名前付きでレイヤを持つことで、柔軟にネットワークを設計できるようになります。

動的にレイヤを選択

nn.moduleDictを使って、名前付きのレイヤとして管理することで、実行時に動的に使用するレイヤを切り替えることが可能になります。
以下の例では、活性化関数の名前を実行時に与えることで、レイヤの切り替えを行っています。

import torch
import torch.nn as nn

class DynamicActivationModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.activations = nn.ModuleDict({
            "relu": nn.ReLU(),
            "leaky": nn.LeakyReLU(),
            "sigmoid": nn.Sigmoid()
        })
        self.linear = nn.Linear(10, 10)

    def forward(self, x, act_type="relu"):
        x = self.linear(x)
        x = self.activations[act_type](x)  # Switch by 'act_type'
        return x

マルチタスク学習

マルチタスク学習においても、nn.ModuleDictが活用できます。以下のようにClassificationとRegressionのHeadがある場合、例えばLoss計算などの際に別々のKeyを与えることで、同じコード上で別々のタスクの推論を行うことができ、コードをスマートにすることが可能にあります。

class MultiTaskModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.shared = nn.Linear(100, 50)
        self.heads = nn.ModuleDict({
            "classification": nn.Linear(50, 10),
            "regression": nn.Linear(50, 1)
        })

    def forward(self, x, task="classification"):
        x = self.shared(x)
        x = self.heads[task](x)
        return x

# Forward
model = MultiTaskModel()
class_logits = model(x, task="classification")  # get classification result
reg_score = model(x, task="regression")      # get regression result

$ loss
loss_cls = nn.CrossEntropyLoss()(class_logits, class_labels)
loss_reg = nn.MSELoss()(reg_score, reg_labels)

まとめ

nn.ModuleDictに関して、実例も踏まえて解説しました。名前付きで各レイヤを管理できるので、今回出した例のような柔軟なモデルの設計を行うことが可能になります。ぜひトライしてみてください。

-Engineering, python, pytorch, 機械学習
-, , ,