【Pytorch】 EarlyStoppingを実装する

はじめに

本記事ではpytorchでEarlyStoppingを実装する方法を紹介します.EarlyStoppingはいくつか実装方法がありますので,そのうちの一つを扱います.

おさらい: EarlyStoppingとは

深層学習における教師あり学習では,訓練データを用いて学習を行いますが,やりすぎると過学習してしまいます.過学習を起こしているときは,訓練データのlossは減少する一方で,汎化性能が下がるため開発データのlossは増加することが予想されます.よって,開発データのlossを眺めつつ,増加傾向にあれば学習を打ち切るような工夫をします.このような工夫をEarlyStoppingと呼びます.

使用するリポジトリ

今回は以下のリポジトリの実装を使います.

github.com

上記リポジトリをcloneするか,cloneが面倒であれば pytorchtools.py だけコピーし,手元に準備します.これ以降,from pytorchtools import EarlyStoppingのようにインポートすることになります.pytorchtools はpipにも存在しますが中身が異なるため,必ず上記リポジトリpytorchtools.py を手元に置き,インポートする形にしてください.

この実装では,開発データの最小lossに注目し,最小lossが更新されているかということを打ち切る基準にします.最小lossが順調に更新されていれば学習を続けますし,一定のエポック数が経過しても更新できなければ打ち切ります.

実装

最小限構成

まずは,最小限の構成として以下のようなコードを書いてみます.

from torch import nn
from pytorchtools import EarlyStopping

# 適当なモデル
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.Linear = nn.Linear(1, 1)
    
    def forward(self, x):
        pass

# インスタンス作成
model = Model()
early_stopping = EarlyStopping(patience=3)
# loss(だと思っている値)のループ
for val_loss in [5, 4, 3, 2, 1, 2, 3, 4, 5, 6]:
    print('val_loss:', val_loss)
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early Stopping")
        break # 打ち切り

まずはモデルのインスタンスと,EarlyStoppingインスタンスを作ります.今回は3回連続で最小lossを更新できなかった場合に打ち切ることにするので,patience=3とします.モデルは何らかのパラメータを持つ必要があるので,適当にLinear()を持たせています.

続くfor文では,開発データのlossのつもりで値をループしています.序盤は5,4,3,2,1とlossが減少しますが,途中で過学習が起こって2,3,4,5とlossが増加する状況を想定します.lossが得られるたびに,early_stopping(val_loss, model)として,lossとモデルのインスタンスを渡します.

EarlyStopping はメンバ変数に .early_stop を持っており,打ち切るべきだと判断すれば True になります.if文でこれを確認し,True であればbreakしてエポックのループを打ち切ります.

実行すると以下のようになります.

val_loss: 5
val_loss: 4
val_loss: 3
val_loss: 2
val_loss: 1
val_loss: 2
EarlyStopping counter: 1 out of 3
val_loss: 3
EarlyStopping counter: 2 out of 3
val_loss: 4
EarlyStopping counter: 3 out of 3
Early Stopping

3回連続で最小lossを更新できなかったので,打ち切られたことが確認できます.

これと同時に,最小lossを達成するようなモデルがcheckpoint.ptというファイル名で保存されています.ファイル名は後述するオプションで変更可能です.

実践的な例

実践的な例に関しては,同リポジトリMNIST_Early_Stopping_example.ipynbが非常に参考になります.このノートブックでは,ざっくり以下のような形でEarlyStoppingが使われています.

early_stopping = EarlyStopping(patience=20)
for エポック数:
    for 訓練データのミニバッチ:
        モデルの訓練
    
    for 開発データのミニバッチ:
        開発データのミニバッチlossを計算,リストに保存
    early_stopping(開発データのミニバッチlossの平均値, モデル)
  if early_stopping.early.stop:
    break # 打ち切り

訓練データのミニバッチを一周するたびに開発データのlossを計算します.開発データのlossはミニバッチ単位で複数得られるため,それらの平均値で打ち切るかどうかを判断します.

オプション

early_stopping = EarlyStopping()インスタンスを作る際には,オプションで以下のような設定ができます.

patience= (デフォルト: 7)

開発データのlossが何回連続で最小lossを更新できなければ打ち切るか指定します.例えば,patience=3 であれば,最小lossを3回連続更新できなければ打ち切ります.

verbose= (デフォルト: False)

普通,特別な出力があるのは開発データの最小lossを更新できなかったとき(悪くなったとき)の

EarlyStopping counter: 1 out of 3

みたいな表示だけですが,verbose=Trueのもとでは,最小lossを更新したとき(良くなったとき)にも以下のようなメッセージを出力します.

Validation loss decreased (5.000000 --> 4.000000).  Saving model ...
delta= (デフォルト: 0)

lossが前回からいくつ減少すれば良くなったと判断するか設定します.デフォルトは 0 なので,最小lossから少しでも値が減少していれば良くなったと判断します.一方で,仮に delta=0.1 に設定した場合,最小lossから 0.1以上減らないと,最小lossを更新したとみなされません. deltaを大きくすればするほど,打ち切られやすいと解釈できます.

path= (デフォルト: 'checkpoint.pt')

最小lossが更新された際にモデルが保存されるパスを指定します.

おわりに

今回はPytorchでEarly Stoppingを実装する方法を紹介しました.今回紹介したライブラリは一例に過ぎません.他にもpytorchのigniteというライブラリ群にEarly Stoppingの実装があったりします.お好みのものを探して使ってみてください.