Skip to content

Commit

Permalink
Towards fixing the problems.
Browse files Browse the repository at this point in the history
- Increase the reference counter of string literals right after their creation such that they don't get cleared away.
- Fixed some minor issues in StringRefCounter:
  - add() should be hidden inside inc() IMHO.
  - string should be erased when the reference counter gets 0 (not 1).
- This essentially fixes the memory leak.
- Unfortunately, the script-level test case operator_eq_2.daphne fails due to a double-free that is related to the arith.select op...
  • Loading branch information
pdamme committed Aug 7, 2024
1 parent f0a0881 commit c9d0d45
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 40 deletions.
24 changes: 22 additions & 2 deletions src/compiler/lowering/ManageObjRefsPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ using namespace mlir;
* op result) to prevent memory leaks.
* - We do this as soon as possible. That is, after the last use of the value,
* or directly after its definition, if it has no uses.
* - As the only exception: We do not decrease the reference of the arguments
* to block terminators.
* - As the only exceptions:
* - We do not decrease the reference counters of the arguments to block terminators.
* - We do not decrease the reference counters of string literals.
* - Whenever a value is duplicated (e.g., by passing it as a block argument),
* we increase the reference of the underlying data object. This is to ensure
* that decreasing the reference on the new value does not destroy a data
Expand Down Expand Up @@ -80,6 +81,25 @@ void processValue(OpBuilder builder, Value v) {
if (defOp && llvm::isa<daphne::ConvertDenseMatrixToMemRef>(defOp))
processMemRefInterop(builder, v);

// Increase the reference counter of string literals, such that they don't
// get gargabe collected.
if(defOp && llvm::isa<daphne::ConstantOp>(defOp) && llvm::isa<daphne::StringType>(v.getType())) {
// The given value is a string literal. We want to increase its reference
// counter right after its definition, such that it is never removed.
// But if the defining op is the block of a FuncOp, make sure not to insert the
// IncRefOp before the CreateDaphneContextOp, otherwise we will run
// into problems during/after lowering to kernel calls.
Block * pb = v.getParentBlock();
if(auto fo = dyn_cast<func::FuncOp>(pb->getParentOp())) {
Value dctx = CompilerUtils::getDaphneContext(fo);
builder.setInsertionPointAfterValue(dctx);
}
else
builder.setInsertionPointAfter(defOp);
builder.create<daphne::IncRefOp>(v.getLoc(), v);
}


if(!llvm::isa<daphne::MatrixType, daphne::FrameType, daphne::StringType>(v.getType()))
return;

Expand Down
4 changes: 2 additions & 2 deletions src/runtime/local/kernels/DecRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@ struct DecRef<Structure> {
template<>
struct DecRef<char> {
static void apply(const char* arg, DCTX(ctx)) {

// Decrease the reference counter. If it became zero, delete the string.
if(!ctx->stringRefCount.dec(arg)) {
std::cerr << "DecRef: " << reinterpret_cast<uintptr_t>(arg) << " (not found) - assuming count=1 and deleting \"" << arg << "\"" << std::endl;
delete [] arg;
}
}
Expand All @@ -50,6 +49,7 @@ struct DecRef<char> {
// ****************************************************************************
// Convenience function
// ****************************************************************************

template<class DTArg>
void decRef(const DTArg * arg, DCTX(ctx)) {
DecRef<DTArg>::apply(arg, ctx);
Expand Down
8 changes: 2 additions & 6 deletions src/runtime/local/kernels/IncRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,8 @@ struct IncRef<Structure> {
template<>
struct IncRef<char> {
static void apply(const char* arg, DCTX(ctx)) {

if(!ctx->stringRefCount.inc(arg)) {
std::cerr << "IncRef: " << reinterpret_cast<uintptr_t>(arg) << " (not found) - adding \"" << arg << "\" with count=2" << std::endl;
// reference counting for removal later
ctx->stringRefCount.add(arg);
}
// Increase the reference counter.
ctx->stringRefCount.inc(arg);
}
};

Expand Down
39 changes: 22 additions & 17 deletions src/util/StringRefCount.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,39 +16,44 @@

#include "StringRefCount.h"

void StringRefCounter::add(const char* arg) {
auto ptr = reinterpret_cast<uintptr_t>(arg);
const std::lock_guard<std::mutex> lock(mtxStrRefCnt);
stringRefCount.insert({ptr, 2});
logger->info("StringRefCount: Added ptr={}; arg={}", ptr, arg);
}

bool StringRefCounter::inc(const char* arg) {
void StringRefCounter::inc(const char* arg) {
auto ptr = reinterpret_cast<uintptr_t>(arg);
const std::lock_guard<std::mutex> lock(mtxStrRefCnt);
if(auto found = stringRefCount.find(ptr); found != stringRefCount.end()) {
// If the string was found, increase its reference counter.
found->second++;
std::cerr << "IncRef: " << ptr << " (found) - count incremented to " << found->second << std::endl;
logger->info("IncRef: ptr={}; arg={}", ptr, arg);
return true;
logger->info("StringRefCounter::inc: ptr={}; arg={}; found and incremented", ptr, arg);
}
else {
// If the string was not found, implicitly assume a prior counter of 1,
// and increase the counter to 2.
stringRefCount.insert({ptr, 2});
logger->info("StringRefCounter::inc: ptr={}; arg={}; not found and set to 2", ptr, arg);
}
return false;
}

bool StringRefCounter::dec(const char* arg) {
auto ptr = reinterpret_cast<uintptr_t>(arg);
const std::lock_guard<std::mutex> lock(mtxStrRefCnt);
if(auto found = stringRefCount.find(ptr); found != stringRefCount.end()) {
// If the string was found, decrease its reference counter.
found->second--;
std::cerr << "DecRef: " << ptr << " (found) - count decremented to " << found->second << std::endl;
if(found->second == 1) {
logger->debug("Removing from StringRefCounter: ptr={}; arg={}", ptr, arg);
// delete [] reinterpret_cast<char *>(found->first);
logger->info("StringRefCounter::dec: ptr={}; arg={}; found and decremented", ptr, arg);
if(found->second == 0) {
// If the reference counter became zero, erase it and return false.
logger->info("StringRefCounter::dec: ptr={}; arg={}; became zero and erased", ptr, arg);
stringRefCount.erase(found);
return false;
}
// If the reference counter did not become zero, keep it and return true.
return true;
}
return false;
else {
// If the string was not found, implicitly assume a prior counter of 1,
// don't change the stored counters, just return false.
logger->info("StringRefCounter::dec: ptr={}; arg={}; not found", ptr, arg);
return false;
}
}

StringRefCounter& StringRefCounter::instance() {
Expand Down
31 changes: 19 additions & 12 deletions src/util/StringRefCount.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,30 @@ class StringRefCounter {

~StringRefCounter() {
if(!stringRefCount.empty()) {
// This should not happen
std::cerr << "Deleting " << stringRefCount.size() << " remaining string refs in ~DaphneContext()" << std::endl;
logger->warn("{} string refs still present while destroying DaphneContext - this should not happen.", stringRefCount.size());
// This should not happen.
logger->warn("{} string refs still present while destroying StringRefCounter - this should not happen.", stringRefCount.size());
}
}

/**
* @brief This method adds a numeric representation of a char* and an initial reference count
* to delete the string/char* later on in the DecRef kernel.
*
* @param str The char* to keep track of
*
*/
void add(const char* arg);

bool inc(const char* arg);
* @brief Increases the reference counter of the given string.
*
* If no reference counter is stored for this string, a prior value of 1 is
* implicitly assumed, i.e., then the reference counter is increased to an
* explicitly stored value of 2.
*/
void inc(const char* arg);

/**
* @brief Decreases the reference counter of the given string.
*
* If no reference counter is stored for this string, a prior value of 1 is
* implicitly assumed, i.e., then no changes are made to the stored reference
* counters and `false` is returned.
*
* @return `false` if the reference counter became zero through the decrement,
* `true` otherwise.
*/
bool dec(const char* arg);

static StringRefCounter& instance();
Expand Down
4 changes: 3 additions & 1 deletion test/run_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ std::unique_ptr<DaphneContext> setupContextAndLogger() {
createDaphneContext(
dctx_, reinterpret_cast<uint64_t>(&user_config),
reinterpret_cast<uint64_t>(&dispatchMapping),
reinterpret_cast<uint64_t>(&Statistics::instance()));
reinterpret_cast<uint64_t>(&Statistics::instance()),
reinterpret_cast<uint64_t>(&StringRefCounter::instance())
);

#ifdef USE_CUDA
CUDA::createCUDAContext(dctx_);
Expand Down

0 comments on commit c9d0d45

Please sign in to comment.