diff --git a/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu b/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu index 70dd8b3772..f5c95fba80 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu @@ -290,6 +290,8 @@ __launch_bounds__(fbgemm_gpu::kMaxThreads) void reorder_batched_ad_indices_kerne const auto input_segment_end = cat_ad_offsets[input_segment_offset_end]; const auto num_elements = input_segment_end - input_segment_start; + Dtype* dst_ptr = reordered_cat_ad_indices.data() + output_segment_start; + Dtype* src_ptr = cat_ad_indices.data() + input_segment_start; if (broadcast_indices) { for (auto i = threadIdx.x; i < num_ads_b * num_elements; i += blockDim.x) { reordered_cat_ad_indices[output_segment_start + i] = @@ -308,29 +310,53 @@ __launch_bounds__(fbgemm_gpu::kMaxThreads) void reorder_batched_ad_indices_kerne cat_ad_indices[input_segment_start + i]; } } else if (num_elements > 64 && num_elements <= 128) { - auto dst = - (vec2_t*)(reordered_cat_ad_indices.data() + output_segment_start); - auto src = (vec2_t*)(cat_ad_indices.data() + input_segment_start); - for (auto i = threadIdx.x; i < num_elements / 2; i += blockDim.x) { - dst[i] = src[i]; - } - if ((num_elements % 2) && threadIdx.x == 31) { - reordered_cat_ad_indices[output_segment_start + num_elements - 1] = - cat_ad_indices[input_segment_start + num_elements - 1]; + // Check alignment for vec2_t (8-byte alignment required) + bool vec2_t_aligned = + reinterpret_cast(dst_ptr) % alignof(vec2_t) == 0 && + reinterpret_cast(src_ptr) % alignof(vec2_t) == 0; + if (vec2_t_aligned) { + // Use vectorized loads if properly aligned + auto dst = (vec2_t*)dst_ptr; + auto src = (vec2_t*)src_ptr; + for (auto i = threadIdx.x; i < num_elements / 2; i += blockDim.x) { + dst[i] = src[i]; + } + if ((num_elements % 2) && threadIdx.x == 31) { + reordered_cat_ad_indices[output_segment_start + num_elements - 1] = + cat_ad_indices[input_segment_start + num_elements - 1]; + } + } else { + // Fall back to scalar loads if misaligned + for (auto i = threadIdx.x; i < num_elements; i += blockDim.x) { + reordered_cat_ad_indices[output_segment_start + i] = + cat_ad_indices[input_segment_start + i]; + } } } else if (num_elements > 128) { - auto dst = - (vec4_t*)(reordered_cat_ad_indices.data() + output_segment_start); - auto src = (vec4_t*)(cat_ad_indices.data() + input_segment_start); - for (auto i = threadIdx.x; i < num_elements / 4; i += blockDim.x) { - dst[i] = src[i]; - } - int remainder = num_elements % 4; - if (remainder && threadIdx.x < remainder) { - reordered_cat_ad_indices - [output_segment_start + num_elements - threadIdx.x - 1] = - cat_ad_indices - [input_segment_start + num_elements - threadIdx.x - 1]; + // Check alignment for vec4_t (16-byte alignment required) + bool vec4_t_aligned = + reinterpret_cast(dst_ptr) % alignof(vec4_t) == 0 && + reinterpret_cast(src_ptr) % alignof(vec4_t) == 0; + if (vec4_t_aligned) { + // Use vectorized loads if properly aligned + auto dst = (vec4_t*)dst_ptr; + auto src = (vec4_t*)src_ptr; + for (auto i = threadIdx.x; i < num_elements / 4; i += blockDim.x) { + dst[i] = src[i]; + } + int remainder = num_elements % 4; + if (remainder && threadIdx.x < remainder) { + reordered_cat_ad_indices + [output_segment_start + num_elements - threadIdx.x - 1] = + cat_ad_indices + [input_segment_start + num_elements - threadIdx.x - 1]; + } + } else { + // Fall back to scalar loads if misaligned + for (auto i = threadIdx.x; i < num_elements; i += blockDim.x) { + reordered_cat_ad_indices[output_segment_start + i] = + cat_ad_indices[input_segment_start + i]; + } } } } @@ -432,12 +458,12 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu( cat_ad_offsets.scalar_type(), "reorder_batched_ad_indices_gpu_kernel_2", [&] { + constexpr auto reorder_batched_ad_indices_kernel_name = + reorder_batched_ad_indices_kernel_vec; #if defined __HIP_PLATFORM_AMD__ constexpr auto NUM_WARPS = 4; const dim3 threads(32, NUM_WARPS); // 32 x 4 const dim3 blocks(cuda_calc_xblock_count(B * T, NUM_WARPS)); - constexpr auto reorder_batched_ad_indices_kernel_name = - reorder_batched_ad_indices_kernel_vec; #else constexpr auto NUM_WARPS = 32; auto maxWarpSize = kMaxThreads / NUM_WARPS; @@ -445,8 +471,6 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu( NUM_WARPS, maxWarpSize < kWarpSize ? maxWarpSize : kWarpSize); // 32 x 32 const dim3 blocks(cuda_calc_xblock_count(B * T, NUM_WARPS)); - constexpr auto reorder_batched_ad_indices_kernel_name = - reorder_batched_ad_indices_kernel; #endif FBGEMM_LAUNCH_KERNEL( (reorder_batched_ad_indices_kernel_name),