GPT-2を使って文のパープレキシティを計算する

とある手法の再現実装をするために学んだので覚え書き.

transformersのGPT-2を使って文のパープレキシティ(perplexity)を計算するための実装について書きます.
フレームワークはPyTorch,python3.8.10で試しています.

インストール

# 仮想環境作るなら
# python -m venv env
# source env/bin/activate
pip install torch transformers

一文のパープレキシティを計算

トークナイズ

訓練済みモデルを使うときは,語彙を揃えるために対応するトークナイザーを使います.transformersにはGPT-2のためのトークナイザーとしてGPT2TokenizerFastがあるので,これを使うことにします.モデルのIDにはgpt2を指定します.他にも,パラメータ数がより多いgpt2-largeなどが使えます.

from transformers import GPT2TokenizerFast

model_id = 'gpt2'
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
sentence = ['This is a pen .']
inputs = tokenizer(sentence, return_tensors='pt', padding=True)
print(inputs)
# 出力:{'input_ids': tensor([[1212,  318,  257, 3112,  764]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}

トークナイザーは文のリスト('str'オブジェクトのリスト)を入力とし,dictオブジェクトを返します(厳密には,一文ならリストにしなくてもいいです).返り値であるdictオブジェクトは2つの要素を含んでいます.一つはinput_idsで,トークンがIDに変換されたものです.もう一つはattention_maskで,バッチ化するときに使うものです.共にshapeは(バッチサイズ,系列長)です.

パープレキシティの計算

モデルにはtransformersのGPT2LMHeadModelを使います.トークナイザーと同じように,モデルのIDを指定して訓練済みのモデルをロードします.

from transformers import GPT2TokenizerFast, GPT2LMHeadModel
import torch
import math

model_id = 'gpt2'
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
model = GPT2LMHeadModel.from_pretrained(model_id)
sentence = 'This is a pen .'
inputs = tokenizer(sentence, return_tensors='pt')
with torch.no_grad():
        outputs = model(input_ids=inputs['input_ids'], labels=inputs['input_ids'])
print(torch.exp(outputs.loss)) # tensor(312.8972)

modelの返り値はCausalLMOutputWithCrossAttentionsオブジェクトです.入力のlabels=input_ids=と同じテンソルを渡すとlossが計算される仕組みになっています.lossは.lossで参照できます.これはtorch.Tensorオブジェクトなので,torch.exp() で囲むことでパープレキシティが得られます.

CausalLMOutputWithCrossAttentionsの詳細.
公式のページはここ: https://github.com/huggingface/transformers/blob/master/src/transformers/models/gpt2/modeling_gpt2.py#L1084
return CausalLMOutputWithCrossAttentions(
    loss=loss,
    logits=lm_logits,
    past_key_values=transformer_outputs.past_key_values,
    hidden_states=transformer_outputs.hidden_states,
    attentions=transformer_outputs.attentions,
    cross_attentions=transformer_outputs.cross_attentions,
)

このlossはクロスエントロピー損失で計算されます.

# 引用元:https://github.com/huggingface/transformers/blob/2c3fcc647a6d04f21668b1f5400c0fd33905bbb1/src/transformers/models/gpt2/modeling_gpt2.py#L1071
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

lm_logitsは(バッチサイズ,系列長,語彙サイズ)のshapeで,各トークンの生成確率が保存されています.lm_logits[..., :-1, :]とすることで,系列の最後尾以外の部分を抜き出しています. 一方,labelsは(バッチサイズ,系列長)のshapeで,labels[..., 1:]とすることで系列の先頭以外の部分を抜き出します.このように1トークンずらすことで,一般的なデコーディングの流れ( t_{i-1}の推定結果を使ってt_iを推定する流れ)を再現できます.

CrossEntropyLoss()はデフォルトのオプションでreduction='mean'が指定されているので,各トークンに対する損失の平均が計算されます.

複数文のパープレキシティを一度に計算(バッチ化)

バッチ化することで,複数文のパープレキシティを一度に計算することができます.基本的には上で述べた一文のみの場合と同じですが,トークナイズにpaddingの設定をする点と,クロスエントロピーの損失を計算するパートを若干自分で書く点が異なります.

トークナイズ

文によって文長が異なるので,バッチ化するときにはpadding=Trueを指定する必要があります.それから,トークナイザーの語彙にはいわゆるpad_tokenが設定されていないので,tokenizer.pad_token = tokenizer.eos_tokenとすることで追加しておきます.他にもtokenizer.add_special_tokens({'pad_token': '[PAD]'})とする方法もありますが,こうすると語彙サイズが1つ増えることでモデル側でindex out of rangeを起こして面倒なので,eos_tokenで代用します(実際,eos_tokenで代用するプログラムが多い印象です).

attention_maskを見ると,paddingされたトークンは0,そうでないトークンは1であることが分かります.

from transformers import GPT2TokenizerFast, GPT2LMHeadModel
import torch
import math

model_id = 'gpt2'
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained(model_id)
sentences = ['This is a pen .',
            'This a is pen .',
            'This is a pen pen pen pen .']
inputs = tokenizer(sentences, return_tensors='pt', padding=True)
print(inputs['attention_mask'])
'''
tensor([[1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1]])
'''

パープレキシティの計算

モデルへの入力は基本的に一文のときと同じですが,attention_maskも追加で渡す点が異なります.しかしながら,三文を入力したのにもかかわらず,返り値のlossは一つの値になっています.これはlossの計算に使われているtorch.nn.CrossentropyLoss()reduction=オプションが,デフォルトで'mean'になっているためです.'mean'では,文の境界にかかわらず,全てのlossが平均されて一つの値を返します.

model_id = 'gpt2'
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained(model_id)
sentences = ['This is a pen .', 'This a is pen .', 'This is a pen pen pen pen .']
inputs = tokenizer(sentences, return_tensors='pt', padding=True)
with torch.no_grad():
    outputs = model(inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=inputs['input_ids'])
print(outputs.loss) # tensor(6.3251)

そのため,outputs.logitsオブジェクトから自分でlossの計算を書きます.具体的には,torch.nn.CrossEntropyLoss(reduction='none')とすることで各トークンの損失を(平均など取ることなく)獲得し,dim=1で(つまり各文について)合計をとります.その後,それぞれの文の系列長で割ります.文の系列長は,attention_maskdim=1で合計すると得られます(attention_maskは,pad_tokenでないトークンが1となっているため).

model_id = 'gpt2'
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained(model_id)
sentences = ['This is a pen .', 'This a is pen .', 'This is a pen pen pen pen .']
inputs = tokenizer(sentences, return_tensors='pt', padding=True)
with torch.no_grad():
    outputs = model(inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=inputs['input_ids'])
print(outputs.logits.shape) # torch.Size([3, 8, 50257])
# 1トークンずらす
shift_logits = outputs.logits[:, :-1, :].contiguous() # 確率
shift_labels = inputs['input_ids'][:, 1:].contiguous() # 正解のトークンID
shift_mask = inputs['attention_mask'][:, 1:].contiguous() # マスク
batch_size, seq_len = shift_labels.shape
loss_fn = torch.nn.CrossEntropyLoss(reduction='none') # reduction='none'に
loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(batch_size, seq_len)
print(loss.shape) # torch.Size([3, 7])
# shift_maskと積をとることで,pad_tokenに対する損失を無視する.
# shift_mask.sum(dim=1)とすることで,各文のpad_tokenを除いた系列長が得られる
loss = (loss * shift_mask).sum(dim=1) / shift_mask.sum(dim=1)
print(torch.exp(loss)) # tensor([ 312.8972, 3360.7671,  125.8699])