
深層学習のフレームワークとして広く使われているPyTorchでは、モデルの保存と読み込みが非常に重要なステップです。学習済みモデルを再利用したり、他の環境で推論を行ったりする際には、適切な保存方法を理解しておく必要があります。本記事では、PyTorchにおけるモデル保存の基本的な方法と、保存形式の種類(state_dictと全体保存)について、初心者にもわかりやすく解説します。
Pytorchモデルの形式について
Pytorchでよく見る保存形式として、.ptや.pthがあります。どちらもPytorchのモデルや重みを保存するための拡張子なので、どちらも共通の関数で保存、読み込み可能です。しかし、慣習的な使い分けとして以下のような区別があります。
| 拡張子 | 意味・用途 | 備考 |
.pt | モデル全体(モデルの構造+重み)を保存することが多い | 汎用性が.pthより高い |
.pth | 学習済みの重み(state_dict():後述)のみを保存するのに多用される | 学習済みの重みの保存に使用されることが多い |
どちらもPytorchモデルをシリアライズして保存している形式なので、拡張子によってテクニカルな差分があるわけではありません。
Pytorchモデルの保存方法
Pytorchモデルの保存については、以下のように実施します。ここでは.pth形式で保存をしています。また説明を簡略化するために、学習する部分のコードは省きます。
import torch
# Load Model arch.
model = MyModel()
# --- training code ---
# Save model weight.
torch.save(model.state_dict(), "model.pth")
まず学習するモデルの定義を読み込みます。その後Pytorchモデルを学習し、torch.saveを使って読み込んだモデルのstate_dict()を保存します。こうすることで、Pytorchモデルの重みとバイアスのみを.pthで保存することができます。
この.pthには、モデルの構造は保存されておらず、重みとバイアスのみが保存されています。そのため、読み込みの際にはモデル定義を別で読み込む必要がある点に注意してください。
Pytorchモデルの読み込み方法
先ほど保存したPytorchモデルを読み込んで利用するには、以下のようにコードを書きます。
# Load Model arch.
model = MyModel()
# Load weight we saved.
model.load_state_dict(torch.load("model.pth"))
model.state_dictを用いた重みの保存では、モデル構造自体は保存されていませんので、先に保存したときに呼び出したモデルと同じクラスを読み込みます。そのあとに先ほど保存したPytorchモデルの重みをload_state_dict()を用いることで重みを読み込むことができます。
Pytorchモデルの全体を保存・読み込む方法
モデル全体(構造+重み)を保存したい場合は、torch.save(model, "model.pt") のようにモデルクラスのオブジェクトをそのまま保存します。読み込むには、torch.loadを使えばモデル構造ごと読み込むことが可能です。
# save
torch.save(model, "model.pt")
# load
model = torch.load("model.pt")
ただし、公式ドキュメントでは、あまりこのやり方は推奨されていません。
なぜstate_dict()を使った方法が推奨されるのか?
モデルの構造も保持できる保存方法のほうが、一見楽なように見えますが、なぜstate_dictでモデルの重みのみを保存することが推奨されているのでしょうか?公式ドキュメントなどを踏まえると、以下の理由が挙げられます。
柔軟性が高い
state_dictはモデルのパラメータ(重みやバイアス)だけを保存します。- モデルの構造(クラス定義)は別途保持するため、構造を変更したり、部分的に再利用したりすることが可能です。
- 例えば、転移学習で一部の層だけを読み込むといった使い方ができます。
安全性と互換性
- モデル全体(
torch.save(model, ...))を保存すると、Pythonのpickle形式で保存されます。 - pickleはPythonのバージョンや環境に依存しやすく、他の環境で読み込めない可能性があります。
- 一方、
state_dictは純粋なテンソルの辞書なので、環境依存性が低く、安全に読み込めます。
再現性と管理のしやすさ
state_dictを使えば、モデルの構造と重みを明確に分離できます。- これにより、コード管理(Gitなど)やモデルのバージョン管理がしやすくなります。
- モデル構造をコードで明示的に定義することで、再現性の高い実験が可能になります。
別環境などでの再利用可能性や、柔軟性等を考慮するとstate_dictを素直に使った方がよさそうです。
まとめ
Pytorchでのモデルの保存・読み出し方法について解説しました。今までなんとなく.ptと.pthを使用していたので、この機会に整理できてよかったなと思いました。
参考
https://docs.pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html