diff --git a/flashnn/triton_kernels/paged_attn.py b/flashnn/triton_kernels/paged_attn.py index 3b9e260..525bf12 100644 --- a/flashnn/triton_kernels/paged_attn.py +++ b/flashnn/triton_kernels/paged_attn.py @@ -5,9 +5,15 @@ import torch import triton import triton.language as tl +import random + +THREADS_PER_WARP = 64 if torch.version.hip is not None else 32 + +PARTITION_SIZE = 512 + +# from vllm.custom_ops import paged_attention_custom -# Requires triton >= 2.2.0 def paged_attention( out: torch.Tensor, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] query: torch.Tensor, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] @@ -42,10 +48,10 @@ def paged_attention( num_splits = 1 partition_size = 0 if max_context_len >= 8192: - partition_size = max(512, kv_block_size) + partition_size = max(256, kv_block_size) num_splits = triton.cdiv(max_context_len, partition_size) else: - partition_size = max(512, kv_block_size) + partition_size = max(256, kv_block_size) num_splits = triton.cdiv(max_context_len, partition_size) if max_context_len <= 1024 or kv_block_size >= 256: num_splits = 1 @@ -74,6 +80,7 @@ def paged_attention( paged_attn_w_mma(*kwargs) + def paged_attn_wo_mma( out: torch.Tensor, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] query: torch.Tensor, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] @@ -136,7 +143,11 @@ def paged_attn_wo_mma( "PARTITION_SIZE": 0 if num_splits == 1 else partition_size, "POWER_OF_2_MAX_SEQ_LEN": triton.next_power_of_2(max_context_len), "USE_PARTITIONING": False if num_splits == 1 else True, + # "UNROLL_FACTOR": 1, } + method_name = "_paged_attn_wo_mma_kernel_" + "_".join( + str(value) for value in const_kwargs.values() + ) _paged_attn_wo_mma_kernel[grid](*kwargs, **const_kwargs) if num_splits != 1: @@ -156,456 +167,98 @@ def paged_attn_wo_mma( tmp_out.stride(1), tmp_out.stride(2), ] - reduce_grid = (num_q_heads, num_seqs, 1) + grid = (num_q_heads, num_seqs, 1) const_kwargs = { "HEAD_SIZE": head_size, "PADDED_NUM_SPLITS": padded_num_splits, "PARTITION_SIZE": partition_size, } - _paged_attn_wo_mma_v2_reduce_kernel[reduce_grid](*kwargs, **const_kwargs) + _paged_attn_wo_mma_v2_reduce_kernel[grid](*kwargs, **const_kwargs) -def paged_attn_w_mma( - out: torch.Tensor, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] - query: torch.Tensor, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] - key_cache: torch.Tensor, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE] - value_cache: torch.Tensor, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE], required same stride with key_cache - context_lens: torch.Tensor, # [num_seqs] - block_tables: torch.Tensor, # [num_seqs, max_num_blocks_per_seq] - attn_scale: float, - max_context_len: int, - num_splits: int, - partition_size: int, - device, - alibi_slope: torch.Tensor = None, -) -> None: - num_seqs = query.shape[0] - num_kv_heads = key_cache.shape[1] - kv_block_size = key_cache.shape[2] - head_size = key_cache.shape[3] - query_group_size = query.shape[1] // num_kv_heads - if query_group_size == 1: - padded_group_size = 1 - elif query_group_size < 16: - padded_group_size = 16 - else: - padded_group_size = triton.next_power_of_2(query_group_size) - with torch.cuda.device(device): - assert alibi_slope is None - grid = (num_seqs, num_kv_heads, num_splits) - shape_info = (num_seqs, num_kv_heads, num_splits, query_group_size) - m_i = torch.empty(size=shape_info, dtype=torch.float32, device=query.device) - l_i = torch.empty(size=shape_info, dtype=torch.float32, device=query.device) - tmp_out = torch.empty( - size=(*shape_info, head_size), dtype=out.dtype, device=out.device - ) - kwargs = [ - m_i, - l_i, - out if num_splits == 1 else tmp_out, - query, - key_cache, - value_cache, - context_lens, - block_tables, - attn_scale, - block_tables.stride(0), - block_tables.stride(1), - query.stride(0), - query.stride(1), - query.stride(2), - key_cache.stride(0), - key_cache.stride(1), - key_cache.stride(2), - key_cache.stride(3), - ] - if num_splits == 1: - kwargs += [ - out.stride(0), - out.stride(1), - out.stride(1), - out.stride(1), - out.stride(2), - ] - else: - kwargs += [ - tmp_out.stride(0), - tmp_out.stride(1), - tmp_out.stride(2), - tmp_out.stride(3), - tmp_out.stride(4), - ] - const_kwargs = { - "HEAD_SIZE": head_size, - "QUERY_GROUP_SIZE": query_group_size, - "PADDED_QUERY_GROUP_SIZE": padded_group_size, - "NUM_KV_HEADS": num_kv_heads, - "KV_BLOCK_SIZE": kv_block_size, - "PARTITION_SIZE": partition_size, - } - _paged_attn_w_mma_kernel[grid](*kwargs, **const_kwargs) +@triton.jit +def _inner_paged_attn_unroll_0_kernel( + q, + k_cache, + v_cache, + stride_km, + block_base_ptrs, + base_offs_kv, + alibi_slope, + block_offs, + seq_len, + qkv, + qk_max, + exp_sum, + BLOCK_SIZE: tl.constexpr, + LO: tl.constexpr, + HI: tl.constexpr, +): + for block_idx in range(LO, HI, 1): + offs_kv_0 = tl.load(block_base_ptrs + block_idx + 0) * stride_km + base_offs_kv + k_0 = tl.load(k_cache + offs_kv_0) + v_0 = tl.load(v_cache + offs_kv_0) + _qk_0 = tl.sum((q[None, :] * k_0).to(tl.float32), axis=1) - if num_splits != 1: - assert (partition_size >= kv_block_size) and ( - partition_size % kv_block_size == 0 - ), f"partition_size={partition_size}, kv_block_size={kv_block_size}" - reduce_grid = (num_seqs, num_kv_heads, 1) - kwargs = [ - out, - m_i, - l_i, - tmp_out, - context_lens, - num_splits, - out.stride(0), - out.stride(1), - out.stride(2), - ] - const_kwargs = { - "HEAD_SIZE": head_size, - "QUERY_GROUP_SIZE": query_group_size, - "PADDED_QUERY_GROUP_SIZE": padded_group_size, - "NUM_KV_HEADS": num_kv_heads, - "PARTITION_SIZE": partition_size, - "NUM_PARTITIONS": triton.next_power_of_2(num_splits), - } - _paged_attn_w_mma_v2_reduce_kernel[reduce_grid](*kwargs, **const_kwargs) + if alibi_slope is not None: + _qk_0 += alibi_slope * ( + (block_idx + 0) * BLOCK_SIZE + block_offs - seq_len + 1 + ) + + _qk_max = tl.maximum(tl.max(_qk_0, axis=0), qk_max) + exp_tmp = tl.exp(_qk_0 - _qk_max) + _exp_sum = exp_sum * tl.exp(qk_max - _qk_max) + tl.sum(exp_tmp, axis=0) + qkv_sum_tmp = (tl.exp(_qk_0[:, None] - _qk_max)).to( + v_cache.dtype.element_ty + ) * v_0 + qkv = (qkv * (exp_sum * tl.exp(qk_max - _qk_max)) + qkv_sum_tmp) / _exp_sum + qk_max = _qk_max + exp_sum = _exp_sum + return qkv, qk_max, exp_sum -@triton.autotune( - configs=[ - triton.Config({}, num_stages=stages, num_warps=warps) - for stages in [0, 1, 3, 4] - for warps in [4, 8, 16] - ], - key=["QUERY_GROUP_SIZE", "HEAD_SIZE", "KV_BLOCK_SIZE"], -) @triton.jit -def _paged_attn_w_mma_kernel( - m_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE] - l_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE] - out_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE, HEAD_SIZE] - q_ptr, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] - k_cache_ptr, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE] - v_cache_ptr, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE] - context_lens_ptr, # [num_seqs] - block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] - attn_scale, - stride_bt0, - stride_bt1, - stride_q0, - stride_q1, - stride_q2, - stride_kv0, - stride_kv1, - stride_kv2, - stride_kv3, - stride_o0, - stride_o1, - stride_o2, - stride_o3, - stride_o4, - HEAD_SIZE: tl.constexpr, - QUERY_GROUP_SIZE: tl.constexpr, - PADDED_QUERY_GROUP_SIZE: tl.constexpr, - NUM_KV_HEADS: tl.constexpr, - KV_BLOCK_SIZE: tl.constexpr, - PARTITION_SIZE: tl.constexpr, +def _inner_paged_attn_unroll_2_kernel( + q, + k_cache, + v_cache, + stride_km, + block_base_ptrs, + base_offs_kv, + alibi_slope, + block_offs, + seq_len, + qkv, + qk_max, + exp_sum, + BLOCK_SIZE: tl.constexpr, + LO: tl.constexpr, + HI: tl.constexpr, ): - seq_idx = tl.program_id(0) - kv_head_idx = tl.program_id(1) - part_idx = tl.program_id(2) - max_num_partitions = tl.num_programs(2) - - # scale sm_scale by log_2(e) and use - # 2^x instead of exp in the loop because CSE and LICM - # don't work as expected with `exp` in the loop - log2e: tl.constexpr = 1.4426950408889634 + for block_idx in range(LO, HI, 2): + offs_kv_0 = tl.load(block_base_ptrs + block_idx + 0) * stride_km + base_offs_kv + offs_kv_1 = tl.load(block_base_ptrs + block_idx + 1) * stride_km + base_offs_kv - USE_PARTITIONING = PARTITION_SIZE > 0 - context_len = tl.load(context_lens_ptr + seq_idx) - if USE_PARTITIONING: - context_start_idx = part_idx * PARTITION_SIZE - if context_start_idx >= context_len: - return - context_end_idx = tl.minimum(context_start_idx + PARTITION_SIZE, context_len) - num_blocks = tl.cdiv(context_end_idx - context_start_idx, KV_BLOCK_SIZE) - else: - num_blocks = tl.cdiv(context_len, KV_BLOCK_SIZE) + k_0 = tl.load(k_cache + offs_kv_0) + k_1 = tl.load(k_cache + offs_kv_1) - block_offset = tl.arange(0, KV_BLOCK_SIZE) - head_offset = tl.arange(0, HEAD_SIZE) - padding_group_offset = tl.arange(0, PADDED_QUERY_GROUP_SIZE) + v_0 = tl.load(v_cache + offs_kv_0) + v_1 = tl.load(v_cache + offs_kv_1) - kv_offset = ( - kv_head_idx * stride_kv1 - + block_offset[:, None] * stride_kv2 - + head_offset[None, :] * stride_kv3 - ) + _qk_0 = tl.sum((q[None, :] * k_0).to(tl.float32), axis=1) + _qk_1 = tl.sum((q[None, :] * k_1).to(tl.float32), axis=1) - # Load queries. - q_offset = ( - seq_idx * stride_q0 - + (kv_head_idx * QUERY_GROUP_SIZE + padding_group_offset[:, None]) * stride_q1 - + head_offset[None, :] * stride_q2 - ) - group_mask = padding_group_offset[:, None] < QUERY_GROUP_SIZE - # q: [PADDED_QUERY_GROUP_SIZE, HEAD_SIZE] - q = tl.load(q_ptr + q_offset, mask=group_mask, other=0.0) - q = (q * attn_scale).to(q_ptr.dtype.element_ty) + if alibi_slope is not None: + _qk_0 += alibi_slope * ( + (block_idx + 0) * BLOCK_SIZE + block_offs - seq_len + 1 + ) + _qk_1 += alibi_slope * ( + (block_idx + 1) * BLOCK_SIZE + block_offs - seq_len + 1 + ) - m_i = tl.zeros([PADDED_QUERY_GROUP_SIZE], dtype=tl.float32) - float("inf") - l_i = tl.zeros([PADDED_QUERY_GROUP_SIZE], dtype=tl.float32) - acc = tl.zeros([PADDED_QUERY_GROUP_SIZE, HEAD_SIZE], dtype=tl.float32) - - num_prev_blocks = part_idx * (PARTITION_SIZE // KV_BLOCK_SIZE) - for i in range(num_blocks): - block_idx = num_prev_blocks + i - block_number = tl.load( - block_tables_ptr + seq_idx * stride_bt0 + block_idx * stride_bt1 - ) - - # Load a key block. - kv_block_offset = block_number * stride_kv0 + kv_offset - mask_offset = block_idx * KV_BLOCK_SIZE + block_offset - kv_mask = mask_offset[:, None] < context_len - - # k: [KV_BLOCK_SIZE, HEAD_SIZE] - k = tl.load(k_cache_ptr + kv_block_offset, mask=kv_mask, other=0.0) - - # qk: [PADDED_QUERY_GROUP_SIZE, KV_BLOCK_SIZE] - if PADDED_QUERY_GROUP_SIZE == 1: - qk = tl.sum(q[:, None, :] * k[None, :, :], axis=2) - else: - qk = tl.dot(q, k.T, out_dtype=tl.float32) - - # qk *= attn_scale - qk = tl.where(mask_offset < context_len, qk, float("-inf")) - - m_i_new = tl.maximum(m_i, tl.max(qk, axis=1)) - - # p: [PADDED_QUERY_GROUP_SIZE, KV_BLOCK_SIZE] - p = tl.math.exp2((qk - m_i_new[:, None]) * log2e) - alpha = tl.math.exp2((m_i - m_i_new) * log2e) - acc *= alpha[:, None] - - # v: [KV_BLOCK_SIZE, HEAD_SIZE] - v = tl.load(v_cache_ptr + kv_block_offset, mask=kv_mask, other=0.0) - - if PADDED_QUERY_GROUP_SIZE == 1: - acc += tl.sum(p.T[:, :, None] * v[:, None, :], axis=0) - else: - p = p.to(v.dtype) - acc += tl.dot(p, v, out_dtype=tl.float32) - - l_i = l_i * alpha + tl.sum(p, axis=1) - m_i = m_i_new - acc = acc / l_i[:, None] - - if USE_PARTITIONING: - part_offset = ( - (seq_idx * NUM_KV_HEADS + kv_head_idx) - * max_num_partitions - * QUERY_GROUP_SIZE - + part_idx * QUERY_GROUP_SIZE - + padding_group_offset - ) - mask = padding_group_offset < QUERY_GROUP_SIZE - tl.store(m_i_ptr + part_offset, m_i, mask=mask) - tl.store(l_i_ptr + part_offset, l_i, mask=mask) - - out_offset = seq_idx * stride_o0 - if USE_PARTITIONING: - out_offset += kv_head_idx * stride_o1 - else: - out_offset += kv_head_idx * QUERY_GROUP_SIZE * stride_o1 - out_offset += ( - part_idx * stride_o2 - + padding_group_offset[:, None] * stride_o3 - + head_offset[None, :] * stride_o4 - ) - - group_mask = padding_group_offset[:, None] < QUERY_GROUP_SIZE - tl.store(out_ptr + out_offset, acc, mask=group_mask) - - -@triton.autotune( - configs=[triton.Config({}, num_warps=warps) for warps in [4, 8, 16]], - key=["QUERY_GROUP_SIZE", "HEAD_SIZE", "NUM_PARTITIONS", "PARTITION_SIZE"], -) -@triton.jit -def _paged_attn_w_mma_v2_reduce_kernel( - out_ptr, # [num_seqs, NUM_KV_HEADS, QUERY_GROUP_SIZE, HEAD_SIZE] - m_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE] - l_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE] - tmp_out_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE, HEAD_SIZE] - context_lens_ptr, # [num_seqs] - max_num_partitions, # partition stride - stride_o0, - stride_o1, - stride_o2, - HEAD_SIZE: tl.constexpr, - QUERY_GROUP_SIZE: tl.constexpr, - PADDED_QUERY_GROUP_SIZE: tl.constexpr, - NUM_KV_HEADS: tl.constexpr, - PARTITION_SIZE: tl.constexpr, - NUM_PARTITIONS: tl.constexpr, -): - seq_idx = tl.program_id(0) - kv_head_idx = tl.program_id(1) - - context_len = tl.load(context_lens_ptr + seq_idx) - - num_partitions = tl.cdiv(context_len, PARTITION_SIZE) - group_head_offset = ( - tl.arange(0, PADDED_QUERY_GROUP_SIZE)[:, None] * HEAD_SIZE - + tl.arange(0, HEAD_SIZE)[None, :] - ) - group_mask = tl.arange(0, PADDED_QUERY_GROUP_SIZE)[:, None] < QUERY_GROUP_SIZE - if num_partitions == 1: - tmp_out_offset = ( - seq_idx * NUM_KV_HEADS + kv_head_idx - ) * max_num_partitions * QUERY_GROUP_SIZE * HEAD_SIZE + group_head_offset - tmp_out = tl.load(tmp_out_ptr + tmp_out_offset, mask=group_mask, other=0.0) - - out_offset = ( - seq_idx * stride_o0 - + kv_head_idx * QUERY_GROUP_SIZE * stride_o1 - + group_head_offset * stride_o2 - ) - tl.store(out_ptr + out_offset, tmp_out, mask=group_mask) - return - - # Get the global max logit. - ml_offset = ( - (seq_idx * NUM_KV_HEADS + kv_head_idx) * max_num_partitions * QUERY_GROUP_SIZE - + tl.arange(0, NUM_PARTITIONS)[:, None] * QUERY_GROUP_SIZE - + tl.arange(0, PADDED_QUERY_GROUP_SIZE)[None, :] - ) - - mask = (tl.arange(0, NUM_PARTITIONS)[:, None] < num_partitions) & ( - tl.arange(0, PADDED_QUERY_GROUP_SIZE)[None, :] < QUERY_GROUP_SIZE - ) - # m_i: [NUM_PARTITIONS, PADDED_QUERY_GROUP_SIZE] - m_i = tl.load(m_i_ptr + ml_offset, mask=mask, other=float("-inf")) - # m: [PADDED_QUERY_GROUP_SIZE] - m = tl.max(m_i, axis=0) - - # Rescale the exp sums and compute the global sum. - # l_i: [NUM_PARTITIONS, PADDED_QUERY_GROUP_SIZE] - l_i = tl.load(l_i_ptr + ml_offset, mask=mask, other=0.0) - l_i *= tl.exp(m_i - m[None, :]) - # l: [PADDED_QUERY_GROUP_SIZE] - l = tl.sum(l_i, axis=0) - # r: [NUM_PARTITIONS, PADDED_QUERY_GROUP_SIZE] - r = l_i / l[None, :] - r = tl.reshape(r, (NUM_PARTITIONS, PADDED_QUERY_GROUP_SIZE, 1)) - - tmp_out_offset = ( - (seq_idx * NUM_KV_HEADS + kv_head_idx) - * max_num_partitions - * QUERY_GROUP_SIZE - * HEAD_SIZE - + tl.arange(0, NUM_PARTITIONS)[:, None, None] * QUERY_GROUP_SIZE * HEAD_SIZE - + tl.arange(0, PADDED_QUERY_GROUP_SIZE)[None, :, None] * HEAD_SIZE - + tl.arange(0, HEAD_SIZE)[None, None, :] - ) - # tmp_out: [NUM_PARTITIONS, PADDED_QUERY_GROUP_SIZE, HEAD_SIZE] - tmp_out = tl.load(tmp_out_ptr + tmp_out_offset, mask=mask[:, :, None], other=0.0) - # out: [PADDED_QUERY_GROUP_SIZE, HEAD_SIZE] - out = tl.sum((tmp_out * r).to(tl.float32), axis=0) - - out_offset = ( - seq_idx * stride_o0 - + kv_head_idx * QUERY_GROUP_SIZE * stride_o1 - + group_head_offset * stride_o2 - ) - tl.store(out_ptr + out_offset, out, mask=group_mask) - - -@triton.jit -def _inner_paged_attn_unroll_0_kernel( - q, - k_cache, - v_cache, - stride_km, - block_base_ptrs, - base_offs_kv, - alibi_slope, - block_offs, - seq_len, - qkv, - qk_max, - exp_sum, - BLOCK_SIZE: tl.constexpr, - LO: tl.constexpr, - HI: tl.constexpr, -): - for block_idx in range(LO, HI, 1): - offs_kv_0 = tl.load(block_base_ptrs + block_idx + 0) * stride_km + base_offs_kv - k_0 = tl.load(k_cache + offs_kv_0) - v_0 = tl.load(v_cache + offs_kv_0) - _qk_0 = tl.sum((q[None, :] * k_0).to(tl.float32), axis=1) - - if alibi_slope is not None: - _qk_0 += alibi_slope * ( - (block_idx + 0) * BLOCK_SIZE + block_offs - seq_len + 1 - ) - - _qk_max = tl.maximum(tl.max(_qk_0, axis=0), qk_max) - exp_tmp = tl.exp(_qk_0 - _qk_max) - _exp_sum = exp_sum * tl.exp(qk_max - _qk_max) + tl.sum(exp_tmp, axis=0) - qkv_sum_tmp = (tl.exp(_qk_0[:, None] - _qk_max)).to( - v_cache.dtype.element_ty - ) * v_0 - qkv = (qkv * (exp_sum * tl.exp(qk_max - _qk_max)) + qkv_sum_tmp) / _exp_sum - qk_max = _qk_max - exp_sum = _exp_sum - return qkv, qk_max, exp_sum - - -@triton.jit -def _inner_paged_attn_unroll_2_kernel( - q, - k_cache, - v_cache, - stride_km, - block_base_ptrs, - base_offs_kv, - alibi_slope, - block_offs, - seq_len, - qkv, - qk_max, - exp_sum, - BLOCK_SIZE: tl.constexpr, - LO: tl.constexpr, - HI: tl.constexpr, -): - for block_idx in range(LO, HI, 2): - offs_kv_0 = tl.load(block_base_ptrs + block_idx + 0) * stride_km + base_offs_kv - offs_kv_1 = tl.load(block_base_ptrs + block_idx + 1) * stride_km + base_offs_kv - - k_0 = tl.load(k_cache + offs_kv_0) - k_1 = tl.load(k_cache + offs_kv_1) - - v_0 = tl.load(v_cache + offs_kv_0) - v_1 = tl.load(v_cache + offs_kv_1) - - _qk_0 = tl.sum((q[None, :] * k_0).to(tl.float32), axis=1) - _qk_1 = tl.sum((q[None, :] * k_1).to(tl.float32), axis=1) - - if alibi_slope is not None: - _qk_0 += alibi_slope * ( - (block_idx + 0) * BLOCK_SIZE + block_offs - seq_len + 1 - ) - _qk_1 += alibi_slope * ( - (block_idx + 1) * BLOCK_SIZE + block_offs - seq_len + 1 - ) - - _qk_max = tl.maximum(tl.max(_qk_0, axis=0), qk_max) - _qk_max = tl.maximum(tl.max(_qk_1, axis=0), _qk_max) + _qk_max = tl.maximum(tl.max(_qk_0, axis=0), qk_max) + _qk_max = tl.maximum(tl.max(_qk_1, axis=0), _qk_max) exp_tmp = tl.exp(_qk_0 - _qk_max) + tl.exp(_qk_1 - _qk_max) _exp_sum = exp_sum * tl.exp(qk_max - _qk_max) + tl.sum(exp_tmp, axis=0) @@ -1013,7 +666,7 @@ def _paged_attn_wo_mma_kernel( @triton.autotune( - configs=[triton.Config({}, num_warps=warps) for warps in [4, 8, 16]], + configs=[triton.Config({}, num_warps=warps) for warps in [2, 4, 8, 16]], key=["HEAD_SIZE", "PADDED_NUM_SPLITS", "PARTITION_SIZE"], ) @triton.jit @@ -1075,3 +728,1861 @@ def _paged_attn_wo_mma_v2_reduce_kernel( inv_sum = 1.0 / (global_exp_sum + 1e-6) tl.store(out + out_ptr, acc * inv_sum) + + + + +def paged_attn_w_mma_unroll4( + out: torch.Tensor, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] + query: torch.Tensor, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] + key_cache: torch.Tensor, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE] + value_cache: torch.Tensor, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE], required same stride with key_cache + context_lens: torch.Tensor, # [num_seqs] + block_tables: torch.Tensor, # [num_seqs, max_num_blocks_per_seq] + attn_scale: float, + max_context_len: int, + num_splits: int, + partition_size: int, + device, + alibi_slope: torch.Tensor = None, +) -> None: + num_seqs = query.shape[0] + num_kv_heads = key_cache.shape[1] + kv_block_size = key_cache.shape[2] + head_size = key_cache.shape[3] + query_group_size = query.shape[1] // num_kv_heads + if query_group_size == 1: + padded_group_size = 1 + elif query_group_size < 16: + padded_group_size = 16 + # elif query_group_size < 32: + # padded_group_size = 32 + else: + padded_group_size = triton.next_power_of_2(query_group_size) + + with torch.cuda.device(device): + assert alibi_slope is None + grid = (num_seqs, num_kv_heads, num_splits) + shape_info = (num_seqs, num_kv_heads, num_splits, query_group_size) + m_i = torch.empty(size=shape_info, dtype=torch.float32, device=query.device) + l_i = torch.empty(size=shape_info, dtype=torch.float32, device=query.device) + tmp_out = torch.empty( + size=(*shape_info, head_size), dtype=out.dtype, device=out.device + ) + kwargs = [ + m_i, + l_i, + out if num_splits == 1 else tmp_out, + query, + key_cache, + value_cache, + context_lens, + block_tables, + attn_scale, + block_tables.stride(0), + block_tables.stride(1), + query.stride(0), + query.stride(1), + query.stride(2), + key_cache.stride(0), + key_cache.stride(1), + key_cache.stride(2), + key_cache.stride(3), + ] + if num_splits == 1: + kwargs += [ + out.stride(0), + out.stride(1), + out.stride(1), + out.stride(1), + out.stride(2), + ] + else: + kwargs += [ + tmp_out.stride(0), + tmp_out.stride(1), + tmp_out.stride(2), + tmp_out.stride(3), + tmp_out.stride(4), + ] + const_kwargs = { + "HEAD_SIZE": head_size, + "QUERY_GROUP_SIZE": query_group_size, + "PADDED_QUERY_GROUP_SIZE": padded_group_size, + "NUM_KV_HEADS": num_kv_heads, + "KV_BLOCK_SIZE": kv_block_size, + "PARTITION_SIZE": partition_size, + # "num_stages":1, + # "num_warps":4, + # "num_ctas":1, + # "kpack":1, + # "matrix_instr_nonkdim":16 #32, + } + _paged_attn_w_mma_kernel_unroll4[grid](*kwargs, **const_kwargs) + if num_splits != 1: + assert (partition_size >= kv_block_size) and ( + partition_size % kv_block_size == 0 + ), f"partition_size={partition_size}, kv_block_size={kv_block_size}" + reduce_grid = (num_seqs, num_kv_heads, 1) + kwargs = [ + out, + m_i, + l_i, + tmp_out, + context_lens, + num_splits, + out.stride(0), + out.stride(1), + out.stride(2), + ] + const_kwargs = { + "HEAD_SIZE": head_size, + "QUERY_GROUP_SIZE": query_group_size, + "PADDED_QUERY_GROUP_SIZE": padded_group_size, + "NUM_KV_HEADS": num_kv_heads, + "PARTITION_SIZE": partition_size, + "NUM_PARTITIONS": triton.next_power_of_2(num_splits), + } + method_name = "paged_attn_w_mma_v2_reduce_kernel_" + "_".join( + str(value) for value in const_kwargs.values() + ) + _paged_attn_w_mma_v2_reduce_kernel[reduce_grid](*kwargs, **const_kwargs) + + + +def paged_attn_w_mma_unroll8( + out: torch.Tensor, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] + query: torch.Tensor, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] + key_cache: torch.Tensor, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE] + value_cache: torch.Tensor, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE], required same stride with key_cache + context_lens: torch.Tensor, # [num_seqs] + block_tables: torch.Tensor, # [num_seqs, max_num_blocks_per_seq] + attn_scale: float, + max_context_len: int, + num_splits: int, + partition_size: int, + device, + alibi_slope: torch.Tensor = None, +) -> None: + num_seqs = query.shape[0] + num_kv_heads = key_cache.shape[1] + kv_block_size = key_cache.shape[2] + head_size = key_cache.shape[3] + query_group_size = query.shape[1] // num_kv_heads + if query_group_size == 1: + padded_group_size = 1 + elif query_group_size < 16: + padded_group_size = 16 + # elif query_group_size < 32: + # padded_group_size = 32 + else: + padded_group_size = triton.next_power_of_2(query_group_size) + + with torch.cuda.device(device): + assert alibi_slope is None + grid = (num_seqs, num_kv_heads, num_splits) + shape_info = (num_seqs, num_kv_heads, num_splits, query_group_size) + m_i = torch.empty(size=shape_info, dtype=torch.float32, device=query.device) + l_i = torch.empty(size=shape_info, dtype=torch.float32, device=query.device) + tmp_out = torch.empty( + size=(*shape_info, head_size), dtype=out.dtype, device=out.device + ) + kwargs = [ + m_i, + l_i, + out if num_splits == 1 else tmp_out, + query, + key_cache, + value_cache, + context_lens, + block_tables, + attn_scale, + block_tables.stride(0), + block_tables.stride(1), + query.stride(0), + query.stride(1), + query.stride(2), + key_cache.stride(0), + key_cache.stride(1), + key_cache.stride(2), + key_cache.stride(3), + ] + if num_splits == 1: + kwargs += [ + out.stride(0), + out.stride(1), + out.stride(1), + out.stride(1), + out.stride(2), + ] + else: + kwargs += [ + tmp_out.stride(0), + tmp_out.stride(1), + tmp_out.stride(2), + tmp_out.stride(3), + tmp_out.stride(4), + ] + const_kwargs = { + "HEAD_SIZE": head_size, + "QUERY_GROUP_SIZE": query_group_size, + "PADDED_QUERY_GROUP_SIZE": padded_group_size, + "NUM_KV_HEADS": num_kv_heads, + "KV_BLOCK_SIZE": kv_block_size, + "PARTITION_SIZE": partition_size, + # "num_stages":1, + # "num_warps":4, + # "num_ctas":1, + # "kpack":1, + # "matrix_instr_nonkdim":16 #32, + } + _paged_attn_w_mma_kernel_unroll8[grid](*kwargs, **const_kwargs) + if num_splits != 1: + assert (partition_size >= kv_block_size) and ( + partition_size % kv_block_size == 0 + ), f"partition_size={partition_size}, kv_block_size={kv_block_size}" + reduce_grid = (num_seqs, num_kv_heads, 1) + kwargs = [ + out, + m_i, + l_i, + tmp_out, + context_lens, + num_splits, + out.stride(0), + out.stride(1), + out.stride(2), + ] + const_kwargs = { + "HEAD_SIZE": head_size, + "QUERY_GROUP_SIZE": query_group_size, + "PADDED_QUERY_GROUP_SIZE": padded_group_size, + "NUM_KV_HEADS": num_kv_heads, + "PARTITION_SIZE": partition_size, + "NUM_PARTITIONS": triton.next_power_of_2(num_splits), + } + method_name = "paged_attn_w_mma_v2_reduce_kernel_" + "_".join( + str(value) for value in const_kwargs.values() + ) + _paged_attn_w_mma_v2_reduce_kernel[reduce_grid](*kwargs, **const_kwargs) + + +def paged_attn_w_mma_unroll2( + out: torch.Tensor, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] + query: torch.Tensor, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] + key_cache: torch.Tensor, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE] + value_cache: torch.Tensor, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE], required same stride with key_cache + context_lens: torch.Tensor, # [num_seqs] + block_tables: torch.Tensor, # [num_seqs, max_num_blocks_per_seq] + attn_scale: float, + max_context_len: int, + num_splits: int, + partition_size: int, + device, + alibi_slope: torch.Tensor = None, +) -> None: + num_seqs = query.shape[0] + num_kv_heads = key_cache.shape[1] + kv_block_size = key_cache.shape[2] + head_size = key_cache.shape[3] + query_group_size = query.shape[1] // num_kv_heads + if query_group_size == 1: + padded_group_size = 1 + elif query_group_size < 16: + padded_group_size = 16 + # elif query_group_size < 32: + # padded_group_size = 32 + else: + padded_group_size = triton.next_power_of_2(query_group_size) + + with torch.cuda.device(device): + assert alibi_slope is None + grid = (num_seqs, num_kv_heads, num_splits) + shape_info = (num_seqs, num_kv_heads, num_splits, query_group_size) + m_i = torch.empty(size=shape_info, dtype=torch.float32, device=query.device) + l_i = torch.empty(size=shape_info, dtype=torch.float32, device=query.device) + tmp_out = torch.empty( + size=(*shape_info, head_size), dtype=out.dtype, device=out.device + ) + kwargs = [ + m_i, + l_i, + out if num_splits == 1 else tmp_out, + query, + key_cache, + value_cache, + context_lens, + block_tables, + attn_scale, + block_tables.stride(0), + block_tables.stride(1), + query.stride(0), + query.stride(1), + query.stride(2), + key_cache.stride(0), + key_cache.stride(1), + key_cache.stride(2), + key_cache.stride(3), + ] + if num_splits == 1: + kwargs += [ + out.stride(0), + out.stride(1), + out.stride(1), + out.stride(1), + out.stride(2), + ] + else: + kwargs += [ + tmp_out.stride(0), + tmp_out.stride(1), + tmp_out.stride(2), + tmp_out.stride(3), + tmp_out.stride(4), + ] + const_kwargs = { + "HEAD_SIZE": head_size, + "QUERY_GROUP_SIZE": query_group_size, + "PADDED_QUERY_GROUP_SIZE": padded_group_size, + "NUM_KV_HEADS": num_kv_heads, + "KV_BLOCK_SIZE": kv_block_size, + "PARTITION_SIZE": partition_size, + # "num_stages":1, + # "num_warps":4, + # "num_ctas":1, + # "kpack":1, + # "matrix_instr_nonkdim":16 #32, + } + _paged_attn_w_mma_kernel_unroll2[grid](*kwargs, **const_kwargs) + if num_splits != 1: + assert (partition_size >= kv_block_size) and ( + partition_size % kv_block_size == 0 + ), f"partition_size={partition_size}, kv_block_size={kv_block_size}" + reduce_grid = (num_seqs, num_kv_heads, 1) + kwargs = [ + out, + m_i, + l_i, + tmp_out, + context_lens, + num_splits, + out.stride(0), + out.stride(1), + out.stride(2), + ] + const_kwargs = { + "HEAD_SIZE": head_size, + "QUERY_GROUP_SIZE": query_group_size, + "PADDED_QUERY_GROUP_SIZE": padded_group_size, + "NUM_KV_HEADS": num_kv_heads, + "PARTITION_SIZE": partition_size, + "NUM_PARTITIONS": triton.next_power_of_2(num_splits), + } + method_name = "paged_attn_w_mma_v2_reduce_kernel_" + "_".join( + str(value) for value in const_kwargs.values() + ) + _paged_attn_w_mma_v2_reduce_kernel[reduce_grid](*kwargs, **const_kwargs) + + +def paged_attn_w_mma( + out: torch.Tensor, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] + query: torch.Tensor, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] + key_cache: torch.Tensor, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE] + value_cache: torch.Tensor, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE], required same stride with key_cache + context_lens: torch.Tensor, # [num_seqs] + block_tables: torch.Tensor, # [num_seqs, max_num_blocks_per_seq] + attn_scale: float, + max_context_len: int, + num_splits: int, + partition_size: int, + device, + alibi_slope: torch.Tensor = None, +) -> None: + num_seqs = query.shape[0] + num_kv_heads = key_cache.shape[1] + kv_block_size = key_cache.shape[2] + head_size = key_cache.shape[3] + query_group_size = query.shape[1] // num_kv_heads + if query_group_size == 1: + padded_group_size = 1 + elif query_group_size < 16: + padded_group_size = 16 + # elif query_group_size < 32: + # padded_group_size = 32 + else: + padded_group_size = triton.next_power_of_2(query_group_size) + + with torch.cuda.device(device): + assert alibi_slope is None + grid = (num_seqs, num_kv_heads, num_splits) + shape_info = (num_seqs, num_kv_heads, num_splits, query_group_size) + m_i = torch.empty(size=shape_info, dtype=torch.float32, device=query.device) + l_i = torch.empty(size=shape_info, dtype=torch.float32, device=query.device) + tmp_out = torch.empty( + size=(*shape_info, head_size), dtype=out.dtype, device=out.device + ) + kwargs = [ + m_i, + l_i, + out if num_splits == 1 else tmp_out, + query, + key_cache, + value_cache, + context_lens, + block_tables, + attn_scale, + block_tables.stride(0), + block_tables.stride(1), + query.stride(0), + query.stride(1), + query.stride(2), + key_cache.stride(0), + key_cache.stride(1), + key_cache.stride(2), + key_cache.stride(3), + ] + if num_splits == 1: + kwargs += [ + out.stride(0), + out.stride(1), + out.stride(1), + out.stride(1), + out.stride(2), + ] + else: + kwargs += [ + tmp_out.stride(0), + tmp_out.stride(1), + tmp_out.stride(2), + tmp_out.stride(3), + tmp_out.stride(4), + ] + const_kwargs = { + "HEAD_SIZE": head_size, + "QUERY_GROUP_SIZE": query_group_size, + "PADDED_QUERY_GROUP_SIZE": padded_group_size, + "NUM_KV_HEADS": num_kv_heads, + "KV_BLOCK_SIZE": kv_block_size, + "PARTITION_SIZE": partition_size, + # "num_stages":1, + # "num_warps":4, + # "num_ctas":1, + # "kpack":1, + # "matrix_instr_nonkdim":16 #32, + } + # if 1 : + _paged_attn_w_mma_kernel[grid](*kwargs, **const_kwargs) + # _paged_attn_w_mma_kernel_unroll2[grid](*kwargs, **const_kwargs) + # else: + # key_cache = key_cache.transpose(2, 3) + # kwargs_trans = [ + # m_i, + # l_i, + # out if num_splits == 1 else tmp_out, + # query, + # key_cache, + # value_cache, + # context_lens, + # block_tables, + # attn_scale, + # block_tables.stride(0), + # block_tables.stride(1), + # query.stride(0), + # query.stride(1), + # query.stride(2), + # key_cache.stride(0), + # key_cache.stride(1), + # key_cache.stride(2), + # key_cache.stride(3), + # value_cache.stride(0), + # value_cache.stride(1), + # value_cache.stride(2), + # value_cache.stride(3), + # out.stride(0), + # out.stride(1), + # out.stride(1), + # out.stride(1), + # out.stride(2), + # ] + # _paged_attn_w_mma_transK_kernel[grid](*kwargs_trans, **const_kwargs) + + + if num_splits != 1: + assert (partition_size >= kv_block_size) and ( + partition_size % kv_block_size == 0 + ), f"partition_size={partition_size}, kv_block_size={kv_block_size}" + reduce_grid = (num_seqs, num_kv_heads, 1) + kwargs = [ + out, + m_i, + l_i, + tmp_out, + context_lens, + num_splits, + out.stride(0), + out.stride(1), + out.stride(2), + ] + const_kwargs = { + "HEAD_SIZE": head_size, + "QUERY_GROUP_SIZE": query_group_size, + "PADDED_QUERY_GROUP_SIZE": padded_group_size, + "NUM_KV_HEADS": num_kv_heads, + "PARTITION_SIZE": partition_size, + "NUM_PARTITIONS": triton.next_power_of_2(num_splits), + } + method_name = "paged_attn_w_mma_v2_reduce_kernel_" + "_".join( + str(value) for value in const_kwargs.values() + ) + _paged_attn_w_mma_v2_reduce_kernel[reduce_grid](*kwargs, **const_kwargs) + + + + +@triton.autotune( + configs=[ + triton.Config({ + # 'matrix_instr_nonkdim': kdim, + # 'kpack': kpack, + # 'waves_per_eu': waves, + }, + num_stages=stages, + num_warps=warps, + # maxnreg=maxnreg, + ) + # for kdim in [16,32] + # for kpack in [1] + # for maxnreg in [64, 128, 256] + # for waves in [0] + for stages in [0, 1] + for warps in [1, 2, 4, 8, 16] + ], + key=["QUERY_GROUP_SIZE", "HEAD_SIZE", "KV_BLOCK_SIZE"], +) +@triton.jit +def _paged_attn_w_mma_kernel( + m_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE] + l_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE] + out_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE, HEAD_SIZE] + q_ptr, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] + k_cache_ptr, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE] + v_cache_ptr, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE] + context_lens_ptr, # [num_seqs] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + attn_scale, + stride_bt0, + stride_bt1, + stride_q0, + stride_q1, + stride_q2, + stride_kv0, + stride_kv1, + stride_kv2, + stride_kv3, + stride_o0, + stride_o1, + stride_o2, + stride_o3, + stride_o4, + HEAD_SIZE: tl.constexpr, + QUERY_GROUP_SIZE: tl.constexpr, # RATIO + PADDED_QUERY_GROUP_SIZE: tl.constexpr, + NUM_KV_HEADS: tl.constexpr, + KV_BLOCK_SIZE: tl.constexpr, + PARTITION_SIZE: tl.constexpr, + # UNROLL_FACTOR: tl.constexpr, +): + seq_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + part_idx = tl.program_id(2) + max_num_partitions = tl.num_programs(2) + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + log2e: tl.constexpr = 1.4426950408889634 + + USE_PARTITIONING = PARTITION_SIZE > 0 + context_len = tl.load(context_lens_ptr + seq_idx) + if USE_PARTITIONING: + context_start_idx = part_idx * PARTITION_SIZE + if context_start_idx >= context_len: + return + context_end_idx = tl.minimum(context_start_idx + PARTITION_SIZE, context_len) + num_blocks = tl.cdiv(context_end_idx - context_start_idx, KV_BLOCK_SIZE) + else: + num_blocks = tl.cdiv(context_len, KV_BLOCK_SIZE) + + block_offset = tl.arange(0, KV_BLOCK_SIZE) + head_offset = tl.arange(0, HEAD_SIZE) + padding_group_offset = tl.arange(0, PADDED_QUERY_GROUP_SIZE) + + kv_offset = ( + kv_head_idx * stride_kv1 + + block_offset[:, None] * stride_kv2 + + head_offset[None, :] * stride_kv3 + ) + + # Load queries. + q_offset = ( + seq_idx * stride_q0 + + (kv_head_idx * QUERY_GROUP_SIZE + padding_group_offset[:, None]) * stride_q1 + + head_offset[None, :] * stride_q2 + ) + group_mask = padding_group_offset[:, None] < QUERY_GROUP_SIZE + # q: [PADDED_QUERY_GROUP_SIZE, HEAD_SIZE] + q = tl.load(q_ptr + q_offset, mask=group_mask, other=0.0) + # q = tl.load(q_ptr + q_offset) + q = (q * attn_scale).to(q_ptr.dtype.element_ty) + + m_i = tl.zeros([PADDED_QUERY_GROUP_SIZE], dtype=tl.float32) - float("inf") + l_i = tl.zeros([PADDED_QUERY_GROUP_SIZE], dtype=tl.float32) + acc = tl.zeros([PADDED_QUERY_GROUP_SIZE, HEAD_SIZE], dtype=tl.float32) + num_prev_blocks = part_idx * (PARTITION_SIZE // KV_BLOCK_SIZE) + + for i in range(num_blocks): + block_idx = num_prev_blocks + i + block_number = tl.load( + block_tables_ptr + seq_idx * stride_bt0 + block_idx * stride_bt1 + ) + # Load a key block. + kv_block_offset = block_number * stride_kv0 + kv_offset + mask_offset = block_idx * KV_BLOCK_SIZE + block_offset + # kv_mask = mask_offset[:, None] < context_len + + # k: [KV_BLOCK_SIZE, HEAD_SIZE] + k = tl.load(k_cache_ptr + kv_block_offset) + # k = tl.load(k_cache_ptr + kv_block_offset, mask=kv_mask, other=0.0) + + # v: [KV_BLOCK_SIZE, HEAD_SIZE] + v = tl.load(v_cache_ptr + kv_block_offset) + # v = tl.load(v_cache_ptr + kv_block_offset, mask=kv_mask, other=0.0) + + # qk: [PADDED_QUERY_GROUP_SIZE, KV_BLOCK_SIZE] + # if PADDED_QUERY_GROUP_SIZE == 1: + # qk = tl.sum(q[:, None, :] * k[None, :, :], axis=2) + # else: + # qk = tl.dot(q, k.T, out_dtype=tl.float32) + qk = tl.dot(q, k.T, out_dtype=tl.float32) + + # qk *= attn_scale + qk = tl.where(mask_offset < context_len, qk, float("-inf")) + + m_i_new = tl.maximum(m_i, tl.max(qk, axis=1)) + + # p: [PADDED_QUERY_GROUP_SIZE, KV_BLOCK_SIZE] + p = tl.math.exp2((qk - m_i_new[:, None]) * log2e) + alpha = tl.math.exp2((m_i - m_i_new) * log2e) + acc *= alpha[:, None] + + # if PADDED_QUERY_GROUP_SIZE == 1: + # acc += tl.sum(p.T[:, :, None] * v[:, None, :], axis=0) + # else: + p = p.to(v.dtype) + acc += tl.dot(p, v, out_dtype=tl.float32) + + l_i = l_i * alpha + tl.sum(p, axis=1) + m_i = m_i_new + acc = acc / l_i[:, None] + + if USE_PARTITIONING: + part_offset = ( + (seq_idx * NUM_KV_HEADS + kv_head_idx) + * max_num_partitions + * QUERY_GROUP_SIZE + + part_idx * QUERY_GROUP_SIZE + + padding_group_offset + ) + mask = padding_group_offset < QUERY_GROUP_SIZE + tl.store(m_i_ptr + part_offset, m_i, mask=mask) + tl.store(l_i_ptr + part_offset, l_i, mask=mask) + # tl.store(m_i_ptr + part_offset, l_i) + # tl.store(l_i_ptr + part_offset, l_i) + + out_offset = seq_idx * stride_o0 + if USE_PARTITIONING: + out_offset += kv_head_idx * stride_o1 + else: + out_offset += kv_head_idx * QUERY_GROUP_SIZE * stride_o1 + out_offset += ( + part_idx * stride_o2 + + padding_group_offset[:, None] * stride_o3 + + head_offset[None, :] * stride_o4 + ) + + group_mask = padding_group_offset[:, None] < QUERY_GROUP_SIZE + tl.store(out_ptr + out_offset, acc, mask=group_mask) + # tl.store(out_ptr + out_offset, acc) + +@triton.autotune( + configs=[ + triton.Config({ + # 'matrix_instr_nonkdim': kdim, + # 'kpack': kpack, + # 'waves_per_eu': waves, + }, + num_stages=stages, + num_warps=warps, + # maxnreg=maxnreg, + ) + # for kdim in [16,32] + # for kpack in [1] + # for maxnreg in [64, 128, 256] + # for waves in [0] + for stages in [0, 1] + for warps in [1, 2, 4, 8, 16] + ], + key=["QUERY_GROUP_SIZE", "HEAD_SIZE", "KV_BLOCK_SIZE"], +) +@triton.jit +def _paged_attn_w_mma_kernel_unroll2( + m_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE] + l_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE] + out_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE, HEAD_SIZE] + q_ptr, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] + k_cache_ptr, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE] + v_cache_ptr, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE] + context_lens_ptr, # [num_seqs] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + attn_scale, + stride_bt0, + stride_bt1, + stride_q0, + stride_q1, + stride_q2, + stride_kv0, + stride_kv1, + stride_kv2, + stride_kv3, + stride_o0, + stride_o1, + stride_o2, + stride_o3, + stride_o4, + HEAD_SIZE: tl.constexpr, + QUERY_GROUP_SIZE: tl.constexpr, # RATIO + PADDED_QUERY_GROUP_SIZE: tl.constexpr, + NUM_KV_HEADS: tl.constexpr, + KV_BLOCK_SIZE: tl.constexpr, + PARTITION_SIZE: tl.constexpr, +): + seq_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + part_idx = tl.program_id(2) + max_num_partitions = tl.num_programs(2) + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + log2e: tl.constexpr = 1.4426950408889634 + + USE_PARTITIONING = PARTITION_SIZE > 0 + context_len = tl.load(context_lens_ptr + seq_idx) + if USE_PARTITIONING: + context_start_idx = part_idx * PARTITION_SIZE + if context_start_idx >= context_len: + return + context_end_idx = tl.minimum(context_start_idx + PARTITION_SIZE, context_len) + num_blocks = tl.cdiv(context_end_idx - context_start_idx, KV_BLOCK_SIZE) + else: + num_blocks = tl.cdiv(context_len, KV_BLOCK_SIZE) + + block_offset = tl.arange(0, KV_BLOCK_SIZE) + head_offset = tl.arange(0, HEAD_SIZE) + padding_group_offset = tl.arange(0, PADDED_QUERY_GROUP_SIZE) + + kv_offset = ( + kv_head_idx * stride_kv1 + + block_offset[:, None] * stride_kv2 + + head_offset[None, :] * stride_kv3 + ) + + # Load queries. + q_offset = ( + seq_idx * stride_q0 + + (kv_head_idx * QUERY_GROUP_SIZE + padding_group_offset[:, None]) * stride_q1 + + head_offset[None, :] * stride_q2 + ) + group_mask = padding_group_offset[:, None] < QUERY_GROUP_SIZE + # q: [PADDED_QUERY_GROUP_SIZE, HEAD_SIZE] + q = tl.load(q_ptr + q_offset, mask=group_mask, other=0.0) + # q = tl.load(q_ptr + q_offset) + q = (q * attn_scale).to(q_ptr.dtype.element_ty) + + m_i = tl.zeros([PADDED_QUERY_GROUP_SIZE], dtype=tl.float32) - float("inf") + l_i = tl.zeros([PADDED_QUERY_GROUP_SIZE], dtype=tl.float32) + acc = tl.zeros([PADDED_QUERY_GROUP_SIZE, HEAD_SIZE], dtype=tl.float32) + + num_prev_blocks = part_idx * (PARTITION_SIZE // KV_BLOCK_SIZE) + + # k_scalar = tl.full([HEAD_SIZE, KV_BLOCK_SIZE], 1.0, dtype=tl.float16) + + for i in range(tl.cdiv(num_blocks, 2)): + # block_idx = num_prev_blocks + i + block_number_0 = tl.load( + block_tables_ptr + seq_idx * stride_bt0 + (num_prev_blocks + i*2 + 0) * stride_bt1 + ) + block_number_1 = tl.load( + block_tables_ptr + seq_idx * stride_bt0 + (num_prev_blocks + i*2 + 1) * stride_bt1 + ) + # block_number_0 = tl.load( + # block_tables_ptr + seq_idx * stride_bt0 + (num_prev_blocks + i + 0) * stride_bt1 + # ) + # block_number_0 = tl.load( + # block_tables_ptr + seq_idx * stride_bt0 + (num_prev_blocks + i + 0) * stride_bt1 + # ) + # if (seq_idx==0 and kv_head_idx==0) and (part_idx==0 and i ==0): + # tl.device_print("block_number_new:", block_number_0, block_number_1) + + # Load a key block. + kv_block_offset_0 = block_number_0 * stride_kv0 + kv_offset + kv_block_offset_1 = block_number_1 * stride_kv0 + kv_offset + mask_offset_0 = (num_prev_blocks + i*2 + 0) * KV_BLOCK_SIZE + block_offset + mask_offset_1 = (num_prev_blocks + i*2 + 1) * KV_BLOCK_SIZE + block_offset + kv_mask_0 = mask_offset_0[:, None] < context_len + kv_mask_1 = mask_offset_1[:, None] < context_len + + # k: [KV_BLOCK_SIZE, HEAD_SIZE] + k_0 = tl.load(k_cache_ptr + kv_block_offset_0) + k_1 = tl.load(k_cache_ptr + kv_block_offset_1) + + # if (i == 0 and seq_idx == 0) and (kv_head_idx == 0 and part_idx == 0): + # tl.device_print("v_0--", kv_mask_0) + # k_0 = tl.load(k_cache_ptr + kv_block_offset_0, mask=kv_mask_0, other=0.0) + # k_1 = tl.load(k_cache_ptr + kv_block_offset_1, mask=kv_mask_1, other=0.0) + # k = tl.load(k_cache_ptr + kv_block_offset, mask=kv_mask, other=0.0) + + # v: [KV_BLOCK_SIZE, HEAD_SIZE] + v_0 = tl.load(v_cache_ptr + kv_block_offset_0) + v_1 = tl.load(v_cache_ptr + kv_block_offset_1) + # v_0 = tl.load(v_cache_ptr + kv_block_offset_0, mask=kv_mask_0, other=0.0) + # v_1 = tl.load(v_cache_ptr + kv_block_offset_1, mask=kv_mask_1, other=0.0) + + # v = tl.load(v_cache_ptr + kv_block_offset, mask=kv_mask, other=0.0) + + # qk: [PADDED_QUERY_GROUP_SIZE, KV_BLOCK_SIZE] + # if PADDED_QUERY_GROUP_SIZE == 1: + # qk = tl.sum(q[:, None, :] * k[None, :, :], axis=2) + # else: + # qk = tl.dot(q, k.T, out_dtype=tl.float32) + qk_0 = tl.dot(q, k_0.T, out_dtype=tl.float32) + qk_1 = tl.dot(q, k_1.T, out_dtype=tl.float32) + + qk_0 = tl.where(mask_offset_0 < context_len, qk_0, float("-inf")) + qk_1 = tl.where(mask_offset_1 < context_len, qk_1, float("-inf")) + + m_i_new = tl.maximum(m_i, tl.max(qk_0, axis=1)) + m_i_new = tl.maximum(m_i_new, tl.max(qk_1, axis=1)) + + # p: [PADDED_QUERY_GROUP_SIZE, KV_BLOCK_SIZE] + p_0 = tl.math.exp2((qk_0 - m_i_new[:, None]) * log2e) + p_1 = tl.math.exp2((qk_1 - m_i_new[:, None]) * log2e) + alpha = tl.math.exp2((m_i - m_i_new) * log2e) + acc *= alpha[:, None] + + # if PADDED_QUERY_GROUP_SIZE == 1: + # acc += tl.sum(p.T[:, :, None] * v[:, None, :], axis=0) + # else: + # acc += tl.sum(p_0.T[:, :, None] * v_0[:, None, :], axis=0) + # acc += tl.sum(p_1.T[:, :, None] * v_1[:, None, :], axis=0) + # if PADDED_QUERY_GROUP_SIZE == 1: + # acc += tl.sum(p_0.T[:, :, None] * v_0[:, None, :], axis=0) + # acc += tl.sum(p_1.T[:, :, None] * v_1[:, None, :], axis=0) + # else: + # if (seq_idx==0 and kv_head_idx==0) and (part_idx==0 and i ==0): + # tl.device_print("v_0--", kv_mask_0) + # tl.debug_barrier() + p_0 = p_0.to(v_0.dtype) + p_1 = p_1.to(v_1.dtype) + # acc += tl.dot(p_0, v_0, out_dtype=tl.float32) + # acc += tl.dot(p_1, v_1, out_dtype=tl.float32) + # _acc_0 = tl.dot(p_0, v_0, out_dtype=tl.float32) + # _acc_1 = tl.dot(p_1, v_1, out_dtype=tl.float32) + # acc += _acc_0 + _acc_1 + + acc += tl.dot(p_0, v_0, out_dtype=tl.float32) + tl.dot(p_1, v_1, out_dtype=tl.float32) + + # exp_sum=exp_sum * tl.exp(qk_max - _qk_max) + tl.sum(exp_tmp, axis=0) + l_i = l_i * alpha + tl.sum(p_0 + p_1, axis=1) + m_i = m_i_new # _qk_max + + acc = acc / l_i[:, None] + + if USE_PARTITIONING: + part_offset = ( + (seq_idx * NUM_KV_HEADS + kv_head_idx) + * max_num_partitions + * QUERY_GROUP_SIZE + + part_idx * QUERY_GROUP_SIZE + + padding_group_offset + ) + mask = padding_group_offset < QUERY_GROUP_SIZE + # tl.store(m_i_ptr + part_offset, m_i, mask=mask) + # tl.store(l_i_ptr + part_offset, l_i, mask=mask) + tl.store(m_i_ptr + part_offset, l_i) + tl.store(l_i_ptr + part_offset, l_i) + + out_offset = seq_idx * stride_o0 + if USE_PARTITIONING: + out_offset += kv_head_idx * stride_o1 + else: + out_offset += kv_head_idx * QUERY_GROUP_SIZE * stride_o1 + out_offset += ( + part_idx * stride_o2 + + padding_group_offset[:, None] * stride_o3 + + head_offset[None, :] * stride_o4 + ) + + group_mask = padding_group_offset[:, None] < QUERY_GROUP_SIZE + tl.store(out_ptr + out_offset, acc, mask=group_mask) + # tl.store(out_ptr + out_offset, acc) + + + +@triton.autotune( + configs=[ + triton.Config({ + # 'matrix_instr_nonkdim': kdim, + # 'kpack': kpack, + # 'waves_per_eu': waves, + }, + num_stages=stages, + num_warps=warps, + # maxnreg=maxnreg, + ) + # for kdim in [16,32] + # for kpack in [1] + # for maxnreg in [64, 128, 256] + # for waves in [0] + for stages in [0, 1] + for warps in [1, 2, 4, 8, 16] + ], + key=["QUERY_GROUP_SIZE", "HEAD_SIZE", "KV_BLOCK_SIZE"], +) +@triton.jit +def _paged_attn_w_mma_kernel_unroll4( + m_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE] + l_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE] + out_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE, HEAD_SIZE] + q_ptr, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] + k_cache_ptr, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE] + v_cache_ptr, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE] + context_lens_ptr, # [num_seqs] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + attn_scale, + stride_bt0, + stride_bt1, + stride_q0, + stride_q1, + stride_q2, + stride_kv0, + stride_kv1, + stride_kv2, + stride_kv3, + stride_o0, + stride_o1, + stride_o2, + stride_o3, + stride_o4, + HEAD_SIZE: tl.constexpr, + QUERY_GROUP_SIZE: tl.constexpr, # RATIO + PADDED_QUERY_GROUP_SIZE: tl.constexpr, + NUM_KV_HEADS: tl.constexpr, + KV_BLOCK_SIZE: tl.constexpr, + PARTITION_SIZE: tl.constexpr, +): + seq_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + part_idx = tl.program_id(2) + max_num_partitions = tl.num_programs(2) + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + log2e: tl.constexpr = 1.4426950408889634 + + USE_PARTITIONING = PARTITION_SIZE > 0 + context_len = tl.load(context_lens_ptr + seq_idx) + if USE_PARTITIONING: + context_start_idx = part_idx * PARTITION_SIZE + if context_start_idx >= context_len: + return + context_end_idx = tl.minimum(context_start_idx + PARTITION_SIZE, context_len) + num_blocks = tl.cdiv(context_end_idx - context_start_idx, KV_BLOCK_SIZE) + else: + num_blocks = tl.cdiv(context_len, KV_BLOCK_SIZE) + + block_offset = tl.arange(0, KV_BLOCK_SIZE) + head_offset = tl.arange(0, HEAD_SIZE) + padding_group_offset = tl.arange(0, PADDED_QUERY_GROUP_SIZE) + + kv_offset = ( + kv_head_idx * stride_kv1 + + block_offset[:, None] * stride_kv2 + + head_offset[None, :] * stride_kv3 + ) + + # Load queries. + q_offset = ( + seq_idx * stride_q0 + + (kv_head_idx * QUERY_GROUP_SIZE + padding_group_offset[:, None]) * stride_q1 + + head_offset[None, :] * stride_q2 + ) + group_mask = padding_group_offset[:, None] < QUERY_GROUP_SIZE + # q: [PADDED_QUERY_GROUP_SIZE, HEAD_SIZE] + q = tl.load(q_ptr + q_offset, mask=group_mask, other=0.0) + q = (q * attn_scale).to(q_ptr.dtype.element_ty) + + m_i = tl.zeros([PADDED_QUERY_GROUP_SIZE], dtype=tl.float32) - float("inf") + l_i = tl.zeros([PADDED_QUERY_GROUP_SIZE], dtype=tl.float32) + acc = tl.zeros([PADDED_QUERY_GROUP_SIZE, HEAD_SIZE], dtype=tl.float32) + + num_prev_blocks = part_idx * (PARTITION_SIZE // KV_BLOCK_SIZE) + + for i in range(tl.cdiv(num_blocks, 4)): + # block_idx = num_prev_blocks + i + block_number_0 = tl.load( + block_tables_ptr + seq_idx * stride_bt0 + (num_prev_blocks + i*4 + 0) * stride_bt1 + ) + block_number_1 = tl.load( + block_tables_ptr + seq_idx * stride_bt0 + (num_prev_blocks + i*4 + 1) * stride_bt1 + ) + block_number_2 = tl.load( + block_tables_ptr + seq_idx * stride_bt0 + (num_prev_blocks + i*4 + 2) * stride_bt1 + ) + block_number_3 = tl.load( + block_tables_ptr + seq_idx * stride_bt0 + (num_prev_blocks + i*4 + 3) * stride_bt1 + ) + # Load a key block. + kv_block_offset_0 = block_number_0 * stride_kv0 + kv_offset + kv_block_offset_1 = block_number_1 * stride_kv0 + kv_offset + kv_block_offset_2 = block_number_2 * stride_kv0 + kv_offset + kv_block_offset_3 = block_number_3 * stride_kv0 + kv_offset + mask_offset_0 = (num_prev_blocks + i*4 + 0) * KV_BLOCK_SIZE + block_offset + mask_offset_1 = (num_prev_blocks + i*4 + 1) * KV_BLOCK_SIZE + block_offset + mask_offset_2 = (num_prev_blocks + i*4 + 2) * KV_BLOCK_SIZE + block_offset + mask_offset_3 = (num_prev_blocks + i*4 + 3) * KV_BLOCK_SIZE + block_offset + + # k: [KV_BLOCK_SIZE, HEAD_SIZE] + k_0 = tl.load(k_cache_ptr + kv_block_offset_0) + k_1 = tl.load(k_cache_ptr + kv_block_offset_1) + k_2 = tl.load(k_cache_ptr + kv_block_offset_2) + k_3 = tl.load(k_cache_ptr + kv_block_offset_3) + + # v: [KV_BLOCK_SIZE, HEAD_SIZE] + v_0 = tl.load(v_cache_ptr + kv_block_offset_0) + v_1 = tl.load(v_cache_ptr + kv_block_offset_1) + v_2 = tl.load(v_cache_ptr + kv_block_offset_2) + v_3 = tl.load(v_cache_ptr + kv_block_offset_3) + + qk_0 = tl.dot(q, k_0.T, out_dtype=tl.float32) + qk_1 = tl.dot(q, k_1.T, out_dtype=tl.float32) + qk_2 = tl.dot(q, k_2.T, out_dtype=tl.float32) + qk_3 = tl.dot(q, k_3.T, out_dtype=tl.float32) + + qk_0 = tl.where(mask_offset_0 < context_len, qk_0, float("-inf")) + qk_1 = tl.where(mask_offset_1 < context_len, qk_1, float("-inf")) + qk_2 = tl.where(mask_offset_2 < context_len, qk_2, float("-inf")) + qk_3 = tl.where(mask_offset_3 < context_len, qk_3, float("-inf")) + + m_i_new = tl.maximum(m_i, tl.max(qk_0, axis=1)) + m_i_new = tl.maximum(m_i_new, tl.max(qk_1, axis=1)) + m_i_new = tl.maximum(m_i_new, tl.max(qk_2, axis=1)) + m_i_new = tl.maximum(m_i_new, tl.max(qk_3, axis=1)) + + # p: [PADDED_QUERY_GROUP_SIZE, KV_BLOCK_SIZE] + p_0 = tl.math.exp2((qk_0 - m_i_new[:, None]) * log2e).to(v_0.dtype) + p_1 = tl.math.exp2((qk_1 - m_i_new[:, None]) * log2e).to(v_1.dtype) + p_2 = tl.math.exp2((qk_2 - m_i_new[:, None]) * log2e).to(v_2.dtype) + p_3 = tl.math.exp2((qk_3 - m_i_new[:, None]) * log2e).to(v_3.dtype) + alpha = tl.math.exp2((m_i - m_i_new) * log2e) + acc *= alpha[:, None] + + acc += ( + tl.dot(p_0, v_0, out_dtype=tl.float32) + + tl.dot(p_1, v_1, out_dtype=tl.float32) + + tl.dot(p_2, v_2, out_dtype=tl.float32) + + tl.dot(p_3, v_3, out_dtype=tl.float32) + ) + + # exp_sum=exp_sum * tl.exp(qk_max - _qk_max) + tl.sum(exp_tmp, axis=0) + l_i = l_i * alpha + tl.sum(p_0 + p_1 + p_2 + p_3 , axis=1) + m_i = m_i_new # _qk_max + + acc = acc / l_i[:, None] + + if USE_PARTITIONING: + part_offset = ( + (seq_idx * NUM_KV_HEADS + kv_head_idx) + * max_num_partitions + * QUERY_GROUP_SIZE + + part_idx * QUERY_GROUP_SIZE + + padding_group_offset + ) + tl.store(m_i_ptr + part_offset, l_i) + tl.store(l_i_ptr + part_offset, l_i) + + out_offset = seq_idx * stride_o0 + if USE_PARTITIONING: + out_offset += kv_head_idx * stride_o1 + else: + out_offset += kv_head_idx * QUERY_GROUP_SIZE * stride_o1 + out_offset += ( + part_idx * stride_o2 + + padding_group_offset[:, None] * stride_o3 + + head_offset[None, :] * stride_o4 + ) + + group_mask = padding_group_offset[:, None] < QUERY_GROUP_SIZE + tl.store(out_ptr + out_offset, acc, mask=group_mask) + # tl.store(out_ptr + out_offset, acc) + + + +@triton.autotune( + configs=[ + triton.Config({ + # 'matrix_instr_nonkdim': kdim, + # 'kpack': kpack, + # 'waves_per_eu': waves, + }, + num_stages=stages, + num_warps=warps, + # maxnreg=maxnreg, + ) + # for kdim in [16,32] + # for kpack in [1] + # for maxnreg in [64, 128, 256] + # for waves in [0] + for stages in [0, 1] + for warps in [1, 2, 4, 8, 16] + ], + key=["QUERY_GROUP_SIZE", "HEAD_SIZE", "KV_BLOCK_SIZE"], +) +@triton.jit +def _paged_attn_w_mma_kernel_unroll8( + m_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE] + l_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE] + out_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE, HEAD_SIZE] + q_ptr, # [num_seqs, NUM_KV_HEADS * QUERY_GROUP_SIZE, HEAD_SIZE] + k_cache_ptr, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE] + v_cache_ptr, # [num_blocks, NUM_KV_HEADS, KV_BLOCK_SIZE, HEAD_SIZE] + context_lens_ptr, # [num_seqs] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + attn_scale, + stride_bt0, + stride_bt1, + stride_q0, + stride_q1, + stride_q2, + stride_kv0, + stride_kv1, + stride_kv2, + stride_kv3, + stride_o0, + stride_o1, + stride_o2, + stride_o3, + stride_o4, + HEAD_SIZE: tl.constexpr, + QUERY_GROUP_SIZE: tl.constexpr, # RATIO + PADDED_QUERY_GROUP_SIZE: tl.constexpr, + NUM_KV_HEADS: tl.constexpr, + KV_BLOCK_SIZE: tl.constexpr, + PARTITION_SIZE: tl.constexpr, +): + seq_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + part_idx = tl.program_id(2) + max_num_partitions = tl.num_programs(2) + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + log2e: tl.constexpr = 1.4426950408889634 + + USE_PARTITIONING = PARTITION_SIZE > 0 + context_len = tl.load(context_lens_ptr + seq_idx) + if USE_PARTITIONING: + context_start_idx = part_idx * PARTITION_SIZE + if context_start_idx >= context_len: + return + context_end_idx = tl.minimum(context_start_idx + PARTITION_SIZE, context_len) + num_blocks = tl.cdiv(context_end_idx - context_start_idx, KV_BLOCK_SIZE) + else: + num_blocks = tl.cdiv(context_len, KV_BLOCK_SIZE) + + block_offset = tl.arange(0, KV_BLOCK_SIZE) + head_offset = tl.arange(0, HEAD_SIZE) + padding_group_offset = tl.arange(0, PADDED_QUERY_GROUP_SIZE) + + kv_offset = ( + kv_head_idx * stride_kv1 + + block_offset[:, None] * stride_kv2 + + head_offset[None, :] * stride_kv3 + ) + + # Load queries. + q_offset = ( + seq_idx * stride_q0 + + (kv_head_idx * QUERY_GROUP_SIZE + padding_group_offset[:, None]) * stride_q1 + + head_offset[None, :] * stride_q2 + ) + group_mask = padding_group_offset[:, None] < QUERY_GROUP_SIZE + # q: [PADDED_QUERY_GROUP_SIZE, HEAD_SIZE] + q = tl.load(q_ptr + q_offset, mask=group_mask, other=0.0) + q = (q * attn_scale).to(q_ptr.dtype.element_ty) + + m_i = tl.zeros([PADDED_QUERY_GROUP_SIZE], dtype=tl.float32) - float("inf") + l_i = tl.zeros([PADDED_QUERY_GROUP_SIZE], dtype=tl.float32) + acc = tl.zeros([PADDED_QUERY_GROUP_SIZE, HEAD_SIZE], dtype=tl.float32) + + num_prev_blocks = part_idx * (PARTITION_SIZE // KV_BLOCK_SIZE) + + for i in range(tl.cdiv(num_blocks, 8)): + # block_idx = num_prev_blocks + i + block_number_0 = tl.load( + block_tables_ptr + seq_idx * stride_bt0 + (num_prev_blocks + i*8 + 0) * stride_bt1 + ) + block_number_1 = tl.load( + block_tables_ptr + seq_idx * stride_bt0 + (num_prev_blocks + i*8 + 1) * stride_bt1 + ) + block_number_2 = tl.load( + block_tables_ptr + seq_idx * stride_bt0 + (num_prev_blocks + i*8 + 2) * stride_bt1 + ) + block_number_3 = tl.load( + block_tables_ptr + seq_idx * stride_bt0 + (num_prev_blocks + i*8 + 3) * stride_bt1 + ) + block_number_4 = tl.load( + block_tables_ptr + seq_idx * stride_bt0 + (num_prev_blocks + i*8 + 4) * stride_bt1 + ) + block_number_5 = tl.load( + block_tables_ptr + seq_idx * stride_bt0 + (num_prev_blocks + i*8 + 5) * stride_bt1 + ) + block_number_6 = tl.load( + block_tables_ptr + seq_idx * stride_bt0 + (num_prev_blocks + i*8 + 6) * stride_bt1 + ) + block_number_7 = tl.load( + block_tables_ptr + seq_idx * stride_bt0 + (num_prev_blocks + i*8 + 7) * stride_bt1 + ) + # Load a key block. + kv_block_offset_0 = block_number_0 * stride_kv0 + kv_offset + kv_block_offset_1 = block_number_1 * stride_kv0 + kv_offset + kv_block_offset_2 = block_number_2 * stride_kv0 + kv_offset + kv_block_offset_3 = block_number_3 * stride_kv0 + kv_offset + kv_block_offset_4 = block_number_4 * stride_kv0 + kv_offset + kv_block_offset_5 = block_number_5 * stride_kv0 + kv_offset + kv_block_offset_6 = block_number_6 * stride_kv0 + kv_offset + kv_block_offset_7 = block_number_7 * stride_kv0 + kv_offset + mask_offset_0 = (num_prev_blocks + i*8 + 0) * KV_BLOCK_SIZE + block_offset + mask_offset_1 = (num_prev_blocks + i*8 + 1) * KV_BLOCK_SIZE + block_offset + mask_offset_2 = (num_prev_blocks + i*8 + 2) * KV_BLOCK_SIZE + block_offset + mask_offset_3 = (num_prev_blocks + i*8 + 3) * KV_BLOCK_SIZE + block_offset + mask_offset_4 = (num_prev_blocks + i*8 + 4) * KV_BLOCK_SIZE + block_offset + mask_offset_5 = (num_prev_blocks + i*8 + 5) * KV_BLOCK_SIZE + block_offset + mask_offset_6 = (num_prev_blocks + i*8 + 6) * KV_BLOCK_SIZE + block_offset + mask_offset_7 = (num_prev_blocks + i*8 + 7) * KV_BLOCK_SIZE + block_offset + + # k: [KV_BLOCK_SIZE, HEAD_SIZE] + k_0 = tl.load(k_cache_ptr + kv_block_offset_0) + k_1 = tl.load(k_cache_ptr + kv_block_offset_1) + k_2 = tl.load(k_cache_ptr + kv_block_offset_2) + k_3 = tl.load(k_cache_ptr + kv_block_offset_3) + k_4 = tl.load(k_cache_ptr + kv_block_offset_4) + k_5 = tl.load(k_cache_ptr + kv_block_offset_5) + k_6 = tl.load(k_cache_ptr + kv_block_offset_6) + k_7 = tl.load(k_cache_ptr + kv_block_offset_7) + + # v: [KV_BLOCK_SIZE, HEAD_SIZE] + v_0 = tl.load(v_cache_ptr + kv_block_offset_0) + v_1 = tl.load(v_cache_ptr + kv_block_offset_1) + v_2 = tl.load(v_cache_ptr + kv_block_offset_2) + v_3 = tl.load(v_cache_ptr + kv_block_offset_3) + v_4 = tl.load(v_cache_ptr + kv_block_offset_4) + v_5 = tl.load(v_cache_ptr + kv_block_offset_5) + v_6 = tl.load(v_cache_ptr + kv_block_offset_6) + v_7 = tl.load(v_cache_ptr + kv_block_offset_7) + + qk_0 = tl.dot(q, k_0.T, out_dtype=tl.float32) + qk_1 = tl.dot(q, k_1.T, out_dtype=tl.float32) + qk_2 = tl.dot(q, k_2.T, out_dtype=tl.float32) + qk_3 = tl.dot(q, k_3.T, out_dtype=tl.float32) + qk_4 = tl.dot(q, k_4.T, out_dtype=tl.float32) + qk_5 = tl.dot(q, k_5.T, out_dtype=tl.float32) + qk_6 = tl.dot(q, k_6.T, out_dtype=tl.float32) + qk_7 = tl.dot(q, k_7.T, out_dtype=tl.float32) + + qk_0 = tl.where(mask_offset_0 < context_len, qk_0, float("-inf")) + qk_1 = tl.where(mask_offset_1 < context_len, qk_1, float("-inf")) + qk_2 = tl.where(mask_offset_2 < context_len, qk_2, float("-inf")) + qk_3 = tl.where(mask_offset_3 < context_len, qk_3, float("-inf")) + qk_4 = tl.where(mask_offset_4 < context_len, qk_4, float("-inf")) + qk_5 = tl.where(mask_offset_5 < context_len, qk_5, float("-inf")) + qk_6 = tl.where(mask_offset_6 < context_len, qk_6, float("-inf")) + qk_7 = tl.where(mask_offset_7 < context_len, qk_7, float("-inf")) + + m_i_new = tl.maximum(m_i, tl.max(qk_0, axis=1)) + m_i_new = tl.maximum(m_i_new, tl.max(qk_1, axis=1)) + m_i_new = tl.maximum(m_i_new, tl.max(qk_2, axis=1)) + m_i_new = tl.maximum(m_i_new, tl.max(qk_3, axis=1)) + m_i_new = tl.maximum(m_i_new, tl.max(qk_4, axis=1)) + m_i_new = tl.maximum(m_i_new, tl.max(qk_5, axis=1)) + m_i_new = tl.maximum(m_i_new, tl.max(qk_6, axis=1)) + m_i_new = tl.maximum(m_i_new, tl.max(qk_7, axis=1)) + + # p: [PADDED_QUERY_GROUP_SIZE, KV_BLOCK_SIZE] + p_0 = tl.math.exp2((qk_0 - m_i_new[:, None]) * log2e).to(v_0.dtype) + p_1 = tl.math.exp2((qk_1 - m_i_new[:, None]) * log2e).to(v_1.dtype) + p_2 = tl.math.exp2((qk_2 - m_i_new[:, None]) * log2e).to(v_2.dtype) + p_3 = tl.math.exp2((qk_3 - m_i_new[:, None]) * log2e).to(v_3.dtype) + p_4 = tl.math.exp2((qk_4 - m_i_new[:, None]) * log2e).to(v_4.dtype) + p_5 = tl.math.exp2((qk_5 - m_i_new[:, None]) * log2e).to(v_5.dtype) + p_6 = tl.math.exp2((qk_6 - m_i_new[:, None]) * log2e).to(v_6.dtype) + p_7 = tl.math.exp2((qk_7 - m_i_new[:, None]) * log2e).to(v_7.dtype) + alpha = tl.math.exp2((m_i - m_i_new) * log2e) + acc *= alpha[:, None] + + acc += ( + tl.dot(p_0, v_0, out_dtype=tl.float32) + + tl.dot(p_1, v_1, out_dtype=tl.float32) + + tl.dot(p_2, v_2, out_dtype=tl.float32) + + tl.dot(p_3, v_3, out_dtype=tl.float32) + + tl.dot(p_4, v_4, out_dtype=tl.float32) + + tl.dot(p_5, v_5, out_dtype=tl.float32) + + tl.dot(p_6, v_6, out_dtype=tl.float32) + + tl.dot(p_7, v_7, out_dtype=tl.float32) + ) + + # exp_sum=exp_sum * tl.exp(qk_max - _qk_max) + tl.sum(exp_tmp, axis=0) + l_i = l_i * alpha + tl.sum(p_0 + p_1 + p_2 + p_3 + p_4 + p_5 + p_6 + p_7, axis=1) + m_i = m_i_new # _qk_max + + acc = acc / l_i[:, None] + + if USE_PARTITIONING: + part_offset = ( + (seq_idx * NUM_KV_HEADS + kv_head_idx) + * max_num_partitions + * QUERY_GROUP_SIZE + + part_idx * QUERY_GROUP_SIZE + + padding_group_offset + ) + tl.store(m_i_ptr + part_offset, l_i) + tl.store(l_i_ptr + part_offset, l_i) + + out_offset = seq_idx * stride_o0 + if USE_PARTITIONING: + out_offset += kv_head_idx * stride_o1 + else: + out_offset += kv_head_idx * QUERY_GROUP_SIZE * stride_o1 + out_offset += ( + part_idx * stride_o2 + + padding_group_offset[:, None] * stride_o3 + + head_offset[None, :] * stride_o4 + ) + + group_mask = padding_group_offset[:, None] < QUERY_GROUP_SIZE + tl.store(out_ptr + out_offset, acc, mask=group_mask) + +@triton.autotune( + # configs=[triton.Config({}, num_warps=warps) for warps in [4, 8, 16]], + configs=[ + triton.Config({ + # 'matrix_instr_nonkdim': kdim, + # 'kpack': kpack, + # 'waves_per_eu': waves, + }, + num_stages=stages, + num_warps=warps) + # for kdim in [16, 32] + # for kpack in [1, 2] + # for waves in [0] + for stages in [0, 1] + for warps in [2, 4, 8, 16] + ], + key=["QUERY_GROUP_SIZE", "HEAD_SIZE", "NUM_PARTITIONS", "PARTITION_SIZE"], +) +@triton.jit +def _paged_attn_w_mma_v2_reduce_kernel( + out_ptr, # [num_seqs, NUM_KV_HEADS, QUERY_GROUP_SIZE, HEAD_SIZE] + m_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE] + l_i_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE] + tmp_out_ptr, # [num_seqs, NUM_KV_HEADS, max_num_partitions, QUERY_GROUP_SIZE, HEAD_SIZE] + context_lens_ptr, # [num_seqs] + max_num_partitions, # partition stride + stride_o0, + stride_o1, + stride_o2, + HEAD_SIZE: tl.constexpr, + QUERY_GROUP_SIZE: tl.constexpr, + PADDED_QUERY_GROUP_SIZE: tl.constexpr, + NUM_KV_HEADS: tl.constexpr, + PARTITION_SIZE: tl.constexpr, + NUM_PARTITIONS: tl.constexpr, +): + seq_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + + context_len = tl.load(context_lens_ptr + seq_idx) + + num_partitions = tl.cdiv(context_len, PARTITION_SIZE) + group_head_offset = ( + tl.arange(0, PADDED_QUERY_GROUP_SIZE)[:, None] * HEAD_SIZE + + tl.arange(0, HEAD_SIZE)[None, :] + ) + group_mask = tl.arange(0, PADDED_QUERY_GROUP_SIZE)[:, None] < QUERY_GROUP_SIZE + if num_partitions == 1: + tmp_out_offset = ( + seq_idx * NUM_KV_HEADS + kv_head_idx + ) * max_num_partitions * QUERY_GROUP_SIZE * HEAD_SIZE + group_head_offset + tmp_out = tl.load(tmp_out_ptr + tmp_out_offset, mask=group_mask, other=0.0) + + out_offset = ( + seq_idx * stride_o0 + + kv_head_idx * QUERY_GROUP_SIZE * stride_o1 + + group_head_offset * stride_o2 + ) + tl.store(out_ptr + out_offset, tmp_out, mask=group_mask) + return + + # Get the global max logit. + ml_offset = ( + (seq_idx * NUM_KV_HEADS + kv_head_idx) * max_num_partitions * QUERY_GROUP_SIZE + + tl.arange(0, NUM_PARTITIONS)[:, None] * QUERY_GROUP_SIZE + + tl.arange(0, PADDED_QUERY_GROUP_SIZE)[None, :] + ) + + mask = (tl.arange(0, NUM_PARTITIONS)[:, None] < num_partitions) & ( + tl.arange(0, PADDED_QUERY_GROUP_SIZE)[None, :] < QUERY_GROUP_SIZE + ) + # m_i: [NUM_PARTITIONS, PADDED_QUERY_GROUP_SIZE] + m_i = tl.load(m_i_ptr + ml_offset, mask=mask, other=float("-inf")) + # m: [PADDED_QUERY_GROUP_SIZE] + m = tl.max(m_i, axis=0) + + # Rescale the exp sums and compute the global sum. + # l_i: [NUM_PARTITIONS, PADDED_QUERY_GROUP_SIZE] + l_i = tl.load(l_i_ptr + ml_offset, mask=mask, other=0.0) + l_i *= tl.exp(m_i - m[None, :]) + # l: [PADDED_QUERY_GROUP_SIZE] + l = tl.sum(l_i, axis=0) + # r: [NUM_PARTITIONS, PADDED_QUERY_GROUP_SIZE] + r = l_i / l[None, :] + r = tl.reshape(r, (NUM_PARTITIONS, PADDED_QUERY_GROUP_SIZE, 1)) + + tmp_out_offset = ( + (seq_idx * NUM_KV_HEADS + kv_head_idx) + * max_num_partitions + * QUERY_GROUP_SIZE + * HEAD_SIZE + + tl.arange(0, NUM_PARTITIONS)[:, None, None] * QUERY_GROUP_SIZE * HEAD_SIZE + + tl.arange(0, PADDED_QUERY_GROUP_SIZE)[None, :, None] * HEAD_SIZE + + tl.arange(0, HEAD_SIZE)[None, None, :] + ) + # tmp_out: [NUM_PARTITIONS, PADDED_QUERY_GROUP_SIZE, HEAD_SIZE] + tmp_out = tl.load(tmp_out_ptr + tmp_out_offset, mask=mask[:, :, None], other=0.0) + # out: [PADDED_QUERY_GROUP_SIZE, HEAD_SIZE] + out = tl.sum((tmp_out * r).to(tl.float32), axis=0) + + out_offset = ( + seq_idx * stride_o0 + + kv_head_idx * QUERY_GROUP_SIZE * stride_o1 + + group_head_offset * stride_o2 + ) + tl.store(out_ptr + out_offset, out, mask=group_mask) + + +from blade_llm.module.triton.paged_attention_new import paged_attention as paged_attention1 +from flash_attn import flash_attn_with_kvcache + +configs = [] +HEAD_DIM = 128 +tmp = [ + # (1, 16, 16), + # (1, 32, 32), + (1, 32, 4), + (64, 32, 4), + (1, 52, 4), + (64, 52, 4), + (1, 16, 2), + (64, 16, 2), + (1, 26, 2), + (64, 26, 2), + (1, 8, 1), + (64, 8, 1), + (1, 13, 1), + (64, 13, 1), +] +for bs, q_head, kv_head in tmp: + configs.append( + triton.testing.Benchmark( + x_names=["seq_len"], + x_vals=[2**i * 1024 for i in range(0, 5)], + line_arg="provider", + line_vals=["triton", "triton_opt","triton_unroll2","triton_unroll4","triton_unroll8",], + # line_vals=["flash_attn"], + # line_vals=["triton_v1", "triton_v2", "triton", "vllm"], + line_names=["triton", "triton_opt","triton_unroll2","triton_unroll4","triton_unroll8"], + # line_names=["flash_attn"], + # line_names=["triton_v1", "triton_v2", "Triton new", "vLLM"], + styles=[("red", "-"), ("yellow", "-"), ("blue", "-"), ("green", "-"), ("black", "-")], + ylabel="ms", + plot_name=f"BS={bs},num_head_q={q_head},num_heads_kv={kv_head}," + + f"head_size=128, block_size=16,num_blocks=10240", + args={ + "num_seqs": bs, + "q_head": q_head, + "kv_head": kv_head, + "head_size": 128, + "block_size": 32, + "num_blocks": 1024, + "dtype": torch.float16, + }, + ) + ) + + +@triton.testing.perf_report(configs) +def benchmark( + num_seqs, + seq_len, + q_head, + kv_head, + head_size, + block_size, + num_blocks, + dtype, + provider, + eps=1e-5, + device="cuda", +): + qkv = torch.empty(num_seqs, 3, q_head, head_size, dtype=dtype, device="cuda") + qkv.uniform_(-3, 3) + query, key, value = qkv.unbind(dim=1) + + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_block_shape = (kv_head, head_size // x, block_size, x) + key_cache = torch.randn((num_blocks, *key_block_shape), dtype=dtype, device="cuda") + value_block_shape = (kv_head, head_size, block_size) + value_cache = torch.randn( + (num_blocks, *value_block_shape), dtype=dtype, device="cuda" + ) + + context_lens = [seq_len for _ in range(num_seqs)] + max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") + # max_context_len=8192 + # context_lens = [max_context_len for _ in range(num_seqs)] + # context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda') + + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for i in range(num_seqs): + block_table = [ + random.randint(0, num_blocks - 1) for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + head_mapping = torch.arange(q_head, dtype=torch.int32, device="cuda") + + scale = float(1.0 / (head_size**0.5)) + + assert q_head % kv_head == 0 + num_queries_per_kv = q_head // kv_head + head_mapping = torch.repeat_interleave( + torch.arange(kv_head, dtype=torch.int32, device="cuda"), num_queries_per_kv + ) + + num_slots = block_size * num_blocks + slot_mapping = random.sample(range(num_slots), num_seqs) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="cuda") + out = torch.empty_like(query) + + key_cache_tri = key_cache.permute(0, 1, 3, 2, 4).flatten(3, 4).contiguous().cuda() + value_cache_tri = value_cache.permute(0, 1, 3, 2).contiguous().cuda() + quantiles = [0.5, 0.2, 0.8] + query_fa = query.unsqueeze(1) + # (num_blocks, page_block_size, nheads_k, headdim) + key_cache_fa = key_cache_tri.permute(0, 2, 1, 3).contiguous() + value_cache_fa = value_cache_tri.permute(0, 2, 1, 3).contiguous() + out_fa = torch.empty_like(query_fa) + + + if provider == "flash_attn": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: flash_attn_with_kvcache( + query_fa, + key_cache_fa, + value_cache_fa, + block_table=block_tables, + softmax_scale=scale, + ), + warmup=20, + rep=100, + quantiles=quantiles, + ) + + if provider == "triton_unroll8": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: paged_attn_w_mma_unroll8( + out, + query, + key_cache_tri, + value_cache_tri, + context_lens, + block_tables, + scale, + max_context_len, + num_splits=triton.cdiv(max_context_len, 256), + partition_size=256, + device=torch.cuda.device_of(query), + ), + warmup=200, + rep=500, + quantiles=quantiles, + ) + + if provider == "triton_unroll4": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: paged_attn_w_mma_unroll4( + out, + query, + key_cache_tri, + value_cache_tri, + context_lens, + block_tables, + scale, + max_context_len, + num_splits=triton.cdiv(max_context_len, 256), + partition_size=256, + device=torch.cuda.device_of(query), + ), + warmup=200, + rep=500, + quantiles=quantiles, + ) + + if provider == "triton_unroll2": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: paged_attn_w_mma_unroll2( + out, + query, + key_cache_tri, + value_cache_tri, + context_lens, + block_tables, + scale, + max_context_len, + num_splits=triton.cdiv(max_context_len, 256), + partition_size=256, + device=torch.cuda.device_of(query), + ), + warmup=200, + rep=500, + quantiles=quantiles, + ) + + if provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: paged_attention1( + out, + query, + key_cache_tri, + value_cache_tri, + context_lens, + block_tables, + scale, + max_context_len, + num_splits=0, + alibi_slope=None, + ), + warmup=200, + rep=500, + quantiles=quantiles, + ) + + if provider == "triton_opt": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: paged_attention( + out, + query, + key_cache_tri, + value_cache_tri, + context_lens, + block_tables, + scale, + max_context_len, + num_splits=0, + alibi_slope=None, + ), + warmup=200, + rep=500, + quantiles=quantiles, + ) + perf_us = lambda x: round(x * 1e3, 2) + return perf_us(ms), perf_us(min_ms), perf_us(max_ms) + +# import ctypes +# _hip = ctypes.CDLL('libroctracer64.so') + + +# def hip_prof_start(): +# ret = _hip.roctracer_start() +# if ret != 0: +# raise Exception('hipProfilerStart() returned %d' % ret) + + +# def hip_prof_stop(): +# ret = _hip.roctracer_stop() +# if ret != 0: +# raise Exception('hipProfilerStop() returned %d' % ret) + +def test_attention(): + for (num_seqs, q_head, kv_head, head_size, block_size, num_blocks, dtype, seq_len) in [ + (1, 16, 16, 128, 16, 10240, torch.float16, 8192), + (1, 32, 32, 128, 16, 10240, torch.float16, 8192), + (1, 32, 4, 128, 16, 10240, torch.float16, 8192), + (64, 32, 4, 128, 16, 1024, torch.float16, 8192), + (1, 52, 4, 128, 16, 10240, torch.float16, 8192), + (64, 52, 4, 128, 16, 10240, torch.float16, 1024), + (1, 16, 2, 128, 16, 10240, torch.float16, 2048), + (64, 16, 2, 128, 16, 10240, torch.float16, 8192), + (1, 26, 2, 128, 16, 10240, torch.float16, 8192), + (64, 26, 2, 128, 16, 10240, torch.float16, 8192), + (1, 8, 1, 128, 16, 10240, torch.float16, 8192), + (64, 8, 1, 128, 16, 10240, torch.float16, 8192), + (1, 13, 1, 128, 16, 10240, torch.float16, 8192), + (64, 13, 1, 128, 16, 10240, torch.float16, 8192), + ]: + print(f"num_seqs={num_seqs}, q_head={q_head}, kv_head={kv_head}, head_size={head_size}, block_size={block_size}, num_blocks={num_blocks}, dtype={dtype}, seq_len={seq_len}") + qkv = torch.ones(num_seqs, 3, q_head, head_size, dtype=dtype, device="cuda") + # qkv.uniform_(-3, 3) + query, key, value = qkv.unbind(dim=1) + + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_block_shape = (kv_head, head_size // x, block_size, x) + key_cache = torch.ones((num_blocks, *key_block_shape), dtype=dtype, device="cuda") + value_block_shape = (kv_head, head_size, block_size) + value_cache = torch.randn( + (num_blocks, *value_block_shape), dtype=dtype, device="cuda" + ) + + context_lens = [seq_len for _ in range(num_seqs)] + max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") + # max_context_len=8192 + # context_lens = [max_context_len for _ in range(num_seqs)] + # context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda') + + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for i in range(num_seqs): + block_table = [ + i for i in range(max_num_blocks_per_seq) + # random.randint(0, num_blocks - 1) for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + head_mapping = torch.arange(q_head, dtype=torch.int32, device="cuda") + + scale = float(1.0 / (head_size**0.5)) + + assert q_head % kv_head == 0 + num_queries_per_kv = q_head // kv_head + head_mapping = torch.repeat_interleave( + torch.arange(kv_head, dtype=torch.int32, device="cuda"), num_queries_per_kv + ) + + num_slots = block_size * num_blocks + # slot_mapping = random.sample(range(num_slots), num_seqs) + # slot_mapping = [i for i in range(num_seqs)] + # slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="cuda") + + key_cache_tri = key_cache.permute(0, 1, 3, 2, 4).flatten(3, 4).contiguous().cuda() + value_cache_tri = value_cache.permute(0, 1, 3, 2).contiguous().cuda() + + out = torch.empty_like(query) + paged_attention1( + out, + query, + key_cache_tri, + value_cache_tri, + context_lens, + block_tables, + scale, + max_context_len, + num_splits=0, + alibi_slope=None, + ) + + out1 = torch.empty_like(query) + paged_attention( + out1, + query, + key_cache_tri, + value_cache_tri, + context_lens, + block_tables, + scale, + max_context_len, + num_splits=0, + alibi_slope=None, + ) + # print(f"new={out1}, ref={out}") + + out2 = torch.empty_like(query) + paged_attn_w_mma_unroll2( + out2, + query, + key_cache_tri, + value_cache_tri, + context_lens, + block_tables, + scale, + max_context_len, + num_splits=triton.cdiv(max_context_len, 256), + partition_size=256, + device=torch.cuda.device_of(query), + ) + + out4 = torch.empty_like(query) + paged_attn_w_mma_unroll4( + out4, + query, + key_cache_tri, + value_cache_tri, + context_lens, + block_tables, + scale, + max_context_len, + num_splits=triton.cdiv(max_context_len, 256), + partition_size=256, + device=torch.cuda.device_of(query), + ) + + out8 = torch.empty_like(query) + paged_attn_w_mma_unroll8( + out8, + query, + key_cache_tri, + value_cache_tri, + context_lens, + block_tables, + scale, + max_context_len, + num_splits=triton.cdiv(max_context_len, 256), + partition_size=256, + device=torch.cuda.device_of(query), + ) + + assert torch.allclose(out1, out, atol=1e-2, rtol=1e-2) + assert torch.allclose(out1, out2, atol=1e-2, rtol=1e-2) + assert torch.allclose(out1, out4, atol=1e-2, rtol=1e-2) + assert torch.allclose(out1, out8, atol=1e-2, rtol=1e-2) + # assert torch.allclose(out2, ref_out, atol=1e-3, rtol=1e-3) + # assert torch.allclose(out3, ref_out, atol=1e-3, rtol=1e-3) + + +if __name__ == "__main__": + # test_attention() + benchmark.run(show_plots=False, print_data=True)