Skip to content

Conversation

@gagika
Copy link
Collaborator

@gagika gagika commented Feb 8, 2026

Description

This PR adds full support for AllenAI's Olmo3-7B and Olmo3-32B models in MaxText, including checkpoint conversion and forward pass correctness verification.

This implementation addresses several unique architectural features of Olmo3 that required changes to the core layers:

  1. Global QK Normalization: Olmo3 applies RMSNorm across the entire hidden dimension (e.g., 4096) before splitting into heads, whereas MaxText's default RMSNorm applies per-head.

    • Change: Updated src/MaxText/layers/attentions.py to reshape query/key tensors [B, L, H, D] -> [B, L, H*D] before normalization when is_olmo3 is detected.
  2. Mixed RoPE Strategy: Olmo3 uses a hybrid positional embedding strategy where "Sliding Window" layers use standard RoPE, while "Global" layers use YaRN.

    • Change: Updated src/MaxText/layers/olmo3.py to explicitly override the rope_type to "default" for local sliding layers.
    • Change: Updated src/MaxText/layers/attentions.py to accept a rope_type override in __init__, enabling layer-specific RoPE configurations.
  3. Configuration Alignments:

    • Set rope_interleave: False to match Hugging Face's concatenated RoPE.
    • Set rope_truncate: False to prevent frequency drift in YaRN layers.
    • Set normalize_embedding_logits: False as Olmo3 does not normalize output logits.
    • Renamed config files to use hyphens (olmo3-7b.yml) to match standard naming conventions.
  4. Checkpoint Conversion:

    • Added OLMO3_MAXTEXT_TO_HF_PARAM_MAPPING and hooks in param_mapping.py.
    • Implemented identity hooks for Norms to preserve the specific scaling used in Olmo3 checkpoints.

Tests

Tested via checkpoint conversion from Hugging Face and running forward_pass_logit_checker.py to verify KL divergence against the reference HF implementation (BF16 and FP32).

1. Checkpoint Conversion:

python3 -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \
    model_name=olmo3-7b \
    hf_access_token=${HF_TOKEN} \
    base_output_directory=${BASE_OUTPUT_DIRECTORY} \
    scan_layers=True

2. Logit Verification (Olmo3-7B):

python3 -m tests.utils.forward_pass_logit_checker src/MaxText/configs/base.yml \
    model_name=olmo3-7b \
    load_parameters_path=${CHECKPOINT_PATH} \
    tokenizer_path="allenai/Olmo-3-7B-Instruct" \
    hf_model_path="allenai/Olmo-3-7B-Instruct" \
    run_hf_model=True \
    max_kl_div=0.005 \
    scan_layers=True \
    normalize_embedding_logits=False

Tested both Olmo3-7b and Olmo3-32 logits:
Max KL divergence for a single token in the set: 0.000005

https://paste.googleplex.com/4668070767493120
https://paste.googleplex.com/5281833875013632

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@github-actions
Copy link

github-actions bot commented Feb 8, 2026

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@codecov
Copy link

codecov bot commented Feb 8, 2026

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request introduces support for Olmo3 models, including checkpoint conversion and necessary architectural adjustments for features like global QK normalization and mixed RoPE strategies. The changes are well-structured and the code is clear and maintainable.

🔍 General Feedback

  • The implementation of the Olmo3-specific features within the existing architecture is clean and minimally invasive.
  • The addition of checkpoint conversion utilities for Olmo3 is a valuable contribution.
  • The updates to the testing utilities to better support different dtypes improve the overall quality of the test suite.

Overall, this is a high-quality contribution that is ready for merging.

@gagika gagika force-pushed the agagik-olmo3 branch 2 times, most recently from d8d60d8 to 5ff2703 Compare February 8, 2026 23:25
@github-actions
Copy link

github-actions bot commented Feb 8, 2026

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request introduces support for Olmo3 models, which is a valuable addition. The implementation correctly handles the specific architectural features of Olmo3, such as global QK normalization and the mixed RoPE strategy. The checkpoint conversion utilities and updates to the forward pass checker are also well-implemented.

🔍 General Feedback

  • The use of a rope_type override in the attention layer is a clean way to manage the mixed RoPE requirements of Olmo3.
  • The new GlobalRMSNorm is a good example of extending functionality while maintaining a clear separation of concerns.
  • The parameter mapping for checkpoint conversion is thorough and handles both scanned and unscanned layer configurations correctly.

The overall quality of the code is high, and the changes are well-documented in the PR description. I have one minor suggestion to improve the modularity of the attention layer, but it does not block the merge.

@dirkgr
Copy link

dirkgr commented Feb 10, 2026

We talked about this in a call, but one important clarification: Olmo does not use a "mixed RoPE strategy" during pretraining. We pre-train with a context window of 8192 on the full-attention layers, and a context window of 4096 on the sliding window attention layers. No special RoPE sauce at this point. Then, after pre-training, when we do the long-context extension, we use Yarn to extend RoPE on the full attention layers, but we leave the sliding window attention layers the same.

@RissyRan
Copy link
Collaborator

We talked about this in a call, but one important clarification: Olmo does not use a "mixed RoPE strategy" during pretraining. We pre-train with a context window of 8192 on the full-attention layers, and a context window of 4096 on the sliding window attention layers. No special RoPE sauce at this point. Then, after pre-training, when we do the long-context extension, we use Yarn to extend RoPE on the full attention layers, but we leave the sliding window attention layers the same.

To match the logit test, it seems we need to do this in decoder layer. Gagik, do you think we could add a if/else branch to indicate pre/post training mode in this case?

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just minor comments related to that if/else for rope.

@dirkgr
Copy link

dirkgr commented Feb 10, 2026

You needed to do that because you grabbed a model from Huggingface that had already undergone context extension. We did not use this during pre-training. Gagik knows what to do :-)

@gagika gagika force-pushed the agagik-olmo3 branch 2 times, most recently from 022697d to cdbb45a Compare February 11, 2026 05:21
@gagika
Copy link
Collaborator Author

gagika commented Feb 11, 2026

We talked about this in a call, but one important clarification: Olmo does not use a "mixed RoPE strategy" during pretraining. We pre-train with a context window of 8192 on the full-attention layers, and a context window of 4096 on the sliding window attention layers. No special RoPE sauce at this point. Then, after pre-training, when we do the long-context extension, we use Yarn to extend RoPE on the full attention layers, but we leave the sliding window attention layers the same.

@dirkgr I added a new pretraining config olmo3-7b-pt.yml with default RoPE.

I tested logit matching against: https://huggingface.co/allenai/Olmo-3-1025-7B/tree/stage1-step1413814

python3 -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \
    model_name=olmo3-7b-pt \
    hf_access_token=${HF_TOKEN} \
    base_output_directory=${BASE_OUTPUT_DIRECTORY} \
    scan_layers=True \
    use_multimodal=false \
    hardware=cpu \
    skip_jax_distributed_system=true \
    --revision "stage1-step1413814"
  python3 -m tests.utils.forward_pass_logit_checker src/MaxText/configs/base.yml \
    tokenizer_path=${TOKENIZER} \
    load_parameters_path=$OLMO_CHECKPOINT \
    model_name=olmo3-7b-pt \
    scan_layers=true \
    max_prefill_predict_length=16 \
    max_target_length=16 \
    use_multimodal=false \
    attention=dot_product \
    dtype=float32 \
    weight_dtype=float32 \
    --hf_model_path=$HF_LOCAL_PATH \
    --run_hf_model=True \
    --max_kl_div=0.001

The logits match quite well for short inputs, I will test longer text and share top 1 match rate.

cc @RissyRan

@gagika gagika force-pushed the agagik-olmo3 branch 2 times, most recently from 43be593 to 348f2cf Compare February 11, 2026 07:38
@copybara-service copybara-service bot merged commit ef90f2d into main Feb 11, 2026
26 of 28 checks passed
@copybara-service copybara-service bot deleted the agagik-olmo3 branch February 11, 2026 19:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants