Skip to content

[PyTorch][Fused Attn] Add support for cuDNN to return Softmax Stats always and Max when return_max_logit=True#2677

Open
sudhakarsingh27 wants to merge 8 commits intoNVIDIA:mainfrom
sudhakarsingh27:fix_return_stats_max_cudnn
Open

[PyTorch][Fused Attn] Add support for cuDNN to return Softmax Stats always and Max when return_max_logit=True#2677
sudhakarsingh27 wants to merge 8 commits intoNVIDIA:mainfrom
sudhakarsingh27:fix_return_stats_max_cudnn

Conversation

@sudhakarsingh27
Copy link
Collaborator

Description

cuDNN recently made returning any subset of {Stats, SumExp, Max} possible. This PR adapts TE to always get Stats from cuDNN and Max tensor if return_max_logit=True. (Note that Stats = log(SumExp)+Max)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • fused_attn_f16_arbitrary_seqlen.cu
    • Removed references to SumExp tensor as it's not needed since cuDNN returns Stats by default.
    • set generate_stats=True which forces cuDNN to always return Stats tensor (needed in the backward pass)
  • transformer_engine/pytorch/cpp_extensions/fused_attn.py
    • Remove code that manually did Stats = log(SumExp) + Max since cuDNN returns Stats directly and TE doesn't need SumExp from cuDNN
  • Corresponding documentation

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

sudhakarsingh27 and others added 5 commits February 12, 2026 13:12
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 12, 2026

Greptile Overview

Greptile Summary

This PR adapts Transformer Engine to leverage cuDNN's new capability to return Stats (log(SumExp)+Max) directly, eliminating the need for manual computation. The changes simplify the code by removing Sum_Exp tensor handling throughout the stack.

Key changes:

  • Set generate_stats=true to force cuDNN to always return Stats tensor (required for backward pass)
  • When return_max_logit=True, cuDNN now returns both Stats and Max tensors (previously returned Max and Sum_Exp)
  • Removed Python-side computation of Stats = Max + log(Sum_Exp) since cuDNN provides Stats directly
  • Updated tensor ordering: Stats is now always first, followed by Max when requested
  • Updated comments and documentation to reflect the new tensor order

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The changes are well-structured and maintain consistency across all three modified files. The tensor order is correctly updated from (Max, Sum_Exp) to (Stats, Max), eliminating redundant computation. The implementation correctly leverages cuDNN's new capability to return Stats directly, and all comments have been updated to reflect the new behavior.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu Updated cuDNN integration to always return Stats tensor and conditionally return Max when return_max_logit=True. Removed Sum_Exp tensor handling as Stats is now returned directly from cuDNN.
transformer_engine/pytorch/cpp_extensions/fused_attn.py Removed manual Stats computation (Stats = Max + log(Sum_Exp)). Now directly uses Stats from cuDNN and extracts Max from output_tensors[2] when return_max_logit=True.
transformer_engine/pytorch/csrc/extensions/attention.cpp Updated comments to reflect new tensor order: Stats is always first, followed by Max when return_max_logit=True (previously was Max, then Sum_Exp).

Sequence Diagram

sequenceDiagram
    participant Py as Python (fused_attn.py)
    participant CPP as C++ (attention.cpp)
    participant CU as CUDA (fused_attn_f16_arbitrary_seqlen.cu)
    participant cuDNN as cuDNN Backend
    
    Py->>CPP: fused_attn_fwd(return_max_logit=True)
    CPP->>CU: fused_attn_arbitrary_seqlen_fwd()
    Note over CU: Set generate_stats=true (always)
    CU->>cuDNN: sdpa() with logit_max option
    Note over cuDNN: Computes Stats internally<br/>Stats = log(SumExp) + Max
    cuDNN-->>CU: Returns: O, Stats, Max
    Note over CU: Tensor order: S1=Stats, S2=Max
    CU-->>CPP: output_tensors[0]=O, [1]=Stats, [2]=Max
    CPP-->>Py: output_tensors list
    Note over Py: Extract max_logit = amax(output_tensors[2])<br/>Store Stats for backward pass
    Py-->>Py: Return: O, aux_ctx_tensors=[Stats, ...], max_logit
Loading

Last reviewed commit: 260380b

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…27/TransformerEngine into fix_return_stats_max_cudnn
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant