diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 71f5473bc9de..18e5908d6ef1 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -89,10 +89,9 @@ def __init__( self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() ) + self.max_cudagraph_size = self.compilation_config.max_capture_size if self.use_full_cuda_graph and self.fa_aot_schedule: - self.max_cudagraph_size = self.compilation_config.max_capture_size - if self.max_cudagraph_size > 992: # This condition derives from FA3's internal heuristic. # TODO(woosuk): Support larger cudagraph sizes. @@ -114,7 +113,14 @@ def __init__( self.max_num_splits = 1 def _schedule_decode( - self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal + self, + num_reqs, + cu_query_lens, + max_query_len, + seqlens, + max_seq_len, + causal, + max_num_splits, ): if self.fa_aot_schedule: return get_scheduler_metadata( @@ -130,7 +136,7 @@ def _schedule_decode( page_size=self.page_size, cu_seqlens_q=cu_query_lens, causal=causal, - num_splits=self.max_num_splits, + num_splits=max_num_splits, ) return None @@ -148,6 +154,15 @@ def _build_decode( max_query_len = query_lens_cpu.max().item() max_seq_len = seq_lens_device.max().item() + # For Flash Attention MLA + full cudagraph + max_num_splits = 0 + if self.use_full_cuda_graph and num_decode_tokens <= self.max_cudagraph_size: + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits + scheduler_metadata = self._schedule_decode( num_reqs=seq_lens_cpu.numel(), cu_query_lens=query_start_loc_device, @@ -155,10 +170,9 @@ def _build_decode( seqlens=seq_lens_device, max_seq_len=max_seq_len, causal=True, + max_num_splits=max_num_splits, ) - # For FA3 + full cudagraph - max_num_splits = 0 if self.use_full_cuda_graph and scheduler_metadata is not None: n = scheduler_metadata.shape[0] # Ensure the persistent buffer is large enough @@ -174,13 +188,6 @@ def _build_decode( self.scheduler_metadata[n:] = 0 scheduler_metadata = self.scheduler_metadata[:n] - if num_decode_tokens <= self.max_cudagraph_size: - # NOTE(woosuk): Setting num_splits > 1 may increase the memory - # usage, because the intermediate buffers of size [num_splits, - # num_heads, num_tokens, head_size] are allocated. Therefore, - # we only set num_splits when using cuda graphs. - max_num_splits = self.max_num_splits - if vllm_is_batch_invariant(): max_num_splits = 1