Rollout Importance Sampling - Migration Guide
Last updated: 10/11/2025.
This document provides a comprehensive overview of the Rollout Importance Sampling (IS) implementation merged from aiic_verl into verl.
References
When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
Off-policy RL: https://fengyao.notion.site/off-policy-rl
Overview
Rollout Importance Sampling corrects for distribution mismatch between:
Rollout policy: e.g., vLLM with BFloat16
Training policy: e.g., FSDP with FP32
This mismatch can lead to biased gradient estimates and unstable training. Rollout IS applies importance sampling weights to correct these biases.
What Changed
Removed (Old Implementation)
# Old TIS configuration (REMOVED)
actor:
tis_imp_ratio_cap: 2.0 # ❌ No longer supported
The old implementation:
Only supported token-level truncate mode
Had no metrics tracking
Lacked numerical stability safeguards
No configurability for different scenarios
Added (New Implementation)
# New Rollout IS configuration (all in algorithm config)
algorithm:
# Main control: set threshold to enable (null = disabled)
rollout_is_threshold: 2.0
# Whether to apply weights to loss (default: false = metrics only)
rollout_is: true
rollout_is_threshold_lower: null # Auto-reciprocal
rollout_is_level: token
rollout_is_mode: truncate
rollout_is_veto_threshold: 1e-4
# REQUIRED: Enable log prob calculation
actor_rollout_ref:
rollout:
calculate_log_probs: true
The new implementation:
✅ Three aggregation levels: token, sequence, geometric
✅ Two bounding modes: truncate, mask
✅ Dual threshold support (upper/lower)
✅ Veto mechanism for catastrophic outliers
✅ 30+ comprehensive metrics
✅ Log-space computation for numerical stability
✅ Memory-efficient implementation
Files Modified
Core Implementation
NEW:
verl/trainer/ppo/mismatch_helper.pyContains
compute_rollout_importance_weights()- main functionContains
compute_is_metrics()- comprehensive metrics
MODIFIED:
verl/trainer/ppo/core_algos.py(lines 962-991)Replaced old TIS implementation (lines 962-967)
Added new rollout IS with metrics support
MODIFIED:
verl/workers/actor/dp_actor.pyUpdated to use
rollout_is_thresholdinstead oftis_imp_ratio_capCollects and logs all rollout IS metrics
Configuration Files
MODIFIED:
verl/trainer/config/algorithm.py(lines 95-100)Added 6 new rollout IS parameters to
AlgoConfig
MODIFIED:
verl/workers/config/actor.py(lines 110-115)Added 6 new rollout IS parameters to
ActorConfig
MODIFIED:
verl/trainer/config/actor/actor.yaml(lines 77-89)Added rollout IS configuration section
MODIFIED:
verl/trainer/config/ppo_trainer.yaml(lines 116-133)Added rollout IS to algorithm config
Documentation
MODIFIED:
docs/examples/config.rstUpdated actor config with rollout IS parameters
Updated algorithm config with rollout IS parameters
Added detailed parameter descriptions
Example Scripts
MODIFIED:
recipe/dapo/run_dapo_qwen2.5_32b_tis.shUpdated from
tis_imp_ratio_capto rollout IS parametersAdded comprehensive comments
NEW:
examples/rollout_importance_sampling/README.mdComprehensive guide with usage patterns
Troubleshooting section
Performance considerations
NEW:
examples/rollout_importance_sampling/run_with_rollout_is.shBasic example with token-level truncate
Tests
NEW:
tests/trainer/ppo/test_rollout_is.pyUnit tests for rollout IS functionality
NEW:
tests/trainer/ppo/test_rollout_is_integration.pyIntegration tests with PPO
Configuration Parameters
algorithm.rollout_is_threshold (float or null)
Main on/off switch. Upper threshold for IS weights.
null= disabled (no computation, no metrics)floatvalue (e.g., 2.0) = enabled (compute weights and metrics)
algorithm.rollout_is (bool)
Whether to apply IS weights to policy loss. Default: False
true= apply weights to loss (full IS correction)false= compute metrics only (useful for monitoring before enabling)
Recommended threshold ranges:
Token level: 1.5 - 5.0
Sequence level: 2.0 - 10.0
Geometric level: 1.0002 - 1.001
algorithm.rollout_is_threshold_lower (float or null)
Lower threshold for IS weights. If null, defaults to 1/upper (reciprocal).
algorithm.rollout_is_level (str)
Aggregation level for IS weights:
"token": Per-token ratios"sequence": Product of ratios"geometric": Geometric mean (experimental)
algorithm.rollout_is_mode (str)
Bounding mode:
"truncate": Cap weights at upper threshold only"mask": Zero out weights outside [lower, upper]
algorithm.rollout_is_veto_threshold (float)
Per-token veto threshold. If any token ratio < this, entire sequence is rejected.
Default: 1e-4 (ratio 10,000x off)
Migration Steps
Step 1: Update Your Configuration
Before (Old):
actor_rollout_ref:
actor:
tis_imp_ratio_cap: 2.0
rollout:
calculate_log_probs: true
After (New):
algorithm:
rollout_is_threshold: 2.0 # Main control
rollout_is: true # Apply to loss (default: false)
rollout_is_level: token
rollout_is_mode: truncate
actor_rollout_ref:
rollout:
calculate_log_probs: true # Still required!
Step 2: Monitor New Metrics
All metrics are prefixed with mismatch/. For example, rollout_is_mean appears as mismatch/rollout_is_mean in logs.
Core IS Weight Metrics
rollout_is_mean: Mean importance sampling weight across all valid tokensIdeal value: Close to 1.0 (indicates minimal distribution mismatch)
Warning: < 0.5 or > 2.0 suggests significant policy mismatch
rollout_is_std: Standard deviation of IS weightsIdeal value: < 0.5 for stable training
Warning: > 1.0 indicates high variance, may need tighter thresholds
rollout_is_min: Minimum IS weight observedShows the most underweighted token/sequence
rollout_is_max: Maximum IS weight observed (before truncation/masking)Shows the most overweighted token/sequence
Compare with
rollout_is_thresholdto see truncation impact
Percentile Metrics
rollout_is_p25: 25th percentile of IS weightsrollout_is_p50: Median IS weight (50th percentile)Should be close to
rollout_is_meanif distribution is symmetric
rollout_is_p75: 75th percentile of IS weightsrollout_is_p95: 95th percentile of IS weightsUse to detect outliers
rollout_is_p99: 99th percentile of IS weightsShould be close to
rollout_is_thresholdif truncation is working
Effective Sample Size
rollout_is_eff_sample_size: Effective sample size after IS weightingFormula:
1 / mean(weights²)where weights are normalizedRange: 0.0 to 1.0 (as fraction of original batch)
Ideal value: > 0.5 (retaining at least 50% effective samples)
Warning: < 0.3 means high variance, losing too many effective samples
Veto Mechanism Metrics
rollout_is_veto_fraction: Fraction of sequences rejected by veto mechanismIdeal value: < 0.05 (less than 5% vetoed)
Warning: > 0.1 suggests policies are too different or numerical issues
rollout_is_catastrophic_token_fraction: Fraction of tokens below veto thresholdIdentifies problematic tokens before sequence-level veto
Warning: > 0.01 indicates widespread distribution issues
Threshold Exceedance Metrics
rollout_is_ratio_fraction_high: Fraction of weights exceeding upper thresholdShows how often truncation/masking occurs on high end
Ideal value: < 0.1 (most weights within bounds)
rollout_is_ratio_fraction_low: Fraction of weights below lower thresholdShows how often masking occurs on low end (mask mode only)
Ideal value: < 0.1
Sequence-Level Metrics (for sequence/geometric modes)
rollout_is_seq_mean: Mean IS weight at sequence levelShould match
rollout_is_meanfor sequence-level aggregation
rollout_is_seq_std: Standard deviation of sequence-level IS weightsrollout_is_seq_min: Minimum sequence-level IS weightrollout_is_seq_max: Maximum sequence-level IS weightrollout_is_seq_max_deviation: Maximum absolute deviation from 1.0 at sequence levelIdeal value: < 1.0
Shows worst-case sequence mismatch
rollout_is_seq_fraction_high: Fraction of sequences exceeding upper thresholdrollout_is_seq_fraction_low: Fraction of sequences below lower threshold
Masking Metrics (mask mode only)
rollout_is_masked_fraction: Fraction of tokens masked (set to zero)Ideal value: < 0.1
Warning: > 0.3 means losing too much data
rollout_is_seq_masked_fraction: Fraction of sequences with at least one masked tokenShows sequence-level impact of masking
Distribution Mismatch Metrics (Training vs Rollout Policy)
mismatch_training_ppl: Perplexity of training policy (e.g., FSDP FP32)Formula:
exp(-mean(log_probs))Lower is better (model is more confident)
mismatch_rollout_ppl: Perplexity of rollout policy (e.g., vLLM BF16)Should be close to
mismatch_training_pplif policies match well
mismatch_ppl_ratio: Ratio of training PPL to rollout PPLFormula:
exp(mean(log(training_ppl / rollout_ppl)))Ideal value: Close to 1.0
Meaning: > 1.0 means training is less confident than rollout
mismatch_training_log_ppl: Log perplexity of training policyUseful for identifying trends (linear scale)
mismatch_rollout_log_ppl: Log perplexity of rollout policymismatch_log_ppl_diff: Mean difference in log perplexitiesFormula:
mean(log_ppl_rollout - log_ppl_training)Ideal value: Close to 0.0
Sign indicates which policy is more confident
mismatch_log_ppl_abs_diff: Mean absolute log perplexity differenceMagnitude of mismatch regardless of direction
mismatch_log_ppl_diff_max: Maximum log perplexity difference across sequencesIdentifies worst-case sequence
mismatch_log_ppl_diff_min: Minimum log perplexity difference across sequencesmismatch_kl: KL divergence KL(π_rollout || π_training)Formula:
mean(log_prob_rollout - log_prob_training)Ideal value: Close to 0.0 (policies match)
Warning: > 0.1 indicates significant mismatch
Note: Can be negative (rollout is less confident)
mismatch_k3_kl: K3 KL estimatorFormula:
mean(exp(log_ratio) - log_ratio - 1)More stable for small KL values
Always non-negative
Example: Accessing Metrics in Code
# Metrics are returned from compute_rollout_importance_weights
from verl.trainer.ppo.mismatch_helper import compute_rollout_importance_weights
weights_proto, metrics = compute_rollout_importance_weights(
old_log_prob=training_log_probs, # from training policy
rollout_log_prob=rollout_log_probs, # from rollout policy
response_mask=response_mask,
rollout_is_level="token",
rollout_is_mode="truncate",
rollout_is_threshold=2.0,
rollout_is_veto_threshold=1e-4,
)
# All metrics have 'mismatch/' prefix
print(f"Mean IS weight: {metrics['mismatch/rollout_is_mean']:.3f}")
print(f"Effective sample size: {metrics['mismatch/rollout_is_eff_sample_size']:.3f}")
print(f"Veto fraction: {metrics['mismatch/rollout_is_veto_fraction']:.3f}")
print(f"KL divergence: {metrics['mismatch/mismatch_kl']:.3f}")
# Check for warning conditions
if metrics['mismatch/rollout_is_mean'] < 0.5 or metrics['mismatch/rollout_is_mean'] > 2.0:
print("⚠️ Warning: Mean IS weight far from 1.0, significant policy mismatch detected")
if metrics['mismatch/rollout_is_eff_sample_size'] < 0.3:
print("⚠️ Warning: Low effective sample size, high variance in IS weights")
if metrics['mismatch/rollout_is_veto_fraction'] > 0.1:
print("⚠️ Warning: High veto fraction, policies may be too different")
Example: Monitoring Metrics During Training
# In your training loop
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(dataloader):
# ... rollout phase ...
# Compute IS weights and get metrics
weights_proto, metrics = compute_rollout_importance_weights(
old_log_prob=batch.old_log_prob,
rollout_log_prob=batch.rollout_log_prob,
response_mask=batch.response_mask,
rollout_is_level=config.rollout_is_level,
rollout_is_mode=config.rollout_is_mode,
rollout_is_threshold=config.rollout_is_threshold,
rollout_is_veto_threshold=config.rollout_is_veto_threshold,
)
# Log to tensorboard/wandb
for metric_name, metric_value in metrics.items():
logger.log_scalar(metric_name, metric_value, step=global_step)
# Use IS weights in training
is_weights = weights_proto.batch["rollout_is_weights"]
# ... apply weights to policy gradient ...
Example: Conditional Alerting Based on Metrics
def check_rollout_is_health(metrics, config):
"""Check if rollout IS metrics indicate healthy training."""
warnings = []
# Check mean IS weight
mean_weight = metrics['mismatch/rollout_is_mean']
if mean_weight < 0.5 or mean_weight > 2.0:
warnings.append(f"Mean IS weight {mean_weight:.3f} is far from 1.0")
# Check effective sample size
ess = metrics['mismatch/rollout_is_eff_sample_size']
if ess < 0.3:
warnings.append(f"Effective sample size {ess:.3f} is too low")
# Check veto fraction
veto_frac = metrics['mismatch/rollout_is_veto_fraction']
if veto_frac > 0.1:
warnings.append(f"Veto fraction {veto_frac:.3f} is too high")
# Check variance
std = metrics['mismatch/rollout_is_std']
if std > 1.0:
warnings.append(f"IS weight std {std:.3f} is too high")
# Check KL divergence
kl = metrics['mismatch/mismatch_kl']
if abs(kl) > 0.1:
warnings.append(f"KL divergence {kl:.3f} indicates significant mismatch")
if warnings:
print("⚠️ Rollout IS Health Warnings:")
for warning in warnings:
print(f" - {warning}")
return False
else:
print("✅ Rollout IS metrics look healthy")
return True
# Use in training
_, metrics = compute_rollout_importance_weights(...)
is_healthy = check_rollout_is_health(metrics, config)
if not is_healthy:
# Consider adjusting config or investigating issues
print("Consider:")
print(" - Tightening rollout_is_threshold")
print(" - Switching to geometric aggregation level")
print(" - Checking if rollout and training policies are too different")
Step 3: Test Your Training
Start with the basic token-level truncate configuration:
bash examples/rollout_importance_sampling/run_with_rollout_is.sh
Monitor metrics for 1-2 epochs before adjusting parameters.
Configuration Examples
Example 1: Full IS Correction
algorithm:
rollout_is_threshold: 2.0
rollout_is: true # Apply weights to loss
rollout_is_level: token
rollout_is_mode: truncate
Example 2: Metrics Only (Monitoring Mode)
algorithm:
rollout_is_threshold: 2.0
rollout_is: false # Compute metrics, don't apply weights
rollout_is_level: token
rollout_is_mode: truncate
Example 3: Geometric Mean with Mask
algorithm:
rollout_is_threshold: 1.0002
rollout_is: true
rollout_is_threshold_lower: 0.9998
rollout_is_level: geometric
rollout_is_mode: mask
Example 4: Asymmetric Thresholds
algorithm:
rollout_is_threshold: 5.0
rollout_is: true
rollout_is_threshold_lower: 0.8
rollout_is_level: token
rollout_is_mode: mask
Troubleshooting
Issue: High variance in IS weights
Symptoms: rollout_is_std > 1.0, rollout_is_eff_sample_size < 0.3
Solutions:
Switch from
sequencetogeometriclevelTighten thresholds
Verify rollout and training aren’t too different
Issue: Too many sequences vetoed
Symptoms: rollout_is_veto_fraction > 0.1
Solutions:
Relax veto threshold:
rollout_is_veto_threshold: 1e-3Check for numerical issues in log prob computation
Verify policies aren’t completely different
Issue: Mean IS weight far from 1.0
Symptoms: rollout_is_mean < 0.5 or > 2.0
Solutions:
Verify
calculate_log_probs=Trueis setCheck rollout_log_probs are correctly passed
Check for systematic bias
Debugging: Visualizing Metrics
Example: Plot IS weight distribution
import matplotlib.pyplot as plt
import numpy as np
def plot_is_metrics(metrics_history):
"""Plot rollout IS metrics over training steps."""
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
# Plot 1: Mean IS weight over time
axes[0, 0].plot(metrics_history['mismatch/rollout_is_mean'])
axes[0, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
axes[0, 0].set_title('Mean IS Weight')
axes[0, 0].set_xlabel('Step')
axes[0, 0].legend()
# Plot 2: Effective sample size
axes[0, 1].plot(metrics_history['mismatch/rollout_is_eff_sample_size'])
axes[0, 1].axhline(y=0.5, color='g', linestyle='--', label='Good')
axes[0, 1].axhline(y=0.3, color='r', linestyle='--', label='Warning')
axes[0, 1].set_title('Effective Sample Size')
axes[0, 1].set_xlabel('Step')
axes[0, 1].legend()
# Plot 3: Veto fraction
axes[0, 2].plot(metrics_history['mismatch/rollout_is_veto_fraction'])
axes[0, 2].axhline(y=0.1, color='r', linestyle='--', label='Warning')
axes[0, 2].set_title('Veto Fraction')
axes[0, 2].set_xlabel('Step')
axes[0, 2].legend()
# Plot 4: IS weight distribution (latest step)
latest_idx = -1
percentiles = [25, 50, 75, 95, 99]
values = [metrics_history[f'mismatch/rollout_is_p{p}'][latest_idx] for p in percentiles]
axes[1, 0].bar([f'p{p}' for p in percentiles], values)
axes[1, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
axes[1, 0].set_title('IS Weight Percentiles (Latest)')
axes[1, 0].legend()
# Plot 5: KL divergence over time
axes[1, 1].plot(metrics_history['mismatch/mismatch_kl'], label='KL')
axes[1, 1].plot(metrics_history['mismatch/mismatch_k3_kl'], label='K3 KL')
axes[1, 1].axhline(y=0, color='g', linestyle='--', alpha=0.3)
axes[1, 1].set_title('KL Divergence')
axes[1, 1].set_xlabel('Step')
axes[1, 1].legend()
# Plot 6: PPL ratio over time
axes[1, 2].plot(metrics_history['mismatch/mismatch_ppl_ratio'])
axes[1, 2].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
axes[1, 2].set_title('PPL Ratio (Training/Rollout)')
axes[1, 2].set_xlabel('Step')
axes[1, 2].legend()
plt.tight_layout()
plt.savefig('rollout_is_metrics.png', dpi=150)
print("Saved plot to rollout_is_metrics.png")
Example: Metric collection during training
# Collect metrics over time
metrics_history = {
'mismatch/rollout_is_mean': [],
'mismatch/rollout_is_eff_sample_size': [],
'mismatch/rollout_is_veto_fraction': [],
'mismatch/rollout_is_p25': [],
'mismatch/rollout_is_p50': [],
'mismatch/rollout_is_p75': [],
'mismatch/rollout_is_p95': [],
'mismatch/rollout_is_p99': [],
'mismatch/mismatch_kl': [],
'mismatch/mismatch_k3_kl': [],
'mismatch/mismatch_ppl_ratio': [],
}
# In training loop
for step in range(num_steps):
# ... compute IS weights ...
_, metrics = compute_rollout_importance_weights(...)
# Store metrics
for key in metrics_history.keys():
if key in metrics:
metrics_history[key].append(metrics[key])
# Plot every 100 steps
if step % 100 == 0:
plot_is_metrics(metrics_history)
Performance Impact
Memory overhead: ~1% of model memory
Computational overhead: 1-3% depending on level
Training stability: Significantly improved when mismatch exists
Backward Compatibility
The old tis_imp_ratio_cap parameter is completely removed. There is no backward compatibility mode.
All scripts and configurations must be updated to use the new rollout IS parameters.
Testing
Run the test suite to verify everything works:
# Basic unit tests
python test_rollout_is.py
# Integration tests (if pytest is available)
pytest tests/trainer/ppo/test_rollout_is_integration.py -v
Expected output: All tests pass ✓
Additional Resources
Implementation:
verl/trainer/ppo/mismatch_helper.pyExamples:
examples/rollout_importance_sampling/DAPO Example:
recipe/dapo/run_dapo_qwen2.5_32b_tis.sh
Summary
The new Rollout Importance Sampling implementation provides:
✅ More robust handling of distribution mismatch
✅ Better numerical stability
✅ Comprehensive metrics for monitoring
✅ Flexibility for different scenarios
✅ Memory-efficient computation
Migration is straightforward: replace tis_imp_ratio_cap with the new rollout_is_* parameters in the algorithm config section.