From d87569372378fa4a6c11c69ef695525198d5fa96 Mon Sep 17 00:00:00 2001 From: iphydf Date: Sat, 3 Feb 2024 14:38:47 +0000 Subject: [PATCH] refactor: Use `merge_sort` instead of `qsort` for sorting. --- toxcore/BUILD.bazel | 3 +- toxcore/DHT.c | 160 +++++++++++++++++++++++------------- toxcore/crypto_core_test.cc | 29 +++++++ toxcore/group.c | 68 +++++++++++++-- toxcore/onion_announce.c | 108 ++++++++++++++++-------- toxcore/onion_client.c | 105 ++++++++++++++++------- toxcore/util.c | 95 +++++++++++++++++++++ toxcore/util.h | 68 +++++++++++++++ toxcore/util_test.cc | 124 ++++++++++++++++++++++------ 9 files changed, 602 insertions(+), 158 deletions(-) diff --git a/toxcore/BUILD.bazel b/toxcore/BUILD.bazel index 5ff521a4ed6..831093ccccc 100644 --- a/toxcore/BUILD.bazel +++ b/toxcore/BUILD.bazel @@ -100,8 +100,6 @@ cc_test( size = "small", srcs = ["util_test.cc"], deps = [ - ":crypto_core", - ":crypto_core_test_util", ":util", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", @@ -953,6 +951,7 @@ cc_library( ":crypto_core", ":friend_connection", ":logger", + ":mem", ":mono_time", ":net_crypto", ":network", diff --git a/toxcore/DHT.c b/toxcore/DHT.c index 2567d1b5a8f..2467c71607d 100644 --- a/toxcore/DHT.c +++ b/toxcore/DHT.c @@ -25,6 +25,7 @@ #include "ping_array.h" #include "shared_key_cache.h" #include "state.h" +#include "util.h" /** The timeout after which a node is discarded completely. */ #define KILL_NODE_TIMEOUT (BAD_NODE_TIMEOUT + PING_INTERVAL) @@ -755,49 +756,6 @@ int get_close_nodes( is_lan, want_announce); } -typedef struct DHT_Cmp_Data { - uint64_t cur_time; - const uint8_t *base_public_key; - Client_data entry; -} DHT_Cmp_Data; - -non_null() -static int dht_cmp_entry(const void *a, const void *b) -{ - const DHT_Cmp_Data *cmp1 = (const DHT_Cmp_Data *)a; - const DHT_Cmp_Data *cmp2 = (const DHT_Cmp_Data *)b; - const Client_data entry1 = cmp1->entry; - const Client_data entry2 = cmp2->entry; - const uint8_t *cmp_public_key = cmp1->base_public_key; - - const bool t1 = assoc_timeout(cmp1->cur_time, &entry1.assoc4) && assoc_timeout(cmp1->cur_time, &entry1.assoc6); - const bool t2 = assoc_timeout(cmp2->cur_time, &entry2.assoc4) && assoc_timeout(cmp2->cur_time, &entry2.assoc6); - - if (t1 && t2) { - return 0; - } - - if (t1) { - return -1; - } - - if (t2) { - return 1; - } - - const int closest = id_closest(cmp_public_key, entry1.public_key, entry2.public_key); - - if (closest == 1) { - return 1; - } - - if (closest == 2) { - return -1; - } - - return 0; -} - #ifdef CHECK_ANNOUNCE_NODE non_null() static void set_announce_node_in_list(Client_data *list, uint32_t list_len, const uint8_t *public_key) @@ -914,31 +872,117 @@ static bool store_node_ok(const Client_data *client, uint64_t cur_time, const ui || id_closest(comp_public_key, client->public_key, public_key) == 2; } +typedef struct Client_data_Cmp { + const Memory *mem; + uint64_t cur_time; + const uint8_t *comp_public_key; +} Client_data_Cmp; + non_null() -static void sort_client_list(const Memory *mem, Client_data *list, uint64_t cur_time, unsigned int length, - const uint8_t *comp_public_key) +static int client_data_cmp(const Client_data_Cmp *cmp, const Client_data *entry1, const Client_data *entry2) { - // Pass comp_public_key to qsort with each Client_data entry, so the - // comparison function can use it as the base of comparison. - DHT_Cmp_Data *cmp_list = (DHT_Cmp_Data *)mem_valloc(mem, length, sizeof(DHT_Cmp_Data)); + const bool t1 = assoc_timeout(cmp->cur_time, &entry1->assoc4) && assoc_timeout(cmp->cur_time, &entry1->assoc6); + const bool t2 = assoc_timeout(cmp->cur_time, &entry2->assoc4) && assoc_timeout(cmp->cur_time, &entry2->assoc6); - if (cmp_list == nullptr) { - return; + if (t1 && t2) { + return 0; } - for (uint32_t i = 0; i < length; ++i) { - cmp_list[i].cur_time = cur_time; - cmp_list[i].base_public_key = comp_public_key; - cmp_list[i].entry = list[i]; + if (t1) { + return -1; } - qsort(cmp_list, length, sizeof(DHT_Cmp_Data), dht_cmp_entry); + if (t2) { + return 1; + } - for (uint32_t i = 0; i < length; ++i) { - list[i] = cmp_list[i].entry; + const int closest = id_closest(cmp->comp_public_key, entry1->public_key, entry2->public_key); + + if (closest == 1) { + return 1; + } + + if (closest == 2) { + return -1; + } + + return 0; +} + +non_null() +static bool client_data_less_handler(const void *object, const void *a, const void *b) +{ + const Client_data_Cmp *cmp = (const Client_data_Cmp *)object; + const Client_data *entry1 = (const Client_data *)a; + const Client_data *entry2 = (const Client_data *)b; + + return client_data_cmp(cmp, entry1, entry2) < 0; +} + +non_null() +static const void *client_data_get_handler(const void *arr, uint32_t index) +{ + const Client_data *entries = (const Client_data *)arr; + return &entries[index]; +} + +non_null() +static void client_data_set_handler(void *arr, uint32_t index, const void *val) +{ + Client_data *entries = (Client_data *)arr; + const Client_data *entry = (const Client_data *)val; + entries[index] = *entry; +} + +non_null() +static void *client_data_subarr_handler(void *arr, uint32_t index, uint32_t size) +{ + Client_data *entries = (Client_data *)arr; + return &entries[index]; +} + +non_null() +static void *client_data_alloc_handler(const void *object, uint32_t size) +{ + const Client_data_Cmp *cmp = (const Client_data_Cmp *)object; + Client_data *tmp = (Client_data *)mem_valloc(cmp->mem, size, sizeof(Client_data)); + + if (tmp == nullptr) { + return nullptr; } - mem_delete(mem, cmp_list); + return tmp; +} + +non_null() +static void client_data_delete_handler(const void *object, void *arr, uint32_t size) +{ + const Client_data_Cmp *cmp = (const Client_data_Cmp *)object; + mem_delete(cmp->mem, arr); +} + +static const Sort_Funcs client_data_cmp_funcs = { + client_data_less_handler, + client_data_get_handler, + client_data_set_handler, + client_data_subarr_handler, + client_data_alloc_handler, + client_data_delete_handler, +}; + +non_null() +static void sort_client_list(const Memory *mem, Client_data *list, uint64_t cur_time, unsigned int length, + const uint8_t *comp_public_key) +{ + // Pass comp_public_key to merge_sort with each Client_data entry, so the + // comparison function can use it as the base of comparison. + const Client_data_Cmp cmp = { + mem, + cur_time, + comp_public_key, + }; + + merge_sort(list, length, &cmp, &client_data_cmp_funcs); } non_null() diff --git a/toxcore/crypto_core_test.cc b/toxcore/crypto_core_test.cc index 198022752ef..c62facba0a1 100644 --- a/toxcore/crypto_core_test.cc +++ b/toxcore/crypto_core_test.cc @@ -19,6 +19,35 @@ using ExtSecretKey = std::array; using Signature = std::array; using Nonce = std::array; +TEST(PkEqual, TwoRandomIdsAreNotEqual) +{ + std::mt19937 rng; + std::uniform_int_distribution dist{0, UINT8_MAX}; + + uint8_t pk1[CRYPTO_PUBLIC_KEY_SIZE]; + uint8_t pk2[CRYPTO_PUBLIC_KEY_SIZE]; + + std::generate(std::begin(pk1), std::end(pk1), [&]() { return dist(rng); }); + std::generate(std::begin(pk2), std::end(pk2), [&]() { return dist(rng); }); + + EXPECT_FALSE(pk_equal(pk1, pk2)); +} + +TEST(PkEqual, IdCopyMakesKeysEqual) +{ + std::mt19937 rng; + std::uniform_int_distribution dist{0, UINT8_MAX}; + + uint8_t pk1[CRYPTO_PUBLIC_KEY_SIZE]; + uint8_t pk2[CRYPTO_PUBLIC_KEY_SIZE] = {0}; + + std::generate(std::begin(pk1), std::end(pk1), [&]() { return dist(rng); }); + + pk_copy(pk2, pk1); + + EXPECT_TRUE(pk_equal(pk1, pk2)); +} + TEST(CryptoCore, EncryptLargeData) { Test_Random rng; diff --git a/toxcore/group.c b/toxcore/group.c index 14e61e6ffc3..33205643c63 100644 --- a/toxcore/group.c +++ b/toxcore/group.c @@ -20,6 +20,7 @@ #include "friend_connection.h" #include "group_common.h" #include "logger.h" +#include "mem.h" #include "mono_time.h" #include "net_crypto.h" #include "network.h" @@ -957,24 +958,75 @@ static bool delpeer(Group_Chats *g_c, uint32_t groupnumber, int peer_index, void /** Order peers with friends first and with more recently active earlier */ non_null() -static int cmp_frozen(const void *a, const void *b) +static bool group_peer_less_handler(const void *object, const void *a, const void *b) { const Group_Peer *pa = (const Group_Peer *)a; const Group_Peer *pb = (const Group_Peer *)b; - if (pa->is_friend ^ pb->is_friend) { - return pa->is_friend ? -1 : 1; + if (((pa->is_friend ? 1 : 0) ^ (pb->is_friend ? 1 : 0)) != 0) { + return pa->is_friend; } - return cmp_uint(pb->last_active, pa->last_active); + return cmp_uint(pb->last_active, pa->last_active) < 0; } +non_null() +static const void *group_peer_get_handler(const void *arr, uint32_t index) +{ + const Group_Peer *entries = (const Group_Peer *)arr; + return &entries[index]; +} + +non_null() +static void group_peer_set_handler(void *arr, uint32_t index, const void *val) +{ + Group_Peer *entries = (Group_Peer *)arr; + const Group_Peer *entry = (const Group_Peer *)val; + entries[index] = *entry; +} + +non_null() +static void *group_peer_subarr_handler(void *arr, uint32_t index, uint32_t size) +{ + Group_Peer *entries = (Group_Peer *)arr; + return &entries[index]; +} + +non_null() +static void *group_peer_alloc_handler(const void *object, uint32_t size) +{ + const Memory *mem = (const Memory *)object; + Group_Peer *tmp = (Group_Peer *)mem_valloc(mem, size, sizeof(Group_Peer)); + + if (tmp == nullptr) { + return nullptr; + } + + return tmp; +} + +non_null() +static void group_peer_delete_handler(const void *object, void *arr, uint32_t size) +{ + const Memory *mem = (const Memory *)object; + mem_delete(mem, arr); +} + +static const Sort_Funcs group_peer_cmp_funcs = { + group_peer_less_handler, + group_peer_get_handler, + group_peer_set_handler, + group_peer_subarr_handler, + group_peer_alloc_handler, + group_peer_delete_handler, +}; + /** @brief Delete frozen peers as necessary to ensure at most `g->maxfrozen` remain. * * @retval true if any frozen peers are removed. */ non_null() -static bool delete_old_frozen(Group_c *g) +static bool delete_old_frozen(Group_c *g, const Memory *mem) { if (g->numfrozen <= g->maxfrozen) { return false; @@ -987,7 +1039,7 @@ static bool delete_old_frozen(Group_c *g) return true; } - qsort(g->frozen, g->numfrozen, sizeof(Group_Peer), cmp_frozen); + merge_sort(g->frozen, g->numfrozen, mem, &group_peer_cmp_funcs); Group_Peer *temp = (Group_Peer *)realloc(g->frozen, g->maxfrozen * sizeof(Group_Peer)); @@ -1032,7 +1084,7 @@ static bool freeze_peer(Group_Chats *g_c, uint32_t groupnumber, int peer_index, ++g->numfrozen; - delete_old_frozen(g); + delete_old_frozen(g, g_c->m->mem); return true; } @@ -1519,7 +1571,7 @@ int group_set_max_frozen(const Group_Chats *g_c, uint32_t groupnumber, uint32_t } g->maxfrozen = maxfrozen; - delete_old_frozen(g); + delete_old_frozen(g, g_c->m->mem); return 0; } diff --git a/toxcore/onion_announce.c b/toxcore/onion_announce.c index 593d81aa2ca..d18e65e6804 100644 --- a/toxcore/onion_announce.c +++ b/toxcore/onion_announce.c @@ -24,6 +24,7 @@ #include "onion.h" #include "shared_key_cache.h" #include "timed_auth.h" +#include "util.h" #define PING_ID_TIMEOUT ONION_ANNOUNCE_TIMEOUT @@ -281,23 +282,17 @@ static int in_entries(const Onion_Announce *onion_a, const uint8_t *public_key) return -1; } -typedef struct Cmp_Data { +typedef struct Onion_Announce_Entry_Cmp { + const Memory *mem; const Mono_Time *mono_time; - const uint8_t *base_public_key; - Onion_Announce_Entry entry; -} Cmp_Data; + const uint8_t *comp_public_key; +} Onion_Announce_Entry_Cmp; non_null() -static int cmp_entry(const void *a, const void *b) +static int onion_announce_entry_cmp(const Onion_Announce_Entry_Cmp *cmp, const Onion_Announce_Entry *entry1, const Onion_Announce_Entry *entry2) { - const Cmp_Data *cmp1 = (const Cmp_Data *)a; - const Cmp_Data *cmp2 = (const Cmp_Data *)b; - const Onion_Announce_Entry entry1 = cmp1->entry; - const Onion_Announce_Entry entry2 = cmp2->entry; - const uint8_t *cmp_public_key = cmp1->base_public_key; - - const bool t1 = mono_time_is_timeout(cmp1->mono_time, entry1.announce_time, ONION_ANNOUNCE_TIMEOUT); - const bool t2 = mono_time_is_timeout(cmp1->mono_time, entry2.announce_time, ONION_ANNOUNCE_TIMEOUT); + const bool t1 = mono_time_is_timeout(cmp->mono_time, entry1->announce_time, ONION_ANNOUNCE_TIMEOUT); + const bool t2 = mono_time_is_timeout(cmp->mono_time, entry2->announce_time, ONION_ANNOUNCE_TIMEOUT); if (t1 && t2) { return 0; @@ -311,7 +306,7 @@ static int cmp_entry(const void *a, const void *b) return 1; } - const int closest = id_closest(cmp_public_key, entry1.public_key, entry2.public_key); + const int closest = id_closest(cmp->comp_public_key, entry1->public_key, entry2->public_key); if (closest == 1) { return 1; @@ -325,31 +320,80 @@ static int cmp_entry(const void *a, const void *b) } non_null() -static void sort_onion_announce_list(const Memory *mem, const Mono_Time *mono_time, - Onion_Announce_Entry *list, unsigned int length, - const uint8_t *comp_public_key) +static bool onion_announce_entry_less_handler(const void *object, const void *a, const void *b) { - // Pass comp_public_key to qsort with each Client_data entry, so the - // comparison function can use it as the base of comparison. - Cmp_Data *cmp_list = (Cmp_Data *)mem_valloc(mem, length, sizeof(Cmp_Data)); + const Onion_Announce_Entry_Cmp *cmp = (const Onion_Announce_Entry_Cmp *)object; + const Onion_Announce_Entry *entry1 = (const Onion_Announce_Entry *)a; + const Onion_Announce_Entry *entry2 = (const Onion_Announce_Entry *)b; - if (cmp_list == nullptr) { - return; - } + return onion_announce_entry_cmp(cmp, entry1, entry2) < 0; +} - for (uint32_t i = 0; i < length; ++i) { - cmp_list[i].mono_time = mono_time; - cmp_list[i].base_public_key = comp_public_key; - cmp_list[i].entry = list[i]; - } +non_null() +static const void *onion_announce_entry_get_handler(const void *arr, uint32_t index) +{ + const Onion_Announce_Entry *entries = (const Onion_Announce_Entry *)arr; + return &entries[index]; +} + +non_null() +static void onion_announce_entry_set_handler(void *arr, uint32_t index, const void *val) +{ + Onion_Announce_Entry *entries = (Onion_Announce_Entry *)arr; + const Onion_Announce_Entry *entry = (const Onion_Announce_Entry *)val; + entries[index] = *entry; +} - qsort(cmp_list, length, sizeof(Cmp_Data), cmp_entry); +non_null() +static void *onion_announce_entry_subarr_handler(void *arr, uint32_t index, uint32_t size) +{ + Onion_Announce_Entry *entries = (Onion_Announce_Entry *)arr; + return &entries[index]; +} - for (uint32_t i = 0; i < length; ++i) { - list[i] = cmp_list[i].entry; +non_null() +static void *onion_announce_entry_alloc_handler(const void *object, uint32_t size) +{ + const Onion_Announce_Entry_Cmp *cmp = (const Onion_Announce_Entry_Cmp *)object; + Onion_Announce_Entry *tmp = (Onion_Announce_Entry *)mem_valloc(cmp->mem, size, sizeof(Onion_Announce_Entry)); + + if (tmp == nullptr) { + return nullptr; } - mem_delete(mem, cmp_list); + return tmp; +} + +non_null() +static void onion_announce_entry_delete_handler(const void *object, void *arr, uint32_t size) +{ + const Onion_Announce_Entry_Cmp *cmp = (const Onion_Announce_Entry_Cmp *)object; + mem_delete(cmp->mem, arr); +} + +static const Sort_Funcs onion_announce_entry_cmp_funcs = { + onion_announce_entry_less_handler, + onion_announce_entry_get_handler, + onion_announce_entry_set_handler, + onion_announce_entry_subarr_handler, + onion_announce_entry_alloc_handler, + onion_announce_entry_delete_handler, +}; + +non_null() +static void sort_onion_announce_list(const Memory *mem, const Mono_Time *mono_time, + Onion_Announce_Entry *list, unsigned int length, + const uint8_t *comp_public_key) +{ + // Pass comp_public_key to sort with each Onion_Announce_Entry entry, so the + // comparison function can use it as the base of comparison. + const Onion_Announce_Entry_Cmp cmp = { + mem, + mono_time, + comp_public_key, + }; + + merge_sort(list, length, &cmp, &onion_announce_entry_cmp_funcs); } /** @brief add entry to entries list diff --git a/toxcore/onion_client.c b/toxcore/onion_client.c index 9b0ac96102a..5d1c7943968 100644 --- a/toxcore/onion_client.c +++ b/toxcore/onion_client.c @@ -694,23 +694,17 @@ static int client_send_announce_request(Onion_Client *onion_c, uint32_t num, con return send_onion_packet_tcp_udp(onion_c, &path, dest, request, len); } -typedef struct Onion_Client_Cmp_Data { +typedef struct Onion_Node_Cmp { + const Memory *mem; const Mono_Time *mono_time; - const uint8_t *base_public_key; - Onion_Node entry; -} Onion_Client_Cmp_Data; + const uint8_t *comp_public_key; +} Onion_Node_Cmp; non_null() -static int onion_client_cmp_entry(const void *a, const void *b) +static int onion_node_cmp(const Onion_Node_Cmp *cmp, const Onion_Node *entry1, const Onion_Node *entry2) { - const Onion_Client_Cmp_Data *cmp1 = (const Onion_Client_Cmp_Data *)a; - const Onion_Client_Cmp_Data *cmp2 = (const Onion_Client_Cmp_Data *)b; - const Onion_Node entry1 = cmp1->entry; - const Onion_Node entry2 = cmp2->entry; - const uint8_t *cmp_public_key = cmp1->base_public_key; - - const bool t1 = onion_node_timed_out(&entry1, cmp1->mono_time); - const bool t2 = onion_node_timed_out(&entry2, cmp2->mono_time); + const bool t1 = onion_node_timed_out(entry1, cmp->mono_time); + const bool t2 = onion_node_timed_out(entry2, cmp->mono_time); if (t1 && t2) { return 0; @@ -724,7 +718,7 @@ static int onion_client_cmp_entry(const void *a, const void *b) return 1; } - const int closest = id_closest(cmp_public_key, entry1.public_key, entry2.public_key); + const int closest = id_closest(cmp->comp_public_key, entry1->public_key, entry2->public_key); if (closest == 1) { return 1; @@ -738,30 +732,79 @@ static int onion_client_cmp_entry(const void *a, const void *b) } non_null() -static void sort_onion_node_list(const Memory *mem, const Mono_Time *mono_time, - Onion_Node *list, unsigned int length, const uint8_t *comp_public_key) +static bool onion_node_less_handler(const void *object, const void *a, const void *b) { - // Pass comp_public_key to qsort with each Client_data entry, so the - // comparison function can use it as the base of comparison. - Onion_Client_Cmp_Data *cmp_list = (Onion_Client_Cmp_Data *)mem_valloc(mem, length, sizeof(Onion_Client_Cmp_Data)); + const Onion_Node_Cmp *cmp = (const Onion_Node_Cmp *)object; + const Onion_Node *entry1 = (const Onion_Node *)a; + const Onion_Node *entry2 = (const Onion_Node *)b; - if (cmp_list == nullptr) { - return; - } + return onion_node_cmp(cmp, entry1, entry2) < 0; +} - for (uint32_t i = 0; i < length; ++i) { - cmp_list[i].mono_time = mono_time; - cmp_list[i].base_public_key = comp_public_key; - cmp_list[i].entry = list[i]; - } +non_null() +static const void *onion_node_get_handler(const void *arr, uint32_t index) +{ + const Onion_Node *entries = (const Onion_Node *)arr; + return &entries[index]; +} - qsort(cmp_list, length, sizeof(Onion_Client_Cmp_Data), onion_client_cmp_entry); +non_null() +static void onion_node_set_handler(void *arr, uint32_t index, const void *val) +{ + Onion_Node *entries = (Onion_Node *)arr; + const Onion_Node *entry = (const Onion_Node *)val; + entries[index] = *entry; +} + +non_null() +static void *onion_node_subarr_handler(void *arr, uint32_t index, uint32_t size) +{ + Onion_Node *entries = (Onion_Node *)arr; + return &entries[index]; +} - for (uint32_t i = 0; i < length; ++i) { - list[i] = cmp_list[i].entry; +non_null() +static void *onion_node_alloc_handler(const void *object, uint32_t size) +{ + const Onion_Node_Cmp *cmp = (const Onion_Node_Cmp *)object; + Onion_Node *tmp = (Onion_Node *)mem_valloc(cmp->mem, size, sizeof(Onion_Node)); + + if (tmp == nullptr) { + return nullptr; } - mem_delete(mem, cmp_list); + return tmp; +} + +non_null() +static void onion_node_delete_handler(const void *object, void *arr, uint32_t size) +{ + const Onion_Node_Cmp *cmp = (const Onion_Node_Cmp *)object; + mem_delete(cmp->mem, arr); +} + +static const Sort_Funcs onion_node_cmp_funcs = { + onion_node_less_handler, + onion_node_get_handler, + onion_node_set_handler, + onion_node_subarr_handler, + onion_node_alloc_handler, + onion_node_delete_handler, +}; + +non_null() +static void sort_onion_node_list(const Memory *mem, const Mono_Time *mono_time, + Onion_Node *list, unsigned int length, const uint8_t *comp_public_key) +{ + // Pass comp_public_key to sort with each Onion_Node entry, so the + // comparison function can use it as the base of comparison. + const Onion_Node_Cmp cmp = { + mem, + mono_time, + comp_public_key, + }; + + merge_sort(list, length, &cmp, &onion_node_cmp_funcs); } non_null() diff --git a/toxcore/util.c b/toxcore/util.c index 1851e58a080..c565579d1fd 100644 --- a/toxcore/util.c +++ b/toxcore/util.c @@ -16,6 +16,7 @@ #include #include +#include "attributes.h" #include "ccompat.h" #include "mem.h" @@ -181,3 +182,97 @@ uint32_t jenkins_one_at_a_time_hash(const uint8_t *key, size_t len) hash += (uint32_t)((uint64_t)hash << 15); return hash; } + +non_null() +static void merge_sort_merge_back( + void *arr, + const void *l_arr, uint32_t l_arr_size, + const void *r_arr, uint32_t r_arr_size, + uint32_t left_start, + const void *object, const Sort_Funcs *funcs) +{ + uint32_t li = 0; + uint32_t ri = 0; + uint32_t k = left_start; + + while (li < l_arr_size && ri < r_arr_size) { + const void *l = funcs->get_callback(l_arr, li); + const void *r = funcs->get_callback(r_arr, ri); + // !(r < l) <=> (r >= l) <=> (l <= r) + if (!funcs->less_callback(object, r, l)) { + funcs->set_callback(arr, k, l); + ++li; + } else { + funcs->set_callback(arr, k, r); + ++ri; + } + ++k; + } + + /* Copy the remaining elements of `l_arr[]`, if there are any. */ + while (li < l_arr_size) { + funcs->set_callback(arr, k, funcs->get_callback(l_arr, li)); + ++li; + ++k; + } + + /* Copy the remaining elements of `r_arr[]`, if there are any. */ + while (ri < r_arr_size) { + funcs->set_callback(arr, k, funcs->get_callback(r_arr, ri)); + ++ri; + ++k; + } +} + +/** Function to merge the two haves `arr[left_start..mid]` and `arr[mid+1..right_end]` of array `arr[]`. */ +non_null() +static void merge_sort_merge( + void *arr, uint32_t left_start, uint32_t mid, uint32_t right_end, void *tmp, + const void *object, const Sort_Funcs *funcs) +{ + const uint32_t l_arr_size = mid - left_start + 1; + const uint32_t r_arr_size = right_end - mid; + + /* Temporary arrays, using the tmp buffer created in `merge_sort` below. */ + void *l_arr = funcs->subarr_callback(tmp, 0, l_arr_size); + void *r_arr = funcs->subarr_callback(tmp, l_arr_size, r_arr_size); + + /* Copy data to temp arrays `l_arr[]` and `r_arr[]`. */ + for (uint32_t i = 0; i < l_arr_size; ++i) { + funcs->set_callback(l_arr, i, funcs->get_callback(arr, left_start + i)); + } + for (uint32_t i = 0; i < r_arr_size; ++i) { + funcs->set_callback(r_arr, i, funcs->get_callback(arr, mid + 1 + i)); + } + + /* Merge the temp arrays back into `arr[left_start..right_end]`. */ + merge_sort_merge_back(arr, l_arr, l_arr_size, r_arr, r_arr_size, left_start, object, funcs); +} + +bool merge_sort(void *arr, uint32_t arr_size, const void *object, const Sort_Funcs *funcs) +{ + void *tmp = funcs->alloc_callback(object, arr_size); + + if (tmp == nullptr) { + return false; + } + + // Merge subarrays in bottom up manner. First merge subarrays of + // size 1 to create sorted subarrays of size 2, then merge subarrays + // of size 2 to create sorted subarrays of size 4, and so on. + for (uint32_t curr_size = 1; curr_size <= arr_size - 1; curr_size = 2 * curr_size) { + // Pick starting point of different subarrays of current size + for (uint32_t left_start = 0; left_start < arr_size - 1; left_start += 2 * curr_size) { + // Find ending point of left subarray. mid+1 is starting + // point of right + const uint32_t mid = min_u32(left_start + curr_size - 1, arr_size - 1); + const uint32_t right_end = min_u32(left_start + 2 * curr_size - 1, arr_size - 1); + + // Merge Subarrays arr[left_start...mid] & arr[mid+1...right_end] + merge_sort_merge(arr, left_start, mid, right_end, tmp, object, funcs); + } + } + + funcs->delete_callback(object, tmp, arr_size); + return true; +} diff --git a/toxcore/util.h b/toxcore/util.h index 5be74a8d86c..4dd31ef4a16 100644 --- a/toxcore/util.h +++ b/toxcore/util.h @@ -96,6 +96,74 @@ uint32_t jenkins_one_at_a_time_hash(const uint8_t *key, size_t len); non_null() uint16_t data_checksum(const uint8_t *data, uint32_t length); +/** @brief Compare elements with a less-than ordering: `a < b`. */ +typedef bool sort_less_cb(const void *object, const void *a, const void *b); +/** @brief Get element from array at index. */ +typedef const void *sort_get_cb(const void *arr, uint32_t index); +/** @brief Set element in array at index to new value (perform copy). */ +typedef void sort_set_cb(void *arr, uint32_t index, const void *val); +/** @brief Get a sub-array at an index of a given size (mutable pointer). + * + * Used to index in the temporary array allocated by `sort_alloc_cb` and get + * a sub-array for working memory. + */ +typedef void *sort_subarr_cb(void *arr, uint32_t index, uint32_t size); +/** @brief Allocate a new array of the element type. + * + * @param size The array size in elements of type T (not byte size). This value + * is always exactly the input array size as passed to `merge_sort`. + */ +typedef void *sort_alloc_cb(const void *object, uint32_t size); +/** @brief Free the element type array. */ +typedef void sort_delete_cb(const void *object, void *arr, uint32_t size); + +/** @brief Virtual function table for getting/setting elements in an array and + * comparing them. + * + * Only the `less`, `alloc`, and `delete` functions get a `this`-pointer. We + * assume that indexing in an array doesn't need any other information than the + * array itself. + * + * For now, the `this`-pointer is const, because we assume sorting doesn't need + * to mutate any state, but if necessary that can be changed in the future. + */ +typedef struct Sort_Funcs { + sort_less_cb *less_callback; + sort_get_cb *get_callback; + sort_set_cb *set_callback; + sort_subarr_cb *subarr_callback; + sort_alloc_cb *alloc_callback; + sort_delete_cb *delete_callback; +} Sort_Funcs; + +/** @brief Non-recursive merge sort function to sort `arr[0...arr_size-1]`. + * + * Avoids `memcpy` and avoids treating elements as byte arrays. Instead, uses + * callbacks to index in arrays and copy elements. This makes it quite a bit + * slower than `qsort`, but works with elements that require special care when + * being copied (e.g. if they are part of a graph or other data structure that + * with pointers or other invariants). + * + * Allocates a single temporary array with the provided alloc callback, and + * frees it at the end. This is significantly faster than an in-place + * implementation. + * + * This could be made more efficient by providing range-copy functions instead + * of calling the get/set callback for every element, but that increases code + * complexity on the caller. + * + * Complexity: + * - Space: `O(n) where n = array_size`. + * - Time: `O(n * log n) where n = array_size`. + * + * @param[in,out] arr An array of type T. + * @param arr_size Number of elements in @p arr (count, not byte size). + * @param[in] object Comparator object. + * @param[in] funcs Callback struct for elements of type T. + */ +non_null() +bool merge_sort(void *arr, uint32_t arr_size, const void *object, const Sort_Funcs *funcs); + #ifdef __cplusplus } /* extern "C" */ #endif diff --git a/toxcore/util_test.cc b/toxcore/util_test.cc index 94e653f2188..d62883a5b14 100644 --- a/toxcore/util_test.cc +++ b/toxcore/util_test.cc @@ -2,47 +2,117 @@ #include -#include "crypto_core.h" -#include "crypto_core_test_util.hh" +#include +#include namespace { -TEST(Util, TwoRandomIdsAreNotEqual) +TEST(Cmp, OrdersNumbersCorrectly) +{ + EXPECT_EQ(cmp_uint(1, 2), -1); + EXPECT_EQ(cmp_uint(0, UINT32_MAX), -1); + EXPECT_EQ(cmp_uint(UINT32_MAX, 0), 1); + EXPECT_EQ(cmp_uint(UINT32_MAX, UINT32_MAX), 0); + EXPECT_EQ(cmp_uint(0, UINT64_MAX), -1); + EXPECT_EQ(cmp_uint(UINT64_MAX, 0), 1); + EXPECT_EQ(cmp_uint(UINT64_MAX, UINT64_MAX), 0); +} + +template +Sort_Funcs sort_funcs() { - Test_Random rng; - uint8_t pk1[CRYPTO_PUBLIC_KEY_SIZE]; - uint8_t sk1[CRYPTO_SECRET_KEY_SIZE]; - uint8_t pk2[CRYPTO_PUBLIC_KEY_SIZE]; - uint8_t sk2[CRYPTO_SECRET_KEY_SIZE]; + return { + [](const void *object, const void *va, const void *vb) { + const T *a = static_cast(va); + const T *b = static_cast(vb); - crypto_new_keypair(rng, pk1, sk1); - crypto_new_keypair(rng, pk2, sk2); + // Just check that *something* is passed. Don't care what. + EXPECT_NE(object, nullptr); - EXPECT_FALSE(pk_equal(pk1, pk2)); + return *a < *b; + }, + [](const void *arr, uint32_t index) -> const void * { + const T *vec = static_cast(arr); + return &vec[index]; + }, + [](void *arr, uint32_t index, const void *val) { + T *vec = static_cast(arr); + const T *value = static_cast(val); + vec[index] = *value; + }, + [](void *arr, uint32_t index, uint32_t size) -> void * { + T *vec = static_cast(arr); + return &vec[index]; + }, + [](const void *object, uint32_t size) -> void * { return new T[size]; }, + [](const void *object, void *arr, uint32_t size) { + T *vec = static_cast(arr); + delete[] vec; + }, + }; } -TEST(Util, IdCopyMakesKeysEqual) +TEST(MergeSort, BehavesLikeStdSort) { - Test_Random rng; - uint8_t pk1[CRYPTO_PUBLIC_KEY_SIZE]; - uint8_t sk1[CRYPTO_SECRET_KEY_SIZE]; - uint8_t pk2[CRYPTO_PUBLIC_KEY_SIZE] = {0}; + std::mt19937 rng; + // INT_MAX-1 so later we have room to add 1 larger element if needed. + std::uniform_int_distribution dist{ + std::numeric_limits::min(), std::numeric_limits::max() - 1}; + + const auto int_funcs = sort_funcs(); + + // Test with int arrays. + for (uint32_t i = 1; i < 1000; ++i) { + std::vector vec(i); + std::generate(std::begin(vec), std::end(vec), [&]() { return dist(rng); }); + + auto sorted = vec; + std::sort(sorted.begin(), sorted.end(), std::less()); - crypto_new_keypair(rng, pk1, sk1); - pk_copy(pk2, pk1); + // If vec was accidentally sorted, add another larger element that almost definitely makes + // it not sorted. + if (vec == sorted) { + int const largest = *std::prev(sorted.end()) + 1; + sorted.push_back(largest); + vec.insert(vec.begin(), largest); + } + ASSERT_NE(vec, sorted); - EXPECT_TRUE(pk_equal(pk1, pk2)); + // Just pass some arbitrary "self" to make sure the callbacks pass it through. + ASSERT_TRUE(merge_sort(vec.data(), vec.size(), &i, &int_funcs)); + ASSERT_EQ(vec, sorted); + } } -TEST(Cmp, OrdersNumbersCorrectly) +TEST(MergeSort, WorksWithNonTrivialTypes) { - EXPECT_EQ(cmp_uint(1, 2), -1); - EXPECT_EQ(cmp_uint(0, UINT32_MAX), -1); - EXPECT_EQ(cmp_uint(UINT32_MAX, 0), 1); - EXPECT_EQ(cmp_uint(UINT32_MAX, UINT32_MAX), 0); - EXPECT_EQ(cmp_uint(0, UINT64_MAX), -1); - EXPECT_EQ(cmp_uint(UINT64_MAX, 0), 1); - EXPECT_EQ(cmp_uint(UINT64_MAX, UINT64_MAX), 0); + std::mt19937 rng; + std::uniform_int_distribution dist{ + std::numeric_limits::min(), std::numeric_limits::max()}; + + const auto string_funcs = sort_funcs(); + + // Test with std::string arrays. + for (uint32_t i = 1; i < 500; ++i) { + std::vector vec(i); + std::generate(std::begin(vec), std::end(vec), [&]() { return std::to_string(dist(rng)); }); + + auto sorted = vec; + std::sort(sorted.begin(), sorted.end(), std::less()); + + // If vec was accidentally sorted, add another larger element that almost definitely makes + // it not sorted. + if (vec == sorted) { + std::string const largest = "larger than largest int"; + sorted.push_back(largest); + vec.insert(vec.begin(), largest); + } + ASSERT_NE(vec, sorted); + + // Just pass some arbitrary "self" to make sure the callbacks pass it through. + ASSERT_TRUE(merge_sort(vec.data(), vec.size(), &i, &string_funcs)); + ASSERT_EQ(vec, sorted); + } } } // namespace