The Answer Key Trick That Cuts Reasoning LLM Training Time in Half
A*-PO is a new RL training algorithm for LLMs that precomputes an 'optimal value' offline, then trains with just one sample per prompt instead of many. It matches or beats PPO and GRPO at up to 2x faster speed and 30%+ lower memory.
The biggest bottleneck in RL training for LLMs isn’t the model — it’s the overhead of figuring out how good each response could be while training. A new algorithm called A*-PO skips that cost entirely by precomputing the answer key before training starts, then running online updates with only a single response per prompt instead of many.
- Up to 2x faster total training time vs. PPO and GRPO (49h vs. 88h and 90h on long-context tasks)
- 30%+ lower peak GPU memory vs. PPO at all model sizes tested (1.5B, 3B, 7B)
- Matches or beats PPO and GRPO on math benchmarks: GSM8K, MATH500, AIME, HMMT
- REBEL runs out of memory in long-context (16K token) settings; A*-PO does not
- Filtering unsolvable prompts using the offline stage cuts Stage 2 training time by 28% with equal or better accuracy
- Works on base models directly — no supervised fine-tuning step needed
Why RL Training for Reasoning Is So Expensive
Models like DeepSeek-R1 and o1 are good at hard reasoning tasks because they were trained with reinforcement learning. The training loop is simple in concept: generate a response, check if it’s correct (binary reward), update the model to produce more correct responses.
The expensive part is computing the advantage — a number that tells the model “was this particular response better or worse than what you usually do?” There are two common ways to estimate it:
PPO maintains a separate “critic” network — basically a second model that’s trained alongside the main model to predict expected rewards. That doubles the memory cost and adds significant computation.
GRPO avoids the critic by generating multiple responses for the same prompt (typically 8-16) and comparing them against each other. No extra network, but now you’re generating 8x more text during every training step.
Both approaches are necessary because at training time, you don’t know what “baseline performance” looks like without either measuring it on the fly or keeping a model around to estimate it.
The Key Insight: Precompute the Ceiling Offline
A*-PO (Policy Optimization via Optimal Advantage Regression) makes one observation that changes the cost structure entirely.
For KL-regularized RL — the standard formulation used in LLM post-training — there’s a closed-form expression for the optimal value function V*(x). It’s the expected reward you’d get if you sampled infinitely many responses from the reference model (the base model before RL). You don’t need to know the current policy’s value function at all.
This means you can estimate V*(x) for every prompt in your training set before training begins, using only the frozen reference model. No gradients. No critic network. Fully parallelizable inference.
That’s Stage 1. You generate N=8 responses per prompt from the reference model and compute V*(x) from their correctness scores. It’s a one-time cost that runs on fast inference (vLLM, etc.) with no backprop.
Stage 2 is then remarkably simple. At each training step, generate one response per prompt using the current policy, and update via a least-squares regression loss: push the model’s log-probability ratio toward the “optimal advantage” r(x, y) - V̂*(x).
If the response was correct and the reference model usually fails this question, the advantage is high — reinforce. If the response was wrong but the reference model usually gets it right, the advantage is negative — suppress. The “answer key” is already baked in from Stage 1.
No clipping (unlike PPO). No response-wise normalization heuristic (unlike GRPO). No critic network. One rollout per prompt.
How A*-PO Compares in Setup
| PPO | GRPO | A*-PO | |
|---|---|---|---|
| Critic network | Yes | No | No |
| Samples per prompt (online) | 1 (+ critic) | 8–16 | 1 |
| Offline pre-generation | No | No | Yes (8 per prompt, once) |
| Response-wise normalization | — | Yes | No |
| Works at 16K context length | Yes (slower) | Yes (slower) | Yes (fastest) |
What the Benchmarks Show
The paper tests three model sizes (Qwen2.5 1.5B, 3B, 7B) across GSM8K (grade school math), MATH (competition math), and long-context competition benchmarks (AIME, HMMT) using the DeepSeek-distilled 1.5B model at 16K context length.
On standard math (MATH dataset, 7B model): A*-PO completes training in 11.01 hours. GRPO takes 20.15h, PPO takes 20.53h. REBEL finishes in 14.67h but uses 98.77% of GPU memory — near OOM — vs. A*-PO’s 76.57%. Accuracy is competitive or better across all evaluation sets.
Training Time — MATH Dataset, 7B Model (lower is better)
Measured on 4x H100 GPUs. A*-PO time includes offline Stage 1 generation.
Peak GPU Memory — MATH Dataset, 7B Model (lower is better)
% of total GPU memory used during backpropagation, averaged over 100 batches.
On long-context competition problems (AIME, HMMT): This is where the gap widens further. A*-PO trains in 49 hours vs. 88h for PPO and 90h for GRPO. REBEL runs out of memory entirely at 16K context length and is excluded from the comparison. A*-PO achieves the best average Avg@32 score across all four competition benchmarks.
A bonus side effect: Because the offline stage already tells you which prompts the reference model can never solve (all N=8 samples wrong), you can filter those prompts out of Stage 2 training. They’re genuinely unlearnable problems — RL post-training won’t fix a 0% pass rate. Filtering with N=8 cuts Stage 2 training time by 28% while maintaining or improving final accuracy on MATH500.
Caveats Worth Understanding
Stage 1 is a one-time cost, but it’s not free. Generating 8 responses per prompt offline takes GPU time (inference only, no gradients). The paper includes this in their “total training time” comparisons, and A*-PO still wins. But if you’re evaluating against methods that skip offline prep, the amortized cost matters — Stage 1 only pays off if you iterate through Stage 2 multiple times.
The assumption: the reference model can solve the problem at least sometimes. The theoretical guarantees require that for every training prompt, the base model has non-zero probability of getting it right. The paper is explicit: RL post-training cannot rescue problems where the base model’s pass@K is zero for any reasonably large K. That’s consistent with recent empirical findings from other labs. If you want to teach a model completely new skills, this approach doesn’t help more than others.
N=8 offline samples hits a plateau. Going from N=1 to N=8 significantly improves Stage 2 training quality. Going beyond N=8 (to N=16, N=32) shows diminishing returns on MATH500 accuracy. Eight is the sweet spot they recommend; you’re mostly spending compute beyond that for minimal gain.
Results are on math benchmarks. The experiments focus heavily on mathematical reasoning with verifiable binary rewards (+1 correct, 0 wrong). Math is the canonical test bed for this kind of RL training, but real-world applications with noisier or more complex reward signals might show different trade-offs.
The code is open-sourced at github.com/ZhaolinGao/A-PO and builds on the verl framework.