Skip to content

Conversation

@parambole
Copy link
Collaborator

@parambole parambole commented Feb 4, 2026

Description

The fix involves adding a single, targeted sharding annotation to target_token_embedding within the MultiTokenPredictionLayer.__call__ method using sharding.maybe_shard_with_logical. This ensures that sharding is correctly propagated through the MTP layer's internal operations, including the projection and the subsequent transformer layer.

FIXES: b/481469708

Tests

Verified with train_compile.py on the TPU backend and validated on ironwood run, where the model now trains without erros.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Feb 4, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@parambole parambole force-pushed the parambole/481469708 branch from d7c3fa6 to ff787aa Compare February 4, 2026 23:59
@parambole parambole changed the title Fix b/481469708 Fix: Shard vocab axis to resolve MTP HBM OOM (b/481469708) Feb 9, 2026
@parambole parambole force-pushed the parambole/481469708 branch from f26e050 to feaf592 Compare February 9, 2026 21:48
@parambole parambole marked this pull request as ready for review February 9, 2026 21:49
@parambole parambole force-pushed the parambole/481469708 branch from 91c73bc to c068f39 Compare February 9, 2026 22:10
Copy link
Collaborator

@suexu1025 suexu1025 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @parambole

@parambole parambole changed the title Fix: Shard vocab axis to resolve MTP HBM OOM (b/481469708) Fix: Added sharding Constraints for MTP block(b/481469708) Feb 10, 2026
@parambole parambole changed the title Fix: Added sharding Constraints for MTP block(b/481469708) Fix: Added sharding Constraints for MTP block (b/481469708) Feb 10, 2026
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just one comment.

@copybara-service copybara-service bot merged commit 50ef2df into main Feb 11, 2026
30 checks passed
@copybara-service copybara-service bot deleted the parambole/481469708 branch February 11, 2026 22:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants