diff --git a/include/mlir/ExecutionEngine/ExecutionEngine.h b/include/mlir/ExecutionEngine/ExecutionEngine.h index 69f6c2e72f33..2023d0b75d15 100644 --- a/include/mlir/ExecutionEngine/ExecutionEngine.h +++ b/include/mlir/ExecutionEngine/ExecutionEngine.h @@ -32,6 +32,7 @@ namespace llvm { template class Expected; class Module; +class TargetMachine; } // namespace llvm namespace mlir { @@ -61,10 +62,13 @@ class ExecutionEngine { /// can be used, e.g., for reporting or optimization. /// If `sharedLibPaths` are provided, the underlying JIT-compilation will open /// and link the shared libraries for symbol resolution. + /// If `refTM` is provided, the underlying JIT-compilation will use that + /// target machine as reference to build its target machine. static llvm::Expected> create(ModuleOp m, std::function transformer = {}, - ArrayRef sharedLibPaths = {}); + ArrayRef sharedLibPaths = {}, + llvm::TargetMachine *refTM = nullptr); /// Looks up a packed-argument function with the given name and returns a /// pointer to it. Propagates errors in case of failure. diff --git a/lib/ExecutionEngine/ExecutionEngine.cpp b/lib/ExecutionEngine/ExecutionEngine.cpp index 0317a92c43fa..8cfd0b769f1a 100644 --- a/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/lib/ExecutionEngine/ExecutionEngine.cpp @@ -144,10 +144,14 @@ class OrcJIT { loadLibraries(sharedLibPaths); } - // Create a JIT engine for the current host. + // Create a JIT engine for the reference target machine `refTM` or the current + // host if `refTM` is not provided. static Expected> - createDefault(IRTransformer transformer, ArrayRef sharedLibPaths) { - auto machineBuilder = llvm::orc::JITTargetMachineBuilder::detectHost(); + create(IRTransformer transformer, ArrayRef sharedLibPaths, + llvm::TargetMachine *refTM) { + + auto machineBuilder = + refTM ? createRefMachineBuilder(refTM) : createDefaultMachineBuilder(); if (!machineBuilder) return machineBuilder.takeError(); @@ -173,6 +177,50 @@ class OrcJIT { } private: + // Create a JIT engine for the reference target machine `refTM`. + static Expected + createRefMachineBuilder(llvm::TargetMachine *refTM) { + assert(refTM && "Expected reference target machine!"); + + Expected machineBuilder = + llvm::orc::JITTargetMachineBuilder(refTM->getTargetTriple()); + if (!machineBuilder) + return machineBuilder.takeError(); + + machineBuilder->setCPU(refTM->getTargetCPU()); + machineBuilder->setRelocationModel(refTM->getRelocationModel()); + machineBuilder->setCodeModel(refTM->getCodeModel()); + machineBuilder->setCodeGenOptLevel(refTM->getOptLevel()); + + std::vector features; + llvm::SubtargetFeatures::Split(features, refTM->getTargetFeatureString()); + machineBuilder->addFeatures(features); + + return machineBuilder; + } + + // Create a JIT engine for the current host. + static Expected + createDefaultMachineBuilder() { + auto machineBuilder = llvm::orc::JITTargetMachineBuilder::detectHost(); + if (!machineBuilder) + return machineBuilder.takeError(); + + // Retrieve host CPU sub-target features. + llvm::SubtargetFeatures subtargetFeatures; + llvm::StringMap featureMap; + llvm::sys::getHostCPUFeatures(featureMap); + for (auto &feature : featureMap) + subtargetFeatures.AddFeature(feature.first(), feature.second); + + // Relocation model, code model and codegen opt level are kept to default + // values. + machineBuilder->setCPU(llvm::sys::getHostCPUName()); + machineBuilder->addFeatures(subtargetFeatures.getFeatures()); + + return machineBuilder; + } + // Wrap the `irTransformer` into a function that can be called by the // IRTranformLayer. If `irTransformer` is not set up, return the module as // is without errors. @@ -321,12 +369,11 @@ void packFunctionArguments(llvm::Module *module) { // Out of line for PIMPL unique_ptr. ExecutionEngine::~ExecutionEngine() = default; -Expected> -ExecutionEngine::create(ModuleOp m, - std::function transformer, - ArrayRef sharedLibPaths) { +Expected> ExecutionEngine::create( + ModuleOp m, std::function transformer, + ArrayRef sharedLibPaths, llvm::TargetMachine *refTM) { auto engine = llvm::make_unique(); - auto expectedJIT = impl::OrcJIT::createDefault(transformer, sharedLibPaths); + auto expectedJIT = impl::OrcJIT::create(transformer, sharedLibPaths, refTM); if (!expectedJIT) return expectedJIT.takeError();