
nn.Linearとは、Pytorchにおいて全結合層を定義するためのモジュールになります。PyTorchを使った機械学習や深層学習の実装において、最も基本的かつ頻繁に登場するレイヤのひとつであり、近年のトレンドであるTransformerにおいてもnn.Linearは頻繁に登場するレイヤになります。この記事では、nn.Linearの引数や使い方について解説をしていきます。
nn.Linearとは?
nn.Linearとは、線形層や全結合層などをネットワーク内に定義する際に使用するPytorchのモジュールになります。以下の図のように入力されたノードに対して線形変換を行い、ノードを出力するレイヤになります。

主に以下のような使われ方がされるレイヤになります。これらの目的のためにPytorchにおいてはnn.Linearを使用します。
- ネットワークの最後の分類層(画像認識の識別のクラス数などに合わせこむ)
- 中間層での特徴変換
- TransformerにおけるQKV生成時
- 入力データの次元圧縮など
引数
nn.Linearの引数は以下のようになっています。
| 引数名 | 型 | デフォルト | 説明 |
| in_features | int | - | レイヤーに入力される特徴マップのChannel数 |
| out_features | int | - | レイヤーから出力される特徴マップのChannel数 |
| bias | bool (optional) | True | バイアス項を追加するか否か |
nn.Linearの引数実際の使い方
それではnn.Linearの実際の使い方を見てみましょう。
import torch.nn as nn
# in_features = 10, out_features = 5, bias=True(default)
linear = nn.Linear(10, 5)
# dummy tensor(batch:3、dim: 10)
x = torch.randn(3, 10)
# Apply
output = linear(x) # output.shape -> torch.Size([3, 5])
上記のnn.Linear(10, 5)の段階で、入力の次元数と、出力の次元数を定義して引数に渡します。その入力サイズの定義に合うテンソルを適用することで、定義した出力次元のテンソルが返ってきます。実際にネットワークを定義するときには、ほかのレイヤのモジュールと組み合わせて、以下のような書き方になります。
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
# Layer def
self.fc1 = nn.Linear(784, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
まとめ
nn.Linearの引数とその使い方について解説しました。そんなに難しいものでもないため、サクッと使いこなしましょう!
参考
公式doc: https://docs.pytorch.org/docs/stable/generated/torch.nn.Linear.html