連続潜在空間で推論する大規模言語モデルの学習 Training Large Language Models to Reason in a Continuous Latent Space
S Hao 著 · 2024 · 被引用数: 35・FAIR at Meta・UC San Diego
GitHub
OpenReview
Focus

CoTとCoconutの比較
- CoTは、推論プロセスを単語トークン列として生成し、次の入力単語トークン列として使用
- Coconutは、最終隠れ層の状態(「連続的思考」)を次の入力埋め込みとして直接使用
Keyword
CoT
- 推論を行う際に、言葉のトークンを生成し、段階的に解決策を生成するようLLMにプロンプトを与える推論技術(プロンプトエンジニアリングフレームワーク)
- 出力トークンが次の入力として機能し、これによりLLMsの有効な深さが増加し、表現力が向上することが証明されている(Feng et al. 2023)
- 生成された出力が入力に循環的に戻されるため、トランスフォーマーの実効的な深さが増加する
- しかし推論が言語空間内に制約されている
- 特定の人間の認知に関する研究結果とは異なる
- その自己回帰的な生成の性質により、計画と探索を通常必要とするより複雑な問題において、人間的な推論の模倣が難しい
- 脳画像研究では、様々な推論タスク中、言語理解と生成を担う脳領域の集合である言語ネットワークがほとんど活性化していない
- 人間の言語は推論よりもコミュニケーションに最適化されている
- 推論チェーンの大部分のトークンは、文章の流暢さのためだけに生成され、実際の推論プロセスにはほとんど寄与していない
Coconut(Chain of Continuous Thought:連続的思考連鎖)
- 言語に制約されない潜在空間でのLLM推論技術
- 言語のトークンに一度戻すという中間ステップをスキップし、次の推論ステップの入力として使用する
- 最後の隠れ状態(連続的思考 continuous thought) → 次のステップの入力埋め込み
- 言語ベースの推論とは異なり、Coconutの連続的思考は、複数の潜在的な次のステップを同時にエンコードでき、幅優先探索(BFS)に似た推論プロセスを可能にする
- 連続的思考の中で多くの可能性のあるオプションを保持し、暗黙的な価値関数に導かれながら、推論を通じて誤った経路を徐々に排除する
幅優先探索(BFS)
- ある時点における複数の可能性を並行して検討し、有望な経路を広げつつ、そうでない経路を排除していく
バックトラック
- 推論や問題解決の過程で、一度試みた考え方や手順が行き詰まったり、誤りであることが判明したりした場合に、前の段階に戻って別の可能性を探ること
ProntoQA(論理的推論)
論理的推論のためのデータセット
- 5ホップの質問で構成されており、架空の概念名が使用される
- 各問題には、木構造のオントロジーがランダムに生成される
- 自然言語で既知の条件が記述される
- モデルは、これらの条件に基づいて、与えられたステートメントが正しいかどうかを判断する
ProsQA(論理的推論)
この研究で新たに提案されたデータセット(ProntoQAの課題に対処するため)
- ランダムに生成されたDAG(有向非巡回グラフ)で既知の条件を記述
- モデルは正しい推論チェーンを見つけるために、グラフ全体で実質的な計画と探索を実行する
- 質問「[エンティティ]は[概念A]または[概念B]ですか?」
- グラフは、[エンティティ]から正解である[概念A]へのパスは存在するが、[概念B]へのパスは存在しないように構築
GSM8k(数学的推論)
- 小学校レベルの数学の文章題
- 問題がより多様でオープン・ドメインであり、現実世界の利用事例に近い
- データセットサイズ:
- 訓練データ:385,620件、検証データ:500件、テストデータ:1319件
- より複雑な文脈理解とモデリングを必要とするため、計算能力への要求が高い
- 対照的に、ProntoQA/ProsQAは計算能力がボトルネックにならない
多段階学習 multi-stage training
- モデルを訓練する際に、一度に最終的な目標を達成させるのではなく、いくつかの段階に分けて、徐々に複雑な学習内容へと移行していく方法
- 段階的な教育やステップバイステップの訓練と解釈できる
- モデルは、最初から言語の制約なしに推論することを学ぶのではなく、まず言語による推論のパターンを捉え、それを手がかりにして、より抽象的な潜在空間での推論能力を獲得していく
Overview
WHAT(これは何?)
- 言語空間のように制約のない潜在空間 Latent Space でLLMが推論を行うための新しいパラダイムであるCoconut(Chain of Continuous Thought:連続的思考連鎖)を提案し、その有効性を実験的に示した
- 言語空間は必ずしも最適な推論方法ではない
- 言語空間での推論は、次トークン予測という訓練目標を課せられ、トークンで生成されなければならないという基本的な制約がある
- 前のトークンの生成が次のトークンの生成に影響を与えるという自己回帰的な性質に基づくアプローチ
- 言語のトークンのほとんどは、主にテキストの一貫性のためのものであり、推論に不可欠なものではない
WHY(提案手法の価値は?)
- 言語のトークンに制約されずにCoTと同様の思考の連鎖効果が、潜在空間でも観察できることを示した
- 潜在空間での連続的思考により、計画的タスクにおいて高度な推論パターンが可能になり、言語推論を上回った
- 推論ステップが前のステップに大きく依存する数学の文章題などの一般的な問題において、Coconutアーキテクチャがより効果的
- ProsQAは、ランダムなDAG(有向非巡回グラフ)構造を持っているため、より多くの計画と探索が必要とされる、より「複雑な推論」を必要とするタスク
- 一方、GSM8kとProntoQAは、直感的な問題構造と限られた分岐のため、次のステップの予測が比較的容易になる
- なぜ計画的タスクで有利なのか
- 単なる推論の連鎖ではなく、幅優先探索(BFS)に似た探索木のように問題を解決できるから
- 有望なノードを優先しながら、関連性の低いノードを刈り込むことが可能
- BFSの場合は、すべてのフロンティアノードを均一に探索する
- なぜ探索木のように問題解決できるのか
- 潜在空間では複数の代替になる次の推論ステップを同時にエンコードできるため
- 単一の決定に早期にコミットするのではなく、複数の推論経路を維持し、確定的な決定を遅らせることが可能
- 探索木の終端状態により近づけることができる
- さらに多くの連続的思考を重ねることで、モデルの出力する確率分布(価値関数)を洗練できる
- 正しいノードと不正解のノードをより容易に区別することが可能になる
- 多段階カリキュラムによる潜在的推論の学習効果
- より簡単な目標に分解する多段階カリキュラムを用いることで、Coconutは様々なタスクで最高の性能を達成する
- 多段階カリキュラムの効果とは
- 多段階カリキュラム導入により、異なるステージを混合するトレーニング手法自体が、モデルが先を見通す能力を向上させる
- CoTとCoconut(k = 0)の比較
- CoTのトレーニングでは、目的は常に次のステップの生成に集中しており、モデルを近視眼的にする→有効な説明でパスを完成できない場合に存在しないエッジを幻覚する
- Coconutトレーニングの後期段階では、最初の数ステップが隠されているため、モデルはより将来のステップに焦点を当てることができる
- LLMの先を見通す能力の向上についてはGloeckle et al.(2024)で論じられている
- CoTよりも有意に少ないトークン数で同等以上の性能を達成した
- 今後の研究課題
- Coconutに使用された多段階学習は効果的であることが証明されたが、特に言語推論チェーンからの監督なしで、潜在空間での推論を学習するためのより良い、より一般的な戦略を開発するにはさらなる研究が必要
WHERE(技術のキモはどこ?)
標準的なLLM (言語モード)の定式化
- 変数
- :入力シーケンス
- :位置までのトークン埋め込みの列
- :位置tまでのすべてのトークンの最終隠れ状態の行列
- :位置の最終隠れ状態(つまり、)
- :トークン埋め込み関数
- :言語モデルヘッドのパラメータ
- 潜在モードの場合、は定義されない
- 潜在思考は言語空間に戻すことを意図していないため
言語モードと潜在モードの切り替え

- 特殊トークン<bot>と<eot>は、それぞれ潜在的思考モードの開始と終了を示すために使用
- 例:
- = <bot>, = <eot> で、モデルが潜在モードの場合()
- となる
- 前のトークンからの最終隠れ状態を入力埋め込みの代わりに使用
- 潜在モードが終了して言語モードに戻ると()
- 再びトークン埋め込みを入力として使用する
- 最終隠れ状態は最終正規化層によって処理されているため、その大きさは過度に大きくならない
学習プロセス
初期段階(推論の過程を明示的な言葉で表現することをモデルに教える段階)
- モデルは通常の言語モード
- CoTにより最終的な答えを生成する
後続の段階(k番目の段階)
- CoTの最初のステップを 個の連続的思考に置換
- c :ハイパーパラメータ

最適化手法
- 連続思考は完全に微分可能のため、逆伝播可能
- モードを切り替えると、Optimizerの状態がリセットされる
- 学習の初期段階で獲得された知識が、連続的思考の学習に悪影響を与えるのを防ぐため、または各段階で異なる学習ダイナミクスを促すため
損失関数
- 負の対数尤度
- n個の潜在思考が予定されている場合、n+1回順伝播を実行
- n 回の順伝播でそれぞれの連続的思考を計算
- 1回の順伝播で、残りのテキストシーケンスに対する損失を計算
- KVキャッシュを使うことで、繰り返しの計算を節約可能
課題
- 複数の順伝播が順番に行われるため、並列処理が難しい
推論プロセス
推論
- 質問の後は当面、言葉を使わない思考(潜在モード)言葉で出力する
- 開始時は質問直後に<bot>トークンを入れる
- ここから潜在モードが始まることを示す
- 終了時、<eot>トークンを出力する
- <eot>の出力タイミング決定の2つの方法
- モデル自身が潜在モードを終えるべきかどうかを判断する(二値分類器)
- 潜在モードでの思考の長さを常に一定にする(固定長パディング)
実験
数学的推論
推論データセット
- GSM8k
学習データセット
- https://arxiv.org/abs/2311.01460 の合成データセット
パラメータ
- c=2(2つの潜在思考)
- First stage:言語によるCoT(epoch=6)
- stage1:最初の1つの推論ステップを1 x cの連続思考に置き換え(epoch=3)
- stage2:最初の2つの推論ステップを2 x cの連続思考に置き換え(epoch=3)
- stage3:最初の3つの推論ステップを3 x cの連続思考に置き換え(epoch=3)
- Last stage:全ての推論ステップを除去し、3 × c の連続思考に置き換え(epoch=3 + 50)
- 各stageでepoch=3 の学習
論理的推論
推論データセット
- ProntoQA
- ProsQA
パラメータ
- c=1(1つの潜在思考)
- First stage:言語によるCoT(epoch=5)
- stage1:最初の1つの推論ステップを1 x cの連続思考に置き換え(epoch=5)
- stage2:最初の2つの推論ステップを2 x cの連続思考に置き換え(epoch=5)
- stage3:最初の3つの推論ステップを3 x cの連続思考に置き換え(epoch=5)
- stage4:最初の3つの推論ステップを4 x cの連続思考に置き換え(epoch=5)
- stage5:最初の3つの推論ステップを5 x cの連続思考に置き換え(epoch=5)
- stage6:最初の3つの推論ステップを6 x cの連続思考に置き換え(epoch=5)
- Last stage:全ての推論ステップを除去し、6 × c の連続思考に置き換え(epoch=5 + 50)
モデル
GPT-2
- 学習率は1×10^-4
- バッチサイズ128
ベースライン
- CoT
- 完全な推論チェーンを使用してFineTuningする
- No-CoT
- 質問と最終的な答えのペアのみでFineTuningする
- iCoT
- CoTと同様に、完全な推論チェーンを使用してFineTuningする
- ただし言語による推論ステップの最初の部分を、徐々に訓練データから取り除いていく
- 最初は言語による明示的な推論を学習させ、徐々にその推論をモデルの内部に落とし込むことを目指す
- Pause token
- No-CoT と同様に、質問と最終的な答えのペアのみでFineTuningする
- ただし質問と答えの間に特殊なトークン <pause> を挿入する
- 明示的な推論ステップを学習させる代わりに、特殊なトークンを使ってモデルの潜在的な計算能力を引き出すことを目指す
- <pause>トークンの数はCoconutの連続的思考と同じ設定
Coconutのバリエーション
- カリキュラムなし
- 多段階学習なし
- 質問と回答のみを含む最終段階のデータを直接使用してCoconutを学習
- 思考なし(iCoTに類似)
- 多段階学習カリキュラム
- 連続的な潜在思考なし
- 正確な学習スケジュールはiCoTでなくCoconutに一致させる
- 思考としての一時停止
- 多段階学習カリキュラム
- 連続的思考の代わりに特別な<pause>トークンを使用
評価
※精度が高いほど推論能力が強く、生成トークン数が少ないほど効率が良い

ProntoQAとProsQAにおいて、Coconutが精度と効率性でCoTを上回った
- 連続的思考の「連鎖」は推論を強化する
ケーススタディ
- 潜在空間での推論を試しに言語の単語のようなものに変換する実験
- 通常、Coconutは言語を使わずに潜在空間で推論する
- 変換された単語の中に、その数学の問題を人間が段階的に解く際に計算する途中に出てくる数値(例:「180」)が含まれていた
- 潜在空間での推論が、直接言語化されないまでも、問題解決のプロセスにおける重要な情報を捉えている可能性を示唆している

潜在的推論プロセスの分析
- テストデータ:ProsQA
- ハイパーパラメータ k ∈ {0, 1, 2, 3, 4, 5, 6}
- Coconutにk個の連続的思考を使用させる
- モデルはk + 1ステップ目から残りの推論チェーンを言語で出力する
- 他の段階のデータを一定の確率(p = 0.3)で常に混合させた
- 初期の学習段階を忘れる問題への対処
- 評価指標
- 回答に対する精度
- 推論過程に対する精度
- 正しいパス:正解への最短パスの1つである出力
- より長いパス:質問に正しく答えているが、最短パスよりも長い有効なパス
- 幻覚:存在しないエッジを含むか、または切断されているパス
- 誤ったターゲット:グラフ内の有効なパスだが、問われているノードではない
- 部分的なパスなしで最終的な答えのみを出力する場合
- 正しいラベル
- 不正解なラベル
潜在的推論の計画能力

- 潜在空間でより多くの推論が行われるほど、計画能力が向上する
- 連続的思考でkを増加させると
- 最終的な回答の正確性と正しい推論プロセスの割合(「正しいラベル」と「正しいパス」)の両方が向上
- 「幻覚」と「誤ったターゲット」の割合も減少
連続的思考は複数の潜在的な次のステップを符号化できる→単なる推論の「連鎖」ではなく、探索木として解釈できる
- 潜在的推論モデルは、探索を探索木の終端状態により近づけることができ、正しいノードと不正解のノードをより容易に区別することが可能

- 最初のステップはAlexの子ノードのいずれか{lempus、sterpus、zhorpus、grimpus}を選択する
- CoTで訓練されたモデルは行き詰まった後に存在しないエッジを幻覚する

- 潜在的推論は単なる推論の「連鎖」ではなく、探索木として解釈可能
- モデルの確率分布を価値関数として解釈
- 価値関数:内部的に各推論の状態によって最終的な目標達成にどれだけ貢献しそうかを評価する
- 確率分布:言語モデルの通常の出力層(ソフトマックス関数)の結果
- 「もし次に言語で出力するとしたらどのようなトークンになるか」という確率分布

- 左図は線間に顕著な差があり、モデルが並列的に代替となる潜在的思考を探索する能力を反映している
- 右図は線間の差が狭くなっていることから、探索木が発展するにつれて並列性が減少し、推論における確実性が増加している

- 第1および第2の潜在的推論ステップにおけるモデルの予測確率とノードの高さの相関関係をテストセット全体で分析
- "sterpus"は葉ノードなので、ターゲットノード"bompus"に到達できない
- 即座に不正解とわかる
- 他のノードにはさらに探索すべき子孫ノードがあり、その評価はより困難
- 高さが2(他の候補よりも高い)である"grimpus"と"lempus"の間でモデルはより大きな不確実性を示す
- モデルは高さが低い場合、不正解のノードには低い値を、正解のノードには高い値を適切に割り当てることに成功
- しかし、高さが増加するにつれて、区別が不明確になり、正確な評価がより困難になることを示している