Skip to content

[JAX] Integrate BF16 Grouped GEMM with on-device group sizes#2680

Draft
jberchtold-nvidia wants to merge 2 commits intoNVIDIA:mainfrom
jberchtold-nvidia:gmm
Draft

[JAX] Integrate BF16 Grouped GEMM with on-device group sizes#2680
jberchtold-nvidia wants to merge 2 commits intoNVIDIA:mainfrom
jberchtold-nvidia:gmm

Conversation

@jberchtold-nvidia
Copy link
Collaborator

Description

Integrate new grouped GEMM from TE common/cuBLASLt that supports on-device group sizes without a D2H memcpy and stream sync. This grouped GEMM is faster and CUDA-graph safe.

TODO:

  • Fix workspace size logic for GMM setup
  • Remove the JAX 64-bit requirement by computing offsets inside the C++ FFI
  • Instead of removing the old implementation, make it based on some flag to the primitive. Initially make this new API disabled by default as it doesn't support quantization, but support is equal, make the new GMM backend the default.

Fixes #2659

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Replaced TE/JAX's existing grouped GEMM that uses a D2H memcpy and stream sync to call a loop of kernels with a faster version that is CUDA-graph safe
  • Exposed a new make_ragged_dot_cls for easy integration into existing models. This will be most useful when quantization is supported and storing recipe state is required

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft February 13, 2026 17:42
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 13, 2026

Greptile Overview

Greptile Summary

This PR replaces JAX's grouped GEMM implementation with a new CUDA-graph-safe version that uses on-device group sizes, eliminating the need for D2H memcpy and stream synchronization. The new implementation uses the TE common grouped GEMM API with nvte_grouped_gemm instead of the previous multi-stream approach with nvte_multi_tensor_gemm.

Major changes:

  • Refactored GroupedGemmFFI in gemm.cpp to use new grouped tensor wrappers (JAXX_GroupedTensorWrapper) that manage grouped tensor lifecycle
  • Updated Python bindings to compute group offsets from group sizes using JAX int64 (temporary workaround until moved to C++)
  • Added make_ragged_dot_cls helper for easy integration into Flax models
  • Fixed output matrix dimension ordering in cublaslt_grouped_gemm.cu setup kernel
  • Currently requires jax_enable_x64=True and does not support quantization

Limitations and TODOs:

  • Quantization support blocked (assertion in make_ragged_dot_cls and check in gemm.cpp)
  • Hardcoded 1MB workspace hack needs proper buffer allocation logic
  • Full output buffer memset on every call (performance concern for large buffers)
  • FP8 leading dimension alignment checks commented out
  • Group offset computation should move from Python/JAX to C++ kernel

Confidence Score: 3/5

  • This PR is safe to merge with moderate risk - the implementation is a major refactoring with known limitations (no quantization, requires jax_enable_x64) but includes proper error checks
  • Score reflects significant implementation changes with several acknowledged TODOs and workarounds: hardcoded workspace allocation, full buffer memsets, commented-out FP8 validation, and temporary JAX int64 requirement. The new grouped GEMM API is cleaner but quantization support is blocked. Core logic appears sound with proper error handling.
  • Pay close attention to transformer_engine/jax/csrc/extensions/gemm.cpp and transformer_engine/jax/cpp_extensions/gemm.py which contain the core refactoring with multiple TODOs and workarounds

Important Files Changed

Filename Overview
transformer_engine/common/gemm/cublaslt_gemm.cu Commented out FP8 leading dimension alignment checks - verify new grouped GEMM handles unaligned inputs correctly
transformer_engine/jax/cpp_extensions/gemm.py Major refactoring to support new grouped GEMM API with on-device group sizes - includes hardcoded workspace hack and requires jax_enable_x64
transformer_engine/jax/csrc/extensions/gemm.cpp Complete rewrite of grouped GEMM implementation - replaces multi-stream approach with new grouped tensor API, adds memset for output buffers, only supports non-quantized mode
transformer_engine/jax/flax/module.py Added make_ragged_dot_cls helper function for grouped GEMM - currently blocks quantization support with assertion

Sequence Diagram

sequenceDiagram
    participant User as JAX/Flax User
    participant Python as gemm.py
    participant FFI as gemm.cpp (FFI)
    participant Wrapper as GroupedTensorWrapper
    participant CuBLAS as nvte_grouped_gemm

    User->>Python: grouped_gemm(lhs, rhs, group_sizes)
    Python->>Python: Compute group_offset from cumsum(group_sizes)
    Python->>Python: Create alpha/beta tensors
    Python->>FFI: GroupedGemmFFI via XLA FFI
    FFI->>FFI: Create workspace buffers (setup + cublas)
    FFI->>Wrapper: make_grouped_tensor(lhs_data, scaling_mode)
    Wrapper-->>FFI: lhs_tensor
    FFI->>Wrapper: set_group_info(group_sizes, group_offset_lhs)
    FFI->>Wrapper: make_grouped_tensor(rhs_data, scaling_mode)
    Wrapper-->>FFI: rhs_tensor
    FFI->>Wrapper: make_grouped_tensor(output, NO_SCALING)
    Wrapper-->>FFI: out_tensor
    FFI->>Wrapper: set_group_info(group_sizes, group_offset_out)
    FFI->>FFI: cudaMemsetAsync(output, 0)
    FFI->>CuBLAS: nvte_grouped_gemm(rhs, lhs, out, alpha, beta, workspaces)
    CuBLAS-->>FFI: Grouped GEMM complete
    FFI-->>Python: output, workspace
    Python-->>User: result
Loading

Last reviewed commit: e6e3bd0

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

group_offset = group_offset or jnp.zeros((1,), jnp.int32)

assert group_offset is None, "group_offset is not yet implemented"
assert (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion error message contains an f-string but doesn't actually format anything since there's no f prefix on the outer string.

Suggested change
assert (
assert jax.config.jax_enable_x64, "Grouped GEMM currently requires jax_enable_x64 to be True for correct behavior"

Comment on lines +1506 to +1508

workspace_size += (
1024 * 1024
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hardcoded 1MB workspace buffer allocation is a HACK as noted. The workspace size calculation should properly account for setup vs cublas workspace needs separately, or this could lead to buffer overruns or inefficient memory usage.

Reference gemm.cpp:669-673 where this workspace is split into setup (1MB) and cublas portions.

Comment on lines +702 to +704

// Output needs to be zeroed in case any group sizes have size zero, meaning the expert weight isn't used in the fwd, meaning the corresponding output gradient should be zero. But using the grouped GEMM, the output buffer contains uninitialized data.
// TODO(jberchtold): make this memset smaller by only zeroing the expert weights that correspond to groups with size zero.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Zeroing the entire output buffer for every grouped GEMM call is inefficient. Consider tracking which expert weights have zero-sized groups and only zeroing those specific regions.

Comment on lines +740 to +742

// This memset is required because the group sizes may not fill the full buffer since we overallocate for the worst case. However, in theory unused space on the grouped axis should not be utilizied downstream, but it seems like somehow it is utilized.
// TODO(jberchtold): try removing this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same concern as line 702 - this memset zeros the entire output buffer on every call. If the buffer is large, this could be a significant performance cost.

Comment on lines +157 to +158
// NVTE_CHECK(ret.lda % 16 == 0,
// "Leading dimension requirement on A for FP8 GEMM. Caller must pad.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check that the new grouped GEMM properly handles FP8 inputs. The leading dimension alignment requirement validation (lda % 16 == 0) has been commented out, which could cause correctness issues if unaligned inputs are passed.

Comment on lines 2097 to +2098

# TODO(jberchtold): move the int64 and offset computation to C++ side in a kernel to avoid needing JAX to support int64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Computing offsets in Python with JAX int64 is a workaround. Move this computation to C++ to avoid requiring jax_enable_x64 and reduce overhead.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

return wrap_function_in_te_state_module(te_dot_general, quantization_recipe, "dot_general")


def make_ragged_dot_cls(quantization_recipe):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion message should clarify what needs to be implemented for quantization support, or provide a reference to a tracking issue.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

NVTE_CHECK(group_sizes.dimensions().size() == 1);
size_t num_gemms = group_sizes.dimensions()[0];

NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed only non-quantized grouped GEMM supported. Ensure this is documented in user-facing docs and that attempting quantization fails gracefully with clear error messages.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Core] TE common nvte_grouped_gemm treats output layout as column-wise instead of rowwise

1 participant