【Pytorch】 EarlyStoppingを実装する
はじめに
本記事ではpytorchでEarlyStoppingを実装する方法を紹介します.EarlyStoppingはいくつか実装方法がありますので,そのうちの一つを扱います.
おさらい: EarlyStoppingとは
深層学習における教師あり学習では,訓練データを用いて学習を行いますが,やりすぎると過学習してしまいます.過学習を起こしているときは,訓練データのlossは減少する一方で,汎化性能が下がるため開発データのlossは増加することが予想されます.よって,開発データのlossを眺めつつ,増加傾向にあれば学習を打ち切るような工夫をします.このような工夫をEarlyStoppingと呼びます.
使用するリポジトリ
今回は以下のリポジトリの実装を使います.
上記リポジトリをcloneするか,cloneが面倒であれば pytorchtools.py
だけコピーし,手元に準備します.これ以降,from pytorchtools import EarlyStopping
のようにインポートすることになります.pytorchtools
はpipにも存在しますが中身が異なるため,必ず上記リポジトリのpytorchtools.py
を手元に置き,インポートする形にしてください.
この実装では,開発データの最小lossに注目し,最小lossが更新されているかということを打ち切る基準にします.最小lossが順調に更新されていれば学習を続けますし,一定のエポック数が経過しても更新できなければ打ち切ります.
実装
最小限構成
まずは,最小限の構成として以下のようなコードを書いてみます.
from torch import nn from pytorchtools import EarlyStopping # 適当なモデル class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.Linear = nn.Linear(1, 1) def forward(self, x): pass # インスタンス作成 model = Model() early_stopping = EarlyStopping(patience=3) # loss(だと思っている値)のループ for val_loss in [5, 4, 3, 2, 1, 2, 3, 4, 5, 6]: print('val_loss:', val_loss) early_stopping(val_loss, model) if early_stopping.early_stop: print("Early Stopping") break # 打ち切り
まずはモデルのインスタンスと,EarlyStopping
のインスタンスを作ります.今回は3回連続で最小lossを更新できなかった場合に打ち切ることにするので,patience=3
とします.モデルは何らかのパラメータを持つ必要があるので,適当にLinear()
を持たせています.
続くfor文では,開発データのlossのつもりで値をループしています.序盤は5,4,3,2,1
とlossが減少しますが,途中で過学習が起こって2,3,4,5
とlossが増加する状況を想定します.lossが得られるたびに,early_stopping(val_loss, model)
として,lossとモデルのインスタンスを渡します.
EarlyStopping
はメンバ変数に .early_stop
を持っており,打ち切るべきだと判断すれば True
になります.if文でこれを確認し,True
であればbreakしてエポックのループを打ち切ります.
実行すると以下のようになります.
val_loss: 5 val_loss: 4 val_loss: 3 val_loss: 2 val_loss: 1 val_loss: 2 EarlyStopping counter: 1 out of 3 val_loss: 3 EarlyStopping counter: 2 out of 3 val_loss: 4 EarlyStopping counter: 3 out of 3 Early Stopping
3回連続で最小lossを更新できなかったので,打ち切られたことが確認できます.
これと同時に,最小lossを達成するようなモデルがcheckpoint.pt
というファイル名で保存されています.ファイル名は後述するオプションで変更可能です.
実践的な例
実践的な例に関しては,同リポジトリのMNIST_Early_Stopping_example.ipynb
が非常に参考になります.このノートブックでは,ざっくり以下のような形でEarlyStoppingが使われています.
early_stopping = EarlyStopping(patience=20) for エポック数: for 訓練データのミニバッチ: モデルの訓練 for 開発データのミニバッチ: 開発データのミニバッチlossを計算,リストに保存 early_stopping(開発データのミニバッチlossの平均値, モデル) if early_stopping.early.stop: break # 打ち切り
訓練データのミニバッチを一周するたびに開発データのlossを計算します.開発データのlossはミニバッチ単位で複数得られるため,それらの平均値で打ち切るかどうかを判断します.
オプション
early_stopping = EarlyStopping()
でインスタンスを作る際には,オプションで以下のような設定ができます.
patience= (デフォルト: 7)
開発データのlossが何回連続で最小lossを更新できなければ打ち切るか指定します.例えば,patience=3
であれば,最小lossを3回連続更新できなければ打ち切ります.
verbose= (デフォルト: False)
普通,特別な出力があるのは開発データの最小lossを更新できなかったとき(悪くなったとき)の
EarlyStopping counter: 1 out of 3
みたいな表示だけですが,verbose=True
のもとでは,最小lossを更新したとき(良くなったとき)にも以下のようなメッセージを出力します.
Validation loss decreased (5.000000 --> 4.000000). Saving model ...
delta= (デフォルト: 0)
lossが前回からいくつ減少すれば良くなったと判断するか設定します.デフォルトは 0
なので,最小lossから少しでも値が減少していれば良くなったと判断します.一方で,仮に delta=0.1
に設定した場合,最小lossから 0.1
以上減らないと,最小lossを更新したとみなされません. delta
を大きくすればするほど,打ち切られやすいと解釈できます.
path= (デフォルト: 'checkpoint.pt')
最小lossが更新された際にモデルが保存されるパスを指定します.
おわりに
今回はPytorchでEarly Stoppingを実装する方法を紹介しました.今回紹介したライブラリは一例に過ぎません.他にもpytorchのigniteというライブラリ群にEarly Stoppingの実装があったりします.お好みのものを探して使ってみてください.