[JAX] Debugging inspect utility#2651
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Greptile OverviewGreptile SummaryIntroduces an experimental debugging utility Key changes:
Issues already flagged in previous threads:
The implementation follows existing patterns in the codebase for primitive registration and FFI integration. Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant JAX as JAX Runtime
participant Python as inspect.py
participant FFI as InspectFFI (C++)
participant GPU as GPU Device
participant FS as File System
User->>Python: inspect_array(x, "my_array")
Python->>Python: _inspect(x)
Python->>Python: _inspect_fwd_rule(x)
Python->>Python: _inspect_array_inner(x)
Python->>Python: Compute min/max/mean/std
Python->>Python: InspectPrimitive.outer_primitive.bind(x, stats)
Python->>JAX: Register primitive call
JAX->>FFI: InspectFFI(stream, input, min, max, mean, std)
FFI->>GPU: cudaMemcpyAsync (tensor data)
FFI->>GPU: cudaMemcpyAsync (stats)
FFI->>GPU: cudaStreamSynchronize()
FFI->>GPU: cudaGetDevice()
FFI->>FS: Write my_tensor_gpu{device}.bin
FFI->>FS: Write my_tensor_gpu{device}_meta.json
FFI->>FS: printf metadata to console
FFI->>JAX: Return (aliased output buffer)
JAX->>Python: Return x (unchanged)
Python->>User: Return x
Last reviewed commit: caf1f6b |
| def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray: | ||
| """Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics. | ||
|
|
||
| Args: | ||
| x (jnp.ndarray): The JAX array to inspect. | ||
| name (str): The name of the array for identification in the output. | ||
| """ | ||
| # TODO: Handle the name of the tensor in the primitive and output files | ||
| return _inspect(x) |
There was a problem hiding this comment.
Unused name argument
inspect_array(x, name) documents printing the array name/shape/dtype/stats, but name is unused and the function currently just returns _inspect(x) (which itself is a no-op). This is misleading API surface and will confuse callers expecting output. Either implement passing/using name (and the actual inspection behavior) or drop the name parameter and update the docstring/export accordingly.
Additional Comments (1)
|
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
da437ca to
f2d1629
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
966b035 to
f56d869
Compare
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci jax |
|
|
||
|
|
||
| def _inspect_array_inner(x: jnp.ndarray) -> jnp.ndarray: | ||
| return InspectPrimitive.outer_primitive.bind( |
There was a problem hiding this comment.
missing assert InspectPrimitive.outer_primitive is not None before bind. Other primitives in this codebase guard this to prevent AttributeError if registration fails (see activation.py:351, amax.py:381, etc.).
| """ | ||
| with open(filename, "rb") as f: | ||
| data = f.read() | ||
| array = jnp.frombuffer(data, dtype=dtype).reshape(shape) |
There was a problem hiding this comment.
Does this handle 4-bit data as well? Will the data printed be legible for e2m1 or e0m3?
This might be useful for debugging quantization outputs
There was a problem hiding this comment.
I tried and unfortunately it doesn't seem to work with fp4. I'm not sure if this is an issue with my FFI or if it's an issue with fp4 support in jnp.frombuffer. But I've added unit tests for the other dtypes with a note that fp4 doesn't work currently
There was a problem hiding this comment.
Sounds good, we can figure out the expansion to 4-bit later
| ): | ||
| """ """ | ||
| output, _ = _inspect_fwd_rule( | ||
| x, |
There was a problem hiding this comment.
is this for eventually inspecting grads? Maybe we can put in a TODO here for inspecting grads?
There was a problem hiding this comment.
This isn't directly for inspecting gradients. If we insert a custom primitive into a JAX graph, JAX errors since it doesn't understand that our inspect FFI is really a no-op and therefore it doesn't know how to compute the gradient of it. So this VJP rule just helps JAX know how to handle this in the backward and pass gradients thru.
This current inspect API wraps a single tensor, which makes sense for the forward. We just wrap some intermediate result in inspect and we can look at the results. Usually this would be immediately after some custom operation we added, e.g. comparing MaxText vs. TE permute. However, since we're just wrapping the tensor, the gradient in this VJP will actually be the incoming gradient from the following operation, not the gradient of our custom op.
So I think we do want support for inspecting gradients, but instead of wrapping individual tensors on the forward, it may make more sense to wrap a given function, like our permute, and let us inspect the forward outputs and gradient outputs from our op (not the incoming gradients)
I have a somewhat hard-coded version of this idea working on my branch that for a GEMM logs stats on fwd output and dgrad and wgrad. If we make it generic enough, something similar could be used for permute and other ops too
Wdyt?
There was a problem hiding this comment.
that seems like a cool idea, but if there are multiple outputs, we need a way to differentiate them. and potentially save them to different bins.
| x (jnp.ndarray): The JAX array to inspect. | ||
| name (str): The name of the array for identification in the output. | ||
| """ | ||
| # TODO: Handle the name of the tensor in the primitive and output files |
There was a problem hiding this comment.
This does not need to be in this PR, but I did ask claude about this and it suggested 2 options:
- We can hash the string name into a number, to pass to XLA, and maintain a hash map on the python side to attach to the filename.
- We do not pass the name to _inspect (like what you are doing here), but we "mv" the file in the filesystem to a different name depending on the argument
name
Wdyt?
There was a problem hiding this comment.
I'm not sure if option 2 will work, since we are running the Python "inspect" function during trace time but the file doesn't get written until the JAX program is jit'd and actually running. So I'm not sure how we'd move the file unless we did something like in option 1 with a dictionary.
I do like option 1, that is a good idea and may be our best option. Another alternative I was considering but haven't looked into yet is a static buffer with a fixed maximum string capacity. On the attention side we pass static args thru a buffer that is a struct that can support decoding fields on the C++ side. There may be a way to pack the string on the Python side and read on the C++ side assuming the string doesn't exceed some fixed char array maxsize. But that idea still needs some investigation, so your option 1 sounds like the most promising.
There was a problem hiding this comment.
option 1 would likely be easier implement IMO since you do not have to try decoding the passed static buffer into multiple fields (I assume the struct might have multiple field?) . Anyways I will leave this here without resolving the thread to re-visit
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci L0 jax |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
e523d99 to
d8045c7
Compare
f952aed to
d5eea08
Compare
3a00057 to
b693f33
Compare
|
/te-ci L0 jax |
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
| } | ||
|
|
||
| // Log the tensor metadata to the console | ||
| printf("[gpu%d]: Tensor data written to %s (shape: [", device, filename.c_str()); | ||
| for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { | ||
| printf("%zu", static_cast<size_t>(input_buf.dimensions()[i])); | ||
| if (i < input_buf.dimensions().size() - 1) { | ||
| printf(", "); | ||
| } | ||
| } | ||
| printf("], dtype: %d", static_cast<int>(input_buf.element_type())); | ||
| printf(", min: %f, max: %f, mean: %f, std: %f)\n", min_val, max_val, mean_val, std_val); | ||
|
|
||
| return ffi_with_cuda_error_check(); |
There was a problem hiding this comment.
Function closes prematurely at line 72, leaving printf and return statements outside function scope. This should cause a compilation error - the closing brace should be moved after line 85.
| } | |
| // Log the tensor metadata to the console | |
| printf("[gpu%d]: Tensor data written to %s (shape: [", device, filename.c_str()); | |
| for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { | |
| printf("%zu", static_cast<size_t>(input_buf.dimensions()[i])); | |
| if (i < input_buf.dimensions().size() - 1) { | |
| printf(", "); | |
| } | |
| } | |
| printf("], dtype: %d", static_cast<int>(input_buf.element_type())); | |
| printf(", min: %f, max: %f, mean: %f, std: %f)\n", min_val, max_val, mean_val, std_val); | |
| return ffi_with_cuda_error_check(); | |
| meta_file << "}"; | |
| meta_file.close(); | |
| // Log the tensor metadata to the console | |
| printf("[gpu%d]: Tensor data written to %s (shape: [", device, filename.c_str()); | |
| for (size_t i = 0; i < input_buf.dimensions().size(); ++i) { | |
| printf("%zu", static_cast<size_t>(input_buf.dimensions()[i])); | |
| if (i < input_buf.dimensions().size() - 1) { | |
| printf(", "); | |
| } | |
| } | |
| printf("], dtype: %d", static_cast<int>(input_buf.element_type())); | |
| printf(", min: %f, max: %f, mean: %f, std: %f)\n", min_val, max_val, mean_val, std_val); | |
| return ffi_with_cuda_error_check(); | |
| } |
ac26696 to
0c09032
Compare
|
/te-ci L0 jax |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
56d6f8d to
8357db4
Compare
for more information, see https://pre-commit.ci
|
/te-ci L0 jax |
| std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin"; | ||
| std::ofstream file(filename, std::ios::binary); | ||
| NVTE_CHECK(file.is_open(), "Failed to create file: ", filename); | ||
| file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size()); | ||
| file.close(); |
There was a problem hiding this comment.
Synchronous file I/O blocks the FFI execution thread. Writing potentially large tensor data (e.g., 256x128 = 32KB+ per tensor) to disk synchronously on every call will cause performance degradation, especially under jax.jit where this may execute frequently. For a debugging utility, consider making file writes optional via environment variable or moving to async I/O.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| _ = jax.jit(f)(x) | ||
|
|
||
| expected = x + 1 | ||
| actual = load_array_dump("my_tensor_gpu0.bin", shape, dtype) |
There was a problem hiding this comment.
Test assumes single GPU (gpu0). On multi-GPU systems where tests might run on gpu1+, this will fail. Check CUDA_VISIBLE_DEVICES or use a device query to determine the actual device ID for the filename.
|
LGTM! Thanks for making this change |
Description
Given
jax.debug.print/callbackis currently broken (issue), this PR introduces an experimental alternative for use for our own internal debugging. This new debugging API allows us to inspect tensors but will be experimental and may have breaking changes without a deprecation process.Usage:
Type of change
Changes
Checklist: