Skip to content

[PyTorch] Add dtype information to QuantizedTensorStorage class#2676

Open
ptrendx wants to merge 3 commits intoNVIDIA:mainfrom
ptrendx:pr_dtype_in_storage
Open

[PyTorch] Add dtype information to QuantizedTensorStorage class#2676
ptrendx wants to merge 3 commits intoNVIDIA:mainfrom
ptrendx:pr_dtype_in_storage

Conversation

@ptrendx
Copy link
Member

@ptrendx ptrendx commented Feb 12, 2026

Description

This PR adds the fake dtype information to the QuantizedTensorStorage class. This eliminates the need to guess the correct type for dequantize, as was the case in the distributed.py, and it eliminates the unintentional dequantization to FP32 when calling dequantize() on the Storage class with no dtype argument.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Added the _dtype field to the QuantizedTensorStorage class
  • Modified the dequantize call to use that new field when calling dequantize with no arguments
  • Removed guessing of the dtype from distributed.py

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx ptrendx requested a review from timmoon10 February 12, 2026 19:07
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 12, 2026

Greptile Overview

Greptile Summary

This PR adds dtype tracking to quantized tensor storage classes to eliminate dtype guessing during dequantization. Previously, the system had to guess the target dtype (hardcoded to torch.bfloat16 in distributed.py) or default to torch.float32 when calling dequantize() without arguments.

Key changes:

  • Added _dtype field to QuantizedTensorStorage base class to track the "fake" high-precision dtype
  • Modified all storage class constructors (Float8TensorStorage, NVFP8TensorStorage, MXFP8TensorStorage, Float8BlockwiseQTensorStorage) to accept and store fake_dtype parameter
  • Updated dequantize() methods to use self._dtype as default instead of torch.float32
  • Removed dtype guessing in distributed.py (replaced hardcoded torch.bfloat16 with inp._dtype)
  • Updated C++ quantizer code to pass fake_dtype using GetATenDType(dtype)
  • Ensured fake_dtype is propagated through all tensor operations (view, copy, split)

Implementation approach:
The change maintains backward compatibility by defaulting fake_dtype to torch.float32 when not provided. The QuantizedTensor.__new__() validates that fake_dtype matches dtype to prevent inconsistencies. All storage classes follow a consistent pattern where _dtype is set in __new__() and propagated through operations like view().

This is a clean refactoring that improves type safety and eliminates magic constants throughout the codebase.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The implementation is well-structured with consistent patterns across all storage classes, proper validation, backward compatibility via sensible defaults, and comprehensive coverage of all tensor creation paths including C++ bindings
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py Adds fake_dtype parameter to storage class, defaulting to torch.float32, and uses it in dequantize() when dtype not specified
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py Adds fake_dtype parameter with same pattern as Float8TensorStorage, properly propagated in view() method
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py Adds fake_dtype parameter with consistent implementation across storage classes
transformer_engine/pytorch/distributed.py Removes hardcoded torch.bfloat16 dtype guessing, now uses inp._dtype from storage classes
transformer_engine/pytorch/quantized_tensor.py Adds _dtype field to base class, validates fake_dtype matches dtype in __new__, updates dequantize() calls to use default dtype
transformer_engine/pytorch/csrc/quantizer.cpp Adds fake_dtype parameter to all C++ quantizer create_tensor calls using GetATenDType(dtype)

Sequence Diagram

sequenceDiagram
    participant User
    participant Quantizer
    participant TensorStorage
    participant QuantizedTensor
    participant Dequantize

    User->>Quantizer: quantize(tensor, fake_dtype=bfloat16)
    Quantizer->>TensorStorage: __new__(data, fake_dtype=bfloat16)
    TensorStorage->>TensorStorage: _dtype = fake_dtype or float32
    TensorStorage->>QuantizedTensor: super().__new__(dtype=fake_dtype)
    QuantizedTensor->>QuantizedTensor: validate fake_dtype == dtype
    QuantizedTensor->>QuantizedTensor: _dtype = dtype
    QuantizedTensor-->>User: return quantized_tensor

    User->>QuantizedTensor: dequantize()
    QuantizedTensor->>Dequantize: dequantize(dtype=self._dtype)
    Dequantize-->>User: return high_precision_tensor
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ptrendx
Copy link
Member Author

ptrendx commented Feb 12, 2026

/te-ci pytorch

Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this is a big improvement. I have some naming nits.

shape: Iterable[int],
dtype: torch.dtype,
*,
fake_dtype: Optional[torch.dtype] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this redundant with the dtype kwarg?

data: Optional[torch.Tensor],
fp8_scale_inv: torch.Tensor,
fp8_dtype: TE_DType,
fake_dtype: Optional[torch.dtype] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to just name it dtype since QuantizedTensor is already using that name in its constructor.

Suggested change
fake_dtype: Optional[torch.dtype] = None,
dtype: Optional[torch.dtype] = None,

instance = super().__new__(cls, *args, **kwargs)
if cls is NVFP4TensorStorage:
instance = object.__new__(cls)
instance._dtype = fake_dtype if fake_dtype is not None else torch.float32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVFP4 compatibility with FP32 is not very good, so maybe it would be worth changing the default to BF16? The downside is that this default is different from standard PyTorch and from the other quantized tensors.

Suggested change
instance._dtype = fake_dtype if fake_dtype is not None else torch.float32
instance._dtype = fake_dtype if fake_dtype is not None else torch.bfloat16

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants