Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 49 additions & 25 deletions fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ __launch_bounds__(fbgemm_gpu::kMaxThreads) void reorder_batched_ad_indices_kerne
const auto input_segment_end = cat_ad_offsets[input_segment_offset_end];
const auto num_elements = input_segment_end - input_segment_start;

Dtype* dst_ptr = reordered_cat_ad_indices.data() + output_segment_start;
Dtype* src_ptr = cat_ad_indices.data() + input_segment_start;
if (broadcast_indices) {
for (auto i = threadIdx.x; i < num_ads_b * num_elements; i += blockDim.x) {
reordered_cat_ad_indices[output_segment_start + i] =
Expand All @@ -308,29 +310,53 @@ __launch_bounds__(fbgemm_gpu::kMaxThreads) void reorder_batched_ad_indices_kerne
cat_ad_indices[input_segment_start + i];
}
} else if (num_elements > 64 && num_elements <= 128) {
auto dst =
(vec2_t*)(reordered_cat_ad_indices.data() + output_segment_start);
auto src = (vec2_t*)(cat_ad_indices.data() + input_segment_start);
for (auto i = threadIdx.x; i < num_elements / 2; i += blockDim.x) {
dst[i] = src[i];
}
if ((num_elements % 2) && threadIdx.x == 31) {
reordered_cat_ad_indices[output_segment_start + num_elements - 1] =
cat_ad_indices[input_segment_start + num_elements - 1];
// Check alignment for vec2_t (8-byte alignment required)
bool vec2_t_aligned =
reinterpret_cast<uintptr_t>(dst_ptr) % alignof(vec2_t) == 0 &&
reinterpret_cast<uintptr_t>(src_ptr) % alignof(vec2_t) == 0;
if (vec2_t_aligned) {
// Use vectorized loads if properly aligned
auto dst = (vec2_t*)dst_ptr;
auto src = (vec2_t*)src_ptr;
for (auto i = threadIdx.x; i < num_elements / 2; i += blockDim.x) {
dst[i] = src[i];
}
if ((num_elements % 2) && threadIdx.x == 31) {
reordered_cat_ad_indices[output_segment_start + num_elements - 1] =
cat_ad_indices[input_segment_start + num_elements - 1];
}
} else {
// Fall back to scalar loads if misaligned
for (auto i = threadIdx.x; i < num_elements; i += blockDim.x) {
reordered_cat_ad_indices[output_segment_start + i] =
cat_ad_indices[input_segment_start + i];
}
}
} else if (num_elements > 128) {
auto dst =
(vec4_t*)(reordered_cat_ad_indices.data() + output_segment_start);
auto src = (vec4_t*)(cat_ad_indices.data() + input_segment_start);
for (auto i = threadIdx.x; i < num_elements / 4; i += blockDim.x) {
dst[i] = src[i];
}
int remainder = num_elements % 4;
if (remainder && threadIdx.x < remainder) {
reordered_cat_ad_indices
[output_segment_start + num_elements - threadIdx.x - 1] =
cat_ad_indices
[input_segment_start + num_elements - threadIdx.x - 1];
// Check alignment for vec4_t (16-byte alignment required)
bool vec4_t_aligned =
reinterpret_cast<uintptr_t>(dst_ptr) % alignof(vec4_t) == 0 &&
reinterpret_cast<uintptr_t>(src_ptr) % alignof(vec4_t) == 0;
if (vec4_t_aligned) {
// Use vectorized loads if properly aligned
auto dst = (vec4_t*)dst_ptr;
auto src = (vec4_t*)src_ptr;
for (auto i = threadIdx.x; i < num_elements / 4; i += blockDim.x) {
dst[i] = src[i];
}
int remainder = num_elements % 4;
if (remainder && threadIdx.x < remainder) {
reordered_cat_ad_indices
[output_segment_start + num_elements - threadIdx.x - 1] =
cat_ad_indices
[input_segment_start + num_elements - threadIdx.x - 1];
}
} else {
// Fall back to scalar loads if misaligned
for (auto i = threadIdx.x; i < num_elements; i += blockDim.x) {
reordered_cat_ad_indices[output_segment_start + i] =
cat_ad_indices[input_segment_start + i];
}
}
}
}
Expand Down Expand Up @@ -432,21 +458,19 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu(
cat_ad_offsets.scalar_type(),
"reorder_batched_ad_indices_gpu_kernel_2",
[&] {
constexpr auto reorder_batched_ad_indices_kernel_name =
reorder_batched_ad_indices_kernel_vec<scalar_t, index_t>;
#if defined __HIP_PLATFORM_AMD__
constexpr auto NUM_WARPS = 4;
const dim3 threads(32, NUM_WARPS); // 32 x 4
const dim3 blocks(cuda_calc_xblock_count(B * T, NUM_WARPS));
constexpr auto reorder_batched_ad_indices_kernel_name =
reorder_batched_ad_indices_kernel_vec<scalar_t, index_t>;
#else
constexpr auto NUM_WARPS = 32;
auto maxWarpSize = kMaxThreads / NUM_WARPS;
const dim3 threads(
NUM_WARPS,
maxWarpSize < kWarpSize ? maxWarpSize : kWarpSize); // 32 x 32
const dim3 blocks(cuda_calc_xblock_count(B * T, NUM_WARPS));
constexpr auto reorder_batched_ad_indices_kernel_name =
reorder_batched_ad_indices_kernel<scalar_t, index_t>;
#endif
FBGEMM_LAUNCH_KERNEL(
(reorder_batched_ad_indices_kernel_name),
Expand Down
Loading