Engineering python pytorch 機械学習

Pytorch入門:nn.Linearとは?引数や使い方を徹底解説

nn.Linearとは、Pytorchにおいて全結合層を定義するためのモジュールになります。PyTorchを使った機械学習や深層学習の実装において、最も基本的かつ頻繁に登場するレイヤのひとつであり、近年のトレンドであるTransformerにおいてもnn.Linearは頻繁に登場するレイヤになります。この記事では、nn.Linearの引数や使い方について解説をしていきます。

nn.Linearとは?

nn.Linearとは、線形層や全結合層などをネットワーク内に定義する際に使用するPytorchのモジュールになります。以下の図のように入力されたノードに対して線形変換を行い、ノードを出力するレイヤになります。

全結合層

主に以下のような使われ方がされるレイヤになります。これらの目的のためにPytorchにおいてはnn.Linearを使用します。

  • ネットワークの最後の分類層(画像認識の識別のクラス数などに合わせこむ)
  • 中間層での特徴変換
  • TransformerにおけるQKV生成時
  • 入力データの次元圧縮など

引数

nn.Linearの引数は以下のようになっています。

引数名デフォルト説明
in_featuresint-レイヤーに入力される特徴マップのChannel数
out_featuresint-レイヤーから出力される特徴マップのChannel数
biasbool (optional)Trueバイアス項を追加するか否か
nn.Linearの引数

実際の使い方

それではnn.Linearの実際の使い方を見てみましょう。

import torch.nn as nn

# in_features = 10, out_features = 5, bias=True(default)
linear = nn.Linear(10, 5)

# dummy tensor(batch:3、dim: 10)
x = torch.randn(3, 10)

# Apply
output = linear(x) # output.shape -> torch.Size([3, 5])

上記のnn.Linear(10, 5)の段階で、入力の次元数と、出力の次元数を定義して引数に渡します。その入力サイズの定義に合うテンソルを適用することで、定義した出力次元のテンソルが返ってきます。実際にネットワークを定義するときには、ほかのレイヤのモジュールと組み合わせて、以下のような書き方になります。

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Layer def
        self.fc1 = nn.Linear(784, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

まとめ

nn.Linearの引数とその使い方について解説しました。そんなに難しいものでもないため、サクッと使いこなしましょう!

参考

-Engineering, python, pytorch, 機械学習
-, , ,