From aa4060e1cd5fa137c561215e7b9ea00b20bedf36 Mon Sep 17 00:00:00 2001 From: yangjianfengo1 Date: Fri, 28 Nov 2025 19:57:34 +0800 Subject: [PATCH] fix deepep fp8 dispatch err --- .../distributed/collective/deep_ep/kernels/internode_ll.cu | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/internode_ll.cu b/paddle/fluid/distributed/collective/deep_ep/kernels/internode_ll.cu index a0589f986ea7c4..4bdea1d6739240 100644 --- a/paddle/fluid/distributed/collective/deep_ep/kernels/internode_ll.cu +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/internode_ll.cu @@ -251,6 +251,10 @@ __global__ __launch_bounds__( if (kUseFP8 && !use_expertwise_scale && kNumPerChannels == -1) { // fp8 per-token dynamic quant __shared__ float amax_cache[num_warps - 1]; + for (int i = thread_id; i < num_warps - 1; i += num_threads) { + amax_cache[i] = 0.0f; + } + asm volatile("bar.sync 1, %0;" ::"r"(num_threads)); float amax = kFP8Margin, scale, scale_inv; #pragma unroll for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {