Skip to content

Fix FP8 block scaling with sequence parallel#2637

Open
cuichenx wants to merge 7 commits intoNVIDIA:mainfrom
cuichenx:chcui/fix_subchannel_fp8+sp
Open

Fix FP8 block scaling with sequence parallel#2637
cuichenx wants to merge 7 commits intoNVIDIA:mainfrom
cuichenx:chcui/fix_subchannel_fp8+sp

Conversation

@cuichenx
Copy link
Contributor

@cuichenx cuichenx commented Jan 31, 2026

Description

Problem

Using Float8BlockQuantizer with sequence parallel fails with AssertionError: All-gather requires quantizable tensor for quantizer Float8BlockQuantizer when local tensor dimensions aren't divisible by 128.

Solution

Skip the assert_dim_for_all_gather check for Float8BlockQuantizer since gather_along_first_dim already has a fallback path
Fix the fallback in _start_all_gather_fp8_blockwise to handle already-quantized inputs by dequantizing before high-precision all-gather

###Note
The fallback path (high-precision all-gather → quantize) may increase the communication overhead.

Verification

The code change does not alter convergence behavior
image

When SP is True, the previous code did not run. When SP is False, this change doesn't affect anything.
image

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Chen Cui <chcui@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 31, 2026

Greptile Overview

Greptile Summary

This PR fixes an AssertionError when using Float8BlockQuantizer with sequence parallel on tensors whose dimensions aren't divisible by 128 (the block length).

Key Changes:

  • Removed assert_dim_for_all_gather checks from layernorm_linear, layernorm_mlp, and linear modules, since the fallback path can handle non-quantizable dimensions
  • Enhanced fallback paths in _start_all_gather_fp8_blockwise, _all_gather_nvfp4, and _all_gather_mxfp8 to dequantize already-quantized inputs before performing high-precision all-gather
  • Fixed dequantize() calls to pass explicit dtype parameter for consistency
  • Added warning messages when falling back to high-precision all-gather

Issues Found:

  • Three new fallback paths are missing .contiguous() calls before all_gather_into_tensor, which can cause runtime errors with non-contiguous tensors (e.g., after transpose or slicing). Other all-gather paths in the codebase consistently use .contiguous().

Confidence Score: 3/5

  • This PR is mostly safe but has a critical bug in the fallback paths that needs to be fixed
  • The core logic correctly removes unnecessary assertions and adds proper fallback handling for non-quantizable dimensions. However, three new all-gather fallback paths are missing .contiguous() calls, which can cause runtime errors with non-contiguous tensors. This is a well-understood pattern violation (other paths use .contiguous()), making it straightforward to fix.
  • Pay close attention to transformer_engine/pytorch/distributed.py - the three missing .contiguous() calls must be added before merging

Important Files Changed

Filename Overview
transformer_engine/pytorch/distributed.py Adds fallback path for FP8 block quantization when dimensions aren't divisible by 128. Fixed dequantize calls to pass dtype parameter. Missing .contiguous() calls in three all-gather fallback paths.
transformer_engine/pytorch/module/layernorm_linear.py Removed assert_dim_for_all_gather check for Float8BlockQuantizer, allowing fallback path to handle non-quantizable dimensions
transformer_engine/pytorch/module/layernorm_mlp.py Removed assert_dim_for_all_gather check for Float8BlockQuantizer, allowing fallback path to handle non-quantizable dimensions
transformer_engine/pytorch/module/linear.py Removed assert_dim_for_all_gather check for Float8BlockQuantizer, allowing fallback path to handle non-quantizable dimensions
transformer_engine/pytorch/utils.py Removed assert_dim_for_all_gather function that was checking quantizer compatibility - no longer needed with fallback paths

Sequence Diagram

sequenceDiagram
    participant Module as LayerNorm/Linear Module
    participant Utils as utils.py
    participant Gather as gather_along_first_dim
    participant FP8Block as _start_all_gather_fp8_blockwise
    participant Quantizer as Float8BlockQuantizer
    participant NCCL as torch.distributed

    Module->>Module: Remove assert_dim_for_all_gather check
    Module->>Gather: Call with sequence_parallel=True
    Gather->>FP8Block: Route to FP8 blockwise handler
    FP8Block->>Quantizer: is_quantizable(inp)?
    
    alt Tensor dimensions divisible by 128
        Quantizer-->>FP8Block: True
        FP8Block->>FP8Block: Quantize to FP8 blockwise
        FP8Block->>NCCL: all_gather FP8 data + scales
        NCCL-->>FP8Block: Gathered FP8 tensor
    else Dimensions not divisible by 128
        Quantizer-->>FP8Block: False
        FP8Block->>FP8Block: Check if already quantized
        alt Already quantized
            FP8Block->>FP8Block: Dequantize to high precision
        end
        FP8Block->>NCCL: all_gather in high precision
        NCCL-->>FP8Block: Gathered high-precision tensor
        FP8Block->>Quantizer: Quantize output
        Quantizer-->>FP8Block: FP8 blockwise tensor
    end
    
    FP8Block-->>Module: Return gathered tensor
Loading

greptile-apps[bot]

This comment was marked as outdated.

@cyanguwa cyanguwa requested a review from timmoon10 February 2, 2026 18:48
@timmoon10

This comment was marked as outdated.

timmoon10

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as outdated.

@timmoon10 timmoon10 self-requested a review February 2, 2026 19:38
Perform all-gather in high-precision if the input tensor is too small to quantize.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

timmoon10
timmoon10 previously approved these changes Feb 2, 2026
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tweaked the module logic so that it no longer checks that the local tensor can be quantized. In addition to FP8 block scaling, MXFP8 and NVFP4 will also support high-precision all-gathers. That said, this may not be fully robust for LayerNormLinear and LayerNormMLP, since the norm kernel may still expect to output to a quantized local tensor.

@timmoon10

This comment was marked as outdated.

@timmoon10
Copy link
Collaborator

/te-ci pytorch L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

out = torch.empty(out_shape, dtype=dtype, device=device)
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize() # Dequantize if needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dequantize() method defaults to dtype=torch.float32. Consider whether this is always appropriate for the fallback path, especially when the original tensor might have been in a different precision (e.g., bfloat16).

):
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize() # Dequantize if needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same consideration as in _start_all_gather_fp8_blockwise: the dequantize() method defaults to dtype=torch.float32, which may not match the original tensor's precision.

):
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize() # Dequantize if needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same consideration as in _start_all_gather_fp8_blockwise: the dequantize() method defaults to dtype=torch.float32, which may not match the original tensor's precision.

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx
Copy link
Member

ptrendx commented Feb 10, 2026

/te-ci pytorch L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 1102 to 1108
if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1:
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize(dtype=dtype) # Dequantize if needed
out = torch.empty(out_shape, dtype=dtype, device=device)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
out = quantizer(out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-contiguous gather input

In the new high-precision fallback (if not quantizer.is_quantizable(inp) ...), all_gather_into_tensor(out, inp, ...) passes inp directly. Elsewhere in this same module the plain-tensor path uses inp.contiguous() (distributed.py:1737-1742) and the FP8 path uses _data.contiguous() (distributed.py:1031-1035), which strongly suggests the collective expects contiguous inputs. If inp is a non-contiguous view (common after transpose/slicing), this fallback can raise at runtime. This same issue also appears in the NVFP4 and MXFP8 high-precision fallbacks (distributed.py:1353 and :1523).

@ptrendx
Copy link
Member

ptrendx commented Feb 12, 2026

/te-ci pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize(dtype=dtype) # Dequantize if needed
out = torch.empty(out_shape, dtype=dtype, device=device)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing .contiguous() call on inp before all-gather

Other all-gather paths in this file use .contiguous() (lines 1739, 1033). Non-contiguous tensors (from transpose/slicing) can cause runtime errors.

Suggested change
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
torch.distributed.all_gather_into_tensor(out, inp.contiguous(), group=process_group, async_op=False)

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 12, 2026

Additional Comments (2)

transformer_engine/pytorch/distributed.py
Missing .contiguous() call on inp before all-gather

        torch.distributed.all_gather_into_tensor(out, inp.contiguous(), group=process_group)

transformer_engine/pytorch/distributed.py
Missing .contiguous() call on inp before all-gather

        torch.distributed.all_gather_into_tensor(out, inp.contiguous(), group=process_group)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants