[Common][PyTorch] Add a new score func sqrtsoftplus to the fused router#2633
[Common][PyTorch] Add a new score func sqrtsoftplus to the fused router#2633yaox12 wants to merge 4 commits intoNVIDIA:mainfrom
sqrtsoftplus to the fused router#2633Conversation
Greptile OverviewGreptile SummaryThis PR adds a new Key Changes:
Precision Strategy: Previous Issues Addressed: Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant PyTorchAPI as PyTorch API
participant CPPExt as C Extension
participant CUDAKernel as CUDA Kernel
participant Memory
User->>PyTorchAPI: Call with sqrtsoftplus
PyTorchAPI->>CPPExt: Forward function call
CPPExt->>CPPExt: Validate score function
CPPExt->>Memory: Allocate FP32 buffers
CPPExt->>CUDAKernel: Launch forward kernel
Note over CUDAKernel: Forward Pass FP32 math
CUDAKernel->>CUDAKernel: Load and cast logits to FP32
CUDAKernel->>Memory: Save original logits
CUDAKernel->>CUDAKernel: Apply sqrtsoftplus
CUDAKernel->>CUDAKernel: Add expert bias
CUDAKernel->>CUDAKernel: TopK selection
CUDAKernel->>CUDAKernel: Normalize if needed
CUDAKernel->>Memory: Write results
CUDAKernel-->>CPPExt: Complete
CPPExt-->>PyTorchAPI: Return tensors
PyTorchAPI-->>User: Return probs and routing map
User->>PyTorchAPI: Call backward
PyTorchAPI->>CPPExt: Backward function call
CPPExt->>CUDAKernel: Launch backward kernel
Note over CUDAKernel: Backward Pass FP32 math
CUDAKernel->>Memory: Load gradients and saved logits
CUDAKernel->>CUDAKernel: Recompute sqrtsoftplus
CUDAKernel->>CUDAKernel: Normalization backward
CUDAKernel->>CUDAKernel: Sqrtsoftplus gradient
CUDAKernel->>Memory: Write grad logits
CUDAKernel-->>CPPExt: Complete
CPPExt-->>PyTorchAPI: Return gradients
PyTorchAPI-->>User: Propagate gradients
|
Additional Comments (1)
The header still documents |
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu
Outdated
Show resolved
Hide resolved
Additional Comments (1)
|
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
70136ee to
3bace73
Compare
transformer_engine/common/fused_router/fused_topk_with_score_function.cu
Show resolved
Hide resolved
Additional Comments (1)
|
| auto scaling_factor_value = scaling_factor.has_value() ? scaling_factor.value() : 1.0f; | ||
| auto score_function_value = score_function_map[score_function]; | ||
| // Init the output tensor | ||
| at::Tensor grad_logits = at::empty( |
There was a problem hiding this comment.
Creating the output tensor on the Python side because the intermediate_output is now always in FP32, while the grad_logits should have the same dtype as the input.
Signed-off-by: Xin Yao <xiny@nvidia.com>
Description
sqrtsoftplusscoresis always in FP32 (match the MCore implementation).intermediate_outputis always in FP32 for better backward precision.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: