STaR(Self-Taught Reasoner)徹底解説:AIが自分で書いた推論を再学習する斬新なアプローチ

2025年1月13日

こんにちは、ゆずかきです。
ここでは、STaR(Self-Taught Reasoner) という、GPTなどの大規模言語モデルにおける「推論チェーン(chain-of-thought)」の能力をブートストラップ(段階的に性能を引き上げる)するための手法について、論文原文に基づき網羅的に解説します。

この記事は初心者~中級者の方にも読みやすいように、STaRを丁寧に分解していきます。今回も文字数はかなり多めですが、その分、このSTaRという手法が抱える「自己言語化で推論性能を高める」メカニズムを余すところなくお伝えしたいと思います。

参考にした論文URL
👇
STaR: Bootstrapping Reasoning With Reasoning


§本記事の構成

  • STaRとは?
  • 背景:少数サンプルでの推論(Few-Shot Prompting)の問題点と、CoTとの関係
  • STaRの基本プロセス
  • 「Rationalization(正答をヒントに自己推論補完)」による強化
  • 実験結果:算数・常識推論・小中学校レベルの数学問題への応用
  • 考察:温度パラメータやバイアス、忠実性の課題
  • まとめと今後の展望

では順を追って見ていきましょう!


§STaRとは?

STaR (Self-Taught Reasoner) とは、以下のような研究モチベーションのもと提案された手法です。

「大規模言語モデルは、数例の“チェイン・オブ・ソート(推論の途中経過)”を見せることで、推論過程を自然言語で書き出すこと(=Rationale Generation)が可能になる。しかし、通常のFew-Shot Promptingだけでは大量データへの学習が十分に行われず、推論性能が頭打ちになりがちである。そこで、自己生成した推論過程(Rationale)をうまくフィルタしながら再学習(Fine-Tuning)するループを回すことで、推論能力を段階的に引き上げられないだろうか?」

従来のCoTは「少数サンプル+大規模モデルを前提としたIn-Context学習」が主流。ところがSTaRでは、

  1. 少量の「解説付きサンプル(Few-ShotのRationale例)」からスタート
  2. モデル自身が全データに対して推論を書き出す
  3. 正しい答えが出ていたデータだけRationaleを保存して再学習
  4. さらに、正答が出なかったデータに対しては「Rationalization」という、答えだけ与えて逆算でRationaleを書かせる手法を追加
  5. 再学習を繰り返す

というブートストラップ(段階的自己改善)を行います。これにより、最初は曖昧だったモデルの推論能力が、自己学習のループを通じて強化されていくわけです。


§背景:少数サンプルでの推論と、CoTだけでは物足りない?

近年、大規模言語モデル(LLM)はFew-Shot In-Context Learning(プロンプトに少数サンプルを並べて推論させる)という手法で高い性能を示しています。しかし:

  • 複雑な数理・常識推論 といった多段思考が必要な課題では、単なるFew-Shotだけでは不十分
  • CoT(Chain-of-Thought) は途中推論を書き出すことで精度を向上させるが、あくまで小規模データ しか活用できない

このため、従来は以下の2つの手段がよく使われてきました。

  1. 人手で大量の推論付きデータを作りFine-Tuning
  2. Few-Shot PromptingでCoT例を見せるだけ

1はコストが膨大で汎用性に欠け、2は推論精度が大きく伸びない。そこで、「モデルが自分で書いた推論を、正解のデータだけ抜き出して再学習」 という手法、つまりSTaR が考案されています。


§STaRの基本プロセス

論文の中で提示されているSTaRアルゴリズムを、簡単にまとめます。以下はRationalizationを含まない基本版です。

  1. 初期状態
  • データセット ({(x_i, y_i)}) と、少数の「推論付きサンプル例」(P)(これは数件)を用意
  • モデル (M)(事前学習済みLLM)を用意
  1. ループ
  2. (Rationale Generation)
    データセット全件 (x_i) に対して、Few-ShotのCoTプロンプトを使ってモデルに「(r_i)(推論)+ y_i’(答え)」を生成させる
  3. (フィルタリング)
    生成された答え (y_i’) が正解 (y_i) と一致したものだけを「(x_i, r_i, y_i)」として抜き出す
  4. (Fine-Tuning)
    抜き出したデータ上でモデルを学習し直し(Fine-Tune)
  5. モデル更新 → 次イテレーションへ
  6. 終了条件
  • ある程度ループした時点で性能が頭打ちになったら終了

これだけでも、「自分で書いた解説で再学習」 というスキームが面白いのですが、問題は「答えを間違えたケースが学習に使われない」という点。誤答データからは何も学べないので、モデルが一定水準以上進歩すると、それ以上難しい問題が解けるようにはなりにくい。


§Rationalization(正答を与えて逆算推論)

そこで著者らは、Rationalization という追加ステップを導入しています。これは、

「モデルが答えを間違えた問題」に対してだけ、正解ラベル ( y_i) を“ヒント”として与えた上で、改めて“こういう理由で (y_i) が答えになる”というRationaleを書かせる

という工程ですね。言い換えれば、Backward reasoning(答えが先に分かっている状態)を利用して正しい推論を生成するのです。

Rationalizationで得られた推論は、「モデル自身が何のヒントもなく導いたRationale」 ではありません。が、その後のFine-Tuningで“正しい推論”例として取り込むことにより、モデルは複雑な問題への推論方針を学べる可能性があります。

STaR全体は、これにより「正解を導けなかったサンプル」からも学習できるようになり、性能が大幅に底上げされるメリットを得るわけです。


§実験結果

論文では、主に以下のタスクでSTaRを検証しています。

  1. Arithmetic(n桁の足し算)
  2. Commonsense QA(CommonsenseQAデータセット)
  3. Grade School Math(GSM8K)

モデルとしては GPT-J(6Bパラメータ) を使用し、既存のスクリプトによるFine-Tuningを行ったそうです。

(1) Arithmetic

  • n桁(1~5桁)の整数加算タスク
  • 少数の手書きCoT例を初期プロンプトとして与え、最大で5桁分を学習
  • STaRループを回す と、2桁計算は初回だけでも精度30%超に到達し(ベースは1%未満…)、繰り返しで89%超 まで到達した
  • Rationalization ありだと多桁学習がさらにスムーズ化する

(2) Commonsense QA

  • 5択問題で、約1万件のトレーニングセット
  • 比較対象
  • GPT-Jを単純に答えのみでFine-Tuning
  • Few-ShotでRationale付き(CoT)のみ
  • 30倍規模のGPT-3 / LaMDA 137B など
  • 結果:
  • STaRあり の最終モデルが 約72.5% の正解率
  • GPT-J直接学習は60%、LaMDA 137Bが55.6%、GPT-3 Fine-Tuneが73.0%という報告もあり、STaRはより少ないパラメータ/データでGPT-3とほぼ同等 の性能に迫った

また、興味深いのは「STaRで得られたRationaleの質が改善」されている点。人間による簡易評価でも、Few-ShotだけのRationaleよりSTaRのRationaleを好む傾向が報告されています。

(3) GSM8K

  • 小中学校レベルの文章題(7,473件)
  • ベースライン(Few-Shotのみ)精度3%~5%前後に対し、STaRは10% 付近まで伸びる
  • Rationalizationは今回そこまで寄与せず(わずかな向上に留まる)。これはタスク特性の差かもしれないと考察されています。

§考察:温度パラメータやバイアス、忠実性の課題

論文終盤で議論されている重要なポイントをピックアップします。

  1. Rationalizationの役割
  • 「正解を提示してから逆算でRationaleを書かせる」のは、RL的に言えば「答えが正しいと分かっている経路」を強制的に辿らせる検索プロセスに近い。
  • 解が不明な状態で生成するより、答えが先に分かった方が正しい推論を導きやすい(モデルにとって難易度が下がる)。
  1. 高い温度(高ランダム性)で多数サンプルを生成→学習
  • 一見、学習データ増加のためにTemperatureを上げて多様なRationaleを出せそうだが、ノイズ・誤推論が大量に紛れ込み学習が破綻 しがちである。
  • STaRではむしろRationalization の方が効果的。
  1. Few-Shotプロンプトを途中でも使用するか?
  • 学習が進むとプロンプトが不要になるとも考えられるが、必ずしもそうではない。
  • 著者の実験では、最後までFew-Shot例を付ける方が性能が安定する傾向。
  1. バイアスや忠実性(Faithfulness)の問題
  • 「モデルが書き出す推論が本当に内部思考を反映しているか」は疑問。背後で別の思考をしながら説明を後付けする可能性がある。
  • また、「正解になりやすい(けど実際には納得感の薄い)Rationale」へ偏る可能性もあり、データセットにバイアスがあればそれを強化してしまうリスク。
  • これらはCoT全般に潜む課題でもある。

§まとめと今後の展望

STaR は、少数サンプルのCoT例からスタートし、モデル自身が出した正答付きRationaleを繰り返し再学習することで、チェイン・オブ・ソート(推論過程)の性能をブートストラップする仕組みです。とりわけ、正解できなかった問題へのRationalization(答えを先に見せて逆算で推論を生み出す工程)が、難問への対応・誤答克服に大きく寄与する可能性があります。

STaR はRL的な文脈(Policy Gradientの近似)とも関連付けられ、正答に結びつくRationale を強化学習的にフィルタ+再学習するという構図ですね。
以下の点が特徴と言えます。

  • 学習データへのRationaleラベル自動付与
    人手で大量にRationaleをアノテーションしなくても、モデル自身が書いた推論を「正解かどうか」でフィルタするだけで学習可能。
  • 高い性能
    GPT-Jレベル(6Bパラメータ)でも、CommonsenseQAで72.5% 、算数等でも顕著な精度向上。
  • まだ課題あり
    初期Few-Shot性能が0に近いと学習が進まない
    バイアス強化やRationaleの忠実性などの問題

それでも、「推論を自己生成し、それを学習してさらに推論力を高める」 という枠組みは、LLMの自己強化における新しい手段として非常に注目に値します。GSM8KやCommonsenseQAなど広範なタスクで汎用的に効きそうですし、今後さらに大規模モデル(GPT-3やPaLMなど)で発展させれば、より強力な自己言語化学習が進む可能性が高いですね。


§おわりに

以上、STaR: Self-Taught Reasoner について、論文を可能な限り詳しく解説しました。少量のRationale例 + モデル自己生成の再学習 で、かなり高いパフォーマンスが得られるのは面白いですね。

実際、論文でも「この手法はRL(強化学習)の方策勾配を近似している」「温度やヒント設計など色々なハイパラを詰められそう」など、まだまだ深堀りの余地が多いとしています。私の観点でも、以下の発展が期待できるかと。

  • LLM(GPT-3, PaLMなど)との組合せ:GPT-Jよりさらに強力なFew-Shot性能を基点にすると、もっと高速に精度が上がる?
  • 多彩なタスク(法律文書、論理パズルなど) での応用
  • エラー分析とバイアス除去:データに偏りがあるタスクでは、それがRationaleに反映され強化されるリスク。

Chain-of-Thoughtが盛り上がる中、このような「自己書き下ろし推論を学習材料に使う」という手法はますます注目されるのではないでしょうか。



§参考文献(論文リンク)


生成AI

Posted by yuzukaki-dialog