From e5e94af1ef979f6b728ef09c109cc9afd196f152 Mon Sep 17 00:00:00 2001 From: hellozmz <407190054@qq.com> Date: Tue, 20 Jan 2026 08:08:02 +0000 Subject: [PATCH 01/14] fix data implicit transposed bugs --- backend/npu.py | 1 + .../Dialect/LinalgExt/Transforms/Passes.h | 2 + .../Dialect/LinalgExt/Transforms/Passes.td | 14 ++ compiler/include/dicp/Utils/Utils.h | 94 +++++++- .../Transforms/AnnotateTransposePass.cpp | 220 ++++++++++++++++++ .../LinalgExt/Transforms/CMakeLists.txt | 3 +- test/ascend/mlir/annotate_transpose_pass.mlir | 57 +++++ triton_dicp_triton.cc | 4 + 8 files changed, 393 insertions(+), 2 deletions(-) create mode 100644 compiler/lib/Dialect/LinalgExt/Transforms/AnnotateTransposePass.cpp create mode 100644 test/ascend/mlir/annotate_transpose_pass.mlir diff --git a/backend/npu.py b/backend/npu.py index 74c18580..7278060e 100644 --- a/backend/npu.py +++ b/backend/npu.py @@ -434,6 +434,7 @@ def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False): dicp_triton.passes.linked_npu.add_linalg_generic_to_scf(pm) dicp_triton.passes.linked_npu.add_scalar_to_1d_tensor(pm) dicp_triton.passes.linked_npu.add_linalg_to_linked(pm, False, True) + dicp_triton.passes.linked_npu.add_annotate_transpose(pm) dicp_triton.passes.linked_npu.add_linked_to_hivm(pm) pm.run(mod) diff --git a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h index 7ae43b6c..d82378c1 100644 --- a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h +++ b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h @@ -20,6 +20,8 @@ std::unique_ptr> createLinalgGenericToSCFPass(); std::unique_ptr> createScalarTo1DTensorPass(); +std::unique_ptr> createAnnotateTransposePass(); + std::unique_ptr> createNormalizeSliceOpsPass(); diff --git a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td index c486210a..d6a8224d 100644 --- a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td +++ b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td @@ -68,4 +68,18 @@ def NormalizeSliceOps : Pass<"normalize-slice-ops", "func::FuncOp"> { let dependentDialects = ["mlir::tensor::TensorDialect"]; } +def AnnotateTransposePass : Pass<"annotate-transpose", "func::FuncOp"> { + let summary = "Annotate operations with permuted memref type"; + let description = [{ + Adds MayImplicitTransposeWithLastAxis annotations to operations with permuted memref type. + }]; + let constructor = "mlir::dicp::LinalgExt::createAnnotateTransposePass()"; + let dependentDialects = [ + "mlir::func::FuncDialect", + "mlir::linalg::LinalgDialect", + "mlir::memref::MemRefDialect", + "mlir::bufferization::BufferizationDialect" + ]; +} + #endif diff --git a/compiler/include/dicp/Utils/Utils.h b/compiler/include/dicp/Utils/Utils.h index e50253dd..dd6964ce 100644 --- a/compiler/include/dicp/Utils/Utils.h +++ b/compiler/include/dicp/Utils/Utils.h @@ -3,19 +3,95 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" + + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringSwitch.h" #include #include +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" + +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" + +#include +#include +#include +#include +#include +#include + // Dispatch conversion pattern handlers based on backend string. Executes // ASCEND_HANDLER when backend == "ascend", otherwise DEFAULT_HANDLER. #define DISPATCH_BACKEND_CONVERSION_PATTERNS(BACKEND_STR, ASCEND_HANDLER, \ @@ -43,7 +119,23 @@ llvm::StringRef getBackend(ModuleOp module); bool isAscendBackend(ModuleOp module); -bool isaPermutedMemRefType(MemRefType); +// bool isaPermutedMemRefType(MemRefType); +inline bool isaPermutedMemRefType(MemRefType memRefType) { + auto [ptrStrides, ptrOffsets] = memRefType.getStridesAndOffset(); + // LLVM_DEBUG({ + // llvm::dbgs()<<"---------- [BEG] ptrStrides ----------\n"; + // for(auto stride: ptrStrides)llvm::dbgs()< + +#define DEBUG_TYPE "annotate-transpose-pass" + +using namespace mlir; +using namespace mlir::dicp; + +namespace mlir { +namespace dicp { +namespace LinalgExt { +#define GEN_PASS_DEF_ANNOTATETRANSPOSEPASS +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h.inc" +} // namespace LinalgExt +} // namespace dicp +} // namespace mlir + +namespace { + +struct AnnotateTransposePass + : public mlir::dicp::LinalgExt::impl::AnnotateTransposePassBase { + +void runOnOperation() override { + auto funcOp = getOperation(); + + llvm::outs() << "[INFO] Starting AnnotateTransposePass on function: " + << funcOp.getName() << "\n"; + + + llvm::outs() << "[INFO] Function body:\n"; + funcOp.print(llvm::outs()); + llvm::outs() << "\n"; + // 首先收集所有需要标记的bufferization.to_tensor操作 + SmallVector toTensorOpsToMark; + + // 检查 memref.copy 操作,看是否会传播非标准 stride + funcOp.walk([&](memref::CopyOp copyOp) { + auto source = copyOp.getSource(); + auto target = copyOp.getTarget(); + + llvm::outs() << "[MEMREF_COPY] Copy operation: " << copyOp << "\n"; + + // 检查源的 stride 信息 + if (auto sourceType = dyn_cast(source.getType())) { + auto [sourceStrides, sourceOffsets] = sourceType.getStridesAndOffset(); + llvm::outs() << " Source strides: ["; + for (size_t i = 0; i < sourceStrides.size(); ++i) { + llvm::outs() << sourceStrides[i] << (i < sourceStrides.size()-1 ? ", " : ""); + } + llvm::outs() << "]\n"; + + bool isSourcePermuted = mlir::dicp::isaPermutedMemRefType(sourceType); + llvm::outs() << " Source is permuted: " << isSourcePermuted << "\n"; + + // 检查目标的 stride 信息 + if (auto targetType = dyn_cast(target.getType())) { + auto [targetStrides, targetOffsets] = targetType.getStridesAndOffset(); + llvm::outs() << " Target strides: ["; + for (size_t i = 0; i < targetStrides.size(); ++i) { + llvm::outs() << targetStrides[i] << (i < targetStrides.size()-1 ? ", " : ""); + } + llvm::outs() << "]\n"; + + bool isTargetPermuted = mlir::dicp::isaPermutedMemRefType(targetType); + llvm::outs() << " Target is permuted: " << isTargetPermuted << "\n"; + + // 关键:如果源是置换的,即使目标不是,也要追踪目标的使用者 + if (isSourcePermuted) { + // 检查目标的使用者 + for (auto user : target.getUsers()) { + if (auto toTensorOp = dyn_cast(user)) { + toTensorOpsToMark.push_back(toTensorOp); + llvm::outs() << " [COPY_SRC_PERMUTED] Marked bufferization.to_tensor for annotation (source was permuted)\n"; + } + } + + // 检查目标是否是某个分配的子视图,如果是,则追踪分配的使用者 + if (auto sourceDefOp = target.getDefiningOp()) { + if (auto subviewOp = dyn_cast(sourceDefOp)) { + Value parentMemRef = subviewOp.getSource(); + + // 检查父memref的使用者 + for (auto user : parentMemRef.getUsers()) { + if (auto toTensorOp = dyn_cast(user)) { + llvm::outs() << " [COPY_SRC_PERMUTED_PARENT] Found bufferization.to_tensor user of parent of copy target: " << toTensorOp << "\n"; + toTensorOpsToMark.push_back(toTensorOp); + llvm::outs() << " [COPY_SRC_PERMUTED_PARENT] Marked bufferization.to_tensor for annotation from parent of copy target (source was permuted)\n"; + } + } + } + } + } else if (isTargetPermuted) { + // 如果目标是置换的 + for (auto user : target.getUsers()) { + if (auto toTensorOp = dyn_cast(user)) { + toTensorOpsToMark.push_back(toTensorOp); + llvm::outs() << " Marked bufferization.to_tensor for annotation from copy target\n"; + } + } + } + } + } + }); + + // 遍历所有 bufferization.to_tensor 操作 + funcOp.walk([&](bufferization::ToTensorOp toTensorOp) { + // 检查源 memref 是否来自于具有非标准 stride 的操作 + Value sourceMemRef = toTensorOp.getOperand(); + + // 检查源 memref 是否为置换类型 + if (auto memRefType = dyn_cast(sourceMemRef.getType())) { + bool isPermuted = mlir::dicp::isaPermutedMemRefType(memRefType); + llvm::outs() << "[TO_TENSOR_CHECK] bufferization.to_tensor: " << toTensorOp << "\n"; + llvm::outs() << " Source memref: " << sourceMemRef << "\n"; + llvm::outs() << " MemRefType: "; + memRefType.dump(); + llvm::outs() << "\n"; + llvm::outs() << " Is permuted: " << isPermuted << "\n"; + + if (isPermuted) { + toTensorOpsToMark.push_back(toTensorOp); + llvm::outs() << " [MARK_ADDED] Marked bufferization.to_tensor for annotation\n"; + } + } + }); + + // 现在对所有标记的to_tensor操作添加annotation + // for (auto toTensorOp : toTensorOpsToMark) { + // llvm::outs() << " [MARK_ADDED] About to add transpose annotation to bufferization.to_tensor result\n"; + // OpBuilder builder(toTensorOp); // 使用toTensorOp的builder,在其后插入 + // auto markOp = builder.create( + // toTensorOp->getLoc(), toTensorOp.getResult()); + // markOp->setAttr("MayImplicitTransposeWithLastAxis", + // UnitAttr::get(&getContext())); + // llvm::outs() << " toTensorOp: " << toTensorOp << ", markOp: " << markOp << "\n"; + // llvm::outs() << " [MARK_ADDED] Added transpose annotation to bufferization.to_tensor result\n"; + // llvm::outs() << " Created annotation::MarkOp: " << markOp << "\n"; + // } + + for (auto toTensorOp : toTensorOpsToMark) { + llvm::outs() << " [MARK_ADDED] About to add transpose annotation to bufferization.to_tensor result\n"; + + // 修改点 1: 初始化 Builder (可以使用 context) + OpBuilder builder(toTensorOp->getContext()); + + // 修改点 2: 显式设置插入点在 toTensorOp 之后 + builder.setInsertionPointAfter(toTensorOp); + + auto markOp = builder.create( + toTensorOp->getLoc(), toTensorOp.getResult()); + + // 建议:使用 builder.getContext() 或者 toTensorOp->getContext() 获取上下文, + // 以防外层的 getContext() 在某些闭包或静态函数中不可用 + markOp->setAttr("MayImplicitTransposeWithLastAxis", + UnitAttr::get(builder.getContext())); + + llvm::outs() << " toTensorOp: " << toTensorOp << ", markOp: " << markOp << "\n"; + llvm::outs() << " [MARK_ADDED] Added transpose annotation to bufferization.to_tensor result\n"; + llvm::outs() << " Created annotation::MarkOp: " << markOp << "\n"; + } + + llvm::outs() << "[INFO] After Function body:\n"; + funcOp.print(llvm::outs()); + llvm::outs() << "\n"; + + llvm::outs() << "[INFO] Finished AnnotateTransposePass on function: " + << funcOp.getName() << "\n"; +} + +private: + bool needsImplicitTranspose(Value value) { + if (auto memRefType = dyn_cast(value.getType())) { + // 检查是否是置换类型的内存布局 + bool isPermuted = mlir::dicp::isaPermutedMemRefType(memRefType); + llvm::outs() << "zmz [DEBUG] isPermuted: " << isPermuted << "\n"; + llvm::outs() << "Detected permuted memref type: "; + memRefType.dump(); + llvm::outs() << "\n"; + LLVM_DEBUG({ + if(isPermuted) { + llvm::dbgs() << "Detected permuted memref type: "; + memRefType.dump(); + llvm::dbgs() << "\n"; + } + }); + return isPermuted; + } + llvm::outs() << "zmz [DEBUG] Not a MemRefType, skipping.\n"; + return false; + } +}; +} // namespace + +namespace mlir::dicp::LinalgExt { +std::unique_ptr> createAnnotateTransposePass() { + return std::make_unique(); +} +} // namespace mlir::dicp::LinalgExt diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt b/compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt index 0b28548a..272103de 100644 --- a/compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_triton_library(LinalgExtTransforms ScalarTo1DTensorPass.cpp RemoveSingleIterationLoop.cpp TensorTransform.cpp + AnnotateTransposePass.cpp DEPENDS LinalgExtTransformsIncGen @@ -26,4 +27,4 @@ add_triton_library(LinalgExtTransforms TritonArithToLinalg StructuredToMemref TritonToStructured -) +) \ No newline at end of file diff --git a/test/ascend/mlir/annotate_transpose_pass.mlir b/test/ascend/mlir/annotate_transpose_pass.mlir new file mode 100644 index 00000000..bb62f77e --- /dev/null +++ b/test/ascend/mlir/annotate_transpose_pass.mlir @@ -0,0 +1,57 @@ +// Test for AnnotateTransposePass - checks that the pass adds MayImplicitTransposeWithLastAxis annotations appropriately + +// RUN: bishengir-opt %s -annotate-transpose-pass | FileCheck %s + +func.func @test_linalg_copy_with_permuted_memref() { + // Original memref with permuted layout + %0 = memref.alloc() : memref<128x5xf32, strided<[8, 1]>> + %1 = memref.alloc() : memref<128x5xf32, strided<[1, 128], offset: ?>> + // linalg.copy should get annotated since target has permuted memref + linalg.copy %1, %0 : memref<128x5xf32, strided<[1, 128], offset: ?>> to memref<128x5xf32, strided<[8, 1]>> + // CHECK: linalg.copy + // CHECK: annotation.mark + // CHECK: "MayImplicitTransposeWithLastAxis" + return +} + +func.func @test_memref_copy_with_permuted_memref() { + // Original memref with permuted layout + %0 = memref.alloc() : memref<128x8xf32, strided<[8, 1]>> + %1 = memref.alloc() : memref<128x8xf32, strided<[1, 128], offset: ?>> + // memref.copy should get annotated since target has permuted memref + memref.copy %1, %0 : memref<128x8xf32, strided<[1, 128], offset: ?>> to memref<128x8xf32, strided<[8, 1]>> + // CHECK: memref.copy + // CHECK: annotation.mark + // CHECK: "MayImplicitTransposeWithLastAxis" + return +} + +func.func @test_bufferization_to_tensor_with_permuted_source() { + %0 = memref.alloc() : memref<128x8xf32, strided<[8, 1]>> + // bufferization.to_tensor should get annotated since source has permuted memref + %1 = bufferization.to_tensor %0 : memref<128x8xf32, strided<[8, 1]>> + // CHECK: bufferization.to_tensor + // CHECK: annotation.mark + // CHECK: "MayImplicitTransposeWithLastAxis" + return +} + +func.func @test_memref_subview_with_permuted_source() { + %0 = memref.alloc() : memref<128x8xf32, strided<[8, 1]>> + // memref.subview should get annotated since source has permuted memref + %1 = memref.subview %0[0, 0] to [64, 4] : memref<128x8xf32, strided<[8, 1]>> + // CHECK: memref.subview + // CHECK: annotation.mark + // CHECK: "MayImplicitTransposeWithLastAxis" + return +} + +func.func @test_non_permuted_memref_no_annotation() { + // Non-permuted memref should not get annotated + %0 = memref.alloc() : memref<128x5xf32> + %1 = memref.alloc() : memref<128x5xf32> + linalg.copy %1, %0 : memref<128x5xf32> to memref<128x5xf32> + // CHECK-NOT: annotation.mark + // CHECK-NOT: "MayImplicitTransposeWithLastAxis" + return +} \ No newline at end of file diff --git a/triton_dicp_triton.cc b/triton_dicp_triton.cc index 979e0071..077d10a0 100644 --- a/triton_dicp_triton.cc +++ b/triton_dicp_triton.cc @@ -70,6 +70,10 @@ void init_triton_dicp_triton_pass_linked_npu(py::module &&m) { pm.addNestedPass( dicp::LinalgExt::createScalarTo1DTensorPass()); }); + m.def("add_annotate_transpose", [](mlir::PassManager &pm) { + pm.addNestedPass( + dicp::LinalgExt::createAnnotateTransposePass()); + }); m.def("add_linalg_to_linked", [](mlir::PassManager &pm, bool globalKernel, bool namedOps) { pm.addPass(mlir::dicp::linked::createLinalgToLinkedPass(globalKernel, From 8dec5d682c224f3ae37b38b324f055b600ce4388 Mon Sep 17 00:00:00 2001 From: hellozmz <407190054@qq.com> Date: Mon, 26 Jan 2026 08:54:58 +0000 Subject: [PATCH 02/14] pass 4/5 test && change shape for ub overflow --- .../Transforms/AnnotateTransposePass.cpp | 284 +++++++++++------- .../passed_tests/test_lightning_attn.py | 174 +++++++++++ 2 files changed, 356 insertions(+), 102 deletions(-) create mode 100644 test/ascend/passed_tests/test_lightning_attn.py diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/AnnotateTransposePass.cpp b/compiler/lib/Dialect/LinalgExt/Transforms/AnnotateTransposePass.cpp index af83ae67..cb4d4a51 100644 --- a/compiler/lib/Dialect/LinalgExt/Transforms/AnnotateTransposePass.cpp +++ b/compiler/lib/Dialect/LinalgExt/Transforms/AnnotateTransposePass.cpp @@ -5,6 +5,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "bishengir/Dialect/Annotation/IR/Annotation.h" #include "dicp/Utils/Utils.h" @@ -39,6 +40,53 @@ namespace LinalgExt { namespace { +// ============================================================================== +// 辅助函数定义 +// ============================================================================== + +/// 检查一个MemRefType是否是置换类型的辅助函数 +/// 判定标准:使用 dicp::isaPermutedMemRefType 或 检查 stride 是否非标准 +bool isPermutedOrHasNonUnitLastStride(MemRefType memRefType) { + if (!memRefType) return false; + + // 1. 使用现有的判定函数 + if (mlir::dicp::isaPermutedMemRefType(memRefType)) { + return true; + } + + // 2. 额外检查:最后维度的stride是否为1 + // 对于 Ascend 来说,如果最后维度 stride != 1,通常意味着不是连续内存,可能需要隐式转置 + auto [strides, offset] = memRefType.getStridesAndOffset(); + if (!strides.empty() && strides.back() != 1) { + return true; + } + + return false; +} + +/// 递归检查值的来源是否具有非标准stride +bool checkValueOriginHasNonStandardStride(Value value) { + if (auto memRefType = dyn_cast(value.getType())) { + if (isPermutedOrHasNonUnitLastStride(memRefType)) { + return true; + } + } + + // 检查定义操作 + if (Operation *defOp = value.getDefiningOp()) { + // 检查Subview操作 + if (auto subViewOp = dyn_cast(defOp)) { + return checkValueOriginHasNonStandardStride(subViewOp.getSource()); + } + // 检查ReinterpretCast操作 + if (auto castOp = dyn_cast(defOp)) { + return checkValueOriginHasNonStandardStride(castOp.getSource()); + } + } + + return false; +} + struct AnnotateTransposePass : public mlir::dicp::LinalgExt::impl::AnnotateTransposePassBase { @@ -48,75 +96,137 @@ void runOnOperation() override { llvm::outs() << "[INFO] Starting AnnotateTransposePass on function: " << funcOp.getName() << "\n"; - - llvm::outs() << "[INFO] Function body:\n"; - funcOp.print(llvm::outs()); - llvm::outs() << "\n"; - // 首先收集所有需要标记的bufferization.to_tensor操作 + // 待处理列表 SmallVector toTensorOpsToMark; - - // 检查 memref.copy 操作,看是否会传播非标准 stride + SmallVector opsToErase; // 用于存储被重写后需要删除的旧Op + + // ============================================================================== + // 1. 遍历 memref.copy 操作 + // 核心逻辑:检测 Dynamic Subview Copy -> 重写为 Static Full Copy + Annotation + // ============================================================================== funcOp.walk([&](memref::CopyOp copyOp) { auto source = copyOp.getSource(); auto target = copyOp.getTarget(); - llvm::outs() << "[MEMREF_COPY] Copy operation: " << copyOp << "\n"; + llvm::outs() << "[MEMREF_COPY_VISIT] " << copyOp << "\n"; + + // --- 尝试进行 IR 重写 (Rewrite) --- + // 目标:将 memref.copy(subview(A), subview(B)) 转换为 memref.copy(A, B) + // 条件:A 是静态 Permuted,B 是静态 Contiguous,且形状匹配 - // 检查源的 stride 信息 - if (auto sourceType = dyn_cast(source.getType())) { - auto [sourceStrides, sourceOffsets] = sourceType.getStridesAndOffset(); - llvm::outs() << " Source strides: ["; - for (size_t i = 0; i < sourceStrides.size(); ++i) { - llvm::outs() << sourceStrides[i] << (i < sourceStrides.size()-1 ? ", " : ""); - } - llvm::outs() << "]\n"; - - bool isSourcePermuted = mlir::dicp::isaPermutedMemRefType(sourceType); - llvm::outs() << " Source is permuted: " << isSourcePermuted << "\n"; + auto srcSubView = source.getDefiningOp(); + auto dstSubView = target.getDefiningOp(); + + if (srcSubView && dstSubView) { + Value baseSource = srcSubView.getSource(); + Value baseTarget = dstSubView.getSource(); - // 检查目标的 stride 信息 - if (auto targetType = dyn_cast(target.getType())) { - auto [targetStrides, targetOffsets] = targetType.getStridesAndOffset(); - llvm::outs() << " Target strides: ["; - for (size_t i = 0; i < targetStrides.size(); ++i) { - llvm::outs() << targetStrides[i] << (i < targetStrides.size()-1 ? ", " : ""); + auto baseSourceType = dyn_cast(baseSource.getType()); + auto baseTargetType = dyn_cast(baseTarget.getType()); + + if (baseSourceType && baseTargetType && + baseSourceType.hasStaticShape() && baseTargetType.hasStaticShape()) { + + bool isBaseSourcePermuted = isPermutedOrHasNonUnitLastStride(baseSourceType); + // 简化判定:如果不是 Permuted 且 stride 正常,视为 Contiguous + bool isBaseTargetContiguous = !isPermutedOrHasNonUnitLastStride(baseTargetType); + + // 检查 Static Shape 是否一致 (例如都是 2x8xf32) + if (isBaseSourcePermuted && isBaseTargetContiguous && + baseSourceType.getShape() == baseTargetType.getShape()) { + + llvm::outs() << " [REWRITE_MATCH] Found Dynamic Subview Copy candidate for Static Rewrite.\n"; + llvm::outs() << " Base Source (Permuted): " << baseSourceType << "\n"; + llvm::outs() << " Base Target (Contiguous): " << baseTargetType << "\n"; + + // 执行重写 + OpBuilder builder(copyOp->getContext()); + builder.setInsertionPoint(copyOp); + + // 1. 创建新的静态 Copy (Base -> Base) + auto newCopyOp = builder.create(copyOp.getLoc(), baseSource, baseTarget); + llvm::outs() << " -> Replaced with Static Copy: " << newCopyOp << "\n"; + + // 2. 关键:在 Base Target (MemRef) 上添加 Annotation + // 这指导 Ascend 编译器生成隐式转置指令 + builder.setInsertionPointAfter(newCopyOp); + auto markOp = builder.create(copyOp.getLoc(), baseTarget); + markOp->setAttr("MayImplicitTransposeWithLastAxis", UnitAttr::get(builder.getContext())); + llvm::outs() << " -> Added Annotation to Base Target MemRef: " << markOp << "\n"; + + // 3. 追踪 Base Target 的 Tensor 使用者 + // 我们需要标记 bufferization.to_tensor(BaseTarget),这样后续的 MatMul 才能识别到 Layout 变化 + for (auto user : baseTarget.getUsers()) { + if (auto toTensorOp = dyn_cast(user)) { + // 去重检查 + bool exists = false; + for(auto op : toTensorOpsToMark) if(op == toTensorOp) exists = true; + + if(!exists) { + toTensorOpsToMark.push_back(toTensorOp); + llvm::outs() << " -> Scheduled Base Target's ToTensorOp for annotation: " << toTensorOp << "\n"; + } + } + } + + // 4. 标记旧的 Copy Op 待删除 + opsToErase.push_back(copyOp); + + // 重写完成,跳过后续分析 + return; } - llvm::outs() << "]\n"; + } + } + + // --- 如果没有触发重写,执行常规的传播分析 --- + // (针对代码中已经是静态 Copy 的情况,或者仅仅进行标记传播) + + if (auto sourceType = dyn_cast(source.getType())) { + bool isSourcePermuted = isPermutedOrHasNonUnitLastStride(sourceType); - bool isTargetPermuted = mlir::dicp::isaPermutedMemRefType(targetType); - llvm::outs() << " Target is permuted: " << isTargetPermuted << "\n"; + if (auto targetType = dyn_cast(target.getType())) { + bool isTargetPermuted = isPermutedOrHasNonUnitLastStride(targetType); - // 关键:如果源是置换的,即使目标不是,也要追踪目标的使用者 + // 如果源是置换的,追踪目标的使用者 if (isSourcePermuted) { // 检查目标的使用者 for (auto user : target.getUsers()) { if (auto toTensorOp = dyn_cast(user)) { - toTensorOpsToMark.push_back(toTensorOp); - llvm::outs() << " [COPY_SRC_PERMUTED] Marked bufferization.to_tensor for annotation (source was permuted)\n"; + bool exists = false; + for(auto op : toTensorOpsToMark) if(op == toTensorOp) exists = true; + if(!exists) { + toTensorOpsToMark.push_back(toTensorOp); + llvm::outs() << " [PROPAGATE] Marked bufferization.to_tensor (Source was permuted)\n"; + } } } - // 检查目标是否是某个分配的子视图,如果是,则追踪分配的使用者 + // 如果目标是 Subview,追踪其父 MemRef if (auto sourceDefOp = target.getDefiningOp()) { if (auto subviewOp = dyn_cast(sourceDefOp)) { Value parentMemRef = subviewOp.getSource(); - - // 检查父memref的使用者 for (auto user : parentMemRef.getUsers()) { if (auto toTensorOp = dyn_cast(user)) { - llvm::outs() << " [COPY_SRC_PERMUTED_PARENT] Found bufferization.to_tensor user of parent of copy target: " << toTensorOp << "\n"; - toTensorOpsToMark.push_back(toTensorOp); - llvm::outs() << " [COPY_SRC_PERMUTED_PARENT] Marked bufferization.to_tensor for annotation from parent of copy target (source was permuted)\n"; + bool exists = false; + for(auto op : toTensorOpsToMark) if(op == toTensorOp) exists = true; + if(!exists) { + toTensorOpsToMark.push_back(toTensorOp); + llvm::outs() << " [PROPAGATE_PARENT] Marked bufferization.to_tensor of Parent MemRef\n"; + } } } } } } else if (isTargetPermuted) { - // 如果目标是置换的 + // 如果目标本身就是置换的 for (auto user : target.getUsers()) { if (auto toTensorOp = dyn_cast(user)) { - toTensorOpsToMark.push_back(toTensorOp); - llvm::outs() << " Marked bufferization.to_tensor for annotation from copy target\n"; + bool exists = false; + for(auto op : toTensorOpsToMark) if(op == toTensorOp) exists = true; + if(!exists) { + toTensorOpsToMark.push_back(toTensorOp); + llvm::outs() << " [PROPAGATE_TARGET] Marked bufferization.to_tensor (Target is permuted)\n"; + } } } } @@ -124,92 +234,62 @@ void runOnOperation() override { } }); - // 遍历所有 bufferization.to_tensor 操作 + // 删除被重写的旧 Op + for (auto op : opsToErase) { + op->erase(); + } + + // ============================================================================== + // 2. 扫描所有 bufferization.to_tensor 操作 (查漏补缺) + // ============================================================================== funcOp.walk([&](bufferization::ToTensorOp toTensorOp) { - // 检查源 memref 是否来自于具有非标准 stride 的操作 + // 如果已经在列表中,跳过 + for(auto existing : toTensorOpsToMark) { if (existing == toTensorOp) return; } + Value sourceMemRef = toTensorOp.getOperand(); + bool hasNonStandardStride = checkValueOriginHasNonStandardStride(sourceMemRef); - // 检查源 memref 是否为置换类型 + bool shouldMark = false; if (auto memRefType = dyn_cast(sourceMemRef.getType())) { - bool isPermuted = mlir::dicp::isaPermutedMemRefType(memRefType); - llvm::outs() << "[TO_TENSOR_CHECK] bufferization.to_tensor: " << toTensorOp << "\n"; - llvm::outs() << " Source memref: " << sourceMemRef << "\n"; - llvm::outs() << " MemRefType: "; - memRefType.dump(); - llvm::outs() << "\n"; - llvm::outs() << " Is permuted: " << isPermuted << "\n"; - - if (isPermuted) { - toTensorOpsToMark.push_back(toTensorOp); - llvm::outs() << " [MARK_ADDED] Marked bufferization.to_tensor for annotation\n"; + if (isPermutedOrHasNonUnitLastStride(memRefType)) { + shouldMark = true; } } + + if (shouldMark || hasNonStandardStride) { + toTensorOpsToMark.push_back(toTensorOp); + llvm::outs() << "[TO_TENSOR_CHECK] Found permuted/strided origin: " << toTensorOp << "\n"; + } }); - // 现在对所有标记的to_tensor操作添加annotation - // for (auto toTensorOp : toTensorOpsToMark) { - // llvm::outs() << " [MARK_ADDED] About to add transpose annotation to bufferization.to_tensor result\n"; - // OpBuilder builder(toTensorOp); // 使用toTensorOp的builder,在其后插入 - // auto markOp = builder.create( - // toTensorOp->getLoc(), toTensorOp.getResult()); - // markOp->setAttr("MayImplicitTransposeWithLastAxis", - // UnitAttr::get(&getContext())); - // llvm::outs() << " toTensorOp: " << toTensorOp << ", markOp: " << markOp << "\n"; - // llvm::outs() << " [MARK_ADDED] Added transpose annotation to bufferization.to_tensor result\n"; - // llvm::outs() << " Created annotation::MarkOp: " << markOp << "\n"; - // } - + // ============================================================================== + // 3. 执行最终标记:为收集到的 Tensor 添加 Annotation + // ============================================================================== for (auto toTensorOp : toTensorOpsToMark) { - llvm::outs() << " [MARK_ADDED] About to add transpose annotation to bufferization.to_tensor result\n"; + // 双重检查:防止重复添加 MarkOp (虽然 OpBuilder 会创建新的 Op,但逻辑上我们不希望冗余) + // 简单检查该 Value 是否已经被 MarkOp 使用 + bool alreadyMarked = false; + // 注意:annotation::MarkOp 通常不直接作为 User 挂在 Value 上,而是作为一个独立的 Op 存在。 + // 为了稳妥,这里我们假设 list 中可能有重复(如果 func.walk 逻辑有交集),去重已经在 push_back 时做了。 - // 修改点 1: 初始化 Builder (可以使用 context) - OpBuilder builder(toTensorOp->getContext()); + llvm::outs() << " [ANNOTATE_ACTION] Adding annotation to: " << toTensorOp << "\n"; - // 修改点 2: 显式设置插入点在 toTensorOp 之后 + OpBuilder builder(toTensorOp->getContext()); builder.setInsertionPointAfter(toTensorOp); auto markOp = builder.create( toTensorOp->getLoc(), toTensorOp.getResult()); - // 建议:使用 builder.getContext() 或者 toTensorOp->getContext() 获取上下文, - // 以防外层的 getContext() 在某些闭包或静态函数中不可用 markOp->setAttr("MayImplicitTransposeWithLastAxis", UnitAttr::get(builder.getContext())); - llvm::outs() << " toTensorOp: " << toTensorOp << ", markOp: " << markOp << "\n"; - llvm::outs() << " [MARK_ADDED] Added transpose annotation to bufferization.to_tensor result\n"; - llvm::outs() << " Created annotation::MarkOp: " << markOp << "\n"; + llvm::outs() << " -> Created annotation::MarkOp: " << markOp << "\n"; } - llvm::outs() << "[INFO] After Function body:\n"; - funcOp.print(llvm::outs()); - llvm::outs() << "\n"; - llvm::outs() << "[INFO] Finished AnnotateTransposePass on function: " << funcOp.getName() << "\n"; } -private: - bool needsImplicitTranspose(Value value) { - if (auto memRefType = dyn_cast(value.getType())) { - // 检查是否是置换类型的内存布局 - bool isPermuted = mlir::dicp::isaPermutedMemRefType(memRefType); - llvm::outs() << "zmz [DEBUG] isPermuted: " << isPermuted << "\n"; - llvm::outs() << "Detected permuted memref type: "; - memRefType.dump(); - llvm::outs() << "\n"; - LLVM_DEBUG({ - if(isPermuted) { - llvm::dbgs() << "Detected permuted memref type: "; - memRefType.dump(); - llvm::dbgs() << "\n"; - } - }); - return isPermuted; - } - llvm::outs() << "zmz [DEBUG] Not a MemRefType, skipping.\n"; - return false; - } }; } // namespace @@ -217,4 +297,4 @@ namespace mlir::dicp::LinalgExt { std::unique_ptr> createAnnotateTransposePass() { return std::make_unique(); } -} // namespace mlir::dicp::LinalgExt +} // namespace mlir::dicp::LinalgExt \ No newline at end of file diff --git a/test/ascend/passed_tests/test_lightning_attn.py b/test/ascend/passed_tests/test_lightning_attn.py new file mode 100644 index 00000000..2f8e542a --- /dev/null +++ b/test/ascend/passed_tests/test_lightning_attn.py @@ -0,0 +1,174 @@ +import math +import pytest +import torch + +from dlblas.utils.device_utils import infer_device +from dlblas.kernels.lightning_attn import lightning_attention_decode_forward +from dlblas.kernels.lightning_attn import lightning_attention_prefill_forward +from dlblas.kernels.lightning_attn import BackendType + + +class TestLightningAttn: + + @pytest.fixture + def B(self, request): + yield request.param + + @pytest.fixture + def H(self, request): + yield request.param + + @pytest.fixture + def N(self, request): + yield request.param + + @pytest.fixture + def D(self, request): + yield request.param + + @pytest.fixture + def E(self, request): + yield request.param + + @pytest.fixture + def dtype(self, request): + yield request.param + + @pytest.fixture + def BLOCK_SIZE(self, request): + yield request.param + + @pytest.fixture + def q_states(self, B, H, N, D, dtype): + yield torch.randn([B, H, N, D], dtype=dtype, device=infer_device()) + + @pytest.fixture + def k_states(self, B, H, N, D, dtype): + yield torch.randn([B, H, N, D], dtype=dtype, device=infer_device()) + + @pytest.fixture + def v_states(self, B, H, N, E, dtype): + yield torch.randn([B, H, N, E], dtype=dtype, device=infer_device()) + + @pytest.fixture + def past_key_value(self, B, H, D, E, dtype): + yield torch.randn([B, H, D, E], dtype=dtype, device=infer_device()) + + @pytest.fixture + def slope_rate(self, H, dtype): + def get_slopes(n): + + def get_slopes_power_of_2(n): + start = 2**(-(2**-(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2**math.floor(math.log2(n)) + return (get_slopes_power_of_2(closest_power_of_2) + get_slopes( + 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) + + slope_rate = torch.tensor(get_slopes(H), dtype=dtype, device=infer_device()).reshape(H, 1, 1) + yield slope_rate * (1 + 1e-5) + + # float32 only + @pytest.mark.parametrize( + ['B', 'H', 'N', 'D', 'E', 'dtype', 'BLOCK_SIZE'], + [ + (1, 64, 5, 64, 128, torch.float32, 8), + (1, 64, 72, 64, 64, torch.float32, 8), + # (1, 64, 72, 64, 64, torch.float32, 16), + ], + indirect=True, + ) + def test_lightning_attention_prefill( + self, + q_states, + k_states, + v_states, + slope_rate, + past_key_value, + BLOCK_SIZE, + dtype, + ): + past_key_value_torch = torch.zeros_like(past_key_value) + past_key_value_triton = torch.zeros_like(past_key_value) + out_torch, _ = lightning_attention_prefill_forward(q_states, k_states, v_states, past_key_value_torch, slope_rate, BLOCK_SIZE, BackendType=BackendType.TORCH) + out_triton, _ = lightning_attention_prefill_forward(q_states, k_states, v_states, past_key_value_triton, slope_rate, BLOCK_SIZE, BackendType=BackendType.TRITON) + + if dtype == torch.float32: + rtol=1e-03 + atol=1 + else: + rtol=1e-03 + atol=1 + + kv_check = torch.allclose( + past_key_value_torch, + past_key_value_triton, + rtol=rtol, + atol=atol, + ) + output_check = torch.allclose( + out_torch, + out_triton, + rtol=rtol, + atol=atol, + ) + + assert kv_check, f"past_key_value torch:{past_key_value_torch}, past_key_value triton:{past_key_value_triton}" + assert output_check, f"output torch:{out_torch}, output triton:{out_triton}" + print(f"zmz debug torch kv_check:{past_key_value_torch}, triton :{past_key_value_triton}") + + # float32 only + @pytest.mark.parametrize( + ['B', 'H', 'N', 'D', 'E', 'dtype', 'BLOCK_SIZE'], + [ + (8, 64, 1, 128, 128, torch.float32, 64), + (16, 64, 1, 128, 128, torch.float32, 64), + ], + indirect=True, + ) + def test_lightning_attention_decode( + self, + q_states, + k_states, + v_states, + slope_rate, + past_key_value, + BLOCK_SIZE, + dtype, + ): + past_key_value_torch = torch.zeros_like(past_key_value) + past_key_value_triton = torch.zeros_like(past_key_value) + out_torch, _ = lightning_attention_decode_forward(q_states, k_states, v_states, past_key_value_torch, slope_rate, BLOCK_SIZE, BackendType=BackendType.TORCH) + out_triton, _ = lightning_attention_decode_forward(q_states, k_states, v_states, past_key_value_triton, slope_rate, BLOCK_SIZE, BackendType=BackendType.TRITON) + + if dtype == torch.float32: + rtol=1e-03 + atol=1e-02 + else: + rtol=1e-03 + atol=1e-02 + + kv_check = torch.allclose( + past_key_value_torch, + past_key_value_triton, + rtol=rtol, + atol=atol, + ) + output_check = torch.allclose( + out_torch, + out_triton, + rtol=rtol, + atol=atol, + ) + + assert kv_check, f"past_key_value torch:{past_key_value_torch}, past_key_value triton:{past_key_value_triton}" + assert output_check, f"output torch:{out_torch}, output triton:{out_triton}" + + +if __name__ == '__main__': + pytest.main([__file__]) From e1779381cc147e083ba3fe8c49b1ae1b11931d62 Mon Sep 17 00:00:00 2001 From: hellozmz <407190054@qq.com> Date: Wed, 28 Jan 2026 03:27:18 +0000 Subject: [PATCH 03/14] fix format bug --- backend/device_utils.py | 138 ++++ .../Transforms/AnnotateTransposePass.cpp | 390 ++++++----- .../passed_tests/test_lightning_attn.py | 607 +++++++++++++++++- 3 files changed, 933 insertions(+), 202 deletions(-) create mode 100644 backend/device_utils.py diff --git a/backend/device_utils.py b/backend/device_utils.py new file mode 100644 index 00000000..25e42826 --- /dev/null +++ b/backend/device_utils.py @@ -0,0 +1,138 @@ +# Copyright (c) 2025, DeepLink. +import functools +from typing import Optional + +import torch +import triton + +WARPS_PER_SM = { + (8, 0): 64, + (8, 6): 48, + (8, 7): 48, + (8, 9): 48, + (9, 0): 64, + (10, 0): 64, + (10, 1): 48, + (12, 0): 48, +} + + +@functools.lru_cache +def get_device_props(device=None): + if device is None: + device = torch.cuda.current_device() + + props = torch.cuda.get_device_properties(device) + + warps_per_sm = WARPS_PER_SM.get((props.major, props.minor), 32) + out = dict( + multi_processor_count=props.multi_processor_count, + warps_per_sm=warps_per_sm, + ) + return out + + +@functools.lru_cache +def get_number_cores(): + if is_npu(): + import triton.runtime.driver as driver + + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device)["num_aicore"] + elif is_cuda(): + return torch.cuda.get_device_properties("cuda").multi_processor_count + else: + raise RuntimeError("Please implement this function.") + + +def is_mlu_592(): + target = triton.runtime.driver.active.get_current_target() + return target.backend == "mlu" and target.arch == 592 + + +def is_muxi(): + target = triton.runtime.driver.active.get_current_target() + return target.backend == "maca" + + +@functools.lru_cache +def is_cuda(): + try: + return torch.cuda.is_available() + except Exception: + return False + + +@functools.lru_cache +def is_npu(): + try: + return torch.npu.is_available() + except Exception: + return False + + +@functools.lru_cache +def is_tesla(): + try: + return "Tesla" in torch.cuda.get_device_name(0) + except Exception: + return False + + +@functools.lru_cache +def is_nvidia_hopper(): + try: + return is_cuda() and ( + "NVIDIA H" in torch.cuda.get_device_name(0) + or torch.cuda.get_device_capability()[0] >= 9 + ) + except Exception: + return False + + +def set_allocator(device_: str): + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device=device_, dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + +@functools.lru_cache +def is_tma_supported(): + try: + is_tma_supported = ( + is_cuda() + and torch.cuda.get_device_capability(0)[0] >= 9 + and ( + hasattr(triton.language, "_experimental_make_tensor_descriptor") + or hasattr(triton.language, "make_tensor_descriptor") + ) + ) + if is_tma_supported: + set_allocator("cuda") + return is_tma_supported + except Exception: + return False + + +@functools.lru_cache +def infer_device(): + """ + Get current device name based on available devices + """ + if is_npu(): + return "npu" + elif is_mlu_592(): + return "mlu" + elif is_muxi(): + return "cuda" + elif is_nvidia_hopper(): + return "cuda" + elif is_cuda(): + return "cuda" + else: + return "cpu" + + +NUM_CORES = get_number_cores() +DEVICE = infer_device() diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/AnnotateTransposePass.cpp b/compiler/lib/Dialect/LinalgExt/Transforms/AnnotateTransposePass.cpp index cb4d4a51..bcfdcea7 100644 --- a/compiler/lib/Dialect/LinalgExt/Transforms/AnnotateTransposePass.cpp +++ b/compiler/lib/Dialect/LinalgExt/Transforms/AnnotateTransposePass.cpp @@ -1,13 +1,13 @@ #include "dicp/Dialect/LinalgExt/Transforms/Passes.h" +#include "bishengir/Dialect/Annotation/IR/Annotation.h" +#include "dicp/Utils/Utils.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "bishengir/Dialect/Annotation/IR/Annotation.h" -#include "dicp/Utils/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -47,20 +47,22 @@ namespace { /// 检查一个MemRefType是否是置换类型的辅助函数 /// 判定标准:使用 dicp::isaPermutedMemRefType 或 检查 stride 是否非标准 bool isPermutedOrHasNonUnitLastStride(MemRefType memRefType) { - if (!memRefType) return false; - + if (!memRefType) + return false; + // 1. 使用现有的判定函数 if (mlir::dicp::isaPermutedMemRefType(memRefType)) { return true; } - + // 2. 额外检查:最后维度的stride是否为1 - // 对于 Ascend 来说,如果最后维度 stride != 1,通常意味着不是连续内存,可能需要隐式转置 + // 对于 Ascend 来说,如果最后维度 stride != + // 1,通常意味着不是连续内存,可能需要隐式转置 auto [strides, offset] = memRefType.getStridesAndOffset(); if (!strides.empty() && strides.back() != 1) { return true; } - + return false; } @@ -71,7 +73,7 @@ bool checkValueOriginHasNonStandardStride(Value value) { return true; } } - + // 检查定义操作 if (Operation *defOp = value.getDefiningOp()) { // 检查Subview操作 @@ -83,213 +85,251 @@ bool checkValueOriginHasNonStandardStride(Value value) { return checkValueOriginHasNonStandardStride(castOp.getSource()); } } - + return false; } -struct AnnotateTransposePass - : public mlir::dicp::LinalgExt::impl::AnnotateTransposePassBase { - -void runOnOperation() override { - auto funcOp = getOperation(); - - llvm::outs() << "[INFO] Starting AnnotateTransposePass on function: " - << funcOp.getName() << "\n"; - - // 待处理列表 - SmallVector toTensorOpsToMark; - SmallVector opsToErase; // 用于存储被重写后需要删除的旧Op - - // ============================================================================== - // 1. 遍历 memref.copy 操作 - // 核心逻辑:检测 Dynamic Subview Copy -> 重写为 Static Full Copy + Annotation - // ============================================================================== - funcOp.walk([&](memref::CopyOp copyOp) { - auto source = copyOp.getSource(); - auto target = copyOp.getTarget(); - - llvm::outs() << "[MEMREF_COPY_VISIT] " << copyOp << "\n"; - - // --- 尝试进行 IR 重写 (Rewrite) --- - // 目标:将 memref.copy(subview(A), subview(B)) 转换为 memref.copy(A, B) - // 条件:A 是静态 Permuted,B 是静态 Contiguous,且形状匹配 - - auto srcSubView = source.getDefiningOp(); - auto dstSubView = target.getDefiningOp(); - - if (srcSubView && dstSubView) { - Value baseSource = srcSubView.getSource(); - Value baseTarget = dstSubView.getSource(); - - auto baseSourceType = dyn_cast(baseSource.getType()); - auto baseTargetType = dyn_cast(baseTarget.getType()); - - if (baseSourceType && baseTargetType && - baseSourceType.hasStaticShape() && baseTargetType.hasStaticShape()) { - - bool isBaseSourcePermuted = isPermutedOrHasNonUnitLastStride(baseSourceType); - // 简化判定:如果不是 Permuted 且 stride 正常,视为 Contiguous - bool isBaseTargetContiguous = !isPermutedOrHasNonUnitLastStride(baseTargetType); - - // 检查 Static Shape 是否一致 (例如都是 2x8xf32) - if (isBaseSourcePermuted && isBaseTargetContiguous && - baseSourceType.getShape() == baseTargetType.getShape()) { - - llvm::outs() << " [REWRITE_MATCH] Found Dynamic Subview Copy candidate for Static Rewrite.\n"; - llvm::outs() << " Base Source (Permuted): " << baseSourceType << "\n"; - llvm::outs() << " Base Target (Contiguous): " << baseTargetType << "\n"; - +struct AnnotateTransposePass + : public mlir::dicp::LinalgExt::impl::AnnotateTransposePassBase< + AnnotateTransposePass> { + + void runOnOperation() override { + auto funcOp = getOperation(); + + llvm::outs() << "[INFO] Starting AnnotateTransposePass on function: " + << funcOp.getName() << "\n"; + + // 待处理列表 + SmallVector toTensorOpsToMark; + SmallVector opsToErase; // 用于存储被重写后需要删除的旧Op + + // ============================================================================== + // 1. 遍历 memref.copy 操作 + // 核心逻辑:检测 Dynamic Subview Copy -> 重写为 Static Full Copy + + // Annotation + // ============================================================================== + funcOp.walk([&](memref::CopyOp copyOp) { + auto source = copyOp.getSource(); + auto target = copyOp.getTarget(); + + llvm::outs() << "[MEMREF_COPY_VISIT] " << copyOp << "\n"; + + // --- 尝试进行 IR 重写 (Rewrite) --- + // 目标:将 memref.copy(subview(A), subview(B)) 转换为 memref.copy(A, B) + // 条件:A 是静态 Permuted,B 是静态 Contiguous,且形状匹配 + + auto srcSubView = source.getDefiningOp(); + auto dstSubView = target.getDefiningOp(); + + if (srcSubView && dstSubView) { + Value baseSource = srcSubView.getSource(); + Value baseTarget = dstSubView.getSource(); + + auto baseSourceType = dyn_cast(baseSource.getType()); + auto baseTargetType = dyn_cast(baseTarget.getType()); + + if (baseSourceType && baseTargetType && + baseSourceType.hasStaticShape() && + baseTargetType.hasStaticShape()) { + + bool isBaseSourcePermuted = + isPermutedOrHasNonUnitLastStride(baseSourceType); + // 简化判定:如果不是 Permuted 且 stride 正常,视为 Contiguous + bool isBaseTargetContiguous = + !isPermutedOrHasNonUnitLastStride(baseTargetType); + + // 检查 Static Shape 是否一致 (例如都是 2x8xf32) + if (isBaseSourcePermuted && isBaseTargetContiguous && + baseSourceType.getShape() == baseTargetType.getShape()) { + + llvm::outs() << " [REWRITE_MATCH] Found Dynamic Subview Copy " + "candidate for Static Rewrite.\n"; + llvm::outs() << " Base Source (Permuted): " << baseSourceType + << "\n"; + llvm::outs() << " Base Target (Contiguous): " << baseTargetType + << "\n"; + // 执行重写 OpBuilder builder(copyOp->getContext()); builder.setInsertionPoint(copyOp); - + // 1. 创建新的静态 Copy (Base -> Base) - auto newCopyOp = builder.create(copyOp.getLoc(), baseSource, baseTarget); - llvm::outs() << " -> Replaced with Static Copy: " << newCopyOp << "\n"; + auto newCopyOp = builder.create( + copyOp.getLoc(), baseSource, baseTarget); + llvm::outs() << " -> Replaced with Static Copy: " << newCopyOp + << "\n"; // 2. 关键:在 Base Target (MemRef) 上添加 Annotation // 这指导 Ascend 编译器生成隐式转置指令 builder.setInsertionPointAfter(newCopyOp); - auto markOp = builder.create(copyOp.getLoc(), baseTarget); - markOp->setAttr("MayImplicitTransposeWithLastAxis", UnitAttr::get(builder.getContext())); - llvm::outs() << " -> Added Annotation to Base Target MemRef: " << markOp << "\n"; + auto markOp = + builder.create(copyOp.getLoc(), baseTarget); + markOp->setAttr("MayImplicitTransposeWithLastAxis", + UnitAttr::get(builder.getContext())); + llvm::outs() << " -> Added Annotation to Base Target MemRef: " + << markOp << "\n"; // 3. 追踪 Base Target 的 Tensor 使用者 - // 我们需要标记 bufferization.to_tensor(BaseTarget),这样后续的 MatMul 才能识别到 Layout 变化 + // 我们需要标记 bufferization.to_tensor(BaseTarget),这样后续的 + // MatMul 才能识别到 Layout 变化 for (auto user : baseTarget.getUsers()) { - if (auto toTensorOp = dyn_cast(user)) { - // 去重检查 - bool exists = false; - for(auto op : toTensorOpsToMark) if(op == toTensorOp) exists = true; - - if(!exists) { - toTensorOpsToMark.push_back(toTensorOp); - llvm::outs() << " -> Scheduled Base Target's ToTensorOp for annotation: " << toTensorOp << "\n"; - } + if (auto toTensorOp = dyn_cast(user)) { + // 去重检查 + bool exists = false; + for (auto op : toTensorOpsToMark) + if (op == toTensorOp) + exists = true; + + if (!exists) { + toTensorOpsToMark.push_back(toTensorOp); + llvm::outs() << " -> Scheduled Base Target's ToTensorOp " + "for annotation: " + << toTensorOp << "\n"; } + } } // 4. 标记旧的 Copy Op 待删除 opsToErase.push_back(copyOp); - + // 重写完成,跳过后续分析 - return; + return; + } } } - } - // --- 如果没有触发重写,执行常规的传播分析 --- - // (针对代码中已经是静态 Copy 的情况,或者仅仅进行标记传播) - - if (auto sourceType = dyn_cast(source.getType())) { - bool isSourcePermuted = isPermutedOrHasNonUnitLastStride(sourceType); - - if (auto targetType = dyn_cast(target.getType())) { - bool isTargetPermuted = isPermutedOrHasNonUnitLastStride(targetType); - - // 如果源是置换的,追踪目标的使用者 - if (isSourcePermuted) { - // 检查目标的使用者 - for (auto user : target.getUsers()) { - if (auto toTensorOp = dyn_cast(user)) { - bool exists = false; - for(auto op : toTensorOpsToMark) if(op == toTensorOp) exists = true; - if(!exists) { + // --- 如果没有触发重写,执行常规的传播分析 --- + // (针对代码中已经是静态 Copy 的情况,或者仅仅进行标记传播) + + if (auto sourceType = dyn_cast(source.getType())) { + bool isSourcePermuted = isPermutedOrHasNonUnitLastStride(sourceType); + + if (auto targetType = dyn_cast(target.getType())) { + bool isTargetPermuted = isPermutedOrHasNonUnitLastStride(targetType); + + // 如果源是置换的,追踪目标的使用者 + if (isSourcePermuted) { + // 检查目标的使用者 + for (auto user : target.getUsers()) { + if (auto toTensorOp = dyn_cast(user)) { + bool exists = false; + for (auto op : toTensorOpsToMark) + if (op == toTensorOp) + exists = true; + if (!exists) { toTensorOpsToMark.push_back(toTensorOp); - llvm::outs() << " [PROPAGATE] Marked bufferization.to_tensor (Source was permuted)\n"; - } + llvm::outs() + << " [PROPAGATE] Marked bufferization.to_tensor (Source " + "was permuted)\n"; + } + } } - } - - // 如果目标是 Subview,追踪其父 MemRef - if (auto sourceDefOp = target.getDefiningOp()) { - if (auto subviewOp = dyn_cast(sourceDefOp)) { - Value parentMemRef = subviewOp.getSource(); - for (auto user : parentMemRef.getUsers()) { - if (auto toTensorOp = dyn_cast(user)) { + + // 如果目标是 Subview,追踪其父 MemRef + if (auto sourceDefOp = target.getDefiningOp()) { + if (auto subviewOp = dyn_cast(sourceDefOp)) { + Value parentMemRef = subviewOp.getSource(); + for (auto user : parentMemRef.getUsers()) { + if (auto toTensorOp = + dyn_cast(user)) { bool exists = false; - for(auto op : toTensorOpsToMark) if(op == toTensorOp) exists = true; - if(!exists) { - toTensorOpsToMark.push_back(toTensorOp); - llvm::outs() << " [PROPAGATE_PARENT] Marked bufferization.to_tensor of Parent MemRef\n"; + for (auto op : toTensorOpsToMark) + if (op == toTensorOp) + exists = true; + if (!exists) { + toTensorOpsToMark.push_back(toTensorOp); + llvm::outs() + << " [PROPAGATE_PARENT] Marked " + "bufferization.to_tensor of Parent MemRef\n"; } + } } } } - } - } else if (isTargetPermuted) { - // 如果目标本身就是置换的 - for (auto user : target.getUsers()) { - if (auto toTensorOp = dyn_cast(user)) { + } else if (isTargetPermuted) { + // 如果目标本身就是置换的 + for (auto user : target.getUsers()) { + if (auto toTensorOp = dyn_cast(user)) { bool exists = false; - for(auto op : toTensorOpsToMark) if(op == toTensorOp) exists = true; - if(!exists) { - toTensorOpsToMark.push_back(toTensorOp); - llvm::outs() << " [PROPAGATE_TARGET] Marked bufferization.to_tensor (Target is permuted)\n"; + for (auto op : toTensorOpsToMark) + if (op == toTensorOp) + exists = true; + if (!exists) { + toTensorOpsToMark.push_back(toTensorOp); + llvm::outs() + << " [PROPAGATE_TARGET] Marked bufferization.to_tensor " + "(Target is permuted)\n"; } + } } } } } + }); + + // 删除被重写的旧 Op + for (auto op : opsToErase) { + op->erase(); } - }); - // 删除被重写的旧 Op - for (auto op : opsToErase) { - op->erase(); - } + // ============================================================================== + // 2. 扫描所有 bufferization.to_tensor 操作 (查漏补缺) + // ============================================================================== + funcOp.walk([&](bufferization::ToTensorOp toTensorOp) { + // 如果已经在列表中,跳过 + for (auto existing : toTensorOpsToMark) { + if (existing == toTensorOp) + return; + } + + Value sourceMemRef = toTensorOp.getOperand(); + bool hasNonStandardStride = + checkValueOriginHasNonStandardStride(sourceMemRef); - // ============================================================================== - // 2. 扫描所有 bufferization.to_tensor 操作 (查漏补缺) - // ============================================================================== - funcOp.walk([&](bufferization::ToTensorOp toTensorOp) { - // 如果已经在列表中,跳过 - for(auto existing : toTensorOpsToMark) { if (existing == toTensorOp) return; } - - Value sourceMemRef = toTensorOp.getOperand(); - bool hasNonStandardStride = checkValueOriginHasNonStandardStride(sourceMemRef); - - bool shouldMark = false; - if (auto memRefType = dyn_cast(sourceMemRef.getType())) { - if (isPermutedOrHasNonUnitLastStride(memRefType)) { - shouldMark = true; + bool shouldMark = false; + if (auto memRefType = dyn_cast(sourceMemRef.getType())) { + if (isPermutedOrHasNonUnitLastStride(memRefType)) { + shouldMark = true; + } } - } - - if (shouldMark || hasNonStandardStride) { - toTensorOpsToMark.push_back(toTensorOp); - llvm::outs() << "[TO_TENSOR_CHECK] Found permuted/strided origin: " << toTensorOp << "\n"; - } - }); - - // ============================================================================== - // 3. 执行最终标记:为收集到的 Tensor 添加 Annotation - // ============================================================================== - for (auto toTensorOp : toTensorOpsToMark) { - // 双重检查:防止重复添加 MarkOp (虽然 OpBuilder 会创建新的 Op,但逻辑上我们不希望冗余) - // 简单检查该 Value 是否已经被 MarkOp 使用 - bool alreadyMarked = false; - // 注意:annotation::MarkOp 通常不直接作为 User 挂在 Value 上,而是作为一个独立的 Op 存在。 - // 为了稳妥,这里我们假设 list 中可能有重复(如果 func.walk 逻辑有交集),去重已经在 push_back 时做了。 - - llvm::outs() << " [ANNOTATE_ACTION] Adding annotation to: " << toTensorOp << "\n"; - - OpBuilder builder(toTensorOp->getContext()); - builder.setInsertionPointAfter(toTensorOp); - - auto markOp = builder.create( - toTensorOp->getLoc(), toTensorOp.getResult()); - - markOp->setAttr("MayImplicitTransposeWithLastAxis", - UnitAttr::get(builder.getContext())); - - llvm::outs() << " -> Created annotation::MarkOp: " << markOp << "\n"; - } - llvm::outs() << "[INFO] Finished AnnotateTransposePass on function: " - << funcOp.getName() << "\n"; -} + if (shouldMark || hasNonStandardStride) { + toTensorOpsToMark.push_back(toTensorOp); + llvm::outs() << "[TO_TENSOR_CHECK] Found permuted/strided origin: " + << toTensorOp << "\n"; + } + }); + // ============================================================================== + // 3. 执行最终标记:为收集到的 Tensor 添加 Annotation + // ============================================================================== + for (auto toTensorOp : toTensorOpsToMark) { + // 双重检查:防止重复添加 MarkOp (虽然 OpBuilder 会创建新的 + // Op,但逻辑上我们不希望冗余) 简单检查该 Value 是否已经被 MarkOp 使用 + bool alreadyMarked = false; + // 注意:annotation::MarkOp 通常不直接作为 User 挂在 Value + // 上,而是作为一个独立的 Op 存在。 为了稳妥,这里我们假设 list + // 中可能有重复(如果 func.walk 逻辑有交集),去重已经在 push_back + // 时做了。 + + llvm::outs() << " [ANNOTATE_ACTION] Adding annotation to: " << toTensorOp + << "\n"; + + OpBuilder builder(toTensorOp->getContext()); + builder.setInsertionPointAfter(toTensorOp); + + auto markOp = builder.create(toTensorOp->getLoc(), + toTensorOp.getResult()); + + markOp->setAttr("MayImplicitTransposeWithLastAxis", + UnitAttr::get(builder.getContext())); + + llvm::outs() << " -> Created annotation::MarkOp: " << markOp << "\n"; + } + + llvm::outs() << "[INFO] Finished AnnotateTransposePass on function: " + << funcOp.getName() << "\n"; + } }; } // namespace diff --git a/test/ascend/passed_tests/test_lightning_attn.py b/test/ascend/passed_tests/test_lightning_attn.py index 2f8e542a..899c4b4c 100644 --- a/test/ascend/passed_tests/test_lightning_attn.py +++ b/test/ascend/passed_tests/test_lightning_attn.py @@ -2,10 +2,521 @@ import pytest import torch -from dlblas.utils.device_utils import infer_device -from dlblas.kernels.lightning_attn import lightning_attention_decode_forward -from dlblas.kernels.lightning_attn import lightning_attention_prefill_forward -from dlblas.kernels.lightning_attn import BackendType +from triton.backends.dicp_triton.device_utils import infer_device + +import torch +import enum +from typing import Tuple + +import triton +import triton.language as tl + + +class BackendType(enum.Enum): + """Backend type.""" + + TORCH = enum.auto() + TRITON = enum.auto() + + +def lightning_attention_prefill_forward_torch( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + past_key_value: torch.Tensor, + slope_rate: torch.Tensor, + BLOCK_SIZE=64, + in_place=True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform lightning attention prefill. + modify from: https://github.com/MiniMax-AI/MiniMax-M1/blob/main/modeling_minimax_m1.py + + Args: + q: Query tensor of shape [B, H, N, D] + k: Key tensor of shape [B, H, N, D] + v: Value tensor of shape [B, H, N, E] + kv_caches: Key-value cache tensor [B, H, D, E] + slope_rate: Decay rate tensor + BLOCK_SIZE: Size of blocks for processing + BLOCK_MODEL: Size of blocks for parallel processing + + Returns: + output: Attention output tensor [B, H, N, E] + kv_caches: Key-value cache tensor [B, H, D, E] + """ + b, h, n, d = q.shape + e = v.shape[-1] + assert q.ndim == 4 + assert past_key_value.shape == (b, h, d, e) + + s = slope_rate.to(torch.float32) + NUM_BLOCK = (n + BLOCK_SIZE - 1) // BLOCK_SIZE + + array = torch.arange(BLOCK_SIZE).to(q) + 1 + q_decay = torch.exp(-s * array.reshape(-1, 1)) + k_decay = torch.exp(-s * (BLOCK_SIZE - array.reshape(-1, 1))) + index = array[:, None] - array[None, :] + s_index = ( + s + * index[ + None, + None, + ] + ) + s_index = torch.where(index >= 0, -s_index, float("-inf")) + diag_decay = torch.exp(s_index) + + if past_key_value is not None: + kv = past_key_value + else: + kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device) + output = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device) + + for i in range(NUM_BLOCK): + si = i * BLOCK_SIZE + ei = min(si + BLOCK_SIZE, n) + m = ei - si + qi = q[:, :, si:ei].contiguous() + ki = k[:, :, si:ei].contiguous() + vi = v[:, :, si:ei].contiguous() + qkv_none_diag = torch.matmul(qi * q_decay[:, :m], kv).to(torch.float32) + qk = ( + torch.matmul(qi, ki.transpose(-1, -2)).to(torch.float32) + * diag_decay[:, :, :m, :m] + ) + qkv_diag = torch.matmul(qk, vi.to(torch.float32)) + output[:, :, si:ei] = qkv_none_diag + qkv_diag + block_decay = torch.exp(-s * m) + kv = block_decay * kv + torch.matmul( + (ki * k_decay[:, -m:]).transpose(-1, -2).to(vi.dtype), vi + ) + if in_place: + past_key_value.copy_(kv) + return output, past_key_value + else: + return output, kv + + +@triton.jit +def _fwd_loop_kernel( + q_ptr, + k_ptr, + v_ptr, + output_ptr, + slope_rate, + kv_cache_ptr, + b: tl.constexpr, + h: tl.constexpr, + n: tl.constexpr, + d: tl.constexpr, + e: tl.constexpr, + BLOCK: tl.constexpr, + NUM_BLOCK: tl.constexpr, + BLOCK_MODEL: tl.constexpr, +): + """ + Kernel for lightning attention prefill with KV cache. + """ + # get offset + off_bh = tl.program_id(0) + off_h = off_bh % h + off_e = tl.program_id(1) + qk_offset = off_bh * n * d + v_offset = off_bh * n * e + o_offset = off_bh * n * e + e_offset = off_e * BLOCK_MODEL + kv_offset = off_bh * d * e + + # get block ptr + Q_block_ptr = q_ptr + qk_offset + tl.arange(0, d)[None, :] + K_trans_block_ptr = k_ptr + qk_offset + tl.arange(0, d)[:, None] + V_block_ptr = v_ptr + v_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :] + O_block_ptr = output_ptr + o_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :] + S_block_ptr = slope_rate + off_h + + # init decay + s = tl.load(S_block_ptr).to(tl.float32) + off_block = tl.arange(0, BLOCK) + q_decay = tl.exp(-s.to(tl.float32) * (off_block[:, None] + 1)) + k_trans_decay = tl.exp(-s.to(tl.float32) * (BLOCK - 1 - off_block[None, :])) + block_decay = tl.exp(-s.to(tl.float32) * BLOCK) + + index = off_block[:, None] - off_block[None, :] + s_index = s * index + s_index = tl.where(index >= 0, -s_index, float("-inf")) + diag_decay = tl.exp(s_index) + + kv = tl.zeros([d, BLOCK_MODEL], dtype=tl.float32) + + # loop compute + for i in range(NUM_BLOCK): + if n < BLOCK * (i + 1): + block_decay = tl.exp(-s.to(tl.float32) * (n - BLOCK * i)) + # (BLOCK - 1 - off_block[None, :] + n - BLOCK) + k_trans_decay = tl.exp(-s.to(tl.float32) * (n - 1 - off_block[None, :])) + # load + q = tl.load( + Q_block_ptr + off_block[:, None] * d, mask=off_block[:, None] < n, other=0.0 + ).to(tl.float32) + k_trans = tl.load( + K_trans_block_ptr + off_block[None, :] * d, + mask=off_block[None, :] < n, + other=0.0, + ).to(tl.float32) + v = tl.load( + V_block_ptr + off_block[:, None] * e, mask=off_block[:, None] < n, other=0.0 + ).to(tl.float32) + + # compute + qk = tl.dot(q, k_trans) * diag_decay + o_intra = tl.dot(qk, v) + o_inter = tl.dot(q * q_decay, kv) + o = o_intra + o_inter + + # save and update + tl.store( + O_block_ptr + off_block[:, None] * e, + o.to(O_block_ptr.dtype.element_ty), + mask=off_block[:, None] < n, + ) + kv = block_decay * kv + tl.dot(k_trans * k_trans_decay, v) + off_block += BLOCK + + KV_block_ptr = ( + kv_cache_ptr + + kv_offset + + e_offset + + tl.arange(0, d)[:, None] * e + + tl.arange(0, BLOCK_MODEL)[None, :] + ) + tl.store( + KV_block_ptr, + kv.to(KV_block_ptr.dtype.element_ty), + ) + + +def lightning_attention_prefill_forward_triton_loop( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + past_key_value: torch.Tensor, + slope_rate: torch.Tensor, + BLOCK_SIZE=64, + BLOCK_MODEL=32, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform lightning attention prefill. + modify from: https://github.com/OpenNLPLab/lightning-attention/blob/main/lightning_attn/ops/triton/lightning_attn2.py + + Args: + q: Query tensor of shape [B, H, N, D] + k: Key tensor of shape [B, H, N, D] + v: Value tensor of shape [B, H, N, E] + kv_caches: Key-value cache tensor [B, H, D, E] + slope_rate: Decay rate tensor + BLOCK_SIZE: Size of blocks for processing + BLOCK_MODEL: Size of blocks for parallel processing + + Returns: + output: Attention output tensor [B, H, N, E] + kv_caches: Key-value cache tensor [B, H, D, E] + """ + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + s = slope_rate.contiguous() + + b, h, n, d = q.shape + e = v.shape[-1] + o = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device) + assert past_key_value.shape == (b, h, d, e) + assert o.shape == v.shape + assert o.dtype == v.dtype + + NUM_BLOCK = triton.cdiv(q.shape[2], BLOCK_SIZE) + # parallel over channel + BLOCK_M = min(triton.next_power_of_2(e), BLOCK_MODEL) + assert e % BLOCK_M == 0 + grid = (b * h, triton.cdiv(e, BLOCK_M)) + + if past_key_value is not None: + kv = past_key_value + assert kv.dtype == torch.float32 + else: + kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device) + + _fwd_loop_kernel[grid]( + q, + k, + v, + o, + s, + kv, + b, + h, + n, + d, + e, + BLOCK=BLOCK_SIZE, + NUM_BLOCK=NUM_BLOCK, + BLOCK_MODEL=BLOCK_M, + ) + return o, kv + + +def lightning_attention_decode_forward_torch( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + past_key_value: torch.Tensor, + slope_rate: torch.Tensor, + in_place=True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform lightning attention decoding. + modify from: https://github.com/MiniMax-AI/MiniMax-M1/blob/main/modeling_minimax_m1.py + + Args: + q: Query tensor of shape [B, H, 1, D] + k: Key tensor of shape [B, H, 1, D] + v: Value tensor of shape [B, H, 1, E] + kv_caches: Key-value cache tensor [B, H, D, E] + slope_rate: Decay rate tensor + + Returns: + output: Attention output tensor [B, H, 1, E] + kv_caches: Key-value cache tensor [B, H, D, E] + """ + assert q.ndim == 4 + B, H, _, D = q.shape + E = v.shape[-1] + assert k.shape == (B, H, 1, D) + assert v.shape == (B, H, 1, E) + assert past_key_value.shape == (B, H, D, E) + kv = past_key_value + s = torch.exp(-slope_rate) + kv = ( + torch.einsum( + "... n d, ... n e -> ... d e", + k, + v, + ) + + s * kv + ) + qkv = torch.einsum("... n d, ... d e -> ... n e", q, kv.to(q.dtype)) + past_key_value.copy_(kv) + if in_place: + past_key_value.copy_(kv) + return qkv, past_key_value + else: + return qkv, kv + + +@triton.jit +def _lightningattn_attn_decode_kernel( + q_ptr, + k_ptr, + v_ptr, + kv_cache_ptr, + slope_rate, + output_ptr, + D: tl.constexpr, + qkv_b_stride, + qkv_h_stride, + cache_b_stride, + cache_h_stride, + cache_d_stride, + cache_e_stride, + BLOCK_SIZE: tl.constexpr, +): + """ + Kernel for lightning attention decoding with KV cache. + """ + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_d = tl.program_id(2) + + batch_id = pid_b + head_id = pid_h + + # Load decay rate for the current head + ratio = tl.load(slope_rate + pid_h) + + # Calculate offsets for dimensions + qk_d_offsets = tl.arange(0, D) + v_d_offsets = tl.arange(0, BLOCK_SIZE) + pid_d * BLOCK_SIZE + cache_d_offsets = ( + qk_d_offsets[:, None] * cache_d_stride + v_d_offsets[None, :] * cache_e_stride + ) + + # Calculate offsets for the current batch and head + q_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride + k_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride + v_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride + + cache_offset = batch_id * cache_b_stride + head_id * cache_h_stride + + # Create masks for loading tensors + qk_mask = qk_d_offsets < D + v_mask = v_d_offsets < D + + # Load query, key, and value tensors + q = tl.load(q_ptr + q_offset + qk_d_offsets, mask=qk_mask, other=0.0) + k = tl.load(k_ptr + k_offset + qk_d_offsets, mask=qk_mask, other=0.0) + v = tl.load(v_ptr + v_offset + v_d_offsets, mask=v_mask, other=0.0) + + # Compute key-value outer product + kv_outer = k[:, None] * v[None, :] + kv_mask = qk_mask[:, None] & v_mask[None, :] + + # Apply decay to previous KV cache + ratio = tl.exp(-ratio) + kv_ptr = kv_cache_ptr + cache_offset + cache_d_offsets + kv_cache_old = tl.load(kv_ptr, mask=kv_mask, other=0.0) + kv_outer = kv_outer + ratio * kv_cache_old + + # Compute attention output + output = q[:, None].to(tl.float32) * kv_outer + output = tl.sum(output, axis=0) + tl.store(kv_ptr, kv_outer, mask=kv_mask) + tl.store(output_ptr + q_offset + v_d_offsets, output, mask=v_mask) + + +def lightning_attention_decode_forward_triton( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + past_key_value: torch.Tensor, + slope_rate: torch.Tensor, + BLOCK_SIZE: int = 128, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform lightning attention decoding using Triton kernels. + modify from: https://github.com/vllm-project/vllm/vllm/model_executor/layers/lightning_attn.py + + Args: + q: Query tensor of shape [B, H, 1, D] + k: Key tensor of shape [B, H, 1, D] + v: Value tensor of shape [B, H, 1, E] + kv_caches: Key-value cache tensor + slope_rate: Decay rate tensor + BLOCK_SIZE: Size of blocks for processing + + Returns: + output: Attention output tensor [B, H, 1, E] + kv_caches: Key-value cache tensor [B, H, D, E] + """ + assert q.ndim == 4 + B, H, _, D = q.shape + E = v.shape[-1] + assert k.shape == (B, H, 1, D) + assert v.shape == (B, H, 1, E) + assert past_key_value.shape == (B, H, D, E) + + # Initialize output tensor + o = torch.empty((B, H, 1, E), dtype=q.dtype, device=q.device) + + # Set grid dimensions for the kernel + grid = (B, H, D // BLOCK_SIZE) + + # Calculate strides for tensors + qkv_b_stride = q.stride(0) + qkv_h_stride = q.stride(1) + + cache_b_stride = past_key_value.stride(0) + cache_h_stride = past_key_value.stride(1) + cache_d_stride = past_key_value.stride(2) + cache_e_stride = past_key_value.stride(3) + + # Launch the kernel + _lightningattn_attn_decode_kernel[grid]( + q, + k, + v, + past_key_value, + slope_rate, + o, + D, + qkv_b_stride, + qkv_h_stride, + cache_b_stride, + cache_h_stride, + cache_d_stride, + cache_e_stride, + BLOCK_SIZE=BLOCK_SIZE, + ) + return o, past_key_value + + +def lightning_attention_prefill_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + past_key_value: torch.Tensor, + slope_rate: torch.Tensor, + BLOCK_SIZE=64, + BackendType: int = BackendType.TORCH, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + q: Query tensor of shape [B, H, N, D] + k: Key tensor of shape [B, H, N, D] + v: Value tensor of shape [B, H, N, E] + kv_caches: Key-value cache tensor [B, H, D, E] + slope_rate: Decay rate tensor + BLOCK_SIZE: Size of blocks for processing + BLOCK_MODEL: Size of blocks for parallel processing + + Returns: + output: Attention output tensor [B, H, N, E] + kv_caches: Key-value cache tensor [B, H, D, E] + """ + if BackendType == BackendType.TRITON: + return lightning_attention_prefill_forward_triton_loop( + q, k, v, past_key_value, slope_rate, BLOCK_SIZE + ) + else: + return lightning_attention_prefill_forward_torch( + q, + k, + v, + past_key_value, + slope_rate, + ) + + +def lightning_attention_decode_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + past_key_value: torch.Tensor, + slope_rate: torch.Tensor, + BLOCK_SIZE: int = 128, + BackendType: int = BackendType.TORCH, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + q: Query tensor of shape [B, H, 1, D] + k: Key tensor of shape [B, H, 1, D] + v: Value tensor of shape [B, H, 1, E] + kv_caches: Key-value cache tensor [B, H, D, E] + slope_rate: Decay rate tensor + BLOCK_SIZE: Size of blocks for processing in triton + BackendType: torch or triton + + Returns: + output: Attention output tensor [B, H, 1, E] + kv_caches: Key-value cache tensor [B, H, D, E] + """ + if BackendType == BackendType.TRITON: + return lightning_attention_decode_forward_triton( + q, k, v, past_key_value, slope_rate, BLOCK_SIZE + ) + else: + return lightning_attention_decode_forward_torch( + q, k, v, past_key_value, slope_rate + ) class TestLightningAttn: @@ -59,23 +570,27 @@ def slope_rate(self, H, dtype): def get_slopes(n): def get_slopes_power_of_2(n): - start = 2**(-(2**-(math.log2(n) - 3))) + start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] if math.log2(n).is_integer(): return get_slopes_power_of_2(n) else: - closest_power_of_2 = 2**math.floor(math.log2(n)) - return (get_slopes_power_of_2(closest_power_of_2) + get_slopes( - 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) - slope_rate = torch.tensor(get_slopes(H), dtype=dtype, device=infer_device()).reshape(H, 1, 1) + slope_rate = torch.tensor( + get_slopes(H), dtype=dtype, device=infer_device() + ).reshape(H, 1, 1) yield slope_rate * (1 + 1e-5) # float32 only @pytest.mark.parametrize( - ['B', 'H', 'N', 'D', 'E', 'dtype', 'BLOCK_SIZE'], + ["B", "H", "N", "D", "E", "dtype", "BLOCK_SIZE"], [ (1, 64, 5, 64, 128, torch.float32, 8), (1, 64, 72, 64, 64, torch.float32, 8), @@ -95,15 +610,31 @@ def test_lightning_attention_prefill( ): past_key_value_torch = torch.zeros_like(past_key_value) past_key_value_triton = torch.zeros_like(past_key_value) - out_torch, _ = lightning_attention_prefill_forward(q_states, k_states, v_states, past_key_value_torch, slope_rate, BLOCK_SIZE, BackendType=BackendType.TORCH) - out_triton, _ = lightning_attention_prefill_forward(q_states, k_states, v_states, past_key_value_triton, slope_rate, BLOCK_SIZE, BackendType=BackendType.TRITON) + out_torch, _ = lightning_attention_prefill_forward( + q_states, + k_states, + v_states, + past_key_value_torch, + slope_rate, + BLOCK_SIZE, + BackendType=BackendType.TORCH, + ) + out_triton, _ = lightning_attention_prefill_forward( + q_states, + k_states, + v_states, + past_key_value_triton, + slope_rate, + BLOCK_SIZE, + BackendType=BackendType.TRITON, + ) if dtype == torch.float32: - rtol=1e-03 - atol=1 + rtol = 1e-03 + atol = 1 else: - rtol=1e-03 - atol=1 + rtol = 1e-03 + atol = 1 kv_check = torch.allclose( past_key_value_torch, @@ -118,13 +649,17 @@ def test_lightning_attention_prefill( atol=atol, ) - assert kv_check, f"past_key_value torch:{past_key_value_torch}, past_key_value triton:{past_key_value_triton}" + assert ( + kv_check + ), f"past_key_value torch:{past_key_value_torch}, past_key_value triton:{past_key_value_triton}" assert output_check, f"output torch:{out_torch}, output triton:{out_triton}" - print(f"zmz debug torch kv_check:{past_key_value_torch}, triton :{past_key_value_triton}") + print( + f"zmz debug torch kv_check:{past_key_value_torch}, triton :{past_key_value_triton}" + ) # float32 only @pytest.mark.parametrize( - ['B', 'H', 'N', 'D', 'E', 'dtype', 'BLOCK_SIZE'], + ["B", "H", "N", "D", "E", "dtype", "BLOCK_SIZE"], [ (8, 64, 1, 128, 128, torch.float32, 64), (16, 64, 1, 128, 128, torch.float32, 64), @@ -143,15 +678,31 @@ def test_lightning_attention_decode( ): past_key_value_torch = torch.zeros_like(past_key_value) past_key_value_triton = torch.zeros_like(past_key_value) - out_torch, _ = lightning_attention_decode_forward(q_states, k_states, v_states, past_key_value_torch, slope_rate, BLOCK_SIZE, BackendType=BackendType.TORCH) - out_triton, _ = lightning_attention_decode_forward(q_states, k_states, v_states, past_key_value_triton, slope_rate, BLOCK_SIZE, BackendType=BackendType.TRITON) + out_torch, _ = lightning_attention_decode_forward( + q_states, + k_states, + v_states, + past_key_value_torch, + slope_rate, + BLOCK_SIZE, + BackendType=BackendType.TORCH, + ) + out_triton, _ = lightning_attention_decode_forward( + q_states, + k_states, + v_states, + past_key_value_triton, + slope_rate, + BLOCK_SIZE, + BackendType=BackendType.TRITON, + ) if dtype == torch.float32: - rtol=1e-03 - atol=1e-02 + rtol = 1e-03 + atol = 1e-02 else: - rtol=1e-03 - atol=1e-02 + rtol = 1e-03 + atol = 1e-02 kv_check = torch.allclose( past_key_value_torch, @@ -166,9 +717,11 @@ def test_lightning_attention_decode( atol=atol, ) - assert kv_check, f"past_key_value torch:{past_key_value_torch}, past_key_value triton:{past_key_value_triton}" + assert ( + kv_check + ), f"past_key_value torch:{past_key_value_torch}, past_key_value triton:{past_key_value_triton}" assert output_check, f"output torch:{out_torch}, output triton:{out_triton}" -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) From 1233d3756dd95809fa10e9c500bb28d35ce57aa6 Mon Sep 17 00:00:00 2001 From: hellozmz <407190054@qq.com> Date: Wed, 28 Jan 2026 05:51:01 +0000 Subject: [PATCH 04/14] fix format --- .../Dialect/LinalgExt/Transforms/Passes.h | 3 ++- compiler/include/dicp/Utils/Utils.h | 21 +++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h index d82378c1..0e18f26d 100644 --- a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h +++ b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h @@ -20,7 +20,8 @@ std::unique_ptr> createLinalgGenericToSCFPass(); std::unique_ptr> createScalarTo1DTensorPass(); -std::unique_ptr> createAnnotateTransposePass(); +std::unique_ptr> +createAnnotateTransposePass(); std::unique_ptr> createNormalizeSliceOpsPass(); diff --git a/compiler/include/dicp/Utils/Utils.h b/compiler/include/dicp/Utils/Utils.h index dd6964ce..6ba86c0d 100644 --- a/compiler/include/dicp/Utils/Utils.h +++ b/compiler/include/dicp/Utils/Utils.h @@ -3,18 +3,15 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" - +#include "mlir/Transforms/DialectConversion.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -129,11 +126,13 @@ inline bool isaPermutedMemRefType(MemRefType memRefType) { // }); switch (ptrStrides.size()) { - case 0: return false; - case 1: return false; - default: { - return ptrStrides[ptrStrides.size()-1] != 1; - } + case 0: + return false; + case 1: + return false; + default: { + return ptrStrides[ptrStrides.size() - 1] != 1; + } } } From f7e84e071897957c192d5c05f390fd9bbf23a498 Mon Sep 17 00:00:00 2001 From: hellozmz <407190054@qq.com> Date: Thu, 29 Jan 2026 11:54:49 +0000 Subject: [PATCH 05/14] add metedata --- backend/npu.py | 5 ++- compiler/include/dicp/Utils/Utils.h | 64 ++--------------------------- 2 files changed, 8 insertions(+), 61 deletions(-) diff --git a/backend/npu.py b/backend/npu.py index 7278060e..55b7ca82 100644 --- a/backend/npu.py +++ b/backend/npu.py @@ -434,7 +434,10 @@ def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False): dicp_triton.passes.linked_npu.add_linalg_generic_to_scf(pm) dicp_triton.passes.linked_npu.add_scalar_to_1d_tensor(pm) dicp_triton.passes.linked_npu.add_linalg_to_linked(pm, False, True) - dicp_triton.passes.linked_npu.add_annotate_transpose(pm) + # 当metadata 中有add_annotate_transpose 开关,再开启 + open_add_annotate_transpose = metadata["add_annotate_transpose"] + if open_add_annotate_transpose is not None and open_add_annotate_transpose is True: + dicp_triton.passes.linked_npu.add_annotate_transpose(pm) dicp_triton.passes.linked_npu.add_linked_to_hivm(pm) pm.run(mod) diff --git a/compiler/include/dicp/Utils/Utils.h b/compiler/include/dicp/Utils/Utils.h index 6ba86c0d..abaf9dc0 100644 --- a/compiler/include/dicp/Utils/Utils.h +++ b/compiler/include/dicp/Utils/Utils.h @@ -7,29 +7,22 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" #include "mlir/IR/TypeUtilities.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" @@ -38,56 +31,13 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/StringSwitch.h" - -#include -#include - -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/Operation.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/ArrayRef.h" #include #include - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Value.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Types.h" - -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/SmallVectorExtras.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/LogicalResult.h" - #include #include #include -#include #include -#include // Dispatch conversion pattern handlers based on backend string. Executes // ASCEND_HANDLER when backend == "ascend", otherwise DEFAULT_HANDLER. @@ -116,14 +66,8 @@ llvm::StringRef getBackend(ModuleOp module); bool isAscendBackend(ModuleOp module); -// bool isaPermutedMemRefType(MemRefType); inline bool isaPermutedMemRefType(MemRefType memRefType) { auto [ptrStrides, ptrOffsets] = memRefType.getStridesAndOffset(); - // LLVM_DEBUG({ - // llvm::dbgs()<<"---------- [BEG] ptrStrides ----------\n"; - // for(auto stride: ptrStrides)llvm::dbgs()< Date: Thu, 29 Jan 2026 12:02:32 +0000 Subject: [PATCH 06/14] fix bugs --- backend/npu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/npu.py b/backend/npu.py index 55b7ca82..54dfddcc 100644 --- a/backend/npu.py +++ b/backend/npu.py @@ -434,7 +434,6 @@ def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False): dicp_triton.passes.linked_npu.add_linalg_generic_to_scf(pm) dicp_triton.passes.linked_npu.add_scalar_to_1d_tensor(pm) dicp_triton.passes.linked_npu.add_linalg_to_linked(pm, False, True) - # 当metadata 中有add_annotate_transpose 开关,再开启 open_add_annotate_transpose = metadata["add_annotate_transpose"] if open_add_annotate_transpose is not None and open_add_annotate_transpose is True: dicp_triton.passes.linked_npu.add_annotate_transpose(pm) @@ -798,6 +797,7 @@ class NPUOptions: tile_mix_cube_loop: int = None limit_auto_multi_buffer_only_for_local_buffer: bool = None set_workspace_multibuffer: int = None + add_annotate_transpose: bool = None stream: int = None From b898509fec1dae9439e8e31efbcfca900121bdfd Mon Sep 17 00:00:00 2001 From: zhaochaoxing <109726331+zhaochaoxing@users.noreply.github.com> Date: Wed, 21 Jan 2026 10:36:03 +0800 Subject: [PATCH 07/14] =?UTF-8?q?Zcx/dev=EF=BC=9Asupport=20common=20ir=20(?= =?UTF-8?q?#137)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix CI * fix ut_tests warning * support dump bisheng ir * support common ir * fix format --- backend/commonir/__init__.py | 0 backend/commonir/adapter.py | 2 + backend/commonir/backend.py | 214 +++++++++++++++++++++++++++++++++++ backend/commonir/compiler.py | 195 +++++++++++++++++++++++++++++++ backend/npu.py | 58 +++++++++- 5 files changed, 468 insertions(+), 1 deletion(-) create mode 100644 backend/commonir/__init__.py create mode 100644 backend/commonir/adapter.py create mode 100644 backend/commonir/backend.py create mode 100644 backend/commonir/compiler.py diff --git a/backend/commonir/__init__.py b/backend/commonir/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/commonir/adapter.py b/backend/commonir/adapter.py new file mode 100644 index 00000000..27acd085 --- /dev/null +++ b/backend/commonir/adapter.py @@ -0,0 +1,2 @@ +class Adapter(object): + pass diff --git a/backend/commonir/backend.py b/backend/commonir/backend.py new file mode 100644 index 00000000..7f49170c --- /dev/null +++ b/backend/commonir/backend.py @@ -0,0 +1,214 @@ +import functools +import os +from typing import Any +from ..compiler import DICPOptions +from ..driver import DICPDriver +from ..utils import get_current_backend + + +class CommonIRBackend: + binary_ext = "ttlinalgdir" + + def __init__(self) -> None: + target = get_current_backend() + self.driver = DICPDriver(target) + if self.driver.target == "dicp": + self.binary_ext = "ttlinalgdir" + elif self.driver.target == "mlu": + self.capability = target.arch + assert isinstance(self.capability, int) + self.binary_ext = "cnbin" + elif self.driver.target == "maca": + self.capability = 80 + self.binary_ext = "mcfatbin" + elif self.driver.target == "ascend": + self.binary_ext = "npubin" + else: + raise RuntimeError(f"Target '{self.target_type}' is not supported.") + + def get_attrs_descriptor(self, params, args): + if self.driver.target == "ascend": + from triton.backends.dicp_triton.npu import AscendAttrsDescriptor + + return AscendAttrsDescriptor(params, args) + else: + raise RuntimeError( + f"backend {self.driver.target} not supported for get_attrs_descriptor." + ) + + def add_stages(self, stages, options, language=None): + + if self.driver.target == "ascend": + from triton.backends.dicp_triton.npu import ( + commonir_to_linkedir, + linalg_to_bin_enable_npu_compile, + ) + + stages["linkedir"] = lambda src, metadata: commonir_to_linkedir( + src, metadata, options, named_ops=True + ) + stages["npubin"] = lambda src, metadata: linalg_to_bin_enable_npu_compile( + src, metadata, options + ) + else: + raise RuntimeError("backend not supported") + + def load_dialects(self, ctx): + if self.driver.target == "mlu": + from triton._C.libtriton import mlu + + mlu.load_dialects(ctx) + return + + def get_driver(self): + return self.driver + + # parse add_kernel[(16,)](x, y, output, n_elements, BLOCK_SIZE=1024) + def parse_options(self, options: dict) -> Any: + if self.driver.target == "ascend": + from triton.backends.dicp_triton.npu import NPUOptions + + args = { + k: options[k] + for k in NPUOptions.__dataclass_fields__.keys() + if k in options + } + options = NPUOptions(**args) + return options + elif self.driver.target == "mlu": + from triton.backends.dicp_triton.mlu import MLUOptions + + args = { + k: options[k] + for k in MLUOptions.__dataclass_fields__.keys() + if k in options + } + # When arch is less than mtp_5xx, tf32 is not supported, use fp32 for calculation. + if "allowed_dot_input_precisions" not in args: + if self.capability < 500: + args["allowed_dot_input_precisions"] = "ieee" + + if "supported_fp8_dtypes" not in args: + supported_fp8_dtypes = set(MLUOptions.supported_fp8_dtypes) + if self.capability >= 600: + supported_fp8_dtypes = supported_fp8_dtypes.union( + ("fp8e5", "fp8e4nv") + ) + args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes)) + + args["max_num_imprecise_acc_default"] = 0 + + if "enable_fp_fusion" not in args: + args["enable_fp_fusion"] = ( + os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1" + ) + + if "enable_mlu_bound_check" not in args: + args["enable_mlu_bound_check"] = ( + os.getenv("TRITON_ENABLE_MLU_BOUND_CHECK", "0") == "1" + ) + return MLUOptions(**args) + elif self.driver.target == "maca": + from triton.backends.dicp_triton.maca import MACAOptions + + # args = {k: options[k] for k in MACAOptions.__dataclass_fields__.keys() if k in options} + # return MACAOptions(**args) + args = { + k: options[k] + for k in MACAOptions.__dataclass_fields__.keys() + if k in options + } + # USE_MACA: support allow_fp8e4nv(i.e. float8_e4m3fn) + args["allow_fp8e4nv"] = True + # args["allow_fp8e4nv"] = False + args["allow_fp8e4b15"] = False + args["max_num_imprecise_acc_default"] = ( + 2**30 if self.capability == 90 else 0 + ) + return MACAOptions(**args) + else: + args = {"arch": self.target} + args.update( + { + k: options[k] + for k in DICPOptions.__dataclass_fields__.keys() + if k in options + } + ) + return DICPOptions(**args) + + def get_codegen_implementation(self, options=None): + codegen_fns = dict() + if self.driver.target == "ascend": + from triton.backends.dicp_triton.npu import min_dot_size + + codegen_fns = {"min_dot_size": min_dot_size(self.target)} + elif self.driver.target == "mlu": + from triton.backends.dicp_triton.mlu import min_dot_size + + codegen_fns = { + "convert_custom_types": lambda arg, dst_ty: arg, + "min_dot_size": min_dot_size(self.target), + } + elif self.driver.target == "maca": + import triton.language.extra.cuda as cuda + + codegen_fns = { + "convert_custom_types": ( + cuda.convert_custom_float8_sm80 + if self.capability >= 80 + else cuda.convert_custom_float8_sm70 + ) + } + return codegen_fns + + def pack_metadata(self, metadata): + if self.driver.target == "ascend": + from triton.backends.dicp_triton.npu import TRITON_PROFILER_REGISTERED + + # collect necessary metadata to launch kernels + # TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 could set unique name. + # Get this name as the kernel_name to CANN runtime. + # kernel_name is unique to Ascend backend and should not be public. + # CANN runtime limits the length of kernel name <= 50. + # Considering '\n' is appended, thus the real kernel name <= 49. + KERNEL_NAME_MAX_LEN = 49 + kernel_name_orig, mix_mode = metadata.name.split() + if len(kernel_name_orig) > KERNEL_NAME_MAX_LEN: + kernel_name = kernel_name_orig[-KERNEL_NAME_MAX_LEN:] + # import warnings + # # red = "\x1b[31;20m" + # # reset = "\x1b[0m" + # warnings.warn(kernel_name_orig + " is truncated to " + kernel_name) + # warnings.warn("because '" + kernel_name_orig + "' exceeds torchnpu profiler's length limit < 50") + else: + kernel_name = kernel_name_orig + return { + "kernel_name": kernel_name, + "hash": metadata.hash, + "debug": metadata.debug, + "profiler_registered": TRITON_PROFILER_REGISTERED, + } + elif self.driver.target == "mlu": + return (metadata.num_warps,) + return ( + metadata.num_warps, + metadata.num_ctas, + metadata.shared, + metadata.cluster_dims[0], + metadata.cluster_dims[1], + metadata.cluster_dims[2], + ) + + @functools.lru_cache() + def hash(self): + if self.driver.target == "mlu": + from triton.backends.dicp_triton.mlu import get_cnas_version + + version = get_cnas_version() + return f"{version}-{self.capability}" + version_key = self.driver.target + return str(version_key) + + +commonir_backend = CommonIRBackend() diff --git a/backend/commonir/compiler.py b/backend/commonir/compiler.py new file mode 100644 index 00000000..740dd823 --- /dev/null +++ b/backend/commonir/compiler.py @@ -0,0 +1,195 @@ +import functools +import hashlib +import json +from pathlib import Path +from typing import Any, List +from triton._C.libtriton import get_cache_invalidating_env_vars + +from triton.runtime.cache import triton_key +from .backend import commonir_backend +from triton.backends.compiler import GPUTarget +from triton.compiler.compiler import AsmDict, _raise_error +from triton.compiler.compiler import LazyDict +from triton.runtime.cache import get_cache_manager + + +class CommonIRSource: + def __init__(self, src: str, grid: List[int], signature: dict): + self.src = src + self.grid = grid + self.signature = signature + + +class CompiledKernel: + def __init__(self, src: CommonIRSource, metadata_group, hash): + from collections import namedtuple + + metadata_path = next( + (Path(p) for c, p in metadata_group.items() if c.endswith(".json")) + ) + metadata = json.loads(metadata_path.read_text()) + metadata["cluster_dims"] = tuple(metadata["cluster_dims"]) + # JSON serialization dumps the target as a dict. Restore it to a GPUTarget. + target = metadata["target"] + metadata["target"] = GPUTarget( + target["backend"], target["arch"], target["warp_size"] + ) + KernelMetadata = namedtuple("KernelMetadata", sorted(list(metadata.keys()))) + self.metadata = KernelMetadata(**metadata) + self.packed_metadata = commonir_backend.pack_metadata(self.metadata) + self.src = src + self.hash = hash + self.name = self.metadata.name + self.grid = src.grid + # stores the text of each level of IR that was generated during compilation + asm_files = [ + Path(p) for c, p in metadata_group.items() if not c.endswith(".json") + ] + binary_ext = commonir_backend.binary_ext + self.asm = AsmDict( + { + file.suffix[1:]: ( + file.read_bytes() + if file.suffix[1:] == binary_ext + else file.read_text() + ) + for file in asm_files + } + ) + self.metadata_group = metadata_group + self.kernel = self.asm[binary_ext] + # binaries are lazily initialized + # because it involves doing runtime things + # (e.g., checking amount of shared memory on current device) + self.module = None + self.function = None + self._run = None + + def _init_handles(self): + if self.module is not None: + return + + def raise_(err): + self._run = functools.partial(_raise_error, err) + raise err + + device = commonir_backend.get_driver().get_current_device() + # create launcher + self._run = commonir_backend.get_driver().launcher_cls(self.src, self.metadata) + ( + self.module, + self.function, + self.n_regs, + self.n_spills, + ) = commonir_backend.get_driver().utils.load_binary( + self.name, self.kernel, self.metadata.shared, device + ) + + @property + def run(self): + if self._run is None: + self._init_handles() + return self._run + + def launch_metadata(self, grid, stream, *args): + self._init_handles() + ret = LazyDict({"name": self.name, "function": self.function, "stream": stream}) + return ret + + def __call__(self, *args: Any) -> Any: + device = commonir_backend.get_driver().get_current_device() + stream = commonir_backend.get_driver().get_current_stream(device) + # launch kernel + + launch_metadata = self.launch_metadata(self.grid, stream, *args) + self.run( + self.grid[0], + self.grid[1], + self.grid[2], + stream, + self.function, + self.packed_metadata, + launch_metadata, + None, # knobs.runtime.launch_enter_hook, + None, # knobs.runtime.launch_exit_hook, + *args, + ) + + def __getitem__(self, grid): + self._init_handles() + + def runner(*args, stream=None): + if stream is None: + device = commonir_backend.get_driver().get_current_device() + stream = commonir_backend.get_driver().get_current_stream(device) + launch_metadata = self.launch_metadata(grid, stream, *args) + self.run( + grid[0], + grid[1], + grid[2], + stream, + self.function, + self.packed_metadata, + launch_metadata, + None, # knobs.runtime.launch_enter_hook, + None, # knobs.runtime.launch_exit_hook, + *args, + ) + + return runner + + +class CommonIRCompiler(object): + + def compile(self, commonir_src: CommonIRSource, options=None, _env_vars=None): + + target = commonir_backend.get_driver().get_current_target() + assert isinstance(target, GPUTarget), "target must be of GPUTarget type" + + extra_options = {} + options = commonir_backend.parse_options( + dict(options or dict(), **extra_options) + ) + # create cache manager + env_vars = get_cache_invalidating_env_vars() if _env_vars is None else _env_vars + + src_hash = hashlib.sha256(commonir_src.src.encode("utf-8")).hexdigest() + key = f"{triton_key()}-{src_hash}-{commonir_backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}" + hash = hashlib.sha256(key.encode("utf-8")).hexdigest() + fn_cache_manager = get_cache_manager(hash) + store_only_binary = False + file_name = "tilelang-commonir" + metadata_filename = f"{file_name}.json" + metadata_group = fn_cache_manager.get_group(metadata_filename) or {} + # initialize metadata + metadata = { + "hash": hash, + "target": target, + **options.__dict__, + **env_vars, + } + # run compilation pipeline and populate metadata + stages = dict() + commonir_backend.add_stages(stages, options) + module = commonir_src.src + ir_filename = f"{file_name}.source" + metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename) + + for ext, compile_ir in list(stages.items()): + next_module = compile_ir(module, metadata) + ir_filename = f"{file_name}.{ext}" + if (not store_only_binary) or (ext in ("cubin", "hsaco", "json")): + metadata_group[ir_filename] = fn_cache_manager.put( + next_module, ir_filename + ) + module = next_module + # write-back metadata + metadata_group[metadata_filename] = fn_cache_manager.put( + json.dumps(metadata, default=vars), metadata_filename, binary=False + ) + fn_cache_manager.put_group(metadata_filename, metadata_group) + return CompiledKernel(commonir_src, metadata_group, hash) + + @functools.lru_cache() + def hash(self): + return "CommonIRCompiler" diff --git a/backend/npu.py b/backend/npu.py index 54dfddcc..578546f8 100644 --- a/backend/npu.py +++ b/backend/npu.py @@ -17,7 +17,6 @@ import pybind11 import shutil - ###################### utils.py start ###################### TRITON_PROFILER_REGISTERED = False @@ -426,6 +425,63 @@ def ttir_to_ttsharedir_ascend(mod, metadata, opt, *, named_ops=False): return mod +def commonir_to_linkedir(commonir, metadata, opt, *, named_ops=False): + assert isinstance(commonir, str) + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "kernel.commonir.mlir") + dst_path = os.path.join(tmpdir, "kernel.linked.mlir") + Path(src_path).write_text(commonir) + cmd_list = [ + _get_dicp_opt_path(), + src_path, + "--lower-affine", + "--normalize-slice-ops", + "--linalg-if-to-select", + "--linalg-generic-to-scf", + "--scalar-to-1d-tensor", + f"--linalg-to-linked=global-kernel=false named-ops=true", + "--linked-to-hivm", + "-o", + dst_path, + ] + try: + ret = subprocess.run(cmd_list, capture_output=True, check=True) + except subprocess.CalledProcessError as e: + print(f"Error: code={e.returncode}, stdout:{e.stdout},stderr: {e.stderr}") + content = Path(dst_path).read_text() + + # TODO(zmz): 修改test_path 中内容,暂时在python中处理,bishengir-compile后续会支持,去掉这里逻辑。 + # 将"*xfxxx"替换成"?xfxxx" + content = content.replace("*xf", "?xf") + content = content.replace("*xi", "?xi") + content = content.replace("*xbf", "?xbf") + # 匹配形如 "memref<...> to tensor<...>" 的模式 + pattern = r"(memref\<.*?\>)\s+to\s+(tensor\<.*?\>)" + # 使用正则替换,保留memref和tensor类型,中间插入注释 + content = re.sub(pattern, r"\1 // to \2", content) + + if opt.debug or dump_ir: + cmd_list = [ + _get_dicp_opt_path(), + "kernel.ttshared.mlir", + "--lower-affine", + "--normalize-slice-ops", + "--linalg-if-to-select", + "--linalg-generic-to-scf", + "--scalar-to-1d-tensor", + f"--linalg-to-linked=global-kernel=false named-ops=true", + "--linked-to-hivm", + ] + dicp_utils._dump_stage_ir( + content, metadata["hash"], "kernel.linkedir.mlir", cmd_list + ) + + if replace_linked_ir is not None: + print(f"[DEBUG] Replace Linkedir with {replace_linked_ir}") + return Path(replace_linked_ir).read_text() + return content + + def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False): pm = ir.pass_manager(mod.context) dicp_triton.passes.linked_npu.add_lower_affine(pm) From 455b2667a3c4fb8be327a7e621bdbd48177f02e3 Mon Sep 17 00:00:00 2001 From: zhaochaoxing <109726331+zhaochaoxing@users.noreply.github.com> Date: Mon, 26 Jan 2026 16:57:44 +0800 Subject: [PATCH 08/14] support tilelang kernel cache (#140) --- backend/commonir/adapter.py | 220 ++++++++++++++++++++++++++++++++++- backend/commonir/compiler.py | 14 +-- backend/npu.py | 4 +- 3 files changed, 222 insertions(+), 16 deletions(-) diff --git a/backend/commonir/adapter.py b/backend/commonir/adapter.py index 27acd085..4d99202a 100644 --- a/backend/commonir/adapter.py +++ b/backend/commonir/adapter.py @@ -1,2 +1,218 @@ -class Adapter(object): - pass +import os +import re +from typing import Callable, List +from triton.backends.dicp_triton.commonir.compiler import ( + CommonIRCompiler, + CommonIRSource, + CompiledKernel, +) + + +class AdapterWrapper: + def __init__(self) -> None: + from tilelang import tvm as tvm + from tvm import tir + from tilelang.engine.param import KernelParam + from tilelang.jit.adapter import BaseKernelAdapter + + class Artifact: + def __init__(self) -> None: + self.kernel_source: str = None + self.params: List[KernelParam] = None + + def set_kernel_source(self, kernel_source) -> None: + self.kernel_source = str(kernel_source) + self.params = self._extrac_params(kernel_source) + + def _extrac_params(self, func: tir.PrimFunc) -> List[KernelParam]: + tensor_types = [] + for var in func.params: + if var in func.buffer_map: + tensor_types.append( + KernelParam.from_buffer(func.buffer_map[var]) + ) + else: + tensor_types.append(KernelParam.from_var(var)) + return tensor_types + + class Adapter(BaseKernelAdapter): + def __init__(self) -> None: + self.mod = None + self.func = None + self.libpath = None + self.kernel_source = None + + def set_info(self, mod, kernel_source, func: CompiledKernel) -> None: + self.mod = mod + self.func = func + self.libpath = func._run.so_launcher_path + self.kernel_source = str(kernel_source) + + def _convert_torch_func(self) -> Callable: + return self.func + + def get_kernel_source(self) -> str: + return self.kernel_source + + self.adapter = Adapter() + self.artifact = Artifact() + + @classmethod + def compile_and_create_adapter(cls, tilelang_module): + adapter_wrapper = AdapterWrapper() + adapter_wrapper.artifact.set_kernel_source(tilelang_module) + mlir_content = cls._tilelang_to_commonir(tilelang_module) + grid = cls._parse_grid(tilelang_module) + signature = cls._parse_signature(mlir_content) + + commonir_compiler = CommonIRCompiler() + func = commonir_compiler.compile(CommonIRSource(mlir_content, grid, signature)) + adapter_wrapper.adapter.set_info(mlir_content, tilelang_module, func) + + return adapter_wrapper + + @classmethod + def from_database( + cls, + params, + result_idx, + target, + func_or_mod, + kernel_global_source, + kernel_lib_path, + pass_configs, + ): + return cls.compile_and_create_adapter(func_or_mod) + + @classmethod + def _tilelang_to_commonir(cls, tilelang_module): + from tilelang.engine import lower + from tilelang import tvm as tvm + from tvm.ir.instrument import PrintAfterAll, PrintBeforeAll + + debug_enabled = os.environ.get("TILELANG_PRINT_COMMONIR", "0") in ( + "1", + "true", + "on", + ) + + instruments = [PrintAfterAll(), PrintBeforeAll()] if debug_enabled else [] + with tvm.transform.PassContext(instruments=instruments): + mlir_path = lower(tilelang_module) + if mlir_path.endswith(".mlir"): + mlir_content = cls._read_mlir_file(mlir_path) + else: + mlir_content = mlir_path + return mlir_content + + @classmethod + def _parse_grid(cls, tilelang_module): + patterns = { + "x": r'T\.launch_thread\("blockIdx\.x",\s*(\d+)\)', + "y": r'T\.launch_thread\("blockIdx\.y",\s*(\d+)\)', + "z": r'T\.launch_thread\("blockIdx\.z",\s*(\d+)\)', + } + block_indices = {"x": None, "y": None, "z": None} + for dim, pattern in patterns.items(): + match = re.search(pattern, str(str(tilelang_module))) + if match: + block_indices[dim] = int(match.group(1)) + return [ + block_indices["x"] if block_indices["x"] is not None else 1, + block_indices["y"] if block_indices["y"] is not None else 1, + block_indices["z"] if block_indices["z"] is not None else 1, + ] + + @classmethod + def _read_mlir_file(cls, file_path) -> str: + try: + with open(file_path, "r", encoding="utf-8") as file: + content = file.read() + return content + except FileNotFoundError: + print(f"Error: File '{file_path}' does not exist") + return None + except Exception as e: + print(f"Error occurred while reading the file: {e}") + return None + + @classmethod + def _parse_signature(cls, mlir_content) -> dict: + target_types = { + "i1", + "i8", + "i16", + "i32", + "i64", + "u32", + "u64", + "fp16", + "bf16", + "fp32", + "f32", + "fp64", + "f16", + } + + pattern = r"func\.func\s*@[^(]*\(([^)]*)\)" + match = re.search(pattern, mlir_content) + + if not match: + return {} + + params_str = match.group(1) + + params = [] + current_param = "" + brace_count = 0 + angle_count = 0 + + for char in params_str: + if char == "," and brace_count == 0 and angle_count == 0: + params.append(current_param.strip()) + current_param = "" + else: + current_param += char + if char == "{": + brace_count += 1 + elif char == "}": + brace_count -= 1 + elif char == "<": + angle_count += 1 + elif char == ">": + angle_count -= 1 + + if current_param: + params.append(current_param.strip()) + + result = {} + index = 0 + + for param in params: + if re.match(r"%args\d+", param.strip()): + continue + + found_type = None + for t_type in target_types: + x_pattern = r"\bx" + t_type + r"\b" + if re.search(x_pattern, param): + found_type = "*" + t_type + break + elif re.search(r"\b" + t_type + r"\b", param): + found_type = t_type + break + + if found_type: + if found_type == "f16": + found_type = "fp16" + elif found_type == "*f16": + found_type = "*fp16" + elif found_type == "f32": + found_type = "fp32" + elif found_type == "*f32": + found_type = "*fp32" + + result[index] = found_type + index += 1 + + return result diff --git a/backend/commonir/compiler.py b/backend/commonir/compiler.py index 740dd823..10e22d2d 100644 --- a/backend/commonir/compiler.py +++ b/backend/commonir/compiler.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Any, List from triton._C.libtriton import get_cache_invalidating_env_vars - +from collections import namedtuple from triton.runtime.cache import triton_key from .backend import commonir_backend from triton.backends.compiler import GPUTarget @@ -22,8 +22,6 @@ def __init__(self, src: str, grid: List[int], signature: dict): class CompiledKernel: def __init__(self, src: CommonIRSource, metadata_group, hash): - from collections import namedtuple - metadata_path = next( (Path(p) for c, p in metadata_group.items() if c.endswith(".json")) ) @@ -58,21 +56,13 @@ def __init__(self, src: CommonIRSource, metadata_group, hash): ) self.metadata_group = metadata_group self.kernel = self.asm[binary_ext] - # binaries are lazily initialized - # because it involves doing runtime things - # (e.g., checking amount of shared memory on current device) self.module = None - self.function = None - self._run = None + self._init_handles() def _init_handles(self): if self.module is not None: return - def raise_(err): - self._run = functools.partial(_raise_error, err) - raise err - device = commonir_backend.get_driver().get_current_device() # create launcher self._run = commonir_backend.get_driver().launcher_cls(self.src, self.metadata) diff --git a/backend/npu.py b/backend/npu.py index 578546f8..7838b641 100644 --- a/backend/npu.py +++ b/backend/npu.py @@ -961,12 +961,12 @@ def __init__(self, src, metadata): wrapper_src = generate_npu_wrapper_src( constants, signature, workspace_size, mix_mode, lock_num, lock_init_value ) - so_launcher_path = make_npu_launcher_stub(wrapper_src, debug_mode) + self.so_launcher_path = make_npu_launcher_stub(wrapper_src, debug_mode) # initialize launcher import importlib.util spec = importlib.util.spec_from_file_location( - "__triton_launcher", so_launcher_path + "__triton_launcher", self.so_launcher_path ) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) From 6d685dae90047ee965ec7c05af14e37edf7e4d1b Mon Sep 17 00:00:00 2001 From: hellozmz <407190054@qq.com> Date: Fri, 30 Jan 2026 02:53:21 +0000 Subject: [PATCH 09/14] format --- compiler/include/dicp/Utils/Utils.h | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/compiler/include/dicp/Utils/Utils.h b/compiler/include/dicp/Utils/Utils.h index abaf9dc0..23e53351 100644 --- a/compiler/include/dicp/Utils/Utils.h +++ b/compiler/include/dicp/Utils/Utils.h @@ -1,27 +1,27 @@ #ifndef TRITON_UTILS_H #define TRITON_UTILS_H +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/TypeUtilities.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" @@ -32,12 +32,12 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" -#include -#include #include #include #include +#include #include +#include // Dispatch conversion pattern handlers based on backend string. Executes // ASCEND_HANDLER when backend == "ascend", otherwise DEFAULT_HANDLER. From f8bd0aabb6f398ca34932ff2c0c0da21174c01a9 Mon Sep 17 00:00:00 2001 From: hellozmz <407190054@qq.com> Date: Fri, 30 Jan 2026 03:19:54 +0000 Subject: [PATCH 10/14] fix bugs --- test/ascend/passed_tests/test_lightning_attn.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/ascend/passed_tests/test_lightning_attn.py b/test/ascend/passed_tests/test_lightning_attn.py index 899c4b4c..ff00b75e 100644 --- a/test/ascend/passed_tests/test_lightning_attn.py +++ b/test/ascend/passed_tests/test_lightning_attn.py @@ -261,6 +261,7 @@ def lightning_attention_prefill_forward_triton_loop( BLOCK=BLOCK_SIZE, NUM_BLOCK=NUM_BLOCK, BLOCK_MODEL=BLOCK_M, + add_annotate_transpose=True, ) return o, kv @@ -653,9 +654,9 @@ def test_lightning_attention_prefill( kv_check ), f"past_key_value torch:{past_key_value_torch}, past_key_value triton:{past_key_value_triton}" assert output_check, f"output torch:{out_torch}, output triton:{out_triton}" - print( - f"zmz debug torch kv_check:{past_key_value_torch}, triton :{past_key_value_triton}" - ) + # print( + # f"debug torch kv_check:{past_key_value_torch}, triton :{past_key_value_triton}" + # ) # float32 only @pytest.mark.parametrize( From 86b2793f49b59c55df52c0243911dd143c8abee3 Mon Sep 17 00:00:00 2001 From: hellozmz <407190054@qq.com> Date: Fri, 30 Jan 2026 07:11:06 +0000 Subject: [PATCH 11/14] change ut --- .../ascend/{passed_tests => attention}/test_lightning_attn.py | 3 +++ test/ascend/run_tests.sh | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) rename test/ascend/{passed_tests => attention}/test_lightning_attn.py (99%) diff --git a/test/ascend/passed_tests/test_lightning_attn.py b/test/ascend/attention/test_lightning_attn.py similarity index 99% rename from test/ascend/passed_tests/test_lightning_attn.py rename to test/ascend/attention/test_lightning_attn.py index ff00b75e..9af5243b 100644 --- a/test/ascend/passed_tests/test_lightning_attn.py +++ b/test/ascend/attention/test_lightning_attn.py @@ -246,6 +246,9 @@ def lightning_attention_prefill_forward_triton_loop( else: kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device) + import os + bishengir_path = os.environ.get("BISHENG_INSTALL_PATH", None) + print(f"zmz debug BISHENG_INSTALL_PATH: {bishengir_path}") _fwd_loop_kernel[grid]( q, k, diff --git a/test/ascend/run_tests.sh b/test/ascend/run_tests.sh index f4bb54cc..ab0dd697 100644 --- a/test/ascend/run_tests.sh +++ b/test/ascend/run_tests.sh @@ -26,4 +26,6 @@ for test_dir in "${pytestcase_dir[@]}"; do done - +export BISHENG_INSTALL_PATH=/mnt/data01/CI/DLCompiler/data/bishengir_20251215/bin/ +export PATH=$BISHENG_INSTALL_PATH:$PATH +run_pytestcases "attention" From 94311c72e29859c2090ddb4bcd05dcc087d8e581 Mon Sep 17 00:00:00 2001 From: hellozmz <407190054@qq.com> Date: Fri, 30 Jan 2026 07:22:30 +0000 Subject: [PATCH 12/14] format --- test/ascend/attention/test_lightning_attn.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/ascend/attention/test_lightning_attn.py b/test/ascend/attention/test_lightning_attn.py index 9af5243b..ff00b75e 100644 --- a/test/ascend/attention/test_lightning_attn.py +++ b/test/ascend/attention/test_lightning_attn.py @@ -246,9 +246,6 @@ def lightning_attention_prefill_forward_triton_loop( else: kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device) - import os - bishengir_path = os.environ.get("BISHENG_INSTALL_PATH", None) - print(f"zmz debug BISHENG_INSTALL_PATH: {bishengir_path}") _fwd_loop_kernel[grid]( q, k, From ec56756908197780a48fcbcb67da896bf421ab88 Mon Sep 17 00:00:00 2001 From: hellozmz <407190054@qq.com> Date: Fri, 30 Jan 2026 09:26:30 +0000 Subject: [PATCH 13/14] add cmd to dicp_opt --- .../Dialect/LinalgExt/Transforms/Passes.td | 2 +- .../Transforms/AnnotateTransposePass.cpp | 62 +++++++++---------- test/ascend/mlir/annotate_transpose_pass.mlir | 57 ----------------- tools/dicp_triton_opt/dicp_triton_opt.cpp | 1 + triton_dicp_triton.cc | 1 + 5 files changed, 34 insertions(+), 89 deletions(-) delete mode 100644 test/ascend/mlir/annotate_transpose_pass.mlir diff --git a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td index d6a8224d..528e6166 100644 --- a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td +++ b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td @@ -68,7 +68,7 @@ def NormalizeSliceOps : Pass<"normalize-slice-ops", "func::FuncOp"> { let dependentDialects = ["mlir::tensor::TensorDialect"]; } -def AnnotateTransposePass : Pass<"annotate-transpose", "func::FuncOp"> { +def AnnotateTranspose : Pass<"annotate-transpose", "func::FuncOp"> { let summary = "Annotate operations with permuted memref type"; let description = [{ Adds MayImplicitTransposeWithLastAxis annotations to operations with permuted memref type. diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/AnnotateTransposePass.cpp b/compiler/lib/Dialect/LinalgExt/Transforms/AnnotateTransposePass.cpp index bcfdcea7..46c3afca 100644 --- a/compiler/lib/Dialect/LinalgExt/Transforms/AnnotateTransposePass.cpp +++ b/compiler/lib/Dialect/LinalgExt/Transforms/AnnotateTransposePass.cpp @@ -32,7 +32,7 @@ using namespace mlir::dicp; namespace mlir { namespace dicp { namespace LinalgExt { -#define GEN_PASS_DEF_ANNOTATETRANSPOSEPASS +#define GEN_PASS_DEF_ANNOTATETRANSPOSE #include "dicp/Dialect/LinalgExt/Transforms/Passes.h.inc" } // namespace LinalgExt } // namespace dicp @@ -90,14 +90,14 @@ bool checkValueOriginHasNonStandardStride(Value value) { } struct AnnotateTransposePass - : public mlir::dicp::LinalgExt::impl::AnnotateTransposePassBase< + : public mlir::dicp::LinalgExt::impl::AnnotateTransposeBase< AnnotateTransposePass> { void runOnOperation() override { auto funcOp = getOperation(); - llvm::outs() << "[INFO] Starting AnnotateTransposePass on function: " - << funcOp.getName() << "\n"; + LLVM_DEBUG(llvm::dbgs() << "[INFO] Starting AnnotateTransposePass on function: " + << funcOp.getName() << "\n"); // 待处理列表 SmallVector toTensorOpsToMark; @@ -112,7 +112,7 @@ struct AnnotateTransposePass auto source = copyOp.getSource(); auto target = copyOp.getTarget(); - llvm::outs() << "[MEMREF_COPY_VISIT] " << copyOp << "\n"; + LLVM_DEBUG(llvm::dbgs() << "[MEMREF_COPY_VISIT] " << copyOp << "\n"); // --- 尝试进行 IR 重写 (Rewrite) --- // 目标:将 memref.copy(subview(A), subview(B)) 转换为 memref.copy(A, B) @@ -142,12 +142,12 @@ struct AnnotateTransposePass if (isBaseSourcePermuted && isBaseTargetContiguous && baseSourceType.getShape() == baseTargetType.getShape()) { - llvm::outs() << " [REWRITE_MATCH] Found Dynamic Subview Copy " - "candidate for Static Rewrite.\n"; - llvm::outs() << " Base Source (Permuted): " << baseSourceType - << "\n"; - llvm::outs() << " Base Target (Contiguous): " << baseTargetType - << "\n"; + LLVM_DEBUG(llvm::dbgs() << " [REWRITE_MATCH] Found Dynamic Subview Copy " + "candidate for Static Rewrite.\n"); + LLVM_DEBUG(llvm::dbgs() << " Base Source (Permuted): " << baseSourceType + << "\n"); + LLVM_DEBUG(llvm::dbgs() << " Base Target (Contiguous): " << baseTargetType + << "\n"); // 执行重写 OpBuilder builder(copyOp->getContext()); @@ -156,8 +156,8 @@ struct AnnotateTransposePass // 1. 创建新的静态 Copy (Base -> Base) auto newCopyOp = builder.create( copyOp.getLoc(), baseSource, baseTarget); - llvm::outs() << " -> Replaced with Static Copy: " << newCopyOp - << "\n"; + LLVM_DEBUG(llvm::dbgs() << " -> Replaced with Static Copy: " << newCopyOp + << "\n"); // 2. 关键:在 Base Target (MemRef) 上添加 Annotation // 这指导 Ascend 编译器生成隐式转置指令 @@ -166,8 +166,8 @@ struct AnnotateTransposePass builder.create(copyOp.getLoc(), baseTarget); markOp->setAttr("MayImplicitTransposeWithLastAxis", UnitAttr::get(builder.getContext())); - llvm::outs() << " -> Added Annotation to Base Target MemRef: " - << markOp << "\n"; + LLVM_DEBUG(llvm::dbgs() << " -> Added Annotation to Base Target MemRef: " + << markOp << "\n"); // 3. 追踪 Base Target 的 Tensor 使用者 // 我们需要标记 bufferization.to_tensor(BaseTarget),这样后续的 @@ -182,9 +182,9 @@ struct AnnotateTransposePass if (!exists) { toTensorOpsToMark.push_back(toTensorOp); - llvm::outs() << " -> Scheduled Base Target's ToTensorOp " + LLVM_DEBUG(llvm::dbgs() << " -> Scheduled Base Target's ToTensorOp " "for annotation: " - << toTensorOp << "\n"; + << toTensorOp << "\n"); } } } @@ -218,9 +218,9 @@ struct AnnotateTransposePass exists = true; if (!exists) { toTensorOpsToMark.push_back(toTensorOp); - llvm::outs() + LLVM_DEBUG(llvm::dbgs() << " [PROPAGATE] Marked bufferization.to_tensor (Source " - "was permuted)\n"; + "was permuted)\n"); } } } @@ -238,9 +238,9 @@ struct AnnotateTransposePass exists = true; if (!exists) { toTensorOpsToMark.push_back(toTensorOp); - llvm::outs() + LLVM_DEBUG(llvm::dbgs() << " [PROPAGATE_PARENT] Marked " - "bufferization.to_tensor of Parent MemRef\n"; + "bufferization.to_tensor of Parent MemRef\n"); } } } @@ -256,9 +256,9 @@ struct AnnotateTransposePass exists = true; if (!exists) { toTensorOpsToMark.push_back(toTensorOp); - llvm::outs() + LLVM_DEBUG(llvm::dbgs() << " [PROPAGATE_TARGET] Marked bufferization.to_tensor " - "(Target is permuted)\n"; + "(Target is permuted)\n"); } } } @@ -295,8 +295,8 @@ struct AnnotateTransposePass if (shouldMark || hasNonStandardStride) { toTensorOpsToMark.push_back(toTensorOp); - llvm::outs() << "[TO_TENSOR_CHECK] Found permuted/strided origin: " - << toTensorOp << "\n"; + LLVM_DEBUG(llvm::dbgs() << "[TO_TENSOR_CHECK] Found permuted/strided origin: " + << toTensorOp << "\n"); } }); @@ -312,8 +312,8 @@ struct AnnotateTransposePass // 中可能有重复(如果 func.walk 逻辑有交集),去重已经在 push_back // 时做了。 - llvm::outs() << " [ANNOTATE_ACTION] Adding annotation to: " << toTensorOp - << "\n"; + LLVM_DEBUG(llvm::dbgs() << " [ANNOTATE_ACTION] Adding annotation to: " << toTensorOp + << "\n"); OpBuilder builder(toTensorOp->getContext()); builder.setInsertionPointAfter(toTensorOp); @@ -324,11 +324,11 @@ struct AnnotateTransposePass markOp->setAttr("MayImplicitTransposeWithLastAxis", UnitAttr::get(builder.getContext())); - llvm::outs() << " -> Created annotation::MarkOp: " << markOp << "\n"; + LLVM_DEBUG(llvm::dbgs() << " -> Created annotation::MarkOp: " << markOp << "\n"); } - llvm::outs() << "[INFO] Finished AnnotateTransposePass on function: " - << funcOp.getName() << "\n"; + LLVM_DEBUG(llvm::dbgs() << "[INFO] Finished AnnotateTransposePass on function: " + << funcOp.getName() << "\n"); } }; } // namespace @@ -337,4 +337,4 @@ namespace mlir::dicp::LinalgExt { std::unique_ptr> createAnnotateTransposePass() { return std::make_unique(); } -} // namespace mlir::dicp::LinalgExt \ No newline at end of file +} // namespace mlir::dicp::LinalgExt diff --git a/test/ascend/mlir/annotate_transpose_pass.mlir b/test/ascend/mlir/annotate_transpose_pass.mlir deleted file mode 100644 index bb62f77e..00000000 --- a/test/ascend/mlir/annotate_transpose_pass.mlir +++ /dev/null @@ -1,57 +0,0 @@ -// Test for AnnotateTransposePass - checks that the pass adds MayImplicitTransposeWithLastAxis annotations appropriately - -// RUN: bishengir-opt %s -annotate-transpose-pass | FileCheck %s - -func.func @test_linalg_copy_with_permuted_memref() { - // Original memref with permuted layout - %0 = memref.alloc() : memref<128x5xf32, strided<[8, 1]>> - %1 = memref.alloc() : memref<128x5xf32, strided<[1, 128], offset: ?>> - // linalg.copy should get annotated since target has permuted memref - linalg.copy %1, %0 : memref<128x5xf32, strided<[1, 128], offset: ?>> to memref<128x5xf32, strided<[8, 1]>> - // CHECK: linalg.copy - // CHECK: annotation.mark - // CHECK: "MayImplicitTransposeWithLastAxis" - return -} - -func.func @test_memref_copy_with_permuted_memref() { - // Original memref with permuted layout - %0 = memref.alloc() : memref<128x8xf32, strided<[8, 1]>> - %1 = memref.alloc() : memref<128x8xf32, strided<[1, 128], offset: ?>> - // memref.copy should get annotated since target has permuted memref - memref.copy %1, %0 : memref<128x8xf32, strided<[1, 128], offset: ?>> to memref<128x8xf32, strided<[8, 1]>> - // CHECK: memref.copy - // CHECK: annotation.mark - // CHECK: "MayImplicitTransposeWithLastAxis" - return -} - -func.func @test_bufferization_to_tensor_with_permuted_source() { - %0 = memref.alloc() : memref<128x8xf32, strided<[8, 1]>> - // bufferization.to_tensor should get annotated since source has permuted memref - %1 = bufferization.to_tensor %0 : memref<128x8xf32, strided<[8, 1]>> - // CHECK: bufferization.to_tensor - // CHECK: annotation.mark - // CHECK: "MayImplicitTransposeWithLastAxis" - return -} - -func.func @test_memref_subview_with_permuted_source() { - %0 = memref.alloc() : memref<128x8xf32, strided<[8, 1]>> - // memref.subview should get annotated since source has permuted memref - %1 = memref.subview %0[0, 0] to [64, 4] : memref<128x8xf32, strided<[8, 1]>> - // CHECK: memref.subview - // CHECK: annotation.mark - // CHECK: "MayImplicitTransposeWithLastAxis" - return -} - -func.func @test_non_permuted_memref_no_annotation() { - // Non-permuted memref should not get annotated - %0 = memref.alloc() : memref<128x5xf32> - %1 = memref.alloc() : memref<128x5xf32> - linalg.copy %1, %0 : memref<128x5xf32> to memref<128x5xf32> - // CHECK-NOT: annotation.mark - // CHECK-NOT: "MayImplicitTransposeWithLastAxis" - return -} \ No newline at end of file diff --git a/tools/dicp_triton_opt/dicp_triton_opt.cpp b/tools/dicp_triton_opt/dicp_triton_opt.cpp index dccd884e..a0482c95 100644 --- a/tools/dicp_triton_opt/dicp_triton_opt.cpp +++ b/tools/dicp_triton_opt/dicp_triton_opt.cpp @@ -105,6 +105,7 @@ inline void registerDICPDialects(mlir::DialectRegistry ®istry) { dicp::LinalgExt::registerLinalgGenericToSCFPass(); dicp::LinalgExt::registerScalarTo1DTensorPass(); dicp::LinalgExt::registerNormalizeSliceOpsPass(); + dicp::LinalgExt::registerAnnotateTransposePass(); registry.insert Date: Fri, 30 Jan 2026 09:30:28 +0000 Subject: [PATCH 14/14] format --- .../Transforms/AnnotateTransposePass.cpp | 60 +++++++++++-------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/AnnotateTransposePass.cpp b/compiler/lib/Dialect/LinalgExt/Transforms/AnnotateTransposePass.cpp index 46c3afca..c8820d06 100644 --- a/compiler/lib/Dialect/LinalgExt/Transforms/AnnotateTransposePass.cpp +++ b/compiler/lib/Dialect/LinalgExt/Transforms/AnnotateTransposePass.cpp @@ -96,8 +96,9 @@ struct AnnotateTransposePass void runOnOperation() override { auto funcOp = getOperation(); - LLVM_DEBUG(llvm::dbgs() << "[INFO] Starting AnnotateTransposePass on function: " - << funcOp.getName() << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "[INFO] Starting AnnotateTransposePass on function: " + << funcOp.getName() << "\n"); // 待处理列表 SmallVector toTensorOpsToMark; @@ -142,12 +143,13 @@ struct AnnotateTransposePass if (isBaseSourcePermuted && isBaseTargetContiguous && baseSourceType.getShape() == baseTargetType.getShape()) { - LLVM_DEBUG(llvm::dbgs() << " [REWRITE_MATCH] Found Dynamic Subview Copy " - "candidate for Static Rewrite.\n"); - LLVM_DEBUG(llvm::dbgs() << " Base Source (Permuted): " << baseSourceType - << "\n"); - LLVM_DEBUG(llvm::dbgs() << " Base Target (Contiguous): " << baseTargetType - << "\n"); + LLVM_DEBUG(llvm::dbgs() + << " [REWRITE_MATCH] Found Dynamic Subview Copy " + "candidate for Static Rewrite.\n"); + LLVM_DEBUG(llvm::dbgs() << " Base Source (Permuted): " + << baseSourceType << "\n"); + LLVM_DEBUG(llvm::dbgs() << " Base Target (Contiguous): " + << baseTargetType << "\n"); // 执行重写 OpBuilder builder(copyOp->getContext()); @@ -156,8 +158,8 @@ struct AnnotateTransposePass // 1. 创建新的静态 Copy (Base -> Base) auto newCopyOp = builder.create( copyOp.getLoc(), baseSource, baseTarget); - LLVM_DEBUG(llvm::dbgs() << " -> Replaced with Static Copy: " << newCopyOp - << "\n"); + LLVM_DEBUG(llvm::dbgs() << " -> Replaced with Static Copy: " + << newCopyOp << "\n"); // 2. 关键:在 Base Target (MemRef) 上添加 Annotation // 这指导 Ascend 编译器生成隐式转置指令 @@ -166,8 +168,9 @@ struct AnnotateTransposePass builder.create(copyOp.getLoc(), baseTarget); markOp->setAttr("MayImplicitTransposeWithLastAxis", UnitAttr::get(builder.getContext())); - LLVM_DEBUG(llvm::dbgs() << " -> Added Annotation to Base Target MemRef: " - << markOp << "\n"); + LLVM_DEBUG(llvm::dbgs() + << " -> Added Annotation to Base Target MemRef: " + << markOp << "\n"); // 3. 追踪 Base Target 的 Tensor 使用者 // 我们需要标记 bufferization.to_tensor(BaseTarget),这样后续的 @@ -182,9 +185,10 @@ struct AnnotateTransposePass if (!exists) { toTensorOpsToMark.push_back(toTensorOp); - LLVM_DEBUG(llvm::dbgs() << " -> Scheduled Base Target's ToTensorOp " - "for annotation: " - << toTensorOp << "\n"); + LLVM_DEBUG(llvm::dbgs() + << " -> Scheduled Base Target's ToTensorOp " + "for annotation: " + << toTensorOp << "\n"); } } } @@ -218,7 +222,8 @@ struct AnnotateTransposePass exists = true; if (!exists) { toTensorOpsToMark.push_back(toTensorOp); - LLVM_DEBUG(llvm::dbgs() + LLVM_DEBUG( + llvm::dbgs() << " [PROPAGATE] Marked bufferization.to_tensor (Source " "was permuted)\n"); } @@ -238,7 +243,8 @@ struct AnnotateTransposePass exists = true; if (!exists) { toTensorOpsToMark.push_back(toTensorOp); - LLVM_DEBUG(llvm::dbgs() + LLVM_DEBUG( + llvm::dbgs() << " [PROPAGATE_PARENT] Marked " "bufferization.to_tensor of Parent MemRef\n"); } @@ -256,7 +262,8 @@ struct AnnotateTransposePass exists = true; if (!exists) { toTensorOpsToMark.push_back(toTensorOp); - LLVM_DEBUG(llvm::dbgs() + LLVM_DEBUG( + llvm::dbgs() << " [PROPAGATE_TARGET] Marked bufferization.to_tensor " "(Target is permuted)\n"); } @@ -295,8 +302,9 @@ struct AnnotateTransposePass if (shouldMark || hasNonStandardStride) { toTensorOpsToMark.push_back(toTensorOp); - LLVM_DEBUG(llvm::dbgs() << "[TO_TENSOR_CHECK] Found permuted/strided origin: " - << toTensorOp << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "[TO_TENSOR_CHECK] Found permuted/strided origin: " + << toTensorOp << "\n"); } }); @@ -312,8 +320,8 @@ struct AnnotateTransposePass // 中可能有重复(如果 func.walk 逻辑有交集),去重已经在 push_back // 时做了。 - LLVM_DEBUG(llvm::dbgs() << " [ANNOTATE_ACTION] Adding annotation to: " << toTensorOp - << "\n"); + LLVM_DEBUG(llvm::dbgs() << " [ANNOTATE_ACTION] Adding annotation to: " + << toTensorOp << "\n"); OpBuilder builder(toTensorOp->getContext()); builder.setInsertionPointAfter(toTensorOp); @@ -324,11 +332,13 @@ struct AnnotateTransposePass markOp->setAttr("MayImplicitTransposeWithLastAxis", UnitAttr::get(builder.getContext())); - LLVM_DEBUG(llvm::dbgs() << " -> Created annotation::MarkOp: " << markOp << "\n"); + LLVM_DEBUG(llvm::dbgs() + << " -> Created annotation::MarkOp: " << markOp << "\n"); } - LLVM_DEBUG(llvm::dbgs() << "[INFO] Finished AnnotateTransposePass on function: " - << funcOp.getName() << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "[INFO] Finished AnnotateTransposePass on function: " + << funcOp.getName() << "\n"); } }; } // namespace