@@ -174,6 +174,7 @@ def _tpu_flash_attention(
174174    flash_block_sizes : BlockSizes ,
175175    dtype : jnp .dtype  =  jnp .float32 ,
176176    attention_kernel : str  =  "flash" ,
177+     is_self_attention : Optional [bool ] =  None ,
177178) ->  jax .Array :
178179  """TPU Flash Attention""" 
179180
@@ -201,8 +202,22 @@ def _tpu_flash_attention(
201202  query  =  _reshape_data_for_flash (query , heads )
202203  key  =  _reshape_data_for_flash (key , heads )
203204  value  =  _reshape_data_for_flash (value , heads )
204-   q_axis_names  =  nn .logical_to_mesh_axes (axis_names_q )
205-   kv_axis_names  =  nn .logical_to_mesh_axes (axis_names_kv )
205+   
206+   # Use different sharding strategy for self-attn vs cross-attn 
207+   if  is_self_attention  is  not None :
208+     if  is_self_attention :
209+         # Self-attention: Context Parallelism (sharding along num_heads) 
210+         q_axis_names  =  PartitionSpec ("data" , ("fsdp" , "tensor" ), None , None )
211+         kv_axis_names  =  PartitionSpec ("data" , ("fsdp" , "tensor" ), None , None )
212+     else :
213+         # Cross-attention: Sequence Parallelism for Q 
214+         # Q's sequence is sharded; K/V are replicated 
215+         q_axis_names  =  PartitionSpec ("data" , None , ("fsdp" , "tensor" ), None )
216+         kv_axis_names  =  PartitionSpec ("data" , None , None , None )
217+   else :
218+     # Fallback to original maxdiffusion behavior if the flag isn't provided 
219+     q_axis_names  =  nn .logical_to_mesh_axes (axis_names_q )
220+     kv_axis_names  =  nn .logical_to_mesh_axes (axis_names_kv )
206221
207222  @functools .partial ( 
208223      shard_map .shard_map , 
@@ -419,6 +434,7 @@ def _apply_attention(
419434    axis_names_kv : AxisNames ,
420435    flash_block_sizes : BlockSizes ,
421436    dpa_layer : Callable ,
437+     is_self_attention : bool  =  True ,
422438):
423439  """Routes to different attention kernels.""" 
424440  _check_attention_inputs (query , key , value )
@@ -439,7 +455,7 @@ def _apply_attention(
439455    )
440456  elif  attention_kernel  ==  "flash" :
441457    return  _tpu_flash_attention (
442-         query , key  *  scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype 
458+         query , key  *  scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype ,  attention_kernel ,  is_self_attention , 
443459    )
444460  elif  attention_kernel  ==  "ring" :
445461    return  _tpu_flash_attention (
@@ -574,6 +590,7 @@ def __init__(
574590      flash_block_sizes : BlockSizes  =  None ,
575591      dtype : DType  =  jnp .float32 ,
576592      quant : Quant  =  None ,
593+       is_self_attention : bool  =  True ,
577594  ):
578595    self .dpa_layer  =  None 
579596    if  attention_kernel  ==  "cudnn_flash_te" :
@@ -593,6 +610,7 @@ def __init__(
593610    self .flash_block_sizes  =  flash_block_sizes 
594611    self .dtype  =  dtype 
595612    self .quant  =  quant 
613+     self .is_self_attention  =  is_self_attention 
596614
597615  def  apply_attention (self , query : Array , key : Array , value : Array ):
598616    return  _apply_attention (
@@ -613,6 +631,7 @@ def apply_attention(self, query: Array, key: Array, value: Array):
613631        axis_names_kv = self .axis_names_kv ,
614632        flash_block_sizes = self .flash_block_sizes ,
615633        dpa_layer = self .dpa_layer ,
634+         is_self_attention = self .is_self_attention ,
616635    )
617636
618637
@@ -701,6 +720,7 @@ def __init__(
701720      precision : jax .lax .Precision  =  None ,
702721      qkv_bias : bool  =  False ,
703722      quant : Quant  =  None ,
723+       is_self_attention : bool  =  True ,
704724  ):
705725    if  attention_kernel  ==  "cudnn_flash_te" :
706726      raise  NotImplementedError (f"Wan 2.1 has not been tested with { attention_kernel }  )
@@ -730,6 +750,7 @@ def __init__(
730750        flash_block_sizes = flash_block_sizes ,
731751        dtype = dtype ,
732752        quant = quant ,
753+         is_self_attention = is_self_attention ,
733754    )
734755    # None axes corresponds to the stacked weights across all blocks 
735756    # because of the use of nnx.vmap and nnx.scan. 
0 commit comments