Skip to content

[JAX] Debugging inspect utility#2651

Open
jberchtold-nvidia wants to merge 20 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/ffi-inspect
Open

[JAX] Debugging inspect utility#2651
jberchtold-nvidia wants to merge 20 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/ffi-inspect

Conversation

@jberchtold-nvidia
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Feb 4, 2026

Description

Given jax.debug.print/callback is currently broken (issue), this PR introduces an experimental alternative for use for our own internal debugging. This new debugging API allows us to inspect tensors but will be experimental and may have breaking changes without a deprecation process.

Usage:

     x = <some logic to compute x>

      from transformer_engine.jax.debug.experimental import inspect_array as te_inspect_array
      x= te_inspect_array(x, "some_name")

     <something that consumes  x>

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Introduces a new debugging tool for dumping binary blobs of tensor values for multi-GPU inspection.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft February 4, 2026 22:51
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 4, 2026

Greptile Overview

Greptile Summary

Introduces an experimental debugging utility inspect_array for JAX that dumps tensor data to binary files for multi-GPU inspection, addressing broken jax.debug.print/callback functionality.

Key changes:

  • Adds new C++ FFI InspectFFI that copies tensor data from device to host and writes binary dumps with JSON metadata
  • Python wrapper with custom VJP support enables gradient-safe inspection
  • Test coverage for multiple dtypes including fp8 formats
  • Properly structured as experimental API with deprecation warnings

Issues already flagged in previous threads:

  • Hardcoded filename my_tensor_gpu{device}.bin prevents distinguishing between different tensors (name parameter not passed to C++)
  • Unconditional printf will spam output on every execution (consider environment variable gating)
  • File write error handling already implemented via NVTE_CHECK
  • outer_primitive guard already present at line 104

The implementation follows existing patterns in the codebase for primitive registration and FFI integration.

Confidence Score: 3/5

  • Safe to merge with known limitations for experimental debugging feature
  • Score reflects that while the implementation is sound and follows codebase patterns, there are usability issues (hardcoded filenames, console spam) already identified in previous review threads that should be addressed for a production-quality feature. However, since this is explicitly marked as experimental with warnings about breaking changes, these limitations are acceptable for initial merge.
  • Pay attention to transformer_engine/jax/csrc/extensions/inspect.cpp - the hardcoded filename and unconditional printf will cause issues when debugging multiple tensors or in production environments

Important Files Changed

Filename Overview
transformer_engine/jax/csrc/extensions/inspect.cpp New C++ FFI for tensor inspection - writes tensor data to binary files with metadata. Issues already flagged: hardcoded filename, unconditional printf spam, missing name parameter.
transformer_engine/jax/debug/experimental/inspect.py Python wrapper for inspect primitive with custom VJP - properly guards outer_primitive. Name parameter intentionally unused (documented in code).
tests/jax/test_custom_call_compute.py Adds test for inspect FFI across multiple dtypes - validates that dumped tensor matches expected value after computation.

Sequence Diagram

sequenceDiagram
    participant User
    participant JAX as JAX Runtime
    participant Python as inspect.py
    participant FFI as InspectFFI (C++)
    participant GPU as GPU Device
    participant FS as File System

    User->>Python: inspect_array(x, "my_array")
    Python->>Python: _inspect(x)
    Python->>Python: _inspect_fwd_rule(x)
    Python->>Python: _inspect_array_inner(x)
    Python->>Python: Compute min/max/mean/std
    Python->>Python: InspectPrimitive.outer_primitive.bind(x, stats)
    Python->>JAX: Register primitive call
    JAX->>FFI: InspectFFI(stream, input, min, max, mean, std)
    FFI->>GPU: cudaMemcpyAsync (tensor data)
    FFI->>GPU: cudaMemcpyAsync (stats)
    FFI->>GPU: cudaStreamSynchronize()
    FFI->>GPU: cudaGetDevice()
    FFI->>FS: Write my_tensor_gpu{device}.bin
    FFI->>FS: Write my_tensor_gpu{device}_meta.json
    FFI->>FS: printf metadata to console
    FFI->>JAX: Return (aliased output buffer)
    JAX->>Python: Return x (unchanged)
    Python->>User: Return x
Loading

Last reviewed commit: caf1f6b

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 103 to 111
def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray:
"""Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics.

Args:
x (jnp.ndarray): The JAX array to inspect.
name (str): The name of the array for identification in the output.
"""
# TODO: Handle the name of the tensor in the primitive and output files
return _inspect(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused name argument

inspect_array(x, name) documents printing the array name/shape/dtype/stats, but name is unused and the function currently just returns _inspect(x) (which itself is a no-op). This is misleading API surface and will confuse callers expecting output. Either implement passing/using name (and the actual inspection behavior) or drop the name parameter and update the docstring/export accordingly.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 4, 2026

Additional Comments (1)

transformer_engine/jax/cpp_extensions/gemm.py
Disabled cuBLAS alignment check

assert_cublas_requirements no longer enforces contracting_size % alignment == 0 for quantized GEMM (the assert is commented out). This will allow invalid shapes through to the cuBLAS custom call and can trigger runtime failures or incorrect behavior when using FP8/NVFP4 inputs. Please restore the check or replace it with an equivalent validation (and only relax it if the backend truly supports unaligned contracting sizes).

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
pre-commit-ci bot and others added 5 commits February 4, 2026 23:28
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as ready for review February 9, 2026 21:52
jberchtold-nvidia and others added 2 commits February 9, 2026 13:52
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile



def _inspect_array_inner(x: jnp.ndarray) -> jnp.ndarray:
return InspectPrimitive.outer_primitive.bind(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing assert InspectPrimitive.outer_primitive is not None before bind. Other primitives in this codebase guard this to prevent AttributeError if registration fails (see activation.py:351, amax.py:381, etc.).

"""
with open(filename, "rb") as f:
data = f.read()
array = jnp.frombuffer(data, dtype=dtype).reshape(shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this handle 4-bit data as well? Will the data printed be legible for e2m1 or e0m3?
This might be useful for debugging quantization outputs

Copy link
Collaborator Author

@jberchtold-nvidia jberchtold-nvidia Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried and unfortunately it doesn't seem to work with fp4. I'm not sure if this is an issue with my FFI or if it's an issue with fp4 support in jnp.frombuffer. But I've added unit tests for the other dtypes with a note that fp4 doesn't work currently

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, we can figure out the expansion to 4-bit later

):
""" """
output, _ = _inspect_fwd_rule(
x,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this for eventually inspecting grads? Maybe we can put in a TODO here for inspecting grads?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't directly for inspecting gradients. If we insert a custom primitive into a JAX graph, JAX errors since it doesn't understand that our inspect FFI is really a no-op and therefore it doesn't know how to compute the gradient of it. So this VJP rule just helps JAX know how to handle this in the backward and pass gradients thru.

This current inspect API wraps a single tensor, which makes sense for the forward. We just wrap some intermediate result in inspect and we can look at the results. Usually this would be immediately after some custom operation we added, e.g. comparing MaxText vs. TE permute. However, since we're just wrapping the tensor, the gradient in this VJP will actually be the incoming gradient from the following operation, not the gradient of our custom op.

So I think we do want support for inspecting gradients, but instead of wrapping individual tensors on the forward, it may make more sense to wrap a given function, like our permute, and let us inspect the forward outputs and gradient outputs from our op (not the incoming gradients)

I have a somewhat hard-coded version of this idea working on my branch that for a GEMM logs stats on fwd output and dgrad and wgrad. If we make it generic enough, something similar could be used for permute and other ops too

Wdyt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that seems like a cool idea, but if there are multiple outputs, we need a way to differentiate them. and potentially save them to different bins.

x (jnp.ndarray): The JAX array to inspect.
name (str): The name of the array for identification in the output.
"""
# TODO: Handle the name of the tensor in the primitive and output files
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not need to be in this PR, but I did ask claude about this and it suggested 2 options:

  1. We can hash the string name into a number, to pass to XLA, and maintain a hash map on the python side to attach to the filename.
  2. We do not pass the name to _inspect (like what you are doing here), but we "mv" the file in the filesystem to a different name depending on the argument name
    Wdyt?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if option 2 will work, since we are running the Python "inspect" function during trace time but the file doesn't get written until the JAX program is jit'd and actually running. So I'm not sure how we'd move the file unless we did something like in option 1 with a dictionary.

I do like option 1, that is a good idea and may be our best option. Another alternative I was considering but haven't looked into yet is a static buffer with a fixed maximum string capacity. On the attention side we pass static args thru a buffer that is a struct that can support decoding fields on the C++ side. There may be a way to pack the string on the Python side and read on the C++ side assuming the string doesn't exceed some fixed char array maxsize. But that idea still needs some investigation, so your option 1 sounds like the most promising.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

option 1 would likely be easier implement IMO since you do not have to try decoding the passed static buffer into multiple fields (I assume the struct might have multiple field?) . Anyways I will leave this here without resolving the thread to re-visit

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L0 jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L0 jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 72 to 85
}

// Log the tensor metadata to the console
printf("[gpu%d]: Tensor data written to %s (shape: [", device, filename.c_str());
for (size_t i = 0; i < input_buf.dimensions().size(); ++i) {
printf("%zu", static_cast<size_t>(input_buf.dimensions()[i]));
if (i < input_buf.dimensions().size() - 1) {
printf(", ");
}
}
printf("], dtype: %d", static_cast<int>(input_buf.element_type()));
printf(", min: %f, max: %f, mean: %f, std: %f)\n", min_val, max_val, mean_val, std_val);

return ffi_with_cuda_error_check();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function closes prematurely at line 72, leaving printf and return statements outside function scope. This should cause a compilation error - the closing brace should be moved after line 85.

Suggested change
}
// Log the tensor metadata to the console
printf("[gpu%d]: Tensor data written to %s (shape: [", device, filename.c_str());
for (size_t i = 0; i < input_buf.dimensions().size(); ++i) {
printf("%zu", static_cast<size_t>(input_buf.dimensions()[i]));
if (i < input_buf.dimensions().size() - 1) {
printf(", ");
}
}
printf("], dtype: %d", static_cast<int>(input_buf.element_type()));
printf(", min: %f, max: %f, mean: %f, std: %f)\n", min_val, max_val, mean_val, std_val);
return ffi_with_cuda_error_check();
meta_file << "}";
meta_file.close();
// Log the tensor metadata to the console
printf("[gpu%d]: Tensor data written to %s (shape: [", device, filename.c_str());
for (size_t i = 0; i < input_buf.dimensions().size(); ++i) {
printf("%zu", static_cast<size_t>(input_buf.dimensions()[i]));
if (i < input_buf.dimensions().size() - 1) {
printf(", ");
}
}
printf("], dtype: %d", static_cast<int>(input_buf.element_type()));
printf(", min: %f, max: %f, mean: %f, std: %f)\n", min_val, max_val, mean_val, std_val);
return ffi_with_cuda_error_check();
}

@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L0 jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L0 jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +46 to +50
std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin";
std::ofstream file(filename, std::ios::binary);
NVTE_CHECK(file.is_open(), "Failed to create file: ", filename);
file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size());
file.close();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Synchronous file I/O blocks the FFI execution thread. Writing potentially large tensor data (e.g., 256x128 = 32KB+ per tensor) to disk synchronously on every call will cause performance degradation, especially under jax.jit where this may execute frequently. For a debugging utility, consider making file writes optional via environment variable or moving to async I/O.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

_ = jax.jit(f)(x)

expected = x + 1
actual = load_array_dump("my_tensor_gpu0.bin", shape, dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test assumes single GPU (gpu0). On multi-GPU systems where tests might run on gpu1+, this will fail. Check CUDA_VISIBLE_DEVICES or use a device query to determine the actual device ID for the filename.

@tdophung
Copy link
Collaborator

LGTM! Thanks for making this change

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants