Skip to content

Commit

Permalink
[wallet]: lock wallet context before adding/removing wallet settings
Browse files Browse the repository at this point in the history
- Also add test that verifies that no race condition will occur during the process.

Co-authored-by furszy <matiasfurszyfer@protonmail.com>
  • Loading branch information
ismaelsadeeq committed Aug 22, 2024
1 parent ee36717 commit 135eb1d
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 17 deletions.
33 changes: 33 additions & 0 deletions src/wallet/test/wallet_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,39 @@ BOOST_FIXTURE_TEST_CASE(importwallet_rescan, TestChain100Setup)
}
}

// This test verifies that wallet settings can be added and removed
// Concurrently, ensuring no race conditions occur during either process.
BOOST_FIXTURE_TEST_CASE(write_wallet_settings_concurrently, TestingSetup)
{
WalletContext context;
context.chain = m_node.chain.get();
const int NUM_WALLETS = 5;
std::vector<std::thread> threads;
{
for (int i{0}; i < NUM_WALLETS; i++) {
threads.emplace_back([i, &context] {
BOOST_CHECK_MESSAGE(AddWalletSetting(context, strprintf("wallet_%d", i)), strprintf("write wallet_%d failed", i));
});
}
for (auto& t : threads)
t.join();
auto wallets = context.chain->getRwSetting("wallet");
BOOST_CHECK_EQUAL(wallets.getValues().size(), NUM_WALLETS);
}
threads.clear();
{
for (int i{0}; i < NUM_WALLETS; i++) {
threads.emplace_back([i, &context] {
BOOST_CHECK_MESSAGE(RemoveWalletSetting(context, strprintf("wallet_%d", i)), strprintf("write wallet_%d failed", i));
});
}
for (auto& t : threads)
t.join();
auto wallets = context.chain->getRwSetting("wallet");
BOOST_CHECK_EQUAL(wallets.getValues().size(), 0);
}
}

// Check that GetImmatureCredit() returns a newly calculated value instead of
// the cached value after a MarkDirty() call.
//
Expand Down
31 changes: 16 additions & 15 deletions src/wallet/wallet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,38 +91,40 @@ using util::ToString;

namespace wallet {

bool AddWalletSetting(interfaces::Chain& chain, const std::string& wallet_name)
bool AddWalletSetting(WalletContext& context, const std::string& wallet_name)
{
common::SettingsValue setting_value = chain.getRwSetting("wallet");
LOCK(context.wallets_mutex);
common::SettingsValue setting_value = context.chain->getRwSetting("wallet");
if (!setting_value.isArray()) setting_value.setArray();
for (const common::SettingsValue& value : setting_value.getValues()) {
if (value.isStr() && value.get_str() == wallet_name) return true;
}
setting_value.push_back(wallet_name);
return chain.updateRwSetting("wallet", setting_value);
return context.chain->updateRwSetting("wallet", setting_value);
}

bool RemoveWalletSetting(interfaces::Chain& chain, const std::string& wallet_name)
bool RemoveWalletSetting(WalletContext& context, const std::string& wallet_name)
{
common::SettingsValue setting_value = chain.getRwSetting("wallet");
LOCK(context.wallets_mutex);
common::SettingsValue setting_value = context.chain->getRwSetting("wallet");
if (!setting_value.isArray()) return true;
common::SettingsValue new_value(common::SettingsValue::VARR);
for (const common::SettingsValue& value : setting_value.getValues()) {
if (!value.isStr() || value.get_str() != wallet_name) new_value.push_back(value);
}
if (new_value.size() == setting_value.size()) return true;
return chain.updateRwSetting("wallet", new_value);
return context.chain->updateRwSetting("wallet", new_value);
}

static void UpdateWalletSetting(interfaces::Chain& chain,
static void UpdateWalletSetting(WalletContext& context,
const std::string& wallet_name,
std::optional<bool> load_on_startup,
std::vector<bilingual_str>& warnings)
{
if (!load_on_startup) return;
if (load_on_startup.value() && !AddWalletSetting(chain, wallet_name)) {
if (load_on_startup.value() && !AddWalletSetting(context, wallet_name)) {
warnings.emplace_back(Untranslated("Wallet load on startup setting could not be updated, so wallet may not be loaded next node startup."));
} else if (!load_on_startup.value() && !RemoveWalletSetting(chain, wallet_name)) {
} else if (!load_on_startup.value() && !RemoveWalletSetting(context, wallet_name)) {
warnings.emplace_back(Untranslated("Wallet load on startup setting could not be updated, so wallet may still be loaded next node startup."));
}
}
Expand Down Expand Up @@ -157,7 +159,6 @@ bool RemoveWallet(WalletContext& context, const std::shared_ptr<CWallet>& wallet
{
assert(wallet);

interfaces::Chain& chain = wallet->chain();
std::string name = wallet->GetName();

// Unregister with the validation interface which also drops shared pointers.
Expand All @@ -172,7 +173,7 @@ bool RemoveWallet(WalletContext& context, const std::shared_ptr<CWallet>& wallet
wallet->NotifyUnload();

// Write the wallet setting
UpdateWalletSetting(chain, name, load_on_start, warnings);
UpdateWalletSetting(context, name, load_on_start, warnings);

return true;
}
Expand Down Expand Up @@ -293,7 +294,7 @@ std::shared_ptr<CWallet> LoadWalletInternal(WalletContext& context, const std::s
wallet->postInitProcess();

// Write the wallet setting
UpdateWalletSetting(*context.chain, name, load_on_start, warnings);
UpdateWalletSetting(context, name, load_on_start, warnings);

return wallet;
} catch (const std::runtime_error& e) {
Expand Down Expand Up @@ -474,7 +475,7 @@ std::shared_ptr<CWallet> CreateWallet(WalletContext& context, const std::string&
wallet->postInitProcess();

// Write the wallet settings
UpdateWalletSetting(*context.chain, name, load_on_start, warnings);
UpdateWalletSetting(context, name, load_on_start, warnings);

// Legacy wallets are being deprecated, warn if a newly created wallet is legacy
if (!(wallet_creation_flags & WALLET_FLAG_DESCRIPTORS)) {
Expand Down Expand Up @@ -4324,7 +4325,7 @@ bool DoMigration(CWallet& wallet, WalletContext& context, bilingual_str& error,
}

// Add the wallet to settings
UpdateWalletSetting(*context.chain, wallet_name, /*load_on_startup=*/true, warnings);
UpdateWalletSetting(context, wallet_name, /*load_on_startup=*/true, warnings);
}
if (data->solvable_descs.size() > 0) {
wallet.WalletLogPrintf("Making a new watchonly wallet containing the unwatched solvable scripts\n");
Expand Down Expand Up @@ -4361,7 +4362,7 @@ bool DoMigration(CWallet& wallet, WalletContext& context, bilingual_str& error,
}

// Add the wallet to settings
UpdateWalletSetting(*context.chain, wallet_name, /*load_on_startup=*/true, warnings);
UpdateWalletSetting(context, wallet_name, /*load_on_startup=*/true, warnings);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/wallet/wallet.h
Original file line number Diff line number Diff line change
Expand Up @@ -1112,10 +1112,10 @@ class WalletRescanReserver
};

//! Add wallet name to persistent configuration so it will be loaded on startup.
bool AddWalletSetting(interfaces::Chain& chain, const std::string& wallet_name);
bool AddWalletSetting(WalletContext& context, const std::string& wallet_name);

//! Remove wallet name from persistent configuration so it will not be loaded on startup.
bool RemoveWalletSetting(interfaces::Chain& chain, const std::string& wallet_name);
bool RemoveWalletSetting(WalletContext& context, const std::string& wallet_name);

struct MigrationResult {
std::string wallet_name;
Expand Down

0 comments on commit 135eb1d

Please sign in to comment.