Gramformerを動かしてみた

本記事は,GEC (Grammatical Error Correction) Advent Calendar 2021 の9日目の記事です.

はじめに

2021/11月末にGramformerというリポジトリを見つけました.

github.com

特に何かの論文の実装というわけではなさそうで,プロジェクトを立ち上げたという感じの雰囲気です.出来たてのプロジェクトですが(first commitは今年7月),既にStarが800以上付いており,注目されていることが伺えます.まだまだ未実装の部分も多く未知数ですが,面白そうな取り組みなので速報的に紹介します.

何に使えるか

READMEによると,Gramformerは次のようなケースで使えるとしています.

  1. Post-processing machine generated text

  2. Human-In-The-Loop (HITL) text

  3. Assisted writing for humans

  4. Custom Platform integration

(今から書くのは僕の妄想ですが)これらを見る限り,単にGECシステムを扱うリポジトリというわけではなく,もっと広いところを見ているように感じます.1.からはGECに止まらないタスクとの関わりを予感します.また,2.からは誤りのない正しい文とか,システム出力をpost-processing的に人手修正したようなデータが蓄積されそうな気がしています.最後に4.はよりアプリケーションを意識した記述です.おそらくgrammarlyなどの既に知られた訂正ツールは企業運営のものばかりなので,オープンソースなものが構築されることに意義を主張しているのだと思います.3.は現在のGECコミュニティが最も強く意識している項目だと思います.GECに関する大きなプロジェクトになりそうなので,楽しみですね.

何ができるか

現状公開されている範囲では,Correcter(訂正器),Detector(検出器),Get Edits),Highlighterの機能を提供するようです.

Install

pythonは3.7推奨のようです.

pip3 install pip==20.1.1 
# IMPORTANT NOTE: (If install runs endlessly resolving package versions in for instance colab, refer to issue #22 - https://github.com/PrithivirajDamodaran/Gramformer/issues/22)
pip3 install -U git+https://github.com/PrithivirajDamodaran/Gramformer.git

Correcter

文中の誤りを訂正します.Gramformerオブジェクトの.correct()を呼ぶだけなので,簡単です.

from gramformer import Gramformer
import torch

def set_seed(seed):
  torch.manual_seed(seed)
  if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

set_seed(1212)

gf = Gramformer(models = 1, use_gpu=False) # 1=corrector, 2=detector

influent_sentences = [
    "He are moving here.",
    "I am doing fine. How is you?"
]   

for influent_sentence in influent_sentences:
    corrected_sentences = gf.correct(influent_sentence, max_candidates=1)
    print("[Input] ", influent_sentence)
    for corrected_sentence in corrected_sentences:
      print("[Correction] ",corrected_sentence)
    print("-" *100)

'''
出力:
[Gramformer] Grammar error correct/highlight model loaded..
[Input]  He are moving here.
[Correction]  ('He is moving here.', -31.02850341796875)
----------------------------------------------------------------------------------------------------
[Input]  I am doing fine. How is you?
[Correction]  ('I am doing fine, how are you?', -37.6710205078125)
----------------------------------------------------------------------------------------------------
'''

Get Edits

ERRANTのアライメント手法を用いてアライメントを取ります.Gramformerオブジェクトの.get_edits()を呼ぶだけです.ERRNATは2019年12月ごろにpythonモジュールとしてimportできるようになっているので,それを活用した形になります.

... <前略> ...
for influent_sentence in influent_sentences:
    corrected_sentences = gf.correct(influent_sentence, max_candidates=1)
    print("[Input] ", influent_sentence)
    for corrected_sentence in corrected_sentences:
      print("[Edits] ", gf.get_edits(influent_sentence, corrected_sentence[0]))
    print("-" *100)

'''
出力:
[Gramformer] Grammar error correct/highlight model loaded..
[Input]  He are moving here.
[Edits]  [('VERB:SVA', 'are', 1, 2, 'is', 1, 2)]
----------------------------------------------------------------------------------------------------
[Input]  I am doing fine. How is you?
[Edits]  [('OTHER', 'fine.', 3, 4, 'fine,', 3, 4), ('ORTH', 'How', 4, 5, 'how', 4, 5), ('VERB:SVA', 'is', 5, 6, 'are', 5, 6)]
----------------------------------------------------------------------------------------------------
'''

Highlighter

ソースに訂正スパンを埋め込む形で出力します.Gramformerオブジェクトの.highlight()を呼ぶことで使えます.

... <前略> ...
for influent_sentence in influent_sentences:
    corrected_sentences = gf.correct(influent_sentence, max_candidates=1)
    print("[Input] ", influent_sentence)
    for corrected_sentence in corrected_sentences:
      print("[Edits] ", gf.highlight(influent_sentence, corrected_sentence[0]))
    print("-" *100)
'''
出力:
[Gramformer] Grammar error correct/highlight model loaded..
[Input]  He are moving here.
[Edits]  He <c type='VERB:SVA' edit='is'>are</c> moving here.
----------------------------------------------------------------------------------------------------
[Input]  I am doing fine. How is you?
[Edits]  I am doing <c type='OTHER' edit='fine,'>fine.</c> <c type='ORTH' edit='how'>How</c> <c type='VERB:SVA' edit='are'>is</c> you?
----------------------------------------------------------------------------------------------------
'''

Detector

検出器は現在(2021/12/8)未実装ですが,Gramformerオブジェクトの.detect()で利用できることが示されています.

モデル

モデルはgrammar_error_correcter_v1と呼ばれるSeq2Seqなモデルを使っているようです.アーキテクチャはよくわかりません.訓練データはWikiEdits,C4ベースの擬似誤りデータ(Stahlberg+ 2020),PIEの擬似誤りデータ(Awasthi+ 2019)を使っているようです.推論時には,GPT-2のスコアによるリランキングをしています.

性能をざっくり検証

せっかくなのでCoNLL-2014 test setの性能を見てみたいと思います.Gramformerがトップで出力した文をそのまま評価に使います.

  • コード(p.pyとします)
from gramformer import Gramformer
import torch
import argparse

def set_seed(seed):
  torch.manual_seed(seed)
  if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

gf = Gramformer(models = 1, use_gpu=False) # 1=corrector, 2=detector

def main(args):
    set_seed(1212)
    gf = Gramformer(models = 1, use_gpu=False)
    outputs = []
    with open(args.in_file) as fp:
        for src in fp:
            src = src.rstrip()
            corrected_sentences = gf.correct(src, max_candidates=1)
            outputs.append(corrected_sentences[0][0])
    with open(args.out_file, "w") as fp:
        fp.writelines(outputs)
    

def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--in_file', required=True)
    parser.add_argument('--out_file', required=True)
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = get_parser()
    main(args)


* コマンドたち

# データの準備
wget https://www.comp.nus.edu.sg/~nlp/conll14st/conll14st-test-data.tar.gz
tar -xf conll14st-test-data.tar.gz
cat conll14st-test-data/noalt/official-2014.combined.m2 | grep '^S ' | cut -d ' ' -f 2- > conll14st-test-data/noalt/orig.txt
# 実行
python p.py --in_file conll14st-test-data/noalt/orig.txt --out_file conll14_out.txt
# 評価
wget https://www.comp.nus.edu.sg/~nlp/sw/m2scorer.tar.gz
tar -xf m2scorer.tar.gz
cd m2scorer
./m2scorer ../conll14_out.txt ../conll14st-test-data/noalt/official-2014.combined.m2


* 結果

Precision   : 0.1873
Recall      : 0.2637
F_0.5       : 0.1988

変に低いですね.出力を眺めると,Detokenizeまでされているので,単語のインデックスが合わなくなっているようです(CoNLL-2014の評価データのソースはtokenizeされている).そこで,Post-processingとしてspacyでtokenizeしました(これはちょっと適当です.ゆるして).その後もう一度評価すると

Precision   : 0.4565
Recall      : 0.2778
F_0.5       : 0.4045

ということで,RNNのベースラインくらいの性能でしょうか(よりも低い?).ちょっと雑なので正確な性能ではないかもしれませんが,今後のモデルの追加に期待です.

おわり

Gramformerが面白そうですねという記事でした.今の感じだと結構簡単に使えそうなので,この先どれくらい強いモデルが入るか?というところに期待です.この記事の情報はすぐに古くなると思いますが,使用感が伝わればと思います.