-
Notifications
You must be signed in to change notification settings - Fork 467
Olmo3 checkpoint conversion and Refactor Olmo3 model to support interleaved RoPE (and attention) #3112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📋 Review Summary
This pull request introduces support for Olmo3 models, including checkpoint conversion and necessary architectural adjustments for features like global QK normalization and mixed RoPE strategies. The changes are well-structured and the code is clear and maintainable.
🔍 General Feedback
- The implementation of the Olmo3-specific features within the existing architecture is clean and minimally invasive.
- The addition of checkpoint conversion utilities for Olmo3 is a valuable contribution.
- The updates to the testing utilities to better support different dtypes improve the overall quality of the test suite.
Overall, this is a high-quality contribution that is ready for merging.
d8d60d8 to
5ff2703
Compare
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📋 Review Summary
This pull request introduces support for Olmo3 models, which is a valuable addition. The implementation correctly handles the specific architectural features of Olmo3, such as global QK normalization and the mixed RoPE strategy. The checkpoint conversion utilities and updates to the forward pass checker are also well-implemented.
🔍 General Feedback
- The use of a
rope_typeoverride in the attention layer is a clean way to manage the mixed RoPE requirements of Olmo3. - The new
GlobalRMSNormis a good example of extending functionality while maintaining a clear separation of concerns. - The parameter mapping for checkpoint conversion is thorough and handles both scanned and unscanned layer configurations correctly.
The overall quality of the code is high, and the changes are well-documented in the PR description. I have one minor suggestion to improve the modularity of the attention layer, but it does not block the merge.
|
We talked about this in a call, but one important clarification: Olmo does not use a "mixed RoPE strategy" during pretraining. We pre-train with a context window of 8192 on the full-attention layers, and a context window of 4096 on the sliding window attention layers. No special RoPE sauce at this point. Then, after pre-training, when we do the long-context extension, we use Yarn to extend RoPE on the full attention layers, but we leave the sliding window attention layers the same. |
To match the logit test, it seems we need to do this in decoder layer. Gagik, do you think we could add a if/else branch to indicate pre/post training mode in this case? |
RissyRan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just minor comments related to that if/else for rope.
|
You needed to do that because you grabbed a model from Huggingface that had already undergone context extension. We did not use this during pre-training. Gagik knows what to do :-) |
022697d to
cdbb45a
Compare
@dirkgr I added a new pretraining config olmo3-7b-pt.yml with default RoPE. I tested logit matching against: https://huggingface.co/allenai/Olmo-3-1025-7B/tree/stage1-step1413814 The logits match quite well for short inputs, I will test longer text and share top 1 match rate. cc @RissyRan |
43be593 to
348f2cf
Compare
Description
This PR adds full support for AllenAI's Olmo3-7B and Olmo3-32B models in MaxText, including checkpoint conversion and forward pass correctness verification.
This implementation addresses several unique architectural features of Olmo3 that required changes to the core layers:
Global QK Normalization: Olmo3 applies RMSNorm across the entire hidden dimension (e.g., 4096) before splitting into heads, whereas MaxText's default
RMSNormapplies per-head.src/MaxText/layers/attentions.pyto reshape query/key tensors[B, L, H, D] -> [B, L, H*D]before normalization whenis_olmo3is detected.Mixed RoPE Strategy: Olmo3 uses a hybrid positional embedding strategy where "Sliding Window" layers use standard RoPE, while "Global" layers use YaRN.
src/MaxText/layers/olmo3.pyto explicitly override therope_typeto"default"for local sliding layers.src/MaxText/layers/attentions.pyto accept arope_typeoverride in__init__, enabling layer-specific RoPE configurations.Configuration Alignments:
rope_interleave: Falseto match Hugging Face's concatenated RoPE.rope_truncate: Falseto prevent frequency drift in YaRN layers.normalize_embedding_logits: Falseas Olmo3 does not normalize output logits.olmo3-7b.yml) to match standard naming conventions.Checkpoint Conversion:
OLMO3_MAXTEXT_TO_HF_PARAM_MAPPINGand hooks inparam_mapping.py.Tests
Tested via checkpoint conversion from Hugging Face and running
forward_pass_logit_checker.pyto verify KL divergence against the reference HF implementation (BF16 and FP32).1. Checkpoint Conversion:
python3 -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \ model_name=olmo3-7b \ hf_access_token=${HF_TOKEN} \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ scan_layers=True2. Logit Verification (Olmo3-7B):
python3 -m tests.utils.forward_pass_logit_checker src/MaxText/configs/base.yml \ model_name=olmo3-7b \ load_parameters_path=${CHECKPOINT_PATH} \ tokenizer_path="allenai/Olmo-3-7B-Instruct" \ hf_model_path="allenai/Olmo-3-7B-Instruct" \ run_hf_model=True \ max_kl_div=0.005 \ scan_layers=True \ normalize_embedding_logits=FalseTested both Olmo3-7b and Olmo3-32 logits:
Max KL divergence for a single token in the set: 0.000005https://paste.googleplex.com/4668070767493120
https://paste.googleplex.com/5281833875013632
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.