diff --git a/aie_kernels/aie2p/softmax.cc b/aie_kernels/aie2p/softmax.cc index 64cca202..5778682a 100644 --- a/aie_kernels/aie2p/softmax.cc +++ b/aie_kernels/aie2p/softmax.cc @@ -3,6 +3,7 @@ #include #include +#include #define SM_VEC_LEN 64 // 32 #define log2e 1.4453125 // 1.44269504089 @@ -30,7 +31,7 @@ void softmax_simple_bf16(bfloat16 *restrict input_vector, bfloat16 *restrict out aie::vector in_elems, exp_val, input_bf16, log2e_vec, max_val_vec; aie::accum out_vals, exp_val_accum, scaled_accum, exp_in_accum; - float max_val = 0; + float max_val = -INFINITY; float accum_exp_val = 0; float running_max = 0; bfloat16 col_sum_inv; diff --git a/iron/common/compilation/base.py b/iron/common/compilation/base.py index 53b1fd3c..4a1d7e44 100644 --- a/iron/common/compilation/base.py +++ b/iron/common/compilation/base.py @@ -496,6 +496,7 @@ def compile(self, graph): str(self.aiecc_path), "-v", "-j1", + "--dynamic-objFifos", "--no-compile-host", "--no-xchesscc", "--no-xbridge", diff --git a/iron/common/fusion.py b/iron/common/fusion.py index 99219848..292b26e9 100644 --- a/iron/common/fusion.py +++ b/iron/common/fusion.py @@ -5,6 +5,7 @@ import ml_dtypes import pyxrt import ctypes +import time from . import compilation as comp from .base import AIEOperatorBase, MLIROperator from .utils import XRTSubBuffer @@ -42,8 +43,7 @@ def get_kernel_artifacts(self): """Collect all kernel artifacts from child operators. Returns: - List of KernelObjectArtifact instances from all unique child operators, - with filenames and symbol prefixes disambiguated per operator index. + List of KernelObjectArtifact instances from all unique child operators. """ kernel_artifacts = [] seen: dict[int, object] = {} @@ -52,9 +52,6 @@ def get_kernel_artifacts(self): ] for idx, op in enumerate(unique_operators): objs = op.get_kernel_artifacts() - for obj in objs: - obj.filename = f"op{idx}_{obj.filename}" - obj.prefix_symbols = f"op{idx}_" kernel_artifacts.extend(objs) return kernel_artifacts @@ -82,8 +79,6 @@ def get_mlir_artifact(self): ] for idx, op in enumerate(unique_operators): mlir_artifact = op.get_mlir_artifact() - if len(op.get_kernel_artifacts()) > 0: - mlir_artifact.generator.kwargs["func_prefix"] = f"op{idx}_" op_name = f"op{idx}_{op.__class__.__name__}" op_names[id(op)] = op_name operator_mlir_map[op_name] = mlir_artifact @@ -290,8 +285,10 @@ def __call__(self, *args): for i, arg in enumerate(args): assert isinstance(arg, pyxrt.bo), f"Argument {i} is not a pyxrt.bo" run.set_arg(i, arg) + t0 = time.perf_counter() run.start() ret_code = run.wait() + self.last_elapsed = time.perf_counter() - t0 if ret_code != pyxrt.ert_cmd_state.ERT_CMD_STATE_COMPLETED: raise RuntimeError(f"Kernel execution failed with return code {ret_code}") @@ -371,10 +368,10 @@ def get_buffer(self, buffer_name): return sub_buffer def __call__(self): - self.input_buffer.to("npu") + self.input_buffer._sync_to_device() super().__call__( self.input_buffer.buffer_object(), self.output_buffer.buffer_object(), self.scratch_buffer.buffer_object(), ) - self.output_buffer.to("cpu") + self.output_buffer._sync_from_device() diff --git a/iron/operators/gemm/design.py b/iron/operators/gemm/design.py index a8ed8ad3..8b717dcf 100644 --- a/iron/operators/gemm/design.py +++ b/iron/operators/gemm/design.py @@ -299,7 +299,7 @@ def my_matmul( gemm_object, [C_l1_ty_internal], ) - matmul_func_name = f"matmul{scalar_suffix}_{dtype_in_str}_f32" + matmul_func_name = f"{func_prefix}matmul{scalar_suffix}_{dtype_in_str}_f32" matmul_kernel = Kernel( matmul_func_name, gemm_object, @@ -314,7 +314,9 @@ def my_matmul( gemm_object, [C_l1_ty], ) - matmul_func_name = f"matmul{scalar_suffix}_{dtype_in_str}_{dtype_out_str}" + matmul_func_name = ( + f"{func_prefix}matmul{scalar_suffix}_{dtype_in_str}_{dtype_out_str}" + ) matmul_kernel = Kernel( matmul_func_name, gemm_object, diff --git a/iron/operators/mha_prefill_lxl_sd/__init__.py b/iron/operators/mha_prefill_lxl_sd/__init__.py new file mode 100644 index 00000000..82f09a67 --- /dev/null +++ b/iron/operators/mha_prefill_lxl_sd/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/iron/operators/mha_prefill_lxl_sd/op.py b/iron/operators/mha_prefill_lxl_sd/op.py new file mode 100644 index 00000000..aba8a115 --- /dev/null +++ b/iron/operators/mha_prefill_lxl_sd/op.py @@ -0,0 +1,399 @@ +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +A layer-by-layer (LxL) single-dispatch (SD) implementation of multi-head attention (MHA). +""" + +from iron.common.context import AIEContext +from iron.common.fusion import FusedMLIROperator +from iron.operators.gemm.op import GEMM +from iron.operators.rope.op import RoPE +from iron.operators.strided_copy.op import StridedCopy +from iron.operators.repeat.op import Repeat +from iron.operators.softmax.op import Softmax +from iron.operators.transpose.op import Transpose +from iron.operators.elementwise_mul.op import ElementwiseMul +from iron.operators.elementwise_add.op import ElementwiseAdd + + +def _pick_tile_n(N, num_cols, max_tile_n=64): + tile_n = N // num_cols + while tile_n > max_tile_n: + tile_n //= 2 + assert N % (tile_n * num_cols) == 0 + return tile_n + + +def _build_core_ops(H, G, d, S, elf_ctx, causal_mask=True): + """Build core attention sub-ops and runlist (no projections/RoPE/GQA). + + Expects pre-processed inputs: + queries: (H, S, d) deinterleaved, contiguous per head + keys: (H, d, S) transposed and GQA-repeated + values: (H, S, d) GQA-repeated + + Produces: + attn_context: (H, S, d) — per-head context vectors + + If causal_mask=False, the elementwise-add masking step is omitted. + """ + B = 2 # bytes per bf16 element + + gemm_scores = GEMM( + M=S, + K=d, + N=S, + num_aie_columns=8, + tile_m=16, + tile_k=64, + tile_n=_pick_tile_n(S, 8), + context=elf_ctx, + ) + scale = ElementwiseMul( + size=H * S * S, + tile_size=S * S // 8, + num_aie_columns=8, + context=elf_ctx, + ) + if causal_mask: + mask = ElementwiseAdd( + size=H * S * S, + tile_size=S * S // 8, + num_aie_columns=8, + context=elf_ctx, + ) + softmax = Softmax( + rows=H * S, + cols=S, + num_aie_columns=1, + num_channels=1, + rtp_vector_size=S, + context=elf_ctx, + ) + gemm_context = GEMM( + M=S, + K=S, + N=d, + num_aie_columns=4, + tile_m=16, + tile_k=64, + tile_n=16, + context=elf_ctx, + prio_accuracy=True, + ) + + qh = S * d * B + kdS = d * S * B + kSd = S * d * B + sh = S * S * B + ch = S * d * B + + runlist = [ + *[ + ( + gemm_scores, + f"queries[{h*qh}:{(h+1)*qh}]", + f"keys[{h*kdS}:{(h+1)*kdS}]", + f"attn_scores[{h*sh}:{(h+1)*sh}]", + ) + for h in range(H) + ], + (scale, "attn_scores", "attn_scale_factor", "attn_scores_scaled"), + ] + + if causal_mask: + runlist += [ + (mask, "attn_scores_scaled", "causal_mask", "attn_scores_masked"), + (softmax, "attn_scores_masked", "attn_weights"), + ] + else: + runlist += [ + (softmax, "attn_scores_scaled", "attn_weights"), + ] + + runlist += [ + *[ + ( + gemm_context, + f"attn_weights[{h*sh}:{(h+1)*sh}]", + f"values[{h*kSd}:{(h+1)*kSd}]", + f"attn_context[{h*ch}:{(h+1)*ch}]", + ) + for h in range(H) + ], + ] + + buffer_sizes = { + "queries": H * S * d * B, + "keys": H * d * S * B, + "values": H * S * d * B, + "attn_scores": H * S * S * B, + "attn_scores_scaled": H * S * S * B, + "attn_weights": H * S * S * B, + "attn_context": H * S * d * B, + } + if causal_mask: + buffer_sizes["attn_scores_masked"] = H * S * S * B + + return runlist, buffer_sizes + + +class AttentionPrefillFused(FusedMLIROperator): + """Fused attention prefill (core, no projections/RoPE). + + Accepts pre-projected Q (S*H,d), K (S*G,d), V (S*G,d) in interleaved layout. + """ + + def __init__( + self, + num_heads, + num_kv_groups, + head_dim, + embedding_dim, + seq_len, + causal_mask=True, + context=None, + ): + assert head_dim == 64 + assert num_heads % num_kv_groups == 0 + assert seq_len % 256 == 0 + assert (num_heads * seq_len) % 16 == 0 + + self.num_heads = num_heads + self.num_kv_groups = num_kv_groups + self.head_dim = head_dim + self.embedding_dim = embedding_dim + self.seq_len = seq_len + + elf_ctx = context or AIEContext() + runlist, buffer_sizes = _build_core_ops( + num_heads, + num_kv_groups, + head_dim, + seq_len, + elf_ctx, + causal_mask=causal_mask, + ) + + mask_suffix = "_causal" if causal_mask else "_nomask" + input_args = ["queries", "keys", "values", "attn_scale_factor"] + if causal_mask: + input_args.append("causal_mask") + + super().__init__( + name=f"attention_prefill_fused_{num_heads}h{num_kv_groups}g{head_dim}d{embedding_dim}e{seq_len}s{mask_suffix}", + runlist=runlist, + input_args=input_args, + output_args=["attn_context"], + buffer_sizes=buffer_sizes, + context=elf_ctx, + ) + + +class AttentionPrefillProjectedFused(FusedMLIROperator): + """Fused attention prefill with Q/K/V projections and RoPE. + + Accepts raw input (S, E) and rope_angles (S, d). + """ + + def __init__( + self, + num_heads, + num_kv_groups, + head_dim, + embedding_dim, + seq_len, + causal_mask=True, + context=None, + ): + assert head_dim == 64 + assert num_heads % num_kv_groups == 0 + assert seq_len % 256 == 0 + assert (num_heads * seq_len) % 16 == 0 + + self.num_heads = num_heads + self.num_kv_groups = num_kv_groups + self.head_dim = head_dim + self.embedding_dim = embedding_dim + self.seq_len = seq_len + + H, G, d, E, S = num_heads, num_kv_groups, head_dim, embedding_dim, seq_len + group_size = H // G + B = 2 + + elf_ctx = context or AIEContext() + + # ---- Projection + RoPE ---- + gemm_query = GEMM( + M=S, + K=E, + N=H * d, + num_aie_columns=8, + tile_m=16, + tile_k=64, + tile_n=_pick_tile_n(H * d, 8), + context=elf_ctx, + ) + gemm_kv = GEMM( + M=S, + K=E, + N=G * d, + num_aie_columns=8, + tile_m=16, + tile_k=64, + tile_n=_pick_tile_n(G * d, 8), + context=elf_ctx, + ) + rope_queries = RoPE(rows=S * H, cols=d, angle_rows=S, context=elf_ctx) + rope_keys = RoPE(rows=S * G, cols=d, angle_rows=S, context=elf_ctx) + + # ---- Deinterleave ---- + deinterleave_q = StridedCopy( + input_sizes=(H, S, d), + input_strides=(d, H * d, 1), + input_offset=0, + output_sizes=(H, S, d), + output_strides=(S * d, d, 1), + output_offset=0, + input_buffer_size=S * H * d, + output_buffer_size=H * S * d, + transfer_size=S * d, + num_aie_channels=1, + context=elf_ctx, + ) + deinterleave_kv = StridedCopy( + input_sizes=(G, S, d), + input_strides=(d, G * d, 1), + input_offset=0, + output_sizes=(G, S, d), + output_strides=(S * d, d, 1), + output_offset=0, + input_buffer_size=S * G * d, + output_buffer_size=G * S * d, + transfer_size=S * d, + num_aie_channels=1, + context=elf_ctx, + ) + + # ---- Transpose keys + GQA repeat ---- + transpose_keys = Transpose( + M=S, + N=d, + num_aie_columns=2, + num_channels=1, + m=256, + n=32, + s=8, + context=elf_ctx, + ) + repeat_kv = Repeat( + rows=G, + cols=d * S, + repeat=group_size, + transfer_size=d, + context=elf_ctx, + ) + + kSd = S * d * B + kdS = d * S * B + + prefix_runlist = [ + (gemm_query, "input", "W_query", "queries_projected"), + (gemm_kv, "input", "W_key", "keys_projected"), + (gemm_kv, "input", "W_value", "values_projected"), + (rope_queries, "queries_projected", "rope_angles", "queries_roped"), + (rope_keys, "keys_projected", "rope_angles", "keys_roped"), + (deinterleave_q, "queries_roped", "queries"), + (deinterleave_kv, "keys_roped", "keys_deint"), + (deinterleave_kv, "values_projected", "values_deint"), + *[ + ( + transpose_keys, + f"keys_deint[{g*kSd}:{(g+1)*kSd}]", + f"keys_transposed[{g*kdS}:{(g+1)*kdS}]", + ) + for g in range(G) + ], + (repeat_kv, "keys_transposed", "keys"), + (repeat_kv, "values_deint", "values"), + ] + prefix_buffer_sizes = { + "queries_projected": S * H * d * B, + "keys_projected": S * G * d * B, + "values_projected": S * G * d * B, + "queries_roped": S * H * d * B, + "keys_roped": S * G * d * B, + "keys_deint": G * S * d * B, + "values_deint": G * S * d * B, + "keys_transposed": G * d * S * B, + } + + core_runlist, core_buffer_sizes = _build_core_ops( + H, + G, + d, + S, + elf_ctx, + causal_mask=causal_mask, + ) + + # ---- Reinterleave + output projection ---- + reinterleave = StridedCopy( + input_sizes=(1, 1, 1, H * S * d), + input_strides=(0, 0, 0, 1), + input_offset=0, + output_sizes=(H, 256, S // 256, d), + output_strides=(d, 256 * H * d, H * d, 1), + output_offset=0, + input_buffer_size=H * S * d, + output_buffer_size=S * H * d, + transfer_size=S * d, + num_aie_channels=1, + context=elf_ctx, + ) + gemm_output = GEMM( + M=S, + K=H * d, + N=E, + num_aie_columns=8, + tile_m=16, + tile_k=64, + tile_n=_pick_tile_n(E, 8), + context=elf_ctx, + prio_accuracy=True, + ) + + suffix_runlist = [ + (reinterleave, "attn_context", "context_interleaved"), + (gemm_output, "context_interleaved", "W_output", "attn_output"), + ] + suffix_buffer_sizes = { + "context_interleaved": S * H * d * B, + } + + mask_suffix = "_causal" if causal_mask else "_nomask" + input_args = [ + "input", + "rope_angles", + "W_query", + "W_key", + "W_value", + "W_output", + "attn_scale_factor", + ] + if causal_mask: + input_args.append("causal_mask") + + super().__init__( + name=f"attention_prefill_projected_fused_{H}h{G}g{d}d{E}e{S}s{mask_suffix}", + runlist=prefix_runlist + core_runlist + suffix_runlist, + input_args=input_args, + output_args=["attn_output"], + buffer_sizes={ + **prefix_buffer_sizes, + **core_buffer_sizes, + **suffix_buffer_sizes, + }, + context=elf_ctx, + ) diff --git a/iron/operators/mha_prefill_lxl_sd/reference.py b/iron/operators/mha_prefill_lxl_sd/reference.py new file mode 100644 index 00000000..3343fa8d --- /dev/null +++ b/iron/operators/mha_prefill_lxl_sd/reference.py @@ -0,0 +1,200 @@ +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def _apply_rope_4d(x, angles): + """Apply RoPE to a 4D tensor using interleaved cos/sin angles. + + x: (batch, heads, seq_len, head_dim) + angles: (seq_len, head_dim) with interleaved [cos_0, sin_0, cos_1, sin_1, ...] + Returns: same shape as x with RoPE applied (two-halves method). + """ + half = x.shape[-1] // 2 + cos = angles[:, ::2].unsqueeze(0).unsqueeze(0) # (1, 1, S, half) + sin = angles[:, 1::2].unsqueeze(0).unsqueeze(0) # (1, 1, S, half) + x1, x2 = x[..., :half], x[..., half:] + return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) + + +def _bf16_matmul(a, b): + """(float32 matmul) → bfloat16, matching NPU accumulation.""" + return (a.float() @ b.float()).to(torch.bfloat16) + + +def generate_golden_reference( + num_heads, + num_kv_groups, + head_dim, + embedding_dim, + seq_len, + seed=42, +): + """Generate golden reference for fused attention prefill. + + Parameters: + num_heads (H): number of query attention heads + num_kv_groups (G): number of KV heads (G=H for MHA, G 1: + keys_for_scores = ( + keys_transposed.reshape(G, d * S) + .repeat_interleave(group_size, dim=0) + .reshape(H, d, S) + ) + values_for_context = ( + values_deinterleaved.reshape(G, S * d) + .repeat_interleave(group_size, dim=0) + .reshape(H, S, d) + ) + else: + keys_for_scores = keys_transposed # (H, d, S) + values_for_context = values_deinterleaved # (H, S, d) + + # ---- Score GEMM per head ---- + attn_scores = torch.stack( + [_bf16_matmul(queries_deinterleaved[h], keys_for_scores[h]) for h in range(H)] + ) # (H, S, S) + + # ---- Scale ---- + attn_scores_scaled = (attn_scores.float() * scale).to(torch.bfloat16) + + # ---- Causal mask ---- + attn_scores_masked = ( + attn_scores_scaled.reshape(H * S, S).float() + causal_mask.float() + ).to(torch.bfloat16) + + # ---- Softmax ---- + attn_weights = torch.nn.functional.softmax( + attn_scores_masked.float().reshape(H, S, S), dim=-1 + ).to( + torch.bfloat16 + ) # (H, S, S) + + # ---- Context GEMM per head ---- + attn_context = torch.stack( + [_bf16_matmul(attn_weights[h], values_for_context[h]) for h in range(H)] + ) # (H, S, d) + + # ---- Re-interleave context: (H, S, d) → (S, H*d) ---- + context_interleaved = attn_context.transpose(0, 1).contiguous().reshape(S, H * d) + + # ---- Output projection ---- + attn_output = _bf16_matmul(context_interleaved, W_output) + + return { + "input": x, + "rope_angles": rope_angles, + "W_query": W_query, + "W_key": W_key, + "W_value": W_value, + "W_output": W_output, + "attn_scale_factor": attn_scale_factor, + "causal_mask": causal_mask, + "queries_raw": queries_raw, + "keys_raw": keys_raw, + "values_raw": values_raw, + "queries_roped": queries_roped, + "keys_roped": keys_roped, + "queries_deinterleaved": queries_deinterleaved, + "keys_deinterleaved": keys_deinterleaved, + "keys_transposed": keys_transposed, + "values_deinterleaved": values_deinterleaved, + "keys_for_scores": keys_for_scores, + "values_for_context": values_for_context, + "attn_scores": attn_scores, + "attn_scores_scaled": attn_scores_scaled, + "attn_scores_masked": attn_scores_masked, + "attn_weights": attn_weights, + "attn_context": attn_context, + "context_interleaved": context_interleaved, + "attn_output": attn_output, + } diff --git a/iron/operators/mha_prefill_lxl_sd/test.py b/iron/operators/mha_prefill_lxl_sd/test.py new file mode 100644 index 00000000..7e90b748 --- /dev/null +++ b/iron/operators/mha_prefill_lxl_sd/test.py @@ -0,0 +1,261 @@ +# SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch +from ml_dtypes import bfloat16 + +from iron.common.test_utils import verify_buffer + +from iron.operators.mha_prefill_lxl_sd.op import ( + AttentionPrefillFused, + AttentionPrefillProjectedFused, +) +from iron.operators.mha_prefill_lxl_sd.reference import generate_golden_reference + +REL_TOL = 0.08 +ABS_TOL = 2.0 +MAX_ERROR_RATE = 0.03 + + +def get_params(): + return [ + pytest.param(2, 2, 64, 256, 256, id="H2"), + pytest.param(32, 8, 64, 2048, 256, id="Llama3.2-256seq"), + pytest.param(12, 12, 64, 768, 256, id="GPT2-Small-256seq"), + ] + + +def get_benchmark_params(): + """GPT-2 Small across sequence lengths 256..32768, with/without causal mask.""" + params = [] + S = 256 + while S <= 32768: + for mask in [True, False]: + tag = "causal" if mask else "nomask" + params.append(pytest.param(12, 12, 64, 768, S, mask, id=f"GPT2-S{S}-{tag}")) + S *= 2 + return params + + +def _load_input(fc, name, tensor): + """Load a tensor into a named sub-buffer of the fused callable.""" + np_buf = tensor.contiguous().view(torch.uint16).numpy().view(bfloat16) + fc.get_buffer(name).data[:] = np_buf.flatten() + + +def _get_scratch_tensor(fc, name, shape): + """Read a named buffer from the fused callable's scratch space.""" + fc.scratch_buffer._sync_from_device() + sub = fc.get_buffer(name) + return sub.data[: int(np.prod(shape))].reshape(shape).astype(np.float32) + + +def _get_output_tensor(fc, name, shape): + """Read a named buffer from the fused callable's output space.""" + fc.output_buffer._sync_from_device() + sub = fc.get_buffer(name) + return sub.data[: int(np.prod(shape))].reshape(shape).astype(np.float32) + + +def _verify_output(fc, golden, H, d, S, E): + """Chain-consistent output verification shared by both test variants.""" + npu_context = torch.from_numpy( + _get_scratch_tensor(fc, "context_interleaved", (S, H * d)) + ).bfloat16() + chain_ref = (npu_context.float() @ golden["W_output"].float()).to(torch.bfloat16) + + fc.output_buffer._sync_from_device() + output_np = fc.get_buffer("attn_output").data + output = torch.from_numpy(output_np.reshape(S, E).astype(np.float32)).bfloat16() + + errors = verify_buffer( + output, + "attn_output", + chain_ref.reshape(S, E), + rel_tol=REL_TOL, + abs_tol=ABS_TOL, + max_error_rate=MAX_ERROR_RATE, + ) + assert not errors, f"Output verification failed with {len(errors)} errors" + + +def _core_gemm_flops(H, G, d, E, S): + """Count GEMM FLOPs for the core attention operator.""" + score_flops = H * 2 * S * d * S # H x (S,d)@(d,S) + context_flops = H * 2 * S * S * d # H x (S,S)@(S,d) + return score_flops + context_flops + + +def _projected_gemm_flops(H, G, d, E, S): + """Count GEMM FLOPs for the projected attention operator.""" + query_proj = 2 * S * E * (H * d) # (S,E)@(E,H*d) + kv_proj = 2 * (2 * S * E * (G * d)) # key + value: (S,E)@(E,G*d) each + output_proj = 2 * S * (H * d) * E # (S,H*d)@(H*d,E) + return query_proj + kv_proj + _core_gemm_flops(H, G, d, E, S) + output_proj + + +# --------------------------------------------------------------------------- +# Core attention tests (pre-projected Q, K, V) +# --------------------------------------------------------------------------- + + +@pytest.mark.metrics( + Latency=r"Latency \(us\): (?P[\d\.]+)", + Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", +) +@pytest.mark.parametrize("H,G,d,E,S", get_params()) +def test_mha_pefill_lxl_sd(H, G, d, E, S): + """Core attention: score GEMM -> scale -> mask -> softmax -> context GEMM.""" + golden = generate_golden_reference(H, G, d, E, S) + + op = AttentionPrefillFused(H, G, d, E, S) + op.compile() + fc = op.get_callable() + + _load_input(fc, "queries", golden["queries_deinterleaved"]) + _load_input(fc, "keys", golden["keys_for_scores"]) + _load_input(fc, "values", golden["values_for_context"]) + _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) + _load_input(fc, "causal_mask", golden["causal_mask"]) + + fc() + + latency_us = fc.last_elapsed * 1e6 + gflops = _core_gemm_flops(H, G, d, E, S) / (fc.last_elapsed) / 1e9 + print(f"\nLatency (us): {latency_us:.1f}") + print(f"Throughput: {gflops:.6e} GFLOP/s") + + actual = _get_output_tensor(fc, "attn_context", (H, S, d)) + expected = golden["attn_context"].float().numpy().reshape(H, S, d) + errors = verify_buffer( + torch.from_numpy(actual).bfloat16(), + "attn_context", + torch.from_numpy(expected).bfloat16().reshape(H, S, d), + rel_tol=REL_TOL, + abs_tol=ABS_TOL, + max_error_rate=MAX_ERROR_RATE, + ) + assert not errors, f"Output verification failed with {len(errors)} errors" + + +# --------------------------------------------------------------------------- +# Projected attention tests (with Q/K/V projections + RoPE) +# --------------------------------------------------------------------------- + + +@pytest.mark.metrics( + Latency=r"Latency \(us\): (?P[\d\.]+)", + Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", +) +@pytest.mark.parametrize("H,G,d,E,S", get_params()) +def test_attention_prefill_projected_fused(H, G, d, E, S): + """Projected attention: Q/K/V proj -> RoPE -> GQA -> attention -> output proj.""" + golden = generate_golden_reference(H, G, d, E, S) + + op = AttentionPrefillProjectedFused(H, G, d, E, S) + op.compile() + fc = op.get_callable() + + _load_input(fc, "input", golden["input"]) + _load_input(fc, "rope_angles", golden["rope_angles"]) + _load_input(fc, "W_query", golden["W_query"]) + _load_input(fc, "W_key", golden["W_key"]) + _load_input(fc, "W_value", golden["W_value"]) + _load_input(fc, "W_output", golden["W_output"]) + _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) + _load_input(fc, "causal_mask", golden["causal_mask"]) + + fc() + + latency_us = fc.last_elapsed * 1e6 + gflops = _projected_gemm_flops(H, G, d, E, S) / (fc.last_elapsed) / 1e9 + print(f"\nLatency (us): {latency_us:.1f}") + print(f"Throughput: {gflops:.6e} GFLOP/s") + + _verify_output(fc, golden, H, d, S, E) + + +# --------------------------------------------------------------------------- +# Benchmark: GPT-2 Small core MHA across sequence lengths, +/- causal mask +# --------------------------------------------------------------------------- + + +@pytest.mark.benchmark +@pytest.mark.metrics( + Latency=r"Latency \(us\): (?P[\d\.]+)", + Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", +) +@pytest.mark.parametrize("H,G,d,E,S,causal", get_benchmark_params()) +def test_mha_prefill_benchmark(H, G, d, E, S, causal): + """Benchmark core MHA for GPT-2 Small across sequence lengths.""" + golden = generate_golden_reference(H, G, d, E, S) + + op = AttentionPrefillFused(H, G, d, E, S, causal_mask=causal) + op.compile() + fc = op.get_callable() + + _load_input(fc, "queries", golden["queries_deinterleaved"]) + _load_input(fc, "keys", golden["keys_for_scores"]) + _load_input(fc, "values", golden["values_for_context"]) + _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) + if causal: + _load_input(fc, "causal_mask", golden["causal_mask"]) + + fc() + + latency_us = fc.last_elapsed * 1e6 + gflops = _core_gemm_flops(H, G, d, E, S) / (fc.last_elapsed) / 1e9 + print(f"\nLatency (us): {latency_us:.1f}") + print(f"Throughput: {gflops:.6e} GFLOP/s") + + +# --------------------------------------------------------------------------- +# Intermediate checks (extensive, not run by default) +# --------------------------------------------------------------------------- + +INTERMEDIATE_CHECKS = [ + ("attn_scores", "attn_scores", lambda H, G, S, d: (H, S, S), "scratch"), + ( + "attn_scores_masked", + "attn_scores_masked", + lambda H, G, S, d: (H, S, S), + "scratch", + ), + ("attn_weights", "attn_weights", lambda H, G, S, d: (H, S, S), "scratch"), + ("attn_context", "attn_context", lambda H, G, S, d: (H, S, d), "output"), +] + + +@pytest.mark.extensive +@pytest.mark.parametrize("H,G,d,E,S", get_params()) +def test_mha_pefill_lxl_sd_intermediates(H, G, d, E, S): + """Check intermediate buffers of core attention (for debugging).""" + golden = generate_golden_reference(H, G, d, E, S) + + op = AttentionPrefillFused(H, G, d, E, S) + op.compile() + fc = op.get_callable() + + _load_input(fc, "queries", golden["queries_deinterleaved"]) + _load_input(fc, "keys", golden["keys_for_scores"]) + _load_input(fc, "values", golden["values_for_context"]) + _load_input(fc, "attn_scale_factor", golden["attn_scale_factor"]) + _load_input(fc, "causal_mask", golden["causal_mask"]) + + fc() + + for buf_name, golden_key, shape_fn, buf_type in INTERMEDIATE_CHECKS: + shape = shape_fn(H, G, S, d) + if buf_type == "output": + actual = _get_output_tensor(fc, buf_name, shape) + else: + actual = _get_scratch_tensor(fc, buf_name, shape) + expected = golden[golden_key].float().numpy().reshape(shape) + diff = np.abs(actual - expected) + print( + f" [{buf_name}] shape={shape} " + f"nan={int(np.isnan(actual).sum())} " + f"max_abs_err={diff.max():.4f} mean_abs_err={diff.mean():.6f}" + ) diff --git a/pytest.ini b/pytest.ini index 44f08847..a3566ee2 100644 --- a/pytest.ini +++ b/pytest.ini @@ -9,4 +9,5 @@ python_functions = test_* markers = extensive: extensive test suite (deselect with '-m "not extensive"') supported_devices(*devices): mark test as only supported on the given devices (e.g. "npu1", "npu2"). All devices supported by default. + benchmark: benchmark-only tests (select with '-m benchmark') addopts = -v --tb=short --import-mode=importlib