From 07b263b9d74fd9d708434f6fd3a71ce07e0fec6e Mon Sep 17 00:00:00 2001 From: hellozmz <407190054@qq.com> Date: Mon, 9 Mar 2026 07:06:07 +0000 Subject: [PATCH] update passed_tests/test_fa_v1.py --- backend/npu.py | 1 - test/ascend/passed_tests/test_fa_v1.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/backend/npu.py b/backend/npu.py index 21afec2..a802d81 100644 --- a/backend/npu.py +++ b/backend/npu.py @@ -497,7 +497,6 @@ def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False): dicp_triton.passes.linked_npu.add_linalg_if_to_select(pm) dicp_triton.passes.linked_npu.add_linalg_generic_to_scf(pm) dicp_triton.passes.linked_npu.add_scalar_to_1d_tensor(pm) - # dicp_triton.passes.linked_npu.add_vectorize_kernel(pm) # 添加vectorize-kernel pass dicp_triton.passes.linked_npu.add_linalg_to_linked(pm, False, True) dicp_triton.passes.linked_npu.add_linked_to_hivm(pm) pm.run(mod) diff --git a/test/ascend/passed_tests/test_fa_v1.py b/test/ascend/passed_tests/test_fa_v1.py index eb421f6..dad458c 100644 --- a/test/ascend/passed_tests/test_fa_v1.py +++ b/test/ascend/passed_tests/test_fa_v1.py @@ -83,11 +83,11 @@ def _attn_fwd_inner( if STAGE == 2: mask = offs_m[:, None] >= (start_n + offs_n[None, :]) qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) - m_ij = tl.maximum(m_i, tl.max(qk, 1)) + m_ij = tl.maximum(m_i, tl.max(qk, 1), tl.PropagateNan.ALL) qk -= m_ij[:, None] else: qk = qk * qk_scale - m_ij = tl.maximum(m_i, tl.max(qk, 1)) + m_ij = tl.maximum(m_i, tl.max(qk, 1), tl.PropagateNan.ALL) qk = qk - m_ij[:, None] # p = tl.math.exp2(qk) @@ -516,7 +516,7 @@ def bench_fn(op, provider): if __name__ == "__main__": # test_op(1,8,8192,128, causal=True, dtype=torch.float16, BM = 32,BN = 32) - test_op(1, 2, 2048, 64, causal=False, dtype=torch.float16, BM=128, BN=512) + test_op(1, 2, 64 * 1024, 64, causal=False, dtype=torch.float16, BM=128, BN=512) # test_op(4,32,1024,64, causal=False, dtype=torch.float16, BM = 64,BN = 256) # test_op(4,32,4096,64, causal=False, dtype=torch.float16, BM = 64,BN = 256) # test_op(4,32,8192,64, causal=False, dtype=torch.float16, BM =64,BN = 256)