# Rollout Importance Sampling - Migration Guide Last updated: 10/11/2025. This document provides a comprehensive overview of the Rollout Importance Sampling (IS) implementation merged from aiic_verl into verl. ## References - **When Speed Kills Stability**: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda - **Off-policy RL**: https://fengyao.notion.site/off-policy-rl ## Overview Rollout Importance Sampling corrects for distribution mismatch between: - **Rollout policy**: e.g., vLLM with BFloat16 - **Training policy**: e.g., FSDP with FP32 This mismatch can lead to biased gradient estimates and unstable training. Rollout IS applies importance sampling weights to correct these biases. ## What Changed ### **Removed (Old Implementation)** ```yaml # Old TIS configuration (REMOVED) actor: tis_imp_ratio_cap: 2.0 # ❌ No longer supported ``` The old implementation: - Only supported token-level truncate mode - Had no metrics tracking - Lacked numerical stability safeguards - No configurability for different scenarios ### **Added (New Implementation)** ```yaml # New Rollout IS configuration (all in algorithm config) algorithm: # Main control: set threshold to enable (null = disabled) rollout_is_threshold: 2.0 # Whether to apply weights to loss (default: false = metrics only) rollout_is: true rollout_is_threshold_lower: null # Auto-reciprocal rollout_is_level: token rollout_is_mode: truncate rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log prob calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` The new implementation: - ✅ Three aggregation levels: token, sequence, geometric - ✅ Two bounding modes: truncate, mask - ✅ Dual threshold support (upper/lower) - ✅ Veto mechanism for catastrophic outliers - ✅ 30+ comprehensive metrics - ✅ Log-space computation for numerical stability - ✅ Memory-efficient implementation ## Files Modified ### **Core Implementation** 1. **NEW**: `verl/trainer/ppo/mismatch_helper.py` - Contains `compute_rollout_importance_weights()` - main function - Contains `compute_is_metrics()` - comprehensive metrics 2. **MODIFIED**: `verl/trainer/ppo/core_algos.py` (lines 962-991) - Replaced old TIS implementation (lines 962-967) - Added new rollout IS with metrics support 3. **MODIFIED**: `verl/workers/actor/dp_actor.py` - Updated to use `rollout_is_threshold` instead of `tis_imp_ratio_cap` - Collects and logs all rollout IS metrics ### **Configuration Files** 4. **MODIFIED**: `verl/trainer/config/algorithm.py` (lines 95-100) - Added 6 new rollout IS parameters to `AlgoConfig` 5. **MODIFIED**: `verl/workers/config/actor.py` (lines 110-115) - Added 6 new rollout IS parameters to `ActorConfig` 6. **MODIFIED**: `verl/trainer/config/actor/actor.yaml` (lines 77-89) - Added rollout IS configuration section 7. **MODIFIED**: `verl/trainer/config/ppo_trainer.yaml` (lines 116-133) - Added rollout IS to algorithm config ### **Documentation** 8. **MODIFIED**: `docs/examples/config.rst` - Updated actor config with rollout IS parameters - Updated algorithm config with rollout IS parameters - Added detailed parameter descriptions ### **Example Scripts** 9. **MODIFIED**: `recipe/dapo/run_dapo_qwen2.5_32b_tis.sh` - Updated from `tis_imp_ratio_cap` to rollout IS parameters - Added comprehensive comments 10. **NEW**: `examples/rollout_importance_sampling/README.md` - Comprehensive guide with usage patterns - Troubleshooting section - Performance considerations 11. **NEW**: `examples/rollout_importance_sampling/run_with_rollout_is.sh` - Basic example with token-level truncate ### **Tests** 12. **NEW**: `tests/trainer/ppo/test_rollout_is.py` - Unit tests for rollout IS functionality 13. **NEW**: `tests/trainer/ppo/test_rollout_is_integration.py` - Integration tests with PPO ## Configuration Parameters ### `algorithm.rollout_is_threshold` (float or null) **Main on/off switch.** Upper threshold for IS weights. - `null` = disabled (no computation, no metrics) - `float` value (e.g., 2.0) = enabled (compute weights and metrics) ### `algorithm.rollout_is` (bool) Whether to apply IS weights to policy loss. Default: `False` - `true` = apply weights to loss (full IS correction) - `false` = compute metrics only (useful for monitoring before enabling) **Recommended threshold ranges:** - Token level: 1.5 - 5.0 - Sequence level: 2.0 - 10.0 - Geometric level: 1.0002 - 1.001 ### `algorithm.rollout_is_threshold_lower` (float or null) Lower threshold for IS weights. If `null`, defaults to 1/upper (reciprocal). ### `algorithm.rollout_is_level` (str) Aggregation level for IS weights: - `"token"`: Per-token ratios - `"sequence"`: Product of ratios - `"geometric"`: Geometric mean (experimental) ### `algorithm.rollout_is_mode` (str) Bounding mode: - `"truncate"`: Cap weights at upper threshold only - `"mask"`: Zero out weights outside [lower, upper] ### `algorithm.rollout_is_veto_threshold` (float) Per-token veto threshold. If any token ratio < this, entire sequence is rejected. Default: `1e-4` (ratio 10,000x off) ## Migration Steps ### Step 1: Update Your Configuration **Before (Old):** ```yaml actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 rollout: calculate_log_probs: true ``` **After (New):** ```yaml algorithm: rollout_is_threshold: 2.0 # Main control rollout_is: true # Apply to loss (default: false) rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true # Still required! ``` ### Step 2: Monitor New Metrics All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appears as `mismatch/rollout_is_mean` in logs. #### **Core IS Weight Metrics** - **`rollout_is_mean`**: Mean importance sampling weight across all valid tokens - **Ideal value**: Close to 1.0 (indicates minimal distribution mismatch) - **Warning**: < 0.5 or > 2.0 suggests significant policy mismatch - **`rollout_is_std`**: Standard deviation of IS weights - **Ideal value**: < 0.5 for stable training - **Warning**: > 1.0 indicates high variance, may need tighter thresholds - **`rollout_is_min`**: Minimum IS weight observed - Shows the most underweighted token/sequence - **`rollout_is_max`**: Maximum IS weight observed (before truncation/masking) - Shows the most overweighted token/sequence - Compare with `rollout_is_threshold` to see truncation impact #### **Percentile Metrics** - **`rollout_is_p25`**: 25th percentile of IS weights - **`rollout_is_p50`**: Median IS weight (50th percentile) - Should be close to `rollout_is_mean` if distribution is symmetric - **`rollout_is_p75`**: 75th percentile of IS weights - **`rollout_is_p95`**: 95th percentile of IS weights - Use to detect outliers - **`rollout_is_p99`**: 99th percentile of IS weights - Should be close to `rollout_is_threshold` if truncation is working #### **Effective Sample Size** - **`rollout_is_eff_sample_size`**: Effective sample size after IS weighting - **Formula**: `1 / mean(weights²)` where weights are normalized - **Range**: 0.0 to 1.0 (as fraction of original batch) - **Ideal value**: > 0.5 (retaining at least 50% effective samples) - **Warning**: < 0.3 means high variance, losing too many effective samples #### **Veto Mechanism Metrics** - **`rollout_is_veto_fraction`**: Fraction of sequences rejected by veto mechanism - **Ideal value**: < 0.05 (less than 5% vetoed) - **Warning**: > 0.1 suggests policies are too different or numerical issues - **`rollout_is_catastrophic_token_fraction`**: Fraction of tokens below veto threshold - Identifies problematic tokens before sequence-level veto - **Warning**: > 0.01 indicates widespread distribution issues #### **Threshold Exceedance Metrics** - **`rollout_is_ratio_fraction_high`**: Fraction of weights exceeding upper threshold - Shows how often truncation/masking occurs on high end - **Ideal value**: < 0.1 (most weights within bounds) - **`rollout_is_ratio_fraction_low`**: Fraction of weights below lower threshold - Shows how often masking occurs on low end (mask mode only) - **Ideal value**: < 0.1 #### **Sequence-Level Metrics** (for sequence/geometric modes) - **`rollout_is_seq_mean`**: Mean IS weight at sequence level - Should match `rollout_is_mean` for sequence-level aggregation - **`rollout_is_seq_std`**: Standard deviation of sequence-level IS weights - **`rollout_is_seq_min`**: Minimum sequence-level IS weight - **`rollout_is_seq_max`**: Maximum sequence-level IS weight - **`rollout_is_seq_max_deviation`**: Maximum absolute deviation from 1.0 at sequence level - **Ideal value**: < 1.0 - Shows worst-case sequence mismatch - **`rollout_is_seq_fraction_high`**: Fraction of sequences exceeding upper threshold - **`rollout_is_seq_fraction_low`**: Fraction of sequences below lower threshold #### **Masking Metrics** (mask mode only) - **`rollout_is_masked_fraction`**: Fraction of tokens masked (set to zero) - **Ideal value**: < 0.1 - **Warning**: > 0.3 means losing too much data - **`rollout_is_seq_masked_fraction`**: Fraction of sequences with at least one masked token - Shows sequence-level impact of masking #### **Distribution Mismatch Metrics** (Training vs Rollout Policy) - **`mismatch_training_ppl`**: Perplexity of training policy (e.g., FSDP FP32) - **Formula**: `exp(-mean(log_probs))` - Lower is better (model is more confident) - **`mismatch_rollout_ppl`**: Perplexity of rollout policy (e.g., vLLM BF16) - Should be close to `mismatch_training_ppl` if policies match well - **`mismatch_ppl_ratio`**: Ratio of training PPL to rollout PPL - **Formula**: `exp(mean(log(training_ppl / rollout_ppl)))` - **Ideal value**: Close to 1.0 - **Meaning**: > 1.0 means training is less confident than rollout - **`mismatch_training_log_ppl`**: Log perplexity of training policy - Useful for identifying trends (linear scale) - **`mismatch_rollout_log_ppl`**: Log perplexity of rollout policy - **`mismatch_log_ppl_diff`**: Mean difference in log perplexities - **Formula**: `mean(log_ppl_rollout - log_ppl_training)` - **Ideal value**: Close to 0.0 - Sign indicates which policy is more confident - **`mismatch_log_ppl_abs_diff`**: Mean absolute log perplexity difference - Magnitude of mismatch regardless of direction - **`mismatch_log_ppl_diff_max`**: Maximum log perplexity difference across sequences - Identifies worst-case sequence - **`mismatch_log_ppl_diff_min`**: Minimum log perplexity difference across sequences - **`mismatch_kl`**: KL divergence KL(π_rollout || π_training) - **Formula**: `mean(log_prob_rollout - log_prob_training)` - **Ideal value**: Close to 0.0 (policies match) - **Warning**: > 0.1 indicates significant mismatch - **Note**: Can be negative (rollout is less confident) - **`mismatch_k3_kl`**: K3 KL estimator - **Formula**: `mean(exp(log_ratio) - log_ratio - 1)` - More stable for small KL values - Always non-negative #### **Example: Accessing Metrics in Code** ```python # Metrics are returned from compute_rollout_importance_weights from verl.trainer.ppo.mismatch_helper import compute_rollout_importance_weights weights_proto, metrics = compute_rollout_importance_weights( old_log_prob=training_log_probs, # from training policy rollout_log_prob=rollout_log_probs, # from rollout policy response_mask=response_mask, rollout_is_level="token", rollout_is_mode="truncate", rollout_is_threshold=2.0, rollout_is_veto_threshold=1e-4, ) # All metrics have 'mismatch/' prefix print(f"Mean IS weight: {metrics['mismatch/rollout_is_mean']:.3f}") print(f"Effective sample size: {metrics['mismatch/rollout_is_eff_sample_size']:.3f}") print(f"Veto fraction: {metrics['mismatch/rollout_is_veto_fraction']:.3f}") print(f"KL divergence: {metrics['mismatch/mismatch_kl']:.3f}") # Check for warning conditions if metrics['mismatch/rollout_is_mean'] < 0.5 or metrics['mismatch/rollout_is_mean'] > 2.0: print("⚠️ Warning: Mean IS weight far from 1.0, significant policy mismatch detected") if metrics['mismatch/rollout_is_eff_sample_size'] < 0.3: print("⚠️ Warning: Low effective sample size, high variance in IS weights") if metrics['mismatch/rollout_is_veto_fraction'] > 0.1: print("⚠️ Warning: High veto fraction, policies may be too different") ``` #### **Example: Monitoring Metrics During Training** ```python # In your training loop for epoch in range(num_epochs): for batch_idx, batch in enumerate(dataloader): # ... rollout phase ... # Compute IS weights and get metrics weights_proto, metrics = compute_rollout_importance_weights( old_log_prob=batch.old_log_prob, rollout_log_prob=batch.rollout_log_prob, response_mask=batch.response_mask, rollout_is_level=config.rollout_is_level, rollout_is_mode=config.rollout_is_mode, rollout_is_threshold=config.rollout_is_threshold, rollout_is_veto_threshold=config.rollout_is_veto_threshold, ) # Log to tensorboard/wandb for metric_name, metric_value in metrics.items(): logger.log_scalar(metric_name, metric_value, step=global_step) # Use IS weights in training is_weights = weights_proto.batch["rollout_is_weights"] # ... apply weights to policy gradient ... ``` #### **Example: Conditional Alerting Based on Metrics** ```python def check_rollout_is_health(metrics, config): """Check if rollout IS metrics indicate healthy training.""" warnings = [] # Check mean IS weight mean_weight = metrics['mismatch/rollout_is_mean'] if mean_weight < 0.5 or mean_weight > 2.0: warnings.append(f"Mean IS weight {mean_weight:.3f} is far from 1.0") # Check effective sample size ess = metrics['mismatch/rollout_is_eff_sample_size'] if ess < 0.3: warnings.append(f"Effective sample size {ess:.3f} is too low") # Check veto fraction veto_frac = metrics['mismatch/rollout_is_veto_fraction'] if veto_frac > 0.1: warnings.append(f"Veto fraction {veto_frac:.3f} is too high") # Check variance std = metrics['mismatch/rollout_is_std'] if std > 1.0: warnings.append(f"IS weight std {std:.3f} is too high") # Check KL divergence kl = metrics['mismatch/mismatch_kl'] if abs(kl) > 0.1: warnings.append(f"KL divergence {kl:.3f} indicates significant mismatch") if warnings: print("⚠️ Rollout IS Health Warnings:") for warning in warnings: print(f" - {warning}") return False else: print("✅ Rollout IS metrics look healthy") return True # Use in training _, metrics = compute_rollout_importance_weights(...) is_healthy = check_rollout_is_health(metrics, config) if not is_healthy: # Consider adjusting config or investigating issues print("Consider:") print(" - Tightening rollout_is_threshold") print(" - Switching to geometric aggregation level") print(" - Checking if rollout and training policies are too different") ``` ### Step 3: Test Your Training Start with the basic token-level truncate configuration: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` Monitor metrics for 1-2 epochs before adjusting parameters. ## Configuration Examples ### Example 1: Full IS Correction ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true # Apply weights to loss rollout_is_level: token rollout_is_mode: truncate ``` ### Example 2: Metrics Only (Monitoring Mode) ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics, don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` ### Example 3: Geometric Mean with Mask ```yaml algorithm: rollout_is_threshold: 1.0002 rollout_is: true rollout_is_threshold_lower: 0.9998 rollout_is_level: geometric rollout_is_mode: mask ``` ### Example 4: Asymmetric Thresholds ```yaml algorithm: rollout_is_threshold: 5.0 rollout_is: true rollout_is_threshold_lower: 0.8 rollout_is_level: token rollout_is_mode: mask ``` ## Troubleshooting ### Issue: High variance in IS weights **Symptoms:** `rollout_is_std` > 1.0, `rollout_is_eff_sample_size` < 0.3 **Solutions:** 1. Switch from `sequence` to `geometric` level 2. Tighten thresholds 3. Verify rollout and training aren't too different ### Issue: Too many sequences vetoed **Symptoms:** `rollout_is_veto_fraction` > 0.1 **Solutions:** 1. Relax veto threshold: `rollout_is_veto_threshold: 1e-3` 2. Check for numerical issues in log prob computation 3. Verify policies aren't completely different ### Issue: Mean IS weight far from 1.0 **Symptoms:** `rollout_is_mean` < 0.5 or > 2.0 **Solutions:** 1. Verify `calculate_log_probs=True` is set 2. Check rollout_log_probs are correctly passed 3. Check for systematic bias ### Debugging: Visualizing Metrics **Example: Plot IS weight distribution** ```python import matplotlib.pyplot as plt import numpy as np def plot_is_metrics(metrics_history): """Plot rollout IS metrics over training steps.""" fig, axes = plt.subplots(2, 3, figsize=(15, 10)) # Plot 1: Mean IS weight over time axes[0, 0].plot(metrics_history['mismatch/rollout_is_mean']) axes[0, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal') axes[0, 0].set_title('Mean IS Weight') axes[0, 0].set_xlabel('Step') axes[0, 0].legend() # Plot 2: Effective sample size axes[0, 1].plot(metrics_history['mismatch/rollout_is_eff_sample_size']) axes[0, 1].axhline(y=0.5, color='g', linestyle='--', label='Good') axes[0, 1].axhline(y=0.3, color='r', linestyle='--', label='Warning') axes[0, 1].set_title('Effective Sample Size') axes[0, 1].set_xlabel('Step') axes[0, 1].legend() # Plot 3: Veto fraction axes[0, 2].plot(metrics_history['mismatch/rollout_is_veto_fraction']) axes[0, 2].axhline(y=0.1, color='r', linestyle='--', label='Warning') axes[0, 2].set_title('Veto Fraction') axes[0, 2].set_xlabel('Step') axes[0, 2].legend() # Plot 4: IS weight distribution (latest step) latest_idx = -1 percentiles = [25, 50, 75, 95, 99] values = [metrics_history[f'mismatch/rollout_is_p{p}'][latest_idx] for p in percentiles] axes[1, 0].bar([f'p{p}' for p in percentiles], values) axes[1, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal') axes[1, 0].set_title('IS Weight Percentiles (Latest)') axes[1, 0].legend() # Plot 5: KL divergence over time axes[1, 1].plot(metrics_history['mismatch/mismatch_kl'], label='KL') axes[1, 1].plot(metrics_history['mismatch/mismatch_k3_kl'], label='K3 KL') axes[1, 1].axhline(y=0, color='g', linestyle='--', alpha=0.3) axes[1, 1].set_title('KL Divergence') axes[1, 1].set_xlabel('Step') axes[1, 1].legend() # Plot 6: PPL ratio over time axes[1, 2].plot(metrics_history['mismatch/mismatch_ppl_ratio']) axes[1, 2].axhline(y=1.0, color='r', linestyle='--', label='Ideal') axes[1, 2].set_title('PPL Ratio (Training/Rollout)') axes[1, 2].set_xlabel('Step') axes[1, 2].legend() plt.tight_layout() plt.savefig('rollout_is_metrics.png', dpi=150) print("Saved plot to rollout_is_metrics.png") ``` **Example: Metric collection during training** ```python # Collect metrics over time metrics_history = { 'mismatch/rollout_is_mean': [], 'mismatch/rollout_is_eff_sample_size': [], 'mismatch/rollout_is_veto_fraction': [], 'mismatch/rollout_is_p25': [], 'mismatch/rollout_is_p50': [], 'mismatch/rollout_is_p75': [], 'mismatch/rollout_is_p95': [], 'mismatch/rollout_is_p99': [], 'mismatch/mismatch_kl': [], 'mismatch/mismatch_k3_kl': [], 'mismatch/mismatch_ppl_ratio': [], } # In training loop for step in range(num_steps): # ... compute IS weights ... _, metrics = compute_rollout_importance_weights(...) # Store metrics for key in metrics_history.keys(): if key in metrics: metrics_history[key].append(metrics[key]) # Plot every 100 steps if step % 100 == 0: plot_is_metrics(metrics_history) ``` ## Performance Impact - **Memory overhead**: ~1% of model memory - **Computational overhead**: 1-3% depending on level - **Training stability**: Significantly improved when mismatch exists ## Backward Compatibility **The old `tis_imp_ratio_cap` parameter is completely removed.** There is no backward compatibility mode. All scripts and configurations must be updated to use the new rollout IS parameters. ## Testing Run the test suite to verify everything works: ```bash # Basic unit tests python test_rollout_is.py # Integration tests (if pytest is available) pytest tests/trainer/ppo/test_rollout_is_integration.py -v ``` Expected output: All tests pass ✓ ## Additional Resources - **Implementation**: `verl/trainer/ppo/mismatch_helper.py` - **Examples**: `examples/rollout_importance_sampling/` - **DAPO Example**: `recipe/dapo/run_dapo_qwen2.5_32b_tis.sh` ## Summary The new Rollout Importance Sampling implementation provides: - ✅ More robust handling of distribution mismatch - ✅ Better numerical stability - ✅ Comprehensive metrics for monitoring - ✅ Flexibility for different scenarios - ✅ Memory-efficient computation Migration is straightforward: replace `tis_imp_ratio_cap` with the new `rollout_is_*` parameters in the `algorithm` config section.