[JAX] Integrate BF16 Grouped GEMM with on-device group sizes#2680
[JAX] Integrate BF16 Grouped GEMM with on-device group sizes#2680jberchtold-nvidia wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR replaces JAX's grouped GEMM implementation with a new CUDA-graph-safe version that uses on-device group sizes, eliminating the need for D2H memcpy and stream synchronization. The new implementation uses the TE common grouped GEMM API with Major changes:
Limitations and TODOs:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as JAX/Flax User
participant Python as gemm.py
participant FFI as gemm.cpp (FFI)
participant Wrapper as GroupedTensorWrapper
participant CuBLAS as nvte_grouped_gemm
User->>Python: grouped_gemm(lhs, rhs, group_sizes)
Python->>Python: Compute group_offset from cumsum(group_sizes)
Python->>Python: Create alpha/beta tensors
Python->>FFI: GroupedGemmFFI via XLA FFI
FFI->>FFI: Create workspace buffers (setup + cublas)
FFI->>Wrapper: make_grouped_tensor(lhs_data, scaling_mode)
Wrapper-->>FFI: lhs_tensor
FFI->>Wrapper: set_group_info(group_sizes, group_offset_lhs)
FFI->>Wrapper: make_grouped_tensor(rhs_data, scaling_mode)
Wrapper-->>FFI: rhs_tensor
FFI->>Wrapper: make_grouped_tensor(output, NO_SCALING)
Wrapper-->>FFI: out_tensor
FFI->>Wrapper: set_group_info(group_sizes, group_offset_out)
FFI->>FFI: cudaMemsetAsync(output, 0)
FFI->>CuBLAS: nvte_grouped_gemm(rhs, lhs, out, alpha, beta, workspaces)
CuBLAS-->>FFI: Grouped GEMM complete
FFI-->>Python: output, workspace
Python-->>User: result
Last reviewed commit: e6e3bd0 |
| group_offset = group_offset or jnp.zeros((1,), jnp.int32) | ||
|
|
||
| assert group_offset is None, "group_offset is not yet implemented" | ||
| assert ( |
There was a problem hiding this comment.
The assertion error message contains an f-string but doesn't actually format anything since there's no f prefix on the outer string.
| assert ( | |
| assert jax.config.jax_enable_x64, "Grouped GEMM currently requires jax_enable_x64 to be True for correct behavior" |
|
|
||
| workspace_size += ( | ||
| 1024 * 1024 |
There was a problem hiding this comment.
This hardcoded 1MB workspace buffer allocation is a HACK as noted. The workspace size calculation should properly account for setup vs cublas workspace needs separately, or this could lead to buffer overruns or inefficient memory usage.
Reference gemm.cpp:669-673 where this workspace is split into setup (1MB) and cublas portions.
|
|
||
| // Output needs to be zeroed in case any group sizes have size zero, meaning the expert weight isn't used in the fwd, meaning the corresponding output gradient should be zero. But using the grouped GEMM, the output buffer contains uninitialized data. | ||
| // TODO(jberchtold): make this memset smaller by only zeroing the expert weights that correspond to groups with size zero. |
There was a problem hiding this comment.
Zeroing the entire output buffer for every grouped GEMM call is inefficient. Consider tracking which expert weights have zero-sized groups and only zeroing those specific regions.
|
|
||
| // This memset is required because the group sizes may not fill the full buffer since we overallocate for the worst case. However, in theory unused space on the grouped axis should not be utilizied downstream, but it seems like somehow it is utilized. | ||
| // TODO(jberchtold): try removing this |
There was a problem hiding this comment.
Same concern as line 702 - this memset zeros the entire output buffer on every call. If the buffer is large, this could be a significant performance cost.
| // NVTE_CHECK(ret.lda % 16 == 0, | ||
| // "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); |
There was a problem hiding this comment.
Check that the new grouped GEMM properly handles FP8 inputs. The leading dimension alignment requirement validation (lda % 16 == 0) has been commented out, which could cause correctness issues if unaligned inputs are passed.
|
|
||
| # TODO(jberchtold): move the int64 and offset computation to C++ side in a kernel to avoid needing JAX to support int64 |
There was a problem hiding this comment.
Computing offsets in Python with JAX int64 is a workaround. Move this computation to C++ to avoid requiring jax_enable_x64 and reduce overhead.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| return wrap_function_in_te_state_module(te_dot_general, quantization_recipe, "dot_general") | ||
|
|
||
|
|
||
| def make_ragged_dot_cls(quantization_recipe): |
There was a problem hiding this comment.
The assertion message should clarify what needs to be implemented for quantization support, or provide a reference to a tracking issue.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| NVTE_CHECK(group_sizes.dimensions().size() == 1); | ||
| size_t num_gemms = group_sizes.dimensions()[0]; | ||
|
|
||
| NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, |
There was a problem hiding this comment.
Confirmed only non-quantized grouped GEMM supported. Ensure this is documented in user-facing docs and that attempting quantization fails gracefully with clear error messages.
Description
Integrate new grouped GEMM from TE common/cuBLASLt that supports on-device group sizes without a D2H memcpy and stream sync. This grouped GEMM is faster and CUDA-graph safe.
TODO:
Fixes #2659
Type of change
Changes
make_ragged_dot_clsfor easy integration into existing models. This will be most useful when quantization is supported and storing recipe state is requiredChecklist: