From 1391466264b080514eb79c8eeae25621c82be11e Mon Sep 17 00:00:00 2001 From: Aya Ibrahim Date: Thu, 9 Oct 2025 15:16:08 -0700 Subject: [PATCH 1/4] fmt --- .../blackwell_gen_impl.cu | 95 +- ...m100_fmha_gen_mainloop_warpspecialized.hpp | 670 ++++--- ...m100_fmha_load_cpasync_warpspecialized.hpp | 267 +-- .../sm100_fmha_gen_kernel_warpspecialized.hpp | 531 +++--- .../sm100_fmha_mla_tma_warpspecialized.hpp | 1607 ++++++++++------- 5 files changed, 1926 insertions(+), 1244 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu index 71c4603eea..543e575b88 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu @@ -6,8 +6,8 @@ * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -19,14 +19,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ @@ -89,7 +90,8 @@ struct GenRunner { using ElementOut = cutlass::bfloat16_t; using ProblemShape = - Shape<_1, int, int, Shape, int>>; // (Sq, Sk, D, ((H, Hr), B)) + Shape<_1, int, int, Shape, int>>; // (Sq, Sk, D, ((H, Hr), + // B)) using StrideQ = Stride<_0, _1, Stride, int>>; // Q D ((H, Hr), B) @@ -114,10 +116,8 @@ struct GenRunner { StrideCacheV, StrideO>; - using Epilogue = - cutlass::fmha::collective::Sm100FmhaGenEpilogueWarpspecialized< - ElementOut, - StrideO>; + using Epilogue = cutlass::fmha::collective:: + Sm100FmhaGenEpilogueWarpspecialized; using TileScheduler = std::conditional_t< kKernelType == KernelType::UMMA_P, @@ -148,7 +148,6 @@ struct GenRunner { const at::Tensor& v_input, const at::Tensor& seqlen_kv_input, const at::Tensor& batch_idx_input) { - this->q = q_input; this->k = k_input; this->v = v_input; @@ -250,10 +249,7 @@ struct GenRunner { return; } - - status = op.run( - at::cuda::getCurrentCUDAStream() - ); + status = op.run(at::cuda::getCurrentCUDAStream()); if (status != cutlass::Status::kSuccess) { std::cerr << "Failed to launch CUTLASS kernel." << std::endl; return; @@ -263,28 +259,30 @@ struct GenRunner { // Dispatch macros for different element types // TODO(henrylhtsang / ayaoibrahim1123): Add support for other data types. -#define DISPATCH_ELEMENT_TYPE(DTYPE, ELEMENT_TYPE, ...) \ - [&] { \ - if (DTYPE == at::kFloat8_e4m3fn) { \ - using ELEMENT_TYPE = cutlass::float_e4m3_t; \ - return __VA_ARGS__(); \ - } else { \ - throw std::runtime_error("Unsupported dtype: " + std::to_string(static_cast(DTYPE))); \ - } \ +#define DISPATCH_ELEMENT_TYPE(DTYPE, ELEMENT_TYPE, ...) \ + [&] { \ + if (DTYPE == at::kFloat8_e4m3fn) { \ + using ELEMENT_TYPE = cutlass::float_e4m3_t; \ + return __VA_ARGS__(); \ + } else { \ + throw std::runtime_error( \ + "Unsupported dtype: " + std::to_string(static_cast(DTYPE))); \ + } \ }() // Dispatch macro for different kernel types -#define DISPATCH_KERNEL_TYPE(KTYPE, KERNEL_TYPE, ...) \ - [&] { \ - if (KTYPE == static_cast(KernelType::UMMA_P)) { \ - constexpr auto KERNEL_TYPE = KernelType::UMMA_P; \ - return __VA_ARGS__(); \ - } else if (KTYPE == static_cast(KernelType::UMMA_I)) { \ - constexpr auto KERNEL_TYPE = KernelType::UMMA_I; \ - return __VA_ARGS__(); \ - } else { \ - throw std::runtime_error("Unsupported kernel type: " + std::to_string(KTYPE)); \ - } \ +#define DISPATCH_KERNEL_TYPE(KTYPE, KERNEL_TYPE, ...) \ + [&] { \ + if (KTYPE == static_cast(KernelType::UMMA_P)) { \ + constexpr auto KERNEL_TYPE = KernelType::UMMA_P; \ + return __VA_ARGS__(); \ + } else if (KTYPE == static_cast(KernelType::UMMA_I)) { \ + constexpr auto KERNEL_TYPE = KernelType::UMMA_I; \ + return __VA_ARGS__(); \ + } else { \ + throw std::runtime_error( \ + "Unsupported kernel type: " + std::to_string(KTYPE)); \ + } \ }() at::Tensor dispatch_fmha_gen_fwd( @@ -306,20 +304,19 @@ at::Tensor dispatch_fmha_gen_fwd( }); } - // ------------------------------------------------------------------------------------------------- // Op registration // ------------------------------------------------------------------------------------------------- TORCH_LIBRARY_FRAGMENT(fbgemm, m) { - m.def("fmha_gen_fwd(" - " Tensor query, " - " Tensor key, " - " Tensor value, " - " Tensor seqlen_kv, " - " Tensor batch_idx, " - " int kernel_type = 0" - ") -> Tensor" - ); + m.def( + "fmha_gen_fwd(" + " Tensor query, " + " Tensor key, " + " Tensor value, " + " Tensor seqlen_kv, " + " Tensor batch_idx, " + " int kernel_type = 0" + ") -> Tensor"); } TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp index f0442e06a8..b4a73f8aee 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp @@ -1,13 +1,13 @@ // @nolint /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -19,14 +19,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once @@ -46,27 +47,26 @@ namespace cutlass::fmha::collective { using namespace cute; -template< - class Element_, - class ElementQK_, - class ElementPV_, - class ElementOut_, - class TileShape_, - class StrideQ_, - class StrideNewK_, - class StrideNewV_, - class StrideK_, - class StrideV_, - class StrideO_, - class Mask_ = ResidualMask, - // shape here is QG K H - // and referes to the two softmax warps - // (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V) - // (1, 2, 1) means they sit side by side (best for small Q / large K) - class ThreadShape = Shape<_1, _2, _1> -> +template < + class Element_, + class ElementQK_, + class ElementPV_, + class ElementOut_, + class TileShape_, + class StrideQ_, + class StrideNewK_, + class StrideNewV_, + class StrideK_, + class StrideV_, + class StrideO_, + class Mask_ = ResidualMask, + // shape here is QG K H + // and referes to the two softmax warps + // (2, 1, 1) means that they are stacked (best for large Q since it loads + // the least K/V) (1, 2, 1) means they sit side by side (best for small Q / + // large K) + class ThreadShape = Shape<_1, _2, _1>> struct Sm100FmhaGenMainloopWarpspecialized { - using Element = Element_; using ElementQK = ElementQK_; using ElementPV = ElementPV_; @@ -87,38 +87,58 @@ struct Sm100FmhaGenMainloopWarpspecialized { static constexpr int StageCountQ = get<1>(TileShape{}) == 256 ? 1 : 2; static constexpr int StageCountKV = 256 * 11 / get<1>(TileShape{}); - + using StagesQ = cutlass::gemm::collective::StageCount; using StagesKV = cutlass::gemm::collective::StageCount; - + using ClusterShape = Shape<_1, _1, _1>; static const int Alignment = 128 / sizeof_bits_v; using TileShapeQK = decltype(shape_div(TileShape{}, ThreadShape{})); - using TileShapePV = decltype(select<0,2,1>(TileShapeQK{})); + using TileShapePV = decltype(select<0, 2, 1>(TileShapeQK{})); using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - Element, StrideQ, Alignment, - Element, StrideK, Alignment, + cutlass::arch::Sm100, + cutlass::arch::OpClassTensorOp, + Element, + StrideQ, + Alignment, + Element, + StrideK, + Alignment, ElementQK, - TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + TileShapeQK, + ClusterShape, + cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm100, + cutlass::arch::OpClassTensorOp, // the stride for A does not matter since we do not load from smem at all - Element, StrideK, Alignment, - Element, decltype(select<1,0,2>(StrideV{})), Alignment, + Element, + StrideK, + Alignment, + Element, + decltype(select<1, 0, 2>(StrideV{})), + Alignment, ElementPV, - TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + TileShapePV, + ClusterShape, + cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; - using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); - using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int{})); - using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int{})); + using SmemLayoutQ = decltype(unstageSmemLayout( + typename CollectiveMmaQK::SmemLayoutA{}, + Int{})); + using SmemLayoutK = decltype(unstageSmemLayout( + typename CollectiveMmaQK::SmemLayoutB{}, + Int{})); + using SmemLayoutV = decltype(unstageSmemLayout( + typename CollectiveMmaPV::SmemLayoutB{}, + Int{})); struct TensorStorage { cute::array_aligned> smem_q; @@ -134,7 +154,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { kSizeP = 32, S0 = 0, S1 = S0 + kSizeS, - V0 = S0, // stats storage from softmax to correction + V0 = S0, // stats storage from softmax to correction V1 = S1, P0 = S0 + kSizeP, P1 = S1 + kSizeP, @@ -153,19 +173,18 @@ struct Sm100FmhaGenMainloopWarpspecialized { // from load to mma warp, protects q in smem using PipelineQ = cutlass::PipelineUmmaConsumerAsync< - StageCountQ, - typename CollectiveMmaQK::AtomThrShapeMNK - >; + StageCountQ, + typename CollectiveMmaQK::AtomThrShapeMNK>; // from load to mma warp, protects k/v in smem using PipelineKV = cutlass::PipelineUmmaConsumerAsync< - StageCountKV, - typename CollectiveMmaQK::AtomThrShapeMNK - >; + StageCountKV, + typename CollectiveMmaQK::AtomThrShapeMNK>; // from mma to softmax0/1 warp, protects S in tmem // (not sure yet about the reverse direction) - // there is one pipe per softmax warp, and the mma warp alternates between them + // there is one pipe per softmax warp, and the mma warp alternates between + // them using PipelineS = cutlass::PipelineUmmaAsync<1>; // from softmax0/1/ to correction wg @@ -178,17 +197,34 @@ struct Sm100FmhaGenMainloopWarpspecialized { using PipelineE = cutlass::PipelineAsync<2>; using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier< - /*stages*/ 1, /*groups*/ 2>; + /*stages*/ 1, + /*groups*/ 2>; - static_assert(cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v) == cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v), "K and V smem layouts must be of equal size"); + static_assert( + cutlass::bits_to_bytes( + cosize(take<0, 3>(SmemLayoutK{})) * cute::sizeof_bits_v) == + cutlass::bits_to_bytes( + cosize(take<0, 3>(SmemLayoutV{})) * cute::sizeof_bits_v), + "K and V smem layouts must be of equal size"); using Load = Sm100FmhaLoadCpAsyncWarpspecialized< - Element, StrideQ, StrideNewK, StrideNewV, StrideCacheK, StrideCacheV, - TensorStorage, CollectiveMmaQK, CollectiveMmaPV, - SmemLayoutQ, SmemLayoutK, SmemLayoutV, - PipelineQ, PipelineKV, TileShape, Mask - >; - + Element, + StrideQ, + StrideNewK, + StrideNewV, + StrideCacheK, + StrideCacheV, + TensorStorage, + CollectiveMmaQK, + CollectiveMmaPV, + SmemLayoutQ, + SmemLayoutK, + SmemLayoutV, + PipelineQ, + PipelineKV, + TileShape, + Mask>; + struct Arguments { typename Load::Arguments load; @@ -213,20 +249,21 @@ struct Sm100FmhaGenMainloopWarpspecialized { float scale_output; }; - template - static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) { + template + static bool can_implement( + ProblemShape const& problem_shape, + Arguments const& args) { return true; } - template + template static Params to_underlying_arguments( ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - float scale_softmax = args.scale_softmax; if (scale_softmax == 0.0f) { - scale_softmax = 1.0f / (float) std::sqrt(get<2>(problem_shape)); + scale_softmax = 1.0f / (float)std::sqrt(get<2>(problem_shape)); } float log2_e = static_cast(std::log2(std::exp(1.0))); @@ -234,8 +271,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { Load::to_underlying_arguments(problem_shape, args.load, workspace), args.scale_q * args.scale_k * scale_softmax, args.scale_q * args.scale_k * log2_e * scale_softmax, - args.scale_v * args.inv_scale_o - }; + args.scale_v * args.inv_scale_o}; } CUTLASS_DEVICE @@ -243,38 +279,51 @@ struct Sm100FmhaGenMainloopWarpspecialized { Load::prefetch_tma_descriptors(params.load); } - template - CUTLASS_DEVICE void - load( - BlkCoord const& blk_coord, ProblemShape const& problem_shape, - Params const& params, ParamsProblemShape const& params_problem_shape, + template + CUTLASS_DEVICE void load( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + Params const& params, + ParamsProblemShape const& params_problem_shape, TensorStorage& storage, - PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, - PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { - + PipelineQ& pipeline_q, + typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, + typename PipelineKV::PipelineState& pipeline_kv_producer_state) { Load load; - load.load(blk_coord, problem_shape, params.load, params_problem_shape, + load.load( + blk_coord, + problem_shape, + params.load, + params_problem_shape, storage, - pipeline_q, pipeline_q_producer_state, - pipeline_kv, pipeline_kv_producer_state); + pipeline_q, + pipeline_q_producer_state, + pipeline_kv, + pipeline_kv_producer_state); } - template - CUTLASS_DEVICE auto - mma( + template + CUTLASS_DEVICE auto mma( BlkCoord const& blk_coord, - Params const& params, ProblemShape const& problem_shape, + Params const& params, + ProblemShape const& problem_shape, TensorStorage& storage, - PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state, - PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state, - PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state, - PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state, - PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) { - + PipelineQ& pipeline_q, + typename PipelineQ::PipelineState& pipeline_q_consumer_state, + PipelineKV& pipeline_kv, + typename PipelineKV::PipelineState& pipeline_kv_consumer_state, + PipelineS& pipeline_s0, + typename PipelineS::PipelineState& pipeline_s0_producer_state, + PipelineS& pipeline_s1, + typename PipelineS::PipelineState& pipeline_s1_producer_state, + PipelineO& pipeline_corr, + typename PipelineO::PipelineState& pipeline_corr_producer_state) { auto pipeline_q_release_state = pipeline_q_consumer_state; auto pipeline_kv_release_state = pipeline_kv_consumer_state; - int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + int mask_tile_count = + Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); typename CollectiveMmaQK::TiledMma mma_qk; ThrMMA thr_mma_qk = mma_qk.get_slice(0); @@ -283,9 +332,12 @@ struct Sm100FmhaGenMainloopWarpspecialized { TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv); ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0); - Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); - Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); - Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + Tensor sQ = + make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = + make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = + make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ); Tensor tSrK = thr_mma_qk.make_fragment_B(sK); @@ -295,8 +347,8 @@ struct Sm100FmhaGenMainloopWarpspecialized { // S0 S1`O0 O1 // sequential in memory, where S overlaps with P and V - Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{})); - Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{})); + Tensor tStS = partition_fragment_C(mma_qk, select<0, 1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0, 1>(TileShapePV{})); Tensor tStS0 = tStS; tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0); @@ -308,8 +360,11 @@ struct Sm100FmhaGenMainloopWarpspecialized { Tensor tOtO1 = tOtO; tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1); - Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{}); - Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging + Tensor sP = make_tensor( + make_smem_ptr((Element*)nullptr), + typename CollectiveMmaPV::SmemLayoutA{}); + Tensor tOrP = + thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging Tensor tOrP0 = tOrP; tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); @@ -325,8 +380,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { pipeline_q.consumer_wait(pipeline_q_consumer_state); ++pipeline_q_consumer_state; - Tensor tSrQ0 = tSrQ(_,_,_,q_index); - + Tensor tSrQ0 = tSrQ(_, _, _, q_index); // wait for K1 k_index = pipeline_kv_consumer_state.index(); @@ -336,7 +390,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { // gemm Q1 * K1 -> S1 pipeline_s0.producer_acquire(pipeline_s0_producer_state); - gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_, _, _, k_index), tStS0); pipeline_s0.producer_commit(pipeline_s0_producer_state); ++pipeline_s0_producer_state; @@ -354,7 +408,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { ++pipeline_q_consumer_state; } - Tensor tSrQ1 = tSrQ(_,_,_,q_index); + Tensor tSrQ1 = tSrQ(_, _, _, q_index); if constexpr (get<1>(ThreadShape{}) > 1) { k_index = pipeline_kv_consumer_state.index(); @@ -365,7 +419,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { pipeline_s1.producer_acquire(pipeline_s1_producer_state); // gemm Q2 * K1 -> S2 - gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_, _, _, k_index), tStS1); pipeline_s1.producer_commit(pipeline_s1_producer_state); ++pipeline_s1_producer_state; @@ -387,7 +441,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { pipeline_s0.producer_acquire(pipeline_s0_producer_state); // gemm P1 * V1 -> O1 - gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); + gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_, _, _, v_index), tOtO0); pipeline_corr.producer_commit(pipeline_corr_producer_state); ++pipeline_corr_producer_state; @@ -402,14 +456,13 @@ struct Sm100FmhaGenMainloopWarpspecialized { // loop: mask_tile_count -= 1; for (; mask_tile_count > 0; mask_tile_count -= 1) { - // wait for Ki k_index = (pipeline_kv_consumer_state.index()); pipeline_kv.consumer_wait(pipeline_kv_consumer_state); ++pipeline_kv_consumer_state; // gemm Q1 * Ki -> S1 - gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_, _, _, k_index), tStS0); pipeline_s0.producer_commit(pipeline_s0_producer_state); ++pipeline_s0_producer_state; @@ -429,7 +482,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { pipeline_corr.producer_acquire(pipeline_corr_producer_state); pipeline_s1.producer_acquire(pipeline_s1_producer_state); - gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_, _, _, v_index), tOtO1); pipeline_corr.producer_commit(pipeline_corr_producer_state); ++pipeline_corr_producer_state; @@ -445,7 +498,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { } // gemm Q2 * Ki -> S2 - gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_, _, _, k_index), tStS1); pipeline_s1.producer_commit(pipeline_s1_producer_state); ++pipeline_s1_producer_state; @@ -464,7 +517,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { pipeline_s0.producer_acquire(pipeline_s0_producer_state); - gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); + gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_, _, _, v_index), tOtO0); pipeline_corr.producer_commit(pipeline_corr_producer_state); ++pipeline_corr_producer_state; @@ -496,7 +549,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { pipeline_corr.producer_acquire(pipeline_corr_producer_state); pipeline_s1.producer_acquire(pipeline_s1_producer_state); - gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_, _, _, v_index), tOtO1); pipeline_corr.producer_commit(pipeline_corr_producer_state); ++pipeline_corr_producer_state; @@ -511,56 +564,81 @@ struct Sm100FmhaGenMainloopWarpspecialized { pipeline_s1.producer_commit(pipeline_s1_producer_state); ++pipeline_s1_producer_state; - // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ... - // Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ... + // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 + // B2, T1 S00 B1, T0 S11 B2, ... Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , + // S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ... } - template - CUTLASS_DEVICE auto - softmax_step( - float& row_max, float& row_sum, - Stage stage, bool final_call, - BlkCoord const& blk_coord, CoordTensor const& cS, - Params const& params, ProblemShape const& problem_shape, - PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, - PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + template < + bool need_apply_mask, + class Stage, + class BlkCoord, + class CoordTensor, + class ProblemShape> + CUTLASS_DEVICE auto softmax_step( + float& row_max, + float& row_sum, + Stage stage, + bool final_call, + BlkCoord const& blk_coord, + CoordTensor const& cS, + Params const& params, + ProblemShape const& problem_shape, + PipelineS& pipeline_s, + typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, + typename PipelineC::PipelineState& pipeline_c_producer_state, OrderBarrierSoftmax& order_s) { + Tensor tScS = + typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); - Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); - - Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); - tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); + Tensor tStS = partition_fragment_C( + typename CollectiveMmaQK::TiledMma{}, select<0, 1>(TileShapeQK{})); + tStS.data() = + uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); - tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); + tStS_v.data() = + uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); - auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; - Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); - tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); - Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + auto tilePlikeFP32 = + size<1>(TileShapeQK{}) / Int{} * Int{}; + Tensor tStS_P = + tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + tStS_P.data() = warp_uniform( + uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + Tensor tScS_P = + tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); // Each thread owns a single row - using TMEM_LOAD = conditional_t(TileShapeQK{}) < _128{}, SM100_TMEM_LOAD_32dp32b8x, SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem - using TMEM_STORE = conditional_t(TileShapeQK{}) < _128{}, SM100_TMEM_STORE_32dp32b8x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem - using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + using TMEM_LOAD = conditional_t< + size<1>(TileShapeQK{}) < _128{}, + SM100_TMEM_LOAD_32dp32b8x, + SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE = conditional_t< + size<1>(TileShapeQK{}) < _128{}, + SM100_TMEM_STORE_32dp32b8x, + SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE_V = + SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); - auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS); Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS); auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v); - auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx); + auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx); Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v); Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v); auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P); - auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P); tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get()); @@ -586,10 +664,10 @@ struct Sm100FmhaGenMainloopWarpspecialized { float row_max_3 = row_max; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTMEM_LOADrS); i += 4) { - row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i)); - row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1)); - row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2)); - row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3)); + row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i)); + row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i + 1)); + row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i + 2)); + row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i + 3)); } row_max = ::fmax(row_max_0, row_max_1); row_max = ::fmax(row_max, row_max_2); @@ -606,39 +684,39 @@ struct Sm100FmhaGenMainloopWarpspecialized { pipeline_c.producer_commit(pipeline_c_producer_state); ++pipeline_c_producer_state; - // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's) + // notify correction wg that they are ready (might need addtl ordering + // between S0 and S1 WG's) ElementQK scale = params.scale_softmax_log2; ElementQK row_max_scale = row_max_safe * scale; float2 scale_fp32x2 = make_float2(scale, scale); - float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale); + float2 minus_row_max_scale_fp32x2 = + make_float2(-row_max_scale, -row_max_scale); Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); constexpr int kConversionsPerStep = 2; - Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); + Tensor tTMEM_STORErS_x4_e = + recast>(tTMEM_STORErS_x4); NumericArrayConverter convert; - const int kReleasePipeCount = 10; // must be multiple of 2 - + const int kReleasePipeCount = 10; // must be multiple of 2 + order_s.wait(); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { - float2 in = make_float2( - tTMEM_LOADrS(i + 0), - tTMEM_LOADrS(i + 1) - ); + float2 in = make_float2(tTMEM_LOADrS(i + 0), tTMEM_LOADrS(i + 1)); float2 out; cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); tTMEM_LOADrS(i + 0) = out.x; tTMEM_LOADrS(i + 1) = out.y; - tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); - tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); + tTMEM_LOADrS(i + 0) = ::exp2f(tTMEM_LOADrS(i + 0)); + tTMEM_LOADrS(i + 1) = ::exp2f(tTMEM_LOADrS(i + 1)); Array in_conv; CUTLASS_PRAGMA_UNROLL @@ -647,7 +725,6 @@ struct Sm100FmhaGenMainloopWarpspecialized { } tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); - if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { order_s.arrive(); } @@ -655,7 +732,10 @@ struct Sm100FmhaGenMainloopWarpspecialized { // this prevents register spills in fp16 if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { if (i == size(tTMEM_LOADrS) - 6) { - copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + copy( + tiled_tmem_store, + tTMEM_STORErS_x4(_, _, 0), + tTMEM_STOREtS_x4(_, _, 0)); } } } @@ -663,7 +743,10 @@ struct Sm100FmhaGenMainloopWarpspecialized { // tmem_store(reg_S8) -> op_P CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); - copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + copy( + tiled_tmem_store, + tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), + tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); cutlass::arch::fence_view_async_tmem_store(); @@ -684,16 +767,16 @@ struct Sm100FmhaGenMainloopWarpspecialized { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTMEM_LOADrS); i += 8) { // row_sum += tTMEM_LOADrS(i); - float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1)); + float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i + 1)); cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in); - in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1)); + in = make_float2(tTMEM_LOADrS(i + 2), tTMEM_LOADrS(i + 2 + 1)); cute::add(local_row_sum_1, local_row_sum_1, in); - in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1)); + in = make_float2(tTMEM_LOADrS(i + 4), tTMEM_LOADrS(i + 4 + 1)); cute::add(local_row_sum_2, local_row_sum_2, in); - in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1)); + in = make_float2(tTMEM_LOADrS(i + 6), tTMEM_LOADrS(i + 6 + 1)); cute::add(local_row_sum_3, local_row_sum_3, in); } @@ -701,7 +784,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; - + row_sum = local_row_sum; if (final_call) { @@ -715,26 +798,28 @@ struct Sm100FmhaGenMainloopWarpspecialized { } } - template - CUTLASS_DEVICE auto - softmax( + template + CUTLASS_DEVICE auto softmax( Stage stage, BlkCoord const& blk_coord, - Params const& params, ProblemShape const& problem_shape, - PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, - PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + Params const& params, + ProblemShape const& problem_shape, + PipelineS& pipeline_s, + typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, + typename PipelineC::PipelineState& pipeline_c_producer_state, OrderBarrierSoftmax& order_s) { - - int mask_tile_count = Mask{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_shape); + int mask_tile_count = + Mask{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_shape); ElementQK row_max = -INFINITY; ElementQK row_sum = 0; - Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{})); + Tensor cS_base = make_identity_tensor(select<0, 1>(TileShapeQK{})); auto logical_offset = make_coord( - get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), - 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{}) - ); + get<0>(blk_coord) * get<0>(TileShape{}) + + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), + 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{})); Tensor cS = domain_offset(logical_offset, cS_base); pipeline_c.producer_acquire(pipeline_c_producer_state); @@ -742,32 +827,49 @@ struct Sm100FmhaGenMainloopWarpspecialized { CUTLASS_PRAGMA_NO_UNROLL for (; mask_tile_count > 0; mask_tile_count -= 1) { softmax_step( - row_max, row_sum, stage, + row_max, + row_sum, + stage, (mask_tile_count == 1) && - (Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape) == 0), - blk_coord, cS, params, problem_shape, - pipeline_s, pipeline_s_consumer_state, - pipeline_c, pipeline_c_producer_state, - order_s - ); - - cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + (Mask{}.get_masked_trip_count( + blk_coord, TileShape{}, problem_shape) == 0), + blk_coord, + cS, + params, + problem_shape, + pipeline_s, + pipeline_s_consumer_state, + pipeline_c, + pipeline_c_producer_state, + order_s); + + cS.data() = + cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); } // Masked iterations - mask_tile_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); + mask_tile_count = + Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); CUTLASS_PRAGMA_NO_UNROLL for (; mask_tile_count > 0; mask_tile_count -= 1) { softmax_step( - row_max, row_sum, stage, mask_tile_count == 1, - blk_coord, cS, params, problem_shape, - pipeline_s, pipeline_s_consumer_state, - pipeline_c, pipeline_c_producer_state, - order_s - ); - - cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + row_max, + row_sum, + stage, + mask_tile_count == 1, + blk_coord, + cS, + params, + problem_shape, + pipeline_s, + pipeline_s_consumer_state, + pipeline_c, + pipeline_c_producer_state, + order_s); + + cS.data() = + cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); } pipeline_c.producer_commit(pipeline_c_producer_state); @@ -779,13 +881,21 @@ struct Sm100FmhaGenMainloopWarpspecialized { ++pipeline_s_consumer_state; } - template - CUTLASS_DEVICE auto - correction_epilogue( - float scale_softmax_log2, float scale_out, Vector const& v0, Vector const& v1, - GTensor& gO, CTensor const& cO, Shape const& g_shape, + template < + class Vector, + class GTensor, + class CTensor, + class Shape, + class Epilogue> + CUTLASS_DEVICE auto correction_epilogue( + float scale_softmax_log2, + float scale_out, + Vector const& v0, + Vector const& v1, + GTensor& gO, + CTensor const& cO, + Shape const& g_shape, Epilogue const& epilogue) { - using ElementOut = typename GTensor::value_type; int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); @@ -795,16 +905,22 @@ struct Sm100FmhaGenMainloopWarpspecialized { // good values would be either 32 or 64 const int kCorrectionTileSize = 32 / sizeof(ElementOut); - using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + using TMEM_LOAD = std::conditional_t< + kCorrectionTileSize == 32, + SM100_TMEM_LOAD_32dp32b32x, + SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem typename CollectiveMmaPV::TiledMma mma; - Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0, 1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); Tensor tOgO = mma.get_slice(0).partition_C(gO); - - Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); - Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); - Tensor tOgO_i = tOgO.compose(make_layout(make_shape(_128{}, Int{}))); + + Tensor tOtO_i = tOtO.compose( + make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose( + make_layout(make_shape(_128{}, Int{}))); + Tensor tOgO_i = tOgO.compose( + make_layout(make_shape(_128{}, Int{}))); Tensor tOtO0 = tOtO_i; tOtO0.data() = tOtO0.data().get() + uint32_t(TmemAllocation::O0); @@ -812,8 +928,8 @@ struct Sm100FmhaGenMainloopWarpspecialized { tOtO1.data() = tOtO1.data().get() + uint32_t(TmemAllocation::O1); auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); - auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); - + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + Tensor tTMEM_LOADtO0 = thr_tmem_load.partition_S(tOtO0); Tensor tTMEM_LOADtO1 = thr_tmem_load.partition_S(tOtO1); Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); @@ -834,27 +950,30 @@ struct Sm100FmhaGenMainloopWarpspecialized { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) { Tensor tTMEM_LOADtO0_i = tTMEM_LOADtO0; - tTMEM_LOADtO0_i.data() = tTMEM_LOADtO0_i.data().get() + uint32_t(i * kCorrectionTileSize); + tTMEM_LOADtO0_i.data() = + tTMEM_LOADtO0_i.data().get() + uint32_t(i * kCorrectionTileSize); Tensor tTMEM_LOADtO1_i = tTMEM_LOADtO1; - tTMEM_LOADtO1_i.data() = tTMEM_LOADtO1_i.data().get() + uint32_t(i * kCorrectionTileSize); + tTMEM_LOADtO1_i.data() = + tTMEM_LOADtO1_i.data().get() + uint32_t(i * kCorrectionTileSize); Tensor tTMEM_LOADgO_i = tTMEM_LOADgO; - tTMEM_LOADgO_i.data() = tTMEM_LOADgO_i.data().get() + i * kCorrectionTileSize * stride<1>(gO); + tTMEM_LOADgO_i.data() = + tTMEM_LOADgO_i.data().get() + i * kCorrectionTileSize * stride<1>(gO); Tensor tTMrO0 = make_tensor(shape(tTMEM_LOADcO)); Tensor tTMrO1 = make_tensor(shape(tTMEM_LOADcO)); - + copy(tiled_tmem_load, tTMEM_LOADtO0_i, tTMrO0); copy(tiled_tmem_load, tTMEM_LOADtO1_i, tTMrO1); - + CUTLASS_PRAGMA_UNROLL for (int j = 0; j < size(tTMrO0); j += 2) { - float2 in0 = make_float2(tTMrO0(j), tTMrO0(j+1)); - float2 in1 = make_float2(tTMrO1(j), tTMrO1(j+1)); + float2 in0 = make_float2(tTMrO0(j), tTMrO0(j + 1)); + float2 in1 = make_float2(tTMrO1(j), tTMrO1(j + 1)); float2 out; cute::mul(out, scale0_f32x2, in0); cute::fma(out, scale1_f32x2, in1, out); tTMrO0(j) = out.x; - tTMrO0(j+1) = out.y; + tTMrO0(j + 1) = out.y; } constexpr int N = 4 / sizeof(ElementOut); @@ -880,11 +999,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { } } - CUTLASS_DEVICE auto - correction_rescale( - float scale, - uint32_t tmem_O) { - + CUTLASS_DEVICE auto correction_rescale(float scale, uint32_t tmem_O) { int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); // As opposed to the softmax, we do not have enough registers here @@ -892,24 +1007,28 @@ struct Sm100FmhaGenMainloopWarpspecialized { // good values would be either 32 or 64 const int kCorrectionTileSize = 32; - using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 64 cols of 32b elem - using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 64 cols of 32b elem + using TMEM_LOAD = + SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 64 cols of 32b elem + using TMEM_STORE = + SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 64 cols of 32b elem typename CollectiveMmaPV::TiledMma mma; - Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); - Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor cO = make_identity_tensor(select<0, 1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0, 1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); - - Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); - Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + + Tensor tOtO_i = tOtO.compose( + make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose( + make_layout(make_shape(_128{}, Int{}))); tOtO_i.data() = tOtO_i.data().get() + tmem_O; - + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); - auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); - auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); - + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); @@ -918,18 +1037,21 @@ struct Sm100FmhaGenMainloopWarpspecialized { float2 scale_f32x2 = make_float2(scale, scale); - Tensor tTMrO = make_tensor(make_shape(shape(tTMEM_LOADcO), Int(TileShape{}) / kCorrectionTileSize>{})); - + Tensor tTMrO = make_tensor(make_shape( + shape(tTMEM_LOADcO), Int(TileShape{}) / kCorrectionTileSize>{})); + auto copy_in = [&](int i) { Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; - tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + tTMEM_LOADtO_i.data() = + tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i); }; auto copy_out = [&](int i) { Tensor tTMEM_STOREtO_i = tTMEM_STOREtO; - tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + tTMEM_STOREtO_i.data() = + tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize); Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i); }; @@ -945,59 +1067,72 @@ struct Sm100FmhaGenMainloopWarpspecialized { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < count; i++) { if (i != count - 1) { - copy_in(i+1); + copy_in(i + 1); } Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); CUTLASS_PRAGMA_UNROLL for (int j = 0; j < size(tTMrO_i); j += 2) { - float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); + float2 in = make_float2(tTMrO_i(j), tTMrO_i(j + 1)); float2 out; cute::mul(out, scale_f32x2, in); tTMrO_i(j) = out.x; - tTMrO_i(j+1) = out.y; + tTMrO_i(j + 1) = out.y; } copy_out(i); } } - template - CUTLASS_DEVICE auto - correction( + template < + class BlkCoord, + class ProblemShape, + class TensorStorageEpi, + class Epilogue> + CUTLASS_DEVICE auto correction( BlkCoord const& blk_coord, - Params const& params, ProblemShape const& problem_shape, + Params const& params, + ProblemShape const& problem_shape, TensorStorageEpi& shared_storage_epi, - PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, - PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, - PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state, - PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + PipelineC& pipeline_s0_c, + typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, + PipelineC& pipeline_s1_c, + typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, + PipelineO& pipeline_o, + typename PipelineO::PipelineState& pipeline_o_consumer_state, + PipelineE& pipeline_epi, + typename PipelineE::PipelineState& pipeline_epi_producer_state, Epilogue const& epilogue) { - - int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + int mask_tile_count = + Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); - Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + Tensor tStS = partition_fragment_C( + typename CollectiveMmaQK::TiledMma{}, select<0, 1>(TileShapeQK{})); + + Tensor cS = make_identity_tensor(select<0, 1>(TileShapeQK{})); + Tensor tScS = + typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); - Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); - Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); - Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); - using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + using TMEM_LOAD_V = + SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); - auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); + auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v); Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v); Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS; - tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0); + tTMEM_LOADVtS0.data() = + tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0); Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS; - tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1); + tTMEM_LOADVtS1.data() = + tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1); // ignore first signal from softmax as no correction is required pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); @@ -1011,7 +1146,6 @@ struct Sm100FmhaGenMainloopWarpspecialized { CUTLASS_PRAGMA_NO_UNROLL for (; mask_tile_count > 0; mask_tile_count -= 1) { - pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); @@ -1020,7 +1154,9 @@ struct Sm100FmhaGenMainloopWarpspecialized { copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); // e^(scale * (old_max - new_max) - float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + float scale = ::exp2f( + params.scale_softmax_log2 * + (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); pipeline_o.consumer_wait(pipeline_o_consumer_state); @@ -1038,7 +1174,9 @@ struct Sm100FmhaGenMainloopWarpspecialized { copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); - scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + scale = ::exp2f( + params.scale_softmax_log2 * + (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); pipeline_o.consumer_wait(pipeline_o_consumer_state); @@ -1093,12 +1231,23 @@ struct Sm100FmhaGenMainloopWarpspecialized { // F2FP // store to smem - Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); - auto g_shape = select<0,2>(problem_shape); - auto mO = make_tensor(make_gmem_ptr(epilogue.params.ptr_o), append<3>(select<0,1>(TileShapePV{}), get<3>(problem_shape)), epilogue.params.dO); + Tensor cO = make_identity_tensor(select<0, 1>(TileShapePV{})); + auto g_shape = select<0, 2>(problem_shape); + auto mO = make_tensor( + make_gmem_ptr(epilogue.params.ptr_o), + append<3>(select<0, 1>(TileShapePV{}), get<3>(problem_shape)), + epilogue.params.dO); auto gO = mO(_, _, get<2>(blk_coord)); - correction_epilogue(params.scale_softmax_log2, params.scale_output, tTMEM_LOADVrS0, tTMEM_LOADVrS1, gO, cO, g_shape, epilogue); + correction_epilogue( + params.scale_softmax_log2, + params.scale_output, + tTMEM_LOADVrS0, + tTMEM_LOADVrS1, + gO, + cO, + g_shape, + epilogue); cutlass::arch::fence_view_async_tmem_load(); @@ -1108,7 +1257,6 @@ struct Sm100FmhaGenMainloopWarpspecialized { pipeline_o.consumer_release(pipeline_o_release_state); ++pipeline_o_release_state; } - }; -} // namespace cutlass::fmha::collective +} // namespace cutlass::fmha::collective diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_load_cpasync_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_load_cpasync_warpspecialized.hpp index c247302eea..6ad4df5501 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_load_cpasync_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_load_cpasync_warpspecialized.hpp @@ -1,13 +1,13 @@ // @nolint /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -19,14 +19,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once @@ -44,31 +45,28 @@ namespace cutlass::fmha::collective { using namespace cute; -template< - class Element, - class StrideQ, - class StrideNewK, - class StrideNewV, - class StrideCacheK, - class StrideCacheV, - class TensorStorage, - class CollectiveMmaQK, - class CollectiveMmaPV, - class SmemLayoutQ, - class SmemLayoutK, - class SmemLayoutV, - class PipelineQ, - class PipelineKV, - class TileShape, - class Mask -> +template < + class Element, + class StrideQ, + class StrideNewK, + class StrideNewV, + class StrideCacheK, + class StrideCacheV, + class TensorStorage, + class CollectiveMmaQK, + class CollectiveMmaPV, + class SmemLayoutQ, + class SmemLayoutK, + class SmemLayoutV, + class PipelineQ, + class PipelineKV, + class TileShape, + class Mask> struct Sm100FmhaLoadCpAsyncWarpspecialized { - using TileShapeQK = typename CollectiveMmaQK::TileShape; using TileShapePV = typename CollectiveMmaPV::TileShape; struct Arguments { - const int* cache_batch_idx; const Element* ptr_q; @@ -87,65 +85,70 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized { using Params = Arguments; - template + template static Params to_underlying_arguments( ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; } CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& params) { - } + static void prefetch_tma_descriptors(Params const& params) {} - template + template CUTLASS_DEVICE auto constexpr transpose(Tensor const& t) { CUTE_STATIC_ASSERT_V(rank(t) == _2{}); - return t.compose(make_layout(make_shape(size<1>(t), size<0>(t)), make_stride(size<0>(t), _1{}))); + return t.compose(make_layout( + make_shape(size<1>(t), size<0>(t)), make_stride(size<0>(t), _1{}))); } - template< - class CAtom, class TA, class TB, - class CountTensor, class CountLimit, - class SrcTensor, class DstTensor - > + template < + class CAtom, + class TA, + class TB, + class CountTensor, + class CountLimit, + class SrcTensor, + class DstTensor> CUTLASS_DEVICE void copy_with_limit( TiledCopy const& tiled_copy, - CountTensor const& c, CountLimit const& l, - SrcTensor const& src, DstTensor&& dst) { - - //copy(tiled_copy, src, dst); + CountTensor const& c, + CountLimit const& l, + SrcTensor const& src, + DstTensor&& dst) { + // copy(tiled_copy, src, dst); #if 1 auto c_f = make_tensor(c.data(), flatten(c.layout())); auto src_f = make_tensor(src.data(), flatten(src.layout())); auto dst_f = make_tensor(dst.data(), flatten(dst.layout())); - auto c_v = group_modes<1,rank_v>(c_f); - auto src_v = group_modes<1,rank_v>(src_f); - auto dst_v = group_modes<1,rank_v>(dst_f); + auto c_v = group_modes<1, rank_v>(c_f); + auto src_v = group_modes<1, rank_v>(src_f); + auto dst_v = group_modes<1, rank_v>(dst_f); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<1>(src_v); i++) { if (elem_less(c_v(_0{}, i), l)) { copy(CAtom{}, src_v(_, i), dst_v(_, i)); - } - else { + } else { clear(dst_v(_, i)); } } #endif } - template - CUTLASS_DEVICE void - load( - BlkCoord const& blk_coord, ProblemShape const& problem_shape, - Params const& params, ParamsProblemShape const& params_problem_shape, + template + CUTLASS_DEVICE void load( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + Params const& params, + ParamsProblemShape const& params_problem_shape, TensorStorage& storage, - PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, - PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { - - int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + PipelineQ& pipeline_q, + typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, + typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + int mask_tile_count = + Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); mask_tile_count *= 2; int warp_idx = (threadIdx.x / 32) % 2; @@ -156,13 +159,17 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized { // this one is only executed by one thread, no need to elect_one auto blk_coord_cache = blk_coord; if (params.cache_batch_idx != nullptr) { - get<2,1>(blk_coord_cache) = params.cache_batch_idx[get<2,1>(blk_coord_cache)]; + get<2, 1>(blk_coord_cache) = + params.cache_batch_idx[get<2, 1>(blk_coord_cache)]; } // Q1, K1, K2, V1, K3, V2, ... Kn, Vn-1, Vn // two pipes: Q and KV - auto cQ = make_identity_tensor(select<0,2>(TileShape{})); - auto mQ = make_tensor(make_gmem_ptr(params.ptr_q), append<3>(select<0,2>(TileShapeQK{}), get<3>(problem_shape)), params.dQ); + auto cQ = make_identity_tensor(select<0, 2>(TileShape{})); + auto mQ = make_tensor( + make_gmem_ptr(params.ptr_q), + append<3>(select<0, 2>(TileShapeQK{}), get<3>(problem_shape)), + params.dQ); auto gQ = mQ(_, _, get<2>(blk_coord)); auto sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); @@ -171,13 +178,19 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized { auto tSgQ = thr_mma_qk.partition_A(gQ); auto tScQ = thr_mma_qk.partition_A(cQ); - auto atom_q_tv = Layout, _16>, Stride, _1>>{}; - auto atom_kv_tv = Layout, _16>, Stride, _1>>{}; + auto atom_q_tv = + Layout, _16>, Stride, _1>>{}; + auto atom_kv_tv = + Layout, _16>, Stride, _1>>{}; auto tiled_copy_q = make_cotiled_copy( Copy_Atom, Element>{}, atom_q_tv, - make_layout(shape(tSgQ), replace<0>(stride(tSgQ), replace<0>(stride<0>(tSgQ), get<2>(TileShape{}))))); + make_layout( + shape(tSgQ), + replace<0>( + stride(tSgQ), + replace<0>(stride<0>(tSgQ), get<2>(TileShape{}))))); auto thr_copy_q = tiled_copy_q.get_slice(thread_idx); @@ -202,31 +215,45 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized { int vlen = sizeof(Vec) / sizeof(Element); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(src); i++) { - auto cc = c(vlen*i); + auto cc = c(vlen * i); Vec* dst_ptr = &dst(i); const Vec* src_ptr = &src(i); bool guard = elem_less(cc, limitQ); - cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Always>( - dst_ptr, src_ptr, guard - ); + cutlass::arch:: + cp_async_zfill<16, cutlass::arch::CacheOperation::Always>( + dst_ptr, src_ptr, guard); } - + pipeline_q.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); }; load_q(q0_index, pipeline_q_producer_state); ++pipeline_q_producer_state; - auto cK_t = make_identity_tensor(select<1,2>(TileShapeQK{})); - auto cK = make_tensor(cK_t.data(), make_layout(get<0>(cK_t.layout()), get<1>(cK_t.layout()), make_layout(_2{}, get<1>(TileShapeQK{}) * stride<0>(cK_t)))); - auto mK = make_tensor(make_gmem_ptr(params.ptr_cache_k), select<1,2,3>(problem_shape), params.dCacheK); - auto gK = local_tile(mK(_, _, get<2>(blk_coord_cache)), TileShapeQK{}, make_coord(_, _, _0{}), Step{}); + auto cK_t = make_identity_tensor(select<1, 2>(TileShapeQK{})); + auto cK = make_tensor( + cK_t.data(), + make_layout( + get<0>(cK_t.layout()), + get<1>(cK_t.layout()), + make_layout(_2{}, get<1>(TileShapeQK{}) * stride<0>(cK_t)))); + auto mK = make_tensor( + make_gmem_ptr(params.ptr_cache_k), + select<1, 2, 3>(problem_shape), + params.dCacheK); + auto gK = local_tile( + mK(_, _, get<2>(blk_coord_cache)), + TileShapeQK{}, + make_coord(_, _, _0{}), + Step{}); auto sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); auto tSgK = thr_mma_qk.partition_B(gK); auto tScK = thr_mma_qk.partition_B(cK); - auto tSlK = thr_mma_qk.partition_B(make_tensor((Element*) nullptr, make_ordered_layout(select<1,2>(TileShapeQK{}), Step<_1, _0>{}))); + auto tSlK = thr_mma_qk.partition_B(make_tensor( + (Element*)nullptr, + make_ordered_layout(select<1, 2>(TileShapeQK{}), Step<_1, _0>{}))); auto tiled_copy_k = make_cotiled_copy( Copy_Atom, Element>{}, atom_kv_tv, @@ -238,20 +265,34 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized { auto tKgK = thr_copy_k.partition_S(tSgK); auto tKcK = thr_copy_k.partition_S(tScK); - int seqlen_cache_kv = get<1>(problem_shape) - ((params.ptr_new_k != nullptr) ? 1 : 0); + int seqlen_cache_kv = + get<1>(problem_shape) - ((params.ptr_new_k != nullptr) ? 1 : 0); auto limitK = append<2>(seqlen_cache_kv, _128{}); - auto cV_t = make_identity_tensor(select<1,2>(TileShapePV{})); - auto cV = make_tensor(cV_t.data(), make_layout(get<0>(cV_t.layout()), get<1>(cV_t.layout()), make_layout(_2{}, get<2>(TileShapePV{}) * stride<1>(cV_t)))); - auto mV = make_tensor(make_gmem_ptr(params.ptr_cache_v), select<2,1,3>(problem_shape), select<1,0,2>(params.dCacheV)); - auto gV = local_tile(mV(_, _, get<2>(blk_coord_cache)), TileShapePV{}, make_coord(_, _0{}, _), Step{}); + auto cV_t = make_identity_tensor(select<1, 2>(TileShapePV{})); + auto cV = make_tensor( + cV_t.data(), + make_layout( + get<0>(cV_t.layout()), + get<1>(cV_t.layout()), + make_layout(_2{}, get<2>(TileShapePV{}) * stride<1>(cV_t)))); + auto mV = make_tensor( + make_gmem_ptr(params.ptr_cache_v), + select<2, 1, 3>(problem_shape), + select<1, 0, 2>(params.dCacheV)); + auto gV = local_tile( + mV(_, _, get<2>(blk_coord_cache)), + TileShapePV{}, + make_coord(_, _0{}, _), + Step{}); auto sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); typename CollectiveMmaPV::TiledMma mma_pv; ThrMMA thr_mma_pv = mma_pv.get_slice(0); auto tOgV = thr_mma_pv.partition_B(gV); auto tOcV = thr_mma_pv.partition_B(cV); - auto tOlV = thr_mma_pv.partition_B(make_tensor((Element*) nullptr, make_layout(select<1,2>(TileShapePV{})))); + auto tOlV = thr_mma_pv.partition_B(make_tensor( + (Element*)nullptr, make_layout(select<1, 2>(TileShapePV{})))); auto tiled_copy_v = make_cotiled_copy( Copy_Atom, Element>{}, @@ -264,13 +305,19 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized { auto tVgV = thr_copy_v.partition_S(tOgV); auto tVcV = thr_copy_v.partition_S(tOcV); - auto limitV = select<1,0>(limitK); + auto limitV = select<1, 0>(limitK); int full_tiles_cache = seqlen_cache_kv / get<1>(TileShapeQK{}); bool has_new = params.ptr_new_k != nullptr; - Tensor mNewK = make_tensor(make_gmem_ptr(params.ptr_new_k), select<1,2,3>(problem_shape), params.dNewK); - Tensor mNewV = make_tensor(make_gmem_ptr(params.ptr_new_v), select<1,2,3>(problem_shape), params.dNewV); + Tensor mNewK = make_tensor( + make_gmem_ptr(params.ptr_new_k), + select<1, 2, 3>(problem_shape), + params.dNewK); + Tensor mNewV = make_tensor( + make_gmem_ptr(params.ptr_new_v), + select<1, 2, 3>(problem_shape), + params.dNewV); Tensor gNewK = mNewK(_, _, get<2>(blk_coord)); Tensor gNewV = mNewV(_, _, get<2>(blk_coord)); @@ -278,8 +325,12 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized { pipeline_kv.producer_acquire(state); if (k_index < full_tiles_cache) { - copy(tiled_copy_k, tKgK(_, _, _, _, k_index), tKsK(_, _, _, _, state.index())); - pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + copy( + tiled_copy_k, + tKgK(_, _, _, _, k_index), + tKsK(_, _, _, _, state.index())); + pipeline_kv.producer_commit( + state, cutlass::arch::cpasync_barrier_arrive); } else { using Vec = uint128_t; Vec vzero = uint128_t(0, 0); @@ -290,7 +341,7 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized { int vlen = sizeof(Vec) / sizeof(Element); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(src); i++) { - auto cc = c(vlen*i); + auto cc = c(vlen * i); Vec* dst_ptr = &dst(i); const Vec* src_ptr = &src(i); bool guard = elem_less(cc, limitK); @@ -298,12 +349,13 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized { src_ptr = &src2(_0{}, get<1>(cc) / vlen); guard = true; } - cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Global>( - dst_ptr, src_ptr, guard - ); + cutlass::arch:: + cp_async_zfill<16, cutlass::arch::CacheOperation::Global>( + dst_ptr, src_ptr, guard); } - - pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + + pipeline_kv.producer_commit( + state, cutlass::arch::cpasync_barrier_arrive); } }; @@ -311,8 +363,12 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized { pipeline_kv.producer_acquire(state); if (v_index < full_tiles_cache) { - copy(tiled_copy_v, tVgV(_, _, _, _, v_index), tVsV(_, _, _, _, state.index())); - pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + copy( + tiled_copy_v, + tVgV(_, _, _, _, v_index), + tVsV(_, _, _, _, state.index())); + pipeline_kv.producer_commit( + state, cutlass::arch::cpasync_barrier_arrive); } else { using Vec = uint128_t; Vec vzero = uint128_t(0, 0); @@ -324,7 +380,7 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(src); i++) { - auto cc = c(vlen*i); + auto cc = c(vlen * i); Vec* dst_ptr = &dst(i); const Vec* src_ptr = &src(i); bool guard = elem_less(cc, limitV); @@ -332,12 +388,13 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized { src_ptr = &src2(_0{}, get<0>(cc) / vlen); guard = true; } - cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Global>( - dst_ptr, src_ptr, guard - ); + cutlass::arch:: + cp_async_zfill<16, cutlass::arch::CacheOperation::Global>( + dst_ptr, src_ptr, guard); } - pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + pipeline_kv.producer_commit( + state, cutlass::arch::cpasync_barrier_arrive); } }; @@ -353,12 +410,11 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized { mask_tile_count -= 1; for (; mask_tile_count > 0; mask_tile_count -= 1) { - load_k(k_index, pipeline_kv_producer_state); ++pipeline_kv_producer_state; k_index += 1; - + load_v(v_index, pipeline_kv_producer_state); ++pipeline_kv_producer_state; @@ -371,7 +427,7 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized { ++pipeline_kv_producer_state; v_index += 1; - + if (has_new) { for (int i = thread_idx; i < get<2>(TileShape{}); i += 64) { gK(seqlen_cache_kv, i, 0) = gNewK(0, i); @@ -379,7 +435,6 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized { } } } - }; -} // namespace cutlass::fmha::collective +} // namespace cutlass::fmha::collective diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp index a1c6d627be..c3b4486966 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp @@ -1,13 +1,13 @@ // @nolint /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -19,14 +19,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ @@ -47,7 +48,6 @@ using namespace cute; using namespace cutlass::fmha::collective; struct Sm100FmhaGenKernelWarpspecializedSchedule { - enum class WarpRole { Softmax0, Softmax1, @@ -59,12 +59,17 @@ struct Sm100FmhaGenKernelWarpspecializedSchedule { }; static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { - if (warp_idx == 0) return WarpRole::Softmax0; // 0 - 3 - if (warp_idx == 1) return WarpRole::MMA; // 12 - if (warp_idx == 2 || warp_idx == 3) return WarpRole::Load; // 13 - if (warp_idx == 4) return WarpRole::Softmax1; // 4 - 7 - if (warp_idx == 8) return WarpRole::Correction; // 8 - 11 - return WarpRole::Empty; // 15 + if (warp_idx == 0) + return WarpRole::Softmax0; // 0 - 3 + if (warp_idx == 1) + return WarpRole::MMA; // 12 + if (warp_idx == 2 || warp_idx == 3) + return WarpRole::Load; // 13 + if (warp_idx == 4) + return WarpRole::Softmax1; // 4 - 7 + if (warp_idx == 8) + return WarpRole::Correction; // 8 - 11 + return WarpRole::Empty; // 15 } static const int NumWarpsSoftmax = 1; @@ -78,18 +83,15 @@ struct Sm100FmhaGenKernelWarpspecializedSchedule { static const int NumRegsEmpty = 24; static const int NumWarps = 12; - }; -template< - class ProblemShapeIn, - class CollectiveMainloop, - class CollectiveEpilogue, - class TileScheduler, - class KernelSchedule = Sm100FmhaGenKernelWarpspecializedSchedule -> +template < + class ProblemShapeIn, + class CollectiveMainloop, + class CollectiveEpilogue, + class TileScheduler, + class KernelSchedule = Sm100FmhaGenKernelWarpspecializedSchedule> struct Sm100FmhaGenKernelWarpspecialized { - using TileShape = typename CollectiveMainloop::TileShape; using ProblemShape = decltype(replace<0>(ProblemShapeIn{}, 0)); @@ -121,14 +123,19 @@ struct Sm100FmhaGenKernelWarpspecialized { struct PipelineStorage { alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q; - alignas(16) typename CollectiveMainloop::PipelineKV::SharedStorage load_kv; + alignas(16) + typename CollectiveMainloop::PipelineKV::SharedStorage load_kv; alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s0; alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s1; alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s0_corr; alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s1_corr; - alignas(16) typename CollectiveMainloop::PipelineO::SharedStorage mma_corr; - alignas(16) typename CollectiveMainloop::PipelineE::SharedStorage corr_epi; - alignas(16) typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage order_s01; + alignas(16) + typename CollectiveMainloop::PipelineO::SharedStorage mma_corr; + alignas(16) + typename CollectiveMainloop::PipelineE::SharedStorage corr_epi; + alignas(16) + typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage + order_s01; } pipelines; uint32_t tmem_base_ptr; @@ -154,18 +161,18 @@ struct Sm100FmhaGenKernelWarpspecialized { const int* seqlen_kv; const int* cache_batch_idx; - const Element* ptr_q; // 1 x D x (H x B) + const Element* ptr_q; // 1 x D x (H x B) StrideQOrig dQ; const Element* ptr_new_k; // 1 x D x (H x B) StrideNewK dNewK; const Element* ptr_new_v; // 1 x D x (H x B) StrideNewV dNewV; - - Element* ptr_cache_k; // seqlen_max x D x (H x B) + + Element* ptr_cache_k; // seqlen_max x D x (H x B) StrideCacheK dCacheK; - Element* ptr_cache_v; // seqlen_max x D x (H x B) + Element* ptr_cache_v; // seqlen_max x D x (H x B) StrideCacheV dCacheV; - ElementOut* ptr_o; // 1 x D x (H x B) + ElementOut* ptr_o; // 1 x D x (H x B) StrideOOrig dO; cutlass::KernelHardwareInfo hw_info; @@ -185,8 +192,11 @@ struct Sm100FmhaGenKernelWarpspecialized { static const int MaxThreadsPerBlock = NumWarps * cutlass::NumThreadsPerWarp; using ArchTag = cutlass::arch::Sm100; - static size_t get_workspace_size(Arguments const& args) { return 0; } - static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + static size_t get_workspace_size(Arguments const& args) { + return 0; + } + static cutlass::Status + initialize_workspace(Arguments const&, void*, cudaStream_t) { return cutlass::Status::kSuccess; } @@ -203,42 +213,55 @@ struct Sm100FmhaGenKernelWarpspecialized { return block; } - static Params to_underlying_arguments(Arguments const& args, void* workspace) { - ProblemShape problem_shape = replace<0>(args.problem_shape, static_cast(get<0>(args.problem_shape))); + static Params to_underlying_arguments( + Arguments const& args, + void* workspace) { + ProblemShape problem_shape = replace<0>( + args.problem_shape, static_cast(get<0>(args.problem_shape))); CUTE_STATIC_ASSERT_V(get<0>(args.problem_shape) == _1{}); StrideQ dQ = replace<0>(args.dQ, 0); StrideO dO = replace<0>(args.dO, 0); - get<0>(problem_shape) = get<3,0,0>(args.problem_shape); - get<3,0,0>(problem_shape) = 1; - get<0>(dQ) = get<2,0,0>(dQ); - get<0>(dO) = get<2,0,0>(dO); - - typename CollectiveMainloop::Arguments mainloop_args { - { - args.cache_batch_idx, - args.ptr_q, dQ, - args.ptr_new_k, args.dNewK, - args.ptr_new_v, args.dNewV, - args.ptr_cache_k, args.dCacheK, - args.ptr_cache_v, args.dCacheV, - }, - args.scale_softmax - }; - - typename CollectiveEpilogue::Arguments epilogue_args { - args.ptr_o, dO, + get<0>(problem_shape) = get<3, 0, 0>(args.problem_shape); + get<3, 0, 0>(problem_shape) = 1; + get<0>(dQ) = get<2, 0, 0>(dQ); + get<0>(dO) = get<2, 0, 0>(dO); + + typename CollectiveMainloop::Arguments mainloop_args{ + { + args.cache_batch_idx, + args.ptr_q, + dQ, + args.ptr_new_k, + args.dNewK, + args.ptr_new_v, + args.dNewV, + args.ptr_cache_k, + args.dCacheK, + args.ptr_cache_v, + args.dCacheV, + }, + args.scale_softmax}; + + typename CollectiveEpilogue::Arguments epilogue_args{ + args.ptr_o, + dO, }; return Params{ problem_shape, args.seqlen_kv, - CollectiveMainloop::to_underlying_arguments(problem_shape, mainloop_args, workspace), - CollectiveEpilogue::to_underlying_arguments(problem_shape, epilogue_args, workspace), - TileScheduler::to_underlying_arguments(problem_shape, args.hw_info, ClusterShape{}, TileShape{}) - }; + CollectiveMainloop::to_underlying_arguments( + problem_shape, mainloop_args, workspace), + CollectiveEpilogue::to_underlying_arguments( + problem_shape, epilogue_args, workspace), + TileScheduler::to_underlying_arguments( + problem_shape, args.hw_info, ClusterShape{}, TileShape{})}; } - CUTLASS_DEVICE auto apply_batch(const Params ¶ms, ProblemShape const& problem_shape, int batch_idx) { + CUTLASS_DEVICE auto apply_batch( + const Params& params, + ProblemShape const& problem_shape, + int batch_idx) { ProblemShape result = problem_shape; get<1>(result) = params.seqlen_kv[batch_idx]; if (params.mainloop.load.ptr_new_k != nullptr) { @@ -247,8 +270,7 @@ struct Sm100FmhaGenKernelWarpspecialized { return result; } - CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { - + CUTLASS_DEVICE void operator()(const Params& params, char* smem) { TileScheduler tile_scheduler{params.tile_scheduler}; int warp_idx = cutlass::canonical_warp_idx_sync(); @@ -267,116 +289,153 @@ struct Sm100FmhaGenKernelWarpspecialized { typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params; if (role == WarpRole::Load) { - pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer; + pipeline_load_q_params.role = + CollectiveMainloop::PipelineQ::ThreadCategory::Producer; } if (role == WarpRole::MMA) { - pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Consumer; + pipeline_load_q_params.role = + CollectiveMainloop::PipelineQ::ThreadCategory::Consumer; } - pipeline_load_q_params.producer_arv_count = NumWarpsLoad * cutlass::NumThreadsPerWarp; + pipeline_load_q_params.producer_arv_count = + NumWarpsLoad * cutlass::NumThreadsPerWarp; typename CollectiveMainloop::PipelineQ pipeline_load_q( - shared_storage.pipelines.load_q, - pipeline_load_q_params, - ClusterShape{}, cute::true_type{}, /*mask calc*/cute::false_type{}); - + shared_storage.pipelines.load_q, + pipeline_load_q_params, + ClusterShape{}, + cute::true_type{}, + /*mask calc*/ cute::false_type{}); + typename CollectiveMainloop::PipelineKV::Params pipeline_load_kv_params; if (role == WarpRole::Load) { - pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Producer; + pipeline_load_kv_params.role = + CollectiveMainloop::PipelineKV::ThreadCategory::Producer; } if (role == WarpRole::MMA) { - pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer; + pipeline_load_kv_params.role = + CollectiveMainloop::PipelineKV::ThreadCategory::Consumer; } - pipeline_load_kv_params.producer_arv_count = NumWarpsLoad * cutlass::NumThreadsPerWarp; + pipeline_load_kv_params.producer_arv_count = + NumWarpsLoad * cutlass::NumThreadsPerWarp; typename CollectiveMainloop::PipelineKV pipeline_load_kv( - shared_storage.pipelines.load_kv, - pipeline_load_kv_params, - ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + shared_storage.pipelines.load_kv, + pipeline_load_kv_params, + ClusterShape{}, + /*barrier init*/ cute::true_type{}, + /*mask calc*/ cute::false_type{}); typename CollectiveMainloop::PipelineS::Params pipeline_mma_s0_params; if (role == WarpRole::MMA) { - pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + pipeline_mma_s0_params.role = + CollectiveMainloop::PipelineS::ThreadCategory::Producer; } if (role == WarpRole::Softmax0) { - pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + pipeline_mma_s0_params.role = + CollectiveMainloop::PipelineS::ThreadCategory::Consumer; } - pipeline_mma_s0_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_mma_s0_params.consumer_arv_count = + NumWarpsSoftmax * cutlass::NumThreadsPerWarp; typename CollectiveMainloop::PipelineS pipeline_mma_s0( - shared_storage.pipelines.mma_s0, - pipeline_mma_s0_params, - ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + shared_storage.pipelines.mma_s0, + pipeline_mma_s0_params, + ClusterShape{}, + /*barrier init*/ cute::true_type{}, + /*mask calc*/ cute::false_type{}); typename CollectiveMainloop::PipelineS::Params pipeline_mma_s1_params; if (role == WarpRole::MMA) { - pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + pipeline_mma_s1_params.role = + CollectiveMainloop::PipelineS::ThreadCategory::Producer; } if (role == WarpRole::Softmax1) { - pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + pipeline_mma_s1_params.role = + CollectiveMainloop::PipelineS::ThreadCategory::Consumer; } - pipeline_mma_s1_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_mma_s1_params.consumer_arv_count = + NumWarpsSoftmax * cutlass::NumThreadsPerWarp; typename CollectiveMainloop::PipelineS pipeline_mma_s1( - shared_storage.pipelines.mma_s1, - pipeline_mma_s1_params, - ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + shared_storage.pipelines.mma_s1, + pipeline_mma_s1_params, + ClusterShape{}, + /*barrier init*/ cute::true_type{}, + /*mask calc*/ cute::false_type{}); typename CollectiveMainloop::PipelineC::Params pipeline_s0_corr_params; if (role == WarpRole::Softmax0) { - pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + pipeline_s0_corr_params.role = + CollectiveMainloop::PipelineC::ThreadCategory::Producer; } if (role == WarpRole::Correction) { - pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + pipeline_s0_corr_params.role = + CollectiveMainloop::PipelineC::ThreadCategory::Consumer; } - pipeline_s0_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; - pipeline_s0_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + pipeline_s0_corr_params.producer_arv_count = + NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s0_corr_params.consumer_arv_count = + NumWarpsCorrection * cutlass::NumThreadsPerWarp; typename CollectiveMainloop::PipelineC pipeline_s0_corr( - shared_storage.pipelines.s0_corr, - pipeline_s0_corr_params, - /*barrier init*/ cute::true_type{}); + shared_storage.pipelines.s0_corr, + pipeline_s0_corr_params, + /*barrier init*/ cute::true_type{}); typename CollectiveMainloop::PipelineC::Params pipeline_s1_corr_params; if (role == WarpRole::Softmax1) { - pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + pipeline_s1_corr_params.role = + CollectiveMainloop::PipelineC::ThreadCategory::Producer; } if (role == WarpRole::Correction) { - pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + pipeline_s1_corr_params.role = + CollectiveMainloop::PipelineC::ThreadCategory::Consumer; } - pipeline_s1_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; - pipeline_s1_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + pipeline_s1_corr_params.producer_arv_count = + NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s1_corr_params.consumer_arv_count = + NumWarpsCorrection * cutlass::NumThreadsPerWarp; typename CollectiveMainloop::PipelineC pipeline_s1_corr( - shared_storage.pipelines.s1_corr, - pipeline_s1_corr_params, - /*barrier init*/ cute::true_type{}); + shared_storage.pipelines.s1_corr, + pipeline_s1_corr_params, + /*barrier init*/ cute::true_type{}); typename CollectiveMainloop::PipelineO::Params pipeline_mma_corr_params; if (role == WarpRole::MMA) { - pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Producer; + pipeline_mma_corr_params.role = + CollectiveMainloop::PipelineO::ThreadCategory::Producer; } if (role == WarpRole::Correction) { - pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Consumer; + pipeline_mma_corr_params.role = + CollectiveMainloop::PipelineO::ThreadCategory::Consumer; } - pipeline_mma_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + pipeline_mma_corr_params.consumer_arv_count = + NumWarpsCorrection * cutlass::NumThreadsPerWarp; typename CollectiveMainloop::PipelineO pipeline_mma_corr( - shared_storage.pipelines.mma_corr, - pipeline_mma_corr_params, - ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + shared_storage.pipelines.mma_corr, + pipeline_mma_corr_params, + ClusterShape{}, + /*barrier init*/ cute::true_type{}, + /*mask calc*/ cute::false_type{}); typename CollectiveMainloop::PipelineE::Params pipeline_corr_epi_params; if (role == WarpRole::Correction) { - pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Producer; + pipeline_corr_epi_params.role = + CollectiveMainloop::PipelineE::ThreadCategory::Producer; } if (role == WarpRole::Epilogue) { - pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer; + pipeline_corr_epi_params.role = + CollectiveMainloop::PipelineE::ThreadCategory::Consumer; } - pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; - pipeline_corr_epi_params.consumer_arv_count = cute::max(1, NumWarpsEpilogue * cutlass::NumThreadsPerWarp); + pipeline_corr_epi_params.producer_arv_count = + NumWarpsCorrection * cutlass::NumThreadsPerWarp; + pipeline_corr_epi_params.consumer_arv_count = + cute::max(1, NumWarpsEpilogue * cutlass::NumThreadsPerWarp); typename CollectiveMainloop::PipelineE pipeline_corr_epi( - shared_storage.pipelines.corr_epi, - pipeline_corr_epi_params, - /*barrier init*/ cute::true_type{}); + shared_storage.pipelines.corr_epi, + pipeline_corr_epi_params, + /*barrier init*/ cute::true_type{}); typename CollectiveMainloop::OrderBarrierSoftmax::Params params_order_s01; params_order_s01.group_id = role == WarpRole::Softmax1 ? 1 : 0; params_order_s01.group_size = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; typename CollectiveMainloop::OrderBarrierSoftmax order_s01( - shared_storage.pipelines.order_s01, params_order_s01); + shared_storage.pipelines.order_s01, params_order_s01); TmemAllocator tmem_allocator; @@ -388,29 +447,53 @@ struct Sm100FmhaGenKernelWarpspecialized { pipeline_mma_s1.init_masks(ClusterShape{}); pipeline_mma_corr.init_masks(ClusterShape{}); - typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_consumer_state; - typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_producer_state = cutlass::make_producer_start_state(); - - typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_consumer_state; - typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_producer_state = cutlass::make_producer_start_state(); - - typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_consumer_state; - typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_producer_state = cutlass::make_producer_start_state(); - - typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_consumer_state; - typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_producer_state = cutlass::make_producer_start_state(); - - typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_consumer_state; - typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_producer_state = cutlass::make_producer_start_state(); - - typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_consumer_state; - typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_producer_state = cutlass::make_producer_start_state(); - - typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_consumer_state; - typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_producer_state = cutlass::make_producer_start_state(); - - typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_consumer_state; - typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state(); + typename CollectiveMainloop::PipelineQ::PipelineState + pipeline_load_q_consumer_state; + typename CollectiveMainloop::PipelineQ::PipelineState + pipeline_load_q_producer_state = cutlass::make_producer_start_state< + typename CollectiveMainloop::PipelineQ>(); + + typename CollectiveMainloop::PipelineKV::PipelineState + pipeline_load_kv_consumer_state; + typename CollectiveMainloop::PipelineKV::PipelineState + pipeline_load_kv_producer_state = cutlass::make_producer_start_state< + typename CollectiveMainloop::PipelineKV>(); + + typename CollectiveMainloop::PipelineS::PipelineState + pipeline_mma_s0_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState + pipeline_mma_s0_producer_state = cutlass::make_producer_start_state< + typename CollectiveMainloop::PipelineS>(); + + typename CollectiveMainloop::PipelineS::PipelineState + pipeline_mma_s1_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState + pipeline_mma_s1_producer_state = cutlass::make_producer_start_state< + typename CollectiveMainloop::PipelineS>(); + + typename CollectiveMainloop::PipelineC::PipelineState + pipeline_s0_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState + pipeline_s0_corr_producer_state = cutlass::make_producer_start_state< + typename CollectiveMainloop::PipelineC>(); + + typename CollectiveMainloop::PipelineC::PipelineState + pipeline_s1_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState + pipeline_s1_corr_producer_state = cutlass::make_producer_start_state< + typename CollectiveMainloop::PipelineC>(); + + typename CollectiveMainloop::PipelineE::PipelineState + pipeline_corr_epi_consumer_state; + typename CollectiveMainloop::PipelineE::PipelineState + pipeline_corr_epi_producer_state = cutlass::make_producer_start_state< + typename CollectiveMainloop::PipelineE>(); + + typename CollectiveMainloop::PipelineO::PipelineState + pipeline_mma_corr_consumer_state; + typename CollectiveMainloop::PipelineO::PipelineState + pipeline_mma_corr_producer_state = cutlass::make_producer_start_state< + typename CollectiveMainloop::PipelineO>(); CollectiveMainloop mainloop; CollectiveEpilogue epilogue(params.epilogue); @@ -422,156 +505,168 @@ struct Sm100FmhaGenKernelWarpspecialized { for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); - auto logical_problem_shape = apply_batch(params, - params.problem_shape, get<2,1>(blk_coord)); + auto logical_problem_shape = + apply_batch(params, params.problem_shape, get<2, 1>(blk_coord)); - if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + if (get<0>(blk_coord) * get<0>(TileShape{}) >= + get<0>(logical_problem_shape)) { continue; } bool is_softmax_0 = role == WarpRole::Softmax0; mainloop.softmax( - is_softmax_0 ? 0 : 1, blk_coord, - params.mainloop, logical_problem_shape, - is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1, - is_softmax_0 ? pipeline_mma_s0_consumer_state : pipeline_mma_s1_consumer_state, - is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr, - is_softmax_0 ? pipeline_s0_corr_producer_state : pipeline_s1_corr_producer_state, - order_s01 - ); - + is_softmax_0 ? 0 : 1, + blk_coord, + params.mainloop, + logical_problem_shape, + is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1, + is_softmax_0 ? pipeline_mma_s0_consumer_state + : pipeline_mma_s1_consumer_state, + is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr, + is_softmax_0 ? pipeline_s0_corr_producer_state + : pipeline_s1_corr_producer_state, + order_s01); } - } - else if (role == WarpRole::Correction) { + } else if (role == WarpRole::Correction) { cutlass::arch::warpgroup_reg_dealloc(); CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); - auto logical_problem_shape = apply_batch(params, - params.problem_shape, get<2,1>(blk_coord)); + auto logical_problem_shape = + apply_batch(params, params.problem_shape, get<2, 1>(blk_coord)); - if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + if (get<0>(blk_coord) * get<0>(TileShape{}) >= + get<0>(logical_problem_shape)) { continue; } mainloop.correction( - blk_coord, - params.mainloop, logical_problem_shape, - shared_storage.epilogue, - pipeline_s0_corr, pipeline_s0_corr_consumer_state, - pipeline_s1_corr, pipeline_s1_corr_consumer_state, - pipeline_mma_corr, pipeline_mma_corr_consumer_state, - pipeline_corr_epi, pipeline_corr_epi_producer_state, - epilogue - ); - - + blk_coord, + params.mainloop, + logical_problem_shape, + shared_storage.epilogue, + pipeline_s0_corr, + pipeline_s0_corr_consumer_state, + pipeline_s1_corr, + pipeline_s1_corr_consumer_state, + pipeline_mma_corr, + pipeline_mma_corr_consumer_state, + pipeline_corr_epi, + pipeline_corr_epi_producer_state, + epilogue); } if constexpr (NumWarpsEpilogue == 0) { static_assert(NumWarpsCorrection == 1); uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; - tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + tmem_allocator.free( + free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); } - } - else if (role == WarpRole::MMA) { + } else if (role == WarpRole::MMA) { warpgroup_reg_set(); - tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + tmem_allocator.allocate( + TmemAllocator::Sm100TmemCapacityColumns, + &shared_storage.tmem_base_ptr); __syncwarp(); CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); - auto logical_problem_shape = apply_batch(params, - params.problem_shape, get<2,1>(blk_coord)); + auto logical_problem_shape = + apply_batch(params, params.problem_shape, get<2, 1>(blk_coord)); - if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + if (get<0>(blk_coord) * get<0>(TileShape{}) >= + get<0>(logical_problem_shape)) { continue; } - mainloop.mma( - blk_coord, - params.mainloop, logical_problem_shape, - shared_storage.mainloop, - pipeline_load_q, pipeline_load_q_consumer_state, - pipeline_load_kv, pipeline_load_kv_consumer_state, - pipeline_mma_s0, pipeline_mma_s0_producer_state, - pipeline_mma_s1, pipeline_mma_s1_producer_state, - pipeline_mma_corr, pipeline_mma_corr_producer_state - ); - - + blk_coord, + params.mainloop, + logical_problem_shape, + shared_storage.mainloop, + pipeline_load_q, + pipeline_load_q_consumer_state, + pipeline_load_kv, + pipeline_load_kv_consumer_state, + pipeline_mma_s0, + pipeline_mma_s0_producer_state, + pipeline_mma_s1, + pipeline_mma_s1_producer_state, + pipeline_mma_corr, + pipeline_mma_corr_producer_state); } - } - else if (role == WarpRole::Load) { + } else if (role == WarpRole::Load) { warpgroup_reg_set(); CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); - auto logical_problem_shape = apply_batch(params, - params.problem_shape, get<2,1>(blk_coord)); + auto logical_problem_shape = + apply_batch(params, params.problem_shape, get<2, 1>(blk_coord)); - if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + if (get<0>(blk_coord) * get<0>(TileShape{}) >= + get<0>(logical_problem_shape)) { continue; } mainloop.load( - blk_coord, logical_problem_shape, - params.mainloop, params.problem_shape, - shared_storage.mainloop, - pipeline_load_q, pipeline_load_q_producer_state, - pipeline_load_kv, pipeline_load_kv_producer_state - ); - + blk_coord, + logical_problem_shape, + params.mainloop, + params.problem_shape, + shared_storage.mainloop, + pipeline_load_q, + pipeline_load_q_producer_state, + pipeline_load_kv, + pipeline_load_kv_producer_state); } - } - else if (role == WarpRole::Epilogue) { + } else if (role == WarpRole::Epilogue) { warpgroup_reg_set(); CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); - auto logical_problem_shape = apply_batch(params, - params.problem_shape, get<2,1>(blk_coord)); + auto logical_problem_shape = + apply_batch(params, params.problem_shape, get<2, 1>(blk_coord)); - if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + if (get<0>(blk_coord) * get<0>(TileShape{}) >= + get<0>(logical_problem_shape)) { continue; } epilogue.store( - blk_coord, logical_problem_shape, - params.epilogue, params.problem_shape, - shared_storage.epilogue, - pipeline_corr_epi, pipeline_corr_epi_consumer_state - ); - + blk_coord, + logical_problem_shape, + params.epilogue, + params.problem_shape, + shared_storage.epilogue, + pipeline_corr_epi, + pipeline_corr_epi_consumer_state); } static_assert(NumWarpsEpilogue <= 1); if constexpr (NumWarpsEpilogue == 1) { uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; - tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + tmem_allocator.free( + free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); } - } - else if (role == WarpRole::Empty) { + } else if (role == WarpRole::Empty) { warpgroup_reg_set(); /* no-op, donate regs and exit */ } } - }; -} // namespace cutlass::fmha::kernel +} // namespace cutlass::fmha::kernel diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp index b89d41c3a8..7f975e6e90 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp @@ -1,13 +1,12 @@ -// @nolint /*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -19,24 +18,24 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ - #pragma once #include "cutlass/cutlass.h" -#include "cute/tensor.hpp" #include "cute/arch/simd_sm100.hpp" +#include "cute/tensor.hpp" #include "cutlass/barrier.h" #include "cutlass/arch/arch.h" @@ -44,15 +43,15 @@ #include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/gemm/collective/collective_builder.hpp" -#include "gather_tensor.hpp" // from examples/common #include "common/pow_2.hpp" +#include "gather_tensor.hpp" // from examples/common #include "sm100_mla_tile_scheduler.hpp" namespace cutlass::fmha::kernel { using namespace cute; -template< +template < class TileShape, class Element_, class ElementAcc_, @@ -64,14 +63,14 @@ template< #else bool kIsCpAsync = false #endif -> + > struct Sm100FmhaMlaKernelTmaWarpspecialized { using Element = Element_; using ElementAcc = ElementAcc_; using ElementOut = ElementOut_; using ElementLSE = ElementLSE_; - + // only 2Sm mode is supported static const bool kIs2Sm = true; static const int MaxThreadsPerBlock = 256; @@ -80,7 +79,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { static const int TotalPNum = 2; using ArchTag = cutlass::arch::Sm100; - using ClusterShape = cute::conditional_t, Shape<_1, _1, _1>>; + using ClusterShape = + cute::conditional_t, Shape<_1, _1, _1>>; using TileShapeH = tuple_element_t<0, TileShape>; using TileShapeS = tuple_element_t<1, TileShape>; @@ -88,11 +88,14 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { using TileShapeL = tuple_element_t<0, TileShapeD>; using TileShapeR = tuple_element_t<1, TileShapeD>; - static_assert(TileShapeL{} % TileShapeR{} == 0, "Rope head dim must divide latent head dim"); + static_assert( + TileShapeL{} % TileShapeR{} == 0, + "Rope head dim must divide latent head dim"); - using ProblemShape = Shape; - using TensorStride = Stride; - using TmemAllocator = cute::conditional_t; + using ProblemShape = Shape; + using TensorStride = Stride; + using TmemAllocator = cute:: + conditional_t; static_assert(TileShapeH{} == 128); static const int kWarpsInN = kIs2Sm ? 2 : 1; @@ -101,62 +104,97 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { static const int kNumLoadWarps = kIsCpAsync ? 2 : 1; enum class WarpRole { - kMma = 0x1, kLoad = 0x2, kCompute = 0x3, kLoadPageTable = 0x4, kEmpty=0x0 + kMma = 0x1, + kLoad = 0x2, + kCompute = 0x3, + kLoadPageTable = 0x4, + kEmpty = 0x0 }; - static const long long unsigned int kWarpAssignment = kIsCpAsync ? 0x4221'3333ull : 0x0021'3333ull; + static const long long unsigned int kWarpAssignment = + kIsCpAsync ? 0x4221'3333ull : 0x0021'3333ull; static CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { - return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); } static const int Alignment = 128 / sizeof_bits_v; static const int AlignmentOut = 128 / sizeof_bits_v; - using TileShapeQK = Shape; - static const int StagesQK = 24 / sizeof(Element); // free parameter - static const int IterationsQKLatent = decltype(TileShapeL{} / get<2>(TileShapeQK{}))::value; - static const int IterationsQKRope = decltype(TileShapeR{} / get<2>(TileShapeQK{}))::value; + using TileShapeQK = + Shape; + static const int StagesQK = 24 / sizeof(Element); // free parameter + static const int IterationsQKLatent = + decltype(TileShapeL{} / get<2>(TileShapeQK{}))::value; + static const int IterationsQKRope = + decltype(TileShapeR{} / get<2>(TileShapeQK{}))::value; static const int IterationsQK = IterationsQKLatent + IterationsQKRope; - using Schedule = cute::conditional_t; + using Schedule = cute::conditional_t< + kIs2Sm, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>; using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - Element, TensorStride, Alignment, - Element, TensorStride, Alignment, + cutlass::arch::Sm100, + cutlass::arch::OpClassTensorOp, + Element, + TensorStride, + Alignment, + Element, + TensorStride, + Alignment, ElementAcc, - TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount, + TileShapeQK, + ClusterShape, + cutlass::gemm::collective::StageCount, Schedule>::CollectiveOp; using TiledMmaQK = typename CollectiveMmaQK::TiledMma; using CtaShapeQK = typename CollectiveMmaQK::CtaShape_MNK; // chosen for unified smem staging between K and V using TileShapePV = Shape; - using TransposeTensorStride = decltype(select<1,0,2>(TensorStride{})); - static const int StagesPV = StagesQK; // not sure why, but must be at least two. check pipes - static const int IterationsPV_K = decltype(TileShapeS{} / get<2>(TileShapePV{}))::value; - static const int IterationsPV_N = decltype(TileShapeL{} / get<1>(TileShapePV{}))::value; + using TransposeTensorStride = decltype(select<1, 0, 2>(TensorStride{})); + static const int StagesPV = + StagesQK; // not sure why, but must be at least two. check pipes + static const int IterationsPV_K = + decltype(TileShapeS{} / get<2>(TileShapePV{}))::value; + static const int IterationsPV_N = + decltype(TileShapeL{} / get<1>(TileShapePV{}))::value; using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - Element, TensorStride, Alignment, - Element, TransposeTensorStride, Alignment, + cutlass::arch::Sm100, + cutlass::arch::OpClassTensorOp, + Element, + TensorStride, + Alignment, + Element, + TransposeTensorStride, + Alignment, ElementAcc, - TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount, + TileShapePV, + ClusterShape, + cutlass::gemm::collective::StageCount, Schedule>::CollectiveOp; using CtaShapePV = typename CollectiveMmaPV::CtaShape_MNK; - static_assert(std::is_same_v); + static_assert( + std::is_same_v); using TiledMmaPV = typename CollectiveMmaPV::TiledMma; using AtomThrShapeMNK = typename CollectiveMmaQK::AtomThrShapeMNK; - static_assert(typename CollectiveMmaQK::AtomThrShapeMNK{} == typename CollectiveMmaPV::AtomThrShapeMNK{}, "schedule must match"); + static_assert( + typename CollectiveMmaQK::AtomThrShapeMNK{} == + typename CollectiveMmaPV::AtomThrShapeMNK{}, + "schedule must match"); static const int StagesPageTable = kIsCpAsync ? StagesPV : 1; // pipelines from load to mma, PipelineTmaUmmaAsync, stages tbd // use expect_tx for Q load - using PipelineLoadQK = cute::conditional_t, PipelineTmaUmmaAsync>; + using PipelineLoadQK = cute::conditional_t< + kIsCpAsync, + PipelineUmmaConsumerAsync, + PipelineTmaUmmaAsync>; using PipelineLoadPV = PipelineLoadQK; // pipeline from mma (Q@K) to softmax, PipelineUmmaAsync, 2 stages using PipelineS = PipelineUmmaAsync; @@ -175,21 +213,35 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { alignas(16) typename PipelinePT::SharedStorage load_page_table; }; - template - static CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) { - return composition(layout, make_tuple(_, _, _, make_layout(stages))); + template + static CUTE_DEVICE constexpr auto unstageSmemLayout( + Layout const& layout, + Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); } - using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutQ = decltype(unstageSmemLayout( + typename CollectiveMmaQK::SmemLayoutA{}, + Int{})); using SmemLayoutKC = typename CollectiveMmaQK::SmemLayoutB; using SmemLayoutVC = typename CollectiveMmaPV::SmemLayoutB; - using SmemLayoutP = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutA{}, make_shape(Int{}, _2{}))); - using SmemLayoutOut = decltype(take<0,2>(typename CollectiveMmaQK::CtaShape_MNK{})); - using TileShapeAcc = decltype(take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{})); - - static const int kBytesLoadQ = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); - static const int kBytesLoadKC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutKC{})) * cute::sizeof_bits_v); - static const int kBytesLoadVC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutVC{})) * cute::sizeof_bits_v); + using SmemLayoutP = decltype(unstageSmemLayout( + typename CollectiveMmaPV::SmemLayoutA{}, + make_shape(Int{}, _2{}))); + using SmemLayoutOut = + decltype(take<0, 2>(typename CollectiveMmaQK::CtaShape_MNK{})); + using TileShapeAcc = + decltype(take<0, 2>(typename CollectiveMmaPV::CtaShape_MNK{})); + + static const int kBytesLoadQ = size(AtomThrShapeMNK{}) * + cutlass::bits_to_bytes(cosize(take<0, 3>(SmemLayoutQ{})) * + cute::sizeof_bits_v); + static const int kBytesLoadKC = size(AtomThrShapeMNK{}) * + cutlass::bits_to_bytes(cosize(take<0, 3>(SmemLayoutKC{})) * + cute::sizeof_bits_v); + static const int kBytesLoadVC = size(AtomThrShapeMNK{}) * + cutlass::bits_to_bytes(cosize(take<0, 3>(SmemLayoutVC{})) * + cute::sizeof_bits_v); // pre-condition for overlapped smem staging static_assert(kBytesLoadKC == kBytesLoadVC); @@ -199,12 +251,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { static const int kTransactionsBytesLoadExtraQ = kBytesLoadQ; static const int kTransactionsBytesLoadPV = kBytesLoadVC; - static const int kNamedBarrierExchange = (int) cutlass::arch::ReservedNamedBarriers::TransformBarrier; - // This Named Barrier is introduced to solve Q tile loading overwritten issue when enable persistent - // tile scheduler for FP8 MLA. - static const int kNamedBarrierEpilogue = (int) cutlass::arch::ReservedNamedBarriers::EpilogueBarrier; - // - static const int kNamedBarrierTmemDealloc = (int) cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier; + static const int kNamedBarrierExchange = + (int)cutlass::arch::ReservedNamedBarriers::TransformBarrier; + // This Named Barrier is introduced to solve Q tile loading overwritten issue + // when enable persistent tile scheduler for FP8 MLA. + static const int kNamedBarrierEpilogue = + (int)cutlass::arch::ReservedNamedBarriers::EpilogueBarrier; + // + static const int kNamedBarrierTmemDealloc = + (int)cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier; enum class TmemAllocation : uint32_t { kSizeS = TileShapeS::value / kWarpsInN, @@ -221,11 +276,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { kTotal = kO0 + kSizeO }; - static_assert(static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, "using too much tmem"); + static_assert( + static_cast(TmemAllocation::kTotal) <= + TmemAllocator::Sm100TmemCapacityColumns, + "using too much tmem"); struct TensorStorage { // to communicate max and row_sum - cute::array smem_exchange; + cute::array + smem_exchange; cute::array smem_page_table; alignas(2048) cute::array> smem_q; union { @@ -245,11 +304,13 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { }; static const int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + static_assert( + SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, + "using too much smem"); struct MainloopArguments { ElementAcc softmax_scale; - + // all tensors strides are (num_heads or seqlen, head_dim, batch) // head_dim stride is always 1 Element* ptr_q_latent; @@ -269,9 +330,10 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { // page table is [batch, seqlen or similar] Stride<_1, int> stride_page_table = {}; int page_count = 0; - int page_size = TileShapeS{}; // powers of two if kIsCpAsync, otherwise TileShapeS + int page_size = + TileShapeS{}; // powers of two if kIsCpAsync, otherwise TileShapeS }; - + struct EpilogueArguments { ElementOut* ptr_o = nullptr; TensorStride stride_o; @@ -291,18 +353,22 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { int* ptr_split_kv = nullptr; bool is_fused_reduction = false; }; - + using TmaLoadQLatent = typename CollectiveMmaQK::Params::TMA_A; using TmaLoadQRope = typename CollectiveMmaQK::Params::TMA_A; using TmaLoadCLatent = typename CollectiveMmaQK::Params::TMA_B; using TmaLoadKRope = typename CollectiveMmaQK::Params::TMA_B; using TmaLoadCLatentTranspose = typename CollectiveMmaPV::Params::TMA_B; - using GmemLayout = decltype(make_layout(Shape{}, Stride{})); + using GmemLayout = decltype(make_layout( + Shape{}, + Stride{})); using SmemLayout = decltype(make_layout(TileShapeAcc{}, LayoutRight{})); - using TmaReduceSum = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, - make_tensor(recast_ptr(nullptr), GmemLayout{}), SmemLayout{})); + using TmaReduceSum = decltype(make_tma_copy( + SM90_TMA_REDUCE_ADD{}, + make_tensor(recast_ptr(nullptr), GmemLayout{}), + SmemLayout{})); struct MainloopParams { TmaLoadQLatent tma_load_q_latent; @@ -324,8 +390,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { ElementAcc output_scale = 1.0f; ElementLSE* ptr_lse_exchange_buff = nullptr; int* ptr_lse_max_exchange_buff = nullptr; - int* ptr_lock = nullptr; // semaphore - TmaReduceSum tma_reduce_sum; + int* ptr_lock = nullptr; // semaphore + TmaReduceSum tma_reduce_sum; }; struct Params { @@ -339,8 +405,11 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { bool is_fused_reduction = false; }; - static Params to_underlying_arguments(Arguments const& args, void* workspace) { - //workspace = nullptr; // let's get an error if one of these needs workspace + static Params to_underlying_arguments( + Arguments const& args, + void* workspace) { + // workspace = nullptr; // let's get an error if one of these needs + // workspace auto [H, K, D, B] = args.problem_shape; auto [L, R] = D; @@ -354,48 +423,62 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { auto params_qk_latent = CollectiveMmaQK::to_underlying_arguments( make_shape(H, K, L, B), - typename CollectiveMmaQK::Arguments { - args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, - args.mainloop.ptr_c_latent, args.mainloop.stride_c_latent, - }, nullptr); + typename CollectiveMmaQK::Arguments{ + args.mainloop.ptr_q_latent, + args.mainloop.stride_q_latent, + args.mainloop.ptr_c_latent, + args.mainloop.stride_c_latent, + }, + nullptr); auto params_qk_latent_paged = CollectiveMmaQK::to_underlying_arguments( make_shape(H, paged_K, L, paged_B), - typename CollectiveMmaQK::Arguments { - args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, - args.mainloop.ptr_c_latent, args.mainloop.stride_c_latent, - }, nullptr); + typename CollectiveMmaQK::Arguments{ + args.mainloop.ptr_q_latent, + args.mainloop.stride_q_latent, + args.mainloop.ptr_c_latent, + args.mainloop.stride_c_latent, + }, + nullptr); auto params_qk_rope = CollectiveMmaQK::to_underlying_arguments( make_shape(H, K, R, B), - typename CollectiveMmaQK::Arguments { - args.mainloop.ptr_q_rope, args.mainloop.stride_q_rope, - args.mainloop.ptr_k_rope, args.mainloop.stride_k_rope, - }, nullptr); + typename CollectiveMmaQK::Arguments{ + args.mainloop.ptr_q_rope, + args.mainloop.stride_q_rope, + args.mainloop.ptr_k_rope, + args.mainloop.stride_k_rope, + }, + nullptr); auto params_qk_rope_paged = CollectiveMmaQK::to_underlying_arguments( make_shape(H, paged_K, R, paged_B), - typename CollectiveMmaQK::Arguments { - args.mainloop.ptr_q_rope, args.mainloop.stride_q_rope, - args.mainloop.ptr_k_rope, args.mainloop.stride_k_rope, - }, nullptr); - - - auto stride_c_latent_transpose = select<1,0,2>(args.mainloop.stride_c_latent); + typename CollectiveMmaQK::Arguments{ + args.mainloop.ptr_q_rope, + args.mainloop.stride_q_rope, + args.mainloop.ptr_k_rope, + args.mainloop.stride_k_rope, + }, + nullptr); + + auto stride_c_latent_transpose = + select<1, 0, 2>(args.mainloop.stride_c_latent); auto params_pv_latent = CollectiveMmaPV::to_underlying_arguments( make_shape(H, L, paged_K, paged_B), - typename CollectiveMmaPV::Arguments { - args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, // dummy, never used - args.mainloop.ptr_c_latent, stride_c_latent_transpose, - }, nullptr); - - MainloopParams mainloop_params { - params_qk_latent.tma_load_a, - params_qk_rope.tma_load_a, - params_qk_latent_paged.tma_load_b, - params_qk_rope_paged.tma_load_b, - params_pv_latent.tma_load_b - }; + typename CollectiveMmaPV::Arguments{ + args.mainloop.ptr_q_latent, + args.mainloop.stride_q_latent, // dummy, never used + args.mainloop.ptr_c_latent, + stride_c_latent_transpose, + }, + nullptr); + + MainloopParams mainloop_params{ + params_qk_latent.tma_load_a, + params_qk_rope.tma_load_a, + params_qk_latent_paged.tma_load_b, + params_qk_rope_paged.tma_load_b, + params_pv_latent.tma_load_b}; EpilogueParams epilogue_params; @@ -404,56 +487,79 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { epilogue_params.ptr_lse = args.epilogue.ptr_lse; epilogue_params.stride_lse = args.epilogue.stride_lse; epilogue_params.output_scale = args.epilogue.output_scale; - epilogue_params.tma_reduce_sum = make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor(recast_ptr(args.epilogue.ptr_o), make_layout(make_shape(H, L, B), args.epilogue.stride_o)), SmemLayout{}); + epilogue_params.tma_reduce_sum = make_tma_copy( + SM90_TMA_REDUCE_ADD{}, + make_tensor( + recast_ptr(args.epilogue.ptr_o), + make_layout(make_shape(H, L, B), args.epilogue.stride_o)), + SmemLayout{}); if (!args.is_fused_reduction && args.split_kv > 1) { - ElementAcc* ptr_o_acc = reinterpret_cast(workspace); - ElementLSE* ptr_lse_acc = reinterpret_cast(ptr_o_acc + H * L * args.split_kv * B); - epilogue_params.ptr_o_acc = ptr_o_acc; + ElementAcc* ptr_o_acc = reinterpret_cast(workspace); + ElementLSE* ptr_lse_acc = + reinterpret_cast(ptr_o_acc + H * L * args.split_kv * B); + epilogue_params.ptr_o_acc = ptr_o_acc; epilogue_params.ptr_lse_acc = ptr_lse_acc; - epilogue_params.stride_o_acc = make_tuple(static_cast(0 + L) * args.split_kv, _1{}, static_cast(0 + H * L) * args.split_kv); - epilogue_params.stride_lse_acc = make_tuple(_1{}, (0 + H) * args.split_kv); + epilogue_params.stride_o_acc = make_tuple( + static_cast(0 + L) * args.split_kv, + _1{}, + static_cast(0 + H * L) * args.split_kv); + epilogue_params.stride_lse_acc = + make_tuple(_1{}, (0 + H) * args.split_kv); } else if (args.is_fused_reduction && args.split_kv > 1) { - ElementLSE* ptr_lse_exchange_buff = reinterpret_cast(workspace); + ElementLSE* ptr_lse_exchange_buff = + reinterpret_cast(workspace); epilogue_params.ptr_lse_exchange_buff = ptr_lse_exchange_buff; - int* ptr_lse_max_exchange_buff = reinterpret_cast(ptr_lse_exchange_buff + H * B); + int* ptr_lse_max_exchange_buff = + reinterpret_cast(ptr_lse_exchange_buff + H * B); epilogue_params.ptr_lse_max_exchange_buff = ptr_lse_max_exchange_buff; int* ptr_lock = ptr_lse_max_exchange_buff + H * B; epilogue_params.ptr_lock = ptr_lock; } - return {args.problem_shape, args.mainloop, epilogue_params, mainloop_params, - TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, args.split_kv), - args.split_kv, args.ptr_split_kv, args.is_fused_reduction}; + return { + args.problem_shape, + args.mainloop, + epilogue_params, + mainloop_params, + TileScheduler::to_underlying_arguments( + args.problem_shape, args.hw_info, ClusterShape{}, args.split_kv), + args.split_kv, + args.ptr_split_kv, + args.is_fused_reduction}; } - static size_t get_workspace_size(Arguments const& args) { + static size_t get_workspace_size(Arguments const& args) { ProblemShape problem_shape = args.problem_shape; auto [H, K, D, B] = problem_shape; auto [D_latent, D_rope] = D; auto split_kv = args.split_kv; - size_t workspace_size {0}; + size_t workspace_size{0}; if (args.is_fused_reduction && args.split_kv > 1) { - // one exchange buffer for LSE max and another buffer for total LSE - // two locks per batch, frist lock is for CTA0 / H=0..63 and the second is for CTA1 / H=64..127 - workspace_size = H * B * (sizeof(int) + sizeof(ElementLSE)) + 2 * B * sizeof(int); + // one exchange buffer for LSE max and another buffer for total LSE + // two locks per batch, frist lock is for CTA0 / H=0..63 and the second + // is for CTA1 / H=64..127 + workspace_size = + H * B * (sizeof(int) + sizeof(ElementLSE)) + 2 * B * sizeof(int); } else if (!args.is_fused_reduction && args.split_kv > 1) { - workspace_size = (sizeof(ElementAcc) * D_latent + sizeof(ElementLSE)) * H * split_kv * B; + workspace_size = (sizeof(ElementAcc) * D_latent + sizeof(ElementLSE)) * + H * split_kv * B; } return workspace_size; } - static Status initialize_workspace( - Arguments const& args, void* ws, cudaStream_t stream) { + static Status + initialize_workspace(Arguments const& args, void* ws, cudaStream_t stream) { auto workspace_size = get_workspace_size(args); if (args.is_fused_reduction && args.split_kv > 1) { auto result = cudaMemsetAsync(ws, 0, workspace_size); if (cudaSuccess != result) { result = cudaGetLastError(); // to clear the error bit CUTLASS_TRACE_HOST( - " cudaMemsetAsync() returned error: " - << cudaGetErrorString(result)); - return Status::kErrorInternal;; + " cudaMemsetAsync() returned error: " + << cudaGetErrorString(result)); + return Status::kErrorInternal; + ; } } return Status::kSuccess; @@ -471,16 +577,18 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { static bool can_implement(Arguments const& args) { if (kIsCpAsync) { if ((args.mainloop.page_size & (args.mainloop.page_size - 1)) != 0) { - std::cerr << __FILE__ << "(" << __LINE__ << "): cpasync page size pow2\n"; + std::cerr << __FILE__ << "(" << __LINE__ + << "): cpasync page size pow2\n"; return false; } if (args.mainloop.page_size > TileShapeS{}) { - std::cerr << __FILE__ << "(" << __LINE__ << "): cpasync page size too big\n"; + std::cerr << __FILE__ << "(" << __LINE__ + << "): cpasync page size too big\n"; return false; } - } - else { - if (args.mainloop.ptr_page_table != nullptr && args.mainloop.page_size != TileShapeS{}) { + } else { + if (args.mainloop.ptr_page_table != nullptr && + args.mainloop.page_size != TileShapeS{}) { std::cerr << __FILE__ << "(" << __LINE__ << "): tma page size off\n"; return false; } @@ -498,17 +606,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { return false; } if (args.is_fused_reduction && args.split_kv > 1) { - if (2 * args.split_kv > args.hw_info.sm_count || - std::is_same_v) { + if (2 * args.split_kv > args.hw_info.sm_count || + std::is_same_v) { return false; } } return true; } - CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw) { - TileScheduler tile_scheduler(params.tile_scheduler); int warp_idx = cutlass::canonical_warp_idx_sync(); @@ -519,12 +625,17 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { int cta_coord_v = cta_rank_in_cluster % size<0>(AtomThrShapeMNK{}); bool is_mma_leader_cta = cta_coord_v == 0; - if (role == WarpRole::kLoad && lane_predicate && ! kIsCpAsync) { - prefetch_tma_descriptor(params.mainloop_params.tma_load_q_latent.get_tma_descriptor()); - prefetch_tma_descriptor(params.mainloop_params.tma_load_c_latent.get_tma_descriptor()); - prefetch_tma_descriptor(params.mainloop_params.tma_load_q_rope.get_tma_descriptor()); - prefetch_tma_descriptor(params.mainloop_params.tma_load_k_rope.get_tma_descriptor()); - prefetch_tma_descriptor(params.mainloop_params.tma_load_c_latent_transpose.get_tma_descriptor()); + if (role == WarpRole::kLoad && lane_predicate && !kIsCpAsync) { + prefetch_tma_descriptor( + params.mainloop_params.tma_load_q_latent.get_tma_descriptor()); + prefetch_tma_descriptor( + params.mainloop_params.tma_load_c_latent.get_tma_descriptor()); + prefetch_tma_descriptor( + params.mainloop_params.tma_load_q_rope.get_tma_descriptor()); + prefetch_tma_descriptor( + params.mainloop_params.tma_load_k_rope.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_c_latent_transpose + .get_tma_descriptor()); } SharedStorage& shared_storage = *reinterpret_cast(smem_raw); @@ -538,15 +649,20 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { if constexpr (kIsCpAsync) { // we can make our life easier by unconditionally loading blocks // since we know it'll always be legal - pipeline_load_qk_params.producer_arv_count = kNumLoadWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); - } - else { - pipeline_load_qk_params.is_leader = lane_predicate && (role == WarpRole::kLoad) && is_mma_leader_cta; + pipeline_load_qk_params.producer_arv_count = + kNumLoadWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + } else { + pipeline_load_qk_params.is_leader = + lane_predicate && (role == WarpRole::kLoad) && is_mma_leader_cta; pipeline_load_qk_params.transaction_bytes = kTransactionsBytesLoadQK; } pipeline_load_qk_params.initializing_warp = 0; - PipelineLoadQK pipeline_load_qk(shared_storage.pipelines.load_qk, pipeline_load_qk_params, - ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + PipelineLoadQK pipeline_load_qk( + shared_storage.pipelines.load_qk, + pipeline_load_qk_params, + ClusterShape{}, + /*barrier init*/ cute::true_type{}, + /*mask calc*/ cute::false_type{}); typename PipelineS::Params pipeline_mma_s_params; if (role == WarpRole::kMma) { @@ -555,12 +671,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { if (role == WarpRole::kCompute) { pipeline_mma_s_params.role = PipelineS::ThreadCategory::Consumer; } - pipeline_mma_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_mma_s_params.consumer_arv_count = + kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); pipeline_mma_s_params.initializing_warp = 1; PipelineS pipeline_mma_s( - shared_storage.pipelines.mma_s, - pipeline_mma_s_params, - ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + shared_storage.pipelines.mma_s, + pipeline_mma_s_params, + ClusterShape{}, + /*barrier init*/ cute::true_type{}, + /*mask calc*/ cute::false_type{}); typename PipelineP::Params pipeline_p_mma_params; if (role == WarpRole::kMma) { @@ -569,13 +688,16 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { if (role == WarpRole::kCompute) { pipeline_p_mma_params.role = PipelineP::ThreadCategory::Producer; } - pipeline_p_mma_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_p_mma_params.producer_arv_count = + kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); pipeline_p_mma_params.consumer_arv_count = 1; pipeline_p_mma_params.initializing_warp = 2; PipelineP pipeline_p_mma( - shared_storage.pipelines.p_mma, - pipeline_p_mma_params, - ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + shared_storage.pipelines.p_mma, + pipeline_p_mma_params, + ClusterShape{}, + /*barrier init*/ cute::true_type{}, + /*mask calc*/ cute::false_type{}); typename PipelineO::Params pipeline_mma_o_params; if (role == WarpRole::kMma) { @@ -584,12 +706,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { if (role == WarpRole::kCompute) { pipeline_mma_o_params.role = PipelineO::ThreadCategory::Consumer; } - pipeline_mma_o_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_mma_o_params.consumer_arv_count = + kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); pipeline_mma_o_params.initializing_warp = 3; PipelineO pipeline_mma_o( - shared_storage.pipelines.mma_o, - pipeline_mma_o_params, - ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + shared_storage.pipelines.mma_o, + pipeline_mma_o_params, + ClusterShape{}, + /*barrier init*/ cute::true_type{}, + /*mask calc*/ cute::false_type{}); typename PipelinePT::Params pipeline_pt_params; if (role == WarpRole::kLoad) { @@ -598,36 +723,42 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { if (role == WarpRole::kLoadPageTable) { pipeline_pt_params.role = PipelinePT::ThreadCategory::Producer; } - pipeline_pt_params.consumer_arv_count = kNumLoadWarps * cutlass::NumThreadsPerWarp; + pipeline_pt_params.consumer_arv_count = + kNumLoadWarps * cutlass::NumThreadsPerWarp; pipeline_pt_params.producer_arv_count = cutlass::NumThreadsPerWarp; pipeline_pt_params.initializing_warp = 4; PipelinePT pipeline_page_table( - shared_storage.pipelines.load_page_table, - pipeline_pt_params); + shared_storage.pipelines.load_page_table, pipeline_pt_params); TmemAllocator tmem_allocator; pipeline_init_arrive_relaxed(size(ClusterShape{})); - pipeline_load_qk.init_masks(ClusterShape{}); // do we need an update here for 2Sm? + pipeline_load_qk.init_masks( + ClusterShape{}); // do we need an update here for 2Sm? pipeline_mma_s.init_masks(ClusterShape{}); pipeline_p_mma.init_masks(ClusterShape{}); pipeline_mma_o.init_masks(ClusterShape{}); typename PipelineLoadQK::PipelineState pipeline_load_qk_consumer_state; - typename PipelineLoadQK::PipelineState pipeline_load_qk_producer_state = cutlass::make_producer_start_state(); + typename PipelineLoadQK::PipelineState pipeline_load_qk_producer_state = + cutlass::make_producer_start_state(); typename PipelineS::PipelineState pipeline_mma_s_consumer_state; - typename PipelineS::PipelineState pipeline_mma_s_producer_state = cutlass::make_producer_start_state(); + typename PipelineS::PipelineState pipeline_mma_s_producer_state = + cutlass::make_producer_start_state(); typename PipelineP::PipelineState pipeline_p_mma_consumer_state; - typename PipelineP::PipelineState pipeline_p_mma_producer_state = cutlass::make_producer_start_state(); + typename PipelineP::PipelineState pipeline_p_mma_producer_state = + cutlass::make_producer_start_state(); typename PipelineO::PipelineState pipeline_mma_o_consumer_state; - typename PipelineO::PipelineState pipeline_mma_o_producer_state = cutlass::make_producer_start_state(); + typename PipelineO::PipelineState pipeline_mma_o_producer_state = + cutlass::make_producer_start_state(); typename PipelinePT::PipelineState pipeline_pt_consumer_state; - typename PipelinePT::PipelineState pipeline_pt_producer_state = cutlass::make_producer_start_state(); + typename PipelinePT::PipelineState pipeline_pt_producer_state = + cutlass::make_producer_start_state(); pipeline_init_wait(size(ClusterShape{})); @@ -636,208 +767,233 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); auto problem_shape = params.problem_shape; - auto local_split_kv = params.split_kv; + auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; - if (params.ptr_split_kv != nullptr) { + if (params.ptr_split_kv != nullptr) { local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; } } - if (local_split_kv <= get<3>(blk_coord)) - continue; + if (local_split_kv <= get<3>(blk_coord)) + continue; load_page_table( - blk_coord, - problem_shape, - params.mainloop, - shared_storage.tensors, - pipeline_page_table, pipeline_pt_producer_state, - local_split_kv - ); + blk_coord, + problem_shape, + params.mainloop, + shared_storage.tensors, + pipeline_page_table, + pipeline_pt_producer_state, + local_split_kv); } - } - else if (role == WarpRole::kLoad) { + } else if (role == WarpRole::kLoad) { if constexpr (kIsCpAsync) { CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); - auto problem_shape = params.problem_shape; - auto local_split_kv = params.split_kv; + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; - if (params.ptr_split_kv != nullptr) { + if (params.ptr_split_kv != nullptr) { local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; } } - if (local_split_kv <= get<3>(blk_coord)) + if (local_split_kv <= get<3>(blk_coord)) continue; load_cpasync( - blk_coord, - problem_shape, - params.mainloop, - params.mainloop_params, - shared_storage.tensors, - pipeline_load_qk, pipeline_load_qk_producer_state, - local_split_kv, - /* must be shared pipe */ - pipeline_page_table, pipeline_pt_consumer_state - ); - cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, + pipeline_load_qk_producer_state, + local_split_kv, + /* must be shared pipe */ + pipeline_page_table, + pipeline_pt_consumer_state); + cutlass::arch::NamedBarrier( + (kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, + kNamedBarrierEpilogue) + .arrive_and_wait(); } - } - else { + } else { if (params.mainloop.ptr_page_table != nullptr) { CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); - auto problem_shape = params.problem_shape; - auto local_split_kv = params.split_kv; + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { - get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; - if (params.ptr_split_kv != nullptr) { - local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; - } + get<1>(problem_shape) = + params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } } - if (local_split_kv <= get<3>(blk_coord)) + if (local_split_kv <= get<3>(blk_coord)) continue; load_tma( - blk_coord, - problem_shape, - params.mainloop, - params.mainloop_params, - shared_storage.tensors, - pipeline_load_qk, pipeline_load_qk_producer_state, - pipeline_load_qk, pipeline_load_qk_producer_state, - local_split_kv - ); - cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, + pipeline_load_qk_producer_state, + pipeline_load_qk, + pipeline_load_qk_producer_state, + local_split_kv); + cutlass::arch::NamedBarrier( + (kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, + kNamedBarrierEpilogue) + .arrive_and_wait(); } - } - else { + } else { CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); - auto problem_shape = params.problem_shape; - auto local_split_kv = params.split_kv; + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { - get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; - if (params.ptr_split_kv != nullptr) { + get<1>(problem_shape) = + params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; - } + } } - if (local_split_kv <= get<3>(blk_coord)) + if (local_split_kv <= get<3>(blk_coord)) continue; load_tma( - blk_coord, - problem_shape, - params.mainloop, - params.mainloop_params, - shared_storage.tensors, - pipeline_load_qk, pipeline_load_qk_producer_state, - pipeline_load_qk, pipeline_load_qk_producer_state, - local_split_kv - ); - cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, + pipeline_load_qk_producer_state, + pipeline_load_qk, + pipeline_load_qk_producer_state, + local_split_kv); + cutlass::arch::NamedBarrier( + (kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, + kNamedBarrierEpilogue) + .arrive_and_wait(); } } } - } - else if (role == WarpRole::kMma) { - tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + } else if (role == WarpRole::kMma) { + tmem_allocator.allocate( + TmemAllocator::Sm100TmemCapacityColumns, + &shared_storage.tmem_base_ptr); __syncwarp(); - + if (is_mma_leader_cta) { CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); auto problem_shape = params.problem_shape; - auto local_split_kv = params.split_kv; + auto local_split_kv = params.split_kv; if (params.mainloop.ptr_seq != nullptr) { get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; if (params.ptr_split_kv != nullptr) { - local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; } } - if (local_split_kv <= get<3>(blk_coord)) + if (local_split_kv <= get<3>(blk_coord)) continue; mma(blk_coord, - problem_shape, - shared_storage.tensors, - pipeline_load_qk, pipeline_load_qk_consumer_state, - pipeline_load_qk, pipeline_load_qk_consumer_state, - pipeline_mma_s, pipeline_mma_s_producer_state, - pipeline_p_mma, pipeline_p_mma_consumer_state, - pipeline_mma_o, pipeline_mma_o_producer_state, - local_split_kv - ); + problem_shape, + shared_storage.tensors, + pipeline_load_qk, + pipeline_load_qk_consumer_state, + pipeline_load_qk, + pipeline_load_qk_consumer_state, + pipeline_mma_s, + pipeline_mma_s_producer_state, + pipeline_p_mma, + pipeline_p_mma_consumer_state, + pipeline_mma_o, + pipeline_mma_o_producer_state, + local_split_kv); } } - //cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive_and_wait(); + // cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, + // kNamedBarrierTmemDealloc).arrive_and_wait(); - //uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; - //tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); - } - else if (role == WarpRole::kCompute) { + // uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + // tmem_allocator.free(free_stage_ptr, + // TmemAllocator::Sm100TmemCapacityColumns); + } else if (role == WarpRole::kCompute) { CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { auto blk_coord = tile_scheduler.get_block_coord(); auto problem_shape = params.problem_shape; - auto split_kv = params.split_kv; - auto local_split_kv = split_kv; + auto split_kv = params.split_kv; + auto local_split_kv = split_kv; if (params.mainloop.ptr_seq != nullptr) { get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; - if (params.ptr_split_kv != nullptr) { + if (params.ptr_split_kv != nullptr) { local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; } } - if (local_split_kv <= get<3>(blk_coord)) + if (local_split_kv <= get<3>(blk_coord)) continue; compute( - blk_coord, - problem_shape, - params.mainloop, // for softmax_scale - params.epilogue, - shared_storage.tensors, // for smem_comm - pipeline_mma_s, pipeline_mma_s_consumer_state, - pipeline_p_mma, pipeline_p_mma_producer_state, - pipeline_mma_o, pipeline_mma_o_consumer_state, - local_split_kv, - params.is_fused_reduction - ); + blk_coord, + problem_shape, + params.mainloop, // for softmax_scale + params.epilogue, + shared_storage.tensors, // for smem_comm + pipeline_mma_s, + pipeline_mma_s_consumer_state, + pipeline_p_mma, + pipeline_p_mma_producer_state, + pipeline_mma_o, + pipeline_mma_o_consumer_state, + local_split_kv, + params.is_fused_reduction); } - //cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive(); + // cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, + // kNamedBarrierTmemDealloc).arrive(); } cute::cluster_sync(); - cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive(); + cutlass::arch::NamedBarrier( + (kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc) + .arrive(); if (role == WarpRole::kMma) { uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; - tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + tmem_allocator.free( + free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); } } - template + template CUTLASS_DEVICE void load_page_table( BlkCoord const& blk_coord, ProblemShape const& problem_shape, MainloopArguments const& mainloop_args, TensorStorage& shared_tensors, PipelinePT& pipeline_page_table, - typename PipelinePT::PipelineState& pipeline_pt_producer_state, int const& split_kv) { - + typename PipelinePT::PipelineState& pipeline_pt_producer_state, + int const& split_kv) { auto [H, K, D, B] = problem_shape; int batch_coord = get<2>(blk_coord); - auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), - make_shape(mainloop_args.page_count, B), - mainloop_args.stride_page_table); + auto mPT_l = make_tensor( + make_gmem_ptr(mainloop_args.ptr_page_table), + make_shape(mainloop_args.page_count, B), + mainloop_args.stride_page_table); auto mPT = mPT_l(_, batch_coord); - + int k_tile_total = ceil_div(K, TileShapeS{}); int k_tile_per_cta = ceil_div(k_tile_total, split_kv); int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit - int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + int k_tile_count = + max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); if (k_tile_count == 0) { return; } @@ -848,44 +1004,44 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { for (; k_tile_count > 0; ++k_index, --k_tile_count) { pipeline_page_table.producer_acquire(pipeline_pt_producer_state); - + // assume a single warp CUTLASS_PRAGMA_UNROLL for (int i = 0; i < TileShapeS{}; i += cutlass::NumThreadsPerWarp) { int idx = i + thread_idx; bool guard = idx < pages_per_tile; - int smem_idx = pipeline_pt_producer_state.index() * TileShapeS::value + idx; + int smem_idx = + pipeline_pt_producer_state.index() * TileShapeS::value + idx; int pt_idx = pages_per_tile * k_index + idx; - cutlass::arch::cp_async_zfill( - &shared_tensors.smem_page_table[smem_idx], &mPT(pt_idx), guard - ); + cutlass::arch:: + cp_async_zfill( + &shared_tensors.smem_page_table[smem_idx], &mPT(pt_idx), guard); } - - pipeline_page_table.producer_commit(pipeline_pt_producer_state, cutlass::arch::cpasync_barrier_arrive); + + pipeline_page_table.producer_commit( + pipeline_pt_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_pt_producer_state; } } - struct Gather { int& page_table_stage; Pow2 pages_per_tile; - const int * __restrict__ smem_page_table; + const int* __restrict__ smem_page_table; CUTLASS_DEVICE int operator()(int idx) const { - return smem_page_table[page_table_stage * TileShapeS::value + idx % pages_per_tile]; + return smem_page_table + [page_table_stage * TileShapeS::value + idx % pages_per_tile]; } CUTLASS_DEVICE friend void print(Gather const&) { printf(""); } - }; - - template + template CUTLASS_DEVICE void load_cpasync( BlkCoord const& blk_coord, ProblemShape const& problem_shape, @@ -897,7 +1053,6 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { int const& split_kv, PipelinePT& pipeline_page_table, typename PipelinePT::PipelineState& pipeline_pt_consumer_state) { - auto [H, K, D, B] = problem_shape; auto [D_latent, D_rope] = D; @@ -906,34 +1061,51 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { int k_tile_total = ceil_div(K, TileShapeS{}); int k_tile_per_cta = ceil_div(k_tile_total, split_kv); int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit - int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + int k_tile_count = + max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); if (k_tile_count == 0) { return; } // partition all tensors - auto mQL = make_tensor(make_gmem_ptr(mainloop_args.ptr_q_latent), make_shape(H, D_latent, B), mainloop_args.stride_q_latent); - auto mQR = make_tensor(make_gmem_ptr(mainloop_args.ptr_q_rope), make_shape(H, D_rope, B), mainloop_args.stride_q_rope); - - int paged_B = mainloop_args.page_count; + auto mQL = make_tensor( + make_gmem_ptr(mainloop_args.ptr_q_latent), + make_shape(H, D_latent, B), + mainloop_args.stride_q_latent); + auto mQR = make_tensor( + make_gmem_ptr(mainloop_args.ptr_q_rope), + make_shape(H, D_rope, B), + mainloop_args.stride_q_rope); + + int paged_B = mainloop_args.page_count; auto paged_K = Pow2{mainloop_args.page_size}; - auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), make_shape(paged_B, B), mainloop_args.stride_page_table); + auto mPT_l = make_tensor( + make_gmem_ptr(mainloop_args.ptr_page_table), + make_shape(paged_B, B), + mainloop_args.stride_page_table); int batch_coord = get<2>(blk_coord); auto mPT = mPT_l(_, batch_coord); - auto gQL = local_tile(mQL, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); - auto gQR = local_tile(mQR, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQL = + local_tile(mQL, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); + auto gQR = + local_tile(mQR, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); - ThrMMA cta_mma_qk = TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); - ThrMMA cta_mma_pv = TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + ThrMMA cta_mma_qk = + TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + ThrMMA cta_mma_pv = + TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); auto tSgQL = cta_mma_qk.partition_A(gQL); auto tSgQR = cta_mma_qk.partition_A(gQR); - Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); - Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); - Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + Tensor sQ = make_tensor( + make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor( + make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor( + make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); auto make_copy_for = [](auto sT) { auto rT_a = sT.layout()(_, _, _, _0{}); @@ -943,8 +1115,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { return make_cotiled_copy( Copy_Atom, Element>{}, make_ordered_layout( - make_shape(threads, values), - make_stride(_1{}, _0{})), + make_shape(threads, values), make_stride(_1{}, _0{})), rT); }; @@ -958,13 +1129,11 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { src_v_ptrs(i) = &src_v(_0{}, i); } - for (int i = 0; i < size<1>(src_v); i++) { auto src_v_i = make_tensor( make_gmem_ptr(src_v_ptrs(i)), make_shape(shape<0>(src_v)), - make_stride(make_stride(_1{}, _0{})) - ); + make_stride(make_stride(_1{}, _0{}))); atom.call(src_v_i, dst_v(_, i)); } }; @@ -973,9 +1142,12 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { auto tiled_copy_kc = make_copy_for(sKC); auto tiled_copy_vc = make_copy_for(sVC); - auto thr_copy_q = tiled_copy_q.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); - auto thr_copy_kc = tiled_copy_kc.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); - auto thr_copy_vc = tiled_copy_vc.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + auto thr_copy_q = tiled_copy_q.get_thread_slice( + threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + auto thr_copy_kc = tiled_copy_kc.get_thread_slice( + threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + auto thr_copy_vc = tiled_copy_vc.get_thread_slice( + threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); auto tQsQ = thr_copy_q.partition_D(sQ); auto tQgQL = thr_copy_q.partition_S(tSgQL); @@ -988,7 +1160,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { int page_table_stage = -1; Pow2 pages_per_tile{TileShapeS{} / paged_K}; - const int * __restrict__ smem_page_table = shared_tensors.smem_page_table.begin(); + const int* __restrict__ smem_page_table = + shared_tensors.smem_page_table.begin(); Gather gather{page_table_stage, pages_per_tile, smem_page_table}; auto mCL = make_tensor( @@ -996,7 +1169,12 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { ComposedLayout{ make_layout( make_shape(make_shape(paged_K, paged_B), _1{}), - make_stride(make_stride(get<0>(mainloop_args.stride_c_latent), example::CustomStride(gather, get<2>(mainloop_args.stride_c_latent))), get<1>(mainloop_args.stride_c_latent))), + make_stride( + make_stride( + get<0>(mainloop_args.stride_c_latent), + example::CustomStride( + gather, get<2>(mainloop_args.stride_c_latent))), + get<1>(mainloop_args.stride_c_latent))), make_coord(_0{}, _0{}), make_identity_layout(make_shape(paged_K * paged_B, D_latent))}); @@ -1005,7 +1183,12 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { ComposedLayout{ make_layout( make_shape(make_shape(paged_K, paged_B), _1{}), - make_stride(make_stride(get<0>(mainloop_args.stride_k_rope), example::CustomStride(gather, get<2>(mainloop_args.stride_k_rope))), get<1>(mainloop_args.stride_k_rope))), + make_stride( + make_stride( + get<0>(mainloop_args.stride_k_rope), + example::CustomStride( + gather, get<2>(mainloop_args.stride_k_rope))), + get<1>(mainloop_args.stride_k_rope))), make_coord(_0{}, _0{}), make_identity_layout(make_shape(paged_K * paged_B, D_latent))}); @@ -1014,13 +1197,21 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { ComposedLayout{ make_layout( make_shape(_1{}, make_shape(paged_K, paged_B)), - make_stride(get<1>(mainloop_args.stride_c_latent), make_stride(get<0>(mainloop_args.stride_c_latent), example::CustomStride(gather, get<2>(mainloop_args.stride_c_latent))))), + make_stride( + get<1>(mainloop_args.stride_c_latent), + make_stride( + get<0>(mainloop_args.stride_c_latent), + example::CustomStride( + gather, get<2>(mainloop_args.stride_c_latent))))), make_coord(_0{}, _0{}), make_identity_layout(make_shape(D_latent, paged_K * paged_B))}); - auto gCL = local_tile(mCL, TileShapeQK{}, make_coord(_,_,_), Step{}); - auto gKR = local_tile(mKR, TileShapeQK{}, make_coord(_,_,_), Step{}); - auto gCLT = local_tile(mCLT, TileShapePV{}, make_coord(_,_,_), Step{}); + auto gCL = + local_tile(mCL, TileShapeQK{}, make_coord(_, _, _), Step{}); + auto gKR = + local_tile(mKR, TileShapeQK{}, make_coord(_, _, _), Step{}); + auto gCLT = + local_tile(mCLT, TileShapePV{}, make_coord(_, _, _), Step{}); auto tSgCL = cta_mma_qk.partition_B(gCL); auto tSgKR = cta_mma_qk.partition_B(gKR); @@ -1063,15 +1254,27 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { // each Q/K tile consists of rope and latent for (int i = 0; i < IterationsQKLatent; i++) { load_stage([&](int index) { - cute::copy(tiled_copy_q, tQgQL(_, _, _, _, _0{}, i, batch_coord), tQsQ(_, _, _, _, i)); - copy_split(tiled_copy_kc, tKCgCL(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + cute::copy( + tiled_copy_q, + tQgQL(_, _, _, _, _0{}, i, batch_coord), + tQsQ(_, _, _, _, i)); + copy_split( + tiled_copy_kc, + tKCgCL(_, _, _, _, k_index, i), + tKCsKC(_, _, _, _, index)); }); } for (int i = 0; i < IterationsQKRope; i++) { load_stage([&](int index) { - cute::copy(tiled_copy_q, tQgQR(_, _, _, _, _0{}, i, batch_coord), tQsQ(_, _, _, _, IterationsQKLatent + i)); - copy_split(tiled_copy_kc, tKCgKR(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + cute::copy( + tiled_copy_q, + tQgQR(_, _, _, _, _0{}, i, batch_coord), + tQsQ(_, _, _, _, IterationsQKLatent + i)); + copy_split( + tiled_copy_kc, + tKCgKR(_, _, _, _, k_index, i), + tKCsKC(_, _, _, _, index)); }); } @@ -1082,20 +1285,25 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { // perform K+Q load here CUTLASS_PRAGMA_NO_UNROLL while (k_tile_count > 0) { - pipeline_page_table.consumer_wait(pipeline_pt_consumer_state); page_table_stage = pipeline_pt_consumer_state.index(); ++pipeline_pt_consumer_state; for (int i = 0; i < IterationsQKLatent; i++) { load_stage([&](int index) { - copy_split(tiled_copy_kc, tKCgCL(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + copy_split( + tiled_copy_kc, + tKCgCL(_, _, _, _, k_index, i), + tKCsKC(_, _, _, _, index)); }); } for (int i = 0; i < IterationsQKRope; i++) { load_stage([&](int index) { - copy_split(tiled_copy_kc, tKCgKR(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + copy_split( + tiled_copy_kc, + tKCgKR(_, _, _, _, k_index, i), + tKCsKC(_, _, _, _, index)); }); } @@ -1104,7 +1312,10 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { for (int i = 0; i < IterationsPV_K; i++) { for (int j = 0; j < IterationsPV_N; j++) { load_stage([&](int index) { - copy_split(tiled_copy_vc, tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), tVCsVC(_, _, _, _, index)); + copy_split( + tiled_copy_vc, + tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), + tVCsVC(_, _, _, _, index)); }); } } @@ -1121,7 +1332,10 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { for (int i = 0; i < IterationsPV_K; i++) { for (int j = 0; j < IterationsPV_N; j++) { load_stage([&](int index) { - copy_split(tiled_copy_vc, tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), tVCsVC(_, _, _, _, index)); + copy_split( + tiled_copy_vc, + tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), + tVCsVC(_, _, _, _, index)); }); } } @@ -1131,7 +1345,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { while (pipeline_offset > 0) { cutlass::arch::cp_async_fence(); - + cutlass::arch::cp_async_wait(); pipeline_load.producer_commit(pipeline_commit_state); ++pipeline_commit_state; @@ -1139,11 +1353,9 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { } cutlass::arch::cp_async_wait<0>(); - } - - template + template CUTLASS_DEVICE void load_tma( BlkCoord const& blk_coord, ProblemShape const& problem_shape, @@ -1155,23 +1367,25 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { PipelineLoadPV& pipeline_load_pv, typename PipelineLoadPV::PipelineState& pipeline_load_pv_producer_state, int const& split_kv) { - auto [H, K, D, B] = problem_shape; auto [D_latent, D_rope] = D; int k_tile_total = ceil_div(K, TileShapeS{}); int k_tile_per_cta = ceil_div(k_tile_total, split_kv); int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit - int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + int k_tile_count = + max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); if (k_tile_count == 0) { return; } using X = Underscore; - + // partition all tensors - auto mQL = mainloop_params.tma_load_q_latent.get_tma_tensor(make_shape(H, D_latent, B)); - auto mQR = mainloop_params.tma_load_q_rope.get_tma_tensor(make_shape(H, D_rope, B)); + auto mQL = mainloop_params.tma_load_q_latent.get_tma_tensor( + make_shape(H, D_latent, B)); + auto mQR = mainloop_params.tma_load_q_rope.get_tma_tensor( + make_shape(H, D_rope, B)); int paged_B = B; int paged_K = K; @@ -1179,22 +1393,35 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { paged_B = mainloop_args.page_count; paged_K = mainloop_args.page_size; } - auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), make_shape(paged_B, B), mainloop_args.stride_page_table); + auto mPT_l = make_tensor( + make_gmem_ptr(mainloop_args.ptr_page_table), + make_shape(paged_B, B), + mainloop_args.stride_page_table); - auto mCL = mainloop_params.tma_load_c_latent.get_tma_tensor(make_shape(paged_K, D_latent, paged_B)); - auto mKR = mainloop_params.tma_load_k_rope.get_tma_tensor(make_shape(paged_K, D_rope, paged_B)); + auto mCL = mainloop_params.tma_load_c_latent.get_tma_tensor( + make_shape(paged_K, D_latent, paged_B)); + auto mKR = mainloop_params.tma_load_k_rope.get_tma_tensor( + make_shape(paged_K, D_rope, paged_B)); - auto mCLT = mainloop_params.tma_load_c_latent_transpose.get_tma_tensor(make_shape(D_latent, paged_K, paged_B)); + auto mCLT = mainloop_params.tma_load_c_latent_transpose.get_tma_tensor( + make_shape(D_latent, paged_K, paged_B)); - auto gQL = local_tile(mQL, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); - auto gQR = local_tile(mQR, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQL = + local_tile(mQL, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); + auto gQR = + local_tile(mQR, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); - auto gCL = local_tile(mCL, TileShapeQK{}, make_coord(_,_,_), Step{}); - auto gKR = local_tile(mKR, TileShapeQK{}, make_coord(_,_,_), Step{}); - auto gCLT = local_tile(mCLT, TileShapePV{}, make_coord(_,_,_), Step{}); + auto gCL = + local_tile(mCL, TileShapeQK{}, make_coord(_, _, _), Step{}); + auto gKR = + local_tile(mKR, TileShapeQK{}, make_coord(_, _, _), Step{}); + auto gCLT = + local_tile(mCLT, TileShapePV{}, make_coord(_, _, _), Step{}); - ThrMMA cta_mma_qk = TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); - ThrMMA cta_mma_pv = TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + ThrMMA cta_mma_qk = + TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + ThrMMA cta_mma_pv = + TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); auto tSgQL = cta_mma_qk.partition_A(gQL); auto tSgQR = cta_mma_qk.partition_A(gQR); @@ -1204,29 +1431,47 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { auto tOgCLT = cta_mma_pv.partition_B(gCLT); - Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); - Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); - Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + Tensor sQ = make_tensor( + make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor( + make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor( + make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); auto [tQLgQL_mkl, tQsQ] = tma_partition( - mainloop_params.tma_load_q_latent, _0{}, make_layout(_1{}), - group_modes<0,3>(sQ), group_modes<0,3>(tSgQL)); + mainloop_params.tma_load_q_latent, + _0{}, + make_layout(_1{}), + group_modes<0, 3>(sQ), + group_modes<0, 3>(tSgQL)); auto [tQRgQR_mkl, tQsQ_ignore] = tma_partition( - mainloop_params.tma_load_q_rope, _0{}, make_layout(_1{}), - group_modes<0,3>(sQ), group_modes<0,3>(tSgQR)); + mainloop_params.tma_load_q_rope, + _0{}, + make_layout(_1{}), + group_modes<0, 3>(sQ), + group_modes<0, 3>(tSgQR)); auto [tCLgCL_nkl, tKCsKC] = tma_partition( - mainloop_params.tma_load_c_latent, _0{}, make_layout(_1{}), - group_modes<0,3>(sKC), group_modes<0,3>(tSgCL)); + mainloop_params.tma_load_c_latent, + _0{}, + make_layout(_1{}), + group_modes<0, 3>(sKC), + group_modes<0, 3>(tSgCL)); auto [tKRgKR_nkl, tKCsKC_ignore] = tma_partition( - mainloop_params.tma_load_k_rope, _0{}, make_layout(_1{}), - group_modes<0,3>(sKC), group_modes<0,3>(tSgKR)); + mainloop_params.tma_load_k_rope, + _0{}, + make_layout(_1{}), + group_modes<0, 3>(sKC), + group_modes<0, 3>(tSgKR)); auto [tCLTgCLT_nkl, tVCsVC] = tma_partition( - mainloop_params.tma_load_c_latent_transpose, _0{}, make_layout(_1{}), - group_modes<0,3>(sVC), group_modes<0,3>(tOgCLT)); + mainloop_params.tma_load_c_latent_transpose, + _0{}, + make_layout(_1{}), + group_modes<0, 3>(sVC), + group_modes<0, 3>(tOgCLT)); uint16_t mcast_mask = 0; @@ -1247,23 +1492,26 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { // each Q/K tile consists of rope and latent for (int i = 0; i < IterationsQKLatent; i++) { - pipeline_load_qk.producer_expect_transaction(pipeline_load_qk_producer_state, kTransactionsBytesLoadExtraQ); + pipeline_load_qk.producer_expect_transaction( + pipeline_load_qk_producer_state, kTransactionsBytesLoadExtraQ); pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); - auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier( + pipeline_load_qk_producer_state); if (cute::elect_one_sync()) { // expect the extra bytes // load_qk ql - cute::copy(mainloop_params.tma_load_q_latent.with(*tma_barrier, mcast_mask), tQLgQL(_, _0{}, i), tQsQ(_, i)); + cute::copy( + mainloop_params.tma_load_q_latent.with(*tma_barrier, mcast_mask), + tQLgQL(_, _0{}, i), + tQsQ(_, i)); // load_qk cl if constexpr (kIsPaged) { cute::copy( mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), tCLgCL(_, _0{}, i, mPT(k_index)), - tKCsKC(_, pipeline_load_qk_producer_state.index()) - ); - } - else { + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } else { cute::copy( mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), tCLgCL(_, k_index, i, batch_coord), @@ -1274,23 +1522,26 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { } for (int i = 0; i < IterationsQKRope; i++) { - pipeline_load_qk.producer_expect_transaction(pipeline_load_qk_producer_state, kTransactionsBytesLoadExtraQ); + pipeline_load_qk.producer_expect_transaction( + pipeline_load_qk_producer_state, kTransactionsBytesLoadExtraQ); pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); - auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier( + pipeline_load_qk_producer_state); if (cute::elect_one_sync()) { // expect the extra bytes // load_qk ql - cute::copy(mainloop_params.tma_load_q_rope.with(*tma_barrier, mcast_mask), tQRgQR(_, _0{}, i), tQsQ(_, i + IterationsQKLatent)); + cute::copy( + mainloop_params.tma_load_q_rope.with(*tma_barrier, mcast_mask), + tQRgQR(_, _0{}, i), + tQsQ(_, i + IterationsQKLatent)); // load_qk cl if constexpr (kIsPaged) { cute::copy( mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), tKRgKR(_, _0{}, i, mPT(k_index)), - tKCsKC(_, pipeline_load_qk_producer_state.index()) - ); - } - else { + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } else { cute::copy( mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), tKRgKR(_, k_index, i, batch_coord), @@ -1307,24 +1558,24 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { // perform K+Q load here CUTLASS_PRAGMA_NO_UNROLL while (k_tile_count > 0) { - // perform K load for (int i = 0; i < IterationsQKLatent; i++) { pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); - auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier( + pipeline_load_qk_producer_state); if (cute::elect_one_sync()) { // load_qk cl if constexpr (kIsPaged) { cute::copy( - mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + mainloop_params.tma_load_c_latent.with( + *tma_barrier, mcast_mask), tCLgCL(_, _0{}, i, mPT(k_index)), - tKCsKC(_, pipeline_load_qk_producer_state.index()) - ); - } - else { + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } else { cute::copy( - mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + mainloop_params.tma_load_c_latent.with( + *tma_barrier, mcast_mask), tCLgCL(_, k_index, i, batch_coord), tKCsKC(_, pipeline_load_qk_producer_state.index())); } @@ -1334,7 +1585,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { for (int i = 0; i < IterationsQKRope; i++) { pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); - auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier( + pipeline_load_qk_producer_state); if (cute::elect_one_sync()) { // load_qk cl @@ -1342,10 +1594,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { cute::copy( mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), tKRgKR(_, _0{}, i, mPT(k_index)), - tKCsKC(_, pipeline_load_qk_producer_state.index()) - ); - } - else { + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } else { cute::copy( mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), tKRgKR(_, k_index, i, batch_coord), @@ -1363,15 +1613,12 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { if (k_tile_count > kPrefetchDistance) { cute::prefetch( mainloop_params.tma_load_c_latent, - tCLgCL(_, _0{}, i, mPT(k_index + kPrefetchDistance)) - ); + tCLgCL(_, _0{}, i, mPT(k_index + kPrefetchDistance))); } - } - else { + } else { cute::prefetch( mainloop_params.tma_load_c_latent, - tCLgCL(_, k_index + kPrefetchDistance, i, batch_coord) - ); + tCLgCL(_, k_index + kPrefetchDistance, i, batch_coord)); } } } @@ -1382,15 +1629,12 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { if (k_tile_count > kPrefetchDistance) { cute::prefetch( mainloop_params.tma_load_k_rope, - tKRgKR(_, _0{}, i, mPT(k_index + kPrefetchDistance)) - ); + tKRgKR(_, _0{}, i, mPT(k_index + kPrefetchDistance))); } - } - else { + } else { cute::prefetch( mainloop_params.tma_load_k_rope, - tKRgKR(_, k_index + kPrefetchDistance, i, batch_coord) - ); + tKRgKR(_, k_index + kPrefetchDistance, i, batch_coord)); } } } @@ -1400,7 +1644,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { for (int i = 0; i < IterationsPV_K; i++) { for (int j = 0; j < IterationsPV_N; j++) { pipeline_load_pv.producer_acquire(pipeline_load_pv_producer_state); - auto tma_barrier = pipeline_load_pv.producer_get_barrier(pipeline_load_pv_producer_state); + auto tma_barrier = pipeline_load_pv.producer_get_barrier( + pipeline_load_pv_producer_state); if (cute::elect_one_sync()) { // load_pv cl @@ -1408,17 +1653,21 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { // note we are off-by-one on k_index if constexpr (kIsPaged) { cute::copy( - mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + mainloop_params.tma_load_c_latent_transpose.with( + *tma_barrier, + mcast_mask, + cute::TMA::CacheHintSm100::EVICT_FIRST), tCLTgCLT(_, j, i, mPT(k_index - 1)), - tVCsVC(_, pipeline_load_pv_producer_state.index()) - ); - } - else { + tVCsVC(_, pipeline_load_pv_producer_state.index())); + } else { cute::copy( - mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), - tCLTgCLT(_, j, IterationsPV_K * (k_index - 1) + i, batch_coord), - tVCsVC(_, pipeline_load_pv_producer_state.index()) - ); + mainloop_params.tma_load_c_latent_transpose.with( + *tma_barrier, + mcast_mask, + cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT( + _, j, IterationsPV_K * (k_index - 1) + i, batch_coord), + tVCsVC(_, pipeline_load_pv_producer_state.index())); } } ++pipeline_load_pv_producer_state; @@ -1432,7 +1681,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { for (int i = 0; i < IterationsPV_K; i++) { for (int j = 0; j < IterationsPV_N; j++) { pipeline_load_pv.producer_acquire(pipeline_load_pv_producer_state); - auto tma_barrier = pipeline_load_pv.producer_get_barrier(pipeline_load_pv_producer_state); + auto tma_barrier = pipeline_load_pv.producer_get_barrier( + pipeline_load_pv_producer_state); if (cute::elect_one_sync()) { // load_pv cl @@ -1441,17 +1691,20 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { if constexpr (kIsPaged) { cute::copy( - mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + mainloop_params.tma_load_c_latent_transpose.with( + *tma_barrier, + mcast_mask, + cute::TMA::CacheHintSm100::EVICT_FIRST), tCLTgCLT(_, j, i, mPT(k_index - 1)), - tVCsVC(_, pipeline_load_pv_producer_state.index()) - ); - } - else { + tVCsVC(_, pipeline_load_pv_producer_state.index())); + } else { cute::copy( - mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + mainloop_params.tma_load_c_latent_transpose.with( + *tma_barrier, + mcast_mask, + cute::TMA::CacheHintSm100::EVICT_FIRST), tCLTgCLT(_, j, IterationsPV_K * (k_index - 1) + i, batch_coord), - tVCsVC(_, pipeline_load_pv_producer_state.index()) - ); + tVCsVC(_, pipeline_load_pv_producer_state.index())); } } ++pipeline_load_pv_producer_state; @@ -1459,7 +1712,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { } } - template + template CUTLASS_DEVICE void mma( BlkCoord const& blk_coord, ProblemShape const& problem_shape, @@ -1475,22 +1728,26 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { PipelineO& pipeline_mma_o, typename PipelineO::PipelineState& pipeline_mma_o_producer_state, int const& split_kv) { - auto [H, K, D, B] = problem_shape; int k_tile_total = ceil_div(K, TileShapeS{}); int k_tile_per_cta = ceil_div(k_tile_total, split_kv); int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit - int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + int k_tile_count = + max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); if (k_tile_count == 0) { return; } // mma init - Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); - Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); - Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); - Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{}); + Tensor sQ = make_tensor( + make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor( + make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor( + make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + Tensor sP = make_tensor( + make_smem_ptr((Element*)shared_tensors.smem_p.begin()), SmemLayoutP{}); Tensor tSrQ = TiledMmaQK::make_fragment_A(sQ); Tensor tSrKC = TiledMmaQK::make_fragment_B(sKC); @@ -1500,8 +1757,10 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { TiledMmaQK tiled_mma_qk; TiledMmaPV tiled_mma_pv; - Tensor tStS = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); - Tensor tOtO = partition_fragment_C(tiled_mma_pv, select<0,1>(TileShapePV{})); + Tensor tStS = + partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShapeQK{})); + Tensor tOtO = + partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShapePV{})); tiled_mma_pv.accumulate_ = UMMA::ScaleOut::Zero; @@ -1517,14 +1776,17 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { pipeline_load_qk.consumer_wait(pipeline_load_qk_consumer_state); int read_stage = pipeline_load_qk_consumer_state.index(); - tStS.data() = uint32_t(pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1); + tStS.data() = uint32_t( + pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 + : TmemAllocation::kS1); CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { - cute::gemm(tiled_mma_qk, - tSrQ(_,_,k_block,i), - tSrKC(_,_,k_block,read_stage), - tStS); + cute::gemm( + tiled_mma_qk, + tSrQ(_, _, k_block, i), + tSrKC(_, _, k_block, read_stage), + tStS); tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; } @@ -1539,21 +1801,23 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { CUTLASS_PRAGMA_NO_UNROLL while (k_tile_count > 0) { - pipeline_mma_s.producer_acquire(pipeline_mma_s_producer_state); tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; for (int i = 0; i < IterationsQK; i++) { pipeline_load_qk.consumer_wait(pipeline_load_qk_consumer_state); int read_stage = pipeline_load_qk_consumer_state.index(); - tStS.data() = uint32_t(pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1); + tStS.data() = uint32_t( + pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 + : TmemAllocation::kS1); CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { - cute::gemm(tiled_mma_qk, - tSrQ(_,_,k_block,i), - tSrKC(_,_,k_block,read_stage), - tStS); + cute::gemm( + tiled_mma_qk, + tSrQ(_, _, k_block, i), + tSrKC(_, _, k_block, read_stage), + tStS); tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; } @@ -1574,15 +1838,21 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { int read_stage = pipeline_load_pv_consumer_state.index(); - tOtO.data() = uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO); + tOtO.data() = uint32_t(TmemAllocation::kO0) + + j * uint32_t(TmemAllocation::kSizeAccO); tiled_mma_pv.accumulate_ = acc_flag; CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tOrP); ++k_block) { - cute::gemm(tiled_mma_pv, - tOrP(_,_,k_block, make_coord(i, pipeline_p_mma_consumer_state.index())), - tOrVC(_,_,k_block,read_stage), - tOtO); + cute::gemm( + tiled_mma_pv, + tOrP( + _, + _, + k_block, + make_coord(i, pipeline_p_mma_consumer_state.index())), + tOrVC(_, _, k_block, read_stage), + tOtO); tiled_mma_pv.accumulate_ = UMMA::ScaleOut::One; } @@ -1609,15 +1879,21 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { int read_stage = pipeline_load_pv_consumer_state.index(); - tOtO.data() = uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO); + tOtO.data() = uint32_t(TmemAllocation::kO0) + + j * uint32_t(TmemAllocation::kSizeAccO); tiled_mma_pv.accumulate_ = acc_flag; CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tOrP); ++k_block) { - cute::gemm(tiled_mma_pv, - tOrP(_,_,k_block, make_coord(i, pipeline_p_mma_consumer_state.index())), - tOrVC(_,_,k_block,read_stage), - tOtO); + cute::gemm( + tiled_mma_pv, + tOrP( + _, + _, + k_block, + make_coord(i, pipeline_p_mma_consumer_state.index())), + tOrVC(_, _, k_block, read_stage), + tOtO); tiled_mma_pv.accumulate_ = UMMA::ScaleOut::One; } @@ -1632,8 +1908,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { ++pipeline_mma_o_producer_state; } - - template + template CUTLASS_DEVICE void softmax( IsLastTile const& is_last_tile, ElementAcc& row_max, @@ -1645,25 +1920,25 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { int k_index, uint32_t tmem_s, int smem_p_index) { - auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; TiledMmaQK tiled_mma_qk; - Tensor tStS = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + Tensor tStS = + partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShapeQK{})); tStS.data() = tmem_s; CUTE_STATIC_ASSERT_V(shape<1>(tStS) == _1{}); CUTE_STATIC_ASSERT_V(shape<2>(tStS) == _1{}); - Tensor tAcc = tStS(make_coord(_,_),_0{},_0{}); + Tensor tAcc = tStS(make_coord(_, _), _0{}, _0{}); - Tensor cS = make_identity_tensor(take<0,2>(CtaShapeQK{})); + Tensor cS = make_identity_tensor(take<0, 2>(CtaShapeQK{})); auto tiled_t2r = make_tmem_copy(load_op, tAcc); auto thread_idx = threadIdx.x % size(tiled_t2r); auto thread_t2r = tiled_t2r.get_slice(thread_idx); - Tensor tTR_cS = thread_t2r.partition_D(cS); + Tensor tTR_cS = thread_t2r.partition_D(cS); Tensor tTR_rAcc = make_tensor(shape(tTR_cS)); Tensor tTR_rS_frag = make_tensor(shape(tTR_rAcc)); @@ -1677,7 +1952,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { if (is_last_tile) { for (int i = 0; i < size(tTR_rAcc); i++) { - if (get<1>(tTR_cS(i)) + TileShapeS{} * k_index >= get<1>(problem_shape)) { + if (get<1>(tTR_cS(i)) + TileShapeS{} * k_index >= + get<1>(problem_shape)) { tTR_rAcc(i) = -std::numeric_limits::infinity(); } } @@ -1693,14 +1969,18 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { // for 2x2 dp, reduce here if constexpr (kWarpsInN > 1) { shared_tensors.smem_exchange[threadIdx.x] = row_max_new; - cutlass::arch::NamedBarrier(kNumComputeWarps*NumThreadsPerWarp, kNamedBarrierExchange).sync(); + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, kNamedBarrierExchange) + .sync(); // (64, 2) shape int peer_index = (threadIdx.x + 64) % 128; - row_max_new = cutlass::max(row_max_new, shared_tensors.smem_exchange[peer_index]); + row_max_new = + cutlass::max(row_max_new, shared_tensors.smem_exchange[peer_index]); } // find correction factor - ElementAcc softmax_scale_log2 = mainloop_args.softmax_scale * static_cast(M_LOG2E); + ElementAcc softmax_scale_log2 = + mainloop_args.softmax_scale * static_cast(M_LOG2E); correction_factor = ::exp2f(softmax_scale_log2 * (row_max - row_max_new)); row_max = row_max_new; @@ -1708,7 +1988,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { ElementAcc row_max_scale_log2 = row_max * softmax_scale_log2; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTR_rAcc); i++) { - tTR_rAcc(i) = ::exp2f(softmax_scale_log2 * tTR_rAcc(i) - row_max_scale_log2); + tTR_rAcc(i) = + ::exp2f(softmax_scale_log2 * tTR_rAcc(i) - row_max_scale_log2); } // quantize @@ -1719,13 +2000,18 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { tTR_rS_vec(i) = epilogue_op(tTR_rAcc_vec(i)); } - Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{})(_, _, _, make_coord(_, smem_p_index)); + Tensor sP = make_tensor( + make_smem_ptr((Element*)shared_tensors.smem_p.begin()), SmemLayoutP{})( + _, _, _, make_coord(_, smem_p_index)); Tensor tOcP = TiledMmaPV{}.get_slice(_0{}).partition_A(cS); // have a mapping for each thread to coord // find identical mapping to coords for the MMA - auto l = make_ordered_layout(make_shape(make_shape(_64{}, _2{}), make_shape(_16{}, TileShapeS{} / _32{})), make_stride(make_stride(_0{}, _3{}), make_stride(_1{}, _2{}))); + auto l = make_ordered_layout( + make_shape( + make_shape(_64{}, _2{}), make_shape(_16{}, TileShapeS{} / _32{})), + make_stride(make_stride(_0{}, _3{}), make_stride(_1{}, _2{}))); auto sP_ = as_position_independent_swizzle_tensor(sP); copy_aligned(tTR_rS_frag, sP_.compose(l)(threadIdx.x, _)); @@ -1750,33 +2036,31 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { CUTLASS_PRAGMA_UNROLL for (int i = 1; i < size(sums); i *= 2) { CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < size(sums); j += 2*i) { - cute::add(sums(j), sums(j), sums(j+i)); + for (int j = 0; j < size(sums); j += 2 * i) { + cute::add(sums(j), sums(j), sums(j + i)); } } row_sum += sums(0).x + sums(0).y; } - - CUTLASS_DEVICE void rescale( - ElementAcc correction_factor, - uint32_t tmem_o) { - + CUTLASS_DEVICE void rescale(ElementAcc correction_factor, uint32_t tmem_o) { // for b2b gemm, do nothing auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; auto store_op = TMEM::tmem_load_to_store(load_op); TiledMmaPV tiled_mma_pv; - Tensor tOtO = partition_fragment_C(tiled_mma_pv, select<0,1>(TileShapePV{})); + Tensor tOtO = + partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShapePV{})); tOtO.data() = tmem_o; CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{}); CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{}); - Tensor tAcc = tOtO(make_coord(_,_),_0{},_0{}); + Tensor tAcc = tOtO(make_coord(_, _), _0{}, _0{}); - auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); - Tensor gO = make_tensor(make_gmem_ptr((ElementAcc*) nullptr), cta_tiler_pv, make_stride(0, 0)); + auto cta_tiler_pv = take<0, 2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = make_tensor( + make_gmem_ptr((ElementAcc*)nullptr), cta_tiler_pv, make_stride(0, 0)); auto tiled_t2r = make_tmem_copy(load_op, tAcc); auto tiled_r2t = make_tmem_copy(store_op, tAcc); @@ -1784,7 +2068,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { auto thread_t2r = tiled_t2r.get_slice(thread_idx); auto thread_r2t = tiled_r2t.get_slice(thread_idx); - Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_gO = thread_t2r.partition_D(gO); Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); @@ -1793,7 +2077,8 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { copy(tiled_t2r, tTR_tAcc, tTR_rAcc); // multiply by correction factor - float2 correction_factor_vec = make_float2(correction_factor, correction_factor); + float2 correction_factor_vec = + make_float2(correction_factor, correction_factor); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTR_rAcc); i += 2) { float2 in = make_float2(tTR_rAcc(i + 0), tTR_rAcc(i + 1)); @@ -1805,10 +2090,9 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { // store o copy(tiled_r2t, tTR_rAcc, tTR_tAcc); - } - + } - template + template CUTLASS_DEVICE void epilogue( ElementAcc& row_max, ElementAcc& row_sum, @@ -1819,43 +2103,54 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { TensorStorage& shared_tensors, uint32_t tmem_o, int const& split_kv) { - auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; TiledMmaPV tiled_mma_pv; - - Tensor tOtO = TiledMmaPV::make_fragment_C(partition_shape_C(TiledMmaPV{}, take<0, 2>(TileShapePV{}))); + + Tensor tOtO = TiledMmaPV::make_fragment_C( + partition_shape_C(TiledMmaPV{}, take<0, 2>(TileShapePV{}))); tOtO.data() = tmem_o; CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{}); CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{}); - Tensor tAcc = tOtO(make_coord(_,_),_0{},_0{}); + Tensor tAcc = tOtO(make_coord(_, _), _0{}, _0{}); auto [H, K, D, B] = problem_shape; auto [D_latent, D_rope] = D; - if (split_kv > 1) { + if (split_kv > 1) { using ElementOutAcc = ElementAcc; constexpr auto AlignmentOutAcc = 128 / cute::sizeof_bits_v; - Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o_acc + get<3>(cta_coord) * D_latent), make_shape(H, D_latent, B), epilogue_args.stride_o_acc); - auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); - Tensor gO = local_tile(mO, cta_tiler_pv, take<0,3>(cta_coord)); + Tensor mO = make_tensor( + make_gmem_ptr(epilogue_args.ptr_o_acc + get<3>(cta_coord) * D_latent), + make_shape(H, D_latent, B), + epilogue_args.stride_o_acc); + auto cta_tiler_pv = take<0, 2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = local_tile(mO, cta_tiler_pv, take<0, 3>(cta_coord)); auto tiled_t2r = make_tmem_copy(load_op, tAcc); auto thread_idx = threadIdx.x % size(tiled_t2r); auto thread_t2r = tiled_t2r.get_slice(thread_idx); - Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_gO = thread_t2r.partition_D(gO); Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); Tensor tTR_rO_frag = make_tensor(shape(tTR_rAcc)); - Tensor tTR_rO_src = recast>(coalesce(tTR_rO_frag)); - Tensor tR2G_rO_dst = recast>(coalesce(tTR_gO)); + Tensor tTR_rO_src = + recast>(coalesce(tTR_rO_frag)); + Tensor tR2G_rO_dst = + recast>(coalesce(tTR_gO)); Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); copy(tiled_t2r, tTR_tAcc, tTR_rAcc); - cutlass::epilogue::thread::LinearCombination epilogue_op({epilogue_args.output_scale / row_sum}); + cutlass::epilogue::thread::LinearCombination< + ElementOutAcc, + 1, + ElementAcc, + ElementAcc, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling> + epilogue_op({epilogue_args.output_scale / row_sum}); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTR_rAcc); i++) { tTR_rO_frag(i) = epilogue_op(tTR_rAcc(i)); @@ -1864,41 +2159,58 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { copy(tTR_rO_src, tR2G_rO_dst); if (get<1>(cta_coord) == 0) { - if (epilogue_args.ptr_lse != nullptr) { + if (epilogue_args.ptr_lse != nullptr) { // compute LSE - ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + ElementAcc lse = cutlass::fast_log(row_sum) + + mainloop_args.softmax_scale * row_max; // store LSE - Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_acc + H * get<3>(cta_coord)), make_shape(H, B), epilogue_args.stride_lse_acc); - Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + Tensor mLSE = make_tensor( + make_gmem_ptr(epilogue_args.ptr_lse_acc + H * get<3>(cta_coord)), + make_shape(H, B), + epilogue_args.stride_lse_acc); + Tensor gLSE = local_tile( + mLSE, + append<3>(cta_tiler_pv, _1{}), + take<0, 3>(cta_coord), + Step<_1, Underscore, _1>{}); // for 2x2 dp, this must be conditional and the index is wrong - if (! kIs2Sm || (threadIdx.x < 64)) - { + if (!kIs2Sm || (threadIdx.x < 64)) { gLSE(threadIdx.x) = lse; } - } + } } - } - else { - Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o), make_shape(H, D_latent, B), epilogue_args.stride_o); - auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); - Tensor gO = local_tile(mO, cta_tiler_pv, take<0,3>(cta_coord)); + } else { + Tensor mO = make_tensor( + make_gmem_ptr(epilogue_args.ptr_o), + make_shape(H, D_latent, B), + epilogue_args.stride_o); + auto cta_tiler_pv = take<0, 2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = local_tile(mO, cta_tiler_pv, take<0, 3>(cta_coord)); auto tiled_t2r = make_tmem_copy(load_op, tAcc); auto thread_idx = threadIdx.x % size(tiled_t2r); auto thread_t2r = tiled_t2r.get_slice(thread_idx); - Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_gO = thread_t2r.partition_D(gO); Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); Tensor tTR_rO_frag = make_tensor(shape(tTR_rAcc)); - Tensor tTR_rO_src = recast>(coalesce(tTR_rO_frag)); - Tensor tR2G_rO_dst = recast>(coalesce(tTR_gO)); + Tensor tTR_rO_src = + recast>(coalesce(tTR_rO_frag)); + Tensor tR2G_rO_dst = + recast>(coalesce(tTR_gO)); Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); copy(tiled_t2r, tTR_tAcc, tTR_rAcc); - cutlass::epilogue::thread::LinearCombination epilogue_op({epilogue_args.output_scale / row_sum}); + cutlass::epilogue::thread::LinearCombination< + ElementOut, + 1, + ElementAcc, + ElementAcc, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling> + epilogue_op({epilogue_args.output_scale / row_sum}); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTR_rAcc); i++) { tTR_rO_frag(i) = epilogue_op(tTR_rAcc(i)); @@ -1909,15 +2221,22 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { if (get<1>(cta_coord) == 0) { if (epilogue_args.ptr_lse != nullptr) { // compute LSE - ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + ElementAcc lse = cutlass::fast_log(row_sum) + + mainloop_args.softmax_scale * row_max; // store LSE - Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse); - Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + Tensor mLSE = make_tensor( + make_gmem_ptr(epilogue_args.ptr_lse), + make_shape(H, B), + epilogue_args.stride_lse); + Tensor gLSE = local_tile( + mLSE, + append<3>(cta_tiler_pv, _1{}), + take<0, 3>(cta_coord), + Step<_1, Underscore, _1>{}); // for 2x2 dp, this must be conditional and the index is wrong - if (! kIs2Sm || (threadIdx.x < 64)) - { + if (!kIs2Sm || (threadIdx.x < 64)) { gLSE(threadIdx.x) = lse; } } @@ -1925,7 +2244,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { } } - template + template CUTLASS_DEVICE ElementLSE epilogue_lse_reduction( ElementAcc& row_max, ElementAcc& row_sum, @@ -1934,64 +2253,89 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { MainloopArguments const& mainloop_args, EpilogueParams const& epilogue_args, int const& local_split_kv) { - auto [H, K, D, B] = problem_shape; auto [D_latent, D_rope] = D; - auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + auto cta_tiler_pv = take<0, 2>(typename CollectiveMmaPV::CtaShape_MNK{}); constexpr int kNumThreads = kNumComputeWarps * NumThreadsPerWarp; - using Sync = cutlass::detail::NamedBarrierSync; + using Sync = + cutlass::detail::NamedBarrierSync; auto wait = [](int* lock, int count) { __threadfence(); if (threadIdx.x == 0) { atomicAdd(lock, 1); - while (atomicCAS(lock, count, count) != count) {}; + while (atomicCAS(lock, count, count) != count) { + }; } __threadfence(); Sync::sync(); }; - const ElementLSE lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; - Tensor mLSE_max_buff = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_max_exchange_buff), make_shape(H, B), epilogue_args.stride_lse); - Tensor gLSE_max_buff = local_tile(mLSE_max_buff, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); - - int* local_lock = epilogue_args.ptr_lock + get<0>(cta_coord) + 2 * get<2>(cta_coord); - - if (! kIs2Sm || (threadIdx.x < 64)) { + const ElementLSE lse = + cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + Tensor mLSE_max_buff = make_tensor( + make_gmem_ptr(epilogue_args.ptr_lse_max_exchange_buff), + make_shape(H, B), + epilogue_args.stride_lse); + Tensor gLSE_max_buff = local_tile( + mLSE_max_buff, + append<3>(cta_tiler_pv, _1{}), + take<0, 3>(cta_coord), + Step<_1, Underscore, _1>{}); + + int* local_lock = + epilogue_args.ptr_lock + get<0>(cta_coord) + 2 * get<2>(cta_coord); + + if (!kIs2Sm || (threadIdx.x < 64)) { atomicMax(&(gLSE_max_buff(threadIdx.x)), __float2int_rn(lse)); } wait(local_lock, local_split_kv); - auto global_lse_max = static_cast(gLSE_max_buff(kIs2Sm ? threadIdx.x % 64 : threadIdx.x)); + auto global_lse_max = static_cast( + gLSE_max_buff(kIs2Sm ? threadIdx.x % 64 : threadIdx.x)); - Tensor mLSE_buff = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_exchange_buff), make_shape(H, B), epilogue_args.stride_lse); - Tensor gLSE_buff = local_tile(mLSE_buff, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + Tensor mLSE_buff = make_tensor( + make_gmem_ptr(epilogue_args.ptr_lse_exchange_buff), + make_shape(H, B), + epilogue_args.stride_lse); + Tensor gLSE_buff = local_tile( + mLSE_buff, + append<3>(cta_tiler_pv, _1{}), + take<0, 3>(cta_coord), + Step<_1, Underscore, _1>{}); - if (! kIs2Sm || (threadIdx.x < 64)) { + if (!kIs2Sm || (threadIdx.x < 64)) { atomicAdd(&(gLSE_buff(threadIdx.x)), expf(lse - global_lse_max)); } - wait(local_lock, 2*local_split_kv); + wait(local_lock, 2 * local_split_kv); const auto sum_lse = gLSE_buff(kIs2Sm ? threadIdx.x % 64 : threadIdx.x); - const auto global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits::infinity() : - cutlass::fast_log(sum_lse) + global_lse_max; + const auto global_lse = (sum_lse == 0.f || sum_lse != sum_lse) + ? std::numeric_limits::infinity() + : cutlass::fast_log(sum_lse) + global_lse_max; const auto lse_scale = expf(lse - global_lse); if (epilogue_args.ptr_lse != nullptr) { - Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse); - Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + Tensor mLSE = make_tensor( + make_gmem_ptr(epilogue_args.ptr_lse), + make_shape(H, B), + epilogue_args.stride_lse); + Tensor gLSE = local_tile( + mLSE, + append<3>(cta_tiler_pv, _1{}), + take<0, 3>(cta_coord), + Step<_1, Underscore, _1>{}); // write out the global LSE - if (! kIs2Sm || (threadIdx.x < 64)) { + if (!kIs2Sm || (threadIdx.x < 64)) { gLSE(threadIdx.x) = global_lse; } } return lse_scale; } - - template + template CUTLASS_DEVICE void epilogue_reduction( ElementAcc& row_max, ElementAcc& row_sum, @@ -2002,9 +2346,9 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { TensorStorage& shared_tensors, int const& local_split_kv, ElementLSE const& lse_scale) { - constexpr int kNumThreads = kNumComputeWarps * NumThreadsPerWarp; - using Sync = cutlass::detail::NamedBarrierSync; + using Sync = + cutlass::detail::NamedBarrierSync; auto [H, K, D, B] = problem_shape; auto [D_latent, D_rope] = D; @@ -2012,37 +2356,52 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; TiledMmaPV tiled_mma_pv; - Tensor tOtO = TiledMmaPV::make_fragment_C(partition_shape_C(TiledMmaPV{}, take<0, 2>(TileShapePV{}))); + Tensor tOtO = TiledMmaPV::make_fragment_C( + partition_shape_C(TiledMmaPV{}, take<0, 2>(TileShapePV{}))); CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{}); CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{}); - using EpilogueLinearCombination = cutlass::epilogue::thread::LinearCombination; - EpilogueLinearCombination epilogue_op({epilogue_args.output_scale / row_sum * lse_scale}); + using EpilogueLinearCombination = + cutlass::epilogue::thread::LinearCombination< + ElementOut, + 1, + ElementAcc, + ElementAcc, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>; + EpilogueLinearCombination epilogue_op( + {epilogue_args.output_scale / row_sum * lse_scale}); CUTLASS_PRAGMA_UNROLL - for(int k = 0; k < IterationsPV_N; ++k) { + for (int k = 0; k < IterationsPV_N; ++k) { auto cta_coord = replace<1>(blk_coord, k); - uint32_t tmem_o = uint32_t(TmemAllocation::kO0) + k * uint32_t(TmemAllocation::kSizeAccO); + uint32_t tmem_o = uint32_t(TmemAllocation::kO0) + + k * uint32_t(TmemAllocation::kSizeAccO); tOtO.data() = tmem_o; - Tensor tAcc = tOtO(make_coord(_,_),_0{},_0{}); + Tensor tAcc = tOtO(make_coord(_, _), _0{}, _0{}); + + Tensor mO = make_tensor( + make_gmem_ptr(epilogue_args.ptr_o), + make_shape(H, D_latent, B), + epilogue_args.stride_o); + Tensor gO = local_tile(mO, TileShapeAcc{}, take<0, 3>(cta_coord)); - Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o), make_shape(H, D_latent, B), epilogue_args.stride_o); - Tensor gO = local_tile(mO, TileShapeAcc{}, take<0,3>(cta_coord)); - auto tiled_t2r = make_tmem_copy(load_op, tAcc); auto thread_idx = threadIdx.x % size(tiled_t2r); auto thread_t2r = tiled_t2r.get_slice(thread_idx); - Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_gO = thread_t2r.partition_D(gO); Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); copy(tiled_t2r, tTR_tAcc, tTR_rAcc); - Tensor sO = make_tensor(make_smem_ptr(reinterpret_cast(shared_tensors.smem_acc.begin())), SmemLayout{}); + Tensor sO = make_tensor( + make_smem_ptr( + reinterpret_cast(shared_tensors.smem_acc.begin())), + SmemLayout{}); Tensor tTR_sO = thread_t2r.partition_D(sO); Sync::sync(); @@ -2053,20 +2412,24 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { tma_store_fence(); Sync::sync(); - auto tma_reduce_sum_per_cta = epilogue_args.tma_reduce_sum.get_slice(_0{}); - auto gmem_tensor_coord = epilogue_args.tma_reduce_sum.get_tma_tensor(shape(mO)); - auto gmem_tensor_coord_per_cta = local_tile(gmem_tensor_coord, TileShapeAcc{}, take<0,3>(cta_coord)); + auto tma_reduce_sum_per_cta = + epilogue_args.tma_reduce_sum.get_slice(_0{}); + auto gmem_tensor_coord = + epilogue_args.tma_reduce_sum.get_tma_tensor(shape(mO)); + auto gmem_tensor_coord_per_cta = + local_tile(gmem_tensor_coord, TileShapeAcc{}, take<0, 3>(cta_coord)); if (threadIdx.x % kNumThreads == 0) { - copy(epilogue_args.tma_reduce_sum, - tma_reduce_sum_per_cta.partition_S(sO), - tma_reduce_sum_per_cta.partition_D(gmem_tensor_coord_per_cta)); - tma_store_arrive(); + copy( + epilogue_args.tma_reduce_sum, + tma_reduce_sum_per_cta.partition_S(sO), + tma_reduce_sum_per_cta.partition_D(gmem_tensor_coord_per_cta)); + tma_store_arrive(); } tma_store_wait<0>(); } } - template + template CUTLASS_DEVICE void compute( CtaCoord const& cta_coord, ProblemShape const& problem_shape, @@ -2081,20 +2444,19 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { typename PipelineO::PipelineState& pipeline_mma_o_consumer_state, int const& split_kv, bool const& is_fused_reduction) { - auto [H, K, D, B] = problem_shape; int k_tile_total = ceil_div(K, TileShapeS{}); int k_tile_per_cta = ceil_div(k_tile_total, split_kv); int k_index = get<3>(cta_coord) * k_tile_per_cta; // lower limit - int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + int k_tile_count = + max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); if (k_tile_count == 0) { - // if we return early, we have to make sure we release the load warp cutlass::arch::NamedBarrier( (kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, - kNamedBarrierEpilogue - ).arrive(); + kNamedBarrierEpilogue) + .arrive(); return; } @@ -2110,8 +2472,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { auto dispatch_bool = [](bool b, auto fn) { if (b) { fn(cute::true_type{}); - } - else { + } else { fn(cute::false_type{}); } }; @@ -2120,13 +2481,19 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { dispatch_bool(k_index == k_index_final, [&](auto is_last_tile) { softmax( is_last_tile, - row_max, row_sum, correction_factor, - problem_shape, mainloop_args, shared_tensors, k_index, - uint32_t(pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1), - pipeline_p_mma_producer_state.index() - ); + row_max, + row_sum, + correction_factor, + problem_shape, + mainloop_args, + shared_tensors, + k_index, + uint32_t( + pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 + : TmemAllocation::kS1), + pipeline_p_mma_producer_state.index()); }); - + k_index += 1; cutlass::arch::fence_view_async_tmem_load(); @@ -2147,11 +2514,18 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { dispatch_bool(k_index == k_index_final, [&](auto is_last_tile) { softmax( is_last_tile, - row_max, row_sum, correction_factor, - problem_shape, mainloop_args, shared_tensors, k_index, - uint32_t(pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1), - pipeline_p_mma_producer_state.index() - ); + row_max, + row_sum, + correction_factor, + problem_shape, + mainloop_args, + shared_tensors, + k_index, + uint32_t( + pipeline_mma_s_consumer_state.index() == 0 + ? TmemAllocation::kS0 + : TmemAllocation::kS1), + pipeline_p_mma_producer_state.index()); }); cutlass::arch::fence_view_async_tmem_load(); @@ -2166,7 +2540,10 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { // rescale CUTLASS_PRAGMA_UNROLL for (int j = 0; j < IterationsPV_N; j++) { - rescale(correction_factor, uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO)); + rescale( + correction_factor, + uint32_t(TmemAllocation::kO0) + + j * uint32_t(TmemAllocation::kSizeAccO)); } cutlass::arch::fence_view_async_tmem_store(); @@ -2182,13 +2559,18 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { if constexpr (kWarpsInN > 1) { // reduce row_sum if needed (for 2x2 dp) shared_tensors.smem_exchange[threadIdx.x] = row_sum; - cutlass::arch::NamedBarrier(kNumComputeWarps*NumThreadsPerWarp, kNamedBarrierExchange).sync(); + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, kNamedBarrierExchange) + .sync(); // (64, 2) shape int peer_index = (threadIdx.x + 64) % 128; row_sum += shared_tensors.smem_exchange[peer_index]; } - cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive(); + cutlass::arch::NamedBarrier( + (kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, + kNamedBarrierEpilogue) + .arrive(); const int actual_split_kv = ceil_div(k_tile_total, k_tile_per_cta); if (!is_fused_reduction || actual_split_kv == 1) { @@ -2196,37 +2578,42 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized { CUTLASS_PRAGMA_UNROLL for (int j = 0; j < IterationsPV_N; j++) { epilogue( - row_max, row_sum, - replace<1>(cta_coord, j), problem_shape, - mainloop_args, epilogue_args, shared_tensors, - uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO), - actual_split_kv - ); + row_max, + row_sum, + replace<1>(cta_coord, j), + problem_shape, + mainloop_args, + epilogue_args, + shared_tensors, + uint32_t(TmemAllocation::kO0) + + j * uint32_t(TmemAllocation::kSizeAccO), + actual_split_kv); } } else { - const ElementLSE lse_scale = - epilogue_lse_reduction( - row_max, row_sum, - cta_coord, - problem_shape, - mainloop_args, epilogue_args, - actual_split_kv - ); - - epilogue_reduction(row_max, row_sum, - cta_coord, - problem_shape, - mainloop_args, epilogue_args, - shared_tensors, - actual_split_kv, - lse_scale - ); + const ElementLSE lse_scale = epilogue_lse_reduction( + row_max, + row_sum, + cta_coord, + problem_shape, + mainloop_args, + epilogue_args, + actual_split_kv); + + epilogue_reduction( + row_max, + row_sum, + cta_coord, + problem_shape, + mainloop_args, + epilogue_args, + shared_tensors, + actual_split_kv, + lse_scale); } cutlass::arch::fence_view_async_tmem_load(); pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state); ++pipeline_mma_o_consumer_state; } - }; /////////////////////////////////////////////////////////////////////////////// From 00d297bd5b59b0cfb7dbde514ca04c2d54632bb3 Mon Sep 17 00:00:00 2001 From: Aya Ibrahim Date: Mon, 13 Oct 2025 15:16:31 -0700 Subject: [PATCH 2/4] Entry point --- ...m100_fmha_gen_mainloop_warpspecialized.hpp | 4 + .../test/attention/decode_kernel_runner.py | 202 ++++++++++++++++++ 2 files changed, 206 insertions(+) create mode 100644 fbgemm_gpu/experimental/gen_ai/test/attention/decode_kernel_runner.py diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp index b4a73f8aee..194031297b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp @@ -32,11 +32,15 @@ **************************************************************************************************/ #pragma once +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/cutlass.h" #include "cutlass/arch/memory_sm80.h" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cute/arch/simd_sm100.hpp" #include "cute/tensor.hpp" +#include "cute/tensor.hpp" #include "cute/layout.hpp" #include "collective/fmha_common.hpp" diff --git a/fbgemm_gpu/experimental/gen_ai/test/attention/decode_kernel_runner.py b/fbgemm_gpu/experimental/gen_ai/test/attention/decode_kernel_runner.py new file mode 100644 index 0000000000..4b0b43fa9d --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/test/attention/decode_kernel_runner.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Standalone script to run the decode kernel for Blackwell FMHA. + +This script runs the decode (generation) kernel for attention, which is used +during inference when generating tokens one at a time (seqlen_q = 1). + +Usage: + buck run fbcode//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:decode_kernel_entry -- \ + --batch_size 2 --seqlen_k 128 --q_heads 8 --head_dim 128 +""" + +import argparse + +import torch +from fbgemm_gpu.experimental.gen_ai.attention.cutlass_blackwell_fmha.cutlass_blackwell_fmha_interface import ( + _cutlass_blackwell_fmha_gen, + GenKernelType, +) + + +def run_decode_kernel( + batch_size: int, + seqlen_k: int, + q_heads: int, + kv_heads: int, + head_dim: int, + dtype: torch.dtype = torch.float8_e4m3fn, +) -> None: + """Run the decode kernel with specified parameters.""" + device = torch.accelerator.current_accelerator() + assert device is not None, "No GPU device available" + + # Decode kernel always has seqlen_q = 1 (generating one token at a time) + seqlen_q = 1 + + print(f"Running decode kernel with:") + print(f" Batch size: {batch_size}") + print(f" Sequence length (K/V): {seqlen_k}") + print(f" Query heads: {q_heads}") + print(f" KV heads: {kv_heads}") + print(f" Head dimension: {head_dim}") + print(f" Data type: {dtype}") + print(f" Device: {device}") + + # Generate random Q, K, V tensors + q = torch.randn( + batch_size, + seqlen_q, + q_heads, + head_dim, + dtype=torch.float if dtype == torch.float8_e4m3fn else dtype, + device=device, + ) + k = torch.randn( + batch_size, + seqlen_k, + kv_heads, + head_dim, + dtype=torch.float if dtype == torch.float8_e4m3fn else dtype, + device=device, + ) + v = torch.randn( + batch_size, + seqlen_k, + kv_heads, + head_dim, + dtype=torch.float if dtype == torch.float8_e4m3fn else dtype, + device=device, + ) + + # Convert to FP8 if needed + if dtype == torch.float8_e4m3fn: + q = q.to(torch.float8_e4m3fn) + k = k.to(torch.float8_e4m3fn) + v = v.to(torch.float8_e4m3fn) + + # Make tensors contiguous as required by _cutlass_blackwell_fmha_gen + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + + # Initialize seqlen_kv for generation phase + seqlen_kv = torch.full( + (batch_size,), + seqlen_k, + dtype=torch.int32, + device=device, + ) + + # Create batch_idx tensor + batch_idx = torch.arange(batch_size, dtype=torch.int32, device=device) + + print("\nRunning decode kernel (_cutlass_blackwell_fmha_gen)...") + print(f" Kernel Type: GenKernelType.UMMA_I") + + # Run the decode kernel directly + out = _cutlass_blackwell_fmha_gen( + q, + k, + v, + seqlen_kv, + batch_idx, + kernel_type=GenKernelType.UMMA_I, + ) + + print(f"Decode kernel completed successfully!") + print(f"Output shape: {out.shape}") + print(f"Output dtype: {out.dtype}") + print(f"Output device: {out.device}") + + # Basic sanity checks + assert out.shape == (batch_size, seqlen_q, q_heads, head_dim) + assert not torch.isnan(out).any(), "Output contains NaN values" + assert not torch.isinf(out).any(), "Output contains Inf values" + + print("\nAll sanity checks passed!") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Run the decode kernel for Blackwell FMHA" + ) + parser.add_argument( + "--batch_size", + type=int, + default=2, + help="Batch size (default: 2)", + ) + parser.add_argument( + "--seqlen_k", + type=int, + default=128, + help="Sequence length for K/V (default: 128)", + ) + parser.add_argument( + "--q_heads", + type=int, + default=8, + help="Number of query heads (default: 8)", + ) + parser.add_argument( + "--kv_heads", + type=int, + default=1, + help="Number of KV heads, use 1 for MQA (default: 1)", + ) + parser.add_argument( + "--head_dim", + type=int, + default=128, + help="Head dimension (default: 128)", + ) + parser.add_argument( + "--dtype", + type=str, + default="fp8", + choices=["fp8", "fp16", "bf16"], + help="Data type (default: fp8)", + ) + + args = parser.parse_args() + + # Convert dtype string to torch dtype + dtype_map = { + "fp8": torch.float8_e4m3fn, + "fp16": torch.float16, + "bf16": torch.bfloat16, + } + dtype = dtype_map[args.dtype] + + # Check GPU availability + if not torch.cuda.is_available(): + print("ERROR: No CUDA device available") + return + + compute_capability = torch.cuda.get_device_capability("cuda") + if compute_capability < (10, 0): + print( + f"ERROR: Decode kernel requires SM100+ (Blackwell), found SM{compute_capability[0]}{compute_capability[1]}" + ) + return + + # Run the decode kernel + run_decode_kernel( + batch_size=args.batch_size, + seqlen_k=args.seqlen_k, + q_heads=args.q_heads, + kv_heads=args.kv_heads, + head_dim=args.head_dim, + dtype=dtype, + ) + + +if __name__ == "__main__": + main() From bd8eec0733f193ae36c933016a7189ce6ea0e61e Mon Sep 17 00:00:00 2001 From: Aya Ibrahim Date: Tue, 14 Oct 2025 15:53:22 -0700 Subject: [PATCH 3/4] BF16 re-enabled w/ Cutlass update Summary: This diff updates the code to enable BF16 enablement with latest Cutlass version. The changes include updating the code in the `blackwell_gen_impl.cu` and `collective/sm100_fmha_gen_mainloop_warpspecialized.hpp` files to support BF16 data type. The `fmha.hpp` file also includes a check to ensure that the SMEM usage does not exceed the capacity. Differential Revision: D84624233 --- .../cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu | 4 +++- .../sm100_fmha_gen_mainloop_warpspecialized.hpp | 4 ++-- .../cuda/cutlass_blackwell_fmha/device/fmha.hpp | 4 ++-- .../gen_ai/test/attention/blackwell_fmha_test.py | 9 ++++----- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu index 543e575b88..56bbe96b95 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu @@ -258,12 +258,14 @@ struct GenRunner { }; // Dispatch macros for different element types -// TODO(henrylhtsang / ayaoibrahim1123): Add support for other data types. #define DISPATCH_ELEMENT_TYPE(DTYPE, ELEMENT_TYPE, ...) \ [&] { \ if (DTYPE == at::kFloat8_e4m3fn) { \ using ELEMENT_TYPE = cutlass::float_e4m3_t; \ return __VA_ARGS__(); \ + } else if (DTYPE == at::kBFloat16) { \ + using ELEMENT_TYPE = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ } else { \ throw std::runtime_error( \ "Unsupported dtype: " + std::to_string(static_cast(DTYPE))); \ diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp index 194031297b..9612e1d7c1 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp @@ -90,7 +90,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { using Mask = Mask_; static constexpr int StageCountQ = get<1>(TileShape{}) == 256 ? 1 : 2; - static constexpr int StageCountKV = 256 * 11 / get<1>(TileShape{}); + static constexpr int StageCountKV = StageCountQ * (sizeof(Element) == 1 ? 11 : 5) ; using StagesQ = cutlass::gemm::collective::StageCount; using StagesKV = cutlass::gemm::collective::StageCount; @@ -622,7 +622,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem using TMEM_STORE = conditional_t< size<1>(TileShapeQK{}) < _128{}, - SM100_TMEM_STORE_32dp32b8x, + SM100_TMEM_STORE_32dp32b16x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha.hpp index d0f4331cea..d84cc9e0a1 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha.hpp @@ -39,7 +39,7 @@ // common #include "cutlass/cutlass.h" #include "cutlass/device_kernel.h" - +#include "cutlass/arch/arch.h" #if !defined(__CUDACC_RTC__) #include "cutlass/cluster_launch.hpp" #include "cutlass/trace.h" @@ -57,7 +57,7 @@ template class FMHA { public: using Kernel = Kernel_; - + static_assert(Kernel::SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); static int const kThreadCount = Kernel::MaxThreadsPerBlock; /// Argument structure: User API diff --git a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py index fcc12a12d9..c98a018c13 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py @@ -435,6 +435,7 @@ def _execute_cutlass_blackwell_attn_varlen( @parameterized.expand( [ ( + dtype, seqlen_k, batch_size, is_mqa, @@ -442,9 +443,10 @@ def _execute_cutlass_blackwell_attn_varlen( head_dim, sm_scale, ) + for dtype in [torch.bfloat16, torch.float8_e4m3fn] for seqlen_k in [64, 128, 256, 1024] for batch_size in [1, 2] - for is_mqa in [True] + for is_mqa in [True, False] for window_size in [(-1, -1), (0, 0), (0, 128), (128, 0), (1024, 0)] for head_dim in [128] for sm_scale in [None, 1.0 / head_dim] @@ -452,6 +454,7 @@ def _execute_cutlass_blackwell_attn_varlen( ) def test_decode( self, + dtype: torch.dtype, seqlen_k: int, batch_size: int, is_mqa: bool, @@ -459,13 +462,9 @@ def test_decode( head_dim: int, sm_scale: Optional[float], q_heads: int = 8, - dtype: torch.dtype = torch.float8_e4m3fn, ) -> None: seqlen_q = 1 causal = True - assert ( - dtype == torch.float8_e4m3fn - ), "Gen Kernel only supports float8_e4m3fn for now" self._execute_cutlass_blackwell_attn_dense( batch_size, seqlen_q, From 7e3b0dbdf14cf8890c81d62f6a3e0b141bb3c04e Mon Sep 17 00:00:00 2001 From: Aya Ibrahim Date: Thu, 16 Oct 2025 09:37:36 -0700 Subject: [PATCH 4/4] Blackwell decode Op (#5004) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5004 X-link: https://github.com/facebookresearch/FBGEMM/pull/2017 Add stand-alone blackwell decode op. Supported mask: BlockDiagonalCausalWithOffsetPaddedKeysMask Differential Revision: D84630701 --- .../cutlass_blackwell_fmha_interface.py | 222 ++++++++++++++---- 1 file changed, 170 insertions(+), 52 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py b/fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py index 6f7c684cb5..690a121567 100644 --- a/fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +++ b/fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py @@ -129,6 +129,20 @@ def _cutlass_blackwell_fmha_backward( ) +def _validate_decode_inputs( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seqlen_kv: torch.Tensor | None, +) -> None: + assert seqlen_kv is not None, "seqlen_kv must be provided for decode" + tensors = {"q": q, "k": k, "v": v, "seqlen_kv": seqlen_kv} + + for name, tensor in tensors.items(): + # assert tensor.is_contiguous(), f"{name} is not contiguous" + assert tensor.is_cuda, f"{name} must be on GPU" + + def _cutlass_blackwell_fmha_gen( q: torch.Tensor, k: torch.Tensor, @@ -136,17 +150,10 @@ def _cutlass_blackwell_fmha_gen( seqlen_kv: torch.Tensor, batch_idx: torch.Tensor, kernel_type: GenKernelType = GenKernelType.UMMA_I, + window_left: int = -1, + window_right: int = -1, ) -> torch.Tensor: - assert q.is_contiguous(), "q is not contiguous" - assert k.is_contiguous(), "k is not contiguous" - assert v.is_contiguous(), "v is not contiguous" - assert seqlen_kv.is_contiguous(), "seqlen_kv is not contiguous" - assert batch_idx.is_contiguous(), "batch_idx is not contiguous" - assert q.is_cuda, "q must be on GPU" - assert k.is_cuda, "k must be on GPU" - assert v.is_cuda, "v must be on GPU" - assert seqlen_kv.is_cuda, "seqlen_kv must be on GPU" - assert batch_idx.is_cuda, "batch_idx must be on GPU" + _validate_decode_inputs(q, k, v, seqlen_kv) return torch.ops.fbgemm.fmha_gen_fwd( q, k, @@ -157,6 +164,118 @@ def _cutlass_blackwell_fmha_gen( ) +def _prepare_decode_inputs( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, bool, tuple[int, ...]]: + """ + Prepare inputs for decode kernel by handling both varlen and batch formats. + + Returns: + - Reshaped q, k, v tensors in batch format [B, 1, H, D] + - batch_size + - needs_reshape_output flag + - original_shape of q + """ + original_shape = tuple(q.shape) + needs_reshape_output = False + batch_size = q.shape[0] + + if q.dim() == 3: + # Varlen format: [total_queries, num_heads, head_dim] + q = q.view(batch_size, 1, q.shape[1], q.shape[2]) + needs_reshape_output = True + + if q.dim() != 4: + raise ValueError( + f"Invalid query shape: {q.shape}. Expected [B, 1, H, D] or [total_queries, H, D]" + ) + assert q.shape[1] == 1, "Kernel have sq=1" + + k = k.view(batch_size, -1, k.shape[1], k.shape[2]) if k.dim() == 3 else k + v = v.view(batch_size, -1, v.shape[1], v.shape[2]) if v.dim() == 3 else v + + return q, k, v, batch_size, needs_reshape_output, original_shape + + +def _create_decode_lse( + out: torch.Tensor, + batch_size: int, + needs_reshape_output: bool, + q_shape: tuple[int, ...], +) -> torch.Tensor: + """ + Create dummy LSE tensor for decode output compatibility. + Gen kernel doesn't return LSE, so we create a zero tensor. + """ + if needs_reshape_output: + # For varlen output format + lse_shape = [batch_size, q_shape[-1]] # [B, H] + else: + # For batch output format + lse_shape = [batch_size, q_shape[-2], q_shape[1]] # [B, H, 1] + + return torch.zeros(*lse_shape, dtype=torch.float32, device=out.device) + + +def _cutlass_blackwell_fmha_decode_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seqlen_kv: torch.Tensor | None = None, + cu_seqlens_q: torch.Tensor | None = None, + cu_seqlens_k: torch.Tensor | None = None, + max_seq_len_q: int | None = None, + max_seq_len_k: int | None = None, + softmax_scale: float | None = None, + causal: bool = False, + window_left: int = -1, + window_right: int = -1, + bottom_right: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Decode-optimized forward pass using the gen kernel. + This wrapper adapts the variable-length batch interface to use the gen kernel + which is optimized for decode (query length = 1). + + Accepts inputs in two formats: + - Varlen format: [total_queries, num_heads, head_dim] (3D) + - Batch format: [batch_size, 1, num_heads, head_dim] (4D) + """ + _validate_decode_inputs(q, k, v, seqlen_kv) + # Handle window size for causal attention + if causal and window_left >= 0: + window_right = 0 + + # Prepare inputs and handle format conversion + q, k, v, batch_size, needs_reshape_output, original_shape = _prepare_decode_inputs( + q, k, v + ) + + # Create batch_idx tensor + batch_idx = torch.arange(batch_size, dtype=torch.int32, device=q.device) + + # Call the gen kernel (optimized for decode) + out = torch.ops.fbgemm.fmha_gen_fwd( + q, + k, + v, + seqlen_kv, + batch_idx, + kernel_type=GenKernelType.UMMA_I, + # window_left=window_left, + # window_right=window_right, + ) + + # Reshape output back to original format if needed + if needs_reshape_output: + out = out.view(*original_shape) + + # Create dummy LSE for compatibility + lse = _create_decode_lse(out, batch_size, needs_reshape_output, original_shape) + + return out, lse + + class CutlassBlackwellFmhaFunc(torch.autograd.Function): @staticmethod def forward( # type: ignore @@ -175,67 +294,66 @@ def forward( # type: ignore bottom_right: bool = True, deterministic: bool = False, ) -> torch.Tensor: + window_left, window_right = window_size # Check if this is generation phase (sq = 1) sq = q.shape[1] - # Only check dtype if cu_seqlens_q and cu_seqlens_k are provided - if cu_seqlens_q is not None and cu_seqlens_k is not None: - assert ( - cu_seqlens_q.dtype == torch.int32 - and cu_seqlens_q.dtype == cu_seqlens_k.dtype - ), "cu_seqlens_q and cu_seqlens_k must be int32" - - # handle window_size - window_left, window_right = window_size - if causal and window_left >= 0: - window_right = 0 - if q.dim() == 4 and sq == 1: - batch_size = q.shape[0] - - # Use provided seqlen_kv - assert ( - seqlen_kv is not None - ), "seqlen_kv must be provided for generation phase" - - # Create batch_idx tensor - batch_idx = torch.arange(batch_size, dtype=torch.int32, device=q.device) - - # Use gen forward (no backward needed for generation) - out = _cutlass_blackwell_fmha_gen( - q, k, v, seqlen_kv, batch_idx, kernel_type=GenKernelType.UMMA_I - ) # For gen case, we don't need to save tensors for backward ctx.is_gen = True - return out - else: - # Use regular FMHA for non-generation case - out, softmax_lse = _cutlass_blackwell_fmha_forward( + out, _ = _cutlass_blackwell_fmha_decode_forward( q, k, v, + seqlen_kv, cu_seqlens_q, cu_seqlens_k, max_seq_len_q, max_seq_len_k, softmax_scale, causal, - seqlen_kv, window_left, window_right, bottom_right, ) - ctx.save_for_backward(q, k, v, out, softmax_lse) - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.max_seq_len_q = max_seq_len_q - ctx.max_seq_len_k = max_seq_len_k - ctx.cu_seqlens_q = cu_seqlens_q - ctx.cu_seqlens_k = cu_seqlens_k - ctx.is_gen = False - ctx.bottom_right = bottom_right - ctx.deterministic = deterministic return out + # Only check dtype if cu_seqlens_q and cu_seqlens_k are provided + if cu_seqlens_q is not None and cu_seqlens_k is not None: + assert ( + cu_seqlens_q.dtype == torch.int32 + and cu_seqlens_q.dtype == cu_seqlens_k.dtype + ), "cu_seqlens_q and cu_seqlens_k must be int32" + + # handle window_size + if causal and window_left >= 0: + window_right = 0 + # Use regular FMHA for non-generation case + out, softmax_lse = _cutlass_blackwell_fmha_forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seq_len_q, + max_seq_len_k, + softmax_scale, + causal, + seqlen_kv, + window_left, + window_right, + bottom_right, + ) + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.max_seq_len_q = max_seq_len_q + ctx.max_seq_len_k = max_seq_len_k + ctx.cu_seqlens_q = cu_seqlens_q + ctx.cu_seqlens_k = cu_seqlens_k + ctx.is_gen = False + ctx.bottom_right = bottom_right + ctx.deterministic = deterministic + return out @staticmethod def backward(ctx, dout: torch.Tensor, *args: Any) -> tuple[ # type: ignore