Skip to content

Commit

Permalink
[LLVM][RUNTIME] Add optional LLVM ORCJIT runtime executor (#15964)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalint13 authored Mar 11, 2024
1 parent 95f97e8 commit cae1af6
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 36 deletions.
21 changes: 20 additions & 1 deletion src/target/llvm/llvm_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,23 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) {
}
}

// Target options
// LLVM JIT engine options
if (const Optional<String>& v = target->GetAttr<String>("jit")) {
String value = v.value();
if ((value == "mcjit") || (value == "orcjit")) {
jit_engine_ = value;
} else {
LOG(FATAL) << "invalid jit option " << value << " (can be `mcjit` or `orcjit`).";
}
}

// RISCV code model
auto arch = llvm::Triple(triple_).getArch();
if (arch == llvm::Triple::riscv32 || arch == llvm::Triple::riscv64) {
code_model_ = llvm::CodeModel::Medium;
}

// Target options
#if TVM_LLVM_VERSION < 50
target_options_.LessPreciseFPMADOption = true;
#endif
Expand Down Expand Up @@ -525,6 +540,10 @@ std::string LLVMTargetInfo::str() const {
os << quote << Join(",", opts) << quote;
}

if (jit_engine_ != "mcjit") {
os << " -jit=" << jit_engine_;
}

return os.str();
}

Expand Down
6 changes: 6 additions & 0 deletions src/target/llvm/llvm_instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ class LLVMTargetInfo {
* \return `llvm::FastMathFlags` for this target
*/
llvm::FastMathFlags GetFastMathFlags() const { return fast_math_flags_; }
/*!
* \brief Get the LLVM JIT engine type
* \return the type name of the JIT engine (default "mcjit" or "orcjit")
*/
const std::string GetJITEngine() const { return jit_engine_; }
/*!
* \brief Get the LLVM optimization level
* \return optimization level for this target
Expand Down Expand Up @@ -324,6 +329,7 @@ class LLVMTargetInfo {
llvm::Reloc::Model reloc_model_ = llvm::Reloc::PIC_;
llvm::CodeModel::Model code_model_ = llvm::CodeModel::Small;
std::shared_ptr<llvm::TargetMachine> target_machine_;
std::string jit_engine_ = "mcjit";
};

/*!
Expand Down
197 changes: 177 additions & 20 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
#include <llvm/ADT/StringRef.h>
#include <llvm/Bitcode/BitcodeWriter.h>
#include <llvm/ExecutionEngine/ExecutionEngine.h>
#include <llvm/ExecutionEngine/MCJIT.h> // Force linking of MCJIT
#include <llvm/ExecutionEngine/MCJIT.h>
#include <llvm/ExecutionEngine/Orc/LLJIT.h>
#include <llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h>
#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
#include <llvm/IR/DataLayout.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/Intrinsics.h>
Expand Down Expand Up @@ -113,8 +116,11 @@ class LLVMModuleNode final : public runtime::ModuleNode {

bool ImplementsFunction(const String& name, bool query_imports) final;

void SetJITEngine(const std::string& jit_engine) { jit_engine_ = jit_engine; }

private:
void LazyInitJIT();
void InitMCJIT();
void InitORCJIT();
bool IsCompatibleWithHost(const llvm::TargetMachine* tm) const;
void* GetGlobalAddr(const std::string& name, const LLVMTarget& llvm_target) const;
void* GetFunctionAddr(const std::string& name, const LLVMTarget& llvm_target) const;
Expand All @@ -123,21 +129,31 @@ class LLVMModuleNode final : public runtime::ModuleNode {
std::unique_ptr<LLVMInstance> llvm_instance_;
// JIT lock
std::mutex mutex_;
// execution engine
llvm::ExecutionEngine* ee_{nullptr};
// jit execution engines
llvm::ExecutionEngine* mcjit_ee_{nullptr};
std::unique_ptr<llvm::orc::LLJIT> orcjit_ee_{nullptr};
// The raw pointer to the module.
llvm::Module* module_{nullptr};
// The unique_ptr owning the module. This becomes empty once JIT has been initialized
// (EngineBuilder takes ownership of the module).
std::unique_ptr<llvm::Module> module_owning_ptr_;
/* \brief names of the external functions declared in this module */
Array<String> function_names_;
std::string jit_engine_;
};

LLVMModuleNode::~LLVMModuleNode() {
if (ee_ != nullptr) {
ee_->runStaticConstructorsDestructors(true);
delete ee_;
if (mcjit_ee_ != nullptr) {
mcjit_ee_->runStaticConstructorsDestructors(true);
delete mcjit_ee_;
}
if (orcjit_ee_ != nullptr) {
auto dtors = llvm::orc::getDestructors(*module_);
auto dtorRunner = std::make_unique<llvm::orc::CtorDtorRunner>(orcjit_ee_->getMainJITDylib());
dtorRunner->add(dtors);
auto err = dtorRunner->run();
ICHECK(!err) << llvm::toString(std::move(err));
orcjit_ee_.reset();
}
module_owning_ptr_.reset();
}
Expand Down Expand Up @@ -166,7 +182,9 @@ PackedFunc LLVMModuleNode::GetFunction(const String& name, const ObjectPtr<Objec
std::string target_string = LLVMTarget::GetTargetMetadata(*module_);
return PackedFunc([target_string](TVMArgs args, TVMRetValue* rv) { *rv = target_string; });
}
if (ee_ == nullptr) LazyInitJIT();
ICHECK(jit_engine_.size()) << "JIT engine type is missing";
if ((jit_engine_ == "mcjit") && (mcjit_ee_ == nullptr)) InitMCJIT();
if ((jit_engine_ == "orcjit") && (orcjit_ee_ == nullptr)) InitORCJIT();

std::lock_guard<std::mutex> lock(mutex_);

Expand Down Expand Up @@ -353,6 +371,7 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) {

module_owning_ptr_ = cg->Finish();
module_ = module_owning_ptr_.get();
jit_engine_ = llvm_target->GetJITEngine();
llvm_target->SetTargetMetadata(module_);
module_->addModuleFlag(llvm::Module::Override, "Debug Info Version",
llvm::DEBUG_METADATA_VERSION);
Expand Down Expand Up @@ -384,13 +403,16 @@ bool LLVMModuleNode::ImplementsFunction(const String& name, bool query_imports)
return std::find(function_names_.begin(), function_names_.end(), name) != function_names_.end();
}

void LLVMModuleNode::LazyInitJIT() {
void LLVMModuleNode::InitMCJIT() {
std::lock_guard<std::mutex> lock(mutex_);
if (ee_) {
if (mcjit_ee_) {
return;
}
// MCJIT builder
With<LLVMTarget> llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_));
llvm::EngineBuilder builder(std::move(module_owning_ptr_));

// set options
builder.setEngineKind(llvm::EngineKind::JIT);
#if TVM_LLVM_VERSION <= 170
builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
Expand All @@ -400,18 +422,31 @@ void LLVMModuleNode::LazyInitJIT() {
builder.setMCPU(llvm_target->GetCPU());
builder.setMAttrs(llvm_target->GetTargetFeatures());
builder.setTargetOptions(llvm_target->GetTargetOptions());

// create the taget machine
auto tm = std::unique_ptr<llvm::TargetMachine>(builder.selectTarget());
if (!IsCompatibleWithHost(tm.get())) {
LOG(FATAL) << "Cannot run module, architecture mismatch";
}

// data layout
llvm::DataLayout layout(tm->createDataLayout());
ICHECK(layout == module_->getDataLayout())
<< "Data layout mismatch between module("
<< module_->getDataLayout().getStringRepresentation() << ")"
<< " and ExecutionEngine (" << layout.getStringRepresentation() << ")";
ee_ = builder.create(tm.release());
ICHECK(ee_ != nullptr) << "Failed to initialize jit engine for " << module_->getTargetTriple();
ee_->runStaticConstructorsDestructors(false);

// create MCJIT
mcjit_ee_ = builder.create(tm.release());
ICHECK(mcjit_ee_ != nullptr) << "Failed to initialize LLVM MCJIT engine for "
<< module_->getTargetTriple();

VLOG(2) << "LLVM MCJIT execute " << module_->getModuleIdentifier() << " for triple `"
<< llvm_target->GetTargetTriple() << "`"
<< " on cpu `" << llvm_target->GetCPU() << "`";

// run ctors
mcjit_ee_->runStaticConstructorsDestructors(false);

if (void** ctx_addr =
reinterpret_cast<void**>(GetGlobalAddr(runtime::symbol::tvm_module_ctx, *llvm_target))) {
Expand All @@ -424,7 +459,104 @@ void LLVMModuleNode::LazyInitJIT() {
// lead to a runtime crash.
// Do name lookup on a symbol that doesn't exist. This will force MCJIT to finalize
// all loaded objects, which will resolve symbols in JITed code.
ee_->getFunctionAddress("__some_name_that_hopefully_doesnt_exist__b49f8aaade5877eaba7583b91");
mcjit_ee_->getFunctionAddress(
"__some_name_that_hopefully_doesnt_exist__b49f8aaade5877eaba7583b91");
}

void LLVMModuleNode::InitORCJIT() {
std::lock_guard<std::mutex> lock(mutex_);
if (orcjit_ee_) {
return;
}
// ORCJIT builder
With<LLVMTarget> llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_));
llvm::orc::JITTargetMachineBuilder tm_builder(llvm::Triple(llvm_target->GetTargetTriple()));

// set options
tm_builder.setCPU(llvm_target->GetCPU());
tm_builder.setFeatures(llvm_target->GetTargetFeatureString());
tm_builder.setOptions(llvm_target->GetTargetOptions());
#if TVM_LLVM_VERSION <= 170
tm_builder.setCodeGenOptLevel(llvm::CodeGenOpt::Aggressive);
#else
tm_builder.setCodeGenOptLevel(llvm::CodeGenOptLevel::Aggressive);
#endif

// create the taget machine
std::unique_ptr<llvm::TargetMachine> tm = llvm::cantFail(tm_builder.createTargetMachine());
if (!IsCompatibleWithHost(tm.get())) {
LOG(FATAL) << "Cannot run module, architecture mismatch";
}

// data layout
String module_name = module_->getModuleIdentifier();
llvm::DataLayout layout(tm->createDataLayout());
ICHECK(layout == module_->getDataLayout())
<< "Data layout mismatch between module("
<< module_->getDataLayout().getStringRepresentation() << ")"
<< " and ExecutionEngine (" << layout.getStringRepresentation() << ")";

// compiler
const auto compilerBuilder = [&](const llvm::orc::JITTargetMachineBuilder&)
-> llvm::Expected<std::unique_ptr<llvm::orc::IRCompileLayer::IRCompiler>> {
return std::make_unique<llvm::orc::TMOwningSimpleCompiler>(std::move(tm));
};

#if TVM_LLVM_VERSION >= 130
// linker
const auto linkerBuilder = [&](llvm::orc::ExecutionSession& session, const llvm::Triple&) {
return std::make_unique<llvm::orc::ObjectLinkingLayer>(session);
};
#endif

// create LLJIT
orcjit_ee_ = llvm::cantFail(llvm::orc::LLJITBuilder()
#if TVM_LLVM_VERSION >= 110
.setDataLayout(layout)
#endif
.setCompileFunctionCreator(compilerBuilder)
#if TVM_LLVM_VERSION >= 130
.setObjectLinkingLayerCreator(linkerBuilder)
#endif
.create());

ICHECK(orcjit_ee_ != nullptr) << "Failed to initialize LLVM ORCJIT engine for "
<< module_->getTargetTriple();

// store ctors
auto ctors = llvm::orc::getConstructors(*module_);
llvm::orc::CtorDtorRunner ctorRunner(orcjit_ee_->getMainJITDylib());
ctorRunner.add(ctors);

// resolve system symbols (like pthread, dl, m, etc.)
auto gen =
llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(layout.getGlobalPrefix());
ICHECK(gen) << llvm::toString(gen.takeError()) << "\n";
orcjit_ee_->getMainJITDylib().addGenerator(std::move(gen.get()));

// transfer module to a clone
auto uctx = std::make_unique<llvm::LLVMContext>();
auto umod = llvm::CloneModule(*(std::move(module_owning_ptr_)));

// add the llvm module to run
llvm::orc::ThreadSafeModule tsm(std::move(umod), std::move(uctx));
auto err = orcjit_ee_->addIRModule(std::move(tsm));
ICHECK(!err) << llvm::toString(std::move(err));

VLOG(2) << "LLVM ORCJIT execute " << module_->getModuleIdentifier() << " for triple `"
<< llvm_target->GetTargetTriple() << "`"
<< " on cpu `" << llvm_target->GetCPU() << "`";

// run ctors
err = ctorRunner.run();
ICHECK(!err) << llvm::toString(std::move(err));

if (void** ctx_addr =
reinterpret_cast<void**>(GetGlobalAddr(runtime::symbol::tvm_module_ctx, *llvm_target))) {
*ctx_addr = this;
}
runtime::InitContextFunctions(
[this, &llvm_target](const char* name) { return GetGlobalAddr(name, *llvm_target); });
}

bool LLVMModuleNode::IsCompatibleWithHost(const llvm::TargetMachine* tm) const {
Expand All @@ -442,20 +574,40 @@ bool LLVMModuleNode::IsCompatibleWithHost(const llvm::TargetMachine* tm) const {
void* LLVMModuleNode::GetGlobalAddr(const std::string& name, const LLVMTarget& llvm_target) const {
// first verifies if GV exists.
if (module_->getGlobalVariable(name) != nullptr) {
return reinterpret_cast<void*>(ee_->getGlobalValueAddress(name));
} else {
return nullptr;
if (jit_engine_ == "mcjit") {
return reinterpret_cast<void*>(mcjit_ee_->getGlobalValueAddress(name));
} else if (jit_engine_ == "orcjit") {
#if TVM_LLVM_VERSION >= 150
auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getValue();
#else
auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getAddress();
#endif
return reinterpret_cast<void*>(addr);
} else {
LOG(FATAL) << "Either `mcjit` or `orcjit` are not initialized.";
}
}
return nullptr;
}

void* LLVMModuleNode::GetFunctionAddr(const std::string& name,
const LLVMTarget& llvm_target) const {
// first verifies if GV exists.
if (module_->getFunction(name) != nullptr) {
return reinterpret_cast<void*>(ee_->getFunctionAddress(name));
} else {
return nullptr;
if (jit_engine_ == "mcjit") {
return reinterpret_cast<void*>(mcjit_ee_->getFunctionAddress(name));
} else if (jit_engine_ == "orcjit") {
#if TVM_LLVM_VERSION >= 150
auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getValue();
#else
auto addr = llvm::cantFail(orcjit_ee_->lookup(name)).getAddress();
#endif
return reinterpret_cast<void*>(addr);
} else {
LOG(FATAL) << "Either `mcjit` or `orcjit` are not initialized.";
}
}
return nullptr;
}

TVM_REGISTER_GLOBAL("target.build.llvm")
Expand All @@ -476,6 +628,7 @@ TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate")
module->setTargetTriple(llvm_target->GetTargetTriple());
module->setDataLayout(llvm_target->GetOrCreateTargetMachine()->createDataLayout());
n->Init(std::move(module), std::move(llvm_instance));
n->SetJITEngine(llvm_target->GetJITEngine());
return runtime::Module(n);
});

Expand Down Expand Up @@ -595,6 +748,7 @@ TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int {
TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll")
.set_body_typed([](std::string filename, std::string fmt) -> runtime::Module {
auto n = make_object<LLVMModuleNode>();
n->SetJITEngine("mcjit");
n->LoadIR(filename);
return runtime::Module(n);
});
Expand All @@ -616,6 +770,7 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob")
std::unique_ptr<llvm::Module> blob =
CodeGenBlob(data, system_lib, llvm_target.get(), c_symbol_prefix);
n->Init(std::move(blob), std::move(llvm_instance));
n->SetJITEngine(llvm_target->GetJITEngine());
return runtime::Module(n);
});

Expand Down Expand Up @@ -645,6 +800,7 @@ runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata

auto n = make_object<LLVMModuleNode>();
n->Init(std::move(mod), std::move(llvm_instance));
n->SetJITEngine(llvm_target->GetJITEngine());

auto meta_mod = MetadataModuleCreate(metadata);
meta_mod->Import(runtime::Module(n));
Expand Down Expand Up @@ -691,6 +847,7 @@ runtime::Module CreateLLVMCrtMetadataModule(const Array<runtime::Module>& module

auto n = make_object<LLVMModuleNode>();
n->Init(std::move(mod), std::move(llvm_instance));
n->SetJITEngine(llvm_target->GetJITEngine());
for (auto m : modules) {
n->Import(m);
}
Expand Down
2 changes: 2 additions & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
.add_attr_option<Integer>("opt-level")
// LLVM command line flags, see below
.add_attr_option<Array<String>>("cl-opt")
// LLVM JIT engine mcjit/orcjit
.add_attr_option<String>("jit")
.set_default_keys({"cpu"})
// Force the external codegen kind attribute to be registered, even if no external
// codegen targets are enabled by the TVM build.
Expand Down
Loading

0 comments on commit cae1af6

Please sign in to comment.