Skip to content

Comments

Add KL-penalized advantage adjustment#562

Merged
arcticfly merged 2 commits intomainfrom
kl-advantage
Feb 24, 2026
Merged

Add KL-penalized advantage adjustment#562
arcticfly merged 2 commits intomainfrom
kl-advantage

Conversation

@arcticfly
Copy link
Collaborator

@arcticfly arcticfly commented Feb 15, 2026

Summary

  • Adds a new mechanism that adjusts per-token advantages based on KL divergence from a reference model — tokens where the policy has drifted more get reduced advantages, tokens that drifted less get increased advantages. The adjustment is zero-mean (centered) across tokens.
  • New LocalBackend.train() parameters: kl_penalty_coef, kl_penalty_reference_step, and kl_ref_adapter_path
  • Fixes a pre-existing bug in preprocessing/inputs.py where warmup config used incorrect field names (lrlearning_rate, kl_coefkl_penalty_coef)

Usage

To penalize divergence from the base (step-0) checkpoint:

result = await backend.train(
    model,
    trajectory_groups,
    kl_penalty_coef=1.0,
    kl_penalty_reference_step=0,
)
  • kl_penalty_coef controls the strength of the penalty (default 1.0, set to 0.0 to disable).
  • kl_penalty_reference_step selects which checkpoint to use as the reference. Use 0 for the base checkpoint, or any other saved step. If omitted, the base model (LoRA disabled) is used.

Test plan

  • All linting/formatting checks pass (uv run prek run --all-files)
  • 5 unit tests for the advantage adjustment formula pass
  • Remote sweep of 9 kl_penalty_coef values (0.0001–1.0001) with kl_penalty_reference_step=0 completed successfully on Kubernetes H200 GPUs, all 20 steps each

🤖 Generated with Claude Code

arcticfly and others added 2 commits February 12, 2026 16:36
Introduces a new mechanism that adjusts per-token advantages based on KL
divergence from a reference model. Tokens where the policy has drifted more
get reduced advantages, while tokens that drifted less get increased
advantages. The adjustment is zero-mean (centered) across tokens.

New parameters on LocalBackend.train():
- kl_penalty_coef: coefficient for the adjustment (0.0 = disabled)
- kl_penalty_reference_step: use a specific checkpoint step as reference
- kl_ref_adapter_path: use an arbitrary LoRA adapter path as reference

Also fixes a pre-existing bug in preprocessing/inputs.py where warmup
config used incorrect field names (lr → learning_rate, kl_coef → kl_penalty_coef).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@arcticfly arcticfly requested review from bradhilton and corbt and removed request for bradhilton February 21, 2026 01:23
Copy link
Collaborator

@bradhilton bradhilton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@arcticfly arcticfly merged commit 078990e into main Feb 24, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants