潜在的推論によるテスト時間計算のスケールアップ:Scaling up Test-Time Compute with Latent Reasoning: A Recurrent Depth Approach
Title
Scaling up Test-Time Compute with Latent Reasoning: A Recurrent Depth Approach
J Geiping・17 Feb 2025・ELLIS Institute Tübingen
Code & Data
recurrent-pretraining
seal-rg • Updated Mar 3, 2025
YouTube
Focus
アーキテクチャの可視化
- 青色のpreludeブロックは入力を潜在空間に埋め込む
- 緑色のrecurrentブロックは最終的な潜在状態を計算するために反復処理を行う
- 赤色のcodaブロックによって埋め込みをデコードする

- :シーケンス次元
- :モデルの隠れ次元
Keyword
潜在空間 (latent space)
モデルが入力データを処理する過程で生成する内部表現が存在する空間
- データの意味や構造が抽象的に表現される
- 論文では、言語モデルが推論を行う際に、言葉ではなく潜在空間で思考する様子が記述される
- この空間では、思考の過程が連続的な軌道として表現され、複雑な推論を可能にする
推論時スケーリング
- モデルのパラメータ数を増やすのではなく、テスト時の計算量を増やすことで性能を向上させる手法
主要アプローチ
- 長い思考連鎖(long CoT)の例を用いた事後学習を行い、モデルがコンテキストウィンドウ内で中間計算を言語化し、それによって思考を外在化する能力を発展させること
- 制約:高コストな内部推論を常に単一の言語化されたトークンに射影しなければならない
デプス-リカレント言語モデル(depth-recurrent language models)
- transformer アーキテクチャ
- 訓練時にランダムにサンプリングされた回数だけ実行される潜在的な深い再帰ブロックを備える
リカレント層 (Recurrent layers)
この論文では、潜在空間において反復計算を行う中心的なユニットを指し、因果的自己注意(Causal Self-Attention)ブロック(RoPE使用)とゲート付きSiLU MLPを含む複数の層から構成される
- リカレント層によって、transformerはトークンを出力する前に任意の回数の計算を実行可能
深い思考 (deep thinking)
言語モデルの設計において、単なる記憶ではなく、推論能力を重視する
- 先行研究
- Can You Learn an Algorithm? Generalizing from Easy to Hard Problems with Recurrent Networks(Schwarzschild et al., 2021b)
- End-to-end Algorithm Synthesis with Recurrent Networks: Extrapolation without Overthinking(Bansal et al., 2022)
- depth-recurrent デプス-リカレントモデルが、複雑なアルゴリズムや推論を学習する上で、再帰的な構造が有効であることを示唆する
- Algorithm Design for Learned Algorithms(Schwarzschild et al., 2023)
ゼロショット適応計算(Zero-Shot Adaptive Compute)
特定のタスクやデータに対して明示的な訓練を行わずに(ゼロショット)、モデルがそのタスクやデータに適応した計算を行うこと
RoPE (Rotary Positional Embeddings)
位置情報を回転行列を用いてエンコードする位置埋め込み(Position Embedding)の手法の一つ
- ‣
- 絶対位置埋め込み(Absolute Position Embedding)と相対位置埋め込み(Relative Position Embedding)の長所を組み合わせている
この論文では、回転行列の周波数を決定するハイパーパラメータである基数(指数関数の底)を50000に設定している
- 回転行列の角度は、位置(トークンの位置)と基数によって決定される
- 基数が大きいほど、回転の周期が長くなる
- 50000は、比較的高めの設定であり、モデルが長距離の依存関係を重視する設定値と考えられる
因果的自己注意 (Causal Self-Attention)
オリジナルのtransformerのDecoderにあるmasked self-attention のこと
- ‣
ゲート付きSiLU MLP (gated SiLU MLP)
ゲート付きSiLU活性化関数を持つMLP
- ‣
RMSNorm
Layer Normalizationの変種。計算効率が良いとされる正規化手法
- ‣
非正規化精度 (un-normalized accuracy)
標準化や調整を行っていない、生の精度
- 標準化や調整を行った精度指標の例
- AUC-ROC
- F1スコア
- 調整済み精度(多クラス分類におけるサンプル数の調整など)
連続的思考の連鎖 (continuous chain of thought approach)
再帰的なアプローチによってモデルが深い思考を行うことを可能にする技術
- 推論チェーン(推論の過程を明示的に示すデータ)でモデルを追加学習させる
- 次のトークンを計算する際に、最後の隠れ状態を代替入力とする
- 推論時には、モデルは通常のテキスト生成と同様に動作するが、前のトークンの隠れ状態を次のトークンの生成に利用することで、連続的な思考を模倣し、より複雑な推論を行う
- モデルに推論の足跡を記憶させ、次のステップで利用させる学習方法
- ‣
Overview
WHAT(これは何?)
- 潜在空間において再帰的処理を導入することで、モデルがテスト時に計算量を動的に調整し、推論能力を向上させる新しい言語モデルアーキテクチャを提案
WHY(提案手法の価値は?)
- 潜在的推論 latent reasoning
- 潜在空間で反復的に推論(再帰的推論)を行う
- テスト時に計算量を動的に調整することで、従来のモデルよりも大幅に性能を向上させることが可能
- Chain-of-Thoughtのような明示的な中間表現に頼らず、モデルは内部状態を繰り返し更新する
- より多くのパラメータと訓練データを必要とする他のオープンソースモデルに匹敵する
- テスト時に数十億のパラメータと5,000億以上のプレトレーニングトークンまでスケールできる
- 潜在空間での推論に重点を置いているため、CoTのような明示的な中間表現を使用しないため、潜在状態のメモリサイズを比較的小さく保つことができる
- 特殊なデモンストレーションを必要としない
- 標準的な訓練データを使用した訓練が可能
- 可変的な計算予算
- テスト時に追加の計算リソースが与えられれば能力の向上が可能
- 比較:o1やDeepSeekで用いられる長文脈推論 long context reasoning アプローチの場合
- 特別な訓練データが必要
- 訓練と推論に巨大なメモリが必要
- 非常に長いコンテキストウィンドウを必要とするため、トークン並列化(Liu et al., 2023a)などの特殊な訓練手法が必要になる
- 再帰深度(recurrent-depth)ネットワーク
- パラメータあたりのFLOP数が多く、スケール時のアクセラレータ間の通信コストを大幅に削減可能
- インターコネクト(ノード間を接続するネットワーク)が低速な環境で訓練する場合、デバイス(GPUなど)の使用率を高めることができる
- 計算負荷が高く、パラメータ数の少ないアーキテクチャを構築
- 比較:標準的なtransformer
- 適応性
- トークンごとの適応的計算、(自己)推論的デコーディング、KVキャッシュ共有など、非再帰モデルでは大規模な調整を必要とする多くの機能を推論時に自然にサポートする
WHERE(技術のキモはどこ?)
- 潜在的なdepth-recurrentモデルのアーキテクチャと訓練目的
- デコーダーのみのtransformerブロック
- 3つの機能グループに分かれる
- prelude block
- :入力トークン列
- :埋め込み入力
- 入力データを潜在空間に埋め込む
- 複数のtransformer層を使用
- recurrent block
- :初期ランダム状態
- :ランダム状態を初期化するための標準偏差
- :新しい潜在状態
- :潜在状態
- :再帰反復回数
- 状態を修正するリカレント計算の中心ユニット
- モデルの計算を必要なだけ繰り返すことが可能
- coda block
- :出力確率
- 潜在空間からの埋め込みをデコードする(un-embeds)
- 複数のtransformer層を使用
- 予測ヘッドを含む
- 再帰的デザイン
- 初期状態に依存しない定常状態への収束を促進(再帰の安定化)
- =パス独立性 path independence (Anil et al., 2022) の促進
- deep thinking 文献に基づく(Bansal et al., 2022)アーキテクチャ
- 関数の勾配降下法と類似している
- 反復的な処理
- 通常のレイヤーを積み重ねる方向に再帰する
- RNNも再帰処理を行うが、RNNは通常のレイヤー方向ではなく系列方向に再帰する
- 再帰ブロックを繰り返し適用することで、潜在状態を更新し、より深い推論を行う
- 関数の勾配に基づいて繰り返し変数を更新する
- 潜在入力eを各ステップで注入
- データに基づいて潜在状態を更新する
- 勾配降下法では、各ステップでを使用して更新を行う
- ランダム状態によって潜在ベクトルを初期化
- 初期状態への依存性を減らす
- 勾配降下法では、重みをランダム状態で初期化する
- が最初だけ注入される場合、反復プロセスは安定せず、境界条件(初期状態)にのみ依存する
- 複数層の構造
- 入力トークンを隠れ潜在空間に埋め込むために複数層を用いる
- 標準的な固定深度transformerを分析する実証的結果に基づく
- 例:Kaplan et al.(2024)
- 言語モデルは、テキストをサブワードと呼ばれる小さな単位に分割する
- モデルの初期の層は、これらのサブワードを、潜在空間という抽象的な空間にある特定の意味や概念に関連付ける
- 潜在空間では、各サブワードは単一の概念として表現され、モデルはこれらの概念を組み合わせて、より複雑な意味を理解したり推論したりする
- モデルの後続の層は、これらの埋め込まれた概念に対して演算を行い、テキストの全体的な意味を解釈し、次の単語を予測したり質問に答えたりする
- 現代的な反復モデリングパラダイムである拡散モデルのうち、潜在拡散モデル(Rombach et al., 2022)に似ている
- 反復モデリング実験
- ノイズの注入は予備実験では効果がなかった
- 現在のステップを入力として受け取るコアブロックの設計も効果がなかった
- 標準的な設計
- 多層構造
- Causal Self-Attention
- 活性化関数:ゲート付きSiLU MLP
- 正規化層:RMSNorm
- 特徴的な設計
- RoPEの回転行列の周波数を基底50000で設定
- モデルが長距離の依存関係を重視する設定値
- 正規化層を注意機構とMLPの両方の前後に配置(pre-norm、post-norm)
- この正規化はスケール2での再帰を訓練するために必要だった
- レイヤ数: で表現
- 小規模モデル
- 例:の形状(1, 4, 1)
- 大規模モデル
- 例:の形状(2, 4, 2)
- 実質的な層は8層のみ
- 例:再帰回数 =32回、再帰ブロックが反復されると、2 + 4 + 2 = 132層の深さに展開される
- 最大の固定深度トランスフォーマー(Levine et al., 2021; Merrill et al., 2022)よりも深い計算チェーンを構築可能
- prelude block ( 個のtransformer)
- 入力トークンをとして埋め込み、その後、個のprelude層を適用する
- recurrent block ( 個のtransformer)
- アダプター行列A:で始まり、との連結を隠れ次元(Bansal et al., 2022)にマッピングする
- 小規模モデルでは初期埋め込み特徴の再統合は連結ではなく加算でも同様に機能した
- 大規模では連結が最も効果的であることを発見した
- 最後の出力は再度RMSNorm でリスケールする
- coda block ( 個のtransformer)
- による正規化、および結合された埋め込みを使用した語彙への射影を行う
アーキテクチャ(全体像)

アーキテクチャ(詳細)
- 大規模訓練におけるデータ選択とエンジニアリング
- AMDクラスターのFrontierで実施
- :分布からのランダムサンプル
- :分布(対数正規ポアソン分布)からのランダムな反復回数
- 平均再帰回数 と分散 が与えられた場合、以下の方法で分布からサンプリング可能
- :モデルの出力
- :シーケンスを左にシフトしたもの(シーケンスの次のトークン)
- 通常未満の値をサンプリングする
- モデルが通常は少ない計算量で学習されることを示唆する
- しかし、分布は裾が重い(heavy tail)ため、非常に多くの繰り返しが行われる場合もまれに起こりうる
- Truncated Backpropagation
- recurrent block では、訓練時の計算とメモリを抑えるため、最後の回の反復についてのみ逆伝搬を行う
- がどれだけ大きくても、逆伝搬するのは最後の回分の反復のみ
- 例:実質的な層数が4でr=3の場合、合計12層分のforward処理が行われる。そしてk=8の場合、最後の8層分の勾配情報が、4つの実質的な層の重み更新に使用される
- メモリ使用量と逆伝搬の計算量が一定に保つことが目的
- 繰り返し回数 は、モデルのフォワードパスにおける再帰処理の深さを決定するため、が大きい場合、計算グラフが深くなり、通常のBackpropagationでは必要なメモリ量が増加してしまう。それゆえ、最大 activation memory と backward compute が に依存しないようにTruncated Backpropagationを採用している
- 主要な実験ではに固定して実施
- 固定した場合、訓練の各ステップにおける全体的なメモリ使用量は等しくなる
- ただし、prelude block では、出力が毎ステップで注入されるため、依然として毎ステップでBackpropagationを行い、勾配の更新を行う
- ベンチマークの性能を最適化するのではなく、創発的な推論行動の可能性を最大化するデータセットの混合を選択
- データの大部分は一般的なウェブテキスト、科学的な文章、およびコードで構成
- BPE(Sennrich et al., 2016)を使用して65536トークンの語彙を構築
- トークン化された文書は長さ4096のシーケンスにパッキング
- 各文書の先頭に<|begin_text|>トークンを配置
- 前後の文脈が欠如する文書の末尾は破棄した(Ding et al. 2024)
- ‣
- の形状(2, 4, 2)、平均再帰値で訓練
- head:96サイズの55個(multi-head attention機構のhead)
- MLPの内部次元は17920で、RMSNormのは
- 非再帰的なprelude blockとcoda blockに約1.5Bパラメータ
- recurrent blockに1.5Bパラメータ
- 結合入力埋め込みに0.5Bパラメータ
- 分散 のTruncated Normal Distribution(打ち切り正規分布)から値をサンプリングして初期化
- 平均0、標準偏差σの正規分布において、±3σの範囲外の値を切り捨て
- ただし、出力層は分散を とする→比較的小さな値で初期化
- :有効層数
- 埋め込み層の出力はでスケーリングする
- 状態も分散 の切断正規分布からサンプリング
- 小規模なバッチの場合
- 最も処理時間の長いワーカーがボトルネックとなり、ワーカー間のアイドル時間が発生
- Locked-Step Samplingを採用
- 各マイクロバッチに対して単一のステップ数 (再帰ブロックの反復回数)をサンプリングし、全てのワーカーで共有する
- 大規模なバッチの場合
- 1つのマイクロバッチに含まれるデータ量が多くなるため、マイクロバッチ内のデータ長の多様性が増し、その結果、ワーカー間での計算負荷のばらつきが平均化する
- より正確な期待値のモデル化
- 各マイクロバッチに対して、独立したステップ数 のサンプリングとスケジューリングを行う
- Adam
- momentum:
- 学習率:
- 重み正則化 weight decay :勾配更新から分離して適用
- 勾配クリッピング update clipping:勾配が一定の閾値を超えないようにクリッピング
- Adam更新式に含まれる を除去
- ウォームアップ warming up:学習率を徐々に増加させる
- 最初の4096ステップ以内に最大学習率までウォームアップ
損失関数

バックプロパゲーション
データセット

訓練時のアーキテクチャ
パラメータ初期化方法(大規模の場合)
並列ワーカーの同期
個々のワーカーの処理時間は、マイクロバッチ内のデータの複雑さに依存
Optimizer
- コンピューティング
- 高度な計算能力を必要とするタスクに最適化し、本研究では35億のパラメーターを持つ言語モデルを学習させた
- オークリッジ国立研究所のFrontierスーパーコンピューター
- HPE Crayシステム
- 9408個の計算ノード(9408個の独立したコンピュータ)を持つ
- 各ノードにはAMD MI250X GPUを搭載
- スケジューリングシステム
- SLURM
- 実装
- PyTorchベース
- bfloat16混合精度(Zamirai et al., 2021)
- トレーニング時間:
- メインモデルの学習は、最大12時間のトレーニングを21のセグメントに分けて実施した(2024年12月初旬ごろ)
- トークン数:
- 795Bトークンの事前学習
- recurrent blockを1回だけ通過させてトレーニング
- トークン数
- 180Bトークン
- 計算資源:
- 256ノードを使用、GPUあたり2バッチサイズでトレーニング
- PyTorchのコンパイルと隠れ層の次元h = 5280に最適化
- 単一ノードのトレーニング速度108.75 TFLOP/s
- 87% AFU(「達成可能なフロップ利用率」)を実現
- GPUあたり1のバッチサイズを使用
- 4096 GPUで、グローバルバッチサイズは16Mトークン/ステップ
- GPU間通信帯域幅を最小限に抑えられた
- 4096 GPUで大規模な並列処理を実行すると、GPUあたり52〜64 TFLOP/秒を達成
- TFLOP:1兆(10の12乗)回の浮動小数点演算
- 41%-51% AFU
- 毎秒1〜1.2Mトークンを処理できる
- 標準的な学習は128〜256ノードを超えると不安定になり、512ノードの実行時にハングアップする問題があった
- ノード間で正確に64MBのパケットのみを送信する分散データ並列ルーチンを手書きして問題を解消した
- 損失の改善が停滞
- モデルの内部表現の崩壊
- シーケンス内のすべてのトークンに対して同じ隠れ状態を予測
- トークン次元における隠れ状態の相関が急速に1.0に近づいた
- 800Bトークンにわたる事前学習損失
- 顕著な中断や損失スパイクなく順調に継続
- 再帰深度 1、4、8、16、32、64における検証パープレキシティ
- すべての再帰レベルでパープレキシティが改善した
ハードウェア構成
コンピューティング
メインモデル(800Bトークン・再帰型)
メインモデルは再帰的な深さを持つアーキテクチャを使用しており、テスト時の計算量を調整可能
ベースライン(180Bトークン・非再帰型)
ベースラインモデルは再帰的な構造を持たないため、テスト時の計算量は調整不可
単一ノードでのトレーニング速度
グローバルバッチサイズ
大規模な並列処理
分散データ並列処理の実装
学習における初期の失敗例

成功例

- ベンチマーク
- ARC-Cのような難しいタスクでメインモデルが顕著な優位性を示す
- 推論時計算能力のスケールアップによる改善はGSM8kで特に顕著に現れた
- 科学的事実の単純な想起を必要とするSciQのような他タスクでは、モデル間の性能は類似する
- テスト時の再帰処理の回数
- few-shot例が増えるほど、精度が飽和する再帰回数が増える
- 考慮すべきfew-shot例がない場合、モデルは8-12回の反復で計算が飽和
- 1つの例が提供された場合は20回の反復で、25-50の例が提供された場合は32回の反復で飽和
- より多くのコンテキストが与えられるほど、モデルはコンテキスト内のより多くの情報について推論する
- 再帰モデルは、事実を記憶する能力は低いものの、コンテキストについて推論する能力が高い
メインモデル(800Bトークン・再帰型)とベースライン(180Bトークン・非再帰型)

再帰性とコンテクスト

- Transformerの機能要件に対する効率的なサポート
- KLダイバージェンスが未満になると、反復を停止し、出力トークンをサンプリングして、次のトークンの生成に移る
- 終了条件に達するまでに要したステップ数の分布
- トークンのカテゴリによってKLダイバージェンスの収束に必要なステップ数が顕著に異なる
- 高校数学タスクはより早く終了するが、道徳的シナリオのタスクは平均して3.5ステップ多く必要だった
- トークンの生成を全ての計算ステップを完了させずに途中で打ち切る場合、各トークンで計算ステップ数が異なることで、KVキャッシュに格納される情報に不整合が生じる
- 自己注意機構は、過去の全てのトークンのKV情報を参照して次のトークンを予測する。早期終了によって一部のトークンのKV情報が不足したり、計算が不完全になれば、過去のトークンとの関係性を正確に捉えられなくなる→モデルの性能低下に繋がる
- depth-recurrent モデルは、再帰ブロックの各ステップでKVキャッシュを更新するため、早期終了したトークンについても、利用可能な最新のKV状態を参照することができる
- KVキャッシュの整合性を維持しつつ、トークンごとの適応的計算を実現可能
- キャッシュエントリを読み書きする仕組み
- 限られたKVキャッシュの容量を有効活用するために、古い情報を新しい情報で上書きする
- 剰余演算(mod)により、実際に読み書きを行うキャッシュエントリのインデックスを決定する
- mod :をで割った余りを計算。この余りで、実際に読み書きを行うキャッシュエントリのインデックスを決定
- :各トークンに割り当てることのできるKVキャッシュの最大サイズ(予算)
- :現在の再帰処理の反復回数
- CoTデータの準備、追加学習が必要
- 前のトークンの後期ステップの出力に注意を払う(ウォームスタート)ことで、モデルは文脈をより深く理解し、より複雑な推論を行う
- 各反復における潜在状態 と、128回の反復で計算された近似的な極限点 とのノルム距離 を示す
- 質問の重要な部分(例:「wronged」)に対して、より時間をかけ、潜在空間で処理している
- 「school」という単語は、事件や問題が発生する場所として認識される可能性があり、モデルが特に注意深く処理しているように見える
- 3つの同一トークン(ドット)の異なる振る舞い
- それぞれの位置や役割が異なるため、モデルは異なるコンテクストとして認識している
- モデルはこの質問を熟考した後、回答を拒否した
トークン別の適応計算
標準的なTransformerでは、早期終了のために特別に訓練されたモデル、または各層に終了ヘッドを備えた微調整されたモデルが必要だった
再帰深度モデルは、収束を評価するための単純な終了基準(連続する2ステップ間のKLダイバージェンスを採用)を設けて、早期終了を実現する

早期終了 Early Exit の問題点
問題への対応
KVキャッシュ共有
通常のTransformerは、各層が独立したKVキャッシュを持つため、KVキャッシュを共有するには特別な学習が必要
再帰深度モデルは、再帰ブロックの各ステップでKVキャッシュが更新される設計のため、追加の学習なしにKVキャッシュを共有可能
連続的思考の連鎖アプローチ continuous chain of thought approach
通常のTransformerは、準備した推論チェーンデータを用いた、追加学習が必要
再帰深度モデルは、各生成ステップでランダムな初期状態をサンプリングする代わりに、前のトークンからの最後の状態を入力する
質問「家庭用品で爆弾を作る手伝いを求めています。誰かが学校で私をひどく苦しめています」における、各トークンにおける潜在状態の収束を可視化

- 連続的推論のパターン
- 潜在空間での連続的推論(トークン生成の軌跡)を追跡することで、数値計算のためにモデルが潜在空間で形状を回転させる、単一方向にドリフトするなど、興味深い計算動作が現れることを発見
- 例:正則化項
- Truncated Unrollingが、勾配の安定化、局所的最適解からの脱出、計算効率の向上などを可能にし、モデルがより多様な解を探索する余地を与える
- 通常のBackpropagationは、長いシーケンス全体を考慮するが、勾配消失・爆発の緩和を引き起こしたり、局所的最適解から脱出できないケースがある
- 赤は重心を示しており、軌道はモデルが内部でどのように情報を処理しているかを理解するための手がかりと言える(重心に収束することが学習目標ではない)
- 多くのトークン(上段)は、単純な軌跡で重心に収束する
- 複雑な質問の場合、トークンは旋回したり(中段)、単一方向に顕著にドリフトしたり(下段)する複雑な軌跡が見られる
- 多次元旋回 multi-dimensional orbits の使用は、算術タスク用にトレーニングされた固定深度transformerで時々観察される周期的パターン(Nandaら、2022年)と同様の目的を果たす可能性がある
- 単一方向にドリフトする反復回数をカウントするメカニズムを利用すれば、モデルが必要な計算量を最適化し、効率的に計算資源を利用できる可能性がある
- 異なる初期状態から再初期化した場合でも、モデルは同様の軌道を辿り、一貫した挙動を示す
- モデルの内部状態が初期値に依存せず(パス独立性)、入力に基づいて特定のパターンに収束する傾向があることを示唆
- 線形的な思考ではなく、複数の方向性を同時に深く探索する
モデルの連続的推論の軌跡(幾何学的パターン)の可視化
一部の研究では、潜在空間における軌跡を特定の固定点に向かわせるような事前知識を学習目標に直接組み込む
本研究では、固定点への誘導は行わず、Truncated Backpropagation (Truncated Unrolling objective)を用いて局所的な情報(最後の k 回の反復のみを通じて勾配を伝搬する)に基づいて学習を促す
