Engineering python pytorch 機械学習

Pytorch入門:nn.Module.trainとevalの使い方を解説

pytorchのnn.Moduleは、train()eval()という関数を持っており、pytorchのコードを見ているとmodel.train()model.eval()というような記述をよく見ると思います。俗にいう訓練モードと推論モードの切り替えを行うときに使うのですが、感覚的に使っている方も多いのではと思います。今日はtrain()eval()の使い方や実際にどのように動きが変わるのかを解説していきます。

nn.Moduleのモードとは?

nn.Moduleで定義したネットワークにはモードが存在しています。1つは訓練モード、もう1つは評価モードや検証モードと呼ばれるものです。それぞれにモードへの移行には、trainやeval関数を使用して切り替えをします。

  • nn.Module.train() : nn.Moduleで定義したモデルを訓練モードへ移行
  • nn.Module.eval() : nn.Moduleで定義したモデルを評価/検証モードへ移行

実際にはnn.Module.train()には引数があり、正確には以下のような引数になります。デフォルトはmode=Trueのため、train()とすることで訓練モードへ移行することになります。

  • train(mode=True) : 訓練モードへ移行
  • train(mode=False) :評価モードへ移行

train(mode=False)とすることは、eval()を使うことと等価になります。どちらも評価モードへ移行する挙動になります。

train()とeval()における違い

ではnn.Module.train()nn.Module.eval()におけるネットワークの違いは何なのでしょうか?

公式docs(https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.eval)によると、eval()の中で以下のように記述があります。

This has an effect only on certain modules. See the documentation of particular modules for details of their behaviors in training/evaluation mode, i.e. whether they are affected, e.g. Dropout, BatchNorm, etc.

(訳:これは特定のモジュールにのみ効果があります。例えば、Dropout、BatchNormなどです。トレーニング/評価モードでの動作の詳細、つまり影響を受けるかどうかについては、特定のモジュールのドキュメントを参照してください。)

これだけだと少しわかりませんので、DropoutBatchNormのドキュメントも見てみます。

Dropoutにおける違い

Dropoutは過学習を抑える手法の1つになります。入力の一部を不活性化(入力値を0にして次ノードに順伝搬させる)させることで、過学習を抑制する効果を発生させます。

pytorchの公式docs(https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html#torch.nn.Dropout)によると訓練モードと評価モードで以下のように振る舞いが変わると記載があります。

During training, randomly zeroes some of the elements of the input tensor with probability p.
(学習中、入力テンソルのいくつかの要素を確率pでランダムにゼロにする)

(中略)

during training. This means that during evaluation the module simply computes an identity function.
(これは、評価時にモジュールが単に恒等関数を計算することを意味する)

つまり、訓練モードの時には実際にDropout層に入力されたテンソルを一定の確率で不活性化させるような振る舞いをし、評価モードの時には受け取った入力をそのまま後ろのノードに伝搬させる振る舞いに変化します。これはDropoutのアイデア自体が学習時にのみ作用する(過学習を抑制する)ものであって、推論などをする際には不要な挙動であることから、このような振る舞いに変化をします。

BatchNormにおける違い

Batch Normalizationとは、深層学習における学習を安定させるテクニックの1つになります。中間層における特徴マップにおける値を各チャンネルごとに正規化することで値を安定させ、勾配消失や勾配爆発のような現象を抑制し、学習を安定させることができると言われています。

pytorchの公式docs(https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html)によると訓練モードと評価モードで以下のように振る舞いが変わると記載があります。

Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default momentum of 0.1.
(またデフォルトでは、学習中、このレイヤーは計算された平均と分散の実行推定値を保持し、評価中の正規化に使用されます。実行中の推定値はデフォルトのモメンタム0.1で保持されます)

つまり、訓練モードでは入力された特徴から平均と分散の計算を行い保持をし、評価モードではその値を使って正規化が行われます。

このように学習時と推論時で異なる挙動をする必要があるレイヤのために、train()eval()が用意されています。

使い方

ではnn.Module.train()nn.Module.eval()はどのようなタイミングで使用すればよいのでしょうか?基本的にはシンプルで学習のステップではtrain()を呼び出し訓練モードに、推論や評価、検証のステップでは必ずeval()を呼び出して評価モードにすればOKです。pytorchのドキュメントにおいても、これらを可能な限り事前にコールすることが推奨されており、余計なバグを生み出さないためにも各ステップの最初で呼ぶようにしましょう。

まとめ

nn.Module.train()nn.Module.eval()について解説をしました。自分もなんとなく使っていた部分がありましたが、公式Docsを読んで理解が深まりましたので、これからは意識して使っていければと思います。

-Engineering, python, pytorch, 機械学習