From ee0b03936e3facef6ec63f1b19e864a381b35fcc Mon Sep 17 00:00:00 2001 From: cokedong <408244909@qq.com> Date: Mon, 2 Mar 2026 08:30:41 +0000 Subject: [PATCH 01/10] fix and add testcases --- backend/commonir/adapter.py | 24 +++ commonir/src/target/codegen_commonir.cc | 223 +++++++++++++++--------- commonir/src/target/codegen_commonir.h | 7 +- test/commonir/add_vector.py | 62 +++++++ test/commonir/gemm.py | 75 ++++++++ 5 files changed, 311 insertions(+), 80 deletions(-) create mode 100644 test/commonir/add_vector.py create mode 100644 test/commonir/gemm.py diff --git a/backend/commonir/adapter.py b/backend/commonir/adapter.py index e4c8cc3..d136a59 100644 --- a/backend/commonir/adapter.py +++ b/backend/commonir/adapter.py @@ -1,4 +1,6 @@ import os +import tempfile +import shutil import re from typing import Callable, List from triton.backends.dicp_triton.commonir.compiler import ( @@ -62,6 +64,16 @@ def compile_and_create_adapter(cls, tilelang_module): adapter_wrapper = AdapterWrapper() adapter_wrapper.artifact.set_kernel_source(tilelang_module) mlir_content = cls._tilelang_to_commonir(tilelang_module) + dump_ir = os.environ.get("DUMP_COMMON_IR", "0") == "1" + if dump_ir: + with tempfile.TemporaryDirectory() as tmpdir: + print(mlir_content) + dst_path = os.path.join(tmpdir, "kernel.commonir.mlir") + print(dst_path) + cls._write_mlir_file(dst_path, mlir_content) + if not os.path.exists("./tmp"): + os.makedirs("./tmp") + shutil.copy(dst_path, "./tmp/kernel.commonir.mlir") grid = cls._parse_grid(tilelang_module) signature = cls._parse_signature(mlir_content) @@ -187,6 +199,18 @@ def _read_mlir_file(cls, file_path) -> str: except Exception as e: print(f"Error occurred while reading the file: {e}") return None + @classmethod + def _write_mlir_file(cls, file_path, mlir_content): + try: + with open(file_path, "w", encoding="utf-8") as file: + file.write(mlir_content) + return True + except FileNotFoundError: + print(f"Error: Directory for '{file_path}' does not exist") + return False + except Exception as e: + print(f"Error occurred while writing to the file: {e}") + return False @classmethod def _parse_signature(cls, mlir_content) -> dict: diff --git a/commonir/src/target/codegen_commonir.cc b/commonir/src/target/codegen_commonir.cc index 9c43858..baa15ef 100644 --- a/commonir/src/target/codegen_commonir.cc +++ b/commonir/src/target/codegen_commonir.cc @@ -7,6 +7,7 @@ #include "../op/copy.h" #include "../op/fill.h" #include "../op/gemm.h" +#include "../op/reduce.h" #include "../op/region.h" #include "arith/pattern_match.h" #include "tvm/ir/expr.h" @@ -36,14 +37,20 @@ using ffi::Array; using ffi::String; template -inline void PrintBinary(const T *op, const char *opstr, std::ostream &os, - CodeGenC *CG) { - auto PrintOp = [op, &os, CG](auto Operand) { +inline void CodeGenTileLangCOMMONIR::PrintBinary(const T *op, const char *opstr, + std::ostream &os) { + auto PrintOp = [op, &os, this](auto Operand) { std::ostringstream tmpos; - CG->PrintExpr(Operand, tmpos << "%"); + if (Operand.template as() || + Operand.template as() || + Operand.template as()) { + PrintExpr(Operand, tmpos << "%"); + } else { + std::string op_name = SSAGetID(PrintExpr(Operand), Operand->dtype); + tmpos << "%" << op_name; + } return tmpos.str(); }; - if (op->dtype.lanes() == 1) { // left op os << "arith." << opstr << " "; @@ -52,7 +59,7 @@ inline void PrintBinary(const T *op, const char *opstr, std::ostream &os, // right op os << PrintOp(op->b); os << " : "; - CG->PrintType(op->a->dtype, os); + PrintType(op->a->dtype, os); } else { os << "<<>>\n"; } @@ -170,6 +177,7 @@ static std::string broadcastAttrCodegen(Array &buffer_shape0, return temp.str(); } + void PrintBufferMap(const Map &buffer_map) { for (const auto &kv : buffer_map) { const Var &var = kv.first; @@ -348,6 +356,7 @@ void CodeGenTileLangCOMMONIR::VisitStmt_(const tir::IfThenElseNode *op) { void CodeGenTileLangCOMMONIR::PrintSSAAssign(const std::string &target, const std::string &src, DataType t) { + PrintIndent(); stream << "%" << target << " = " << src << "\n"; } @@ -487,74 +496,74 @@ void CodeGenTileLangCOMMONIR::VisitExpr_(const FloorDivNode *op, // FIXME: The floor div in python is not the same as arith.divsi in negative // scenarios. if (op->dtype.is_int() || op->dtype.is_uint()) { - PrintBinary(op, "divsi", os, this); + PrintBinary(op, "divsi", os); } else if (op->dtype.is_float()) { - PrintBinary(op, "divf", os, this); + PrintBinary(op, "divf", os); } } void CodeGenTileLangCOMMONIR::VisitExpr_(const FloorModNode *op, std::ostream &os) { if (op->dtype.is_int() || op->dtype.is_uint()) { - PrintBinary(op, "remsi", os, this); + PrintBinary(op, "remsi", os); } else if (op->dtype.is_float()) { - PrintBinary(op, "remf", os, this); + PrintBinary(op, "remf", os); } } void CodeGenTileLangCOMMONIR::VisitExpr_(const LTNode *op, std::ostream &os) { if (op->a->dtype.is_int()) { - PrintBinary(op, "cmpi slt,", os, this); + PrintBinary(op, "cmpi slt,", os); } else if (op->a->dtype.is_uint()) { - PrintBinary(op, "cmpi ult,", os, this); + PrintBinary(op, "cmpi ult,", os); } else { - PrintBinary(op, "cmpf olt,", os, this); + PrintBinary(op, "cmpf olt,", os); } } void CodeGenTileLangCOMMONIR::VisitExpr_(const NENode *op, std::ostream &os) { if (op->a->dtype.is_int() || op->a->dtype.is_uint()) { - PrintBinary(op, "cmpi ne,", os, this); + PrintBinary(op, "cmpi ne,", os); } else { - PrintBinary(op, "cmpf one,", os, this); + PrintBinary(op, "cmpf one,", os); } } void CodeGenTileLangCOMMONIR::VisitExpr_(const EQNode *op, std::ostream &os) { if (op->a->dtype.is_int() || op->a->dtype.is_uint()) { - PrintBinary(op, "cmpi eq,", os, this); + PrintBinary(op, "cmpi eq,", os); } else { - PrintBinary(op, "cmpf oeq,", os, this); + PrintBinary(op, "cmpf oeq,", os); } } void CodeGenTileLangCOMMONIR::VisitExpr_(const LENode *op, std::ostream &os) { if (op->a->dtype.is_int()) { - PrintBinary(op, "cmpi sle,", os, this); + PrintBinary(op, "cmpi sle,", os); } else if (op->a->dtype.is_uint()) { - PrintBinary(op, "cmpi ule,", os, this); + PrintBinary(op, "cmpi ule,", os); } else { - PrintBinary(op, "cmpf ole,", os, this); + PrintBinary(op, "cmpf ole,", os); } } void CodeGenTileLangCOMMONIR::VisitExpr_(const GENode *op, std::ostream &os) { if (op->a->dtype.is_int()) { - PrintBinary(op, "cmpi sge,", os, this); + PrintBinary(op, "cmpi sge,", os); } else if (op->a->dtype.is_uint()) { - PrintBinary(op, "cmpi uge,", os, this); + PrintBinary(op, "cmpi uge,", os); } else { - PrintBinary(op, "cmpf oge,", os, this); + PrintBinary(op, "cmpf oge,", os); } } void CodeGenTileLangCOMMONIR::VisitExpr_(const GTNode *op, std::ostream &os) { if (op->a->dtype.is_int()) { - PrintBinary(op, "cmpi sgt,", os, this); + PrintBinary(op, "cmpi sgt,", os); } else if (op->a->dtype.is_uint()) { - PrintBinary(op, "cmpi ugt,", os, this); + PrintBinary(op, "cmpi ugt,", os); } else { - PrintBinary(op, "cmpf ogt,", os, this); + PrintBinary(op, "cmpf ogt,", os); } } @@ -760,6 +769,7 @@ String CodeGenTileLangCOMMONIR::GenSubviewFromRegion(Buffer buffer_data, return new_buffer_name; } + String CodeGenTileLangCOMMONIR::CreateMemrefToTensor(String src_data_name) { if (!dynamic_cast(type_info[src_data_name])) { LOG(FATAL) << src_data_name << " should be a memref"; @@ -771,11 +781,11 @@ String CodeGenTileLangCOMMONIR::CreateMemrefToTensor(String src_data_name) { std::ostringstream temp; temp << "bufferization.to_tensor %" << src_data_name << " restrict writable : " << GetMemrefInfo(src_data_name); - temp << " to " << GetTensorInfo(tempTensor); + temp << " to " << GetTensorInfo(tempTensor); new_tensor_name = SSAGetID(temp.str(), src_dtype); tempTensor->var_id = new_tensor_name; this->type_info_tensor[new_tensor_name] = tempTensor; - + return new_tensor_name; } @@ -829,13 +839,28 @@ void CodeGenTileLangCOMMONIR::VisitExpr_(const CallNode *op, std::ostream &os) { FillCodegen(op, os); } else if (op->op.same_as(Op::Get("tl.tileop.copy"))) { CopyCodegen(op, os); - } else if (op->op.same_as(Op::Get("tl.tileop.gemm")) || - op->op.same_as(Op::Get("tl.tileop.gemm_py"))) { + } else if (op->op.same_as(Op::Get("tl.tileop.gemm")) || op->op.same_as(Op::Get("tl.tileop.gemm_py"))) { GemmCodegen(op, os); + } else if (op->op.same_as(Op::Get("tl.infinity"))) { + InfinityCodegen(op, os); + } else if (op->op.same_as(Op::Get("tl.tileop.reduce"))) { + ReduceCodegen(op, os); + } else if (op->op.same_as(Op::Get("tir.rsqrt"))) { + StubCodegen(op, os, "tir.rsqrt"); + } else if (op->op.same_as(Op::Get("tir.sigmoid"))) { + StubCodegen(op, os, "tir.sigmoid"); + } else if (op->op.same_as(Op::Get("tir.exp"))) { + StubCodegen(op, os, "tir.exp"); } else { CodeGenC::VisitExpr_(op, os); } } +void CodeGenTileLangCOMMONIR::StubCodegen(const CallNode *op, + std::ostream &os, + String stubname) { + this->PrintIndent(); + this->stream << stubname << "\n"; +} void CodeGenTileLangCOMMONIR::FillCodegen(const CallNode *op, std::ostream &os) { @@ -895,11 +920,6 @@ void CodeGenTileLangCOMMONIR::CopyCodegen(const CallNode *op, void CodeGenTileLangCOMMONIR::GemmCodegen(const CallNode *op, std::ostream &os) { tvm::tl::Gemm gemmop(op->args); - // todo(dkx): support transpose indexing_maps - ICHECK(!gemmop->transA_) - << "Currently we only support: transA_ must be false"; - ICHECK(!gemmop->transB_) - << "Currently we only support: transB_ must be false"; // todo(dkx): support clearAccum_ = True ICHECK(is_zero(gemmop->clearAccum_)) << "Currently we only support: clearAccum_ must be const_false"; @@ -907,8 +927,7 @@ void CodeGenTileLangCOMMONIR::GemmCodegen(const CallNode *op, // ICHECK(gemmop->policy_ == tvm::tl::GemmWarpPolicyType::kSquare) // << "Currently we only support: policy_ must be kSquare"; ICHECK(gemmop->kPack_ == 1) << "Currently we only support: kPack_ must be 1"; - ICHECK(gemmop->wgWait_ == 0) - << "Currently we only support: wgWait_ must be 0"; + ICHECK(gemmop->wgWait_ == 0) << "Currently we only support: wgWait_ must be 0"; Buffer a_buffer = gemmop->a_; Buffer b_buffer = gemmop->b_; @@ -928,6 +947,23 @@ void CodeGenTileLangCOMMONIR::GemmCodegen(const CallNode *op, temp << "outs(\%" << new_tensor_name << " : " << GetTensorInfo(new_tensor_name) << ") -> " << GetTensorInfo(new_tensor_name); + + // todo(dkx): support transpose ops + std::string annotations = ""; + if (gemmop->transA_ || gemmop->transB_) { + annotations += " {"; + if (gemmop->transA_) { + annotations += "transA = 1"; + if (gemmop->transB_) { + annotations += ", transB = 1"; + } + } else { + annotations += "transB = 1"; + } + annotations += "}"; + } + temp << annotations; + String C_tensor_out = SSAGetID(temp.str(), dst_dtype); temp.str(""); temp.clear(); @@ -946,6 +982,37 @@ void CodeGenTileLangCOMMONIR::GemmCodegen(const CallNode *op, << GetMemrefInfo(c_buffer_name) << ") -> ()\n"; } +void CodeGenTileLangCOMMONIR::InfinityCodegen(const CallNode *op, + std::ostream &os) { + const DataType &dtype = op->dtype; + ICHECK_EQ(dtype.lanes(), 1); + if (dtype.is_float()) { + if (dtype.bits() == 64 || dtype.bits() == 32 || dtype.bits() == 16) { + PrimExpr imm_exp = + FloatImm(dtype, std::numeric_limits::infinity(), op->span); + os << SSAGetID(PrintExpr(imm_exp), dtype); + return; + } + } else if (dtype.is_bfloat16()) { + PrimExpr imm_exp = + FloatImm(dtype, std::numeric_limits::infinity(), op->span); + os << SSAGetID(PrintExpr(imm_exp), dtype); + return; + } + LOG(FATAL) << "Cannot decide infinity for type " << dtype; + throw; +} + +void CodeGenTileLangCOMMONIR::ReduceCodegen(const CallNode *op, + std::ostream &os) { + tvm::tl::ReduceOp reduceop(op->args); + // todo(dkx): support other reduce type + ICHECK(reduceop->type->isSum() || reduceop->type->isMax()) << "Currently we only support: sum or max"; + this->PrintIndent(); + this->stream << "linalg.reduce \n"; +} + + void CodeGenTileLangCOMMONIR::VisitStmt_(const LetStmtNode *op) { std::string value = PrintExpr(op->value); PrintIndent(); @@ -956,8 +1023,6 @@ void CodeGenTileLangCOMMONIR::VisitStmt_(const LetStmtNode *op) { void CodeGenTileLangCOMMONIR::VisitStmt_(const BufferStoreNode *op) { std::string value = SSAGetID(PrintExpr(op->value), op->value->dtype); - PrintIndent(); - Buffer buffer_data = op->buffer; DataType buffer_type = buffer_data->dtype; String buffer_name = buffer_data->name; @@ -976,6 +1041,7 @@ void CodeGenTileLangCOMMONIR::VisitStmt_(const BufferStoreNode *op) { } Array cast_index_array = GenConvertIndex(op->indices); + PrintIndent(); this->stream << "memref.store \%" + value + ", \%" + buffer_name_val; this->stream << "["; for (int i = 0; i < dim; i++) { @@ -993,39 +1059,31 @@ void CodeGenTileLangCOMMONIR::VisitStmt_(const BufferStoreNode *op) { void CodeGenTileLangCOMMONIR::VisitStmt_(const AttrStmtNode *op) { if (op->attr_key == "thread_extent") { IterVar iv = Downcast(op->node); - if (iv->thread_tag == "blockIdx.x" && iv->var->name_hint != "_") { - - std::ostringstream temp; - temp << "arith.constant 0" - << " : "; - PrintType(iv->var->dtype, temp); - std::string vid = SSAGetID(temp.str(), iv->var->dtype); - - auto block_id_ = AllocVarID(iv->var.get()); - this->PrintIndent(); - this->stream << "%" << block_id_ << " = arith.addi %" << vid << ", " - << this->thread_context_args[3] << ": i32\n"; - } else if (iv->thread_tag == "blockIdx.y" && iv->var->name_hint != "_") { + if ((iv->thread_tag == "blockIdx.x" || iv->thread_tag == "blockIdx.y" || + iv->thread_tag == "blockIdx.z") && + iv->var->name_hint != "_") { + int arg_index = -1; + if (iv->thread_tag == "blockIdx.x") { + arg_index = 3; + } else if (iv->thread_tag == "blockIdx.y") { + arg_index = 4; + } else if (iv->thread_tag == "blockIdx.z") { + arg_index = 5; + } std::ostringstream temp; temp << "arith.constant 0" << " : "; PrintType(iv->var->dtype, temp); std::string vid = SSAGetID(temp.str(), iv->var->dtype); - auto block_id_ = AllocVarID(iv->var.get()); this->PrintIndent(); this->stream << "%" << block_id_ << " = arith.addi %" << vid << ", " - << this->thread_context_args[4] << ": i32\n"; - } else if (iv->thread_tag == "blockIdx.z" && iv->var->name_hint != "_") { - std::ostringstream temp; - temp << "arith.constant 0" - << " : "; - PrintType(iv->var->dtype, temp); - std::string vid = SSAGetID(temp.str(), iv->var->dtype); + << this->thread_context_args[arg_index] << ": i32\n"; + } else if ((iv->thread_tag == "threadIdx.x" || iv->thread_tag == "threadIdx.y" || + iv->thread_tag == "threadIdx.z") && + iv->var->name_hint != "_") { + // todo(dkx): should handle this dilemma on npu auto block_id_ = AllocVarID(iv->var.get()); - this->PrintIndent(); - this->stream << "%" << block_id_ << " = arith.addi %" << vid << ", " - << this->thread_context_args[5] << ": i32\n"; } this->VisitStmt(op->body); return; @@ -1066,37 +1124,37 @@ void CodeGenTileLangCOMMONIR::VisitStmt_(const AllocateNode *op) { void CodeGenTileLangCOMMONIR::VisitExpr_(const MinNode *op, std::ostream &os) { if (op->dtype.is_int()) { - PrintBinary(op, "minsi", os, this); + PrintBinary(op, "minsi", os); } else if (op->dtype.is_uint()) { - PrintBinary(op, "minui", os, this); + PrintBinary(op, "minui", os); } else if (op->dtype.is_float()) { - PrintBinary(op, "minnumf", os, this); + PrintBinary(op, "minnumf", os); } } void CodeGenTileLangCOMMONIR::VisitExpr_(const MaxNode *op, std::ostream &os) { if (op->dtype.is_int()) { - PrintBinary(op, "maxsi", os, this); + PrintBinary(op, "maxsi", os); } else if (op->dtype.is_uint()) { - PrintBinary(op, "maxui", os, this); + PrintBinary(op, "maxui", os); } else if (op->dtype.is_float()) { - PrintBinary(op, "maxnumf", os, this); + PrintBinary(op, "maxnumf", os); } } void CodeGenTileLangCOMMONIR::VisitExpr_(const AddNode *op, std::ostream &os) { if (op->dtype.is_int() || op->dtype.is_uint()) { - PrintBinary(op, "addi", os, this); + PrintBinary(op, "addi", os); } else if (op->dtype.is_float()) { - PrintBinary(op, "addf", os, this); + PrintBinary(op, "addf", os); } } void CodeGenTileLangCOMMONIR::VisitExpr_(const SubNode *op, std::ostream &os) { if (op->dtype.is_int() || op->dtype.is_uint()) { - PrintBinary(op, "subi", os, this); + PrintBinary(op, "subi", os); } else if (op->dtype.is_float()) { - PrintBinary(op, "subf", os, this); + PrintBinary(op, "subf", os); } } @@ -1108,7 +1166,14 @@ void CodeGenTileLangCOMMONIR::VisitExpr_(const FloatImmNode *op, } else if (op->value == std::numeric_limits::infinity()) { temp << "arith.constant 0x7F800000 : "; } else { - temp << "arith.constant " << op->value << " : "; + temp << "arith.constant "; + double val = op->value; + if (std::floor(val) == val && std::isfinite(val)) { + temp << static_cast(val) << ".0"; + } else { + temp << val; + } + temp << " : "; } PrintType(op->dtype, temp); os << SSAGetID(temp.str(), op->dtype); @@ -1129,26 +1194,26 @@ void CodeGenTileLangCOMMONIR::VisitExpr_(const IntImmNode *op, void CodeGenTileLangCOMMONIR::VisitExpr_(const MulNode *op, std::ostream &os) { if (op->dtype.is_int() || op->dtype.is_uint()) { - PrintBinary(op, "muli", os, this); + PrintBinary(op, "muli", os); } else if (op->dtype.is_float()) { - PrintBinary(op, "mulf", os, this); + PrintBinary(op, "mulf", os); } } void CodeGenTileLangCOMMONIR::VisitExpr_(const AndNode *op, std::ostream &os) { assert(op->a.dtype().is_int() || op->a.dtype().is_uint()); assert(op->b.dtype().is_int() || op->b.dtype().is_uint()); - PrintBinary(op, "andi", os, this); + PrintBinary(op, "andi", os); } void CodeGenTileLangCOMMONIR::VisitExpr_(const OrNode *op, std::ostream &os) { assert(op->a.dtype().is_int() || op->a.dtype().is_uint()); assert(op->b.dtype().is_int() || op->b.dtype().is_uint()); - PrintBinary(op, "ori", os, this); + PrintBinary(op, "ori", os); } void CodeGenTileLangCOMMONIR::VisitExpr_(const DivNode *op, std::ostream &os) { - PrintBinary(op, "<<>>", os, this); + PrintBinary(op, "<<>>", os); } void CodeGenTileLangCOMMONIR::VisitExpr_(const SelectNode *op, diff --git a/commonir/src/target/codegen_commonir.h b/commonir/src/target/codegen_commonir.h index 4664c7b..6a6d0f0 100644 --- a/commonir/src/target/codegen_commonir.h +++ b/commonir/src/target/codegen_commonir.h @@ -18,8 +18,8 @@ namespace tvm { namespace codegen { -using ffi::Array; using ffi::String; +using ffi::Array; class SSAType { public: @@ -125,6 +125,8 @@ class CodeGenTileLangCOMMONIR final : public CodeGenC { void AddFunction(const GlobalVar &gvar, const PrimFunc &f); private: + template + inline void PrintBinary(const T *op, const char *opstr, std::ostream &os); Array GenConvertIndex(Array exprs); String GenSubviewFromRegion(const CallNode *region_node); String GenSubviewFromRegion(Buffer buffer_data, Array range); @@ -141,6 +143,9 @@ class CodeGenTileLangCOMMONIR final : public CodeGenC { void FillCodegen(const CallNode *op, std::ostream &os); void CopyCodegen(const CallNode *op, std::ostream &os); void GemmCodegen(const CallNode *op, std::ostream &os); + void InfinityCodegen(const CallNode *op, std::ostream &os); + void ReduceCodegen(const CallNode *op, std::ostream &os); + void StubCodegen(const CallNode *op, std::ostream &os, String stubname); // save memref name and type std::map type_info; diff --git a/test/commonir/add_vector.py b/test/commonir/add_vector.py new file mode 100644 index 0000000..fdfbe80 --- /dev/null +++ b/test/commonir/add_vector.py @@ -0,0 +1,62 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. +import os + +import tilelang +import tilelang.language as T + +import torch + +# tilelang.cache.clear_cache() + +dtype = "float32" +seq_len = 1024 + + +def vec_add(N, block_N, dtype="float32"): + n_num = N // block_N + + @T.prim_func + def main( + A: T.Tensor((N), dtype), + B: T.Tensor((N), dtype), + C: T.Tensor((N), dtype), + ): + with T.Kernel(n_num, 1) as (by, bx): + start_y1 = by * block_N + start_y = start_y1 + bx + for (local_y) in T.Parallel(block_N): + y = start_y + local_y + C[y] = A[y] + B[y] + + return main + + +def test_vec_add(): + func = vec_add(seq_len, seq_len // 4) + compiled_kernel = tilelang.compile(func, target='commonir') + + v1 = torch.randn(size=[seq_len], dtype=eval("torch." + dtype)).npu() + v2 = torch.randn(size=[seq_len], dtype=eval("torch." + dtype)).npu() + v3 = torch.zeros(size=[seq_len], dtype=eval("torch." + dtype)).npu() + + y_ref = v1 + v2 + compiled_kernel(v1, v2, v3) + + # print(y_ref) + # print(v3) + + print(f'The maximum difference between torch and Tilellang is ' + f'{torch.max(torch.abs(y_ref - v3))}') + + if torch.allclose(v3, y_ref, atol=1e-2, rtol=0): + print("✅ Tilellang and Torch match") + else: + print("❌ Tilellang and Torch differ") + diff = torch.abs(v3 - y_ref) + print(f"Max diff: {diff.max().item()}") + print(f"Mean diff: {diff.mean().item()}") + + +if __name__ == "__main__": + test_vec_add() \ No newline at end of file diff --git a/test/commonir/gemm.py b/test/commonir/gemm.py new file mode 100644 index 0000000..7b55db3 --- /dev/null +++ b/test/commonir/gemm.py @@ -0,0 +1,75 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. + +import tilelang +import tilelang.language as T + +import torch +import torch_npu +device = torch.npu.current_device() +dtype = torch.float16 +# tilelang.cache.clear_cache() + +def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): + + @T.prim_func + def gemm( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm + + +def main(): + func = matmul(1024, 1024, 1024, 128, 128, 32) + kernel = tilelang.compile(func, target='commonir') + SIZEALL = 1024 + + torch.manual_seed(0) + a = torch.rand((SIZEALL, SIZEALL), dtype=dtype, device=device) - 0.5 + b = torch.rand((SIZEALL, SIZEALL), dtype=dtype, device=device) - 0.5 + result = torch.zeros((SIZEALL, SIZEALL), dtype=dtype, device=device) + + kernel(a, b, result) + golden = a @ b + mask = golden.abs() < 1.0 + tmpatol = tmprtol = 2**-6 + # try: + # torch.testing.assert_close(result[mask], golden[mask], atol=tmpatol, rtol=0) + # torch.testing.assert_close(result[~mask], golden[~mask], atol=0, rtol=tmprtol) + # print("run matmul success") + # except: + # print(f"[ERROR] 存在精度问题") + # # max diff + # max_diff = torch.max(torch.abs(result - golden)) + # print(f"[ERROR] max diff: {max_diff}") + # # max diff index + # max_diff_index = torch.argmax(torch.abs(result - golden)) + # print(f"[ERROR] max diff index: {max_diff_index}") + # print(f"[ERROR] result: {result}") + print(f"result is {result}, golden is {golden}") + if torch.allclose(result, golden, atol=1e-2, rtol=1e-2): + print("✅ Tilellang and Torch match") + else: + print("❌ Tilellang and Torch differ") + diff = torch.abs(result - golden) + print(f"Max diff: {diff.max().item()}") + print(f"Mean diff: {diff.mean().item()}") + + +if __name__ == "__main__": + main() \ No newline at end of file From 2a8a36d61b3250c9d4f6c821211a0018e9a1e248 Mon Sep 17 00:00:00 2001 From: cokedong <408244909@qq.com> Date: Mon, 2 Mar 2026 09:15:34 +0000 Subject: [PATCH 02/10] fix formate --- commonir/src/target/codegen_commonir.cc | 26 ++++++++++++------------- commonir/src/target/codegen_commonir.h | 2 +- test/commonir/add_vector.py | 18 +++++++++-------- test/commonir/gemm.py | 17 ++++++++++------ 4 files changed, 35 insertions(+), 28 deletions(-) diff --git a/commonir/src/target/codegen_commonir.cc b/commonir/src/target/codegen_commonir.cc index baa15ef..8d5aa6a 100644 --- a/commonir/src/target/codegen_commonir.cc +++ b/commonir/src/target/codegen_commonir.cc @@ -177,7 +177,6 @@ static std::string broadcastAttrCodegen(Array &buffer_shape0, return temp.str(); } - void PrintBufferMap(const Map &buffer_map) { for (const auto &kv : buffer_map) { const Var &var = kv.first; @@ -769,7 +768,6 @@ String CodeGenTileLangCOMMONIR::GenSubviewFromRegion(Buffer buffer_data, return new_buffer_name; } - String CodeGenTileLangCOMMONIR::CreateMemrefToTensor(String src_data_name) { if (!dynamic_cast(type_info[src_data_name])) { LOG(FATAL) << src_data_name << " should be a memref"; @@ -781,11 +779,11 @@ String CodeGenTileLangCOMMONIR::CreateMemrefToTensor(String src_data_name) { std::ostringstream temp; temp << "bufferization.to_tensor %" << src_data_name << " restrict writable : " << GetMemrefInfo(src_data_name); - temp << " to " << GetTensorInfo(tempTensor); + temp << " to " << GetTensorInfo(tempTensor); new_tensor_name = SSAGetID(temp.str(), src_dtype); tempTensor->var_id = new_tensor_name; this->type_info_tensor[new_tensor_name] = tempTensor; - + return new_tensor_name; } @@ -839,7 +837,8 @@ void CodeGenTileLangCOMMONIR::VisitExpr_(const CallNode *op, std::ostream &os) { FillCodegen(op, os); } else if (op->op.same_as(Op::Get("tl.tileop.copy"))) { CopyCodegen(op, os); - } else if (op->op.same_as(Op::Get("tl.tileop.gemm")) || op->op.same_as(Op::Get("tl.tileop.gemm_py"))) { + } else if (op->op.same_as(Op::Get("tl.tileop.gemm")) || + op->op.same_as(Op::Get("tl.tileop.gemm_py"))) { GemmCodegen(op, os); } else if (op->op.same_as(Op::Get("tl.infinity"))) { InfinityCodegen(op, os); @@ -855,8 +854,7 @@ void CodeGenTileLangCOMMONIR::VisitExpr_(const CallNode *op, std::ostream &os) { CodeGenC::VisitExpr_(op, os); } } -void CodeGenTileLangCOMMONIR::StubCodegen(const CallNode *op, - std::ostream &os, +void CodeGenTileLangCOMMONIR::StubCodegen(const CallNode *op, std::ostream &os, String stubname) { this->PrintIndent(); this->stream << stubname << "\n"; @@ -927,7 +925,8 @@ void CodeGenTileLangCOMMONIR::GemmCodegen(const CallNode *op, // ICHECK(gemmop->policy_ == tvm::tl::GemmWarpPolicyType::kSquare) // << "Currently we only support: policy_ must be kSquare"; ICHECK(gemmop->kPack_ == 1) << "Currently we only support: kPack_ must be 1"; - ICHECK(gemmop->wgWait_ == 0) << "Currently we only support: wgWait_ must be 0"; + ICHECK(gemmop->wgWait_ == 0) + << "Currently we only support: wgWait_ must be 0"; Buffer a_buffer = gemmop->a_; Buffer b_buffer = gemmop->b_; @@ -1007,12 +1006,12 @@ void CodeGenTileLangCOMMONIR::ReduceCodegen(const CallNode *op, std::ostream &os) { tvm::tl::ReduceOp reduceop(op->args); // todo(dkx): support other reduce type - ICHECK(reduceop->type->isSum() || reduceop->type->isMax()) << "Currently we only support: sum or max"; + ICHECK(reduceop->type->isSum() || reduceop->type->isMax()) + << "Currently we only support: sum or max"; this->PrintIndent(); this->stream << "linalg.reduce \n"; } - void CodeGenTileLangCOMMONIR::VisitStmt_(const LetStmtNode *op) { std::string value = PrintExpr(op->value); PrintIndent(); @@ -1079,9 +1078,10 @@ void CodeGenTileLangCOMMONIR::VisitStmt_(const AttrStmtNode *op) { this->PrintIndent(); this->stream << "%" << block_id_ << " = arith.addi %" << vid << ", " << this->thread_context_args[arg_index] << ": i32\n"; - } else if ((iv->thread_tag == "threadIdx.x" || iv->thread_tag == "threadIdx.y" || - iv->thread_tag == "threadIdx.z") && - iv->var->name_hint != "_") { + } else if ((iv->thread_tag == "threadIdx.x" || + iv->thread_tag == "threadIdx.y" || + iv->thread_tag == "threadIdx.z") && + iv->var->name_hint != "_") { // todo(dkx): should handle this dilemma on npu auto block_id_ = AllocVarID(iv->var.get()); } diff --git a/commonir/src/target/codegen_commonir.h b/commonir/src/target/codegen_commonir.h index 6a6d0f0..d638813 100644 --- a/commonir/src/target/codegen_commonir.h +++ b/commonir/src/target/codegen_commonir.h @@ -18,8 +18,8 @@ namespace tvm { namespace codegen { -using ffi::String; using ffi::Array; +using ffi::String; class SSAType { public: diff --git a/test/commonir/add_vector.py b/test/commonir/add_vector.py index fdfbe80..03c5007 100644 --- a/test/commonir/add_vector.py +++ b/test/commonir/add_vector.py @@ -18,14 +18,14 @@ def vec_add(N, block_N, dtype="float32"): @T.prim_func def main( - A: T.Tensor((N), dtype), - B: T.Tensor((N), dtype), - C: T.Tensor((N), dtype), + A: T.Tensor((N), dtype), + B: T.Tensor((N), dtype), + C: T.Tensor((N), dtype), ): with T.Kernel(n_num, 1) as (by, bx): start_y1 = by * block_N start_y = start_y1 + bx - for (local_y) in T.Parallel(block_N): + for local_y in T.Parallel(block_N): y = start_y + local_y C[y] = A[y] + B[y] @@ -34,7 +34,7 @@ def main( def test_vec_add(): func = vec_add(seq_len, seq_len // 4) - compiled_kernel = tilelang.compile(func, target='commonir') + compiled_kernel = tilelang.compile(func, target="commonir") v1 = torch.randn(size=[seq_len], dtype=eval("torch." + dtype)).npu() v2 = torch.randn(size=[seq_len], dtype=eval("torch." + dtype)).npu() @@ -46,8 +46,10 @@ def test_vec_add(): # print(y_ref) # print(v3) - print(f'The maximum difference between torch and Tilellang is ' - f'{torch.max(torch.abs(y_ref - v3))}') + print( + f"The maximum difference between torch and Tilellang is " + f"{torch.max(torch.abs(y_ref - v3))}" + ) if torch.allclose(v3, y_ref, atol=1e-2, rtol=0): print("✅ Tilellang and Torch match") @@ -59,4 +61,4 @@ def test_vec_add(): if __name__ == "__main__": - test_vec_add() \ No newline at end of file + test_vec_add() diff --git a/test/commonir/gemm.py b/test/commonir/gemm.py index 7b55db3..e6b9d19 100644 --- a/test/commonir/gemm.py +++ b/test/commonir/gemm.py @@ -6,19 +6,24 @@ import torch import torch_npu + device = torch.npu.current_device() dtype = torch.float16 # tilelang.cache.clear_cache() + def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @T.prim_func def gemm( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as ( + bx, + by, + ): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -36,7 +41,7 @@ def gemm( def main(): func = matmul(1024, 1024, 1024, 128, 128, 32) - kernel = tilelang.compile(func, target='commonir') + kernel = tilelang.compile(func, target="commonir") SIZEALL = 1024 torch.manual_seed(0) @@ -72,4 +77,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() From 3b5e989a84bc8e596f0f908f79211fc6ed967368 Mon Sep 17 00:00:00 2001 From: cokedong <408244909@qq.com> Date: Mon, 2 Mar 2026 09:21:51 +0000 Subject: [PATCH 03/10] fix --- backend/commonir/adapter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/commonir/adapter.py b/backend/commonir/adapter.py index d136a59..c142ad6 100644 --- a/backend/commonir/adapter.py +++ b/backend/commonir/adapter.py @@ -199,6 +199,7 @@ def _read_mlir_file(cls, file_path) -> str: except Exception as e: print(f"Error occurred while reading the file: {e}") return None + @classmethod def _write_mlir_file(cls, file_path, mlir_content): try: From 271e034b8d39734ef7dfd30f6b1975a2a7071fb1 Mon Sep 17 00:00:00 2001 From: cokedong <408244909@qq.com> Date: Tue, 3 Mar 2026 01:25:32 +0000 Subject: [PATCH 04/10] fix --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 7bbdde0..37ada4f 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -80,7 +80,7 @@ jobs: run: | set -ex source /home/dlc_ci/.bashrc - source activate dlcompiler + conda activate dlcompiler source /usr/local/Ascend/ascend-toolkit/set_env.sh cd ${{env.CI_PATH }} echo "whoami? $(whoami)" From 1f03c96a4d668a902f59af6adae495070426a227 Mon Sep 17 00:00:00 2001 From: cokedong <408244909@qq.com> Date: Tue, 3 Mar 2026 01:52:52 +0000 Subject: [PATCH 05/10] rerun --- .github/workflows/main.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 37ada4f..d786c60 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -10,7 +10,7 @@ on: - main env: - CI_PATH: "${{ vars.CI_BASE_PATH }}/GitHub/${{ github.repository }}/${GITHUB_RUN_NUMBER}" + CI_PATH: "${{ vars.CI_BASE_PATH }}/GitHub/${{ github.repository }}/${{ github.run_id }}" THIRD_PARTY_PATH: "${{ vars.CI_BASE_PATH }}/data/DLCompiler/third_party" concurrency: @@ -80,7 +80,7 @@ jobs: run: | set -ex source /home/dlc_ci/.bashrc - conda activate dlcompiler + source activate dlcompiler source /usr/local/Ascend/ascend-toolkit/set_env.sh cd ${{env.CI_PATH }} echo "whoami? $(whoami)" From 843ec6a8084d537659e3e3a968721328da7dce96 Mon Sep 17 00:00:00 2001 From: cokedong <408244909@qq.com> Date: Tue, 3 Mar 2026 01:58:50 +0000 Subject: [PATCH 06/10] fix --- .github/workflows/main.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d786c60..a53ba73 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -10,7 +10,7 @@ on: - main env: - CI_PATH: "${{ vars.CI_BASE_PATH }}/GitHub/${{ github.repository }}/${{ github.run_id }}" + CI_PATH: "${{ vars.CI_BASE_PATH }}/GitHub/${{ github.repository }}/${GITHUB_RUN_NUMBER}/${{ github.run_id }}" THIRD_PARTY_PATH: "${{ vars.CI_BASE_PATH }}/data/DLCompiler/third_party" concurrency: @@ -31,6 +31,7 @@ jobs: - name: Create custom directory run: | set -ex + echo ${{ env.CI_BASE_PATH }} echo ${{ env.CI_PATH }} mkdir -p ${{ env.CI_PATH }} From fa10f83817a8f1f33fee7e95b888663ccbe99952 Mon Sep 17 00:00:00 2001 From: cokedong <408244909@qq.com> Date: Tue, 3 Mar 2026 02:08:29 +0000 Subject: [PATCH 07/10] fix --- .github/workflows/main.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index a53ba73..1d297be 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -31,7 +31,7 @@ jobs: - name: Create custom directory run: | set -ex - echo ${{ env.CI_BASE_PATH }} + echo ${{ vars.CI_BASE_PATH }} echo ${{ env.CI_PATH }} mkdir -p ${{ env.CI_PATH }} @@ -117,7 +117,7 @@ jobs: bash test/dsl/run_tests.sh - name: Clear workfile - if: always() + if: success() run: | export workdir=$(pwd) cd .. From a125db71f2e6b164b63348568697b4b40104eaef Mon Sep 17 00:00:00 2001 From: cokedong <408244909@qq.com> Date: Tue, 3 Mar 2026 02:18:48 +0000 Subject: [PATCH 08/10] fix --- .github/workflows/main.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 1d297be..d812e4c 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -27,6 +27,8 @@ jobs: uses: actions/checkout@v4 with: ssh-key: ${{ secrets.SSH_PRIVATE_KEY_DLC_CI }} + timeout-minutes: 5 + persist-credentials: false - name: Create custom directory run: | From 7a00a19e47e66e59d55e978474c5e04546e4430f Mon Sep 17 00:00:00 2001 From: cokedong <408244909@qq.com> Date: Tue, 3 Mar 2026 03:10:19 +0000 Subject: [PATCH 09/10] revert --- .github/workflows/main.yml | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d812e4c..a9592fd 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -10,7 +10,7 @@ on: - main env: - CI_PATH: "${{ vars.CI_BASE_PATH }}/GitHub/${{ github.repository }}/${GITHUB_RUN_NUMBER}/${{ github.run_id }}" + CI_PATH: "${{ vars.CI_BASE_PATH }}/GitHub/${{ github.repository }}/${GITHUB_RUN_NUMBER}" THIRD_PARTY_PATH: "${{ vars.CI_BASE_PATH }}/data/DLCompiler/third_party" concurrency: @@ -27,13 +27,10 @@ jobs: uses: actions/checkout@v4 with: ssh-key: ${{ secrets.SSH_PRIVATE_KEY_DLC_CI }} - timeout-minutes: 5 - persist-credentials: false - name: Create custom directory run: | set -ex - echo ${{ vars.CI_BASE_PATH }} echo ${{ env.CI_PATH }} mkdir -p ${{ env.CI_PATH }} @@ -119,7 +116,7 @@ jobs: bash test/dsl/run_tests.sh - name: Clear workfile - if: success() + if: always() run: | export workdir=$(pwd) cd .. @@ -128,4 +125,4 @@ jobs: chmod -R 777 $workdir if [ -d "${{ env.CI_PATH }}" ]; then rm -rf ${{ env.CI_PATH }} - fi + fi \ No newline at end of file From 7668adc12d13c59562dbae701b7fd45a50c3b068 Mon Sep 17 00:00:00 2001 From: cokedong <408244909@qq.com> Date: Tue, 3 Mar 2026 03:11:37 +0000 Subject: [PATCH 10/10] fix --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index a9592fd..7bbdde0 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -125,4 +125,4 @@ jobs: chmod -R 777 $workdir if [ -d "${{ env.CI_PATH }}" ]; then rm -rf ${{ env.CI_PATH }} - fi \ No newline at end of file + fi