From 820283794a39fbc4d752081dedb2d4f305838e5d Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 16 Sep 2025 22:55:00 +0000 Subject: [PATCH] skip flash block sizes setting for cross attention. --- src/maxdiffusion/max_utils.py | 16 ++++++++-------- src/maxdiffusion/models/attention_flax.py | 3 ++- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 6638e0f8..47da450e 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -495,14 +495,14 @@ def get_flash_block_sizes(config): flash_block_sizes = None if len(config.flash_block_sizes.keys()) > 0: flash_block_sizes = splash_attention_kernel.BlockSizes( - block_q=config.flash_block_sizes["block_q"], - block_kv_compute=config.flash_block_sizes["block_kv_compute"], - block_kv=config.flash_block_sizes["block_kv"], - block_q_dkv=config.flash_block_sizes["block_q_dkv"], - block_kv_dkv=config.flash_block_sizes["block_kv_dkv"], - block_kv_dkv_compute=config.flash_block_sizes["block_kv_dkv_compute"], - block_q_dq=config.flash_block_sizes["block_q_dq"], - block_kv_dq=config.flash_block_sizes["block_kv_dq"], + block_q=int(config.flash_block_sizes["block_q"]), + block_kv_compute=int(config.flash_block_sizes["block_kv_compute"]), + block_kv=int(config.flash_block_sizes["block_kv"]), + block_q_dkv=int(config.flash_block_sizes["block_q_dkv"]), + block_kv_dkv=int(config.flash_block_sizes["block_kv_dkv"]), + block_kv_dkv_compute=int(config.flash_block_sizes["block_kv_dkv_compute"]), + block_q_dq=int(config.flash_block_sizes["block_q_dq"]), + block_kv_dq=int(config.flash_block_sizes["block_kv_dq"]), ) return flash_block_sizes diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 5df5f334..32542792 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -184,7 +184,8 @@ def _tpu_flash_attention( kv_max_block_size = key.shape[1] else: kv_max_block_size = q_max_block_size - if flash_block_sizes: + # ensure that for cross attention we override the block sizes. + if flash_block_sizes and key.shape[1] == query.shape[1]: block_sizes = flash_block_sizes else: block_sizes = splash_attention_kernel.BlockSizes(