diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index c9e7aec880b9..10777f894eae 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -19,7 +19,7 @@ else() FetchContent_Declare( flashmla GIT_REPOSITORY https://github.com/vllm-project/FlashMLA - GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f + GIT_TAG 28417e516fcbf6257a422ba117ef5b6f44da5682 GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" @@ -66,6 +66,7 @@ if(FLASH_MLA_ARCHS) ${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu + ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_metadata.cu ) set(FlashMLA_INCLUDES diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 2de7f71b6e30..d8ab0b9097ef 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -102,6 +102,12 @@ def get_mla_metadata( (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. - num_splits: (batch_size + 1), dtype torch.int32. """ + if is_fp8_kvcache and topk is None: + return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8( + cache_seqlens, + num_q_tokens_per_head_k, + num_heads_k, + ) return torch.ops._flashmla_C.get_mla_decoding_metadata( cache_seqlens, num_q_tokens_per_head_k, diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 34d3c8ee1ba2..3e481af29544 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -91,6 +91,7 @@ def __init__( self.cg_buf_tile_scheduler_metadata = None self.cg_buf_num_splits = None + self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8") device_properties = torch.cuda.get_device_properties(self.device) num_sms = device_properties.multi_processor_count @@ -123,6 +124,7 @@ def _build_decode( seq_lens_device, self.num_q_heads, 1, # MQA for the decode path + is_fp8_kvcache=self.is_fp8_kvcache, ) # TODO: we can disambiguate between decode and mixed-prefill decode here