STaR(Self-Taught Reasoner)徹底解説:AIが自分で書いた推論を再学習する斬新なアプローチ
こんにちは、ゆずかきです。
ここでは、STaR(Self-Taught Reasoner) という、GPTなどの大規模言語モデルにおける「推論チェーン(chain-of-thought)」の能力をブートストラップ(段階的に性能を引き上げる)するための手法について、論文原文に基づき網羅的に解説します。
この記事は初心者~中級者の方にも読みやすいように、STaRを丁寧に分解していきます。今回も文字数はかなり多めですが、その分、このSTaRという手法が抱える「自己言語化で推論性能を高める」メカニズムを余すところなくお伝えしたいと思います。
参考にした論文URL
👇
STaR: Bootstrapping Reasoning With Reasoning
§本記事の構成
- STaRとは?
- 背景:少数サンプルでの推論(Few-Shot Prompting)の問題点と、CoTとの関係
- STaRの基本プロセス
- 「Rationalization(正答をヒントに自己推論補完)」による強化
- 実験結果:算数・常識推論・小中学校レベルの数学問題への応用
- 考察:温度パラメータやバイアス、忠実性の課題
- まとめと今後の展望
では順を追って見ていきましょう!
§STaRとは?
STaR (Self-Taught Reasoner) とは、以下のような研究モチベーションのもと提案された手法です。
「大規模言語モデルは、数例の“チェイン・オブ・ソート(推論の途中経過)”を見せることで、推論過程を自然言語で書き出すこと(=Rationale Generation)が可能になる。しかし、通常のFew-Shot Promptingだけでは大量データへの学習が十分に行われず、推論性能が頭打ちになりがちである。そこで、自己生成した推論過程(Rationale)をうまくフィルタしながら再学習(Fine-Tuning)するループを回すことで、推論能力を段階的に引き上げられないだろうか?」
従来のCoTは「少数サンプル+大規模モデルを前提としたIn-Context学習」が主流。ところがSTaRでは、
- 少量の「解説付きサンプル(Few-ShotのRationale例)」からスタート
- モデル自身が全データに対して推論を書き出す
- 正しい答えが出ていたデータだけRationaleを保存して再学習
- さらに、正答が出なかったデータに対しては「Rationalization」という、答えだけ与えて逆算でRationaleを書かせる手法を追加
- 再学習を繰り返す
というブートストラップ(段階的自己改善)を行います。これにより、最初は曖昧だったモデルの推論能力が、自己学習のループを通じて強化されていくわけです。
§背景:少数サンプルでの推論と、CoTだけでは物足りない?
近年、大規模言語モデル(LLM)はFew-Shot In-Context Learning(プロンプトに少数サンプルを並べて推論させる)という手法で高い性能を示しています。しかし:
- 複雑な数理・常識推論 といった多段思考が必要な課題では、単なるFew-Shotだけでは不十分
- CoT(Chain-of-Thought) は途中推論を書き出すことで精度を向上させるが、あくまで小規模データ しか活用できない
このため、従来は以下の2つの手段がよく使われてきました。
- 人手で大量の推論付きデータを作りFine-Tuning
- Few-Shot PromptingでCoT例を見せるだけ
1はコストが膨大で汎用性に欠け、2は推論精度が大きく伸びない。そこで、「モデルが自分で書いた推論を、正解のデータだけ抜き出して再学習」 という手法、つまりSTaR が考案されています。
§STaRの基本プロセス
論文の中で提示されているSTaRアルゴリズムを、簡単にまとめます。以下はRationalizationを含まない基本版です。
- 初期状態:
- データセット ({(x_i, y_i)}) と、少数の「推論付きサンプル例」(P)(これは数件)を用意
- モデル (M)(事前学習済みLLM)を用意
- ループ:
- (Rationale Generation)
データセット全件 (x_i) に対して、Few-ShotのCoTプロンプトを使ってモデルに「(r_i)(推論)+ y_i’(答え)」を生成させる - (フィルタリング)
生成された答え (y_i’) が正解 (y_i) と一致したものだけを「(x_i, r_i, y_i)」として抜き出す - (Fine-Tuning)
抜き出したデータ上でモデルを学習し直し(Fine-Tune) - モデル更新 → 次イテレーションへ
- 終了条件:
- ある程度ループした時点で性能が頭打ちになったら終了
これだけでも、「自分で書いた解説で再学習」 というスキームが面白いのですが、問題は「答えを間違えたケースが学習に使われない」という点。誤答データからは何も学べないので、モデルが一定水準以上進歩すると、それ以上難しい問題が解けるようにはなりにくい。
§Rationalization(正答を与えて逆算推論)
そこで著者らは、Rationalization という追加ステップを導入しています。これは、
「モデルが答えを間違えた問題」に対してだけ、正解ラベル ( y_i) を“ヒント”として与えた上で、改めて“こういう理由で (y_i) が答えになる”というRationaleを書かせる
という工程ですね。言い換えれば、Backward reasoning(答えが先に分かっている状態)を利用して正しい推論を生成するのです。
Rationalizationで得られた推論は、「モデル自身が何のヒントもなく導いたRationale」 ではありません。が、その後のFine-Tuningで“正しい推論”例として取り込むことにより、モデルは複雑な問題への推論方針を学べる可能性があります。
STaR全体は、これにより「正解を導けなかったサンプル」からも学習できるようになり、性能が大幅に底上げされるメリットを得るわけです。
§実験結果
論文では、主に以下のタスクでSTaRを検証しています。
- Arithmetic(n桁の足し算)
- Commonsense QA(CommonsenseQAデータセット)
- Grade School Math(GSM8K)
モデルとしては GPT-J(6Bパラメータ) を使用し、既存のスクリプトによるFine-Tuningを行ったそうです。
(1) Arithmetic
- n桁(1~5桁)の整数加算タスク
- 少数の手書きCoT例を初期プロンプトとして与え、最大で5桁分を学習
- STaRループを回す と、2桁計算は初回だけでも精度30%超に到達し(ベースは1%未満…)、繰り返しで89%超 まで到達した
- Rationalization ありだと多桁学習がさらにスムーズ化する
(2) Commonsense QA
- 5択問題で、約1万件のトレーニングセット
- 比較対象
- GPT-Jを単純に答えのみでFine-Tuning
- Few-ShotでRationale付き(CoT)のみ
- 30倍規模のGPT-3 / LaMDA 137B など
- 結果:
- STaRあり の最終モデルが 約72.5% の正解率
- GPT-J直接学習は60%、LaMDA 137Bが55.6%、GPT-3 Fine-Tuneが73.0%という報告もあり、STaRはより少ないパラメータ/データでGPT-3とほぼ同等 の性能に迫った
また、興味深いのは「STaRで得られたRationaleの質が改善」されている点。人間による簡易評価でも、Few-ShotだけのRationaleよりSTaRのRationaleを好む傾向が報告されています。
(3) GSM8K
- 小中学校レベルの文章題(7,473件)
- ベースライン(Few-Shotのみ)精度3%~5%前後に対し、STaRは10% 付近まで伸びる
- Rationalizationは今回そこまで寄与せず(わずかな向上に留まる)。これはタスク特性の差かもしれないと考察されています。
§考察:温度パラメータやバイアス、忠実性の課題
論文終盤で議論されている重要なポイントをピックアップします。
- Rationalizationの役割
- 「正解を提示してから逆算でRationaleを書かせる」のは、RL的に言えば「答えが正しいと分かっている経路」を強制的に辿らせる検索プロセスに近い。
- 解が不明な状態で生成するより、答えが先に分かった方が正しい推論を導きやすい(モデルにとって難易度が下がる)。
- 高い温度(高ランダム性)で多数サンプルを生成→学習
- 一見、学習データ増加のためにTemperatureを上げて多様なRationaleを出せそうだが、ノイズ・誤推論が大量に紛れ込み学習が破綻 しがちである。
- STaRではむしろRationalization の方が効果的。
- Few-Shotプロンプトを途中でも使用するか?
- 学習が進むとプロンプトが不要になるとも考えられるが、必ずしもそうではない。
- 著者の実験では、最後までFew-Shot例を付ける方が性能が安定する傾向。
- バイアスや忠実性(Faithfulness)の問題
- 「モデルが書き出す推論が本当に内部思考を反映しているか」は疑問。背後で別の思考をしながら説明を後付けする可能性がある。
- また、「正解になりやすい(けど実際には納得感の薄い)Rationale」へ偏る可能性もあり、データセットにバイアスがあればそれを強化してしまうリスク。
- これらはCoT全般に潜む課題でもある。
§まとめと今後の展望
STaR は、少数サンプルのCoT例からスタートし、モデル自身が出した正答付きRationaleを繰り返し再学習することで、チェイン・オブ・ソート(推論過程)の性能をブートストラップする仕組みです。とりわけ、正解できなかった問題へのRationalization(答えを先に見せて逆算で推論を生み出す工程)が、難問への対応・誤答克服に大きく寄与する可能性があります。
STaR はRL的な文脈(Policy Gradientの近似)とも関連付けられ、正答に結びつくRationale を強化学習的にフィルタ+再学習するという構図ですね。
以下の点が特徴と言えます。
- 学習データへのRationaleラベル自動付与:
人手で大量にRationaleをアノテーションしなくても、モデル自身が書いた推論を「正解かどうか」でフィルタするだけで学習可能。 - 高い性能:
GPT-Jレベル(6Bパラメータ)でも、CommonsenseQAで72.5% 、算数等でも顕著な精度向上。 - まだ課題あり:
初期Few-Shot性能が0に近いと学習が進まない
バイアス強化やRationaleの忠実性などの問題
それでも、「推論を自己生成し、それを学習してさらに推論力を高める」 という枠組みは、LLMの自己強化における新しい手段として非常に注目に値します。GSM8KやCommonsenseQAなど広範なタスクで汎用的に効きそうですし、今後さらに大規模モデル(GPT-3やPaLMなど)で発展させれば、より強力な自己言語化学習が進む可能性が高いですね。
§おわりに
以上、STaR: Self-Taught Reasoner について、論文を可能な限り詳しく解説しました。少量のRationale例 + モデル自己生成の再学習 で、かなり高いパフォーマンスが得られるのは面白いですね。
実際、論文でも「この手法はRL(強化学習)の方策勾配を近似している」「温度やヒント設計など色々なハイパラを詰められそう」など、まだまだ深堀りの余地が多いとしています。私の観点でも、以下の発展が期待できるかと。
- LLM(GPT-3, PaLMなど)との組合せ:GPT-Jよりさらに強力なFew-Shot性能を基点にすると、もっと高速に精度が上がる?
- 多彩なタスク(法律文書、論理パズルなど) での応用
- エラー分析とバイアス除去:データに偏りがあるタスクでは、それがRationaleに反映され強化されるリスク。
Chain-of-Thoughtが盛り上がる中、このような「自己書き下ろし推論を学習材料に使う」という手法はますます注目されるのではないでしょうか。
§参考文献(論文リンク)
- STaR: Bootstrapping Reasoning With Reasoning: httpss://arxiv.org/abs/2203.14465
- GitHub – STaR (Self-Taught Reasoner): httpss://github.com/ezelikman/STaR
ディスカッション
コメント一覧
まだ、コメントがありません