diff --git a/.gitignore b/.gitignore index 340ae2678b8..7ffccb559d5 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ .nfs* tags MANIFEST +CMakeLists.txt.user build/ docs/_build/ diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt new file mode 100644 index 00000000000..9d3b7b20117 --- /dev/null +++ b/mlir-compiler/CMakeLists.txt @@ -0,0 +1,25 @@ +cmake_minimum_required(VERSION 3.5) + +project(mlir_compiler LANGUAGES CXX C) + +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +add_subdirectory(mlir-compiler) + +if(DEFINED DPCOMP_TREE) + message(STATUS "Out of tree PLIER is used") + add_subdirectory(${DPCOMP_TREE} plier) +elseif(DEFINED DPCOMP_DIR) + message(STATUS "PLIER from DPCOMP_DIR is used") + target_include_directories(${PROJECT_NAME} PRIVATE + ${DPCOMP_DIR}/include + ) + target_link_directories(${PROJECT_NAME} PRIVATE + ${DPCOMP_DIR} + ${DPCOMP_DIR}/Release + ) +else() + message(FATAL_ERROR "dpcomp not found") +endif() diff --git a/mlir-compiler/llvm-sha.txt b/mlir-compiler/llvm-sha.txt new file mode 100644 index 00000000000..0dea32271e1 --- /dev/null +++ b/mlir-compiler/llvm-sha.txt @@ -0,0 +1 @@ +7b153b43d3a14d76975039408c4b922beb576735 diff --git a/mlir-compiler/mlir-compiler/CMakeLists.txt b/mlir-compiler/mlir-compiler/CMakeLists.txt new file mode 100644 index 00000000000..5458d260ab8 --- /dev/null +++ b/mlir-compiler/mlir-compiler/CMakeLists.txt @@ -0,0 +1,76 @@ + +find_package(pybind11 REQUIRED) + +find_package(LLVM REQUIRED CONFIG) +find_package(MLIR REQUIRED CONFIG) + +list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") +include(TableGen) +include(AddLLVM) +include(AddMLIR) +include(HandleLLVMOptions) + +set(SOURCES_LIST + src/pipelines/base_pipeline.cpp + src/pipelines/lower_to_llvm.cpp + src/pipelines/parallel_to_tbb.cpp + src/pipelines/plier_to_linalg.cpp + src/pipelines/plier_to_std.cpp + src/lowering.cpp + src/mangle.cpp + src/py_func_resolver.cpp + src/py_linalg_resolver.cpp + src/py_map_types.cpp + src/py_module.cpp + ) +set(HEADERS_LIST + src/pipelines/base_pipeline.hpp + src/pipelines/lower_to_llvm.hpp + src/pipelines/parallel_to_tbb.hpp + src/pipelines/plier_to_linalg.hpp + src/pipelines/plier_to_std.hpp + src/lowering.hpp + src/mangle.hpp + src/py_func_resolver.hpp + src/py_linalg_resolver.hpp + src/py_map_types.hpp + src/py_module.hpp + ) + +pybind11_add_module(${PROJECT_NAME} ${SOURCES_LIST} ${HEADERS_LIST}) + +if (MSVC) + target_compile_options(${PROJECT_NAME} PRIVATE /EHsc) +endif () + +if (CMAKE_SYSTEM_NAME STREQUAL Linux) + target_link_options(${PROJECT_NAME} PRIVATE "LINKER:--version-script=export.txt") +endif() + +if (CMAKE_SYSTEM_NAME STREQUAL Darwin) + target_link_libraries(${PROJECT_NAME} PRIVATE "-Wl,-exported_symbols_list,export_darwin.txt") +endif() + +target_compile_definitions(${PROJECT_NAME} PRIVATE ${LLVM_DEFINITIONS}) + +target_link_libraries(${PROJECT_NAME} PRIVATE + plier + LLVM${LLVM_NATIVE_ARCH}CodeGen + LLVM${LLVM_NATIVE_ARCH}Desc + LLVMTarget + MLIRIR + MLIRLLVMIR + MLIRLLVMToLLVMIRTranslation + MLIRTransforms + MLIRStandardOpsTransforms + MLIRLinalgTransforms + MLIRSCFToStandard + MLIRTensorTransforms + ) + +target_include_directories(${PROJECT_NAME} PRIVATE + ./src + ${LLVM_INCLUDE_DIRS} + ${MLIR_INCLUDE_DIRS} + ) diff --git a/mlir-compiler/mlir-compiler/export.txt b/mlir-compiler/mlir-compiler/export.txt new file mode 100644 index 00000000000..dce06e3f92c --- /dev/null +++ b/mlir-compiler/mlir-compiler/export.txt @@ -0,0 +1,4 @@ +{ + global: PyInit_mlir_compiler; + local: *; +}; diff --git a/mlir-compiler/mlir-compiler/export_darwin.txt b/mlir-compiler/mlir-compiler/export_darwin.txt new file mode 100644 index 00000000000..a5d5900af2d --- /dev/null +++ b/mlir-compiler/mlir-compiler/export_darwin.txt @@ -0,0 +1 @@ +_PyInit_mlir_compiler diff --git a/mlir-compiler/mlir-compiler/src/lowering.cpp b/mlir-compiler/mlir-compiler/src/lowering.cpp new file mode 100644 index 00000000000..6c34ac75278 --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/lowering.cpp @@ -0,0 +1,749 @@ +#include "lowering.hpp" + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +#include +#include + +#include + +#include "plier/dialect.hpp" + +#include "plier/compiler/compiler.hpp" +#include "plier/compiler/pipeline_registry.hpp" +#include "plier/utils.hpp" + +#include "pipelines/base_pipeline.hpp" +#include "pipelines/parallel_to_tbb.hpp" +#include "pipelines/plier_to_std.hpp" +#include "pipelines/plier_to_linalg.hpp" +#include "pipelines/lower_to_llvm.hpp" + +namespace py = pybind11; +namespace +{ + +std::string serialize_mod(const llvm::Module& mod) +{ + std::string ret; + llvm::raw_string_ostream stream(ret); + llvm::WriteBitcodeToFile(mod, stream); + stream.flush(); + return ret; +} + +std::vector> get_blocks(const py::object& func) +{ + std::vector> ret; + auto blocks = func.attr("blocks").cast(); + ret.reserve(blocks.size()); + for (auto it : blocks) + { + ret.push_back({it.first.cast(), it.second}); + } + return ret; +} + +py::list get_body(const py::handle& block) +{ + return block.attr("body").cast(); +} + +struct OpId +{ + llvm::StringRef op; + llvm::StringRef name; +}; + +static const constexpr OpId inst_ops_names[] = { + {"+", "add"}, // binary + {"+", "pos"}, // unary + {"-", "sub"}, // binary + {"-", "neg"}, // unary + {"*", "mul"}, + {"/", "truediv"}, + {"//", "floordiv"}, + {"%", "mod"}, + + {">", "gt"}, + {">=", "ge"}, + {"<", "lt"}, + {"<=", "le"}, + {"!=", "ne"}, + {"==", "eq"}, +}; + +struct inst_handles +{ + inst_handles() + { + auto mod = py::module::import("numba.core.ir"); + Assign = mod.attr("Assign"); + Del = mod.attr("Del"); + Return = mod.attr("Return"); + Branch = mod.attr("Branch"); + Jump = mod.attr("Jump"); + SetItem = mod.attr("SetItem"); + StaticSetItem = mod.attr("StaticSetItem"); + + Arg = mod.attr("Arg"); + Expr = mod.attr("Expr"); + Var = mod.attr("Var"); + Const = mod.attr("Const"); + Global = mod.attr("Global"); + FreeVar = mod.attr("FreeVar"); + + auto ops = py::module::import("operator"); + + for (auto elem : llvm::zip(inst_ops_names, ops_handles)) + { + auto name = std::get<0>(elem).name; + std::get<1>(elem) = ops.attr(name.data()); + } + } + + py::handle Assign; + py::handle Del; + py::handle Return; + py::handle Branch; + py::handle Jump; + py::handle SetItem; + py::handle StaticSetItem; + + py::handle Arg; + py::handle Expr; + py::handle Var; + py::handle Const; + py::handle Global; + py::handle FreeVar; + + std::array ops_handles; +}; + +struct plier_lowerer final +{ + plier_lowerer(mlir::MLIRContext& context): + ctx(context), + builder(&ctx) + { + ctx.loadDialect(); + ctx.loadDialect(); + } + + mlir::FuncOp lower(const py::object& compilation_context, mlir::ModuleOp mod, const py::object& func_ir) + { + + typemap = compilation_context["typemap"]; + func_name_resolver = compilation_context["resolve_func"]; + auto name = compilation_context["fnname"]().cast(); + auto typ = get_func_type(compilation_context["fnargs"], compilation_context["restype"]); + func = mlir::FuncOp::create(builder.getUnknownLoc(), name, typ); + if (compilation_context["fastmath"]().cast()) + { + func->setAttr(plier::attributes::getFastmathName(), mlir::UnitAttr::get(&ctx)); + } + auto max_concurrency = compilation_context["max_concurrency"]().cast(); + if (max_concurrency > 0) + { + mod->setAttr(plier::attributes::getMaxConcurrencyName(), builder.getI64IntegerAttr(max_concurrency)); + } + lower_func_body(func_ir); + mod.push_back(func); + return func; + } +private: + mlir::MLIRContext& ctx; + mlir::OpBuilder builder; + std::vector blocks; + std::unordered_map blocks_map; + inst_handles insts; + mlir::FuncOp func; + std::unordered_map vars_map; + struct BlockInfo + { + struct PhiDesc + { + mlir::Block* dest_block = nullptr; + std::string var_name; + unsigned arg_index = 0; + }; + llvm::SmallVector outgoing_phi_nodes; + }; + py::handle current_instr; + py::handle typemap; + py::handle func_name_resolver; + + std::unordered_map block_infos; + + plier::PyType get_obj_type(const py::handle& obj) const + { + return plier::PyType::get(&ctx, py::str(obj).cast()); + } + + plier::PyType get_type(const py::handle& inst) const + { + auto type = typemap(inst); + return get_obj_type(type); + } + + void lower_func_body(const py::object& func_ir) + { + auto ir_blocks = get_blocks(func_ir); + assert(!ir_blocks.empty()); + blocks.reserve(ir_blocks.size()); + for (std::size_t i = 0; i < ir_blocks.size(); ++i) + { + auto block = (0 == i ? func.addEntryBlock() : func.addBlock()); + blocks.push_back(block); + blocks_map[ir_blocks[i].first] = block; + } + + for (std::size_t i = 0; i < ir_blocks.size(); ++i) + { + lower_block(blocks[i], ir_blocks[i].second); + } + fixup_phis(); + } + + void lower_block(mlir::Block* bb, const py::handle& ir_block) + { + assert(nullptr != bb); + builder.setInsertionPointToEnd(bb); + for (auto it : get_body(ir_block)) + { + current_instr = it; + lower_inst(it); + current_instr = nullptr; + } + } + + void lower_inst(const py::handle& inst) + { + if (py::isinstance(inst, insts.Assign)) + { + auto target = inst.attr("target"); + auto val = lower_assign(inst, target); + storevar(val, target); + } + else if (py::isinstance(inst, insts.SetItem) || + py::isinstance(inst, insts.StaticSetItem)) + { + setitem(inst.attr("target"), inst.attr("index"), inst.attr("value")); + } + else if (py::isinstance(inst, insts.Del)) + { + delvar(inst.attr("value")); + } + else if (py::isinstance(inst, insts.Return)) + { + retvar(inst.attr("value")); + } + else if (py::isinstance(inst, insts.Branch)) + { + branch(inst.attr("cond"), inst.attr("truebr"), inst.attr("falsebr")); + } + else if (py::isinstance(inst, insts.Jump)) + { + jump(inst.attr("target")); + } + else + { + plier::report_error(llvm::Twine("lower_inst not handled: \"") + py::str(inst.get_type()).cast() + "\""); + } + } + + mlir::Value lower_assign(const py::handle& inst, const py::handle& target) + { + auto value = inst.attr("value"); + if (py::isinstance(value, insts.Arg)) + { + auto index = value.attr("index").cast(); + return builder.create(get_current_loc(), index, + target.attr("name").cast()); + } + if(py::isinstance(value, insts.Expr)) + { + return lower_expr(value); + } + if(py::isinstance(value, insts.Var)) + { + return loadvar(value); + } + if (py::isinstance(value, insts.Const)) + { + return get_const(value.attr("value")); + } + if (py::isinstance(value, insts.Global) || + py::isinstance(value, insts.FreeVar)) + { + auto name = value.attr("name").cast(); + return builder.create(get_current_loc(), + name); + } + + plier::report_error(llvm::Twine("lower_assign not handled: \"") + py::str(value.get_type()).cast() + "\""); + } + + mlir::Value lower_expr(const py::handle& expr) + { + auto op = expr.attr("op").cast(); + using func_t = mlir::Value (plier_lowerer::*)(const py::handle&); + const std::pair handlers[] = { + {"binop", &plier_lowerer::lower_binop}, + {"inplace_binop", &plier_lowerer::lower_inplce_binop}, + {"unary", &plier_lowerer::lower_unary}, + {"cast", &plier_lowerer::lower_cast}, + {"call", &plier_lowerer::lower_call}, + {"phi", &plier_lowerer::lower_phi}, + {"build_tuple", &plier_lowerer::lower_build_tuple}, + {"getitem", &plier_lowerer::lower_getitem}, + {"static_getitem", &plier_lowerer::lower_static_getitem}, + {"getiter", &plier_lowerer::lower_simple}, + {"iternext", &plier_lowerer::lower_simple}, + {"pair_first", &plier_lowerer::lower_simple}, + {"pair_second", &plier_lowerer::lower_simple}, + {"getattr", &plier_lowerer::lower_getattr}, + }; + for (auto& h : handlers) + { + if (h.first == op) + { + return (this->*h.second)(expr); + } + } + plier::report_error(llvm::Twine("lower_expr not handled: \"") + op + "\""); + } + + template + mlir::Value lower_simple(const py::handle& inst) + { + auto value = loadvar(inst.attr("value")); + return builder.create(get_current_loc(), value); + } + + mlir::Value lower_cast(const py::handle& inst) + { + auto value = loadvar(inst.attr("value")); + auto res_type = get_type(current_instr.attr("target")); + return builder.create(get_current_loc(), res_type, value); + } + + mlir::Value lower_getitem(const py::handle& inst) + { + auto value = loadvar(inst.attr("value")); + auto index = loadvar(inst.attr("index")); + return builder.create(get_current_loc(), value, index); + } + + mlir::Value lower_static_getitem(const py::handle& inst) + { + auto value = loadvar(inst.attr("value")); + auto index_var = loadvar(inst.attr("index_var")); + auto index = inst.attr("index").cast(); + return builder.create(get_current_loc(), + value, index_var, index); + } + + mlir::Value lower_build_tuple(const py::handle& inst) + { + auto items = inst.attr("items").cast(); + mlir::SmallVector args; + for (auto item : items) + { + args.push_back(loadvar(item)); + } + return builder.create(get_current_loc(), args); + } + + mlir::Value lower_phi(const py::handle& expr) + { + auto incoming_vals = expr.attr("incoming_values").cast(); + auto incoming_blocks = expr.attr("incoming_blocks").cast(); + assert(incoming_vals.size() == incoming_blocks.size()); + + auto current_block = builder.getBlock(); + assert(nullptr != current_block); + + auto arg_index = current_block->getNumArguments(); + auto arg = current_block->addArgument(get_type(current_instr.attr("target"))); + + auto count = incoming_vals.size(); + for (std::size_t i = 0; i < count; ++i) + { + auto var = incoming_vals[i].attr("name").cast(); + auto block = blocks_map.find(incoming_blocks[i].cast())->second; + block_infos[block].outgoing_phi_nodes.push_back({current_block, std::move(var), arg_index}); + } + + return arg; + } + + + mlir::Value lower_call(const py::handle& expr) + { + auto py_func = expr.attr("func"); + auto func = loadvar(py_func); + auto args = expr.attr("args").cast(); + auto kws = expr.attr("kws").cast(); + auto vararg = expr.attr("vararg"); + + mlir::SmallVector args_list; + mlir::SmallVector> kwargs_list; + for (auto a : args) + { + args_list.push_back(loadvar(a)); + } + for (auto a : kws) + { + auto item = a.cast(); + auto name = item[0]; + auto val_name = item[1]; + kwargs_list.push_back({name.cast(), loadvar(val_name)}); + } + + auto py_func_name = func_name_resolver(typemap(py_func)); + if (py_func_name.is_none()) + { + plier::report_error(llvm::Twine("Can't resolve function: ") + py::str(typemap(py_func)).cast()); + } + + auto func_name = py_func_name.cast(); + + return builder.create(get_current_loc(), func, func_name, + args_list, kwargs_list); + } + + mlir::Value lower_binop(const py::handle& expr) + { + auto op = expr.attr("fn"); + auto lhs_name = expr.attr("lhs"); + auto rhs_name = expr.attr("rhs"); + auto lhs = loadvar(lhs_name); + auto rhs = loadvar(rhs_name); + auto op_name = resolve_op(op); + return builder.create(get_current_loc(), lhs, rhs, op_name); + } + + mlir::Value lower_inplce_binop(const py::handle& expr) + { + auto op = expr.attr("immutable_fn"); + auto lhs_name = expr.attr("lhs"); + auto rhs_name = expr.attr("rhs"); + auto lhs = loadvar(lhs_name); + auto rhs = loadvar(rhs_name); + auto op_name = resolve_op(op); + return builder.create(get_current_loc(), lhs, rhs, op_name); + } + + mlir::Value lower_unary(const py::handle& expr) + { + auto op = expr.attr("fn"); + auto val_name = expr.attr("value"); + auto val = loadvar(val_name); + auto op_name = resolve_op(op); + return builder.create(get_current_loc(), val, op_name); + } + + llvm::StringRef resolve_op(const py::handle& op) + { + for (auto elem : llvm::zip(inst_ops_names, insts.ops_handles)) + { + if (op.is(std::get<1>(elem))) + { + return std::get<0>(elem).op; + } + } + + plier::report_error(llvm::Twine("resolve_op not handled: \"") + py::str(op).cast() + "\""); + } + + mlir::Value lower_getattr(const py::handle& inst) + { + auto val = inst.attr("value"); + auto value = loadvar(val); + auto name = inst.attr("attr").cast(); + return builder.create(get_current_loc(), value, name); + } + + void setitem(const py::handle& target, const py::handle& index, const py::handle& value) + { + auto ind = [&]()->mlir::Value + { + if (py::isinstance(index)) + { + return builder.create(get_current_loc(), index.cast()); + } + return loadvar(index); + }(); + builder.create(get_current_loc(), loadvar(target), ind, loadvar(value)); + } + + void storevar(mlir::Value val, const py::handle& inst) + { + vars_map[inst.attr("name").cast()] = val; + val.setType(get_type(inst)); + } + + mlir::Value loadvar(const py::handle& inst) + { + auto it = vars_map.find(inst.attr("name").cast()); + assert(vars_map.end() != it); + return it->second; + } + + void delvar(const py::handle& inst) + { + auto var = loadvar(inst); + builder.create(get_current_loc(), var); + } + + void retvar(const py::handle& inst) + { + auto var = loadvar(inst); + auto func_type = func.getType(); + auto ret_type = func_type.getResult(0); + auto var_type = var.getType(); + if (ret_type != var_type) + { + var = builder.create(get_current_loc(), ret_type, var); + } + builder.create(get_current_loc(), var); + } + + void branch(const py::handle& cond, const py::handle& tr, const py::handle& fl) + { + auto c = loadvar(cond); + auto tr_block = blocks_map.find(tr.cast())->second; + auto fl_block = blocks_map.find(fl.cast())->second; + auto cond_val = builder.create(get_current_loc(), mlir::IntegerType::get(&ctx, 1), c); + builder.create(get_current_loc(), cond_val, tr_block, fl_block); + } + + void jump(const py::handle& target) + { + auto block = blocks_map.find(target.cast())->second; + builder.create(get_current_loc(), mlir::None, block); + } + + mlir::Value get_const(const py::handle& val) + { + auto get_val = [&](mlir::Attribute attr) + { + return builder.create(get_current_loc(), attr); + }; + if (py::isinstance(val)) + { + return get_val(builder.getI64IntegerAttr(val.cast())); + } + if (py::isinstance(val)) + { + return get_val(builder.getF64FloatAttr(val.cast())); + } + if (py::isinstance(val)) + { + return get_val(builder.getUnitAttr()); + } + plier::report_error(llvm::Twine("get_const unhandled type \"") + py::str(val.get_type()).cast() + "\""); + } + + mlir::FunctionType get_func_type(const py::handle& fnargs, const py::handle& restype) + { + auto ret = get_obj_type(restype()); + llvm::SmallVector args; + for (auto arg : fnargs()) + { + args.push_back(get_obj_type(arg)); + } + return mlir::FunctionType::get(&ctx, args, {ret}); + } + + mlir::Location get_current_loc() + { + return builder.getUnknownLoc(); // TODO + } + + void fixup_phis() + { + auto build_arg_list = [&](mlir::Block* block, auto& outgoing_phi_nodes, auto& list) + { + for (auto& o : outgoing_phi_nodes) + { + if (o.dest_block == block) + { + auto arg_index = o.arg_index; + if (list.size() <= arg_index) + { + list.resize(arg_index + 1); + } + auto it = vars_map.find(o.var_name); + assert(vars_map.end() != it); + auto arg_type = block->getArgument(arg_index).getType(); + auto val = builder.create(builder.getUnknownLoc(), arg_type, it->second); + list[arg_index] = val; + } + } + }; + for (auto& bb : func) + { + auto it = block_infos.find(&bb); + if (block_infos.end() != it) + { + auto& info = it->second; + auto term = bb.getTerminator(); + if (nullptr == term) + { + plier::report_error("broken ir: block without terminator"); + } + builder.setInsertionPointToEnd(&bb); + + if (auto op = mlir::dyn_cast(term)) + { + auto dest = op.getDest(); + mlir::SmallVector args; + build_arg_list(dest, info.outgoing_phi_nodes, args); + op.erase(); + builder.create(builder.getUnknownLoc(), dest, args); + } + else if (auto op = mlir::dyn_cast(term)) + { + auto true_dest = op.trueDest(); + auto false_dest = op.falseDest(); + auto cond = op.getCondition(); + mlir::SmallVector true_args; + mlir::SmallVector false_args; + build_arg_list(true_dest, info.outgoing_phi_nodes, true_args); + build_arg_list(false_dest, info.outgoing_phi_nodes, false_args); + op.erase(); + builder.create(builder.getUnknownLoc(), cond, true_dest, true_args, false_dest, false_args); + } + else + { + plier::report_error(llvm::Twine("Unhandled terminator: ") + term->getName().getStringRef()); + } + } + } + } + +}; + +plier::CompilerContext::Settings get_settings(const py::handle& settings) +{ + plier::CompilerContext::Settings ret; + ret.verify = settings["verify"].cast(); + ret.pass_statistics = settings["pass_statistics"].cast(); + ret.pass_timings = settings["pass_timings"].cast(); + ret.ir_printing = settings["ir_printing"].cast(); + return ret; +} + +py::bytes gen_ll_module(mlir::ModuleOp mod) +{ + std::string err; + llvm::raw_string_ostream err_stream(err); + auto diag_handler = [&](mlir::Diagnostic& diag) + { + if (diag.getSeverity() == mlir::DiagnosticSeverity::Error) + { + err_stream << diag; + } + }; + llvm::LLVMContext ll_ctx; + std::unique_ptr ll_mod; + plier::scoped_diag_handler(*mod.getContext(), diag_handler, [&]() + { + mlir::registerLLVMDialectTranslation(*mod.getContext()); + ll_mod = mlir::translateModuleToLLVMIR(mod, ll_ctx); + if (nullptr == ll_mod) + { + err_stream.flush(); + plier::report_error(llvm::Twine("Cannot generate LLVM module\n") + err); + } + }); + assert(nullptr != ll_mod); +// ll_mod->dump(); + return serialize_mod(*ll_mod); +} + +void create_pipeline(plier::PipelineRegistry& registry) +{ + register_base_pipeline(registry); + register_lower_to_llvm_pipeline(registry); + register_plier_to_std_pipeline(registry); + register_plier_to_linalg_pipeline(registry); + register_parallel_to_tbb_pipeline(registry); +} + +struct Module +{ + mlir::MLIRContext context; + plier::PipelineRegistry registry; + mlir::ModuleOp module; + + Module() + { + create_pipeline(registry); + } +}; + +void run_compiler(Module& mod, const py::object& compilation_context) +{ + auto& context = mod.context; + auto& module = mod.module; + auto& registry = mod.registry; + + auto settings = get_settings(compilation_context["compiler_settings"]); + plier::CompilerContext compiler(context, settings, registry); + compiler.run(module); +} +} + +py::capsule create_module() +{ + auto mod = std::make_unique(); + { + mlir::OpBuilder builder(&mod->context); + mod->module = mlir::ModuleOp::create(builder.getUnknownLoc()); + } + py::capsule capsule(mod.get(), [](void* ptr) + { + delete static_cast(ptr); + }); + mod.release(); + return capsule; +} + +py::capsule lower_function(const py::object& compilation_context, const py::capsule& py_mod, const py::object& func_ir) +{ + auto mod = static_cast(py_mod); + auto& context = mod->context; + auto& module = mod->module; + auto func = plier_lowerer(context).lower(compilation_context, module, func_ir); + return py::capsule(func.getOperation()); // no dtor, func owned by module +} + +py::bytes compile_module(const py::object& compilation_context, const py::capsule& py_mod) +{ + auto mod = static_cast(py_mod); + run_compiler(*mod, compilation_context); + return gen_ll_module(mod->module); +} + +py::str module_str(const py::capsule& py_mod) +{ + auto mod = static_cast(py_mod); + std::string ret; + llvm::raw_string_ostream ss(ret); + mod->module.print(ss); + ss.flush(); + return py::str(ss.str()); +} diff --git a/mlir-compiler/mlir-compiler/src/lowering.hpp b/mlir-compiler/mlir-compiler/src/lowering.hpp new file mode 100644 index 00000000000..0580e7e58f3 --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/lowering.hpp @@ -0,0 +1,20 @@ +#pragma once + +namespace pybind11 +{ +class bytes; +class capsule; +class object; +class str; +} + +pybind11::capsule create_module(); + +pybind11::capsule lower_function(const pybind11::object& compilation_context, + const pybind11::capsule& py_mod, + const pybind11::object& func_ir); + +pybind11::bytes compile_module(const pybind11::object& compilation_context, + const pybind11::capsule& py_mod); + +pybind11::str module_str(const pybind11::capsule& py_mod); diff --git a/mlir-compiler/mlir-compiler/src/mangle.cpp b/mlir-compiler/mlir-compiler/src/mangle.cpp new file mode 100644 index 00000000000..cc06a538323 --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/mangle.cpp @@ -0,0 +1,244 @@ +#include "mangle.hpp" + +#include +#include + +#include +#include +#include + +#include + +namespace +{ +static const constexpr auto PREFIX = "_Z"; + +template +bool mangle_int(llvm::raw_ostream& res, mlir::Type type) +{ + if (auto i = type.dyn_cast()) + { + if (i.getWidth() == Width && i.getSignedness() == Sign) + { + res << Symbol; + return true; + } + } + return false; +} + +template +bool mangle_float(llvm::raw_ostream& res, mlir::Type type) +{ + if (auto i = type.dyn_cast()) + { + if (i.getWidth() == Width) + { + res << Symbol; + return true; + } + } + return false; +} + +void mangle_memref_impl(llvm::raw_ostream& res, mlir::MemRefType type); + +bool mangle_memref(llvm::raw_ostream& res, mlir::Type type) +{ + if (auto m = type.dyn_cast()) + { + mangle_memref_impl(res, m); + return true; + } + return false; +} + +using type_mangler_t = bool(*)(llvm::raw_ostream&, mlir::Type); + +static const constexpr type_mangler_t type_manglers[] = { + &mangle_int<1, mlir::IntegerType::Signed, 'b'>, + &mangle_int<1, mlir::IntegerType::Unsigned, 'b'>, + &mangle_int<1, mlir::IntegerType::Signless, 'b'>, + + &mangle_int<8, mlir::IntegerType::Signed, 'a'>, + &mangle_int<8, mlir::IntegerType::Unsigned, 'h'>, + &mangle_int<8, mlir::IntegerType::Signless, 'c'>, + + &mangle_int<16, mlir::IntegerType::Signed, 's'>, + &mangle_int<16, mlir::IntegerType::Unsigned, 't'>, + &mangle_int<16, mlir::IntegerType::Signless, 's'>, + + &mangle_int<32, mlir::IntegerType::Signed, 'i'>, + &mangle_int<32, mlir::IntegerType::Unsigned, 'j'>, + &mangle_int<32, mlir::IntegerType::Signless, 'i'>, + + &mangle_int<64, mlir::IntegerType::Signed, 'x'>, + &mangle_int<64, mlir::IntegerType::Unsigned, 'm'>, + &mangle_int<64, mlir::IntegerType::Signless, 'x'>, + + &mangle_int<128, mlir::IntegerType::Signed, 'n'>, + &mangle_int<128, mlir::IntegerType::Unsigned, 'o'>, + &mangle_int<128, mlir::IntegerType::Signless, 'n'>, + + &mangle_float<32, 'f'>, + &mangle_float<64, 'd'>, + &mangle_float<80, 'e'>, + &mangle_float<128, 'g'>, + + &mangle_memref, +}; + +bool check_type(mlir::Type type) +{ + llvm::raw_null_ostream ss; + for (auto mangler : type_manglers) + { + if (mangler(ss, type)) + { + return true; + } + } + return false; +} + +bool is_valid_char(char c) +{ + return (c >= 'a' && c <= 'z') || + (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') || + (c == '_'); +} + +std::string escape_string(llvm::StringRef str) +{ + std::string ret; + llvm::raw_string_ostream ss(ret); + for (auto c : str) + { + if (is_valid_char(c)) + { + ss << c; + } + else + { + ss << "$" << llvm::format_hex_no_prefix(static_cast(c), 2); + } + } + ss.flush(); + return ret; +} + +template +void mangle_ident_impl(llvm::raw_ostream& res, llvm::StringRef ident, F&& template_params) +{ + assert(!ident.empty()); + llvm::SmallVector parts; + ident.split(parts, '.'); + assert(!parts.empty()); + auto write_part = [&](auto part) + { + auto escaped = escape_string(part); + if (std::isdigit(escaped.front())) + { + res << escaped.size() + 1 << '_' << escaped; + } + else + { + res << escaped.size() << escaped; + } + }; + if (parts.size() == 1) + { + write_part(parts.front()); + template_params(res); + } + else + { + res << 'N'; + for (auto& part : parts) + { + write_part(part); + } + template_params(res); + res << 'E'; + } +} + +void mangle_ident(llvm::raw_ostream& res, llvm::StringRef ident) +{ + auto dummy = [](auto&) {}; + mangle_ident_impl(res, ident, dummy); +} + +template +void mangle_ident(llvm::raw_ostream& res, llvm::StringRef ident, F&& template_params) +{ + auto wrap_template = [&](llvm::raw_ostream& s) + { + s << 'I'; + template_params(s); + s << 'E'; + }; + mangle_ident_impl(res, ident, wrap_template); +} + +void mangle_type(llvm::raw_ostream& res, mlir::Type type) +{ + for(auto m : type_manglers) + { + if (m(res, type)) + { + return; + } + } + llvm_unreachable("Cannot mangle type"); +} + +void mangle_memref_impl(llvm::raw_ostream& res, mlir::MemRefType type) +{ + auto params = [&](llvm::raw_ostream& s) + { + mangle_type(s, type.getElementType()); + s << "Li"<< type.getRank() << "E"; + mangle_ident(s, "C"); + }; + mangle_ident(res, "array", params); +} + +void mangle_types(llvm::raw_ostream& res, mlir::TypeRange types) +{ + for (auto type : types) + { + mangle_type(res, type); + } +} + +} + +bool mangle(llvm::raw_ostream& res, llvm::StringRef ident, mlir::TypeRange types) +{ + for (auto type : types) + { + if (!check_type(type)) + { + return false; + } + } + res << PREFIX; + mangle_ident(res, ident); + mangle_types(res, types); + return true; +} + + +std::string mangle(llvm::StringRef ident, mlir::TypeRange types) +{ + std::string ret; + llvm::raw_string_ostream ss(ret); + if (!mangle(ss, ident, types)) + { + return {}; + } + ss.flush(); + return ret; +} diff --git a/mlir-compiler/mlir-compiler/src/mangle.hpp b/mlir-compiler/mlir-compiler/src/mangle.hpp new file mode 100644 index 00000000000..6d050392757 --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/mangle.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include + +namespace llvm +{ +class StringRef; +class raw_ostream; +} + +namespace mlir +{ +class TypeRange; +} + +bool mangle(llvm::raw_ostream& res, llvm::StringRef ident, mlir::TypeRange types); + +std::string mangle(llvm::StringRef ident, mlir::TypeRange types); diff --git a/mlir-compiler/mlir-compiler/src/pipelines/base_pipeline.cpp b/mlir-compiler/mlir-compiler/src/pipelines/base_pipeline.cpp new file mode 100644 index 00000000000..1685305d74f --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/pipelines/base_pipeline.cpp @@ -0,0 +1,42 @@ +#include "pipelines/base_pipeline.hpp" + +#include "plier/compiler/pipeline_registry.hpp" + +namespace +{ +const constexpr llvm::StringRef passes[] ={ + "init", + "lowering", + "terminate", +}; + +void dummy_pass_func(mlir::OpPassManager&) {} +} + +void register_base_pipeline(plier::PipelineRegistry& registry) +{ + for (std::size_t i = 0; i < llvm::array_lengthof(passes); ++i) + { + registry.register_pipeline([i](auto sink) + { + if (0 == i) + { + sink(passes[i], {}, {}, {}, dummy_pass_func); + } + else + { + sink(passes[i], {passes[i - 1]}, {}, {}, dummy_pass_func); + } + }); + } +} + +PipelineStage get_high_lowering_stage() +{ + return {passes[0], passes[1]}; +} + +PipelineStage get_lower_lowering_stage() +{ + return {passes[1], passes[2]}; +} diff --git a/mlir-compiler/mlir-compiler/src/pipelines/base_pipeline.hpp b/mlir-compiler/mlir-compiler/src/pipelines/base_pipeline.hpp new file mode 100644 index 00000000000..a7cff3d87cc --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/pipelines/base_pipeline.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include + +namespace plier +{ +class PipelineRegistry; +} + +void register_base_pipeline(plier::PipelineRegistry& registry); + +struct PipelineStage +{ + llvm::StringRef begin; + llvm::StringRef end; +}; + +PipelineStage get_high_lowering_stage(); // TODO: better name +PipelineStage get_lower_lowering_stage(); // TODO: better name diff --git a/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp new file mode 100644 index 00000000000..70e94013769 --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -0,0 +1,1345 @@ +#include "pipelines/lower_to_llvm.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "plier/dialect.hpp" + +#include "plier/transforms/func_utils.hpp" + +#include "base_pipeline.hpp" +#include "plier/compiler/pipeline_registry.hpp" + +#include "plier/utils.hpp" + +namespace +{ +const mlir::LowerToLLVMOptions &getLLVMOptions() +{ + static mlir::LowerToLLVMOptions options = []() + { + llvm::InitializeNativeTarget(); + auto triple = llvm::sys::getProcessTriple(); + std::string err_str; + auto target = llvm::TargetRegistry::lookupTarget(triple, err_str); + if (nullptr == target) + { + plier::report_error(llvm::Twine("Unable to get target: ") + err_str); + } + llvm::TargetOptions target_opts; + std::unique_ptr machine(target->createTargetMachine(triple, llvm::sys::getHostCPUName(), "", target_opts, llvm::None)); + mlir::LowerToLLVMOptions opts; + opts.dataLayout = machine->createDataLayout(); + opts.useBarePtrCallConv = true; + return opts; + }(); + return options; +} + +struct LLVMTypeHelper +{ + LLVMTypeHelper(mlir::MLIRContext& ctx): + type_converter(&ctx) {} + + mlir::Type i(unsigned bits) + { + return mlir::IntegerType::get(&type_converter.getContext(), bits); + } + + mlir::Type ptr(mlir::Type type) + { + assert(static_cast(type)); + auto ll_type = type_converter.convertType(type); + assert(static_cast(ll_type)); + return mlir::LLVM::LLVMPointerType::get(ll_type); + } + + mlir::MLIRContext& get_context() + { + return type_converter.getContext(); + } + + mlir::LLVMTypeConverter& get_type_converter() + { + return type_converter; + } + +private: + mlir::LLVMTypeConverter type_converter; +}; + +mlir::Type getExceptInfoType(LLVMTypeHelper& type_helper) +{ + mlir::Type elems[] = { + type_helper.ptr(type_helper.i(8)), + type_helper.i(32), + type_helper.ptr(type_helper.i(8)), + }; + return mlir::LLVM::LLVMStructType::getLiteral(&type_helper.get_context(), elems); +} + +mlir::LLVM::LLVMStructType get_array_type(mlir::TypeConverter& converter, mlir::MemRefType type) +{ + assert(type); + auto ctx = type.getContext(); + auto i8p = mlir::LLVM::LLVMPointerType::get(mlir::IntegerType::get(ctx, 8)); + auto i64 = mlir::IntegerType::get(ctx, 64); + auto data_type = converter.convertType(type.getElementType()); + assert(data_type); + auto shape_type = mlir::LLVM::LLVMArrayType::get(i64, static_cast(type.getRank())); + const mlir::Type members[] = { + i8p, // 0, meminfo + i8p, // 1, parent + i64, // 2, nitems + i64, // 3, itemsize + mlir::LLVM::LLVMPointerType::get(data_type), // 4, data + shape_type, // 5, shape + shape_type, // 6, strides + }; + return mlir::LLVM::LLVMStructType::getLiteral(ctx, members); +} + +template +void flatten_type(mlir::Type type, F&& func) +{ + if (auto struct_type = type.dyn_cast()) + { + for (auto elem : struct_type.getBody()) + { + flatten_type(elem, std::forward(func)); + } + } + else if (auto arr_type = type.dyn_cast()) + { + auto elem = arr_type.getElementType(); + auto size = arr_type.getNumElements(); + for (unsigned i = 0 ; i < size; ++i) + { + flatten_type(elem, std::forward(func)); + } + } + else + { + func(type); + } +} + +template +mlir::Value unflatten(mlir::Type type, mlir::Location loc, mlir::OpBuilder& builder, F&& next_func) +{ + namespace mllvm = mlir::LLVM; + if (auto struct_type = type.dyn_cast()) + { + mlir::Value val = builder.create(loc, struct_type); + for (auto elem : llvm::enumerate(struct_type.getBody())) + { + auto elem_index = builder.getI64ArrayAttr(static_cast(elem.index())); + auto elem_type = elem.value(); + auto elem_val = unflatten(elem_type, loc, builder, std::forward(next_func)); + val = builder.create(loc, val, elem_val, elem_index); + } + return val; + } + else if (auto arr_type = type.dyn_cast()) + { + auto elem_type = arr_type.getElementType(); + auto size = arr_type.getNumElements(); + mlir::Value val = builder.create(loc, arr_type); + for (unsigned i = 0 ; i < size; ++i) + { + auto elem_index = builder.getI64ArrayAttr(static_cast(i)); + auto elem_val = unflatten(elem_type, loc, builder, std::forward(next_func)); + val = builder.create(loc, val, elem_val, elem_index); + } + return val; + } + else + { + return next_func(); + } +} + +void write_memref_desc(llvm::raw_ostream& os, mlir::MemRefType memref_type) +{ + if (memref_type.hasRank()) + { + os << memref_type.getRank(); + } + else + { + os << "?"; + } + os << "x"; + memref_type.getElementType().print(os); +} + +std::string gen_to_memref_conversion_func_name(mlir::MemRefType memref_type) +{ + assert(memref_type); + std::string ret; + llvm::raw_string_ostream ss(ret); + ss << "__convert_to_memref_"; + write_memref_desc(ss, memref_type); + ss.flush(); + return ret; +} + +std::string gen_from_memref_conversion_func_name(mlir::MemRefType memref_type) +{ + assert(memref_type); + std::string ret; + llvm::raw_string_ostream ss(ret); + ss << "__convert_from_memref_"; + write_memref_desc(ss, memref_type); + ss.flush(); + return ret; +} + +mlir::Value div_strides(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value strides, mlir::Value m) +{ + auto array_type = strides.getType().cast(); + mlir::Value array = builder.create(loc, array_type); + for (unsigned i = 0 ; i < array_type.getNumElements(); ++i) + { + auto index = builder.getI64ArrayAttr(i); + auto prev = builder.create(loc, array_type.getElementType(), strides, index); + auto val = builder.create(loc, prev, m); + array = builder.create(loc, array, val, index); + } + return array; +} + +mlir::Value mul_strides(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value strides, mlir::Value m) +{ + auto array_type = strides.getType().cast(); + mlir::Value array = builder.create(loc, array_type); + for (unsigned i = 0 ; i < array_type.getNumElements(); ++i) + { + auto index = builder.getI64ArrayAttr(i); + auto prev = builder.create(loc, array_type.getElementType(), strides, index); + auto val = builder.create(loc, prev, m); + array = builder.create(loc, array, val, index); + } + return array; +} + +unsigned item_size(mlir::Type type) +{ + if (auto inttype = type.dyn_cast()) + { + assert((inttype.getWidth() % 8) == 0); + return inttype.getWidth() / 8; + } + if (auto floattype = type.dyn_cast()) + { + assert((floattype.getWidth() % 8) == 0); + return floattype.getWidth() / 8; + } + llvm_unreachable("item_size: invalid type"); +} + +mlir::FuncOp get_to_memref_conversion_func( + mlir::ModuleOp module, mlir::OpBuilder& builder, mlir::MemRefType memref_type, + mlir::LLVM::LLVMStructType src_type, mlir::LLVM::LLVMStructType dst_type) +{ + assert(memref_type); + assert(src_type); + assert(dst_type); + auto func_name = gen_to_memref_conversion_func_name(memref_type); + if (auto func = module.lookupSymbol(func_name)) + { + assert(func.getType().getNumResults() == 1); + assert(func.getType().getResult(0) == dst_type); + return func; + } + auto func_type = mlir::FunctionType::get(builder.getContext(), src_type, dst_type); + auto loc = builder.getUnknownLoc(); + auto new_func = plier::add_function(builder, module, func_name, func_type); + auto alwaysinline = mlir::StringAttr::get(builder.getContext(), "alwaysinline"); + new_func->setAttr("passthrough", mlir::ArrayAttr::get(builder.getContext(), alwaysinline)); + mlir::OpBuilder::InsertionGuard guard(builder); + auto block = new_func.addEntryBlock(); + builder.setInsertionPointToStart(block); + namespace mllvm = mlir::LLVM; + mlir::Value arg = block->getArgument(0); + auto extract = [&](unsigned index) + { + auto res_type = src_type.getBody()[index]; + auto i = builder.getI64ArrayAttr(index); + return builder.create(loc, res_type, arg, i); + }; + auto meminfo = extract(0); + auto ptr = extract(4); + auto shape = extract(5); + auto strides = extract(6); + auto i64 = mlir::IntegerType::get(builder.getContext(), 64); + auto offset = builder.create(loc, i64, builder.getI64IntegerAttr(0)); + mlir::Value res = builder.create(loc, dst_type); + auto meminfo_casted = builder.create(loc, ptr.getType(), meminfo); + auto itemsize = builder.create(loc, i64, builder.getI64IntegerAttr(item_size(memref_type.getElementType()))); + auto insert = [&](unsigned index, mlir::Value val) + { + auto i = builder.getI64ArrayAttr(index); + res = builder.create(loc, res, val, i); + }; + insert(0, meminfo_casted); + insert(1, ptr); + insert(2, offset); + insert(3, shape); + insert(4, div_strides(loc, builder, strides, itemsize)); + builder.create(loc, res); + return new_func; +} + +mlir::FuncOp get_from_memref_conversion_func( + mlir::ModuleOp module, mlir::OpBuilder& builder, mlir::MemRefType memref_type, + mlir::LLVM::LLVMStructType src_type, mlir::LLVM::LLVMStructType dst_type) +{ + assert(memref_type); + assert(src_type); + assert(dst_type); + auto func_name = gen_from_memref_conversion_func_name(memref_type); + if (auto func = module.lookupSymbol(func_name)) + { + assert(func.getType().getNumResults() == 1); + assert(func.getType().getResult(0) == dst_type); + return func; + } + auto func_type = mlir::FunctionType::get(builder.getContext(), src_type, dst_type); + auto loc = builder.getUnknownLoc(); + auto new_func = plier::add_function(builder, module, func_name, func_type); + auto alwaysinline = mlir::StringAttr::get(builder.getContext(), "alwaysinline"); + new_func->setAttr("passthrough", mlir::ArrayAttr::get(builder.getContext(), alwaysinline)); + mlir::OpBuilder::InsertionGuard guard(builder); + auto block = new_func.addEntryBlock(); + builder.setInsertionPointToStart(block); + namespace mllvm = mlir::LLVM; + mlir::Value arg = block->getArgument(0); + auto i8ptr_type = mllvm::LLVMPointerType::get(builder.getIntegerType(8)); + auto i64_type = builder.getIntegerType(64); + auto extract = [&](unsigned index) + { + auto res_type = src_type.getBody()[index]; + auto i = builder.getI64ArrayAttr(index); + return builder.create(loc, res_type, arg, i); + }; + auto meminfo = builder.create(loc, i8ptr_type, extract(0)); + auto orig_ptr = extract(1); + auto offset = extract(2); + auto shape = extract(3); + auto strides = extract(4); + auto ptr = builder.create(loc, orig_ptr.getType(), orig_ptr, offset.getResult()); + mlir::Value res = builder.create(loc, dst_type); + auto null = builder.create(loc, i8ptr_type); + mlir::Value nitems = builder.create(loc, i64_type, builder.getI64IntegerAttr(1)); + for (int64_t i = 0; i < memref_type.getRank(); ++i) + { + auto dim = builder.create(loc, nitems.getType(), shape, builder.getI64ArrayAttr(i)); + nitems = builder.create(loc, nitems, dim); + } + auto itemsize = builder.create(loc, i64_type, builder.getI64IntegerAttr(item_size(memref_type.getElementType()))); + auto insert = [&](unsigned index, mlir::Value val) + { + auto i = builder.getI64ArrayAttr(index); + res = builder.create(loc, res, val, i); + }; + insert(0, meminfo); + insert(1, null); // parent + insert(2, nitems); + insert(3, itemsize); + insert(4, ptr); + insert(5, shape); + insert(6, mul_strides(loc, builder, strides, itemsize)); + builder.create(loc, res); + return new_func; +} + +mlir::Attribute get_fastmath_attrs(mlir::MLIRContext& ctx) +{ + auto add_pair = [&](auto name, auto val) + { + const mlir::Attribute attrs[] = { + mlir::StringAttr::get(&ctx, name), + mlir::StringAttr::get(&ctx, val) + }; + return mlir::ArrayAttr::get(&ctx, attrs); + }; + const mlir::Attribute attrs[] = { + add_pair("denormal-fp-math", "preserve-sign,preserve-sign"), + add_pair("denormal-fp-math-f32", "ieee,ieee"), + add_pair("no-infs-fp-math", "true"), + add_pair("no-nans-fp-math", "true"), + add_pair("no-signed-zeros-fp-math", "true"), + add_pair("unsafe-fp-math", "true"), + add_pair(plier::attributes::getFastmathName(), "1"), + }; + return mlir::ArrayAttr::get(&ctx, attrs); +} + +void fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) +{ + if (func.isPrivate()) + { + return; + } + if (func->getAttr(plier::attributes::getFastmathName())) + { + func->setAttr("passthrough", get_fastmath_attrs(*func.getContext())); + } + auto old_type = func.getType(); + assert(old_type.getNumResults() <= 1); + auto& ctx = *old_type.getContext(); + llvm::SmallVector args; + + auto ptr = [&](auto arg) + { + return type_helper.ptr(arg); + }; + + unsigned index = 0; + auto add_arg = [&](mlir::Type type) + { + args.push_back(type); + auto ret = func.getBody().insertArgument(index, type); + ++index; + return ret; + }; + + mlir::OpBuilder builder(&ctx); + builder.setInsertionPointToStart(&func.getBody().front()); + + auto loc = builder.getUnknownLoc(); + llvm::SmallVector new_args; + auto process_arg = [&](mlir::Type type) + { + if (auto memref_type = type.dyn_cast()) + { + new_args.clear(); + auto arr_type = get_array_type(type_helper.get_type_converter(), memref_type); + flatten_type(arr_type, [&](mlir::Type new_type) + { + new_args.push_back(add_arg(new_type)); + }); + auto it = new_args.begin(); + mlir::Value desc = unflatten(arr_type, loc, builder, [&]() + { + auto ret = *it; + ++it; + return ret; + }); + + auto mod = mlir::cast(func->getParentOp()); + auto dst_type = type_helper.get_type_converter().convertType(memref_type); + assert(dst_type); + auto conv_func = get_to_memref_conversion_func(mod, builder, memref_type, arr_type, dst_type.cast()); + auto converted = builder.create(loc, conv_func, desc).getResult(0); + auto casted = builder.create(loc, memref_type, converted); + func.getBody().getArgument(index).replaceAllUsesWith(casted); + func.getBody().eraseArgument(index); + } + else + { + args.push_back(type); + ++index; + } + }; + + auto get_res_type = [&](mlir::Type type)->mlir::Type + { + if (auto memreftype = type.dyn_cast()) + { + return get_array_type(type_helper.get_type_converter(), memreftype); + } + return type; + }; + + auto orig_ret_type = (old_type.getNumResults() != 0 ? get_res_type(old_type.getResult(0)) : type_helper.ptr(type_helper.i(8))); + add_arg(ptr(orig_ret_type)); + add_arg(ptr(ptr(getExceptInfoType(type_helper)))); + + auto old_args = old_type.getInputs(); + for (auto arg : old_args) + { + process_arg(arg); + } + auto ret_type = mlir::IntegerType::get(&ctx, 32); + func.setType(mlir::FunctionType::get(&ctx, args, ret_type)); +} + +struct ReturnOpLowering : public mlir::OpRewritePattern +{ + ReturnOpLowering(mlir::MLIRContext* ctx, mlir::TypeConverter& converter): + OpRewritePattern(ctx), type_converter(converter) {} + + mlir::LogicalResult matchAndRewrite(mlir::ReturnOp op, + mlir::PatternRewriter& rewriter) const + { + auto parent = op->getParentOfType(); + if (nullptr == parent || parent.isPrivate()) + { + return mlir::failure(); + } + + auto insert_ret = [&]() + { + auto ctx = op.getContext(); + auto ret_type = mlir::IntegerType::get(ctx, 32); + auto ll_ret_type = mlir::IntegerType::get(ctx, 32); + mlir::Value ret = rewriter.create(op.getLoc(), ll_ret_type, mlir::IntegerAttr::get(ret_type, 0)); + rewriter.replaceOpWithNewOp(op, ret); + }; + + auto loc = op.getLoc(); + rewriter.setInsertionPoint(op); + auto addr = op->getParentRegion()->front().getArgument(0); + if (op.getNumOperands() == 0) + { + assert(addr.getType().isa()); + auto null_type = addr.getType().cast().getElementType(); + auto ll_val = rewriter.create(op.getLoc(), null_type); + rewriter.create(loc, ll_val, addr); + insert_ret(); + return mlir::success(); + } + else if (op.getNumOperands() == 1) + { + mlir::Value val = op.getOperand(0); + auto orig_type = val.getType(); + auto ll_ret_type = type_converter.convertType(orig_type); + assert(ll_ret_type); + val = rewriter.create(loc, ll_ret_type, val); + if (auto memref_type = orig_type.dyn_cast()) + { + auto dst_type = get_array_type(type_converter, memref_type).cast(); + auto mod = op->getParentOfType(); + auto func = get_from_memref_conversion_func(mod, rewriter, memref_type, ll_ret_type.cast(), dst_type); + val = rewriter.create(loc, func, val).getResult(0); + } + rewriter.create(loc, val, addr); + insert_ret(); + return mlir::success(); + } + else + { + return mlir::failure(); + } + } + +private: + mlir::TypeConverter& type_converter; +}; + +// Remove redundant bitcasts we have created on PreLowering +struct RemoveBitcasts : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::LLVM::BitcastOp op, + mlir::PatternRewriter& rewriter) const + { + if (op.getType() == op.getOperand().getType()) + { + rewriter.replaceOp(op, op.getOperand()); + return mlir::success(); + } + return mlir::failure(); + } +}; + +template +struct ApplyFastmathFlags : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + Op op, mlir::PatternRewriter& rewriter) const + { + auto parent = mlir::cast(op->getParentOp()); + bool changed = false; + + rewriter.startRootUpdate(op); + auto fmf = op.fastmathFlags(); + getFastmathFlags(parent, [&](auto flag) + { + if (!mlir::LLVM::bitEnumContains(fmf, flag)) + { + fmf = fmf | flag; + changed = true; + } + }); + if (changed) + { + op.fastmathFlagsAttr(mlir::LLVM::FMFAttr::get(op.getContext(), fmf)); + rewriter.finalizeRootUpdate(op); + } + else + { + rewriter.cancelRootUpdate(op); + } + + return mlir::success(changed); + } + +private: + template + static void getFastmathFlags(mlir::LLVM::LLVMFuncOp func, F&& sink) + { + if (func->hasAttr(plier::attributes::getFastmathName())) + { + sink(mlir::LLVM::FastmathFlags::fast); + } + } +}; + +// Copypaste from StandardToLLVM +struct AllocLikeOpLowering : public mlir::ConvertToLLVMPattern { + using ConvertToLLVMPattern::createIndexConstant; + using ConvertToLLVMPattern::getIndexType; + using ConvertToLLVMPattern::getVoidPtrType; + + explicit AllocLikeOpLowering(mlir::StringRef opName, mlir::LLVMTypeConverter &converter) + : ConvertToLLVMPattern(opName, &converter.getContext(), converter, /*benefit*/99) {} + +protected: + // Creates a call to an allocation function with params and casts the + // resulting void pointer to ptrType. + mlir::Value createAllocCall(mlir::Location loc, mlir::StringRef name, mlir::Type ptrType, + mlir::ArrayRef params, mlir::ModuleOp module, + mlir::ConversionPatternRewriter &rewriter) const { + using namespace mlir; + SmallVector paramTypes; + auto allocFuncOp = module.lookupSymbol(name); + if (!allocFuncOp) { + for (Value param : params) + paramTypes.push_back(param.getType()); + auto allocFuncType = + LLVM::LLVMFunctionType::get(getVoidPtrType(), paramTypes); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + allocFuncOp = rewriter.create(rewriter.getUnknownLoc(), + name, allocFuncType); + } + auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFuncOp); + auto allocatedPtr = rewriter + .create(loc, getVoidPtrType(), + allocFuncSymbol, params) + .getResult(0); + return rewriter.create(loc, ptrType, allocatedPtr); + } + + /// Allocates the underlying buffer. Returns the allocated pointer and the + /// aligned pointer. + virtual std::tuple + allocateBuffer(mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, + mlir::Value sizeBytes, mlir::Operation *op) const = 0; + +private: + static mlir::MemRefType getMemRefResultType(mlir::Operation *op) { + return op->getResult(0).getType().cast(); + } + + mlir::LogicalResult match(mlir::Operation *op) const override { + mlir::MemRefType memRefType = getMemRefResultType(op); + return mlir::success(isConvertibleAndHasIdentityMaps(memRefType)); + } + + // An `alloc` is converted into a definition of a memref descriptor value and + // a call to `malloc` to allocate the underlying data buffer. The memref + // descriptor is of the LLVM structure type where: + // 1. the first element is a pointer to the allocated (typed) data buffer, + // 2. the second element is a pointer to the (typed) payload, aligned to the + // specified alignment, + // 3. the remaining elements serve to store all the sizes and strides of the + // memref using LLVM-converted `index` type. + // + // Alignment is performed by allocating `alignment` more bytes than + // requested and shifting the aligned pointer relative to the allocated + // memory. Note: `alignment - ` would actually be + // sufficient. If alignment is unspecified, the two pointers are equal. + + // An `alloca` is converted into a definition of a memref descriptor value and + // an llvm.alloca to allocate the underlying data buffer. + void rewrite(mlir::Operation *op, mlir::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::MemRefType memRefType = getMemRefResultType(op); + auto loc = op->getLoc(); + + // Get actual sizes of the memref as values: static sizes are constant + // values and dynamic sizes are passed to 'alloc' as operands. In case of + // zero-dimensional memref, assume a scalar (size 1). + mlir::SmallVector sizes; + mlir::SmallVector strides; + mlir::Value sizeBytes; + this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes, + strides, sizeBytes); + + // Allocate the underlying buffer. + mlir::Value allocatedPtr; + mlir::Value alignedPtr; + std::tie(allocatedPtr, alignedPtr) = + this->allocateBuffer(rewriter, loc, sizeBytes, op); + + // Create the MemRef descriptor. + auto memRefDescriptor = this->createMemRefDescriptor( + loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter); + + // Return the final value of the descriptor. + rewriter.replaceOp(op, {memRefDescriptor}); + } +}; + +struct AllocOpLowering : public AllocLikeOpLowering { + AllocOpLowering(mlir::LLVMTypeConverter &converter) + : AllocLikeOpLowering(mlir::AllocOp::getOperationName(), converter) {} + + std::tuple allocateBuffer(mlir::ConversionPatternRewriter &rewriter, + mlir::Location loc, mlir::Value sizeBytes, + mlir::Operation *op) const override { + auto allocOp = mlir::cast(op); + auto memRefType = allocOp.getType(); + mlir::Value alignment; + if (auto alignmentAttr = allocOp.alignment()) { + alignment = createIndexConstant(rewriter, loc, *alignmentAttr); + } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { + // In the case where no alignment is specified, we may want to override + // `malloc's` behavior. `malloc` typically aligns at the size of the + // biggest scalar on a target HW. For non-scalars, use the natural + // alignment of the LLVM type given by the LLVM DataLayout. + alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); + } else { + alignment = createIndexConstant(rewriter, loc, 32/*item_size(memRefType.getElementType())*/); + } + alignment = rewriter.create(loc, rewriter.getIntegerType(32), alignment); + + auto mod = allocOp->getParentOfType(); + auto meminfo_ptr = + createAllocCall(loc, "NRT_MemInfo_alloc_safe_aligned", getVoidPtrType(), {sizeBytes, alignment}, + mod, rewriter); + auto data_ptr = createAllocCall(loc, "NRT_MemInfo_data_fast", getVoidPtrType(), {meminfo_ptr}, + mod, rewriter); + + auto elem_ptr_type = mlir::LLVM::LLVMPointerType::get(memRefType.getElementType()); + auto bitcast = [&](mlir::Value val) + { + return rewriter.create(loc, elem_ptr_type, val); + }; + + return std::make_tuple(bitcast(meminfo_ptr), bitcast(data_ptr)); + } +}; + +struct DeallocOpLowering : public mlir::ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + explicit DeallocOpLowering(mlir::LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern(converter, /*benefit*/99) {} + + mlir::LogicalResult + matchAndRewrite(mlir::DeallocOp op, mlir::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const override { + assert(operands.size() == 1 && "dealloc takes one operand"); + mlir::DeallocOp::Adaptor transformed(operands); + + // Insert the `free` declaration if it is not already present. + auto freeFunc = + op->getParentOfType().lookupSymbol("NRT_decref"); + if (!freeFunc) { + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart( + op->getParentOfType().getBody()); + freeFunc = rewriter.create( + rewriter.getUnknownLoc(), "NRT_decref", + mlir::LLVM::LLVMFunctionType::get(getVoidType(), getVoidPtrType())); + } + + mlir::MemRefDescriptor memref(transformed.memref()); + mlir::Value casted = rewriter.create( + op.getLoc(), getVoidPtrType(), + memref.allocatedPtr(rewriter, op.getLoc())); + rewriter.replaceOpWithNewOp( + op, mlir::TypeRange(), rewriter.getSymbolRefAttr(freeFunc), casted); + return mlir::success(); + } +}; + +class CheckForPlierTypes : + public mlir::PassWrapper> +{ + void runOnOperation() override + { + markAllAnalysesPreserved(); + auto plier_dialect = getContext().getOrLoadDialect(); + getOperation()->walk([&](mlir::Operation* op) + { + if (op->getName().getDialect() == plier_dialect) + { + op->emitOpError(": not all plier ops were translated\n"); + signalPassFailure(); + return; + } + + auto check_type = [](mlir::Type type) + { + return type.isa(); + }; + + if (llvm::any_of(op->getResultTypes(), check_type) || + llvm::any_of(op->getOperandTypes(), check_type)) + { + op->emitOpError(": plier types weren't translated\n"); + signalPassFailure(); + } + }); + } +}; + +class LLVMFunctionPass : public mlir::OperationPass +{ +public: + using OperationPass::OperationPass; + + /// The polymorphic API that runs the pass over the currently held function. + virtual void runOnFunction() = 0; + + /// The polymorphic API that runs the pass over the currently held operation. + void runOnOperation() final { + if (!getFunction().isExternal()) + runOnFunction(); + } + + /// Return the current function being transformed. + mlir::LLVM::LLVMFuncOp getFunction() { return this->getOperation(); } +}; + +void copyAttrs(mlir::Operation* src, mlir::Operation* dst) +{ + const mlir::StringRef attrs[] = { + plier::attributes::getFastmathName(), + plier::attributes::getParallelName(), + plier::attributes::getMaxConcurrencyName(), + }; + for (auto name : attrs) + { + if (auto attr = src->getAttr(name)) + { + dst->setAttr(name, attr); + } + } +} + +struct LowerParallel : public mlir::OpRewritePattern +{ + LowerParallel(mlir::MLIRContext* context): + OpRewritePattern(context), + converter(context) {} + + mlir::LogicalResult + matchAndRewrite(plier::ParallelOp op, + mlir::PatternRewriter &rewriter) const override { + auto num_loops = op.getNumLoops(); + llvm::SmallVector context_vars; + llvm::SmallVector context_constants; + llvm::DenseSet context_vars_set; + auto add_context_var = [&](mlir::Value value) + { + if (0 != context_vars_set.count(value)) + { + return; + } + context_vars_set.insert(value); + if (auto op = value.getDefiningOp()) + { + mlir::ConstantOp a; + if (op->hasTrait()) + { + context_constants.emplace_back(op); + return; + } + } + context_vars.emplace_back(value); + }; + + auto is_defined_inside = [&](mlir::Value value) + { + auto& this_region = op.getLoopBody(); + auto op_region = value.getParentRegion(); + assert(nullptr != op_region); + do + { + if (op_region == &this_region) + { + return true; + } + op_region = op_region->getParentRegion(); + } + while (nullptr != op_region); + return false; + }; + + if (op->walk([&](mlir::Operation* inner)->mlir::WalkResult + { + if (op != inner) + { + for (auto arg : inner->getOperands()) + { + if (!is_defined_inside(arg)) + { + add_context_var(arg); + } + } + } + return mlir::WalkResult::advance(); + }).wasInterrupted()) + { + return mlir::failure(); + } + + auto context_type = [&]()->mlir::LLVM::LLVMStructType + { + llvm::SmallVector fields; + fields.reserve(context_vars.size()); + for (auto var : context_vars) + { + auto type = converter.convertType(var.getType()); + if (!type) + { + return {}; + } + fields.emplace_back(type); + } + return mlir::LLVM::LLVMStructType::getLiteral(op.getContext(), fields); + }(); + + if (!context_type) + { + return mlir::failure(); + } + auto context_ptr_type = mlir::LLVM::LLVMPointerType::get(context_type); + + auto loc = op.getLoc(); + auto index_type = rewriter.getIndexType(); + auto llvm_index_type = mlir::IntegerType::get(op.getContext(), 64); // TODO + auto to_llvm_index = [&](mlir::Value val)->mlir::Value + { + if (val.getType() != llvm_index_type) + { + return rewriter.create(loc, llvm_index_type, val); + } + return val; + }; + auto from_llvm_index = [&](mlir::Value val)->mlir::Value + { + if (val.getType() != index_type) + { + return rewriter.create(loc, index_type, val); + } + return val; + }; + auto llvm_i32_type = mlir::IntegerType::get(op.getContext(), 32); + auto zero = rewriter.create(loc, llvm_i32_type, rewriter.getI32IntegerAttr(0)); + auto one = rewriter.create(loc, llvm_i32_type, rewriter.getI32IntegerAttr(1)); + auto context = rewriter.create(loc, context_ptr_type, one, 0); + for (auto it : llvm::enumerate(context_vars)) + { + auto type = context_type.getBody()[it.index()]; + auto llvm_val = rewriter.create(loc, type, it.value()); + auto i = rewriter.getI32IntegerAttr(static_cast(it.index())); + mlir::Value indices[] = { + zero, + rewriter.create(loc, llvm_i32_type, i) + }; + auto pointer_type = mlir::LLVM::LLVMPointerType::get(type); + auto ptr = rewriter.create(loc, pointer_type, context, indices); + rewriter.create(loc, llvm_val, ptr); + } + auto void_ptr_type = mlir::LLVM::LLVMPointerType::get(mlir::IntegerType::get(op.getContext(), 8)); + auto context_abstract = rewriter.create(loc, void_ptr_type, context); + + auto input_range_type = [&]() + { + const mlir::Type members[] = { + llvm_index_type, // lower_bound + llvm_index_type, // upper_bound + llvm_index_type, // step + }; + return mlir::LLVM::LLVMStructType::getLiteral(op.getContext(), members); + }(); + auto input_range_ptr = mlir::LLVM::LLVMPointerType::get(input_range_type); + auto range_type = [&]() + { + const mlir::Type members[] = { + llvm_index_type, // lower_bound + llvm_index_type, // upper_bound + }; + return mlir::LLVM::LLVMStructType::getLiteral(op.getContext(), members); + }(); + auto range_ptr = mlir::LLVM::LLVMPointerType::get(range_type); + auto func_type = [&]() + { + const mlir::Type args[] = { + range_ptr, // bounds + index_type, // thread index + void_ptr_type // context + }; + return mlir::FunctionType::get(op.getContext(), args, {}); + }(); + + auto mod = op->getParentOfType(); + auto outlined_func = [&]()->mlir::FuncOp + { + auto func = [&]() + { + auto parent_func = op->getParentOfType(); + assert(parent_func); + auto func_name = [&]() + { + auto old_name = parent_func.getName(); + for (int i = 0;;++i) + { + auto name = (0 == i ? + (llvm::Twine(old_name) + "_outlined").str() : + (llvm::Twine(old_name) + "_outlined_" + llvm::Twine(i)).str()); + if (!mod.lookupSymbol(name)) + { + return name; + } + } + }(); + + auto func = plier::add_function(rewriter, mod, func_name, func_type); + copyAttrs(parent_func, func); + return func; + }(); + mlir::BlockAndValueMapping mapping; + auto& old_entry = op.getLoopBody().front(); + auto entry = func.addEntryBlock(); + auto loc = rewriter.getUnknownLoc(); + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(entry); + auto pos0 = rewriter.getI64ArrayAttr(0); + auto pos1 = rewriter.getI64ArrayAttr(1); + for (unsigned i = 0; i < num_loops; ++i) + { + auto arg = entry->getArgument(0); + const mlir::Value indices[] = { + rewriter.create(loc, llvm_i32_type, rewriter.getI32IntegerAttr(static_cast(i))) + }; + auto ptr = rewriter.create(loc, range_ptr, arg, indices); + auto dims = rewriter.create(loc, ptr); + auto lower = rewriter.create(loc, llvm_index_type, dims, pos0); + auto upper = rewriter.create(loc, llvm_index_type, dims, pos1); + mapping.map(old_entry.getArgument(i), from_llvm_index(lower)); + mapping.map(old_entry.getArgument(i + num_loops), from_llvm_index(upper)); + } + mapping.map(old_entry.getArgument(2 * num_loops), entry->getArgument(1)); // thread index + for (auto arg : context_constants) + { + rewriter.clone(*arg, mapping); + } + auto context_ptr = rewriter.create(loc, context_ptr_type, entry->getArgument(2)); + auto zero = rewriter.create(loc, llvm_i32_type, rewriter.getI32IntegerAttr(0)); + for (auto it : llvm::enumerate(context_vars)) + { + auto index = it.index(); + auto old_val = it.value(); + const mlir::Value indices[] = { + zero, + rewriter.create(loc, llvm_i32_type, rewriter.getI32IntegerAttr(static_cast(index))) + }; + auto pointer_type = mlir::LLVM::LLVMPointerType::get(context_type.getBody()[index]); + auto ptr = rewriter.create(loc, pointer_type, context_ptr, indices); + auto llvm_val = rewriter.create(loc, ptr); + auto val = rewriter.create(loc, old_val.getType(), llvm_val); + mapping.map(old_val, val); + } + op.getLoopBody().cloneInto(&func.getBody(), mapping); + auto& orig_entry = *std::next(func.getBody().begin()); + rewriter.create(loc, &orig_entry); + for (auto& block : func.getBody()) + { + if (auto term = mlir::dyn_cast(block.getTerminator())) + { + rewriter.eraseOp(term); + rewriter.setInsertionPointToEnd(&block); + rewriter.create(loc); + } + } + return func; + }(); + + auto parallel_for = [&]() + { + auto func_name = "numba_parallel_for2"; + if (auto sym = mod.lookupSymbol(func_name)) + { + return sym; + } + const mlir::Type args[] = { + input_range_ptr, // bounds + index_type, // num_loops + func_type, // func + void_ptr_type // context + }; + auto parallel_func_type = mlir::FunctionType::get(op.getContext(), args, {}); + return plier::add_function(rewriter, mod, func_name, parallel_func_type); + }(); + auto func_addr = rewriter.create(loc, func_type, rewriter.getSymbolRefAttr(outlined_func)); + + auto num_loops_var = rewriter.create(loc, num_loops); + auto input_ranges = rewriter.create(loc, input_range_ptr, to_llvm_index(num_loops_var), 0); + for (unsigned i = 0; i < num_loops; ++i) + { + mlir::Value input_range = rewriter.create(loc, input_range_type); + auto insert = [&](mlir::Value val, unsigned index) + { + input_range = rewriter.create(loc, input_range, val, rewriter.getI64ArrayAttr(index)); + }; + insert(to_llvm_index(op.lowerBounds()[i]), 0); + insert(to_llvm_index(op.upperBounds()[i]), 1); + insert(to_llvm_index(op.steps()[i]), 2); + const mlir::Value indices[] = { + rewriter.create(loc, llvm_i32_type, rewriter.getI32IntegerAttr(static_cast(i))) + }; + auto ptr = rewriter.create(loc, input_range_ptr, input_ranges, indices); + rewriter.create(loc, input_range, ptr); + } + + const mlir::Value pf_args[] = { + input_ranges, + num_loops_var, + func_addr, + context_abstract + }; + rewriter.create(loc, parallel_for, pf_args); + rewriter.eraseOp(op); + return mlir::success(); + } + +private: + mutable mlir::LLVMTypeConverter converter; // TODO +}; + +struct LowerParallelToCFGPass : + public mlir::PassWrapper> +{ + virtual void getDependentDialects( + mlir::DialectRegistry ®istry) const override + { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override final + { + mlir::OwningRewritePatternList patterns; + patterns.insert(&getContext()); + + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +struct PreLLVMLowering : public mlir::PassWrapper +{ + virtual void getDependentDialects( + mlir::DialectRegistry ®istry) const override + { + registry.insert(); + registry.insert(); + } + + void runOnFunction() override final + { + LLVMTypeHelper type_helper(getContext()); + + mlir::OwningRewritePatternList patterns; + auto func = getFunction(); + fix_func_sig(type_helper, func); + + patterns.insert(&getContext(), + type_helper.get_type_converter()); + + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +struct PostLLVMLowering : + public mlir::PassWrapper +{ + virtual void getDependentDialects( + mlir::DialectRegistry ®istry) const override + { + registry.insert(); + } + + void runOnFunction() override final + { + mlir::OwningRewritePatternList patterns; + + patterns.insert< + RemoveBitcasts, + ApplyFastmathFlags, + ApplyFastmathFlags, + ApplyFastmathFlags, + ApplyFastmathFlags, + ApplyFastmathFlags, + ApplyFastmathFlags, + ApplyFastmathFlags + >(&getContext()); + + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +struct LowerRetain : public mlir::OpConversionPattern +{ + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(plier::RetainOp op, llvm::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const override { + assert(operands.size() == 1); + auto arg = operands[0]; + if (!arg.getType().isa()) + { + return mlir::failure(); + } + + auto llvmVoidPointerType = + mlir::LLVM::LLVMPointerType::get(rewriter.getIntegerType(8)); + auto incref_func = [&]() + { + auto mod = op->getParentOfType(); + assert(mod); + auto func = mod.lookupSymbol("NRT_incref"); + if (!func) + { + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(mod.getBody()); + auto llvmVoidType = mlir::LLVM::LLVMVoidType::get(rewriter.getContext()); + func = rewriter.create( + rewriter.getUnknownLoc(), "NRT_incref", + mlir::LLVM::LLVMFunctionType::get(llvmVoidType, llvmVoidPointerType)); + } + return func; + }(); + + auto loc = op.getLoc(); + auto index = rewriter.getI64ArrayAttr(0); + auto elemType = arg.getType().cast().getBody()[0]; + mlir::Value ptr = rewriter.create(loc, elemType, arg, index); + ptr = rewriter.create(loc, llvmVoidPointerType, ptr); + rewriter.create(loc, incref_func, ptr); + rewriter.replaceOp(op, arg); + + return mlir::success(); + } +}; + +struct LowerCasts : public mlir::OpConversionPattern +{ + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(plier::CastOp op, llvm::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const override { + assert(operands.size() == 1); + auto converter = getTypeConverter(); + assert(nullptr != converter); + auto src_type = operands[0].getType(); + auto dst_type = converter->convertType(op.getType()); + if (src_type == dst_type) + { + rewriter.replaceOp(op, operands[0]); + return mlir::success(); + } + return mlir::failure(); + } +}; + +// Copypasted from mlir +struct LLVMLoweringPass : public mlir::PassWrapper> { + LLVMLoweringPass(const mlir::LowerToLLVMOptions& opts): + options(opts) {} + + /// Run the dialect converter on the module. + void runOnOperation() override { + using namespace mlir; + if (options.useBarePtrCallConv && options.emitCWrappers) { + getOperation().emitError() + << "incompatible conversion options: bare-pointer calling convention " + "and C wrapper emission"; + signalPassFailure(); + return; + } + if (failed(LLVM::LLVMDialect::verifyDataLayoutString( + options.dataLayout.getStringRepresentation(), [this](const Twine &message) { + getOperation().emitError() << message.str(); + }))) { + signalPassFailure(); + return; + } + + ModuleOp m = getOperation(); + + LLVMTypeConverter typeConverter(&getContext(), options); + + OwningRewritePatternList patterns; + populateStdToLLVMConversionPatterns(typeConverter, patterns); + patterns.insert(typeConverter, &getContext()); + patterns.insert(typeConverter); + + LLVMConversionTarget target(getContext()); + if (failed(applyPartialConversion(m, target, std::move(patterns)))) + signalPassFailure(); + m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), + StringAttr::get(m.getContext(), options.dataLayout.getStringRepresentation())); + } + +private: + mlir::LowerToLLVMOptions options; +}; + +void populate_lower_to_llvm_pipeline(mlir::OpPassManager& pm) +{ + pm.addPass(std::make_unique()); + pm.addPass(mlir::createLowerToCFGPass()); + pm.addPass(mlir::createCanonicalizerPass()); +// pm.addPass(std::make_unique()); + pm.addNestedPass(std::make_unique()); + pm.addPass(std::make_unique(getLLVMOptions())); + pm.addNestedPass(std::make_unique()); +} +} + + +void register_lower_to_llvm_pipeline(plier::PipelineRegistry& registry) +{ + registry.register_pipeline([](auto sink) + { + auto stage = get_lower_lowering_stage(); + sink(lower_to_llvm_pipeline_name(), {stage.begin}, {stage.end}, {}, &populate_lower_to_llvm_pipeline); + }); +} + +llvm::StringRef lower_to_llvm_pipeline_name() +{ + return "lower_to_llvm"; +} diff --git a/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.hpp b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.hpp new file mode 100644 index 00000000000..861cff36af7 --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.hpp @@ -0,0 +1,15 @@ +#pragma once + +namespace plier +{ +class PipelineRegistry; +} + +namespace llvm +{ +class StringRef; +} + +void register_lower_to_llvm_pipeline(plier::PipelineRegistry& registry); + +llvm::StringRef lower_to_llvm_pipeline_name(); diff --git a/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp b/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp new file mode 100644 index 00000000000..f0046b1bd9e --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp @@ -0,0 +1,218 @@ +#include "pipelines/parallel_to_tbb.hpp" + +#include +#include +#include +#include +#include +#include + +#include "plier/dialect.hpp" + +#include "plier/compiler/pipeline_registry.hpp" +#include "plier/transforms/const_utils.hpp" +#include "pipelines/base_pipeline.hpp" +#include "pipelines/lower_to_llvm.hpp" + +namespace +{ +mlir::MemRefType getReduceType(mlir::Type type, int64_t count) +{ + if (type.isIntOrFloat()) + { + return mlir::MemRefType::get(count, type); + } + return {}; +} + +mlir::Value getZeroVal(mlir::OpBuilder& builder, mlir::Location loc, mlir::Type type) +{ + auto const_val = plier::getZeroVal(type); + if (const_val) + { + return builder.create(loc, const_val); + } + llvm_unreachable("Unhandled type"); +} + +struct ParallelToTbb : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::scf::ParallelOp op, mlir::PatternRewriter &rewriter) const override + { + if (mlir::isa(op->getParentOp())) + { + return mlir::failure(); + } + bool need_parallel = op->hasAttr(plier::attributes::getParallelName()) || + !op->getParentOfType(); + if (!need_parallel) + { + return mlir::failure(); + } + + int64_t max_concurrency = 0; + auto mod = op->getParentOfType(); + if (auto mc = mod->getAttrOfType(plier::attributes::getMaxConcurrencyName())) + { + max_concurrency = mc.getInt(); + } + + if (max_concurrency <= 1) + { + return mlir::failure(); + } + for (auto type : op.getResultTypes()) + { + if (!getReduceType(type, max_concurrency)) + { + return mlir::failure(); + } + } + + auto loc = op.getLoc(); + mlir::BlockAndValueMapping mapping; + llvm::SmallVector reduce_vars(op.getNumResults()); + for (auto it : llvm::enumerate(op.getResultTypes())) + { + auto type = it.value(); + auto reduce_type = getReduceType(type, max_concurrency); + assert(reduce_type); + auto reduce = rewriter.create(loc, reduce_type); + auto index = static_cast(it.index()); + reduce_vars[index] = reduce; + } + + auto reduce_init_body_builder = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value index, mlir::ValueRange args) + { + assert(args.empty()); + (void)args; + for (auto it : llvm::enumerate(reduce_vars)) + { + auto reduce = it.value(); + auto type = op.getResultTypes()[it.index()]; + auto zero = getZeroVal(rewriter, loc, type); + builder.create(loc, zero, reduce, index); + } + builder.create(loc); + }; + + auto reduce_lower_bound = rewriter.create(loc, 0); + auto reduce_upper_bound = rewriter.create(loc, max_concurrency); + auto reduce_step = rewriter.create(loc, 1); + rewriter.create(loc, reduce_lower_bound, reduce_upper_bound, reduce_step, llvm::None, reduce_init_body_builder); + + auto& old_body = op.getLoopBody().front(); + auto orig_lower_bound = op.lowerBound(); + auto orig_upper_bound = op.upperBound(); + auto orig_step = op.step(); + auto body_builder = [&](mlir::OpBuilder &builder, ::mlir::Location loc, mlir::ValueRange lower_bound, mlir::ValueRange upper_bound, mlir::Value thread_index) + { + llvm::SmallVector initVals(op.initVals().size()); + for (auto it : llvm::enumerate(op.initVals())) + { + auto reduce_var = reduce_vars[it.index()]; + auto val = builder.create(loc, reduce_var, thread_index); + initVals[it.index()] = val; + } + auto new_op = mlir::cast(builder.clone(*op, mapping)); + new_op->removeAttr(plier::attributes::getParallelName()); + assert(new_op->getNumResults() == reduce_vars.size()); + new_op.lowerBoundMutable().assign(lower_bound); + new_op.upperBoundMutable().assign(upper_bound); + new_op.initValsMutable().assign(initVals); + for (auto it : llvm::enumerate(new_op->getResults())) + { + auto reduce_var = reduce_vars[it.index()]; + builder.create(loc, it.value(), reduce_var, thread_index); + } + }; + + rewriter.create(loc, orig_lower_bound, orig_upper_bound, orig_step, body_builder); + + auto reduce_body_builder = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value index, mlir::ValueRange args) + { + assert(args.size() == reduce_vars.size()); + mapping.clear(); + auto reduce_ops = llvm::make_filter_range(old_body.without_terminator(), [](auto& op) + { + return mlir::isa(op); + }); + llvm::SmallVector yield_args; + yield_args.reserve(args.size()); + for (auto it : llvm::enumerate(reduce_ops)) + { + auto& reduce_var = reduce_vars[it.index()]; + auto arg = args[static_cast(it.index())]; + auto reduce_op = mlir::cast(it.value()); + auto& reduce_op_body = reduce_op.reductionOperator().front(); + assert(reduce_op_body.getNumArguments() == 2); + auto prev_val = builder.create(loc, reduce_var, index); + mapping.map(reduce_op_body.getArgument(0), arg); + mapping.map(reduce_op_body.getArgument(1), prev_val); + for (auto& old_reduce_op : reduce_op_body.without_terminator()) + { + builder.clone(old_reduce_op, mapping); + } + auto result = mlir::cast(reduce_op_body.getTerminator()).result(); + result = mapping.lookupOrNull(result); + assert(result); + yield_args.emplace_back(result); + } + builder.create(loc, yield_args); + }; + + auto reduce_loop = rewriter.create(loc, reduce_lower_bound, reduce_upper_bound, reduce_step, op.initVals(), reduce_body_builder); + rewriter.replaceOp(op, reduce_loop.getResults()); + + return mlir::success(); + } +}; + +struct ParallelToTbbPass : + public mlir::PassWrapper> +{ + virtual void getDependentDialects( + mlir::DialectRegistry ®istry) const override + { + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override; +}; + +void ParallelToTbbPass::runOnOperation() +{ + mlir::OwningRewritePatternList patterns; + + patterns.insert< + ParallelToTbb + >(&getContext()); + + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +void populate_parallel_to_tbb_pipeline(mlir::OpPassManager& pm) +{ + pm.addNestedPass(std::make_unique()); +} +} + +void register_parallel_to_tbb_pipeline(plier::PipelineRegistry& registry) +{ + registry.register_pipeline([](auto sink) + { + auto stage = get_lower_lowering_stage(); + auto llvm_pipeline = lower_to_llvm_pipeline_name(); + sink(parallel_to_tbb_pipeline_name(), {stage.begin}, {llvm_pipeline}, {}, &populate_parallel_to_tbb_pipeline); + }); +} + +llvm::StringRef parallel_to_tbb_pipeline_name() +{ + return "parallel_to_tbb"; +} diff --git a/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.hpp b/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.hpp new file mode 100644 index 00000000000..cca9709169e --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.hpp @@ -0,0 +1,15 @@ +#pragma once + +namespace plier +{ +class PipelineRegistry; +} + +namespace llvm +{ +class StringRef; +} + +void register_parallel_to_tbb_pipeline(plier::PipelineRegistry& registry); + +llvm::StringRef parallel_to_tbb_pipeline_name(); diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp new file mode 100644 index 00000000000..67c5d2d166d --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -0,0 +1,1138 @@ +#include "pipelines/plier_to_linalg.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "plier/dialect.hpp" + +#include "pipelines/plier_to_std.hpp" + +#include "plier/transforms/pipeline_utils.hpp" +#include "plier/rewrites/call_lowering.hpp" +#include "plier/rewrites/canonicalize_reductions.hpp" +#include "plier/rewrites/cast_lowering.hpp" +#include "plier/rewrites/common_opts.hpp" +#include "plier/rewrites/cse.hpp" +#include "plier/rewrites/promote_to_parallel.hpp" +#include "plier/rewrites/type_conversion.hpp" +#include "plier/rewrites/force_inline.hpp" +#include "plier/rewrites/index_type_propagation.hpp" +#include "plier/rewrites/loop_rewrites.hpp" +#include "plier/rewrites/memory_rewrites.hpp" +#include "plier/transforms/loop_utils.hpp" + +#include "base_pipeline.hpp" +#include "plier/compiler/pipeline_registry.hpp" +#include "py_linalg_resolver.hpp" + +#include + +namespace +{ +void applyOptimizations(mlir::FuncOp op, const mlir::FrozenRewritePatternList& patterns, llvm::function_ref additionalOpts = nullptr) +{ + bool repeat = false; + do + { + repeat = false; + (void)mlir::applyPatternsAndFoldGreedily(op, patterns); + if (mlir::succeeded(plier::applyCSE(op.getRegion(), false))) + { + repeat = true; + } + if (mlir::succeeded(plier::optimizeMemoryOps(op))) + { + repeat = true; + } + if (additionalOpts && mlir::succeeded(additionalOpts(op))) + { + repeat = true; + } + } + while(repeat); +} + +enum class ArrayLayout +{ + C, + F +}; + +bool parse_layout(llvm::StringRef& name, ArrayLayout& layout) +{ + if (name.consume_back("C")) + { + layout = ArrayLayout::C; + return true; + } + if (name.consume_back("F")) + { + layout = ArrayLayout::F; + return true; + } + return false; +} + +template +bool consume_int_back(llvm::StringRef& name, T& result) +{ + unsigned len = 0; + auto tmp_name = name; + while (!tmp_name.empty() && std::isdigit(tmp_name.back())) + { + ++len; + tmp_name = tmp_name.drop_back(); + } + tmp_name = name.substr(name.size() - len); + if (!tmp_name.consumeInteger(10, result)) + { + name = name.substr(0, name.size() - len); + return true; + } + return false; +} + +struct ArrayDesc +{ + unsigned dims = 0; + ArrayLayout layout = {}; + llvm::StringRef name; +}; + +llvm::Optional parse_array_desc(llvm::StringRef& name) +{ + unsigned num_dims = 0; + ArrayLayout layout = {}; + if (name.consume_front("array(") && + name.consume_back(")") && + parse_layout(name, layout) && + name.consume_back(", ") && + name.consume_back("d") && + consume_int_back(name, num_dims) && + name.consume_back(", ") && + !name.empty()) + { + return ArrayDesc{num_dims, layout, name}; + } + return {}; +} + +mlir::Type map_array_type(mlir::MLIRContext& ctx, mlir::TypeConverter& conveter, + llvm::StringRef& name) +{ + if (auto desc = parse_array_desc(name)) + { + if (desc->layout == ArrayLayout::C) + { + if (auto type = conveter.convertType(plier::PyType::get(&ctx, desc->name))) + { + llvm::SmallVector shape(desc->dims, -1); + return mlir::RankedTensorType::get(shape, type); + } + } + } + return nullptr; +} + + +mlir::Type map_plier_type(mlir::TypeConverter& converter, mlir::Type type) +{ + if (type.isa()) + { + auto name = type.cast().getName(); + return map_array_type(*type.getContext(), converter, name); + } + return nullptr; +} + +bool check_numpy_args(llvm::ArrayRef args, unsigned expected_count) +{ + if (args.size() != expected_count) + { + return false; + } + for (auto arg : args) + { + auto type = arg.getType(); + if (!type.isa() && !type.isa()) + { + return false; + } + } + return true; +} + +void rerun_std_pipeline(mlir::Operation* op) +{ + assert(nullptr != op); + auto marker = mlir::StringAttr::get(op->getContext(), plier_to_std_pipeline_name()); + auto mod = op->getParentOfType(); + assert(nullptr != mod); + plier::add_pipeline_jump_marker(mod, marker); +} + +bool is_int(mlir::Type type) +{ + assert(type); + return type.isa(); +} + +mlir::LogicalResult lower_prange(plier::PyCallOp op, llvm::ArrayRef operands, llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) +{ + if (!kwargs.empty()) + { + return mlir::failure(); + } + if ((operands.size() < 1 || operands.size() > 3) || + !llvm::all_of(operands, [](mlir::Value val) { return is_int(val.getType());})) + { + return mlir::failure(); + } + mlir::Value val = op.getResult(); + if (!val.getUsers().empty()) + { + auto user = mlir::dyn_cast(*val.getUsers().begin()); + auto get_bounds = [&](mlir::OpBuilder& builder, mlir::Location loc) + { + auto lower_bound = (operands.size() >= 2 ? operands[0] : builder.create(loc, 0)); + auto upper_bound = (operands.size() >= 2 ? operands[1] : operands[0]); + auto step = (operands.size() == 3 ? operands[2] : builder.create(loc, 1)); + return std::make_tuple(lower_bound, upper_bound, step); + }; + auto get_index = [](mlir::OpBuilder& builder, mlir::Location loc, mlir::Type dst_type, mlir::Value index) + { + return builder.create(loc, dst_type, index); + }; + auto set_attr = [](mlir::scf::ForOp op) + { + op->setAttr(plier::attributes::getParallelName(), mlir::UnitAttr::get(op->getContext())); + }; + if (!user || mlir::failed(lower_while_to_for(user, rewriter, get_bounds, get_index, set_attr))) + { + return mlir::failure(); + } + } + + rerun_std_pipeline(op); + if (val.getUsers().empty()) + { + rewriter.eraseOp(op); + } + return mlir::success(); +} + +struct CallLowerer +{ + using args_t = llvm::ArrayRef; + using kwargs_t = llvm::ArrayRef>; + mlir::LogicalResult operator()( + plier::PyCallOp op, llvm::StringRef name, args_t args, + kwargs_t kwargs, + mlir::PatternRewriter& rewriter) + { + using func_t = mlir::LogicalResult(*)(plier::PyCallOp, args_t, kwargs_t, mlir::PatternRewriter&); + std::pair handlers[] = { + {"numba.prange", lower_prange}, + }; + for (auto& handler : handlers) + { + if (handler.first == name) + { + return handler.second(op, args, kwargs, rewriter); + } + } + + if (mlir::succeeded(applyRewrite(op, rewriter, linalg_resolver.rewrite_func(name, op.getLoc(), rewriter, args, kwargs)))) + { + return mlir::success(); + } + + if (name == "len" && check_numpy_args(args, 1) && kwargs.empty()) + { + auto loc = op.getLoc(); + mlir::Value dim = rewriter.create(loc, args[0], 0); + mlir::Value res = rewriter.create(loc, op.getType(), dim); + rerun_std_pipeline(op); + rewriter.replaceOp(op, res); + return mlir::success(); + } + return mlir::failure(); + } + + mlir::LogicalResult operator()( + plier::GetattrOp op, llvm::StringRef name, mlir::Value arg, + mlir::PatternRewriter& rewriter) + { + if (!arg.getType().isa()) + { + return mlir::failure(); + } + auto full_name = (llvm::Twine("array.") + name).str(); + return applyRewrite(op, rewriter, linalg_resolver.rewrite_attr(full_name, op.getLoc(), rewriter, arg)); + } + + mlir::LogicalResult operator()( + plier::BinOp op, llvm::StringRef name, mlir::Value lhs, mlir::Value rhs, + mlir::PatternRewriter& rewriter) + { + if (!lhs.getType().isa() && + !rhs.getType().isa()) + { + return mlir::failure(); + } + const std::pair names[] = { + {"+", "operator.add"}, + {"-", "operator.sub"}, + {"*", "operator.mul"}, + }; + for (auto it : names) + { + if (it.first == name) + { + return applyRewrite(op, rewriter, linalg_resolver.rewrite_func(it.second, op.getLoc(), rewriter, {lhs, rhs}, {})); + } + } + return mlir::failure(); + } + +private: + PyLinalgResolver linalg_resolver; + + mlir::LogicalResult applyRewrite(mlir::Operation* op, mlir::PatternRewriter& rewriter, llvm::Optional result) + { + if (result) + { + assert(result->size() == op->getNumResults()); + rerun_std_pipeline(op); + if (result->empty()) + { + rewriter.eraseOp(op); + } + else + { + rewriter.replaceOp(op, *result); + } + return mlir::success(); + } + return mlir::failure(); + } +}; + +mlir::Value index_cast(mlir::Value value, mlir::Location loc, mlir::OpBuilder& builder) +{ + if (!value.getType().isa()) + { + auto index_type = mlir::IndexType::get(value.getContext()); + auto res = builder.create(loc, index_type, value); + rerun_std_pipeline(res); + return res; + } + return value; +} + +bool isValidGetitemIndex(mlir::Type type) +{ + return type.isa(); +} + +template +struct GetitemOpLowering : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + T op, mlir::PatternRewriter &rewriter) const override + { + assert(op.getNumOperands() == 2); + auto val = op.getOperand(0); + auto index = op.getOperand(1); + auto type = val.getType(); + bool is_memref = type.template isa(); + bool is_tensor = type.template isa(); + if (!is_memref && !is_tensor) + { + return mlir::failure(); + } + if (!isValidGetitemIndex(index.getType())) + { + return mlir::failure(); + } + auto loc = op.getLoc(); + + llvm::SmallVector indices; + if (auto tuple_type = index.getType().template dyn_cast()) + { + indices.resize(tuple_type.size()); + for (auto it : llvm::enumerate(tuple_type)) + { + auto getitem_ind = rewriter.create(loc, it.index()); + auto ind = rewriter.create(loc, index, getitem_ind); + indices[it.index()] = index_cast(ind, loc, rewriter); + } + } + else + { + indices.push_back(index_cast(index, loc, rewriter)); + } + + mlir::Value res; + if (is_memref) + { + res = rewriter.create(loc, val, indices); + } + else if (is_tensor) + { + res = rewriter.create(loc, val, indices); + } + else + { + llvm_unreachable("Invalid getitem"); + } + rerun_std_pipeline(op); + rewriter.replaceOp(op, res); + return mlir::success(); + } +}; + +bool can_replace_ssa(mlir::Operation* op) +{ + assert(nullptr != op); + if (op->getParentRegion()->getBlocks().size() != 1) + { + return false; + } + auto parent = op->getParentOp(); + if (mlir::isa(parent)) + { + return true; + } + return false; +// return can_replace_ssa(parent); +} + +bool replace_ssa_in_block(mlir::Value value, mlir::Value new_value, mlir::PatternRewriter &rewriter) +{ + auto new_op = new_value.getDefiningOp(); + assert(nullptr != new_op); + auto block = new_op->getBlock(); + bool changed = false; + for (auto user : llvm::make_early_inc_range(value.getUsers())) + { + if (auto op = block->findAncestorOpInBlock(*user)) + { + if (op != new_op && new_op->isBeforeInBlock(op)) + { + rewriter.updateRootInPlace(user, [&]() + { + for (auto it2 : llvm::enumerate(user->getOperands())) + { + if (it2.value() == value) + { + user->setOperand(static_cast(it2.index()), new_value); + break; + } + } + }); + changed = true; + } + } + } + return changed; +} + +bool replace_ssa_value(mlir::Value value, mlir::Value new_value, mlir::PatternRewriter &rewriter) +{ + bool changed = replace_ssa_in_block(value, new_value, rewriter); + auto parent = new_value.getDefiningOp()->getParentOp(); + if (auto func = mlir::dyn_cast(parent)) + { + // TODO update return + return changed; + } + llvm_unreachable("Unhandled parent op"); +} + +template +struct SetitemOpLoweringSSA : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + T op, mlir::PatternRewriter &rewriter) const override + { + if (!can_replace_ssa(op)) + { + return mlir::failure(); + } + auto target = op.getOperand(0); + auto index = op.getOperand(1); + auto value = op.getOperand(2); + auto target_type = target.getType().template dyn_cast(); + if (!target_type) + { + return mlir::failure(); + } + auto elem_type = target_type.getElementType(); + auto loc = op.getLoc(); + if (value.getType() != elem_type) + { + // TODO + value = rewriter.create(loc, elem_type, value); + rerun_std_pipeline(op); +// return mlir::failure(); + } + + auto new_tensor = rewriter.create(loc, value); + auto new_index = index_cast(index, loc, rewriter); + mlir::Value one = rewriter.create(loc, 1); + auto new_value = rewriter.create(loc, new_tensor, target, new_index, one, one); + replace_ssa_value(target, new_value, rewriter); + rewriter.eraseOp(op); + return mlir::success(); + } +}; + +struct PlierToLinalgPass : + public mlir::PassWrapper> +{ + virtual void getDependentDialects( + mlir::DialectRegistry ®istry) const override + { + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override; +}; + +template +struct SetitemOpLowering : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + T op, mlir::PatternRewriter &rewriter) const override + { + auto get_target_type = [&]() + { + return op.getOperand(0).getType(); + }; + + auto index = op.index(); + if (!isValidGetitemIndex(index.getType())) + { + return mlir::failure(); + } + + if (auto target_type = get_target_type().template dyn_cast()) + { + auto target = op.getOperand(0); + mlir::OpBuilder::InsertionGuard g(rewriter); + if (auto parent_op = target.getDefiningOp()) + { + rewriter.setInsertionPointAfter(parent_op); + } + else + { + rewriter.setInsertionPointToStart(target.getParentBlock()); + } + auto memref_type = mlir::MemRefType::get(target_type.getShape(), target_type.getElementType()); + auto memref = rewriter.create(target.getLoc(), memref_type, target); + for (auto& use : llvm::make_early_inc_range(target.getUses())) + { + auto use_op = use.getOwner(); + assert(nullptr != use_op); + if (use_op != memref) + { + if (mlir::isa(use_op)) + { + use_op->setOperand(use.getOperandNumber(), memref); + } + else + { + mlir::OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(use_op); + auto new_val = rewriter.create(use_op->getLoc(), memref); + rewriter.updateRootInPlace(use_op, [&]() + { + use_op->setOperand(use.getOperandNumber(), new_val); + }); + } + } + } + } + else if (get_target_type().template isa()) + { + // nothing + } + else + { + return mlir::failure(); + } + auto target = op.getOperand(0); + auto value = op.getOperand(2); + auto loc = op.getLoc(); + auto elem_type = target.getType().template cast().getElementType(); + if (value.getType() != elem_type) + { + // TODO + value = rewriter.create(loc, elem_type, value); + rerun_std_pipeline(op); + } + + llvm::SmallVector indices; + if (auto tuple_type = index.getType().template dyn_cast()) + { + indices.resize(tuple_type.size()); + for (auto it : llvm::enumerate(tuple_type)) + { + auto getitem_ind = rewriter.create(loc, it.index()); + auto ind = rewriter.create(loc, index, getitem_ind); + indices[it.index()] = index_cast(ind, loc, rewriter); + } + rerun_std_pipeline(op); + } + else + { + indices.push_back(index_cast(index, loc, rewriter)); + } + rewriter.create(loc, value, target, indices); + rewriter.eraseOp(op); + return mlir::success(); + } +}; + +struct ArrayShape : public mlir::OpRewritePattern +{ + ArrayShape(mlir::TypeConverter& type_converter, + mlir::MLIRContext* context): + OpRewritePattern(context), + converter(type_converter) {} + + mlir::LogicalResult matchAndRewrite( + plier::GetattrOp op, mlir::PatternRewriter &rewriter) const override + { + auto type = op.value().getType().dyn_cast(); + if (!type || op.name() != "shape" || !type.hasRank()) + { + return mlir::failure(); + } + + auto rank = static_cast(type.getRank()); + auto elem_type = converter.convertType(op.getType()).dyn_cast_or_null(); + if (!elem_type || elem_type.size() != rank) + { + return mlir::failure(); + } + + llvm::SmallVector dims(rank); + for (size_t i = 0; i < rank; ++i) + { + auto dim = rewriter.create(op.getLoc(), op.value(), i); + dims[i] = rewriter.create(op.getLoc(), elem_type.getType(i), dim); + } + auto res = rewriter.create(op.getLoc(), op.getType(), dims); + rerun_std_pipeline(op); + rewriter.replaceOp(op, res.getResult()); + return mlir::success(); + } + +private: + mlir::TypeConverter& converter; +}; + +template +bool has_compatibale_shape(T&& a1, T&& a2) +{ + if (!a1.hasRank() || !a2.hasRank() || a1.getRank() != a2.getRank()) + { + return false; + } + for (auto it : llvm::zip(a1.getShape(), a2.getShape())) + { + auto s1 = std::get<0>(it); + auto s2 = std::get<1>(it); + if (s1 >= 0 && s2 >= 0 && s1 != s2) + { + return false; + } + } + return true; +} + +struct RankedTypesCasts : public mlir::OpRewritePattern +{ + RankedTypesCasts(mlir::TypeConverter& /*type_converter*/, + mlir::MLIRContext* context): + OpRewritePattern(context){} + + mlir::LogicalResult matchAndRewrite( + plier::CastOp op, mlir::PatternRewriter &rewriter) const override + { + auto src_type = op.value().getType(); + auto dst_type = op.getType(); + if (src_type.isa() && dst_type.isa()) + { + auto src = src_type.cast(); + auto dst = dst_type.cast(); + if (!has_compatibale_shape(src,dst)) + { + return mlir::failure(); + } + rewriter.replaceOpWithNewOp(op, dst, op.value()); + return mlir::success(); + } + return mlir::failure(); + } +}; + +struct GetattrRewriter : public mlir::OpRewritePattern +{ + using resolver_t = std::function; + + GetattrRewriter(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context, + resolver_t resolver): + OpRewritePattern(context), + resolver(resolver) + {} + + mlir::LogicalResult matchAndRewrite( + plier::GetattrOp op, mlir::PatternRewriter &rewriter) const override + { + return resolver(op, op.name(), op.value(), rewriter); + } + +private: + resolver_t resolver; +}; + +struct BinopRewriter : public mlir::OpRewritePattern +{ + using resolver_t = std::function; + + BinopRewriter(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context, + resolver_t resolver): + OpRewritePattern(context), + resolver(resolver) + {} + + mlir::LogicalResult matchAndRewrite( + plier::BinOp op, mlir::PatternRewriter &rewriter) const override + { + return resolver(op, op.op(), op.lhs(), op.rhs(), rewriter); + } + +private: + resolver_t resolver; +}; + +struct SimplifyExpandDims : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::linalg::GenericOp op, mlir::PatternRewriter &rewriter) const override + { + if (!op.hasTensorSemantics()) + { + return mlir::failure(); + } + if (op.getNumInputs() != 1 || op.getNumOutputs() != 1) + { + return mlir::failure(); + } + + auto context = op.getContext(); + auto parallel_attr = mlir::StringAttr::get(context, "parallel"); + if (llvm::any_of(op.iterator_types(), [&](auto attr) { return attr != parallel_attr; })) + { + return mlir::failure(); + } + + auto maps = op.indexing_maps(); + assert(maps.size() == 2); + auto out_map = maps[1].cast().getValue(); + if (!out_map.isIdentity()) + { + return mlir::failure(); + } + auto in_map = maps[0].cast().getValue(); + auto num_dims = op.getNumLoops(); + if (in_map.getNumResults() != num_dims) + { + return mlir::failure(); + } + + bool changed = false; + auto out_shape = op.getOutput(0).getType().cast().getShape(); + llvm::SmallVector exprs(num_dims); + for (unsigned i = 0; i < num_dims; ++i) + { + auto prev_expr = in_map.getResult(i); + bool can_convert = [&]() + { + if (out_shape[i] == 1) + { + auto const_expr = prev_expr.dyn_cast(); + if (const_expr && const_expr.getValue() == 0) + { + return true; + } + } + return false; + }(); + if (can_convert) + { + changed = true; + exprs[i] = mlir::getAffineDimExpr(i, context); + } + else + { + exprs[i] = prev_expr; + } + } + + if (changed) + { + const mlir::Attribute new_maps[] = { + mlir::AffineMapAttr::get(mlir::AffineMap::get(num_dims, 0, exprs, context)), + maps[1] + }; + auto new_maps_attr = mlir::ArrayAttr::get(context, new_maps); + rewriter.updateRootInPlace(op, [&]() + { + op.indexing_mapsAttr(new_maps_attr); + }); + } + + return mlir::success(changed); + } +}; + +struct LowerEnforceShape : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + plier::EnforceShapeOp op, mlir::PatternRewriter &rewriter) const override + { + auto type = op.getType(); + auto src = op.value(); + rewriter.replaceOpWithNewOp(op, type, src); + return mlir::success(); + } +}; + +void PlierToLinalgPass::runOnOperation() +{ + auto context = &getContext(); + + mlir::TypeConverter type_converter; + // Convert unknown types to itself + type_converter.addConversion([](mlir::Type type) { return type; }); + populate_std_type_converter(getContext(), type_converter); + type_converter.addConversion([&](plier::PyType type)->llvm::Optional + { + auto ret = map_plier_type(type_converter, type); + if (!ret) + { + return llvm::None; + } + return ret; + }); + + mlir::OwningRewritePatternList patterns; + patterns.insert< + plier::FuncOpSignatureConversion, + plier::CastOpLowering, + RankedTypesCasts, + ArrayShape + >(type_converter, context); + + CallLowerer callLowerer; + + patterns.insert< + plier::CallOpLowering, + GetattrRewriter, + BinopRewriter + >(type_converter, context, std::ref(callLowerer)); + + patterns.insert< + GetitemOpLowering, + GetitemOpLowering, + SetitemOpLowering + >(&getContext()); + + // range/prange lowering need dead branch pruning to properly + // handle negative steps + for (auto *op : context->getRegisteredOperations()) + { + op->getCanonicalizationPatterns(patterns, context); + } + + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +struct LowerLinalgPass : + public mlir::PassWrapper> +{ + virtual void getDependentDialects( + mlir::DialectRegistry ®istry) const override + { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override; +}; + +void LowerLinalgPass::runOnOperation() +{ + mlir::OwningRewritePatternList patterns; + + patterns.insert< + mlir::linalg::LinalgLoweringPattern, + mlir::linalg::LinalgLoweringPattern + >(&getContext(), mlir::linalg::LinalgLoweringType::ParallelLoops); + + + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +struct PostPlierToLinalgPass : + public mlir::PassWrapper +{ + void runOnFunction() override; +}; + +void PostPlierToLinalgPass::runOnFunction() +{ + mlir::OwningRewritePatternList patterns; + + auto& context = getContext(); + plier::populate_common_opts_patterns(context, patterns); + + patterns.insert< + SimplifyExpandDims + >(&getContext()); + + applyOptimizations(getFunction(), std::move(patterns)); +} + +struct TensorFusionPass : + public mlir::PassWrapper> +{ + void runOnOperation() override; +}; + +void TensorFusionPass::runOnOperation() +{ + mlir::OwningRewritePatternList patterns; + + auto& context = getContext(); + plier::populate_common_opts_patterns(context, patterns); + + patterns.insert< + SimplifyExpandDims, + LowerEnforceShape + >(&getContext()); + + mlir::populateLinalgTensorOpsFusionPatterns(&context, patterns); + + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +struct CommonOptPass : + public mlir::PassWrapper> +{ + void runOnOperation() override; +}; + +void CommonOptPass::runOnOperation() +{ + mlir::OwningRewritePatternList patterns; + + auto& context = getContext(); + plier::populate_common_opts_patterns(context, patterns); + + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +struct LoopInvariantCodeMotion : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::scf::ForOp op, mlir::PatternRewriter &rewriter) const override + { + auto parentOp = op->getParentOp(); + rewriter.startRootUpdate(parentOp); + auto res = mlir::moveLoopInvariantCode(op); + if (mlir::succeeded(res)) + { + rewriter.finalizeRootUpdate(parentOp); + } + else + { + rewriter.cancelRootUpdate(parentOp); + } + return res; + } +}; + +struct RetainArgsPass : + public mlir::PassWrapper +{ + virtual void getDependentDialects( + mlir::DialectRegistry ®istry) const override + { + registry.insert(); + } + + void runOnFunction() override; +}; + +void RetainArgsPass::runOnFunction() +{ + auto func = getFunction(); + if (func.isPrivate() || func.isDeclaration() || func.body().empty()) + { + return; + } + + mlir::OpBuilder builder(&getContext()); + auto loc = builder.getUnknownLoc(); + auto block = &func.body().front(); + builder.setInsertionPointToStart(block); + for (auto arg : block->getArguments()) + { + if (arg.getType().isa()) + { + auto retained = builder.create(loc, arg); + llvm::SmallPtrSet except({retained}); + arg.replaceAllUsesExcept(retained, except); + } + } +} + +struct PostLinalgOptPass : + public mlir::PassWrapper +{ + void runOnFunction() override; +}; + +void PostLinalgOptPass::runOnFunction() +{ + mlir::OwningRewritePatternList patterns; + + auto& context = getContext(); + plier::populate_common_opts_patterns(context, patterns); + + patterns.insert< + plier::CanonicalizeReduction + >(&context); + + applyOptimizations(getFunction(), std::move(patterns), [](mlir::FuncOp op) + { + return plier::naivelyFuseParallelOps(op.getRegion()); + }); +} + +struct PromoteParallelPass : + public mlir::PassWrapper +{ + void runOnFunction() override; +}; + +void PromoteParallelPass::runOnFunction() +{ + mlir::OwningRewritePatternList patterns; + + auto& context = getContext(); + plier::populate_common_opts_patterns(context, patterns); + + patterns.insert< + plier::CanonicalizeReduction, + plier::PromoteToParallel // TODO + >(&context); + + applyOptimizations(getFunction(), std::move(patterns)); +} + +void populate_plier_to_linalg_gen_pipeline(mlir::OpPassManager& pm) +{ + pm.addPass(std::make_unique()); + pm.addNestedPass(std::make_unique()); + pm.addPass(mlir::createSymbolDCEPass()); +} + +void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm) +{ + pm.addPass(std::make_unique()); + + pm.addPass(mlir::createTensorConstantBufferizePass()); + pm.addNestedPass(mlir::createSCFBufferizePass()); + pm.addNestedPass(mlir::createLinalgBufferizePass()); + pm.addNestedPass(mlir::createStdBufferizePass()); + pm.addNestedPass(mlir::createTensorBufferizePass()); + pm.addPass(mlir::createFuncBufferizePass()); + pm.addNestedPass(mlir::createFinalizingBufferizePass()); + + pm.addNestedPass(mlir::createBufferHoistingPass()); + pm.addNestedPass(mlir::createBufferLoopHoistingPass()); + pm.addNestedPass(mlir::createPromoteBuffersToStackPass()); + + pm.addNestedPass(std::make_unique()); + pm.addNestedPass(mlir::createBufferDeallocationPass()); + pm.addPass(mlir::createCopyRemovalPass()); + + pm.addPass(std::make_unique()); + pm.addNestedPass(std::make_unique()); + pm.addPass(mlir::createSymbolDCEPass()); + pm.addNestedPass(std::make_unique()); +} +} + +void register_plier_to_linalg_pipeline(plier::PipelineRegistry& registry) +{ + registry.register_pipeline([](auto sink) + { + auto stage = get_high_lowering_stage(); + sink(plier_to_linalg_gen_pipeline_name(), {plier_to_std_pipeline_name()}, {plier_to_linalg_opt_pipeline_name()}, {plier_to_std_pipeline_name()}, &populate_plier_to_linalg_gen_pipeline); + sink(plier_to_linalg_opt_pipeline_name(), {plier_to_linalg_gen_pipeline_name()}, {stage.end}, {}, &populate_plier_to_linalg_opt_pipeline); + }); +} + +llvm::StringRef plier_to_linalg_gen_pipeline_name() +{ + return "plier_to_linalg_gen"; +} + +llvm::StringRef plier_to_linalg_opt_pipeline_name() +{ + return "plier_to_linalg_opt"; +} diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.hpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.hpp new file mode 100644 index 00000000000..660b0e51703 --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.hpp @@ -0,0 +1,16 @@ +#pragma once + +namespace plier +{ +class PipelineRegistry; +} + +namespace llvm +{ +class StringRef; +} + +void register_plier_to_linalg_pipeline(plier::PipelineRegistry& registry); + +llvm::StringRef plier_to_linalg_gen_pipeline_name(); +llvm::StringRef plier_to_linalg_opt_pipeline_name(); diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp new file mode 100644 index 00000000000..4bd91a4f2db --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -0,0 +1,1461 @@ +#include "pipelines/plier_to_std.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "plier/dialect.hpp" + +#include "plier/rewrites/call_lowering.hpp" +#include "plier/rewrites/cast_lowering.hpp" +#include "plier/rewrites/type_conversion.hpp" +#include "plier/transforms/const_utils.hpp" +#include "plier/transforms/func_utils.hpp" +#include "plier/transforms/loop_utils.hpp" + +#include "base_pipeline.hpp" +#include "plier/compiler/pipeline_registry.hpp" +#include "py_func_resolver.hpp" +#include "mangle.hpp" + +namespace +{ +mlir::Type map_int_type(mlir::MLIRContext& ctx, llvm::StringRef& name) +{ + unsigned num_bits = 0; + if (name.consume_front("int") && + !name.consumeInteger(10, num_bits)) + { + return mlir::IntegerType::get(&ctx, num_bits); + } + return nullptr; +} + +mlir::Type map_int_literal_type(mlir::MLIRContext& ctx, llvm::StringRef& name) +{ + unsigned dummy = 0; + if (name.consume_front("Literal[int](") && + !name.consumeInteger(10, dummy) && name.consume_front(")")) + { + return mlir::IntegerType::get(&ctx, 64); // TODO + } + return nullptr; +} + +mlir::Type map_bool_type(mlir::MLIRContext& ctx, llvm::StringRef& name) +{ + if (name.consume_front("bool")) + { + return mlir::IntegerType::get(&ctx, 1); + } + return nullptr; +} + +mlir::Type map_float_type(mlir::MLIRContext& ctx, llvm::StringRef& name) +{ + unsigned num_bits = 0; + if (name.consume_front("float") && + !name.consumeInteger(10, num_bits)) + { + switch(num_bits) + { + case 64: return mlir::Float64Type::get(&ctx); + case 32: return mlir::Float32Type::get(&ctx); + case 16: return mlir::Float16Type::get(&ctx); + } + } + return nullptr; +} + +mlir::Type map_plier_type_name(mlir::MLIRContext& ctx, llvm::StringRef& name); +bool map_type_helper(mlir::MLIRContext& ctx, llvm::StringRef& name, mlir::Type& ret) +{ + auto type = map_plier_type_name(ctx, name); + if (static_cast(type)) + { + ret = type; + return true; + } + return false; +} + +mlir::Type map_pair_type(mlir::MLIRContext& ctx, llvm::StringRef& name) +{ + mlir::Type first; + mlir::Type second; + if (name.consume_front("pair<") && + map_type_helper(ctx, name, first) && + name.consume_front(", ") && + map_type_helper(ctx, name, second) && + name.consume_front(">")) + { + return mlir::TupleType::get(&ctx, {first, second}); + } + return nullptr; +} + +mlir::Type map_unituple_type(mlir::MLIRContext& ctx, llvm::StringRef& name) +{ + mlir::Type type; + unsigned count = 0; + if (name.consume_front("UniTuple(") && + map_type_helper(ctx, name, type) && + name.consume_front(" x ") && + !name.consumeInteger(10, count) && + name.consume_front(")")) + { + llvm::SmallVector types(count, type); + return mlir::TupleType::get(&ctx, types); + } + return nullptr; +} + +mlir::Type map_tuple_type(mlir::MLIRContext& ctx, llvm::StringRef& name) +{ + if (!name.consume_front("Tuple(")) + { + return nullptr; + } + llvm::SmallVector types; + while (true) + { + if (name.consume_front(")")) + { + break; + } + auto type = map_plier_type_name(ctx, name); + if (!static_cast(type)) + { + return nullptr; + } + types.push_back(type); + (void)name.consume_front(", "); + } + return mlir::TupleType::get(&ctx, types); +} + +mlir::Type map_func_type(mlir::MLIRContext& ctx, llvm::StringRef& name) +{ + if (name.consume_front("Function(") && + name.consume_front("") && // TODO unhardcode; + name.consume_front(")")) + { + return mlir::FunctionType::get(&ctx, {}, {}); + } + return nullptr; +} + +mlir::Type map_plier_type_name(mlir::MLIRContext& ctx, llvm::StringRef& name) +{ + using func_t = mlir::Type(*)(mlir::MLIRContext& ctx, llvm::StringRef& name); + const func_t handlers[] = { + &map_int_type, + &map_int_literal_type, + &map_bool_type, + &map_float_type, + &map_pair_type, + &map_unituple_type, + &map_tuple_type, + &map_func_type, + }; + for (auto h : handlers) + { + auto temp_name = name; + auto t = h(ctx, temp_name); + if (static_cast(t)) + { + name = temp_name; + return t; + } + } + return nullptr; +} + +mlir::Type map_plier_type(mlir::Type type) +{ + assert(type); + if (!type.isa()) + { + return type; + } + auto name = type.cast().getName(); + return map_plier_type_name(*type.getContext(), name); +} + +bool is_supported_type(mlir::Type type) +{ + assert(type); + return type.isIntOrFloat(); +} + +bool is_int(mlir::Type type) +{ + assert(type); + return type.isa(); +} + +bool is_float(mlir::Type type) +{ + assert(type); + return type.isa(); +} + +bool is_index(mlir::Type type) +{ + assert(type); + return type.isa(); +} + +struct ConstOpLowering : public mlir::OpRewritePattern +{ + ConstOpLowering(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} + + mlir::LogicalResult matchAndRewrite( + plier::ConstOp op, mlir::PatternRewriter &rewriter) const override + { + auto value = op.val(); + if (!is_supported_type(value.getType())) + { + return mlir::failure(); + } + rewriter.replaceOpWithNewOp(op, value); + return mlir::success(); + } +}; + +struct ArgOpLowering : public mlir::OpRewritePattern +{ + ArgOpLowering(mlir::TypeConverter &typeConverter, + mlir::MLIRContext *context): + OpRewritePattern(context), converter(typeConverter) {} + + mlir::LogicalResult matchAndRewrite( + plier::ArgOp op, mlir::PatternRewriter &rewriter) const override + { + auto func = op->getParentOfType(); + if (!func) + { + return mlir::failure(); + } + + auto index= op.index(); + if (index >= func.getNumArguments()) + { + return mlir::failure(); + } + + auto arg = func.getArgument(index); + if(converter.convertType(op.getType()) != arg.getType()) + { + return mlir::failure(); + } + rewriter.replaceOp(op, arg); + return mlir::success(); + } +private: + mlir::TypeConverter& converter; +}; + + + +struct ReturnOpLowering : public mlir::OpRewritePattern +{ + ReturnOpLowering(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} + + mlir::LogicalResult matchAndRewrite( + mlir::ReturnOp op, mlir::PatternRewriter &rewriter) const override + { + auto operands = op.getOperands(); + auto func = mlir::cast(op->getParentOp()); + auto res_types = func.getType().getResults(); + assert(res_types.size() == operands.size() || res_types.empty()); + bool converted = (res_types.size() != operands.size()); + llvm::SmallVector new_vals; + for (auto it : llvm::zip(operands, res_types)) + { + auto src = std::get<0>(it); + auto dst = std::get<1>(it); + if (src.getType() != dst) + { + auto new_op = rewriter.create(op.getLoc(), dst, src); + new_vals.push_back(new_op); + converted = true; + } + else + { + new_vals.push_back(src); + } + } + if (converted) + { + rewriter.create(op.getLoc(), new_vals); + rewriter.eraseOp(op); + return mlir::success(); + } + return mlir::failure(); + } +}; + +struct SelectOpLowering : public mlir::OpRewritePattern +{ + SelectOpLowering(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} + + mlir::LogicalResult matchAndRewrite( + mlir::SelectOp op, mlir::PatternRewriter &rewriter) const override + { + auto operands = op.getOperands(); + assert(operands.size() == 3); + auto true_val = operands[1]; + auto false_val = operands[2]; + if (true_val.getType() == false_val.getType() && + true_val.getType() != op.getType()) + { + auto cond = operands[0]; + rewriter.replaceOpWithNewOp(op, cond, true_val, false_val); + return mlir::success(); + } + return mlir::failure(); + } +}; + +struct CondBrOpLowering : public mlir::OpRewritePattern +{ + CondBrOpLowering(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} + + mlir::LogicalResult matchAndRewrite( + mlir::CondBranchOp op, mlir::PatternRewriter &rewriter) const override + { + auto operands = op.getOperands(); + assert(!operands.empty()); + auto cond = operands.front(); + operands = operands.drop_front(); + bool changed = false; + + auto process_operand = [&](mlir::Block& block, auto& ret) + { + for (auto arg : block.getArguments()) + { + assert(!operands.empty()); + auto val = operands.front(); + operands = operands.drop_front(); + auto src_type = val.getType(); + auto dst_type = arg.getType(); + if (src_type != dst_type) + { + ret.push_back(rewriter.create(op.getLoc(), dst_type, val)); + changed = true; + } + else + { + ret.push_back(val); + } + } + }; + + llvm::SmallVector true_vals; + llvm::SmallVector false_vals; + auto true_dest = op.getTrueDest(); + auto false_dest = op.getFalseDest(); + process_operand(*true_dest, true_vals); + process_operand(*false_dest, false_vals); + if (changed) + { + rewriter.create(op.getLoc(), cond, true_dest, true_vals, false_dest, false_vals); + rewriter.eraseOp(op); + return mlir::success(); + } + return mlir::failure(); + } +}; + +mlir::Type coerce(mlir::Type type0, mlir::Type type1) +{ + // TODO: proper rules + assert(type0 != type1); + auto get_bits_count = [](mlir::Type type)->unsigned + { + assert(type); + if (type.isa()) + { + return type.cast().getWidth(); + } + if (type.isa()) + { + return 11; + } + if (type.isa()) + { + return 24; + } + if (type.isa()) + { + return 53; + } + llvm_unreachable("Unhandled type"); + }; + auto f0 = is_float(type0); + auto f1 = is_float(type1); + if (f0 && !f1) + { + return type0; + } + if (!f0 && f1) + { + return type1; + } + return get_bits_count(type0) < get_bits_count(type1) ? type1 : type0; +} + +template +mlir::Value int_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& rewriter) +{ + auto src_bits = val.getType().cast().getWidth(); + auto dst_bits = dst_type.cast().getWidth(); + assert(src_bits != dst_bits); + if (dst_bits > src_bits) + { + using T = std::conditional_t; + return rewriter.create(val.getLoc(), val, dst_type); + } + else + { + return rewriter.create(val.getLoc(), val, dst_type); + } +} + +template +mlir::Value int_float_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& rewriter) +{ + using T = std::conditional_t; + return rewriter.create(val.getLoc(), val, dst_type); +} + +template +mlir::Value float_int_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& rewriter) +{ + using T = std::conditional_t; + return rewriter.create(val.getLoc(), val, dst_type); +} + +mlir::Value index_cast_impl(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& rewriter) +{ + return rewriter.create(val.getLoc(), val, dst_type); +} + +mlir::Value do_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& rewriter) +{ + assert(dst_type); + auto src_type = val.getType(); + if (src_type == dst_type) + { + return val; + } + + struct Handler + { + using selector_t = bool(*)(mlir::Type); + using cast_op_t = mlir::Value(*)(mlir::Type, mlir::Value, mlir::PatternRewriter&); + selector_t src; + selector_t dst; + cast_op_t cast_op; + }; + + const Handler handlers[] = { + {&is_int, &is_int, &int_cast}, + {&is_int, &is_float, &int_float_cast}, + {&is_float, &is_int, &float_int_cast}, + {&is_index, &is_int, &index_cast_impl}, + {&is_int, &is_index, &index_cast_impl}, + }; + + for (auto& h : handlers) + { + if (h.src(src_type) && h.dst(dst_type)) + { + return h.cast_op(dst_type, val, rewriter); + } + } + + return nullptr; +} + +template +void replace_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type new_type, mlir::ValueRange operands) +{ + assert(nullptr != op); + llvm::SmallVector new_operands(operands.size()); + for (auto it : llvm::enumerate(operands)) + { + new_operands[it.index()] = do_cast(new_type, it.value(), rewriter); + } + rewriter.replaceOpWithNewOp(op, new_type, new_operands); +} + +void replace_itruediv_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type new_type, mlir::ValueRange operands) +{ + assert(nullptr != op); + assert(new_type.isa()); + auto lhs = do_cast(new_type, operands[0], rewriter); + auto rhs = do_cast(new_type, operands[1], rewriter); + rewriter.replaceOpWithNewOp(op, lhs, rhs); +} + +void replace_imod_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type /*new_type*/, mlir::ValueRange operands) +{ + auto loc = op->getLoc(); + auto a = operands[0]; + auto b = operands[1]; + auto v1 = rewriter.create(loc, a, b).getResult(); + auto v2 = rewriter.create(loc, v1, b).getResult(); + auto res = rewriter.create(loc, v2, b).getResult(); + rewriter.replaceOp(op, res); +} + +void replace_fmod_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type /*new_type*/, mlir::ValueRange operands) +{ + auto loc = op->getLoc(); + auto a = operands[0]; + auto b = operands[1]; + auto v1 = rewriter.create(loc, a, b).getResult(); + auto v2 = rewriter.create(loc, v1, b).getResult(); + auto res = rewriter.create(loc, v2, b).getResult(); + rewriter.replaceOp(op, res); +} + +template +void replace_cmp_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type /*new_type*/, mlir::ValueRange operands) +{ + assert(nullptr != op); + auto pred_attr = mlir::IntegerAttr::get(mlir::IntegerType::get(op->getContext(), 64), Pred); + mlir::Type new_type = mlir::IntegerType::get(op->getContext(), 1); + rewriter.replaceOpWithNewOp(op, new_type, pred_attr, operands[0], operands[1]); +} + + +struct BinOpLowering : public mlir::OpRewritePattern +{ + BinOpLowering(mlir::TypeConverter &typeConverter, + mlir::MLIRContext *context): + OpRewritePattern(context), converter(typeConverter) {} + + mlir::LogicalResult matchAndRewrite( + plier::BinOp op, mlir::PatternRewriter &rewriter) const override + { + auto operands = op.getOperands(); + assert(operands.size() == 2); + auto type0 = operands[0].getType(); + auto type1 = operands[1].getType(); + if (!is_supported_type(type0) || !is_supported_type(type1)) + { + return mlir::failure(); + } + auto res_type = converter.convertType(op.getType()); + if (!res_type || !is_supported_type(res_type)) + { + return mlir::failure(); + } + mlir::Type final_type; + std::array converted_operands; + if (type0 != type1) + { + final_type = coerce(type0, type1); + converted_operands = { + do_cast(final_type, operands[0], rewriter), + do_cast(final_type, operands[1], rewriter)}; + } + else + { + final_type = type0; + converted_operands = {operands[0], operands[1]}; + } + assert(static_cast(final_type)); + + using func_t = void(*)(mlir::Operation*, mlir::PatternRewriter&, mlir::Type, mlir::ValueRange); + struct OpDesc + { + llvm::StringRef type; + func_t iop; + func_t fop; + }; + + const OpDesc handlers[] = { + {"+", &replace_op, &replace_op}, + {"-", &replace_op, &replace_op}, + {"*", &replace_op, &replace_op}, + {"/", &replace_itruediv_op, &replace_op}, + {"%", &replace_imod_op, &replace_fmod_op}, + + {">", &replace_cmp_op(mlir::CmpIPredicate::sgt)>, + &replace_cmp_op(mlir::CmpFPredicate::OGT)>}, + {">=", &replace_cmp_op(mlir::CmpIPredicate::sge)>, + &replace_cmp_op(mlir::CmpFPredicate::OGE)>}, + {"<", &replace_cmp_op(mlir::CmpIPredicate::slt)>, + &replace_cmp_op(mlir::CmpFPredicate::OLT)>}, + {"<=", &replace_cmp_op(mlir::CmpIPredicate::sle)>, + &replace_cmp_op(mlir::CmpFPredicate::OLE)>}, + {"!=", &replace_cmp_op(mlir::CmpIPredicate::ne)>, + &replace_cmp_op(mlir::CmpFPredicate::ONE)>}, + {"==", &replace_cmp_op(mlir::CmpIPredicate::eq)>, + &replace_cmp_op(mlir::CmpFPredicate::OEQ)>}, + }; + + using membptr_t = func_t OpDesc::*; + auto call_handler = [&](membptr_t mem) + { + for (auto& h : handlers) + { + if (h.type == op.op()) + { + (h.*mem)(op, rewriter, res_type, converted_operands); + return mlir::success(); + } + } + return mlir::failure(); + }; + + + if (is_int(final_type)) + { + return call_handler(&OpDesc::iop); + } + else if (is_float(final_type)) + { + return call_handler(&OpDesc::fop); + } + return mlir::failure(); + } +private: + mlir::TypeConverter& converter; +}; + +mlir::Value negate(mlir::Value val, mlir::Location loc, mlir::PatternRewriter &rewriter) +{ + auto type = val.getType(); + if (auto itype = type.dyn_cast()) + { + // TODO: not int negation? + auto zero = rewriter.create(loc, mlir::IntegerAttr::get(itype, 0)); + return rewriter.create(loc, zero, val); + } + if (type.isa()) + { + return rewriter.create(loc, val); + } + llvm_unreachable("negate: unsupported type"); +} + +struct UnaryOpLowering : public mlir::OpRewritePattern +{ + UnaryOpLowering(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} + + mlir::LogicalResult matchAndRewrite( + plier::UnaryOp op, mlir::PatternRewriter &rewriter) const override + { + auto arg = op.getOperand(); + auto type = arg.getType(); + if (!is_supported_type(type)) + { + return mlir::failure(); + } + if (op.op() == "+") + { + rewriter.replaceOp(op, arg); + return mlir::success(); + } + assert(op.op() == "-"); + auto new_val = negate(arg, op.getLoc(), rewriter); + rewriter.replaceOp(op, new_val); + return mlir::success(); + } +}; + +mlir::Block* get_next_block(mlir::Block* block) +{ + assert(nullptr != block); + if (auto br = mlir::dyn_cast_or_null(block->getTerminator())) + { + return br.dest(); + } + return nullptr; +}; + +void erase_blocks(llvm::ArrayRef blocks) +{ + for (auto block : blocks) + { + assert(nullptr != block); + block->dropAllDefinedValueUses(); + } + for (auto block : blocks) + { + block->erase(); + } +} + +bool is_blocks_different(llvm::ArrayRef blocks) +{ + for (auto it : llvm::enumerate(blocks)) + { + auto block1 = it.value(); + assert(nullptr != block1); + for (auto block2 : blocks.drop_front(it.index() + 1)) + { + assert(nullptr != block2); + if (block1 == block2) + { + return false; + } + } + } + return true; +} + +struct ScfIfRewrite : public mlir::OpRewritePattern +{ + ScfIfRewrite(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} + + mlir::LogicalResult matchAndRewrite( + mlir::CondBranchOp op, mlir::PatternRewriter &rewriter) const override + { + auto getDest = [&](bool true_dest) + { + return true_dest ? op.getTrueDest() : op.getFalseDest(); + }; + auto getOperands = [&](bool true_dest) + { + return true_dest ? op.getTrueOperands() : op.getFalseOperands(); + }; + auto loc = op.getLoc(); + for (bool reverse : {false, true}) + { + auto true_block = getDest(!reverse); + auto post_block = get_next_block(true_block); + if (nullptr == post_block) + { + continue; + } + auto false_block = getDest(reverse); + if (false_block != post_block && + get_next_block(false_block) != post_block) + { + continue; + } + + auto start_block = op.getOperation()->getBlock(); + if (!is_blocks_different({start_block, true_block, post_block})) + { + continue; + } + mlir::Value cond = op.condition(); + if (reverse) + { + auto i1 = mlir::IntegerType::get(op.getContext(), 1); + auto one = rewriter.create(loc, mlir::IntegerAttr::get(i1, 1)); + cond = rewriter.create(loc, cond, one); + } + + mlir::BlockAndValueMapping mapper; + llvm::SmallVector yield_vals; + auto copy_block = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::Block& block) + { + mapper.clear(); + for (auto& op : block.without_terminator()) + { + builder.clone(op, mapper); + } + auto term = mlir::cast(block.getTerminator()); + yield_vals.clear(); + yield_vals.reserve(term.getNumOperands()); + for (auto op : term.getOperands()) + { + yield_vals.emplace_back(mapper.lookupOrDefault(op)); + } + builder.create(loc, yield_vals); + }; + + auto true_body = [&](mlir::OpBuilder& builder, mlir::Location loc) + { + copy_block(builder, loc, *true_block); + }; + + bool has_else = false_block != post_block; + auto res_types = mlir::cast(true_block->getTerminator()).getOperandTypes(); + mlir::scf::IfOp if_op; + if (has_else) + { + auto false_body = [&](mlir::OpBuilder& builder, mlir::Location loc) + { + copy_block(builder, loc, *false_block); + }; + if_op = rewriter.create( + loc, + res_types, + cond, + true_body, + false_body); + } + else + { + if (res_types.empty()) + { + if_op = rewriter.create( + loc, + res_types, + cond, + true_body); + } + else + { + auto false_body = [&](mlir::OpBuilder& builder, mlir::Location loc) + { + auto res = getOperands(reverse); + yield_vals.clear(); + yield_vals.reserve(res.size()); + for (auto op : res) + { + yield_vals.emplace_back(mapper.lookupOrDefault(op)); + } + builder.create(loc, yield_vals); + }; + if_op = rewriter.create( + loc, + res_types, + cond, + true_body, + false_body); + } + } + + rewriter.create(loc, post_block, if_op.getResults()); + rewriter.eraseOp(op); + + if (true_block->getUsers().empty()) + { + erase_blocks(true_block); + } + if (false_block->getUsers().empty()) + { + erase_blocks(false_block); + } + return mlir::success(); + } + return mlir::failure(); + } +}; + +mlir::scf::WhileOp create_while( + mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange iterArgs, + llvm::function_ref beforeBuilder, + llvm::function_ref afterBuilder) +{ + mlir::OperationState state(loc, mlir::scf::WhileOp::getOperationName()); + state.addOperands(iterArgs); + + { + mlir::OpBuilder::InsertionGuard g(builder); + auto add_region = [&](mlir::ValueRange args)->mlir::Block* + { + auto reg = state.addRegion(); + auto block = builder.createBlock(reg); + for (auto arg : args) + { + block->addArgument(arg.getType()); + } + return block; + }; + + auto beforeBlock = add_region(iterArgs); + beforeBuilder(builder, state.location, beforeBlock->getArguments()); + auto cond = mlir::cast(beforeBlock->getTerminator()); + state.addTypes(cond.args().getTypes()); + + auto afterblock = add_region(cond.args()); + afterBuilder(builder, state.location, afterblock->getArguments()); + } + return mlir::cast(builder.createOperation(state)); +} + +bool is_inside_block(mlir::Operation* op, mlir::Block* block) +{ + assert(nullptr != op); + assert(nullptr != block); + do + { + if (op->getBlock() == block) + { + return true; + } + } + while((op = op->getParentOp())); + return false; +} + +struct ScfWhileRewrite : public mlir::OpRewritePattern +{ + ScfWhileRewrite(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} + + mlir::LogicalResult matchAndRewrite( + mlir::BranchOp op, mlir::PatternRewriter &rewriter) const override + { + auto before_block = op.dest(); + auto before_term = mlir::dyn_cast(before_block->getTerminator()); + if (!before_term) + { + return mlir::failure(); + } + auto start_block = op.getOperation()->getBlock(); + auto after_block = before_term.trueDest(); + auto post_block = before_term.falseDest(); + if (get_next_block(after_block) != before_block || + !is_blocks_different({start_block, before_block, after_block, post_block})) + { + return mlir::failure(); + } + + auto check_outside_vals = [&](mlir::Operation* op)->mlir::WalkResult + { + for (auto user : op->getUsers()) + { + if (!is_inside_block(user, before_block) && + !is_inside_block(user, after_block)) + { + return mlir::WalkResult::interrupt(); + } + } + return mlir::WalkResult::advance(); + }; + + if (after_block->walk(check_outside_vals).wasInterrupted()) + { + return mlir::failure(); + } + + mlir::BlockAndValueMapping mapper; + llvm::SmallVector yield_vars; + auto before_block_args = before_block->getArguments(); + llvm::SmallVector orig_vars(before_block_args.begin(), before_block_args.end()); + + auto before_body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange iterargs) + { + mapper.map(before_block_args, iterargs); + yield_vars.resize(before_block_args.size()); + for (auto& op : before_block->without_terminator()) + { + auto new_op = builder.clone(op, mapper); + for (auto user : op.getUsers()) + { + if (!is_inside_block(user, before_block)) + { + for (auto it : llvm::zip(op.getResults(), new_op->getResults())) + { + orig_vars.emplace_back(std::get<0>(it)); + yield_vars.emplace_back(std::get<1>(it)); + } + break; + } + } + } + + llvm::transform(before_block->getArguments(), yield_vars.begin(), + [&](mlir::Value val) { return mapper.lookupOrDefault(val); }); + + auto term = mlir::cast(before_block->getTerminator()); + for (auto arg : term.falseDestOperands()) + { + orig_vars.emplace_back(arg); + yield_vars.emplace_back(mapper.lookupOrDefault(arg)); + } + auto cond = mapper.lookupOrDefault(term.condition()); + builder.create(loc, cond, yield_vars); + }; + auto after_body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange iterargs) + { + mapper.clear(); + assert(orig_vars.size() == iterargs.size()); + mapper.map(orig_vars, iterargs); + for (auto& op : after_block->without_terminator()) + { + builder.clone(op, mapper); + } + yield_vars.clear(); + auto term = mlir::cast(after_block->getTerminator()); + for (auto arg : term.getOperands()) + { + yield_vars.emplace_back(mapper.lookupOrDefault(arg)); + } + builder.create(loc, yield_vars); + }; + + auto while_op = create_while( + rewriter, + op.getLoc(), + op.getOperands(), + before_body, + after_body); + + assert(orig_vars.size() == while_op.getNumResults()); + for (auto arg : llvm::zip(orig_vars, while_op.getResults())) + { + std::get<0>(arg).replaceAllUsesWith(std::get<1>(arg)); + } + + rewriter.create(op.getLoc(), post_block, before_term.falseDestOperands()); + rewriter.eraseOp(op); + erase_blocks({before_block, after_block}); + + return mlir::success(); + } +}; + +struct FixupWhileTypes : public mlir::OpRewritePattern +{ + FixupWhileTypes(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} + + mlir::LogicalResult matchAndRewrite( + mlir::scf::WhileOp op, mlir::PatternRewriter &rewriter) const override + { + bool changed = false; + mlir::OpBuilder::InsertionGuard g(rewriter); + auto before_block = &op.before().front(); + rewriter.startRootUpdate(op); + rewriter.setInsertionPointToStart(before_block); + assert(before_block->getNumArguments() == op.getNumOperands()); + auto loc = rewriter.getUnknownLoc(); + for (auto it : llvm::zip(op.getOperandTypes(), before_block->getArguments())) + { + auto new_type = std::get<0>(it); + auto arg = std::get<1>(it); + auto old_type = arg.getType(); + if (old_type != new_type) + { + rewriter.create(loc, old_type, arg); + arg.setType(new_type); + changed = true; + } + } + + auto term = mlir::cast(before_block->getTerminator()); + auto after_types = term.args().getTypes(); + + auto after_block = &op.after().front(); + rewriter.setInsertionPointToStart(after_block); + assert(after_block->getNumArguments() == term.args().size()); + for (auto it : llvm::zip(after_types, after_block->getArguments())) + { + auto new_type = std::get<0>(it); + auto arg = std::get<1>(it); + auto old_type = arg.getType(); + if (old_type != new_type) + { + rewriter.create(loc, old_type, arg); + arg.setType(new_type); + changed = true; + } + } + + rewriter.setInsertionPointAfter(op); + assert(op.getNumResults() == term.args().size()); + for (auto it : llvm::zip(after_types, op.getResults())) + { + auto new_type = std::get<0>(it); + auto arg = std::get<1>(it); + auto old_type = arg.getType(); + if (old_type != new_type) + { + rewriter.create(loc, old_type, arg); + arg.setType(new_type); + changed = true; + } + } + + if (changed) + { + rewriter.finalizeRootUpdate(op); + } + else + { + rewriter.cancelRootUpdate(op); + } + return mlir::success(changed); + } +}; + +struct PropagateBuildTupleTypes : public mlir::OpRewritePattern +{ + PropagateBuildTupleTypes(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} + + mlir::LogicalResult matchAndRewrite( + plier::BuildTupleOp op, mlir::PatternRewriter &rewriter) const override + { + if (op.getType().isa() || + llvm::any_of(op.getOperandTypes(), [](mlir::Type type){ return type.isa(); })) + { + return mlir::failure(); + } + + auto new_type = mlir::TupleType::get(op.getContext(), op.getOperandTypes()); + rewriter.replaceOpWithNewOp(op, new_type, op.getOperands()); + return mlir::success(); + } +}; + +template +struct FoldTupleGetitem : public mlir::OpRewritePattern +{ + FoldTupleGetitem(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + mlir::OpRewritePattern(context) {} + + mlir::LogicalResult matchAndRewrite( + Op op, mlir::PatternRewriter &rewriter) const override + { + auto build_tuple = op.value().template getDefiningOp(); + if (!build_tuple) + { + return mlir::failure(); + } + + if (auto val = plier::getConstVal(op.getOperand(1))) + { + auto index = val.getInt(); + if (index >= 0 && index < build_tuple.getNumOperands()) + { + auto val = build_tuple.getOperand(static_cast(index)); + rewriter.replaceOp(op, val); + return mlir::success(); + } + } + return mlir::failure(); + } +}; + +mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef operands, llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) +{ + if (!kwargs.empty()) + { + return mlir::failure(); + } + if ((operands.size() < 1 || operands.size() > 3) || + !llvm::all_of(operands, [](mlir::Value val) { return is_int(val.getType());})) + { + return mlir::failure(); + } + mlir::Value val = op.getResult(); + if (!val.getUsers().empty()) + { + auto user = mlir::dyn_cast(*val.getUsers().begin()); + auto get_bounds = [&](mlir::OpBuilder& builder, mlir::Location loc) + { + auto lower_bound = (operands.size() >= 2 ? operands[0] : builder.create(loc, 0)); + auto upper_bound = (operands.size() >= 2 ? operands[1] : operands[0]); + auto step = (operands.size() == 3 ? operands[2] : builder.create(loc, 1)); + return std::make_tuple(lower_bound, upper_bound, step); + }; + auto get_index = [](mlir::OpBuilder& builder, mlir::Location loc, mlir::Type dst_type, mlir::Value index) + { + return builder.create(loc, dst_type, index); + }; + if (!user || mlir::failed(lower_while_to_for(user, rewriter, get_bounds, get_index))) + { + return mlir::failure(); + } + } + + if (val.getUsers().empty()) + { + rewriter.eraseOp(op); + } + return mlir::success(); +} + +mlir::LogicalResult lower_len(plier::PyCallOp op, llvm::ArrayRef operands, llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) +{ + if (!kwargs.empty()) + { + return mlir::failure(); + } + if (operands.size() != 1) + { + return mlir::failure(); + } + + auto build_tuple = operands[0].getDefiningOp(); + if (!build_tuple) + { + return mlir::failure(); + } + + auto size = rewriter.create(op.getLoc(), build_tuple.getNumOperands()); + auto cast = rewriter.create(op.getLoc(), op.getType(), size); + rewriter.replaceOp(op, cast.getResult()); + return mlir::success(); +} + +mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, llvm::ArrayRef operands, llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) +{ + if (!kwargs.empty()) + { + return mlir::failure(); + } + if (operands.size() != 1) + { + return mlir::failure(); + } + auto val = operands[0]; + bool success = false; + auto replace_op = [&](mlir::Value val) + { + assert(!success); + if (val) + { + rewriter.replaceOp(op, val); + success = true; + } + }; + auto src_type = val.getType(); + auto dst_type = mlir::IntegerType::get(op.getContext(), 1); + mlir::TypeSwitch(src_type) + .Case([&](auto) { replace_op(do_cast(dst_type, val, rewriter)); }); + return mlir::success(success); +} + +mlir::FuncOp get_lib_symbol( + mlir::ModuleOp mod, llvm::StringRef name, mlir::FunctionType type, + mlir::PatternRewriter& rewriter) +{ + assert(!name.empty()); + if (auto op = mod.lookupSymbol(name)) + { + assert(op.getType() == type); + return op; + } + + return plier::add_function(rewriter, mod, name, type); +} + +mlir::LogicalResult lower_math_func( + plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, + llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) +{ + if (!kwargs.empty()) + { + return mlir::failure(); + } + auto ret_type = map_plier_type(op.getType()); + auto valid_type = [&](mlir::Type type) + { + return type.isa(); + }; + if (ret_type && name.consume_front("math.") && args.size() == 1 && + valid_type(args[0].getType())) + { + auto loc = op.getLoc(); + mlir::Value arg = rewriter.create(loc, ret_type, args[0]); + auto is_float = ret_type.isa(); + auto func_type = mlir::FunctionType::get(op.getContext(), ret_type, ret_type); + auto module = op->getParentOfType(); + mlir::FuncOp func; + if (is_float) + { + func = get_lib_symbol(module, name.str() + "f", func_type, rewriter); + } + else // double + { + func = get_lib_symbol(module, name, func_type, rewriter); + } + auto call = rewriter.create(loc, func, arg); + rewriter.replaceOp(op, call.getResults()); + return mlir::success(); + } + + return mlir::failure(); +} + +struct CallLowerer +{ + mlir::LogicalResult operator()(plier::PyCallOp op, llvm::StringRef name, + llvm::ArrayRef args, llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) + { + if (mlir::succeeded(lower_math_func(op, name, args, kwargs, rewriter))) + { + return mlir::success(); + } + + using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, llvm::ArrayRef>, mlir::PatternRewriter&); + std::pair handlers[] = { + {"bool", lower_bool_cast}, + {"range", lower_range}, + {"len", lower_len}, + }; + for (auto& handler : handlers) + { + if (handler.first == name) + { + return handler.second(op, args, kwargs, rewriter); + } + } + + mlir::ValueRange r(args); + auto mangled_name = mangle(name, r.getTypes()); + if (!mangled_name.empty()) + { + auto mod = op->getParentOfType(); + assert(mod); + auto func = mod.lookupSymbol(mangled_name); + if (!func) + { + func = py_resolver.get_func(name, r.getTypes()); + if (func) + { + func.setPrivate(); + func.setName(mangled_name); + } + } + if (func) + { + assert(func.getType().getNumResults() == op->getNumResults()); + auto new_func_call = rewriter.create(op.getLoc(), func, args); + rewriter.replaceOp(op, new_func_call.getResults()); + return mlir::success(); + } + } + + return mlir::failure(); + } + +private: + PyFuncResolver py_resolver; +}; + +struct PlierToStdPass : + public mlir::PassWrapper> +{ + virtual void getDependentDialects( + mlir::DialectRegistry ®istry) const override + { + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override; +}; + +void PlierToStdPass::runOnOperation() +{ + mlir::TypeConverter type_converter; + // Convert unknown types to itself + type_converter.addConversion([](mlir::Type type) { return type; }); + + auto context = &getContext(); + populate_std_type_converter(*context, type_converter); + + mlir::OwningRewritePatternList patterns; + + patterns.insert< + plier::FuncOpSignatureConversion, + ArgOpLowering, + ReturnOpLowering, + ConstOpLowering, + SelectOpLowering, + CondBrOpLowering, + BinOpLowering, + UnaryOpLowering, + ScfIfRewrite, + ScfWhileRewrite, + FixupWhileTypes, + PropagateBuildTupleTypes, + FoldTupleGetitem, + FoldTupleGetitem + >(type_converter, context); + + patterns.insert< + plier::CastOpLowering + >(type_converter, context, &do_cast); + + CallLowerer callLowerer; + + patterns.insert< + plier::CallOpLowering + >(type_converter, context, std::ref(callLowerer)); + + mlir::populateStdExpandOpsPatterns(context, patterns); + + // range/prange lowering need dead branch pruning to properly + // handle negative steps + for (auto *op : context->getRegisteredOperations()) + { + op->getCanonicalizationPatterns(patterns, context); + } + + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +void populate_plier_to_std_pipeline(mlir::OpPassManager& pm) +{ + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(std::make_unique()); +} +} + +void populate_std_type_converter(mlir::MLIRContext& context, mlir::TypeConverter& converter) +{ + auto none_type = plier::PyType::getNone(&context); + converter.addConversion( + [none_type](mlir::Type type, llvm::SmallVectorImpl& ret_types) + ->llvm::Optional + { + if (type == none_type) + { + return mlir::success(); + } + auto ret = map_plier_type(type); + if (!ret) + { + return llvm::None; + } + ret_types.push_back(ret); + return mlir::success(); + }); +} + +void register_plier_to_std_pipeline(plier::PipelineRegistry& registry) +{ + registry.register_pipeline([](auto sink) + { + auto stage = get_high_lowering_stage(); + sink(plier_to_std_pipeline_name(), {stage.begin}, {stage.end}, {}, &populate_plier_to_std_pipeline); + }); +} + +llvm::StringRef plier_to_std_pipeline_name() +{ + return "plier_to_std"; +} diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.hpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.hpp new file mode 100644 index 00000000000..c71cb6aa07c --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.hpp @@ -0,0 +1,23 @@ +#pragma once + +namespace plier +{ +class PipelineRegistry; +} + +namespace llvm +{ +class StringRef; +} + +namespace mlir +{ +class MLIRContext; +class TypeConverter; +} + +void populate_std_type_converter(mlir::MLIRContext& context, mlir::TypeConverter& converter); + +void register_plier_to_std_pipeline(plier::PipelineRegistry& registry); + +llvm::StringRef plier_to_std_pipeline_name(); diff --git a/mlir-compiler/mlir-compiler/src/py_func_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_func_resolver.cpp new file mode 100644 index 00000000000..f927c028203 --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/py_func_resolver.cpp @@ -0,0 +1,50 @@ +#include "py_func_resolver.hpp" + +#include + +#include + +#include "py_map_types.hpp" + +namespace py = pybind11; + +struct PyFuncResolver::Context +{ + py::handle resolver; + py::handle compiler; + py::handle types; +}; + +PyFuncResolver::PyFuncResolver(): + context(std::make_unique()) +{ + auto registry_mod = py::module::import("numba.mlir.func_registry"); + auto compiler_mod = py::module::import("numba.mlir.inner_compiler"); + context->resolver = registry_mod.attr("find_active_func"); + context->compiler = compiler_mod.attr("compile_func"); + context->types = py::module::import("numba.core.types"); +} + +PyFuncResolver::~PyFuncResolver() +{ + +} + +mlir::FuncOp PyFuncResolver::get_func(llvm::StringRef name, mlir::TypeRange types) +{ + assert(!name.empty()); + auto py_name = py::str(name.data(), name.size()); + auto py_func = context->resolver(py_name); + if (py_func.is_none()) + { + return {}; + } + auto py_types = map_types_to_numba(context->types, types); + if (py_types.is_none()) + { + return {}; + } + auto res = static_cast(context->compiler(py_func, py_types).cast()); + auto func = (res ? mlir::cast(res) : nullptr); + return func; +} diff --git a/mlir-compiler/mlir-compiler/src/py_func_resolver.hpp b/mlir-compiler/mlir-compiler/src/py_func_resolver.hpp new file mode 100644 index 00000000000..8c94dabb96e --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/py_func_resolver.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include + +namespace llvm +{ +class StringRef; +} + +namespace mlir +{ +class FuncOp; +class TypeRange; +} + +class PyFuncResolver +{ +public: + PyFuncResolver(); + ~PyFuncResolver(); + + mlir::FuncOp get_func(llvm::StringRef name, mlir::TypeRange types); + +private: + struct Context; + std::unique_ptr context; +}; diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp new file mode 100644 index 00000000000..06804c756c3 --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -0,0 +1,1167 @@ +#include "py_linalg_resolver.hpp" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "plier/dialect.hpp" +#include "py_map_types.hpp" +#include "plier/utils.hpp" +#include "plier/transforms/const_utils.hpp" + +namespace py = pybind11; + +struct PyBuilderContext +{ + mlir::Location loc; + mlir::OpBuilder& builder; + PyLinalgResolver::Context& context; +}; + +namespace +{ +bool is_compatible_type(mlir::Type type) +{ + if (auto tuple_type = type.dyn_cast()) + { + return llvm::all_of(tuple_type, &is_compatible_type); + } + return type.isa(); +} + +template +bool is_compatible_types(R&& vals) +{ + return llvm::all_of(vals, [](auto val) { return is_compatible_type(val.getType()); }); +} + +template +py::capsule wrap_mlir(T val) +{ + return py::capsule(val.getAsOpaquePointer()); +} + +template +T unwrap_mlir(py::capsule obj) +{ + return T::getFromOpaquePointer(static_cast(obj)); +} + +auto unwrap_ssa_val(py::handle obj) +{ + return unwrap_mlir(obj.attr("_ssa_val").cast()); +} + +auto unwrap_type(py::handle obj) +{ + return unwrap_mlir(obj.attr("_mlir_type").cast()); +} + +size_t container_size(py::handle obj) +{ + if (py::isinstance(obj)) + { + return obj.cast().size(); + } + if (py::isinstance(obj)) + { + return obj.cast().size(); + } + return 1; +} + +template +void container_iterate(py::handle obj, F&& func) +{ + auto impl = [&](auto cont) + { + for (auto it : llvm::enumerate(cont)) + { + func(it.index(), it.value()); + } + }; + if (py::isinstance(obj)) + { + impl(obj.cast()); + } + else if (py::isinstance(obj)) + { + impl(obj.cast()); + } + else + { + func(std::size_t(0), obj); + } +} + +llvm::Optional make_py_literal(mlir::Value val) +{ + assert(val); + if (auto int_val = plier::getConstVal(val)) + { + return py::int_(int_val.getInt()); + } + if (auto float_val = plier::getConstVal(val)) + { + return py::float_(float_val.getValueAsDouble()); + } + return {}; +} + +mlir::Value do_cast(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value val, mlir::Type type) +{ + if (val.getType() != type) + { + return builder.create(loc, type, val); + } + return val; +} + +bool cmp_capsule(py::capsule a1, py::capsule a2) +{ + return static_cast(a1) == static_cast(a2); +} + +void setup_py_var(py::handle var); +} + +struct PyLinalgResolver::Context +{ + py::handle var; + py::handle type; + py::handle builder; + py::handle inspect; + py::handle types_mod; + py::handle compile_func; + py::handle lookup_func; + + py::object create_var(py::capsule context, mlir::Value value) + { + assert(value); + if (auto literal = make_py_literal(value)) + { + return *literal; + } + auto ret = var(context, wrap_mlir(value)); + setup_py_var(ret); + return ret; + } + + py::object create_type(mlir::Type t) + { + return type(wrap_mlir(t), py::cpp_function(&cmp_capsule)); + } + + mlir::FuncOp compile_body(py::handle body, py::list arg_types) + { + auto func = compile_func(body, arg_types).cast(); + auto mlir_func = mlir::cast(static_cast(func)); + mlir_func.setPrivate(); + mlir_func->setAttr(plier::attributes::getForceInlineName(), mlir::UnitAttr::get(mlir_func->getContext())); + return mlir_func; + } + + py::object wrap_result(py::capsule context, mlir::ValueRange values) + { + if (values.empty()) + { + return py::none(); + } + if (values.size() == 1) + { + return create_var(context, values.front()); + } + py::tuple ret(values.size()); + for (auto it : llvm::enumerate(values)) + { + ret[it.index()] = create_var(context, it.value()); + } + return std::move(ret); + } + + mlir::Value unwrap_val(mlir::Location loc, mlir::OpBuilder& builder, py::handle obj) + { + if (py::isinstance(obj, var)) + { + return unwrap_ssa_val(obj); + } + if (py::isinstance(obj)) + { + auto attr = builder.getI64IntegerAttr(obj.cast()); + return builder.create(loc, attr); + } + if (py::isinstance(obj)) + { + auto attr = builder.getF64FloatAttr(obj.cast()); + return builder.create(loc, attr); + } + plier::report_error("Invalid element type"); + } +}; + +namespace +{ +py::object get_args(py::handle inspect, py::handle func, llvm::function_ref create_var, + mlir::ValueRange args, llvm::ArrayRef> kwargs) +{ + auto sig_func = inspect.attr("signature"); + auto sig = sig_func(func); + auto params = sig.attr("parameters"); + auto params_list = py::list(params); + params_list = params_list[py::slice(1, static_cast(params_list.size()), 1)]; // skip builder param + auto empty = inspect.attr("Parameter").attr("empty"); + + py::list ret(py::len(params_list)); + for (auto it : llvm::enumerate(params_list)) + { + auto index = it.index(); + auto param_name = it.value(); + auto param = params[param_name]; + if (!args.empty()) + { + ret[index] = create_var(args.front()); + args = args.drop_front(); + continue; + } + if (!kwargs.empty()) + { + auto name = param_name.cast(); + auto val = [&]()->mlir::Value + { + for (auto kwarg : kwargs) + { + if (kwarg.first == name) + { + return kwarg.second; + } + } + return {}; + }(); + if (val) + { + ret[index] = create_var(val); + continue; + } + } + auto def_val = param.attr("default"); + if (!def_val.is(empty)) + { + ret[index] = def_val; + } + else + { + return py::none(); + } + } + if (!args.empty()) + { + return py::none(); + } + return std::move(ret); +} + +PyBuilderContext& get_py_context(py::capsule& ctx) +{ + return *static_cast(ctx); +} + +auto get_types(mlir::ValueRange values) +{ + return values.getTypes(); +} + +auto get_agrs_from_tuple(py::handle args, llvm::function_ref unpack) +{ + llvm::SmallVector ret; + if (args.is_none()) + { + return ret; + } + if (py::isinstance(args)) + { + auto tuple = args.cast(); + ret.resize(tuple.size()); + for (auto it : llvm::enumerate(tuple)) + { + ret[it.index()] = unpack(it.value()); + } + } + else + { + ret.emplace_back(unpack(args)); + } + return ret; +} + +auto get_iterators(py::list iterators, mlir::MLIRContext& ctx) +{ + llvm::SmallVector ret(iterators.size()); + for (auto it : llvm::enumerate(iterators)) + { + ret[it.index()] = mlir::StringAttr::get(&ctx, it.value().cast()).getValue(); + } + return ret; +} + +mlir::AffineMapAttr get_affine_map_attr(py::handle obj, mlir::MLIRContext& ctx) +{ + auto str = (llvm::Twine("affine_map<") + obj.cast() + ">").str(); + return mlir::parseAttribute(str, &ctx).cast(); +} + +auto get_affine_maps(py::list maps, mlir::MLIRContext& ctx) +{ + llvm::SmallVector ret(maps.size()); + for (auto it : llvm::enumerate(maps)) + { + ret[it.index()] = get_affine_map_attr(it.value(), ctx).getValue(); + } + return ret; +} + +auto get_generic_op_body_types(mlir::ValueRange inputs, mlir::ValueRange outputs) +{ + llvm::SmallVector ret; + ret.reserve(inputs.size() + outputs.size()); + for (auto r : {inputs, outputs}) + { + for (auto type : r.getTypes()) + { + auto elem_type = [&]() + { + if (auto tensor = type.dyn_cast()) + { + return tensor.getElementType(); + } + return type; + }(); + ret.emplace_back(elem_type); + } + } + return ret; +} + +auto generic_op_body_result_types(mlir::ValueRange outputs) +{ + llvm::SmallVector ret; + ret.reserve(outputs.size()); + for (auto type : outputs.getTypes()) + { + auto elem_type = type.cast().getElementType(); + ret.emplace_back(elem_type); + } + return ret; +} + +bool is_int(mlir::Type type) +{ + return type.isa(); +} + +unsigned get_int_bit_width(mlir::Type type) +{ + if (type.isa()) + { + return type.cast().getWidth(); + } + if (type.isa()) + { + return 64; // TODO + } + llvm_unreachable("No an integer type"); +} + +bool is_float(mlir::Type type) +{ + return type.isa(); +} + +unsigned get_float_bit_width(mlir::Type type) +{ + return type.cast().getWidth(); +} + +mlir::Type broadcast_type(mlir::Type type1, mlir::Type type2) +{ + if (type1 == type2) + { + return type1; + } + // TODO + if (is_int(type1) && is_int(type2)) + { + auto width = std::max(get_int_bit_width(type1), get_int_bit_width(type2)); + return mlir::IntegerType::get(type1.getContext(), width); + } + if (is_float(type1) && is_float(type2)) + { + return (get_float_bit_width(type1) > get_float_bit_width(type2) ? type1 : type2); + } + if (is_float(type1) && is_int(type2)) + { + return type1; + } + if (is_int(type1) && is_float(type2)) + { + return type2; + } + llvm_unreachable("Unable to broadcast type"); +} + +mlir::Value broadcast_dim(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value val1, mlir::Value val2) +{ + assert(val1.getType().isa()); + assert(val2.getType().isa()); + auto one = builder.create(loc, 1); + auto cond = builder.create(loc, mlir::CmpIPredicate::eq, val1, one); + return builder.create(loc, cond, val2, val1); +} + +mlir::Value expand_dim(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value initial, mlir::Value src, unsigned dim, mlir::ValueRange target_shape) +{ + auto context = builder.getContext(); + auto src_type = src.getType().cast(); + auto num_dims = static_cast(src_type.getRank()); + auto shape = llvm::to_vector<8>(src_type.getShape()); + shape[dim] = -1; + mlir::Type target_type = mlir::RankedTensorType::get(shape, src_type.getElementType()); + auto dim_val = builder.create(loc, initial, dim); + auto one = builder.create(loc, 1); + mlir::Value cond = builder.create(loc, mlir::CmpIPredicate::eq, one, dim_val); + llvm::SmallVector new_shape(num_dims); + for (unsigned i = 0 ; i < num_dims; ++i) + { + if (i == dim) + { + new_shape[i] = target_shape[i]; + } + else + { + new_shape[i] = builder.create(loc, src, i); + } + } + auto true_body = [&](mlir::OpBuilder &builder, mlir::Location loc) + { + assert(dim < shape.size()); + shape[dim] = 1; +// mlir::Type casted_type = mlir::RankedTensorType::get(shape, src_type.getElementType()); +// auto casted = builder.create(loc, casted_type, src).getResult(); + auto casted = src; // TODO + auto init = builder.create(loc, new_shape, src_type.getElementType()).getResult(); + llvm::SmallVector exprs(num_dims); + for (unsigned i = 0; i < num_dims; ++i) + { + if (i == dim) + { + exprs[i] = mlir::getAffineConstantExpr(0, context); + } + else + { + exprs[i] = mlir::getAffineDimExpr(i, context); + } + } + const mlir::AffineMap maps[] = { + mlir::AffineMap::get(num_dims, 0, exprs, context), + mlir::AffineMap::getMultiDimIdentityMap(num_dims, context), + }; + llvm::SmallVector iterators(num_dims, "parallel"); + + auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange values) + { + assert(values.size() == 2); + builder.create(loc, values[0]); + }; + + auto expanded = builder.create(loc, target_type, casted, init, maps, iterators, body); + auto res = builder.create(loc, target_type, expanded.getResult(0)); + builder.create(loc, res.getResult()); + }; + auto false_body = [&](mlir::OpBuilder &builder, mlir::Location loc) + { + auto res = builder.create(loc, target_type, src); + builder.create(loc, res.getResult()); + }; + return builder.create(loc, target_type, cond, true_body, false_body).getResult(0); +} + +mlir::Value expand_dims(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value val, unsigned num_dims, mlir::ValueRange target_shape) +{ + assert(num_dims <= target_shape.size()); + if (num_dims < target_shape.size()) + { + target_shape = target_shape.drop_front(target_shape.size() - num_dims); + } + mlir::Value current = val; + for (unsigned i = 0; i < num_dims; ++i) + { + current = expand_dim(builder, loc, val, current, i, target_shape); + } + current = builder.create(loc, current, target_shape); + return current; +} + +py::object broadcast_impl(py::capsule context, py::tuple args) +{ + if (1 == args.size()) + { + return args[0]; + } + auto& ctx = get_py_context(context); + auto loc = ctx.loc; + auto& builder = ctx.builder; + llvm::SmallVector mlir_args(args.size()); + for (auto it : llvm::enumerate(args)) + { + mlir_args[it.index()] = ctx.context.unwrap_val(loc, builder, it.value()); + } + using shape_t = llvm::SmallVector; + auto get_shape = [&](mlir::Value val)->llvm::Optional> + { + auto type = val.getType(); + if (auto shaped = type.dyn_cast()) + { + if (!shaped.hasRank()) + { + return {}; + } + shape_t ret(static_cast(shaped.getRank())); + for (auto it : llvm::enumerate(ret)) + { + auto dim = builder.create(loc, val, it.index()); + ret[it.index()] = dim; + } + return std::make_pair(ret, shaped.getElementType()); + } + if (type.isa()) + { + return std::make_pair(shape_t{}, type); + } + return {}; + }; + mlir::Type res_type; + mlir::SmallVector shape_vals; + if (auto shape_and_type = get_shape(mlir_args.front())) + { + res_type = shape_and_type->second; + shape_vals = shape_and_type->first; + } + else + { + return py::none(); + } + + for (auto arg : llvm::drop_begin(mlir_args)) + { + auto shape_and_type = get_shape(arg); + if (!shape_and_type) + { + py::none(); + } + res_type = broadcast_type(res_type, shape_and_type->second); + auto new_shape_vals = shape_and_type->first; + for (auto it : llvm::zip(llvm::reverse(shape_vals), llvm::reverse(new_shape_vals))) + { + auto& old_val = std::get<0>(it); + auto new_val = std::get<1>(it); + old_val = broadcast_dim(builder, loc, old_val, new_val); + } + if (new_shape_vals.size() > shape_vals.size()) + { + auto front = llvm::makeArrayRef(new_shape_vals).drop_back(shape_vals.size()); + assert(!front.empty()); + shape_vals.insert(shape_vals.begin(), front.begin(), front.end()); + } + } + + py::tuple ret(mlir_args.size()); + if (shape_vals.empty()) + { + for (auto it : llvm::enumerate(mlir_args)) + { + mlir::Value val = it.value(); + if (val.getType() != res_type) + { + val = builder.create(loc, res_type, val); + } + ret[it.index()] = ctx.context.create_var(context, val); + } + return std::move(ret); + } + + llvm::SmallVector shape(static_cast(shape_vals.size()), -1); + auto tensor_type = mlir::RankedTensorType::get(shape, res_type); + for (auto it : llvm::enumerate(mlir_args)) + { + mlir::Value val = it.value(); + if (auto src_type = val.getType().dyn_cast()) + { + assert(src_type.hasRank()); + val = expand_dims(builder, loc, val, static_cast(src_type.getRank()), shape_vals); + } + if (val.getType() != tensor_type) + { + auto type = val.getType(); + if (auto src_type = type.dyn_cast()) + { + assert(src_type.hasRank()); + auto src_num_dims = static_cast(src_type.getRank()); + auto num_dims = static_cast(tensor_type.getRank()); + auto init = builder.create(loc, shape_vals, tensor_type.getElementType()).getResult(); + mlir::AffineMap maps[] = { + mlir::AffineMap::getMinorIdentityMap(num_dims, src_num_dims, builder.getContext()), +// mlir::AffineMap::getMultiDimIdentityMap(num_dims, builder.getContext()).getMajorSubMap(src_num_dims), + mlir::AffineMap::getMultiDimIdentityMap(num_dims, builder.getContext()), + }; + llvm::SmallVector iterators(num_dims, "parallel"); + auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange values) + { + assert(values.size() == 2); + auto res = builder.create(loc, tensor_type.getElementType(), values[0]); + builder.create(loc, res.getResult()); + }; + val = builder.create(loc, tensor_type, val, init, maps, iterators, body).getResult(0); + } + else + { + if (tensor_type.getElementType() != type) + { + val = builder.create(loc, tensor_type.getElementType(), val); + } + val = builder.create(loc, val); + auto num_dims = static_cast(tensor_type.getRank()); + auto init = builder.create(loc, shape_vals, tensor_type.getElementType()).getResult(); + mlir::AffineMap maps[] = { + mlir::AffineMap::get(num_dims, 0, mlir::getAffineConstantExpr(0, builder.getContext())), + mlir::AffineMap::getMultiDimIdentityMap(num_dims, builder.getContext()), + }; + llvm::SmallVector iterators(num_dims, "parallel"); + auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange values) + { + assert(values.size() == 2); + builder.create(loc, values[0]); + }; + val = builder.create(loc, tensor_type, val, init, maps, iterators, body).getResult(0); + } + } + ret[it.index()] = ctx.context.create_var(context, val); + } + return std::move(ret); +} + +py::object init_tensor_impl(py::capsule context, py::handle shape, py::handle dtype, py::handle init_val) +{ + auto& ctx = get_py_context(context); + auto loc = ctx.loc; + auto& builder = ctx.builder; + auto elem_type = unwrap_type(dtype); + mlir::Value init; + auto count = py::len(shape); + if (count == 0) + { + if (init_val.is_none()) + { + // TODO: undef + auto zero_val = plier::getZeroVal(elem_type); + assert(zero_val); + init = builder.create(loc, zero_val); + } + else + { + init = do_cast(loc, builder, ctx.context.unwrap_val(loc, builder, init_val), elem_type); + } + } + else + { + auto index_type = builder.getIndexType(); + llvm::SmallVector shape_val(count); + llvm::SmallVector static_shape(count, -1); + for (size_t i = 0; i < count; ++i) + { + auto elem = shape[py::int_(i)]; + if (py::isinstance(elem)) + { + static_shape[i] = elem.cast(); + } + shape_val[i] = do_cast(loc, builder, ctx.context.unwrap_val(loc, builder, elem), index_type); + } + + if (init_val.is_none()) + { + init = builder.create(loc, shape_val, elem_type); + } + else + { + auto val = do_cast(loc, builder, ctx.context.unwrap_val(loc, builder, init_val), elem_type); + llvm::SmallVector shape(count, -1); + auto type = mlir::RankedTensorType::get(shape, elem_type); + auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange /*indices*/) + { + builder.create(loc, val); + }; + init = builder.create(loc, type, shape_val, body); + } + if (llvm::any_of(static_shape, [](auto val){ return val >= 0;})) + { + auto new_type = mlir::RankedTensorType::get(static_shape, elem_type); + init = builder.create(loc, new_type, init); + } + } + return ctx.context.create_var(context, init); +} + +py::object fill_tensor_impl(py::capsule context, py::handle tensor, py::handle value) +{ + auto& ctx = get_py_context(context); + auto loc = ctx.loc; + auto& builder = ctx.builder; + auto tensor_val = ctx.context.unwrap_val(loc, builder, tensor); + auto tensor_type = tensor_val.getType().cast(); + auto init_val = ctx.context.unwrap_val(loc, builder, value); + if (init_val.getType() != tensor_type.getElementType()) + { + init_val = builder.create(loc, tensor_type.getElementType(), init_val); + } + +// auto val = builder.create(loc, tensor_type, tensor_val, init_val); + auto rank = static_cast(tensor_type.getRank()); + mlir::AffineMap affine_maps[] = { + mlir::AffineMap::getMultiDimIdentityMap(rank, builder.getContext()), + }; + llvm::SmallVector iterators(rank, "parallel"); + auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange values) + { + assert(values.size() == 1); + builder.create(loc, init_val); + }; + auto val = builder.create( + loc, + tensor_type, + llvm::None, + tensor_val, + affine_maps, + iterators, + body); + return ctx.context.create_var(context, val.getResult(0)); +} + +py::object generic_impl(py::capsule context, py::handle inputs, py::handle outputs, py::list iterators, py::list maps, py::handle body) +{ + auto& ctx = get_py_context(context); + auto loc = ctx.loc; + auto& builder = ctx.builder; + auto& mlir_context = *builder.getContext(); + + auto unpack = [&](py::handle obj)->mlir::Value + { + return ctx.context.unwrap_val(loc, builder, obj); + }; + + auto inputs_args = get_agrs_from_tuple(inputs, unpack); + auto output_args = get_agrs_from_tuple(outputs, unpack); + auto ret_types = get_types(output_args); + auto mlir_iterators = get_iterators(iterators, mlir_context); + + auto func_types = map_types_to_numba(ctx.context.types_mod, get_generic_op_body_types(inputs_args, output_args)); + auto body_func = ctx.context.compile_body(body, func_types); + + auto cast_values = [&](mlir::ValueRange vals, mlir::TypeRange types) + { + assert(vals.size() == types.size()); + llvm::SmallVector ret(vals.size()); + auto do_cast = [&](mlir::Value val, mlir::Type type) + { + if (val.getType() == type) + { + return val; + } + return builder.create(loc, type, val).getResult(); + }; + for (auto it : llvm::enumerate(vals)) + { + auto index = static_cast(it.index()); + ret[index] = do_cast(it.value(), types[index]); + } + return ret; + }; + if (mlir_iterators.empty()) + { + inputs_args.append(output_args.begin(), output_args.end()); + auto res = builder.create(loc, body_func, inputs_args); + return ctx.context.wrap_result(context, cast_values(res.getResults(), ret_types)); + } + else + { + auto affine_maps = get_affine_maps(maps, mlir_context); + auto body_builder = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange args) + { + auto func_type = body_func.getType(); + auto new_args = cast_values(args, func_type.getInputs()); + auto call = builder.create(loc, body_func, new_args); + auto new_results = cast_values(call.getResults(), generic_op_body_result_types(output_args)); + builder.create(loc, new_results); + }; + + auto generic_op = builder.create( + loc, + ret_types, + inputs_args, + output_args, + affine_maps, + mlir_iterators, + body_builder); + return ctx.context.wrap_result(context, generic_op.getResults()); + } +} + +py::object from_elements_impl(py::capsule context, py::handle values, py::handle dtype) +{ + auto& ctx = get_py_context(context); + auto& builder = ctx.builder; + auto loc = ctx.loc; + auto type = unwrap_type(dtype); + + llvm::SmallVector vals(container_size(values)); + container_iterate(values, [&](auto index, py::handle obj) + { + if (py::isinstance(obj, ctx.context.var)) + { + vals[index] = unwrap_ssa_val(obj); + } + else if (py::isinstance(obj) || + py::isinstance(obj)) + { + auto attr = [&]()->mlir::Attribute + { + if (type.isa()) + { + return mlir::IntegerAttr::get(type, obj.cast()); + } + if (type.isa()) + { + return mlir::FloatAttr::get(type, obj.cast()); + } + plier::report_error("Invalid dtype"); + }(); + vals[index] = builder.create(loc, attr); + } + else + { + plier::report_error("Invalid element type"); + } + }); + auto res = builder.create(loc, vals); + return ctx.context.create_var(context, res); +} + +py::object extract_impl(py::capsule context, py::handle value, py::handle indices) +{ + auto& ctx = get_py_context(context); + auto& builder = ctx.builder; + auto loc = ctx.loc; + + llvm::SmallVector ind(container_size(indices)); + container_iterate(indices, [&](auto index, py::handle obj) + { + if (py::isinstance(obj, ctx.context.var)) + { + ind[index] = unwrap_ssa_val(obj); + } + else if (py::isinstance(obj)) + { + ind[index] = builder.create(loc, obj.cast()); + } + else + { + plier::report_error("Invalid element type"); + } + }); + auto res = builder.create(loc, ctx.context.unwrap_val(loc, builder, value), ind); + return ctx.context.create_var(context, res); +} + +py::object reshape_impl(py::capsule context, py::handle tensor, py::int_ out_dims, py::list maps) +{ + auto& ctx = get_py_context(context); + auto& builder = ctx.builder; + auto loc = ctx.loc; + + auto tensor_val = ctx.context.unwrap_val(loc, builder, tensor); + if (!tensor_val.getType().isa()) + { + plier::report_error("Invalid reshapa argument"); + } + auto elem_type = tensor_val.getType().cast().getElementType(); + auto new_dims = out_dims.cast(); + llvm::SmallVector dims(new_dims, -1); + auto new_type = mlir::RankedTensorType::get(dims, elem_type); + + llvm::SmallVector affine_maps(container_size(maps)); + container_iterate(maps, [&](auto index, py::handle obj) + { + affine_maps[index] = get_affine_map_attr(obj, *builder.getContext()); + }); + auto affine_maps_attr = mlir::ArrayAttr::get(builder.getContext(), affine_maps); + auto reshape = builder.create(loc, new_type, tensor_val, affine_maps_attr); + return ctx.context.create_var(context, reshape); +} + +void setup_py_builder(py::handle builder, mlir::OpBuilder& b, llvm::function_ref create_type) +{ + py::setattr(builder, "_broadcast", py::cpp_function(&broadcast_impl)); + py::setattr(builder, "_init_tensor", py::cpp_function(&init_tensor_impl)); + py::setattr(builder, "_fill_tensor", py::cpp_function(&fill_tensor_impl)); + py::setattr(builder, "_generic", py::cpp_function(&generic_impl)); + py::setattr(builder, "_from_elements", py::cpp_function(&from_elements_impl)); + py::setattr(builder, "_extract", py::cpp_function(&extract_impl)); + py::setattr(builder, "_reshape", py::cpp_function(&reshape_impl)); + + auto add_type = [&](const char* name, mlir::Type type) + { + py::setattr(builder, name, create_type(type)); + }; + + add_type("int8", b.getIntegerType(8)); + add_type("int16", b.getIntegerType(16)); + add_type("int32", b.getIntegerType(32)); + add_type("int64", b.getIntegerType(64)); + add_type("index", b.getIndexType()); + + add_type("float16", b.getF16Type()); + add_type("float32", b.getF32Type()); + add_type("float64", b.getF64Type()); +} + +py::object shape_impl(py::capsule context, py::capsule ssa_val) +{ + auto& ctx = get_py_context(context); + auto value = unwrap_mlir(ssa_val); + if (value.getType().isa()) + { + auto& builder = ctx.builder; + auto loc = ctx.loc; + auto mlir_type = value.getType().cast(); + auto shape = mlir_type.getShape(); + llvm::SmallVector shape_vals(shape.size()); + for (auto it : llvm::enumerate(shape)) + { + auto i = it.index(); + mlir::Value mlir_dim = builder.create(loc, value, i); + shape_vals[i] = mlir_dim; + } + llvm::SmallVector shape_types(shape.size(), builder.getIndexType()); + auto shape_type = mlir::TupleType::get(builder.getContext(), shape_types); + auto shape_var = builder.create(loc, shape_type, shape_vals); + return ctx.context.create_var(context, shape_var.getResult()); + } + return py::list(); +} + +py::object dtype_impl(py::capsule context, py::capsule ssa_val) +{ + auto& ctx = get_py_context(context); + auto value = unwrap_mlir(ssa_val); + auto type = value.getType(); + if (auto tensor_type = type.dyn_cast()) + { + return ctx.context.create_type(tensor_type.getElementType()); + } + return ctx.context.create_type(type); +} + +py::object len_impl(py::capsule /*context*/, py::capsule ssa_val) +{ + auto value = unwrap_mlir(ssa_val); + auto type = value.getType(); + if (auto tuple_type = type.dyn_cast()) + { + return py::int_(tuple_type.size()); + } + return py::int_(1); +} + +py::object getitem_impl(py::capsule context, py::capsule ssa_val, py::handle index) +{ + auto& ctx = get_py_context(context); + auto value = unwrap_mlir(ssa_val); + auto& builder = ctx.builder; + auto loc = ctx.loc; + auto index_val = index.cast(); + auto type = value.getType(); + if (auto tuple_type = type.dyn_cast()) + { + if (index_val < 0 || index_val >= static_cast(tuple_type.size())) + { + plier::report_error("Invalid getitem index"); + } + if (auto parent_op = value.getDefiningOp()) + { + return ctx.context.create_var(context, parent_op.getOperand(static_cast(index_val))); + } + auto elem_type = tuple_type.getType(static_cast(index_val)); + auto ind = builder.create(loc, index_val); + auto item = builder.create(loc, elem_type, value, ind); + return ctx.context.create_var(context, item.getResult()); + } + else + { + if (0 != index_val) + { + plier::report_error("Invalid getitem index"); + } + return ctx.context.create_var(context, value); + } +} + +template +mlir::Value binop_func(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value lhs, mlir::Value rhs) +{ + return builder.create(loc, lhs, rhs); +} + +py::object binop_impl(py::capsule context, py::capsule ssa_val, py::handle rhs, py::str op) +{ + auto& ctx = get_py_context(context); + auto& builder = ctx.builder; + auto loc = ctx.loc; + auto lhs = unwrap_mlir(ssa_val); + + auto type = lhs.getType(); + if (!type.isa()) + { + plier::report_error("Invalid binop arg type"); + } + + auto is_float = [&]()->bool + { + if (auto shaped_type = type.dyn_cast()) + { + return shaped_type.getElementType().isa(); + } + return type.isa(); + }(); + + using binop_func_t = mlir::Value(*)(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value lhs, mlir::Value rhs); + const std::tuple funcs[] = { + {"*", &binop_func, &binop_func}, + }; + + auto op_name = static_cast(op); + for (auto f : funcs) + { + auto name = std::get<0>(f); + auto func = (is_float ? std::get<2>(f) : std::get<1>(f)); + if (name == op_name) + { + auto rhs_var = do_cast(loc, builder, ctx.context.unwrap_val(loc, builder, rhs), type); + auto res = func(loc, builder, lhs, rhs_var); + return ctx.context.create_var(context, res); + } + } + plier::report_error("Unhandled binop type"); +} + +void setup_py_var(pybind11::handle var) +{ + py::setattr(var, "_shape", py::cpp_function(&shape_impl)); + py::setattr(var, "_dtype", py::cpp_function(&dtype_impl)); + py::setattr(var, "_len", py::cpp_function(&len_impl)); + py::setattr(var, "_getitem", py::cpp_function(&getitem_impl)); + py::setattr(var, "_binop", py::cpp_function(&binop_impl)); +} + +PyLinalgResolver::Values unpack_results(py::handle object) +{ + PyLinalgResolver::Values ret; + if (object.is_none()) + { + return ret; + } + if (py::isinstance(object)) + { + auto tuple = object.cast(); + ret.resize(tuple.size()); + for (auto it : llvm::enumerate(tuple)) + { + ret[it.index()] = unwrap_ssa_val(it.value()); + } + return ret; + } + ret.emplace_back(unwrap_ssa_val(object)); + return ret; +} +} + +PyLinalgResolver::PyLinalgResolver(): + context(std::make_unique()) +{ + auto builder_mod = py::module::import("numba.mlir.linalg_builder"); + context->var = builder_mod.attr("Var"); + context->type = builder_mod.attr("Type"); + context->builder = builder_mod.attr("Builder"); + context->inspect = py::module::import("inspect"); + context->types_mod = py::module::import("numba.core.types"); + context->compile_func = builder_mod.attr("compile_func"); + context->lookup_func = builder_mod.attr("lookup_func"); +} + +PyLinalgResolver::~PyLinalgResolver() +{ + +} + +llvm::Optional PyLinalgResolver::rewrite_func(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args, KWArgs kwargs) +{ + auto mangled_name = (llvm::Twine(name) + "()").str(); + return rewrite(mangled_name, loc, builder, args, kwargs); +} + +llvm::Optional PyLinalgResolver::rewrite_attr(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::Value arg) +{ + return rewrite(name, loc, builder, arg, {}); +} + +llvm::Optional PyLinalgResolver::rewrite(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args, KWArgs kwargs) +{ + assert(!name.empty()); + if (!is_compatible_types(args) || + !is_compatible_types(llvm::make_second_range(kwargs))) + { + return {}; + } + + auto builder_func = context->lookup_func(py::str(name.data(), name.size())); + if (builder_func.is_none()) + { + return {}; + } + + PyBuilderContext py_builder_context{loc, builder, *context}; + auto py_context = py::capsule(&py_builder_context); + auto py_args = get_args( + context->inspect, + builder_func, + [&](auto val){ return context->create_var(py_context, val);}, + args, + kwargs); + if (py_args.is_none()) + { + return {}; + } + auto py_builder = context->builder(py_context); + setup_py_builder(py_builder, builder, [&](auto type){ return context->create_type(type);}); + + auto result = builder_func(py_builder, *py_args); + if (result.is_none()) + { + return {}; + } + return unpack_results(result); +} diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.hpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.hpp new file mode 100644 index 00000000000..156a5f3acad --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.hpp @@ -0,0 +1,44 @@ +#pragma once + +#include + +#include +#include +#include + +namespace llvm +{ +class StringRef; +} + +namespace mlir +{ +class Value; +class FuncOp; +class ValueRange; +class OpBuilder; +class Location; +} + +class PyLinalgResolver +{ +public: + PyLinalgResolver(); + ~PyLinalgResolver(); + + using Values = llvm::SmallVector; + using KWArgs = llvm::ArrayRef>; + + llvm::Optional rewrite_func(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args, + KWArgs kwargs); + + llvm::Optional rewrite_attr(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::Value arg); + +private: + friend struct PyBuilderContext; + struct Context; + std::unique_ptr context; + + llvm::Optional rewrite(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args, + KWArgs kwargs); +}; diff --git a/mlir-compiler/mlir-compiler/src/py_map_types.cpp b/mlir-compiler/mlir-compiler/src/py_map_types.cpp new file mode 100644 index 00000000000..64f68ba72ae --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/py_map_types.cpp @@ -0,0 +1,108 @@ +#include "py_map_types.hpp" + +#include + +#include +#include + +namespace py = pybind11; + +namespace +{ +template +bool is_int(mlir::Type type) +{ + if (auto t = type.dyn_cast()) + { + if (t.getWidth() == Width && t.getSignedness() == Signed) + { + return true; + } + } + return false; +} + +template +bool is_float(mlir::Type type) +{ + if (auto f = type.dyn_cast()) + { + if (f.getWidth() == Width) + { + return true; + } + } + return false; +} + +py::object map_type(const py::handle& types_mod, mlir::Type type) +{ + using fptr_t = bool(*)(mlir::Type); + const std::pair primitive_types[] = { + {&is_int<1, mlir::IntegerType::Signed>, "boolean"}, + {&is_int<1, mlir::IntegerType::Signless>, "boolean"}, + {&is_int<1, mlir::IntegerType::Unsigned>, "boolean"}, + + {&is_int<8, mlir::IntegerType::Signed>, "int8"}, + {&is_int<8, mlir::IntegerType::Signless>, "int8"}, + {&is_int<8, mlir::IntegerType::Unsigned>, "uint8"}, + + {&is_int<16, mlir::IntegerType::Signed>, "int16"}, + {&is_int<16, mlir::IntegerType::Signless>, "int16"}, + {&is_int<16, mlir::IntegerType::Unsigned>, "uint16"}, + + {&is_int<32, mlir::IntegerType::Signed>, "int32"}, + {&is_int<32, mlir::IntegerType::Signless>, "int32"}, + {&is_int<32, mlir::IntegerType::Unsigned>, "uint32"}, + + {&is_int<64, mlir::IntegerType::Signed>, "int64"}, + {&is_int<64, mlir::IntegerType::Signless>, "int64"}, + {&is_int<64, mlir::IntegerType::Unsigned>, "uint64"}, + + {&is_float<32>, "float32"}, + {&is_float<64>, "float64"}, + }; + + for (auto h : primitive_types) + { + if (h.first(type)) + { + auto name = h.second; + return types_mod.attr(py::str(name.data(), name.size())); + } + } + + if (auto m = type.dyn_cast()) + { + auto elem_type = map_type(types_mod, m.getElementType()); + if (!elem_type) + { + return {}; + } + auto ndims = py::int_(m.getRank()); + auto array_type = types_mod.attr("Array"); + return array_type(elem_type, ndims, py::str("C")); + } + return {}; +} +} +pybind11::object map_type_to_numba(pybind11::handle types_mod, mlir::Type type) +{ + auto elem = map_type(types_mod, type); + if (!elem) + { + return py::none(); + } + return elem; +} + +pybind11::list map_types_to_numba(pybind11::handle types_mod, mlir::TypeRange types) +{ + py::list ret(types.size()); + for (auto it : llvm::enumerate(types)) + { + ret[it.index()] = map_type_to_numba(types_mod, it.value()); + } + return ret; +} + diff --git a/mlir-compiler/mlir-compiler/src/py_map_types.hpp b/mlir-compiler/mlir-compiler/src/py_map_types.hpp new file mode 100644 index 00000000000..90cd42d4fb6 --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/py_map_types.hpp @@ -0,0 +1,17 @@ +#pragma once + +namespace pybind11 +{ +class list; +class object; +class handle; +} + +namespace mlir +{ +class Type; +class TypeRange; +} + +pybind11::object map_type_to_numba(pybind11::handle types_mod, mlir::Type type); +pybind11::list map_types_to_numba(pybind11::handle types_mod, mlir::TypeRange types); diff --git a/mlir-compiler/mlir-compiler/src/py_module.cpp b/mlir-compiler/mlir-compiler/src/py_module.cpp new file mode 100644 index 00000000000..20fb44dd6b6 --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/py_module.cpp @@ -0,0 +1,13 @@ +#include + +#include "py_module.hpp" + +#include "lowering.hpp" + +PYBIND11_MODULE(mlir_compiler, m) +{ + m.def("create_module", &create_module, "todo"); + m.def("lower_function", &lower_function, "todo"); + m.def("compile_module", &compile_module, "todo"); + m.def("module_str", &module_str, "todo"); +} diff --git a/mlir-compiler/mlir-compiler/src/py_module.hpp b/mlir-compiler/mlir-compiler/src/py_module.hpp new file mode 100644 index 00000000000..6f70f09beec --- /dev/null +++ b/mlir-compiler/mlir-compiler/src/py_module.hpp @@ -0,0 +1 @@ +#pragma once diff --git a/mlir-compiler/readme.md b/mlir-compiler/readme.md new file mode 100644 index 00000000000..df49d6887ce --- /dev/null +++ b/mlir-compiler/readme.md @@ -0,0 +1,18 @@ +# Building MLIR backend + +MLIR backend is not yet integrated into Numba build process + +1. Follow usual numba build instructions (using release llvm) +2. Install pybind11 +3. Build llvm from specific commit required for the backend (numba/mlir-compiler/llvm-sha.txt) +4. Build backend using cmake (numba/mlir-compiler/CMakeLists.txt) using compiled llvm +5. Add dir with compiled backend to PYTHONPATH + +# Running MLIR backend tests + +`python runtests.py numba.mlir.tests` + +# Useful env variables + +* `NUMBA_MLIR_ENABLE=1` - enable/diasable MLIR backed (default - 1) +* `NUMBA_MLIR_PRINT_IR=1` - dump MLIR IR to stdout before and after each pass (default - 0) diff --git a/numba/core/compiler.py b/numba/core/compiler.py index 6275a5c48ee..3b79fa2608a 100644 --- a/numba/core/compiler.py +++ b/numba/core/compiler.py @@ -33,6 +33,7 @@ from numba.core.object_mode_passes import (ObjectModeFrontEnd, ObjectModeBackEnd, CompileInterpMode) +from numba.mlir.passes import (MlirDumpPlier, MlirBackend) class Flags(utils.ConfigOptions): # These options are all false by default, but the defaults are @@ -503,6 +504,14 @@ def define_typed_pipeline(state, name="typed"): pm.add_pass(NopythonTypeInference, "nopython frontend") pm.add_pass(AnnotateTypes, "annotate types") + import numba.mlir.settings + + if numba.mlir.settings.DUMP_PLIER: + pm.add_pass(MlirDumpPlier, "mlir dump plier") + + if numba.mlir.settings.USE_MLIR: + pm.add_pass(MlirBackend, "mlir backend") + # strip phis pm.add_pass(PreLowerStripPhis, "remove phis nodes") diff --git a/numba/core/lowering.py b/numba/core/lowering.py index 1c9c19cd3b1..f2f1680d653 100644 --- a/numba/core/lowering.py +++ b/numba/core/lowering.py @@ -11,6 +11,8 @@ from numba.core.funcdesc import default_mangler from numba.core.environment import Environment +import numba.mlir.settings +_use_mlir = numba.mlir.settings.USE_MLIR _VarArgItem = namedtuple("_VarArgItem", ("vararg", "index")) @@ -170,8 +172,9 @@ def lower(self): # Run target specific post lowering transformation self.context.post_lowering(self.module, self.library) - # Materialize LLVM Module - self.library.add_ir_module(self.module) + if not _use_mlir: + # Materialize LLVM Module + self.library.add_ir_module(self.module) def extract_function_arguments(self): self.fnargs = self.call_conv.decode_arguments(self.builder, @@ -183,15 +186,23 @@ def lower_normal_function(self, fndesc): """ Lower non-generator *fndesc*. """ + if _use_mlir: + mod_ir = self.mlir_blob + import llvmlite.binding as llvm + mod = llvm.parse_bitcode(mod_ir) + self.setup_function(fndesc) - # Init argument values - self.extract_function_arguments() - entry_block_tail = self.lower_function_body() + if _use_mlir: + self.library.add_llvm_module(mod); + else: + # Init argument values + self.extract_function_arguments() + entry_block_tail = self.lower_function_body() - # Close tail of entry block - self.builder.position_at_end(entry_block_tail) - self.builder.branch(self.blkmap[self.firstblk]) + # Close tail of entry block + self.builder.position_at_end(entry_block_tail) + self.builder.branch(self.blkmap[self.firstblk]) def lower_function_body(self): """ @@ -278,7 +289,6 @@ def pre_block(self, block): from numba.core.unsafe import eh super(Lower, self).pre_block(block) - if block == self.firstblk: # create slots for all the vars, irrespective of whether they are # initialized, SSA will pick this up and warn users about using diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py index 3756a3e7c22..98d7c7a0865 100644 --- a/numba/core/typed_passes.py +++ b/numba/core/typed_passes.py @@ -20,6 +20,7 @@ build_definitions, compute_cfg_from_blocks) from numba.core import postproc +from numba.core.lowering import _use_mlir @contextmanager def fallback_context(state, msg): @@ -367,6 +368,9 @@ def run_pass(self, state): with targetctx.push_code_library(library): lower = lowering.Lower(targetctx, library, fndesc, interp, metadata=metadata) + if _use_mlir: + setattr(lower, 'mlir_blob', state.mlir_blob) + lower.lower() if not flags.no_cpython_wrapper: lower.create_cpython_wrapper(flags.release_gil) @@ -463,7 +467,6 @@ def run_pass(self, state): ) return True - @register_pass(mutates_CFG=True, analysis_only=False) class InlineOverloads(FunctionPass): """ diff --git a/numba/mlir/__init__.py b/numba/mlir/__init__.py new file mode 100644 index 00000000000..d4ef05b47ad --- /dev/null +++ b/numba/mlir/__init__.py @@ -0,0 +1,9 @@ +from numba import runtests + +from . import builtin_funcs +from . import math_funcs + +from .numpy import funcs + +def test(*args, **kwargs): + return runtests.main("numba.mlir.tests", *args, **kwargs) diff --git a/numba/mlir/builtin_funcs.py b/numba/mlir/builtin_funcs.py new file mode 100644 index 00000000000..443f07b387a --- /dev/null +++ b/numba/mlir/builtin_funcs.py @@ -0,0 +1,9 @@ +from numba.mlir.func_registry import add_func + +from numba import prange + +add_func(range, 'range') +add_func(len, 'len') +add_func(bool, 'bool') + +add_func(prange, 'numba.prange') diff --git a/numba/mlir/func_registry.py b/numba/mlir/func_registry.py new file mode 100644 index 00000000000..3570a90f038 --- /dev/null +++ b/numba/mlir/func_registry.py @@ -0,0 +1,35 @@ + +_mlir_func_names = {} +_active_funcs_stack = [] + +def add_func(func, name): + key = id(func) + assert not key in _mlir_func_names + _mlir_func_names[key] = name + +def get_func_name(func): + return _mlir_func_names.get(id(func), None) + +def push_active_funcs_stack(): + global _active_funcs_stack + _active_funcs_stack.append({}) + +def pop_active_funcs_stack(): + global _active_funcs_stack + assert(len(_active_funcs_stack) > 0) + _active_funcs_stack.pop() + +def add_active_funcs(name, func): + global _active_funcs_stack + assert(len(_active_funcs_stack) > 0) + top = _active_funcs_stack[-1] + top[name] = func + +def find_active_func(name): + global _active_funcs_stack + assert(len(_active_funcs_stack) > 0) + for elem in reversed(_active_funcs_stack): + res = elem.get(name) + if not res is None: + return res + return None diff --git a/numba/mlir/inner_compiler.py b/numba/mlir/inner_compiler.py new file mode 100644 index 00000000000..b74987db8fe --- /dev/null +++ b/numba/mlir/inner_compiler.py @@ -0,0 +1,35 @@ +from numba.core.typed_passes import NopythonTypeInference, AnnotateTypes +from numba.core.compiler import CompilerBase, DefaultPassBuilder, DEFAULT_FLAGS, compile_extra +from numba.core.compiler_machinery import PassManager +from numba.core import typing, cpu + +from numba.mlir.passes import MlirBackendInner, get_mlir_func + +class MlirTempCompiler(CompilerBase): # custom compiler extends from CompilerBase + + def define_pipelines(self): + dpb = DefaultPassBuilder + pm = PassManager('MlirTempCompiler') + untyped_passes = dpb.define_untyped_pipeline(self.state) + pm.passes.extend(untyped_passes.passes) + + pm.add_pass(NopythonTypeInference, "nopython frontend") + pm.add_pass(AnnotateTypes, "annotate types") + pm.add_pass(MlirBackendInner, "mlir backend") + + pm.finalize() + return [pm] + +def _compile_isolated(func, args, return_type=None, flags=DEFAULT_FLAGS, + locals={}): + from numba.core.registry import cpu_target + typingctx = typing.Context() + targetctx = cpu.CPUContext(typingctx) + # Register the contexts in case for nested @jit or @overload calls + with cpu_target.nested_context(typingctx, targetctx): + return compile_extra(typingctx, targetctx, func, args, return_type, + flags, locals, pipeline_class=MlirTempCompiler) + +def compile_func(func, args): + _compile_isolated(func, args) + return get_mlir_func() diff --git a/numba/mlir/linalg_builder.py b/numba/mlir/linalg_builder.py new file mode 100644 index 00000000000..ddcc95c8fb6 --- /dev/null +++ b/numba/mlir/linalg_builder.py @@ -0,0 +1,88 @@ +from .func_registry import add_func + +class Var: + def __init__(self, context, ssa_val): + self._context = context + self._ssa_val = ssa_val + + @property + def shape(self): + return self._shape(self._context, self._ssa_val) + + @property + def dtype(self): + return self._dtype(self._context, self._ssa_val) + + def __len__(self): + return self._len(self._context, self._ssa_val) + + def __getitem__(self, index): + return self._getitem(self._context, self._ssa_val, index) + + def __mul__(self, o): return self._binop(self._context, self._ssa_val, o, '*') + def __rmul__(self, o): return self._binop(self._context, self._ssa_val, o, '*') + +class Type: + def __init__(self, mlir_type, eq): + self._mlir_type = mlir_type + self._eq = eq + + def __eq__(self, other): + return self._eq(self._mlir_type, other._mlir_type) + +def is_literal(val): + return not isinstance(val, Var) + +class Builder: + def __init__(self, context): + self._context = context + + def broadcast(self, *args): + return self._broadcast(self._context, args) + + def init_tensor(self, shape, dtype, init_val=None): + return self._init_tensor(self._context, shape, dtype, init_val) + + def fill_tensor(self, tensor, value): + return self._fill_tensor(self._context, tensor, value) + + def generic(self, inputs, outputs, iterators, maps, body): + return self._generic(self._context, inputs, outputs, iterators, maps, body) + + def from_elements(self, values, dtype): + return self._from_elements(self._context, values, dtype) + + def extract(self, value, indices): + return self._extract(self._context, value, indices) + + def reshape(self, src, num_dims, affine_maps): + return self._reshape(self._context, src, num_dims, affine_maps) + +def compile_func(*args, **kwargs): + import numba.mlir.inner_compiler + return numba.mlir.inner_compiler.compile_func(*args, **kwargs) + +_func_registry = {} + +def register_func(name, orig_func = None): + def _decorator(func): + global _func_registry + mangled_name = name + '()' + assert not mangled_name in _func_registry + _func_registry[mangled_name] = func + if not orig_func is None: + add_func(orig_func, name) + return func + return _decorator + +def register_attr(name): + def _decorator(func): + global _func_registry + assert not name in _func_registry + _func_registry[name] = func + return func + return _decorator + +def lookup_func(name): + global _func_registry + return _func_registry.get(name) diff --git a/numba/mlir/math_funcs.py b/numba/mlir/math_funcs.py new file mode 100644 index 00000000000..7811935d431 --- /dev/null +++ b/numba/mlir/math_funcs.py @@ -0,0 +1,11 @@ +from numba.mlir.func_registry import add_func + +import math + +_funcs = ['log', 'sqrt', 'exp', 'erf'] + +for f in _funcs: + fname = 'math.' + f + add_func(eval(fname), fname) + + diff --git a/numba/mlir/numpy/funcs.py b/numba/mlir/numpy/funcs.py new file mode 100644 index 00000000000..ce82e4b0b90 --- /dev/null +++ b/numba/mlir/numpy/funcs.py @@ -0,0 +1,225 @@ +from ..linalg_builder import register_func, register_attr, is_literal + +import numpy +import math + +def is_int(t, b): + return t == b.int8 or t == b.int16 or t == b.int32 or t == b.int64 + +def is_float(t, b): + return t == b.float16 or t == b.float32 or t == b.float64 + +def eltwise(builder, args, body, res_type = None): + if isinstance(args, tuple): + args = builder.broadcast(*args) + else: + args = (args,) + + if res_type is None: + res_type = args[0].dtype + + shape = args[0].shape + + num_dims = len(shape) + iterators = ['parallel' for _ in range(num_dims)] + dims = ','.join(['d%s' % i for i in range(num_dims)]) + expr = f'({dims}) -> ({dims})' + maps = [expr for _ in range(len(args) + 1)] + init = builder.init_tensor(shape, res_type) + + return builder.generic(args, init, iterators, maps, body) + +@register_func('numpy.add', numpy.add) +@register_func('operator.add') +def add_impl(builder, arg1, arg2): + def body(a, b, c): + return a + b + + return eltwise(builder, (arg1, arg2), body) + +@register_func('numpy.subtract', numpy.subtract) +@register_func('operator.sub') +def sub_impl(builder, arg1, arg2): + def body(a, b, c): + return a - b + + return eltwise(builder, (arg1, arg2), body) + +@register_func('numpy.multiply', numpy.multiply) +@register_func('operator.mul') +def mul_impl(builder, arg1, arg2): + def body(a, b, c): + return a * b + + return eltwise(builder, (arg1, arg2), body) + +@register_func('array.sum') +@register_func('numpy.sum', numpy.sum) +def sum_impl(builder, arg, axis=None): + if axis is None: + shape = arg.shape + num_dims = len(shape) + iterators = ['reduction' for _ in range(num_dims)] + dims = ','.join(['d%s' % i for i in range(num_dims)]) + expr1 = f'({dims}) -> ({dims})' + expr2 = f'({dims}) -> (0)' + maps = [expr1,expr2] + init = builder.from_elements(0, arg.dtype) + + def body(a, b): + return a + b + + res = builder.generic(arg, init, iterators, maps, body) + return builder.extract(res, 0) + elif isinstance(axis, int): + shape = arg.shape + num_dims = len(shape) + iterators = [('reduction' if i == axis else 'parallel') for i in range(num_dims)] + dims1 = ','.join(['d%s' % i for i in range(num_dims)]) + dims2 = ','.join(['d%s' % i for i in range(num_dims) if i != axis]) + expr1 = f'({dims1}) -> ({dims1})' + expr2 = f'({dims1}) -> ({dims2})' + maps = [expr1,expr2] + res_shape = tuple(shape[i] for i in range(len(shape)) if i != axis) + + orig_type = arg.dtype + if is_int(orig_type, builder): + res_type = builder.int64 + else: + res_type = orig_type + init = builder.init_tensor(res_shape, res_type, 0) + + def body(a, b): + return a + b + + return builder.generic(arg, init, iterators, maps, body) + + +@register_func('numpy.sqrt', numpy.sqrt) +def sqrt_impl(builder, arg): + + def body(a, b): + return math.sqrt(a) + + return eltwise(builder, arg, body, builder.float64) + +@register_func('numpy.square', numpy.square) +def square_impl(builder, arg): + + def body(a, b): + return a * a + + return eltwise(builder, arg, body) + +@register_func('numpy.empty', numpy.empty) +def empty_impl(builder, shape): + # TODO: dtype + return builder.init_tensor(shape, builder.float64) + +@register_func('numpy.dot', numpy.dot) +def dot_impl(builder, a, b): + shape1 = a.shape + shape2 = b.shape + if len(shape1) == 1 and len(shape2) == 1: + iterators = ['reduction'] + expr1 = '(d0) -> (d0)' + expr2 = '(d0) -> (0)' + maps = [expr1,expr1,expr2] + init = builder.from_elements(0, a.dtype) + + def body(a, b, c): + return a * b + c + + res = builder.generic((a,b), init, iterators, maps, body) + return builder.extract(res, 0) + if len(shape1) == 2 and len(shape2) == 2: + iterators = ['parallel','parallel','reduction'] + expr1 = '(d0,d1,d2) -> (d0,d2)' + expr2 = '(d0,d1,d2) -> (d2,d1)' + expr3 = '(d0,d1,d2) -> (d0,d1)' + maps = [expr1,expr2,expr3] + res_shape = (shape1[0], shape2[1]) + init = builder.init_tensor(res_shape, a.dtype, 0) + + def body(a, b, c): + return a * b + c + + return builder.generic((a,b), init, iterators, maps, body) + +@register_attr('array.size') +def size_impl(builder, arg): + shape = arg.shape + res = builder.init_tensor([], builder.index, 1) + for i in range(len(shape)): + res = res * shape[i] + return res + +@register_attr('array.T') +def transpose_impl(builder, arg): + shape = arg.shape + dims = len(shape) + if dims == 1: + return arg + if dims == 2: + iterators = ['parallel','parallel'] + expr1 = '(d0,d1) -> (d0,d1)' + expr2 = '(d0,d1) -> (d1,d0)' + maps = [expr1,expr2] + res_shape = (shape[1], shape[0]) + init = builder.init_tensor(res_shape, arg.dtype) + + def body(a, b): + return a + + return builder.generic(arg, init, iterators, maps, body) + +def flatten(builder, arg, src_dims_count): + if 1 == src_dims_count: + return arg + dims = ','.join(['d%s' % i for i in range(src_dims_count)]) + expr = f'({dims}) -> ({dims})' + maps = [ + expr + ] + return builder.reshape(arg, 1, maps) + +def find_size_index(shape): + size_index = -1 + for i in range(len(shape)): + d = shape[i] + if is_literal(d): + if 1 != d: + return -1 + else: + if size_index != -1: + return -1 + size_index = i + return size_index + +@register_func('array.reshape') +def reshape_impl(builder, arg, new_shape): + shape = arg.shape + src_count = len(shape) + count = len(new_shape) + if count == 1: + return flatten(builder, arg, src_count) + else: + size_index = find_size_index(new_shape) + if size_index < 0: + return + + flat = flatten(builder, arg, src_count) + init = builder.init_tensor(new_shape, arg.dtype) + + iterators = ['parallel' for _ in range(count)] + dims1 = ','.join(['d%s' % i for i in range(count)]) + dims3 = ','.join(['d%s' % i if i == size_index else '0' for i in range(count)]) + expr1 = f'({dims1}) -> (d{size_index})' + expr2 = f'({dims1}) -> ({dims1})' + maps = [expr1, expr2] + + def body(a, b): + return a + + return builder.generic(flat, init, iterators, maps, body) + diff --git a/numba/mlir/passes.py b/numba/mlir/passes.py new file mode 100644 index 00000000000..c14652494ec --- /dev/null +++ b/numba/mlir/passes.py @@ -0,0 +1,136 @@ +from numba.core.compiler_machinery import (FunctionPass, register_pass) +from numba.core import (types) + +import numba.mlir.settings +import numba.mlir.func_registry +import numba.core.types.functions +_mlir_last_compiled_func = None +_mlir_active_module = None + +def _reload_parfors(): + """Reloader for cached parfors + """ + # Re-initialize the parallel backend when load from cache. + from numba.np.ufunc.parallel import _launch_threads + _launch_threads() + +class MlirBackendBase(FunctionPass): + + def __init__(self): + import numba.mlir.func_registry + self._get_func_name = numba.mlir.func_registry.get_func_name + FunctionPass.__init__(self) + + def run_pass(self, state): + numba.mlir.func_registry.push_active_funcs_stack() + try: + res = self.run_pass_impl(state) + finally: + numba.mlir.func_registry.pop_active_funcs_stack() + return res + + def _resolve_func_name(self, obj): + name, func = self._resolve_func_name_impl(obj) + if not (name is None or func is None): + numba.mlir.func_registry.add_active_funcs(name, func) + return name + + def _resolve_func_name_impl(self, obj): + if isinstance(obj, types.Function): + func = obj.typing_key + return (self._get_func_name(func), None) + if isinstance(obj, types.BoundFunction): + return (str(obj.typing_key), None) + if isinstance(obj, numba.core.types.functions.Dispatcher): + func = obj.dispatcher.py_func + return (func.__module__ + "." + func.__qualname__, func) + return (None, None) + + def _get_func_context(self, state): + mangler = state.targetctx.mangler + mangler = default_mangler if mangler is None else mangler + unique_name = state.func_ir.func_id.unique_name + modname = state.func_ir.func_id.func.__module__ + from numba.core.funcdesc import qualifying_prefix + qualprefix = qualifying_prefix(modname, unique_name) + fn_name = mangler(qualprefix, state.args) + + from numba.np.ufunc.parallel import get_thread_count + + ctx = {} + ctx['compiler_settings'] = {'verify': True, 'pass_statistics': False, 'pass_timings': False, 'ir_printing': numba.mlir.settings.PRINT_IR} + ctx['typemap'] = lambda op: state.typemap[op.name] + ctx['fnargs'] = lambda: state.args + ctx['restype'] = lambda: state.return_type + ctx['fnname'] = lambda: fn_name + ctx['resolve_func'] = self._resolve_func_name + ctx['fastmath'] = lambda: state.targetctx.fastmath + ctx['max_concurrency'] = lambda: get_thread_count() if state.flags.auto_parallel.enabled else 0 + return ctx + +@register_pass(mutates_CFG=True, analysis_only=False) +class MlirDumpPlier(MlirBackendBase): + + _name = "mlir_dump_plier" + + def __init__(self): + MlirBackendBase.__init__(self) + + def run_pass(self, state): + import mlir_compiler + module = mlir_compiler.create_module() + ctx = self._get_func_context(state) + mlir_compiler.lower_function(ctx, module, state.func_ir) + print(mlir_compiler.module_str(module)) + return True + +def get_mlir_func(): + global _mlir_last_compiled_func + return _mlir_last_compiled_func + +@register_pass(mutates_CFG=True, analysis_only=False) +class MlirBackend(MlirBackendBase): + + _name = "mlir_backend" + + def __init__(self): + MlirBackendBase.__init__(self) + + def run_pass_impl(self, state): + import mlir_compiler + global _mlir_active_module + old_module = _mlir_active_module + + try: + module = mlir_compiler.create_module() + _mlir_active_module = module + global _mlir_last_compiled_func + ctx = self._get_func_context(state) + _mlir_last_compiled_func = mlir_compiler.lower_function(ctx, module, state.func_ir) + mod_ir = mlir_compiler.compile_module(ctx, module) + finally: + _mlir_active_module = old_module + setattr(state, 'mlir_blob', mod_ir) + _reload_parfors() + state.reload_init.append(_reload_parfors) + return True + +@register_pass(mutates_CFG=True, analysis_only=False) +class MlirBackendInner(MlirBackendBase): + + _name = "mlir_backend_inner" + + def __init__(self): + MlirBackendBase.__init__(self) + + def run_pass_impl(self, state): + import mlir_compiler + global _mlir_active_module + module = _mlir_active_module + assert not module is None + global _mlir_last_compiled_func + ctx = self._get_func_context(state) + _mlir_last_compiled_func = mlir_compiler.lower_function(ctx, module, state.func_ir) + from numba.core.compiler import compile_result + state.cr = compile_result() + return True diff --git a/numba/mlir/settings.py b/numba/mlir/settings.py new file mode 100644 index 00000000000..8f09d635b3c --- /dev/null +++ b/numba/mlir/settings.py @@ -0,0 +1,16 @@ +from os import environ + +def _readenv(name, ctor, default): + value = environ.get(name) + if value is None: + return default() if callable(default) else default + try: + return ctor(value) + except Exception: + warnings.warn("environ %s defined but failed to parse '%s'" % + (name, value), RuntimeWarning) + return default + +USE_MLIR = _readenv('NUMBA_MLIR_ENABLE', int, 1) +DUMP_PLIER = _readenv('NUMBA_MLIR_DUMP_PLIER', int, 0) +PRINT_IR = _readenv('NUMBA_MLIR_PRINT_IR', int, 0) diff --git a/numba/mlir/tests/__init__.py b/numba/mlir/tests/__init__.py new file mode 100644 index 00000000000..b2cb40fb5ce --- /dev/null +++ b/numba/mlir/tests/__init__.py @@ -0,0 +1,10 @@ +from numba.testing import unittest +from numba.testing import load_testsuite +from os.path import dirname + +def load_tests(loader, tests, pattern): + suite = unittest.TestSuite() + this_dir = dirname(__file__) + suite.addTests(load_testsuite(loader, this_dir)) + + return suite diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py new file mode 100644 index 00000000000..23bdf352734 --- /dev/null +++ b/numba/mlir/tests/test_basic.py @@ -0,0 +1,340 @@ +import numba +from numba import njit +from math import nan, inf, isnan +from numpy.testing import assert_equal # for nans comparison + +from numba.tests.support import TestCase +import unittest + +import itertools + +# TODO: nans and infs not tested yet, we are not sure if want exactly follow +# interpreted python rules +_test_values = [-3,-2,-1,0,1,2,3,-2.5,-1.0,-0.5 -0.0, 0.0, 0.5, 1.0, 2.5] +class TestMlirBasic(TestCase): + + def test_ret(self): + def py_func(a): + return a + + jit_func = njit(py_func) + for val in _test_values: + assert_equal(py_func(val), jit_func(val)) + + def test_ops(self): + py_funcs = [ + lambda a, b: a + b, + lambda a, b: a - b, + lambda a, b: a * b, + lambda a, b: a / b, + lambda a, b: a % b, + # TODO: floordiv + ] + + for py_func in py_funcs: + jit_func = njit(py_func) + for a, b in itertools.product(_test_values, _test_values): + try: + assert_equal(py_func(a, b), jit_func(a, b)) + except ZeroDivisionError: + pass + + def test_inplace_op(self): + def py_func(a,b): + a += b + return a + + jit_func = njit(py_func) + for a, b in itertools.product(_test_values, _test_values): + assert_equal(py_func(a, b), jit_func(a, b)) + + def test_unary_ops(self): + py_funcs = [ + lambda a: +a, + lambda a: -a, + ] + + for py_func in py_funcs: + jit_func = njit(py_func) + for a in _test_values: + assert_equal(py_func(a), jit_func(a)) + + def test_cmp_ops(self): + py_funcs = [ + lambda a, b: a if a > b else b, + lambda a, b: a if a < b else b, + lambda a, b: a if a >= b else b, + lambda a, b: a if a <= b else b, + lambda a, b: a if a == b else b, + lambda a, b: a if a != b else b, + ] + + for py_func in py_funcs: + jit_func = njit(py_func) + for a, b in itertools.product(_test_values, _test_values): + assert_equal(py_func(a, b), jit_func(a, b)) + + def test_const_ops(self): + py_funcs = [ + lambda a: a + 42, + lambda a: 43 + a, + lambda a: a + 42.5, + lambda a: 43.5 + a, + ] + + for py_func in py_funcs: + jit_func = njit(py_func) + for val in _test_values: + assert_equal(py_func(val), jit_func(val)) + + def test_var(self): + def py_func(a): + c = 1 + c = c + a + return c + + jit_func = njit(py_func) + for val in _test_values: + assert_equal(py_func(val), jit_func(val)) + + def test_ret_none(self): + def py_func1(): + return None + + def py_func2(): + pass + + jit_func1 = njit(py_func1) + jit_func2 = njit(py_func2) + assert_equal(py_func1(), jit_func1()) + assert_equal(py_func2(), jit_func2()) + + def test_if1(self): + def py_func(a, b): + c = 3 + if a > 5: + c = c + a + c = c + b + return c + + jit_func = njit(py_func) + for a, b in itertools.product(_test_values, _test_values): + assert_equal(py_func(a, b), jit_func(a, b)) + + def test_if2(self): + def py_func(a, b): + if a > b: + return a + b + else: + return a - b + + jit_func = njit(py_func) + for a, b in itertools.product(_test_values, _test_values): + assert_equal(py_func(a, b), jit_func(a, b)) + + def test_tuple(self): + def py_func(a, b, c): + t = (a,b,c) + return t[0] + t[1] + t[2] + + jit_func = njit(py_func) + for a, b, c in itertools.product(_test_values, _test_values, _test_values): + assert_equal(py_func(a, b, c), jit_func(a, b, c)) + + def test_tuple_len(self): + def py_func(a, b, c): + t = (a,b,c) + return len(t) + + jit_func = njit(py_func) + for a, b, c in itertools.product(_test_values, _test_values, _test_values): + assert_equal(py_func(a, b, c), jit_func(a, b, c)) + + def test_range1(self): + def py_func(a): + res = 0 + for i in range(a): + res = res + i + return res + + jit_func = njit(py_func) + assert_equal(py_func(10), jit_func(10)) + + def test_range2(self): + def py_func(a, b): + res = 0 + for i in range(a, b): + res = res + i + return res + + jit_func = njit(py_func) + assert_equal(py_func(10, 20), jit_func(10, 20)) + + def test_range3(self): + def py_func(a, b, c): + res = 0 + for i in range(a, b, c): + res = res + i + return res + + jit_func = njit(py_func) + assert_equal(py_func(10, 20, 2), jit_func(10, 20, 2)) + + def test_range_negative_step(self): + def py_func(a, b, c): + res = 0 + for i in range(a, b, c): + res = res + i + return res + + jit_func = njit(py_func) + assert_equal(py_func(5, -8, -2), jit_func(5, -8, -2)) + + def test_range_const_step1(self): + def py_func(a, b): + res = 0 + for i in range(a, b, -2): + res = res + i + return res + + jit_func = njit(py_func) + assert_equal(py_func(5, -8), jit_func(5, -8)) + + def test_range_const_step2(self): + def py_func(a, b): + res = 0 + for i in range(a, b, 2): + res = res + i + return res + + jit_func = njit(py_func) + assert_equal(py_func(-5, 8), jit_func(-5, 8)) + + def test_range_use_index_after(self): + def py_func(n): + res = 0 + for i in range(0, n, 2): + res = res + i + return res + i + + jit_func = njit(py_func) + assert_equal(py_func(9), jit_func(9)) + + def test_range_if(self): + def py_func(n): + res = 0 + res1 = 2 + for i in range(n): + if i > 5: + res = res + i + else: + res1 = res1 + i * 2 + return res + res1 + + jit_func = njit(py_func) + assert_equal(py_func(10), jit_func(10)) + + def test_range_ifs(self): + def py_func(n): + res = 0 + for i in range(n): + if i == 2: + res = res + 2 + elif i == 7: + res = res + 5 + elif i == 99: + res = res + 99 + else: + res = res + i + return res + + jit_func = njit(py_func) + assert_equal(py_func(10), jit_func(10)) + + def test_range_continue(self): + def py_func(n): + res = 0 + res1 = 2 + for i in range(n): + res = res + i + if i < 5: + continue + res1 = res1 + i * 2 + return res + res1 + + jit_func = njit(py_func) + assert_equal(py_func(10), jit_func(10)) + + def test_range_nested1(self): + def py_func(a, b, c): + res = 0 + for i in range(a): + for j in range(b): + for k in range(c): + res = res + i + return res + + jit_func = njit(py_func) + assert_equal(py_func(10, 20, 2), jit_func(10, 20, 2)) + + def test_range_nested2(self): + def py_func(a, b, c): + res = 0 + for i in range(a): + for j in range(b): + for k in range(c): + res = res + i + j * 10 + k * 100 + return res + + jit_func = njit(py_func) + assert_equal(py_func(10, 20, 2), jit_func(10, 20, 2)) + + def test_prange1(self): + def py_func(a): + res = 0 + for i in numba.prange(a): + res = res + i + return res + + jit_func = njit(py_func, parallel=True) + assert_equal(py_func(10), jit_func(10)) + + def test_prange2(self): + def py_func(a, b): + res = 0 + for i in numba.prange(a, b): + res = res + i + return res + + jit_func = njit(py_func, parallel=True) + assert_equal(py_func(10, 20), jit_func(10, 20)) + + def test_func_call1(self): + def py_func1(b): + return b + 3 + + jit_func1 = njit(py_func1) + + def py_func2(a): + return jit_func1(a) * 4 + + jit_func2 = njit(py_func2) + + assert_equal(py_func2(10), jit_func2(10)) + + def test_func_call2(self): + def py_func1(b): + return b + 3 + + jit_func1 = njit(py_func1) + + def py_func2(a): + return jit_func1(a) * jit_func1(a + 1) + + jit_func2 = njit(py_func2) + + assert_equal(py_func2(10), jit_func2(10)) + + +if __name__ == '__main__': + unittest.main() diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py new file mode 100644 index 00000000000..d28d1651753 --- /dev/null +++ b/numba/mlir/tests/test_numpy.py @@ -0,0 +1,308 @@ +import numba +from numba import njit +from numpy.testing import assert_equal # for nans comparison +import numpy as np +from numba.tests.support import TestCase +import unittest +import itertools + +_arr_1d_int = [1,2,3,4,5,6,7,8] +_arr_1d_float = [1.0,2.1,3.2,4.3,5.4,6.5,7.6,8.7] +_arr_2d_int = [[1,2,3,4],[5,6,7,8]] +_arr_2d_float = [[1.0,2.1,3.2,4.3],[5.4,6.5,7.6,8.7]] +_test_arrays = [_arr_1d_int, _arr_1d_float, _arr_2d_int, _arr_2d_float] +class TestMlirBasic(TestCase): + + def test_staticgetitem(self): + def py_func(a): + return a[1] + + jit_func = njit(py_func) + arr = np.asarray([5,6,7]) + assert_equal(py_func(arr), jit_func(arr)) + + def test_getitem(self): + def py_func(a, b): + return a[b] + + jit_func = njit(py_func) + arr = np.asarray([5,6,7]) + for i in range(3): + assert_equal(py_func(arr, i), jit_func(arr, i)) + + def test_array_len(self): + def py_func(a): + return len(a) + + jit_func = njit(py_func) + arr = np.asarray([5,6,7]) + assert_equal(py_func(arr), jit_func(arr)) + + def test_unary(self): + funcs = [ + lambda a: a.sum(), + lambda a: np.sum(a), + lambda a: np.sqrt(a), + lambda a: np.square(a), + lambda a: a.size, + # lambda a: a.T, TODO: need fortran layout support + lambda a: a.T.T, + ] + + for py_func in funcs: + jit_func = njit(py_func) + for a in _test_arrays: + arr = np.array(a) + assert_equal(py_func(arr), jit_func(arr)) + + def test_binary(self): + funcs = [ + lambda a, b: np.add(a, b), + lambda a, b: a + b, + lambda a, b: np.subtract(a, b), + lambda a, b: a - b, + lambda a, b: np.multiply(a, b), + lambda a, b: a * b, + ] + + test_data = [1, 2.5, np.array([1,2,3]), np.array([4.4,5.5,6.6])] + for py_func in funcs: + jit_func = njit(py_func) + for a1, a2 in itertools.product(test_data, test_data): + assert_equal(py_func(a1,a2), jit_func(a1,a2)) + + def test_sum_axis(self): + funcs = [ + lambda a: np.sum(a, axis=0), + lambda a: np.sum(a, axis=1), + ] + + for py_func in funcs: + jit_func = njit(py_func) + arr = np.array([[1,2,3],[4,5,6]]) + for a in [arr, arr.astype(np.float32)]: + assert_equal(py_func(a), jit_func(a)) + + def test_sum_add(self): + def py_func(a, b): + return np.add(a, b).sum() + + jit_func = njit(py_func) + arr1 = np.asarray([1,2,3]) + arr2 = np.asarray([4,5,6]) + assert_equal(py_func(arr1, arr2), jit_func(arr1, arr2)) + + def test_sum_add2(self): + def py_func(a, b, c): + t = np.add(a, b) + return np.add(t, c).sum() + + jit_func = njit(py_func) + arr1 = np.asarray([1,2,3]) + arr2 = np.asarray([4,5,6]) + arr3 = np.asarray([7,8,9]) + assert_equal(py_func(arr1, arr2, arr3), jit_func(arr1, arr2, arr3)) + + def test_dot(self): + def py_func(a, b): + return np.dot(a, b) + + jit_func = njit(py_func) + arr1 = np.asarray([1,2,3], np.float32) + arr2 = np.asarray([4,5,6], np.float32) + arr3 = np.asarray([[1,2,3],[4,5,6]], np.float32) + arr4 = np.asarray([[1,2],[3,4],[5,6]], np.float32) + + for a, b in [(arr1,arr2), (arr3,arr4)]: + assert_equal(py_func(a, b), jit_func(a, b)) + + def test_static_setitem(self): + def py_func(a): + a[1] = 42 + return a[1] + + jit_func = njit(py_func) + arr = np.asarray([1,2,3]) + assert_equal(py_func(arr), jit_func(arr)) + + def test_setitem1(self): + def py_func(a, b): + a[b] = 42 + return a[b] + + jit_func = njit(py_func) + arr = np.asarray([1,2,3]) + assert_equal(py_func(arr, 1), jit_func(arr, 1)) + + def test_setitem2(self): + def py_func(a, b, c): + a[b, c] = 42 + return a[b, c] + + jit_func = njit(py_func) + arr = np.asarray([[1,2,3],[4,5,6]]) + assert_equal(py_func(arr, 1, 2), jit_func(arr, 1, 2)) + + def test_setitem_loop(self): + def py_func(a): + for i in range(len(a)): + a[i] = a[i] + i + return a.sum() + + jit_func = njit(py_func) + arr = np.asarray([3,2,1]) + assert_equal(py_func(arr.copy()), jit_func(arr.copy())) + + def test_array_bounds1(self): + def py_func(a): + res = 0 + for i in range(len(a)): + if i >= len(a) or i < 0: + res = res + 1 + else: + res = res + a[i] + return res + + jit_func = njit(py_func) + arr = np.asarray([3,2,1]) + assert_equal(py_func(arr.copy()), jit_func(arr.copy())) + + def test_array_bounds2(self): + def py_func(a): + res = 0 + for i in range(len(a)): + if i < len(a) and i >= 0: + res = res + a[i] + else: + res = res + 1 + return res + + jit_func = njit(py_func) + arr = np.asarray([3,2,1]) + assert_equal(py_func(arr.copy()), jit_func(arr.copy())) + + def test_array_bounds3(self): + def py_func(a): + res = 0 + for i in range(len(a)): + if 0 <= i < len(a): + res = res + a[i] + else: + res = res + 1 + return res + + jit_func = njit(py_func) + arr = np.asarray([3,2,1]) + assert_equal(py_func(arr.copy()), jit_func(arr.copy())) + + def test_array_bounds4(self): + def py_func(a): + res = 0 + for i in range(len(a) - 1): + if 0 <= i < (len(a) - 1): + res = res + a[i] + else: + res = res + 1 + return res + + jit_func = njit(py_func) + arr = np.asarray([3,2,1]) + assert_equal(py_func(arr.copy()), jit_func(arr.copy())) + + def test_array_shape(self): + def py_func(a): + shape = a.shape + return shape[0] + shape[1] * 10 + + jit_func = njit(py_func) + arr = np.array([[1,2,3],[4,5,6]]) + assert_equal(py_func(arr), jit_func(arr)) + + def test_array_return(self): + def py_func(a): + return a + + jit_func = njit(py_func) + arr = np.array([1,2,3]) + assert_equal(py_func(arr), jit_func(arr)) + + def test_array_prange_const(self): + def py_func(a, b): + a[0] = 42 + for i in numba.prange(b): + a[0] = 1 + return a[0] + + jit_func = njit(py_func, parallel=True) + arr = np.array([0.0]) + assert_equal(py_func(arr, 5), jit_func(arr, 5)) + + def test_empty1(self): + def py_func(d): + a = np.empty(d) + for i in range(d): + a[i] = i + return a + + jit_func = njit(py_func) + assert_equal(py_func(5), jit_func(5)) + + def test_empty2(self): + def py_func(d1, d2): + a = np.empty((d1, d2)) + for i in range(d1): + for j in range(d2): + a[i, j] = i + j * 10 + return a + + jit_func = njit(py_func) + assert_equal(py_func(5, 7), jit_func(5, 7)) + + def test_reshape(self): + funcs = [ + lambda a: a.reshape(a.size), + lambda a: a.reshape((a.size,)), + lambda a: a.reshape((a.size,1)), + lambda a: a.reshape((1, a.size)), + lambda a: a.reshape((1, a.size, 1)), + ] + + arr1 = np.array([1,2,3,4,5,6,7,8,9,10,11,12]) + # arr2 = arr1.reshape((2,6)) + # arr3 = arr1.reshape((2,3,2)) + for py_func in funcs: + jit_func = njit(py_func) + # for a in [arr1,arr2,arr3]: TODO: flatten support + for a in [arr1]: + assert_equal(py_func(a), jit_func(a)) + + def test_broadcast(self): + def py_func(a, b): + return np.add(a, b) + + jit_func = njit(py_func) + + test_data = [ + 1, + np.array([1]), + np.array([[1]]), + np.array([[1,2],[3,4]]), + np.array([5,6]), + np.array([[5],[6]]), + np.array([[5,6]]), + ] + + for a, b in itertools.product(test_data, test_data): + assert_equal(py_func(a,b), jit_func(a,b)) + + def test_parallel(self): + def py_func(a, b): + return np.add(a, b) + + jit_func = njit(py_func, parallel=True) + arr = np.asarray([[[1,2,3],[4,5,6]], + [[1,2,3],[4,5,6]]]) + assert_equal(py_func(arr,arr), jit_func(arr,arr)) + +if __name__ == '__main__': + unittest.main() diff --git a/numba/np/ufunc/parallel.py b/numba/np/ufunc/parallel.py index 1b4a3b53ddf..fbf250c6585 100644 --- a/numba/np/ufunc/parallel.py +++ b/numba/np/ufunc/parallel.py @@ -497,6 +497,7 @@ def raise_with_hint(required): raise_with_hint(requirements) ll.add_symbol('numba_parallel_for', lib.parallel_for) + ll.add_symbol('numba_parallel_for2', lib.parallel_for2) ll.add_symbol('do_scheduling_signed', lib.do_scheduling_signed) ll.add_symbol('do_scheduling_unsigned', lib.do_scheduling_unsigned) diff --git a/numba/np/ufunc/tbbpool.cpp b/numba/np/ufunc/tbbpool.cpp index faff4790fb7..87a829873aa 100644 --- a/numba/np/ufunc/tbbpool.cpp +++ b/numba/np/ufunc/tbbpool.cpp @@ -15,10 +15,16 @@ Implement parallel vectorize workqueue on top of Intel TBB. #include #include #include +#include +#include +#include #include "workqueue.h" #include "gufunc_scheduler.h" +#undef min +#undef max + /* TBB 2019 U5 is the minimum required version as this is needed: * https://github.com/intel/tbb/blob/18070344d755ece04d169e6cc40775cae9288cee/CHANGES#L133-L134 * and therefore @@ -38,6 +44,21 @@ Implement parallel vectorize workqueue on top of Intel TBB. static tbb::task_group *tg = NULL; static tbb::task_scheduler_init *tsi = NULL; + +namespace +{ +struct ThreadContext +{ + ThreadContext(int n_threads): + num_threads(n_threads), + arena(n_threads) {} + + int num_threads = 0; + tbb::task_arena arena; +}; +static ThreadContext* thread_context = nullptr; +} + static int tsi_count = 0; #ifdef _MSC_VER @@ -202,6 +223,104 @@ parallel_for(void *fn, char **args, size_t *dimensions, size_t *steps, void *dat }); } +struct InputRange +{ + size_t lower; + size_t upper; + size_t step; +}; + +struct Range +{ + size_t lower; + size_t upper; +}; + +struct Dim +{ + Range val; + Dim* prev; +}; + +using parallel_for2_fptr = void(*)(const Range*, size_t, void*); + +static void parallel_for2_nested(const InputRange* input_ranges, size_t depth, size_t num_threads, size_t num_loops, Dim* prev_dim, parallel_for2_fptr func, void* ctx) +{ + auto input = input_ranges[depth]; + auto lower_bound = input.lower; + auto upper_bound = input.upper; + auto step = input.step; + + if(_DEBUG) + { + printf("parallel_for2_nested: lower_bound=%d, upper_bound=%d, step=%d, depth=%d\n", (int)lower_bound, (int)upper_bound, (int)step, (int)depth); + } + + size_t count = (upper_bound - lower_bound + step - 1) / step; + size_t grain = std::max(size_t(1), std::min(count / num_threads / 2, size_t(64))); + tbb::parallel_for(tbb::blocked_range(0, count, grain), + [&](const tbb::blocked_range& r) + { + auto begin = lower_bound + r.begin() * step; + auto end = lower_bound + r.end() * step; + if(_DEBUG) + { + printf("parallel_for2_nested body: begin=%d, end=%d, depth=%d\n\n", (int)begin, (int)end, (int)depth); + } + auto next = depth + 1; + Dim dim{Range{begin, end}, prev_dim}; + if (next == num_loops) + { + auto thread_index = static_cast(tbb::this_task_arena::current_thread_index()); + std::array static_ranges; + std::unique_ptr dyn_ranges; + auto* range_ptr = [&]()->Range* + { + if (num_loops <= static_ranges.size()) + { + return static_ranges.data(); + } + dyn_ranges.reset(new Range[num_loops]); + return dyn_ranges.get(); + }(); + + Dim* current = &dim; + for (size_t i = 0; i < num_loops; ++i) + { + range_ptr[num_loops - i - 1] = current->val; + current = current->prev; + } + func(range_ptr, thread_index, ctx); + } + else + { + parallel_for2_nested(input_ranges, next, num_threads, num_loops, &dim, func, ctx); + } + }, tbb::auto_partitioner()); +} + +static void parallel_for2(const InputRange* input_ranges, size_t num_loops, parallel_for2_fptr func, void* ctx) +{ + auto context = thread_context; + assert(nullptr != context); + auto num_threads = context->num_threads; + if(_DEBUG) + { + printf("parallel_for2 num_loops=%d: ", (int)num_loops); + for (size_t i = 0; i < num_loops; ++i) + { + auto r = input_ranges[i]; + printf("(%d, %d, %d) ", (int)r.lower, (int)r.upper, (int)r.step); + } + puts("\n"); + } + + context->arena.execute([&] + { + parallel_for2_nested(input_ranges, 0, num_threads, num_loops, nullptr, func, ctx); + }); +} + void ignore_blocking_terminate_assertion( const char*, int, const char*, const char * ) { tbb::internal::runtime_warning("Unable to wait for threads to shut down before fork(). It can break multithreading in child process\n"); @@ -250,6 +369,8 @@ static void unload_tbb(void) tbb::set_assertion_handler(orig); delete tsi; tsi = NULL; + delete thread_context; + thread_context = nullptr; } } #endif @@ -266,6 +387,8 @@ static void launch_threads(int count) tg = new tbb::task_group; tg->run([] {}); // start creating threads asynchronously + thread_context = new ThreadContext(count); + _INIT_NUM_THREADS = count; #ifndef _MSC_VER @@ -307,6 +430,8 @@ MOD_INIT(tbbpool) PyLong_FromVoidPtr((void*)&add_task)); PyObject_SetAttrString(m, "parallel_for", PyLong_FromVoidPtr((void*)¶llel_for)); + PyObject_SetAttrString(m, "parallel_for2", + PyLong_FromVoidPtr((void*)¶llel_for2)); PyObject_SetAttrString(m, "do_scheduling_signed", PyLong_FromVoidPtr((void*)&do_scheduling_signed)); PyObject_SetAttrString(m, "do_scheduling_unsigned", diff --git a/numba/tests/__init__.py b/numba/tests/__init__.py index f04b1007dab..555bab11663 100644 --- a/numba/tests/__init__.py +++ b/numba/tests/__init__.py @@ -33,5 +33,8 @@ def load_tests(loader, tests, pattern): roc_dir = join(dirname(dirname(__file__)), 'roc/tests') suite.addTests(loader.discover(roc_dir)) + mlir_dir = join(dirname(dirname(__file__)), 'mlir/tests') + suite.addTests(loader.discover(mlir_dir)) + return suite