[JAX] TE Permutation integration to Maxtext#2672
Merged
tdophung merged 11 commits intoNVIDIA:mainfrom Feb 13, 2026
Merged
Conversation
for more information, see https://pre-commit.ci
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…ger than num tokens Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: JAX Toolbox <jax@nvidia.com>
for more information, see https://pre-commit.ci
Collaborator
Author
|
This PR contain changes cherry-picked from #2651 . I can wait until this gets merged and then merge mine, but if my PR is needed more urgently, happy to remove the cherry picked change |
Contributor
Greptile OverviewGreptile SummaryThis PR adds necessary changes to support MaxText integration with TE permutation operations when Expert Parallelism (EP) > 1. Key Changes:
Technical Implementation:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as MaxText User
participant API as sort_chunks_by_index
participant FWD as _sort_chunks_by_index_fwd_rule
participant Kernel as _make_chunk_sort_map_kernel
participant BWD as _sort_chunks_by_index_bwd_rule
User->>API: sort_chunks_by_index(inp, split_sizes, sorted_indices)
API->>FWD: Forward pass
FWD->>Kernel: Generate row_id_map with padding handling
Note over Kernel: Compute total_valid_tokens<br/>Apply identity mapping for pid >= total_valid_tokens
Kernel-->>FWD: row_id_map (with padding masked)
FWD->>FWD: sort_chunks_by_map(inp, row_id_map)
FWD-->>API: (output, row_id_map), residuals
Note over FWD: residuals now include split_sizes<br/>and sorted_indices (not nondiff_argnums)
User->>BWD: Backward pass (gradient)
BWD->>BWD: Extract split_sizes, sorted_indices from residuals
BWD->>BWD: sort_chunks_by_map(output_grad, row_id_map, is_forward=False)
BWD-->>User: (inp_grad, zeros_like(split_sizes), zeros_like(sorted_indices))
Last reviewed commit: 11a45d3 |
transformer_engine/jax/inspect.py
Outdated
| _inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule) | ||
|
|
||
|
|
||
| def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray: |
Contributor
There was a problem hiding this comment.
name parameter is unused - not passed to C++ backend or used in filename
Suggested change
| def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray: | |
| def inspect_array(x: jnp.ndarray) -> jnp.ndarray: |
Comment on lines
116
to
120
| std::ofstream file(filename, std::ios::binary); | ||
| if (file.is_open()) { | ||
| file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size()); | ||
| file.close(); | ||
| } |
Contributor
There was a problem hiding this comment.
No error handling if file fails to open - silently continues without writing data
Suggested change
| std::ofstream file(filename, std::ios::binary); | |
| if (file.is_open()) { | |
| file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size()); | |
| file.close(); | |
| } | |
| std::ofstream file(filename, std::ios::binary); | |
| if (!file.is_open()) { | |
| return ffi::Error(ffi::ErrorCode::kInternal, "Failed to open file for writing"); | |
| } | |
| file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size()); | |
| file.close(); |
Collaborator
Author
|
/te-ci |
Collaborator
Author
|
/te_ci |
Signed-off-by: JAX Toolbox <jax@nvidia.com>
…nsformerEngine into maxtext_integ_2
for more information, see https://pre-commit.ci
jberchtold-nvidia
previously approved these changes
Feb 13, 2026
Collaborator
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM, thanks!
Signed-off-by: JAX Toolbox <jax@nvidia.com>
…rmerEngine into maxtext_integ_2
jberchtold-nvidia
approved these changes
Feb 13, 2026
Collaborator
Author
|
/te_ci |
Collaborator
Author
|
/te-ci |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Changes needed on TE side to make maxtext integration works
Issue # 2585
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: