【Pytorch】nn.Embeddingの使い方を丁寧に

はじめに

本記事では,Pytorchの埋め込み層を実現するnn.Embedding()について,入門の立ち位置で解説します.

ただし,結局公式ドキュメントが最強なので,まずはこちらを読むのをお勧めします.

pytorch.org

対象読者は,

  • 他のモデルの実装記事見ても,全人類nn.Embeddingをサラッと使ってて何だこれ〜〜って人

  • nn.Embeddingの入出力のイメージが分からない人

  • 公式ドキュメント(英語だし)分からね〜〜という人

目次

nn.Enbeddingの入出力

まずは入出力を確認します.

  • 入力

入力は,単語IDの並びです.例えば,いま,単語とIDの対応が

{'this': 0, 'is': 1, 'a': 2, 'sentence': 3}

だと仮定すると,
['This', 'is', 'a', 'sentence']
という単語列は
[0, 1, 2, 3]
という単語IDの並びに変換できます.このような単語IDの並びが入力です.one-hotにする必要はありません.

  • 出力

出力は,各単語の埋め込みベクトルです.入力として[0,1,2,3]を与えた場合,

[[単語ID 0 に対する埋め込みベクトル],
[単語ID 1 に対する埋め込みベクトル],
[単語ID 2 に対する埋め込みベクトル],
[単語ID 3 に対する埋め込みベクトル]]

が出力となります.各単語IDが埋め込みベクトルに変わるため,次元が一つ増えます.shapeとしては,(系列長) -> (系列長, 埋め込み次元)と変化します.詳しくは,あとで具体例で確認します.

なお,ここでは各データ型をpythonのリストのように書いていますが,本当は<class 'torch.Tensor'>である必要があります.

nn.Embeddingの宣言

nn.Embedding()は,基本的に2つの引数をとります.

  • 第一引数: 語彙サイズ(単語IDの最大値+1)

  • 第二引数: 埋め込む次元

例えば,入出力の項で用いた例では,全部で4単語を扱うので語彙サイズは4です.埋め込む次元は,自由に決めることができます.以下では5次元とかにしてみましょう.

import torch
import torch.nn as nn

# 語彙サイズ
vocab_size = 4
# 埋め込む次元
emb_dim = 5
embeddings = nn.Embedding(vocab_size, emb_dim)

とりあえず1単語埋め込む

とりあえず1単語埋め込んでみます.宣言は上記のものを流用するとして,単語ID0の単語だけ埋め込みましょう.単語とIDの対応は人ぞれぞれですが,入出力の項で用いた例ではthisの埋め込みベクトルを得る処理だと思えば良いです.

入力はあくまでも単語IDの並びなので,1単語だけだとしても,リストにする必要があります.また,入出力のところで触れたように,入力のデータ型は<class 'torch.Tensor'>でないといけないので,torch.tensor(リスト)で変換してから入力します.

import torch
import torch.nn as nn

vocab_size = 4
emb_dim = 5
embeddings = nn.Embedding(vocab_size, emb_dim)
# 0番目の単語なので,[0]をTensorに変換
word = torch.tensor([0])
embed_word = embeddings(word)
print(embed_word)
print(word.shape, '->', embed_word.shape)
'''
出力: 
tensor([[-0.5962, -1.2342,  1.1888, -1.1408, -0.3594]],
       grad_fn=<EmbeddingBackward>)
torch.Size([1]) -> torch.Size([1, 5])
'''

ということで,単語ID0の埋め込みベクトルは[-0.5962, -1.2342, 1.1888, -1.1408, -0.3594]だと分かりました.宣言時に埋め込む次元を5にしているので,出力も5次元になっています.

一応shapeを確認すると,(1) -> (1, 5) = (系列長, 埋め込み次元)と変化しているのが分かります.

複数単語を一気に埋め込む

いま,['This', 'is', 'a', 'sentence'] という単語列が,[0, 1, 2, 3] と単語IDの並びに変換されているとします.この時も,1単語の時と全く同じように書けます.

import torch
import torch.nn as nn

vocab_size = 4
emb_dim = 5
embeddings = nn.Embedding(vocab_size, emb_dim)

words = torch.tensor([0, 1, 2, 3])
embed_words = embeddings(words)
print(embed_words)
print(words.shape, '->', embed_words.shape)
'''
出力: 
tensor([[-0.0964,  0.0113,  0.5742,  0.7339, -1.9287],
        [ 0.8564,  1.8212, -0.6291,  0.4318,  1.5869],
        [ 0.2528, -0.3460,  0.0923, -0.7709, -0.6723],
        [ 1.7714,  0.0655, -0.6220, -0.3896, -0.5604]],
       grad_fn=<EmbeddingBackward>)
torch.Size([4]) -> torch.Size([4, 5])
'''

このように,複数単語も一気に計算できました.shapeが(4, 5) = (系列長, 埋め込み次元)と変化していることも確認しておきます.

ミニバッチ化してたらどうなるの

学習時にミニバッチ化している場合,次元が一つ増えます.具体的には,バッチサイズが2だとすると

[['this', 'is', 'a', 'sentence'],
 ['is', 'this', 'a', 'sentence']]

のようなデータが1回の入力になるわけです.単語IDには

[[0, 1, 2, 3],
 [1, 0, 2, 3]]

みたいに変換できるでしょう.このような場合も,同じように埋め込むことができます.

import torch
import torch.nn as nn

vocab_size = 4
emb_dim = 5
embeddings = nn.Embedding(vocab_size, emb_dim)
sents = torch.tensor([[0, 1, 2, 3],
                     [1, 0, 2, 3]])
embed_sents = embeddings(sents)
print(embed_sents)
print(sents.shape, '->', embed_sents.shape)

'''
出力: 
tensor([[[-1.2136, -2.5102, -0.1156, -0.1305, -0.0215],
         [ 2.7627, -0.6553, -0.2072,  0.7654, -2.3114],
         [-1.7942, -1.3689,  0.1742, -0.4785,  0.0510],
         [ 0.4606,  0.7227, -1.6526, -1.4224, -0.8632]],

        [[ 2.7627, -0.6553, -0.2072,  0.7654, -2.3114],
         [-1.2136, -2.5102, -0.1156, -0.1305, -0.0215],
         [-1.7942, -1.3689,  0.1742, -0.4785,  0.0510],
         [ 0.4606,  0.7227, -1.6526, -1.4224, -0.8632]]],
       grad_fn=<EmbeddingBackward>)
torch.Size([2, 4]) -> torch.Size([2, 4, 5])
'''

やはり,各単語IDがしっかり埋め込まれています.また,同じ単語IDは同じ埋め込みベクトルになっていることも確認できます.さらに,変換後の次元は(2, 4, 5) = (バッチサイズ, 系列長, 埋め込み次元)となっていることも分かります.本記事のテーマとは関係ありませんが,この形はいわゆるLSTMなどのRNNにおいて,batch_first=Trueとした時の入力で要求される形そのものです.

padding

同一バッチ内での系列長とpadding

一つ前でミニバッチでの例を示しましたが,これはバッチ内のデータの系列長が揃っているという意味で特殊な例です.2文とも4単語で揃っています.実データではそんなことは稀で,文によって単語数は異なるでしょう.nn.Embeddingは,同一バッチ内で系列長が異なるとエラーを吐くので,系列長を揃えるためのpadding周りも紹介します.

paddingは,系列長が足りない部分を何かしらの値で埋めることです.引き続き,具体例として入出力の項で用いた例を使います.いま,入力の単語を1始まりの単語IDで置き換えて,paddingは0で行うことにしましょう.また,単語とIDの対応は
{'this': 1, 'is': 2, 'a': 3, 'sentence': 4}
と仮定します.また,入力の文を

[['this', 'is', 'a', 'sentence'],
 ['this', 'sentence']]

とします.バッチサイズが2だとすると,バッチ内の系列長が4単語と2単語で異なるため,nn.Embeddingに入力できません.これを解決するため,単語IDに変換してpaddingすると

[[1, 2, 3, 4],
 [1, 4, 0, 0]]

となります.2文目について,足りない2単語を0でpaddingすることで,2つの文の系列長が一致しました.

padding_idxオプション

(一つ前の続き)これでnn.Embedding()の入力とできますが,paddingした0というのはダミーのIDであって,何かしらの単語を表すものではありません.0に対しては埋め込みベクトルも計算して欲しくないので,オプションのpadding_idxを用いて無視するようにします.また,ダミーIDのために単語IDが1つ余分に消費されるため,宣言時の語彙サイズも1つ大きくする必要があります.

import torch
import torch.nn as nn

# 語彙サイズを1つ大きく
vocab_size = 4+1
emb_dim = 5
# padding_idx=0を追加
embeddings = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
sents = torch.tensor([[1, 2, 3, 4],
                     [1, 4, 0, 0]])
embed_sents = embeddings(sents)
print(embed_sents)
print(sents.shape, '->', embed_sents.shape)

'''
出力:
tensor([[[ 0.1721,  0.7526,  0.3652,  0.2402, -0.4727],
         [ 0.3910, -0.0685,  2.2712,  0.3159, -0.6302],
         [-0.1785,  0.4621, -0.5341, -0.8397,  0.6010],
         [-0.4067, -1.0936, -0.5026,  0.2667, -0.1626]],

        [[ 0.1721,  0.7526,  0.3652,  0.2402, -0.4727],
         [-0.4067, -1.0936, -0.5026,  0.2667, -0.1626],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],
       grad_fn=<EmbeddingBackward>)
torch.Size([2, 4]) -> torch.Size([2, 4, 5])
'''

paddingしたところは埋め込みベクトルがゼロベクトルになっていることが分かります.

おわりに

今回はnn.Embeddingを細かめに説明しました.入出力の具体例とshapeを意識して解説したので,参考になれば幸いです.もし誤りなどあれば,ご指摘お願いいたします.

paddingに関する個人的な実験

padding_idxを負にしたら?

padding_idxは負にもできます.この場合,普段のpythonにおける負のindex参照と同じように,「後ろから何番目」として動作します.例えば,vocab_size = 5のもとでpadding_idx = -1とすると,単語IDが4の埋め込みベクトルが計算されなくなります.

import torch
import torch.nn as nn

vocab_size = 4+1
emb_dim = 5
# padding_idxを-1にしてみた
embeddings = nn.Embedding(vocab_size, emb_dim, padding_idx=-1)
sents = torch.tensor([[1, 2, 3, 4],
                     [1, 4, 0, 0]])
embed_sents = embeddings(sents)
print(embed_sents)
print(sents.shape, '->', embed_sents.shape)

'''
出力: 
tensor([[[-0.1468,  1.0873, -1.0831, -0.4471,  1.3967],
         [ 0.3240, -2.7857, -0.6588,  0.4635, -1.5496],
         [ 1.4328,  0.7491, -0.9504, -0.5701, -0.8590],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.1468,  1.0873, -1.0831, -0.4471,  1.3967],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0230, -1.0231,  0.5071, -2.9424, -1.3194],
         [ 0.0230, -1.0231,  0.5071, -2.9424, -1.3194]]],
       grad_fn=<EmbeddingBackward>)
torch.Size([2, 4]) -> torch.Size([2, 4, 5])
'''

paddingするダミーIDを負にしたら?

個人的に,-1でpaddingできたらvocab_sizeを1つ大きくしたりしなくて良いなーと思いましたが,それはできません.そもそもpadding_idx=-1の挙動的に,-1を認識させるのが無理ですね.

import torch
import torch.nn as nn

vocab_size = 4+1
emb_dim = 5
embeddings = nn.Embedding(vocab_size, emb_dim, padding_idx=-1)
# -1でpaddingしてみる
sents = torch.tensor([[1, 2, 3, 4],
                     [1, 4, -1, -1]])
embed_sents = embeddings(sents)
print(embed_sents)
print(sents.shape, '->', embed_sents.shape)

'''
出力:
~~ 色々エラーが出て ~~
RuntimeError: index out of range: Tried to access index -1 out of table with 4 rows. at ../aten/src/TH/generic/THTensorEvenMoreMath.cpp:418
'''