Skip to content

Conversation

@RissyRan
Copy link
Collaborator

@RissyRan RissyRan commented Feb 9, 2026

Description

  • Update default mhc expansion rate to 1, which is the same as disable the feature.
  • Update the shape of activations in decoders.py when feature is enabled.
  • Enable the loss tracking in MoE layers when using mHC.
  • Update the precision of weights to activation before matmul, which aligns with existing pattern in MaxText
  • Update mHC a little bit to align with normalization (flexible with pre-norm and post-norm). Please note mHC also has a norm across last k * dim dimension inside.

General pre-norm when mHC feature is disabled:

Input (x) ───────────────────────────┐
  │                                  │
  ▼                                  │ (Residual Connection)
[ Pre-Norm ]                         │
  │                                  │
  ▼                                  │
[ Attention / MLP ]                  │
  │                                  │
  ▼                                  │
[ Layer Output ]                     │
  │                                  │
  └───────────► ( + ) ◄──────────────┘
                 │
                 ▼
           Final Output

Pre-norml when mHC feature is enabled (ref):


Input (x) ─────────────────────────-─────────┐
          │                                  │
          ▼                                  │
    [ Pre-Norm ]                             |  
          │    Pre mapping                   |  
          ▼                           residual mapping     
  [ Attention / MLP ]                        │      
          │                                  |
          ▼                                  │   
  [ Layer Output ]                           │          
          │   post mapping                   │             
          |                                  │         
         layer_output                      res_output           
          └──────────► ( + ) ◄────────--─────┘          
                        │                             
                        ▼                             
                  Final Output                 

Tests

  • Update unit tests
  • End-to-end sanity check test - link
  • Check MoE related load balance is captured in TB - link
  • Check mHC end-to-end (500 steps with seed datasets) - comparison. Please note, the paper is comparing mHC vs. HC. Here is comparing mHC vs. baseline.
    • mHC - loss: 6.701 (expansion_rate=4), slightly lower on this toy model with real dataset
    • Without mHC - loss: 6.796 (expansion_rate=1)
# cmd to run

python3 -m MaxText.train maxtext/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=${RUN_NAME} per_device_batch_size=8 enable_checkpointing=false model_name=deepseek-custom ici_fsdp_parallelism=4 steps=500 max_target_length=4096 async_checkpointing=false dtype=bfloat16 weight_dtype=float32 scan_layers=True dataset_type=synthetic attention=dot_product train_split=train dataset_type=hf hf_path='HuggingFaceFW/fineweb-edu' hf_name=default enable_tensorboard=true tokenizer_type=huggingface tokenizer_path=deepseek-ai/DeepSeek-V3.2 data_shuffle_seed=1234
  • DeepSeek v2 sanity tests (expect no impact for existing models)
python3 -m MaxText.train maxtext/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=${RUN_NAME} per_device_batch_size=8 enable_checkpointing=false model_name=deepseek2-16b ici_fsdp_parallelism=4 steps=20 max_target_length=4096 async_checkpointing=false tokenizer_path=src/MaxText/assets/tokenizer.mistral-v1 dtype=bfloat16 weight_dtype=float32 scan_layers=True dataset_type=synthetic attention=flash

# before change

I0210 00:18:46.050221 139944566173248 metric_logger.py:181] completed step: 19, seconds: 4.363, TFLOP/s/device: 123.215, Tokens/s/device: 7510.083, total_weights: 131072, loss: 8.135

# after change

I0210 00:09:42.323213 139931725155904 metric_logger.py:181] completed step: 19, seconds: 4.365, TFLOP/s/device: 123.166, Tokens/s/device: 7507.116, total_weights: 131072, loss: 8.135

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Feb 9, 2026

Codecov Report

❌ Patch coverage is 94.11765% with 4 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/train.py 60.00% 2 Missing and 2 partials ⚠️

📢 Thoughts on this report? Let us know!

@RissyRan RissyRan force-pushed the mhc_integration branch 5 times, most recently from 8a7e2a0 to 5b938c0 Compare February 10, 2026 00:57
@RissyRan RissyRan changed the title [WIP] Integrate MHC with DeepSeek custom model [WIP] Integrate mHC with DeepSeek custom model Feb 10, 2026
@RissyRan RissyRan changed the title [WIP] Integrate mHC with DeepSeek custom model Integrate mHC with DeepSeek custom model Feb 10, 2026
@github-actions
Copy link

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This Pull Request integrates Manifold-Constrained Hyper Connections (mHC) with the DeepSeek custom model. The changes involve modifying decoder layers, deepseek layers, and the mhc implementation itself, along with updates to configuration and unit tests. The implementation appears sound, addressing the core objective of integrating mHC.

🔍 General Feedback

  • The refactoring in src/MaxText/layers/mhc.py to explicitly use self.dtype and self.matmul_precision is a good improvement for clarity and consistency in precision handling.
  • The updated unit tests in tests/unit/mhc_test.py adequately cover the new return values from the mHC module.
  • The addition of an AOT compilation test for mHC integration is a positive step towards ensuring long-term stability and correct compilation.

@RissyRan RissyRan force-pushed the mhc_integration branch 4 times, most recently from 1436770 to 2fdb6e1 Compare February 11, 2026 20:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants