[PyTorch] Add dtype information to QuantizedTensorStorage class#2676
[PyTorch] Add dtype information to QuantizedTensorStorage class#2676ptrendx wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR adds dtype tracking to quantized tensor storage classes to eliminate dtype guessing during dequantization. Previously, the system had to guess the target dtype (hardcoded to Key changes:
Implementation approach: This is a clean refactoring that improves type safety and eliminates magic constants throughout the codebase. Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant Quantizer
participant TensorStorage
participant QuantizedTensor
participant Dequantize
User->>Quantizer: quantize(tensor, fake_dtype=bfloat16)
Quantizer->>TensorStorage: __new__(data, fake_dtype=bfloat16)
TensorStorage->>TensorStorage: _dtype = fake_dtype or float32
TensorStorage->>QuantizedTensor: super().__new__(dtype=fake_dtype)
QuantizedTensor->>QuantizedTensor: validate fake_dtype == dtype
QuantizedTensor->>QuantizedTensor: _dtype = dtype
QuantizedTensor-->>User: return quantized_tensor
User->>QuantizedTensor: dequantize()
QuantizedTensor->>Dequantize: dequantize(dtype=self._dtype)
Dequantize-->>User: return high_precision_tensor
|
|
/te-ci pytorch |
timmoon10
left a comment
There was a problem hiding this comment.
Overall this is a big improvement. I have some naming nits.
| shape: Iterable[int], | ||
| dtype: torch.dtype, | ||
| *, | ||
| fake_dtype: Optional[torch.dtype] = None, |
There was a problem hiding this comment.
Isn't this redundant with the dtype kwarg?
| data: Optional[torch.Tensor], | ||
| fp8_scale_inv: torch.Tensor, | ||
| fp8_dtype: TE_DType, | ||
| fake_dtype: Optional[torch.dtype] = None, |
There was a problem hiding this comment.
I'd prefer to just name it dtype since QuantizedTensor is already using that name in its constructor.
| fake_dtype: Optional[torch.dtype] = None, | |
| dtype: Optional[torch.dtype] = None, |
| instance = super().__new__(cls, *args, **kwargs) | ||
| if cls is NVFP4TensorStorage: | ||
| instance = object.__new__(cls) | ||
| instance._dtype = fake_dtype if fake_dtype is not None else torch.float32 |
There was a problem hiding this comment.
NVFP4 compatibility with FP32 is not very good, so maybe it would be worth changing the default to BF16? The downside is that this default is different from standard PyTorch and from the other quantized tensors.
| instance._dtype = fake_dtype if fake_dtype is not None else torch.float32 | |
| instance._dtype = fake_dtype if fake_dtype is not None else torch.bfloat16 |
Description
This PR adds the fake dtype information to the QuantizedTensorStorage class. This eliminates the need to guess the correct type for dequantize, as was the case in the distributed.py, and it eliminates the unintentional dequantization to FP32 when calling dequantize() on the Storage class with no dtype argument.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: