Fix FP8 block scaling with sequence parallel#2637
Fix FP8 block scaling with sequence parallel#2637cuichenx wants to merge 7 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Chen Cui <chcui@nvidia.com>
Greptile OverviewGreptile SummaryThis PR fixes an Key Changes:
Issues Found:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Module as LayerNorm/Linear Module
participant Utils as utils.py
participant Gather as gather_along_first_dim
participant FP8Block as _start_all_gather_fp8_blockwise
participant Quantizer as Float8BlockQuantizer
participant NCCL as torch.distributed
Module->>Module: Remove assert_dim_for_all_gather check
Module->>Gather: Call with sequence_parallel=True
Gather->>FP8Block: Route to FP8 blockwise handler
FP8Block->>Quantizer: is_quantizable(inp)?
alt Tensor dimensions divisible by 128
Quantizer-->>FP8Block: True
FP8Block->>FP8Block: Quantize to FP8 blockwise
FP8Block->>NCCL: all_gather FP8 data + scales
NCCL-->>FP8Block: Gathered FP8 tensor
else Dimensions not divisible by 128
Quantizer-->>FP8Block: False
FP8Block->>FP8Block: Check if already quantized
alt Already quantized
FP8Block->>FP8Block: Dequantize to high precision
end
FP8Block->>NCCL: all_gather in high precision
NCCL-->>FP8Block: Gathered high-precision tensor
FP8Block->>Quantizer: Quantize output
Quantizer-->>FP8Block: FP8 blockwise tensor
end
FP8Block-->>Module: Return gathered tensor
|
This comment was marked as outdated.
This comment was marked as outdated.
Perform all-gather in high-precision if the input tensor is too small to quantize. Signed-off-by: Tim Moon <tmoon@nvidia.com>
timmoon10
left a comment
There was a problem hiding this comment.
I've tweaked the module logic so that it no longer checks that the local tensor can be quantized. In addition to FP8 block scaling, MXFP8 and NVFP4 will also support high-precision all-gathers. That said, this may not be fully robust for LayerNormLinear and LayerNormMLP, since the norm kernel may still expect to output to a quantized local tensor.
This comment was marked as outdated.
This comment was marked as outdated.
|
/te-ci pytorch L1 |
| out = torch.empty(out_shape, dtype=dtype, device=device) | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize() # Dequantize if needed |
There was a problem hiding this comment.
The dequantize() method defaults to dtype=torch.float32. Consider whether this is always appropriate for the fallback path, especially when the original tensor might have been in a different precision (e.g., bfloat16).
| ): | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize() # Dequantize if needed |
There was a problem hiding this comment.
Same consideration as in _start_all_gather_fp8_blockwise: the dequantize() method defaults to dtype=torch.float32, which may not match the original tensor's precision.
| ): | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize() # Dequantize if needed |
There was a problem hiding this comment.
Same consideration as in _start_all_gather_fp8_blockwise: the dequantize() method defaults to dtype=torch.float32, which may not match the original tensor's precision.
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
|
/te-ci pytorch L1 |
| if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1: | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize(dtype=dtype) # Dequantize if needed | ||
| out = torch.empty(out_shape, dtype=dtype, device=device) | ||
| torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) | ||
| out = quantizer(out) |
There was a problem hiding this comment.
Non-contiguous gather input
In the new high-precision fallback (if not quantizer.is_quantizable(inp) ...), all_gather_into_tensor(out, inp, ...) passes inp directly. Elsewhere in this same module the plain-tensor path uses inp.contiguous() (distributed.py:1737-1742) and the FP8 path uses _data.contiguous() (distributed.py:1031-1035), which strongly suggests the collective expects contiguous inputs. If inp is a non-contiguous view (common after transpose/slicing), this fallback can raise at runtime. This same issue also appears in the NVFP4 and MXFP8 high-precision fallbacks (distributed.py:1353 and :1523).
|
/te-ci pytorch |
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize(dtype=dtype) # Dequantize if needed | ||
| out = torch.empty(out_shape, dtype=dtype, device=device) | ||
| torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) |
There was a problem hiding this comment.
Missing .contiguous() call on inp before all-gather
Other all-gather paths in this file use .contiguous() (lines 1739, 1033). Non-contiguous tensors (from transpose/slicing) can cause runtime errors.
| torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) | |
| torch.distributed.all_gather_into_tensor(out, inp.contiguous(), group=process_group, async_op=False) |
Additional Comments (2)
|
Description
Problem
Using Float8BlockQuantizer with sequence parallel fails with
AssertionError: All-gather requires quantizable tensor for quantizer Float8BlockQuantizerwhen local tensor dimensions aren't divisible by 128.Solution
Skip the assert_dim_for_all_gather check for Float8BlockQuantizer since gather_along_first_dim already has a fallback path
Fix the fallback in _start_all_gather_fp8_blockwise to handle already-quantized inputs by dequantizing before high-precision all-gather
###Note
The fallback path (high-precision all-gather → quantize) may increase the communication overhead.
Verification
The code change does not alter convergence behavior

When SP is True, the previous code did not run. When SP is False, this change doesn't affect anything.

Type of change
Changes
Please list the changes introduced in this PR:
Checklist: