Skip to content

Commit

Permalink
Fix virtual function unload
Browse files Browse the repository at this point in the history
  • Loading branch information
qubka committed Dec 3, 2024
1 parent 1ac6220 commit 2f073d4
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 28 deletions.
15 changes: 9 additions & 6 deletions src/callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ uint64_t PLH::Callback::getJitFunc(const asmjit::FuncSignature& sig, const Callb
return m_functionPtr;
}

if (!g_jitRuntime) {
auto rt = m_rt.lock();
if (!rt) {
m_errorCode = "JitRuntime invalid";
return 0;
}
Expand All @@ -67,7 +68,7 @@ uint64_t PLH::Callback::getJitFunc(const asmjit::FuncSignature& sig, const Callb
physical registers may be inserted as nodes.
*/
asmjit::CodeHolder code;
code.init(g_jitRuntime->environment(), g_jitRuntime->cpuFeatures());
code.init(rt->environment(), rt->cpuFeatures());

// initialize function
asmjit::x86::Compiler cc(&code);
Expand Down Expand Up @@ -297,7 +298,7 @@ uint64_t PLH::Callback::getJitFunc(const asmjit::FuncSignature& sig, const Callb

cc.finalize();

if (asmjit::Error err = g_jitRuntime->add(&m_functionPtr, &code)) {
if (asmjit::Error err = rt->add(&m_functionPtr, &code)) {
m_functionPtr = 0;
m_errorCode = asmjit::DebugUtils::errorAsString(err);
return 0;
Expand Down Expand Up @@ -394,11 +395,13 @@ std::string_view PLH::Callback::getError() const {
return !m_functionPtr && m_errorCode ? m_errorCode : "";
}

PLH::Callback::Callback() {
PLH::Callback::Callback(std::weak_ptr<asmjit::JitRuntime> rt) : m_rt(std::move(rt)) {
}

PLH::Callback::~Callback() {
if (m_functionPtr) {
g_jitRuntime->release(m_functionPtr);
if (auto rt = m_rt.lock()) {
if (m_functionPtr) {
rt->release(m_functionPtr);
}
}
}
5 changes: 2 additions & 3 deletions src/callback.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ namespace PLH {

using View = std::pair<std::vector<CallbackHandler>&, std::shared_lock<std::shared_mutex>>;

Callback();
explicit Callback(std::weak_ptr<asmjit::JitRuntime> rt);
~Callback();

uint64_t getJitFunc(const asmjit::FuncSignature& sig, CallbackEntry pre, CallbackEntry post);
Expand All @@ -121,6 +121,7 @@ namespace PLH {
private:
static asmjit::TypeId getTypeId(DataType type);

std::weak_ptr<asmjit::JitRuntime> m_rt;
std::array<std::vector<CallbackHandler>, 2> m_callbacks;
std::shared_mutex m_mutex;
uint64_t m_functionPtr = 0;
Expand All @@ -129,8 +130,6 @@ namespace PLH {
const char* m_errorCode;
};
};

extern std::unique_ptr<asmjit::JitRuntime> g_jitRuntime;
}

inline PLH::ReturnFlag operator|(PLH::ReturnFlag lhs, PLH::ReturnFlag rhs) noexcept {
Expand Down
38 changes: 21 additions & 17 deletions src/plugin.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
#include "plugin.hpp"
#include "scope_guard.hpp"

PLH::PolyHookPlugin g_polyHookPlugin;
EXPOSE_PLUGIN(PLUGIN_API, &g_polyHookPlugin)

std::unique_ptr<asmjit::JitRuntime> PLH::g_jitRuntime;

using namespace PLH;

static void PreCallback(Callback* callback, const Callback::Parameters* params, Callback::Property* property, const Callback::Return* ret) {
Expand Down Expand Up @@ -35,11 +34,10 @@ static void PostCallback(Callback* callback, const Callback::Parameters* params,
}

void PolyHookPlugin::OnPluginStart() {
g_jitRuntime = std::make_unique<asmjit::JitRuntime>();
m_jitRuntime = std::make_unique<asmjit::JitRuntime>();
}

void PolyHookPlugin::OnPluginEnd() {
g_jitRuntime.reset();
}

Callback* PolyHookPlugin::hookDetour(void* pFunc, DataType returnType, std::span<const DataType> arguments) {
Expand All @@ -56,10 +54,11 @@ Callback* PolyHookPlugin::hookDetour(void* pFunc, DataType returnType, std::span
}
}

auto callback = std::make_unique<Callback>();
auto callback = std::make_unique<Callback>(m_jitRuntime);
auto error = callback->getError();
if (!error.empty()) {
// Log ?
std::puts(error.data());
std::terminate();
return nullptr;
}

Expand Down Expand Up @@ -89,10 +88,11 @@ Callback* PolyHookPlugin::hookVirtual(void* pClass, int index, DataType returnTy
}
}

auto callback = std::make_unique<Callback>();
auto callback = std::make_unique<Callback>(m_jitRuntime);
auto error = callback->getError();
if (!error.empty()) {
// Log ?
std::puts(error.data());
std::terminate();
return nullptr;
}

Expand Down Expand Up @@ -125,11 +125,10 @@ bool PolyHookPlugin::unhookDetour(void* pFunc) {
auto it = m_detours.find(pFunc);
if (it != m_detours.end()) {
auto& detour = it->second;
if (detour->unHook()) {
m_callbacks.erase(std::pair{detour.get(), -1});
m_detours.erase(it);
return true;
}
detour->unHook();
m_detours.erase(it);
m_callbacks.erase(std::pair{detour.get(), -1});
return true;
}

return false;
Expand All @@ -144,10 +143,13 @@ bool PolyHookPlugin::unhookVirtual(void* pClass, int index) {
auto it = m_vhooks.find(pClass);
if (it != m_vhooks.end()) {
auto& vtable = it->second;
if (vtable->unHook()) {
vtable->unHook();

bool shouldBeUnhook = true;
auto hookGuard = ScopeGuard([&]() {
if (shouldBeUnhook) m_vhooks.erase(it);
m_callbacks.erase(std::pair{vtable.get(), index});
m_vhooks.erase(it);
}
});

auto it2 = m_tables.find(pClass);
if (it2 != m_tables.end()) {
Expand All @@ -163,7 +165,9 @@ bool PolyHookPlugin::unhookVirtual(void* pClass, int index) {
if (!vtable->hook())
return false;

return m_vhooks.emplace(pClass, std::move(vtable)).first->second.get();
// do not unhook, we just replace our value in map
shouldBeUnhook = false;
return true;
}
}

Expand Down
5 changes: 3 additions & 2 deletions src/plugin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ namespace PLH {
int getVTableIndex(void* pFunc) const;

private:
std::map<void*, std::unique_ptr<IHook>> m_vhooks;
std::map<void*, std::unique_ptr<NatDetour>> m_detours;
std::shared_ptr<asmjit::JitRuntime> m_jitRuntime;
std::map<std::pair<void*, int>, std::unique_ptr<Callback>> m_callbacks;
std::map<void*, std::pair<VFuncMap, VFuncMap>> m_tables;
std::map<void*, std::unique_ptr<IHook>> m_vhooks;
std::map<void*, std::unique_ptr<NatDetour>> m_detours;
std::mutex m_mutex;
};
}
32 changes: 32 additions & 0 deletions src/scope_guard.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#pragma once

#include <functional>

namespace PLH {
class ScopeGuard {
public:
ScopeGuard() : f(nullptr) {}
~ScopeGuard() { if (f) f(); }
ScopeGuard(const ScopeGuard&) = delete;
template<class Callable> requires (!std::is_same_v<Callable, ScopeGuard>)
explicit ScopeGuard(Callable&& undo) : f(std::forward<Callable>(undo)) {}
ScopeGuard(ScopeGuard&& other) noexcept : f(std::move(other.f)) { other.f = nullptr; }

ScopeGuard& operator=(const ScopeGuard&) = delete;
ScopeGuard& operator=(ScopeGuard&& other) noexcept {
if (f) f();
f = std::move(other.f);
other.f = nullptr;
return *this;
}
template<class Callable> requires (!std::is_same_v<Callable, ScopeGuard>)
ScopeGuard& operator=(Callable&& undo) {
if (f) f();
f = std::forward<Callable>(undo);
return *this;
}

private:
std::function<void()> f;
};
}

0 comments on commit 2f073d4

Please sign in to comment.