Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions transformer_engine/common/fused_router/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ __device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, ReduceFuncT
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
volatile double val = lane_id < data_size ? static_cast<double>(data_ptr[lane_id]) : default_val;
double val = lane_id < data_size ? static_cast<double>(data_ptr[lane_id]) : default_val;
for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
val = reduce_func(val, data_ptr[i]);
}
Expand Down Expand Up @@ -85,7 +85,7 @@ __device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int dat
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
volatile double val =
double val =
lane_id < data_size && mask[lane_id] ? static_cast<double>(data_ptr[lane_id]) : default_val;
for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
if (mask[i]) {
Expand Down Expand Up @@ -183,25 +183,25 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i
// After looping topk times, the topk_indices will be the topk indices
for (int k = 0; k < topk; k++) {
// Find the max value and its index
volatile double val = (lane_id < data_size && !is_masked(k, lane_id))
? static_cast<double>(scores[lane_id])
: -std::numeric_limits<double>::infinity();
volatile int index = (lane_id < data_size) ? lane_id : 0;
double val = (lane_id < data_size && !is_masked(k, lane_id))
? static_cast<double>(scores[lane_id])
: -std::numeric_limits<double>::infinity();
int index = (lane_id < data_size) ? lane_id : 0;
// Some value is hanlded in local thread
// Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ...
// Reduce the value in local thread
for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) {
volatile double cur_val = (is_masked(k, i)) ? -std::numeric_limits<double>::infinity()
: static_cast<double>(scores[i]);
double cur_val = (is_masked(k, i)) ? -std::numeric_limits<double>::infinity()
: static_cast<double>(scores[i]);
if (cur_val > val) {
val = cur_val;
index = i;
}
}
// Warp shuffle between threads
for (int s = 16; s > 0; s /= 2) {
volatile auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s);
volatile auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s);
auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s);
auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s);
if (shuffled_val > val) {
val = shuffled_val;
index = shuffled_index;
Expand Down