From d9bcce46f6c25fa36a6e645a37a8ba69d8e10173 Mon Sep 17 00:00:00 2001 From: hellozmz <407190054@qq.com> Date: Mon, 9 Mar 2026 06:57:09 +0000 Subject: [PATCH] tmp: new add --- backend/npu.py | 4 +- backend/npu_utils.cpp | 3 +- .../Transforms/VectorizeParallelLoopPass.cpp | 633 +++++++++--------- 3 files changed, 316 insertions(+), 324 deletions(-) diff --git a/backend/npu.py b/backend/npu.py index 21afec22..9d3fad2b 100644 --- a/backend/npu.py +++ b/backend/npu.py @@ -468,6 +468,7 @@ def commonir_to_linkedir(commonir, metadata, opt, *, named_ops=False): content = re.sub(pattern, r"\1 // to \2", content) if opt.debug or dump_ir: + # todo 调整commonir 路径。 cmd_list = [ _get_dicp_opt_path(), "kernel.ttshared.mlir", @@ -1378,7 +1379,8 @@ def _format_of(ty): #define PY_SSIZE_T_CLEAN #include {'#include ' if enable_taskqueue else ''} -#include "experiment/runtime/runtime/rt.h" +// #include "experiment/runtime/runtime/rt.h" +#include "/usr/local/Ascend/cann/pkg_inc/runtime/runtime/rt.h" {extract_device_print_code_from_cann() if enable_device_print else ''} #define TENSOR_KIND_INPUT 0 diff --git a/backend/npu_utils.cpp b/backend/npu_utils.cpp index 161086e0..482890e6 100644 --- a/backend/npu_utils.cpp +++ b/backend/npu_utils.cpp @@ -6,7 +6,8 @@ #include #include -#include "experiment/runtime/runtime/rt.h" +// #include "experiment/runtime/runtime/rt.h" +#include "/usr/local/Ascend/cann/pkg_inc/runtime/runtime/rt.h" // Use map to differentiate same name functions from different binary static std::unordered_map registered_names; diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/VectorizeParallelLoopPass.cpp b/compiler/lib/Dialect/LinalgExt/Transforms/VectorizeParallelLoopPass.cpp index e73ae778..6882ceac 100644 --- a/compiler/lib/Dialect/LinalgExt/Transforms/VectorizeParallelLoopPass.cpp +++ b/compiler/lib/Dialect/LinalgExt/Transforms/VectorizeParallelLoopPass.cpp @@ -1,6 +1,8 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -10,398 +12,385 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "vectorize-parallel-loop" using namespace mlir; -namespace mlir { -namespace dicp { -namespace LinalgExt { -#define GEN_PASS_DEF_VECTORIZEPARALLELLOOPPASS -#include "dicp/Dialect/LinalgExt/Transforms/Passes.h.inc" -} // namespace LinalgExt -} // namespace dicp -} // namespace mlir +namespace { -#define DEBUG_TYPE "vectorize-parallel-loop-pass" +// ============================================================================ +// 工具类:处理数据流映射和类型推导 +// ============================================================================ +class TensorizationContext { +public: + TensorizationContext(int64_t size, PatternRewriter &rewriter, Location loc) + : tileSize(size), rewriter(rewriter), loc(loc) {} + + // 记录标量到张量的映射 + void map(Value scalar, Value tensor) { + scalarToTensorMap[scalar] = tensor; + LLVM_DEBUG(llvm::dbgs() << " [Map] Scalar " << scalar << " -> Tensor " << tensor << "\n"); + } -namespace { + // 查找映射,如果不存在且是标量,则尝试广播 + Value lookupOrBroadcast(Value scalar) { + if (scalarToTensorMap.count(scalar)) { + return scalarToTensorMap[scalar]; + } + + // 如果是 Index 类型,通常用于地址计算,保持原样(不转 Tensor) + if (scalar.getType().isIndex()) { + return nullptr; + } + + // 尝试广播 (Splat) + // 检查是否是基础标量类型 (Float, Int) + if (auto scalarType = dyn_cast(scalar.getType())) { + // 已经是 Tensor 了,直接返回 + return scalar; + } + + LLVM_DEBUG(llvm::dbgs() << " [Broadcast] Creating splat for: " << scalar << "\n"); + auto tensorType = RankedTensorType::get({tileSize}, scalar.getType()); + Value splat = rewriter.create(loc, tensorType, scalar); + map(scalar, splat); + return splat; + } + + Value getMapped(Value scalar) { + return scalarToTensorMap.count(scalar) ? scalarToTensorMap[scalar] : nullptr; + } + +private: + int64_t tileSize; + PatternRewriter &rewriter; + Location loc; + DenseMap scalarToTensorMap; +}; -// 核心 Pattern:将标量并行循环展开为向量化的顺序操作 -struct VectorizeParallelLoopPattern : public OpRewritePattern { +// ============================================================================ +// 核心 Pattern:将 scf.parallel 转化为 Tensor + Bufferization 操作 +// ============================================================================ +struct TensorizeParallelLoopPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(scf::ParallelOp op, PatternRewriter &rewriter) const override { - LLVM_DEBUG( - llvm::dbgs() - << "\n[VectorizeParallelLoop] >>> Start matching scf.parallel at " - << op.getLoc() << "\n"); + LLVM_DEBUG(llvm::dbgs() << "\n=== [Tensorize] Start Processing scf.parallel at " << op.getLoc() << " ===\n"); - // 1. 检查循环结构 + // 1. 基本合法性检查 + // 只处理 1D 循环,且边界必须是常量,以确定静态 Shape if (op.getNumLoops() != 1) { - LLVM_DEBUG(llvm::dbgs() - << "[VectorizeParallelLoop] Skip: Multi-dimensional loop.\n"); + LLVM_DEBUG(llvm::dbgs() << "[Skip] Not a 1D loop.\n"); return failure(); } - Value lowerBound = op.getLowerBound()[0]; - Value upperBound = op.getUpperBound()[0]; - - auto lowerOp = lowerBound.getDefiningOp(); - auto upperOp = upperBound.getDefiningOp(); + auto lowerOp = op.getLowerBound()[0].getDefiningOp(); + auto upperOp = op.getUpperBound()[0].getDefiningOp(); + auto stepOp = op.getStep()[0].getDefiningOp(); - if (!lowerOp || !upperOp) { - LLVM_DEBUG(llvm::dbgs() - << "[VectorizeParallelLoop] Skip: Bounds are not constant.\n"); + if (!lowerOp || !upperOp || !stepOp) { + LLVM_DEBUG(llvm::dbgs() << "[Skip] Bounds or step are not constant.\n"); return failure(); } int64_t lowerVal = lowerOp.value(); int64_t upperVal = upperOp.value(); - int64_t size = upperVal - lowerVal; - - LLVM_DEBUG(llvm::dbgs() << "[VectorizeParallelLoop] Loop Bounds: [" - << lowerVal << ", " << upperVal << ")\n"); - LLVM_DEBUG(llvm::dbgs() - << "[VectorizeParallelLoop] Calculated Vector Size: " << size - << "\n"); - - // 只有当有实际计算量时才处理 - if (size <= 0) { - LLVM_DEBUG(llvm::dbgs() << "[VectorizeParallelLoop] Skip: Size <= 0.\n"); - return failure(); - } - - // 2. 准备映射表 - // mapper: 用于处理索引计算 (将 Loop IV 映射为常数 LowerBound) - IRMapping mapper; + int64_t stepVal = stepOp.value(); + int64_t range = upperVal - lowerVal; + + // 我们假设这个 Pass 作用于已经 Tile 过的内部循环, + // 或者我们把整个循环范围当作一个 Tensor 处理 + int64_t tensorSize = range; + + LLVM_DEBUG(llvm::dbgs() << " Loop Range: [" << lowerVal << ", " << upperVal << "), Step: " << stepVal << "\n"); + LLVM_DEBUG(llvm::dbgs() << " Target Tensor Size: " << tensorSize << "\n"); + + if (tensorSize <= 0) return failure(); + + // 2. 初始化上下文 + TensorizationContext ctx(tensorSize, rewriter, op.getLoc()); + + // 3. IV 映射 + // 在 Tensor 化模式下,原循环的 IV 通常映射为 Base Offset (LowerBound) + // 后续的地址计算都基于这个 Base Offset 进行 + IRMapping ivMap; + Value iv = op.getInductionVars()[0]; + ivMap.map(iv, op.getLowerBound()[0]); + LLVM_DEBUG(llvm::dbgs() << " Mapped IV " << iv << " -> LowerBound " << op.getLowerBound()[0] << "\n"); + + // 保存所有创建的新操作,以便替换yield操作 + SmallVector newOps; + + // 4. 遍历 Body 指令 Block *body = op.getBody(); - Value iv = body->getArgument(0); - - LLVM_DEBUG(llvm::dbgs() - << "[VectorizeParallelLoop] Mapping Induction Variable " << iv - << " -> Constant " << lowerBound << "\n"); - mapper.map(iv, lowerBound); // 关键修复:将 IV 替换为 Loop 起始值 - - // scalarToTensorMap: 用于数据流向量化 (标量 Value -> 向量 Tensor Value) - DenseMap scalarToTensorMap; + auto &ops = body->getOperations(); + // 收集所有要处理的操作,排除terminator + SmallVector opsToProcess; + for (auto &inst : ops) { + if (!isa(inst)) { + opsToProcess.push_back(&inst); + } + } - LLVM_DEBUG( - llvm::dbgs() - << "[VectorizeParallelLoop] Starting to process body operations...\n"); + // 处理所有收集的操作 + for (auto *inst : opsToProcess) { + LLVM_DEBUG(llvm::dbgs() << " -> Visiting: " << inst->getName() << "\n"); - // 3. 遍历原循环体,按顺序生成向量化代码 - for (Operation &inst : body->getOperations()) { - LLVM_DEBUG(llvm::dbgs() - << " -> Visiting Op: " << inst.getName() << "\n"); + // --- Case 0: 忽略 Terminator --- + if (isa(*inst)) continue; - // 跳过 terminator - if (isa(inst) || isa(inst)) { - LLVM_DEBUG(llvm::dbgs() << " Skipping terminator.\n"); + // --- Case 1: 地址/索引计算 (Index Arithmetic) --- + // 这些指令通常保持标量形式,用于计算 memref 的偏移 + if (isa(*inst) && + inst->getResult(0).getType().isIndex()) { + LLVM_DEBUG(llvm::dbgs() << " [Action] Cloning Index Compute.\n"); + auto *clonedOp = rewriter.clone(*inst, ivMap); + newOps.push_back(clonedOp); continue; } - - // --- Case A: 索引计算 (Index Cast, Add, Mul 等) --- - // 直接克隆,但使用 mapper 将 IV 替换为常数 - if (isa(inst)) { - LLVM_DEBUG(llvm::dbgs() - << " [Action] Cloning index calculation.\n"); - Operation *newOp = rewriter.clone(inst, mapper); - LLVM_DEBUG(llvm::dbgs() - << " New Op result: " << newOp->getResult(0) << "\n"); - continue; + + // 如果是 i32 运算但用于地址计算的,也 clone (需要根据上下文,这里简化处理,假设所有 i32/index 混合运算都为了地址) + if (isa(*inst) && inst->getResult(0).getType().isInteger(32)) { + LLVM_DEBUG(llvm::dbgs() << " [Action] Cloning Int32 Compute (Assuming Address Calc).\n"); + auto *clonedOp = rewriter.clone(*inst, ivMap); + newOps.push_back(clonedOp); + continue; } - // --- Case B: 读取内存 (Load -> Vectorize) --- - if (auto loadOp = dyn_cast(inst)) { - LLVM_DEBUG(llvm::dbgs() << " [Action] Vectorizing LoadOp.\n"); - Value memref = loadOp.getMemRef(); - // 获取计算好的索引 (通过 mapper 查找) - Value index = mapper.lookup(loadOp.getIndices()[0]); - LLVM_DEBUG(llvm::dbgs() << " Base MemRef: " << memref << "\n"); - LLVM_DEBUG(llvm::dbgs() << " Mapped Index: " << index << "\n"); - - // 1. Alloc Local Buffer - auto memrefType = dyn_cast(memref.getType()); - if (!memrefType) { - LLVM_DEBUG(llvm::dbgs() - << "[VectorizeParallelLoop] ERROR: MemRef type expected " - "but not found.\n"); - return failure(); + // --- Case 2: 内存读取 (Load -> Alloc + Copy + ToTensor) --- + if (auto loadOp = dyn_cast(*inst)) { + LLVM_DEBUG(llvm::dbgs() << " [Action] Tensorizing MemRef Load.\n"); + + Value baseMemref = loadOp.getMemRef(); + // 计算偏移量:使用 ivMap 查找映射后的操作数 + SmallVector indices; + for (auto idx : loadOp.getIndices()) { + indices.push_back(ivMap.lookupOrDefault(idx)); } - auto localType = MemRefType::get({size}, memrefType.getElementType()); - Value localAlloc = - rewriter.create(op.getLoc(), localType); - LLVM_DEBUG(llvm::dbgs() << " Created Local Alloc: " - << localAlloc.getType() << "\n"); - - // 2. Subview Global Memory - SmallVector offsets = {index}; - SmallVector sizes = {rewriter.getIndexAttr(size)}; + + // 2.1 Alloc Local (UB) + auto elemType = dyn_cast(baseMemref.getType()).getElementType(); + auto localMemType = MemRefType::get({tensorSize}, elemType); + Value localAlloc = rewriter.create(op.getLoc(), localMemType); + + // 2.2 Create SubView (GM View) + // 假设 stride 为 1 (连续访问) + SmallVector offsets; + for(auto v : indices) offsets.push_back(v); + SmallVector sizes = {rewriter.getIndexAttr(tensorSize)}; SmallVector strides = {rewriter.getIndexAttr(1)}; + Value subview = rewriter.create( - op.getLoc(), memref, offsets, sizes, strides); - LLVM_DEBUG(llvm::dbgs() << " Created Subview.\n"); + op.getLoc(), baseMemref, offsets, sizes, strides); - // 3. Copy Global -> Local - rewriter.create(op.getLoc(), subview, localAlloc); - LLVM_DEBUG(llvm::dbgs() << " Created Copy (Global -> Local).\n"); + // 2.3 Copy GM -> UB (MTE) + auto copyOp = rewriter.create(op.getLoc(), subview, localAlloc); + newOps.push_back(copyOp.getOperation()); - // 4. Local Buffer -> Tensor - auto tensorType = - RankedTensorType::get({size}, memrefType.getElementType()); - auto toTensor = rewriter.create( + // 2.4 To Tensor + auto tensorType = RankedTensorType::get({tensorSize}, elemType); + Value tensorVal = rewriter.create( op.getLoc(), tensorType, localAlloc, /*restrict=*/true); - LLVM_DEBUG(llvm::dbgs() << " Created ToTensorOp (Result: " - << toTensor.getResult() << ").\n"); - // 5. 注册映射:原 Load 的标量结果 -> 新的 Tensor 结果 - scalarToTensorMap[loadOp.getResult()] = toTensor.getResult(); + // 注册映射 + ctx.map(loadOp.getResult(), tensorVal); continue; } - // --- Case C: 计算逻辑 (Generic Binary Operations -> Vector Binary - // Operations) --- 检查是否为二元运算操作 - bool isBinaryOp = - inst.getNumOperands() == 2 && - (isa(inst)); - - if (isBinaryOp) { - LLVM_DEBUG(llvm::dbgs() << " [Action] Processing Binary ArithOp: " - << inst.getName() << "\n"); - - Value lhs = inst.getOperand(0); - Value rhs = inst.getOperand(1); - - // 检查操作数是否已向量化 - Value vecLhs = - scalarToTensorMap.count(lhs) ? scalarToTensorMap[lhs] : nullptr; - Value vecRhs = - scalarToTensorMap.count(rhs) ? scalarToTensorMap[rhs] : nullptr; - - if (vecLhs) - LLVM_DEBUG(llvm::dbgs() << " LHS is vectorized.\n"); - if (vecRhs) - LLVM_DEBUG(llvm::dbgs() << " RHS is vectorized.\n"); - - // 如果两个输入都是向量,生成向量运算 - if (vecLhs && vecRhs) { - // 创建一个新的OperationState,使用与原操作相同的操作码 - OperationState state(op.getLoc(), inst.getName().getStringRef()); - - // 添加向量化的操作数 - state.addOperands({vecLhs, vecRhs}); - - // 从原操作复制结果类型,但转换为向量类型 - llvm::SmallVector resultTypes; - for (auto result : inst.getResults()) { - Type scalarType = result.getType(); - ShapedType vectorType; - - if (auto shapedType = dyn_cast(scalarType)) { - // 如果已经是shaped type,则保持形状但可能更新为tensor类型 - vectorType = RankedTensorType::get(shapedType.getShape(), - shapedType.getElementType()); - } else { - // 如果是标量类型,转换为对应元素类型的向量 - vectorType = RankedTensorType::get({size}, scalarType); - } + // --- Case 3: 内存写入 (Store -> Alloc + Materialize + Copy) --- + if (auto storeOp = dyn_cast(*inst)) { + LLVM_DEBUG(llvm::dbgs() << " [Action] Tensorizing MemRef Store.\n"); - resultTypes.push_back(vectorType); - } - state.addTypes(resultTypes); - - // 创建新的向量化操作 - auto newOp = rewriter.create(state); - - // 将新操作的结果映射到scalarToTensorMap - for (size_t i = 0; i < inst.getNumResults(); ++i) { - scalarToTensorMap[inst.getResult(i)] = newOp->getResult(i); - } - - LLVM_DEBUG({ - llvm::dbgs() << " Created Vector Operation: " << inst.getName() - << "\n"; - llvm::dbgs() << " Result Type: " - << newOp->getResult(0).getType() << "\n"; - }); - } else { - // 如果不是向量操作(可能是索引计算的一部分),则回退到普通 clone - LLVM_DEBUG( - llvm::dbgs() - << " WARNING: Operands not vectorized, cloning scalar op.\n"); - rewriter.clone(inst, mapper); + Value valToStore = storeOp.getValue(); + Value tensorToStore = ctx.lookupOrBroadcast(valToStore); // 可能是计算结果,也可能是常量 + + if (!tensorToStore) { + LLVM_DEBUG(llvm::dbgs() << " [Error] Value to store not available in map.\n"); + return failure(); } + + Value baseMemref = storeOp.getMemRef(); + SmallVector indices; + for (auto idx : storeOp.getIndices()) { + indices.push_back(ivMap.lookupOrDefault(idx)); + } + + // 3.1 Alloc Local (Output Buffer) + auto tensorType = dyn_cast(tensorToStore.getType()); + auto localMemType = MemRefType::get({tensorSize}, tensorType.getElementType()); + Value localOut = rewriter.create(op.getLoc(), localMemType); + + // 3.2 Materialize (Tensor -> UB) + auto matOp = rewriter.create( + op.getLoc(), tensorToStore, localOut); + matOp.setWritable(true); + + // 3.3 SubView (GM Output View) + SmallVector offsets; + for(auto v : indices) offsets.push_back(v); + SmallVector sizes = {rewriter.getIndexAttr(tensorSize)}; + SmallVector strides = {rewriter.getIndexAttr(1)}; + + Value outSubview = rewriter.create( + op.getLoc(), baseMemref, offsets, sizes, strides); + + // 3.4 Copy UB -> GM (MTE) + auto copyOp = rewriter.create(op.getLoc(), localOut, outSubview); + newOps.push_back(copyOp.getOperation()); continue; } - // --- Case D: 写回逻辑 (Materialize) --- - if (auto matOp = - dyn_cast(inst)) { - LLVM_DEBUG(llvm::dbgs() - << " [Action] Processing MaterializeInDestinationOp.\n"); - Value source = matOp.getSource(); - Value destMemref = matOp.getDest(); - - Value vectorResult = nullptr; - - // 追踪数据来源 - if (auto insertOp = source.getDefiningOp()) { - LLVM_DEBUG( - llvm::dbgs() - << " Source is tensor.insert, tracing scalar input...\n"); - Value scalarInput = insertOp.getScalar(); - if (scalarToTensorMap.count(scalarInput)) { - vectorResult = scalarToTensorMap[scalarInput]; - LLVM_DEBUG(llvm::dbgs() << " Found vectorized source.\n"); - } - } else if (scalarToTensorMap.count(source)) { - vectorResult = scalarToTensorMap[source]; - LLVM_DEBUG(llvm::dbgs() - << " Found vectorized source directly.\n"); - } + // --- Case 4: Linalg Op (MatMul 等) --- + // Linalg Op 本身就是 Tensor 语义友好的,如果输入已经是 Tensor,直接 Clone 即可 + if (isa(*inst)) { + LLVM_DEBUG(llvm::dbgs() << " [Action] Handling Linalg Op.\n"); + // 对于 Linalg,我们需要确保其 inputs 和 outputs 都已映射为 tensor + // 这里的处理比较简化:假设 linalg op 是纯计算,其输入来自于之前的 load + + SmallVector newOperands; + for (Value operand : inst->getOperands()) { + Value mapped = ctx.getMapped(operand); + if (mapped) newOperands.push_back(mapped); + else newOperands.push_back(operand); // 可能是 accumulator 或其他 + } + + // Clone 并替换操作数 + Operation* newOp = rewriter.clone(*inst); + newOp->setOperands(newOperands); + newOps.push_back(newOp); + + // 映射结果 + for(auto i=0; igetNumResults(); ++i) { + ctx.map(inst->getResult(i), newOp->getResult(i)); + } + continue; + } - if (vectorResult) { - // 1. Alloc Output Buffer - auto tensorType = dyn_cast(vectorResult.getType()); - if (!tensorType) { - LLVM_DEBUG(llvm::dbgs() - << "[VectorizeParallelLoop] ERROR: Expected " - "RankedTensorType for vector result.\n"); - continue; - } - auto elemType = tensorType.getElementType(); - auto localOutType = MemRefType::get({size}, elemType); - Value localOut = - rewriter.create(op.getLoc(), localOutType); - LLVM_DEBUG(llvm::dbgs() << " Created Local Output Alloc: " - << localOutType << "\n"); - - // 2. Materialize Tensor -> Local Buffer - // Fix: capture operation and set writable to true - auto newMatOp = - rewriter.create( - op.getLoc(), vectorResult, localOut); - newMatOp.setWritable(true); - LLVM_DEBUG( - llvm::dbgs() - << " Created Vectorized Materialize (writable=true).\n"); - - // 3. 处理输出地址 (ReinterpretCast -> Subview) - Value baseMemref = destMemref; - Value writeOffset = nullptr; - - if (auto castOp = - destMemref.getDefiningOp()) { - LLVM_DEBUG( - llvm::dbgs() - << " Dest is ReinterpretCast, resolving offset...\n"); - baseMemref = castOp.getSource(); - if (!castOp.getOffsets().empty()) { - // Fix: Directly use the Value, do not use dyn_cast - Value loopOffset = castOp.getOffsets()[0]; - writeOffset = mapper.lookup(loopOffset); - LLVM_DEBUG(llvm::dbgs() << " Resolved write offset: " - << writeOffset << "\n"); + // --- Case 5: 通用计算指令 (Arith, Math) --- + // 这是一个通用的处理逻辑,支持 Unary, Binary, Ternary 等 + // 只要是 NoSideEffect 的计算指令都可以尝试转换 + if (isPureComputeOp(*inst)) { + LLVM_DEBUG(llvm::dbgs() << " [Action] Generic Compute Tensorization.\n"); + + SmallVector vecOperands; + bool allMapped = true; + + for (Value operand : inst->getOperands()) { + Value vecOp = ctx.lookupOrBroadcast(operand); + if (!vecOp) { + // 如果操作数既不是 tensor 也没法 broadcast (比如是 index),则保留原样? + // 通常计算指令的操作数不应该是 index + allMapped = false; + break; } - } else { - LLVM_DEBUG(llvm::dbgs() - << " Dest is not ReinterpretCast. Handling logic " - "might be incomplete for simple memrefs.\n"); - } - - // 如果找到了写入位置,执行 Copy Local -> Global - if (baseMemref && writeOffset) { - SmallVector offsets = {writeOffset}; - SmallVector sizes = {rewriter.getIndexAttr(size)}; - SmallVector strides = {rewriter.getIndexAttr(1)}; - - Value outSubview = rewriter.create( - op.getLoc(), baseMemref, offsets, sizes, strides); - - rewriter.create(op.getLoc(), localOut, outSubview); - LLVM_DEBUG(llvm::dbgs() - << " Created Copy (Local -> Global).\n"); - } - } else { - LLVM_DEBUG(llvm::dbgs() - << " WARNING: Could not find vectorized source for " - "materialize.\n"); + vecOperands.push_back(vecOp); } - continue; - } - // 忽略不需要的操作 - if (isa(inst)) { - LLVM_DEBUG( - llvm::dbgs() - << " Skipping tensor.insert (handled in materialize).\n"); - continue; + if (allMapped) { + // 构建新的 Tensor 类型 + SmallVector resultTypes; + for (Type t : inst->getResultTypes()) { + resultTypes.push_back(RankedTensorType::get({tensorSize}, t)); + } + + // 使用 OperationState 通用构建 Op + OperationState state(op.getLoc(), inst->getName().getStringRef()); + state.addOperands(vecOperands); + state.addTypes(resultTypes); + state.addAttributes(inst->getAttrs()); + + Operation *newOp = rewriter.create(state); + newOps.push_back(newOp); + + // 映射结果 + for (size_t i = 0; i < inst->getNumResults(); ++i) { + ctx.map(inst->getResult(i), newOp->getResult(i)); + } + continue; + } } - if (isa(inst)) { - LLVM_DEBUG(llvm::dbgs() << " Skipping tensor.empty.\n"); - continue; + + // --- Case 6: scf.reduce (Reduction) --- + if (auto reduceOp = dyn_cast(*inst)) { + // 在 Tensor 模式下,reduce 通常意味着对整个 Tensor 进行归约 + // 这里需要引入 linalg.reduce 或类似机制,实现较复杂。 + // 简单策略:如果遇到 reduce,发出警告或尝试使用 arith/vector reduce (但题目要求不用 vector) + // 对于纯 Tensor + Linalg 体系,可以使用 linalg.reduce + LLVM_DEBUG(llvm::dbgs() << " [Warning] scf.reduce detected. Implementing basic collapse.\n"); + // TODO: Implement linalg.reduce generation + continue; } - LLVM_DEBUG(llvm::dbgs() - << " [Unhandled] Operation not handled specifically: " - << inst.getName() << "\n"); + LLVM_DEBUG(llvm::dbgs() << " [Unhandled] Skipping Op: " << inst->getName() << "\n"); } - // 打印当前op - LLVM_DEBUG({ - llvm::dbgs() << "[VectorizeParallelLoop] Current Op: "; - op.print(llvm::dbgs()); - llvm::dbgs() << "\n"; - }); - // 打印映射表 - LLVM_DEBUG({ - llvm::dbgs() << "[VectorizeParallelLoop] Scalar to Tensor Map:\n"; - for (auto &[scalar, tensor] : scalarToTensorMap) { - llvm::dbgs() << " " << scalar << " -> " << tensor << "\n"; - } - }); - - // 4. 删除原循环 - LLVM_DEBUG( - llvm::dbgs() - << "[VectorizeParallelLoop] Erasing original scf.parallel op.\n"); - rewriter.eraseOp(op); + // 5. 替换原循环的所有使用者 + // 获取原循环的 yield 操作及其结果 + SmallVector newYieldOperands; // 声明变量 + + auto yieldOp = op.getBody()->getTerminator(); + if (auto yield = dyn_cast(yieldOp)) { + // 创建一个新的 yield 操作,使用转换后的值 + for (Value operand : yield.getOperands()) { + Value mapped = ctx.getMapped(operand); + if (mapped) { + newYieldOperands.push_back(mapped); + } else { + // 如果没有映射,则需要广播或保留原始值 + Value broadcasted = ctx.lookupOrBroadcast(operand); + if (broadcasted) { + newYieldOperands.push_back(broadcasted); + } else { + newYieldOperands.push_back(operand); + } + } + } + rewriter.create(op.getLoc(), newYieldOperands); + } - LLVM_DEBUG(llvm::dbgs() - << "[VectorizeParallelLoop] <<< MatchAndRewrite Done.\n\n"); + // 正确替换原循环操作 + rewriter.replaceOp(op, newYieldOperands); + LLVM_DEBUG(llvm::dbgs() << "=== [Tensorize] Done ===\n"); + return success(); } + + // 辅助函数:判断是否为纯计算指令 + bool isPureComputeOp(Operation &op) const { + // 简单白名单 + return isa(op) || + op.getDialect()->getNamespace() == "math"; // 所有 math.* (exp, log, etc) + } }; +// ============================================================================ +// Pass 定义 +// ============================================================================ struct VectorizeParallelLoopPass - : public PassWrapper> { + : public PassWrapper> { + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorizeParallelLoopPass) + StringRef getArgument() const final { return "vectorize-parallel-loop"; } StringRef getDescription() const final { - return "Vectorize scf.parallel loops by unrolling and using bulk memory " - "ops."; + return "Convert scf.parallel to tensor operations with explicit GM-UB data movement."; } void runOnOperation() override { - LLVM_DEBUG(llvm::dbgs() - << "[Pass] Starting VectorizeParallelLoopPass on function...\n"); + LLVM_DEBUG(llvm::dbgs() << "--- Running TensorizeParallelLoopPass ---\n"); + RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); + patterns.add(&getContext()); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { - LLVM_DEBUG(llvm::dbgs() << "[Pass] Pattern application failed.\n"); signalPassFailure(); - } else { - LLVM_DEBUG(llvm::dbgs() << "[Pass] Pattern application succeeded.\n"); } } - - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorizeParallelLoopPass) }; } // namespace