From e2fbe963fb2bd2e971cc81cf8299f9e97b4f986e Mon Sep 17 00:00:00 2001 From: CjS77 Date: Sat, 25 Apr 2020 19:46:58 +0200 Subject: [PATCH] Testnet release (Rincewind) ** Major Changes from 0.0.9 **** Store and forward Peers will hold onto message for recipients that are not online and deliver the messages to them when they appear again. **** OsX package installer **** Many documentation improvements **** Ephemeral keys for private messages This is a big change that preserves privacy on the network but dramatically reduces the amount of traffic peers have to deal with. **** Emoji Ids *** Other changes - The target difficulty for a specified PoW algorithm is included in the block header. This allows the target difficulty of any block height to be calculated by only processing the last set of target difficulty samples up to that height. - Don't mark peers as offline if there are no existing connections (#1763) - Add UTXO selection strategy for large txs - Base node: Dynamically determine build version (#1760) - Include random peers for liveness ping (#1753) - RandomX - Version Update (#1754) - Add generic debug log function to FFI (#1752) - Lots of logging improvements - Added list-transactions and cancel-transaction commands (#1746) - ASCII table output for list-peers and list-connections (#1709) - Improve Difficulty adjustment manager - Modular configuration via ConfigLoader and ConfigPath traits - Fix chain monitoring bug in Transaction Service (#1739) - Empty Emoji String Bug Fix (#1736) - Coin-split base node cli command - Complete the basic OSX pkg build - Perform reorgs only on stronger tip accumulated difficulties - Use filesystem storage for dht.db on libwallet (#1735) - Fix duplicate message propagation (#1730) - Introduced accumulated difficulty validators to allow different rules for testing and running running a base node. - - Changes to peer offline handling (#1716) - Update Transaction cancellation to work for Inbound and Outbound Txs - Added oneshot reply to outbound messaging (#1703) - Add transaction stress test command to CLI - Implemented basic `make-it-rain` command - Fix MmrCache rewind issue - Use ephemeral key for private messages (e.g Discovery) (#1686) - Limit orphan pool size - Added a function to list UTXOs in the console (#1678) - Prevent adding yourself as a peer (#1665) - Update transaction weights (#1661) - Fix block period calculation - Validators will now check the weight of a block when doing validation (#1648) - Cleaned up duplicate code from the Blockchain db - The ban peer log will now supply n reason why the peer was banned (#1638) --- .github/workflows/clippy-check.yml | 2 +- README.md | 126 +- RELEASE_CHECKLIST.md | 20 +- applications/tari_base_node/Cargo.toml | 27 +- .../tari_base_node/assets/tari_banner.rs | 25 + .../tari_base_node/assets/tari_logo.rs | 21 +- applications/tari_base_node/build.rs | 97 + .../tari_base_node/osx-pkg/build-pkg.sh | 107 ++ .../tari_base_node/osx-pkg/env-sample | 8 + .../osx-pkg/scripts/postinstall | 72 + .../tari_base_node/osx-pkg/scripts/preinstall | 10 + .../tari_base_node/osx/post-install.sh | 211 +++ .../tari_base_node/osx/uninstall-pkg.sh | 33 + applications/tari_base_node/src/builder.rs | 234 ++- applications/tari_base_node/src/cli.rs | 189 +- applications/tari_base_node/src/main.rs | 115 +- applications/tari_base_node/src/miner.rs | 7 + applications/tari_base_node/src/parser.rs | 988 ++++++++++- applications/tari_base_node/src/table.rs | 147 ++ applications/tari_base_node/src/utils.rs | 66 +- applications/test_faucet/Cargo.toml | 2 +- base_layer/core/Cargo.toml | 27 +- .../chain_metadata_service/handle.rs | 8 + .../chain_metadata_service/service.rs | 66 +- .../comms_interface/comms_request.rs | 4 +- .../comms_interface/inbound_handlers.rs | 45 +- .../comms_interface/local_interface.rs | 10 +- .../comms_interface/outbound_interface.rs | 4 +- .../core/src/base_node/proto/request.proto | 2 +- .../core/src/base_node/proto/request.rs | 11 +- .../core/src/base_node/service/initializer.rs | 1 + base_layer/core/src/base_node/service/mod.rs | 1 + .../core/src/base_node/service/service.rs | 6 +- .../core/src/base_node/state_machine.rs | 10 +- .../core/src/base_node/states/block_sync.rs | 58 +- .../src/base_node/states/events_and_states.rs | 2 +- .../base_node/states/forward_block_sync.rs | 9 +- .../core/src/base_node/states/listening.rs | 30 +- base_layer/core/src/base_node/states/mod.rs | 1 - .../src/base_node/states/starting_state.rs | 14 +- base_layer/core/src/blocks/block.rs | 2 + base_layer/core/src/blocks/genesis_block.rs | 3 +- base_layer/core/src/chain_storage/async_db.rs | 5 +- .../src/chain_storage/blockchain_database.rs | 747 ++++---- .../core/src/chain_storage}/consts.rs | 7 +- base_layer/core/src/chain_storage/error.rs | 66 +- .../core/src/chain_storage/lmdb_db/lmdb_db.rs | 184 +- .../core/src/chain_storage/lmdb_db/mod.rs | 1 + .../src/chain_storage/memory_db/memory_db.rs | 121 +- .../core/src/chain_storage/memory_db/mod.rs | 1 + base_layer/core/src/chain_storage/mod.rs | 8 +- .../core/src/consensus/consensus_constants.rs | 28 +- .../core/src/consensus/consensus_manager.rs | 147 +- base_layer/core/src/consensus/network.rs | 2 +- base_layer/core/src/helpers/mock_backend.rs | 23 +- base_layer/core/src/helpers/mod.rs | 12 +- base_layer/core/src/mempool/async_mempool.rs | 3 +- base_layer/core/src/mempool/consts.rs | 1 - base_layer/core/src/mempool/error.rs | 3 + base_layer/core/src/mempool/mempool.rs | 198 +-- .../core/src/mempool/mempool_storage.rs | 289 +++ base_layer/core/src/mempool/mod.rs | 40 +- .../core/src/mempool/orphan_pool/mod.rs | 1 + .../src/mempool/orphan_pool/orphan_pool.rs | 12 +- .../orphan_pool/orphan_pool_storage.rs | 4 +- .../core/src/mempool/pending_pool/mod.rs | 3 +- .../src/mempool/pending_pool/pending_pool.rs | 264 ++- .../pending_pool/pending_pool_storage.rs | 223 --- .../priority/prioritized_transaction.rs | 1 + .../core/src/mempool/proto/mempool_request.rs | 2 + .../src/mempool/proto/mempool_response.rs | 2 + base_layer/core/src/mempool/proto/mod.rs | 1 + .../src/mempool/proto/service_request.proto | 6 +- .../src/mempool/proto/service_response.proto | 4 +- .../src/mempool/proto/state_response.proto | 22 + .../core/src/mempool/proto/state_response.rs | 98 ++ base_layer/core/src/mempool/reorg_pool/mod.rs | 1 + .../core/src/mempool/reorg_pool/reorg_pool.rs | 31 +- .../mempool/reorg_pool/reorg_pool_storage.rs | 5 + .../src/mempool/service/inbound_handlers.rs | 3 + .../core/src/mempool/service/initializer.rs | 10 +- .../core/src/mempool/service/local_service.rs | 13 +- base_layer/core/src/mempool/service/mod.rs | 3 + .../core/src/mempool/service/request.rs | 2 + .../core/src/mempool/service/response.rs | 7 +- .../core/src/mempool/service/service.rs | 48 +- .../src/mempool/unconfirmed_pool/error.rs | 3 - .../core/src/mempool/unconfirmed_pool/mod.rs | 3 +- .../unconfirmed_pool/unconfirmed_pool.rs | 322 ++-- .../unconfirmed_pool_storage.rs | 210 --- base_layer/core/src/mining/blake_miner.rs | 19 +- base_layer/core/src/mining/miner.rs | 146 +- .../diff_adj_manager/diff_adj_manager.rs | 137 -- .../diff_adj_manager/diff_adj_storage.rs | 462 ----- base_layer/core/src/proof_of_work/error.rs | 2 + .../core/src/proof_of_work/lwma_diff.rs | 8 +- .../src/proof_of_work/median_timestamp.rs | 43 + base_layer/core/src/proof_of_work/mod.rs | 8 +- .../core/src/proof_of_work/monero_rx.rs | 8 +- .../core/src/proof_of_work/proof_of_work.rs | 9 +- .../src/proof_of_work/target_difficulty.rs | 44 + base_layer/core/src/proto/block.proto | 1 + base_layer/core/src/proto/block.rs | 2 + .../core/src/transactions/aggregated_body.rs | 6 + base_layer/core/src/transactions/fee.rs | 24 +- base_layer/core/src/transactions/helpers.rs | 2 +- .../core/src/transactions/transaction.rs | 14 +- .../transactions/transaction_protocol/mod.rs | 31 + .../transaction_protocol/recipient.rs | 16 +- .../transaction_protocol/sender.rs | 50 +- .../transaction_initializer.rs | 34 +- .../validation/accum_difficulty_validators.rs | 63 + .../core/src/validation/block_validators.rs | 92 +- base_layer/core/src/validation/error.rs | 4 + base_layer/core/src/validation/helpers.rs | 78 +- base_layer/core/src/validation/mocks.rs | 34 +- base_layer/core/src/validation/mod.rs | 10 +- base_layer/core/src/validation/traits.rs | 27 +- .../src/validation/transaction_validators.rs | 69 +- base_layer/core/tests/async_db.rs | 12 +- base_layer/core/tests/block_validation.rs | 13 +- .../chain_storage_tests/chain_backend.rs | 605 +++++-- .../chain_storage_tests/chain_storage.rs | 528 ++++-- base_layer/core/tests/diff_adj_manager.rs | 548 ------ base_layer/core/tests/helpers/mod.rs | 1 + base_layer/core/tests/helpers/nodes.rs | 26 +- .../core/tests/helpers/pow_blockchain.rs | 104 ++ base_layer/core/tests/median_timestamp.rs | 136 ++ base_layer/core/tests/mempool.rs | 18 +- base_layer/core/tests/node_comms_interface.rs | 11 - base_layer/core/tests/node_service.rs | 25 +- base_layer/core/tests/node_state_machine.rs | 15 +- base_layer/core/tests/target_difficulty.rs | 206 +++ base_layer/core/tests/wallet.rs | 10 +- base_layer/mmr/Cargo.toml | 4 +- base_layer/mmr/src/mmr_cache.rs | 5 +- base_layer/mmr/tests/mmr_cache.rs | 68 + base_layer/p2p/Cargo.toml | 10 +- base_layer/p2p/examples/pingpong.rs | 2 +- .../src/comms_connector/inbound_connector.rs | 54 +- .../p2p/src/comms_connector/peer_message.rs | 19 +- base_layer/p2p/src/domain_message.rs | 25 +- base_layer/p2p/src/initialization.rs | 13 +- .../p2p/src/services/liveness/config.rs | 13 +- .../p2p/src/services/liveness/handle.rs | 2 - base_layer/p2p/src/services/liveness/mock.rs | 3 - base_layer/p2p/src/services/liveness/mod.rs | 31 +- .../p2p/src/services/liveness/neighbours.rs | 66 - .../p2p/src/services/liveness/peer_pool.rs | 120 ++ .../p2p/src/services/liveness/service.rs | 334 +++- base_layer/p2p/src/services/liveness/state.rs | 16 +- base_layer/p2p/src/services/utils.rs | 1 + base_layer/p2p/src/test_utils.rs | 39 +- base_layer/p2p/tests/services/liveness.rs | 18 +- base_layer/wallet/Cargo.toml | 14 +- base_layer/wallet/src/error.rs | 2 + .../src/output_manager_service/handle.rs | 28 +- .../src/output_manager_service/service.rs | 199 ++- .../storage/database.rs | 11 +- .../storage/memory_db.rs | 8 +- .../storage/sqlite_db.rs | 18 +- base_layer/wallet/src/testnet_utils.rs | 169 +- .../src/text_message_service/service.rs | 24 +- .../wallet/src/transaction_service/error.rs | 34 + .../wallet/src/transaction_service/handle.rs | 69 +- .../wallet/src/transaction_service/mod.rs | 19 +- .../src/transaction_service/protocols}/mod.rs | 12 +- .../transaction_broadcast_protocol.rs | 460 +++++ .../transaction_chain_monitoring_protocol.rs | 479 +++++ .../transaction_receive_protocol.rs} | 26 +- .../protocols/transaction_send_protocol.rs | 511 ++++++ .../wallet/src/transaction_service/service.rs | 1503 +++++++--------- .../transaction_service/storage/database.rs | 43 +- .../transaction_service/storage/memory_db.rs | 113 +- .../transaction_service/storage/sqlite_db.rs | 82 +- base_layer/wallet/src/util/emoji.rs | 43 +- base_layer/wallet/src/util/luhn.rs | 3 + base_layer/wallet/src/wallet.rs | 14 +- .../tests/output_manager_service/service.rs | 97 +- .../tests/output_manager_service/storage.rs | 34 +- .../tests/support/comms_and_services.rs | 4 +- .../tests/transaction_service/service.rs | 1568 ++++++++++------- .../tests/transaction_service/storage.rs | 15 +- base_layer/wallet/tests/wallet/mod.rs | 265 ++- base_layer/wallet_ffi/Cargo.toml | 12 +- base_layer/wallet_ffi/README.md | 35 +- base_layer/wallet_ffi/src/callback_handler.rs | 116 +- base_layer/wallet_ffi/src/error.rs | 4 + base_layer/wallet_ffi/src/lib.rs | 258 ++- base_layer/wallet_ffi/wallet.h | 23 +- buildtools/base_node.Dockerfile | 2 +- common/Cargo.toml | 10 +- {config => common/config}/README.md | 0 common/config/presets/rincewind-simple.toml | 33 + {config => common/config}/tari.config.json | 0 .../config}/tari_config_sample.toml | 0 common/examples/base_node_init.rs | 77 + common/examples/mempool_config.rs | 153 ++ common/presets | 1 - common/src/configuration/bootstrap.rs | 424 +++++ common/src/configuration/error.rs | 35 + .../global.rs} | 547 ++---- common/src/configuration/loader.rs | 505 ++++++ common/src/configuration/mod.rs | 46 + common/src/configuration/utils.rs | 229 +++ common/src/dir_utils.rs | 4 +- common/src/lib.rs | 221 +-- comms/Cargo.toml | 6 +- comms/dht/Cargo.toml | 12 +- comms/dht/diesel.toml | 5 + comms/dht/examples/memorynet.rs | 236 ++- comms/dht/migrations/.gitkeep | 0 .../2020-04-01-095825_initial/down.sql | 2 + .../2020-04-01-095825_initial/up.sql | 27 + .../down.sql | 1 + .../up.sql | 20 + .../down.sql | 1 + .../up.sql | 1 + .../down.sql | 1 + .../up.sql | 1 + comms/dht/src/actor.rs | 288 ++- comms/dht/src/broadcast_strategy.rs | 27 +- comms/dht/src/builder.rs | 21 +- comms/dht/src/config.rs | 29 +- comms/dht/src/crypt.rs | 8 +- comms/dht/src/dedup.rs | 232 +++ comms/dht/src/dht.rs | 219 ++- comms/dht/src/discovery/error.rs | 7 +- comms/dht/src/discovery/requester.rs | 80 +- comms/dht/src/discovery/service.rs | 386 ++-- comms/dht/src/envelope.rs | 196 ++- comms/dht/src/inbound/decryption.rs | 261 ++- comms/dht/src/inbound/dedup.rs | 166 -- comms/dht/src/inbound/deserialize.rs | 99 +- .../dht/src/inbound/dht_handler/middleware.rs | 6 +- comms/dht/src/inbound/dht_handler/task.rs | 85 +- comms/dht/src/inbound/error.rs | 3 - comms/dht/src/inbound/message.rs | 68 +- comms/dht/src/inbound/mod.rs | 2 - comms/dht/src/inbound/validate.rs | 129 +- comms/dht/src/lib.rs | 16 +- comms/dht/src/logging_middleware.rs | 28 +- comms/dht/src/macros.rs | 8 +- comms/dht/src/outbound/broadcast.rs | 173 +- comms/dht/src/outbound/encryption.rs | 202 --- comms/dht/src/outbound/error.rs | 8 +- comms/dht/src/outbound/message.rs | 126 +- comms/dht/src/outbound/message_params.rs | 11 +- comms/dht/src/outbound/message_send_state.rs | 230 +++ comms/dht/src/outbound/mock.rs | 28 +- comms/dht/src/outbound/mod.rs | 24 +- comms/dht/src/outbound/requester.rs | 6 +- comms/dht/src/outbound/serialize.rs | 174 +- comms/dht/src/proto/envelope.proto | 19 +- comms/dht/src/proto/mod.rs | 15 + comms/dht/src/proto/store_forward.proto | 18 +- comms/dht/src/proto/tari.dht.envelope.rs | 22 +- comms/dht/src/proto/tari.dht.store_forward.rs | 26 +- comms/dht/src/schema.rs | 25 + comms/dht/src/storage/connection.rs | 151 ++ comms/dht/src/storage/database.rs | 87 + comms/dht/src/storage/dht_setting_entry.rs | 51 + comms/dht/src/storage/error.rs | 37 + .../error.rs => comms/dht/src/storage/mod.rs | 19 +- comms/dht/src/store_forward/database/mod.rs | 278 +++ .../store_forward/database/stored_message.rs | 93 + comms/dht/src/store_forward/error.rs | 26 +- comms/dht/src/store_forward/forward.rs | 123 +- comms/dht/src/store_forward/message.rs | 66 +- comms/dht/src/store_forward/mod.rs | 25 +- .../src/store_forward/saf_handler/layer.rs | 15 +- .../store_forward/saf_handler/middleware.rs | 10 +- .../dht/src/store_forward/saf_handler/task.rs | 453 +++-- comms/dht/src/store_forward/service.rs | 430 +++++ comms/dht/src/store_forward/store.rs | 468 +++-- comms/dht/src/test_utils/dht_actor_mock.rs | 38 +- .../dht/src/test_utils/dht_discovery_mock.rs | 5 +- comms/dht/src/test_utils/makers.rs | 126 +- comms/dht/src/test_utils/mod.rs | 15 +- .../src/test_utils/store_and_forward_mock.rs | 136 ++ comms/dht/src/tower_filter/error.rs | 46 - comms/dht/src/tower_filter/future.rs | 16 +- comms/dht/src/tower_filter/mod.rs | 13 +- comms/dht/src/tower_filter/predicate.rs | 6 +- comms/dht/src/tower_filter/test.rs | 81 - .../src/{store_forward/state.rs => utils.rs} | 43 +- comms/dht/tests/dht.rs | 269 ++- comms/examples/tor.rs | 5 +- comms/src/builder/comms_node.rs | 5 + comms/src/builder/tests.rs | 4 - comms/src/connection_manager/common.rs | 2 + comms/src/connection_manager/dialer.rs | 2 +- comms/src/connection_manager/error.rs | 2 - comms/src/connection_manager/manager.rs | 61 +- .../src/connection_manager/peer_connection.rs | 11 + comms/src/connection_manager/requester.rs | 27 +- comms/src/connection_manager/tests/manager.rs | 55 +- comms/src/connection_manager/wire_mode.rs | 4 +- comms/src/consts.rs | 4 - comms/src/message/envelope.rs | 151 +- comms/src/message/error.rs | 4 +- comms/src/message/inbound.rs | 7 +- comms/src/message/mod.rs | 25 +- comms/src/message/outbound.rs | 55 +- comms/src/peer_manager/connection_stats.rs | 25 +- comms/src/peer_manager/manager.rs | 242 ++- comms/src/peer_manager/node_id.rs | 31 +- comms/src/peer_manager/peer.rs | 92 +- comms/src/peer_manager/peer_features.rs | 7 + comms/src/peer_manager/peer_query.rs | 6 +- comms/src/peer_manager/peer_storage.rs | 189 +- comms/src/pipeline/error.rs | 12 +- comms/src/pipeline/outbound.rs | 9 +- comms/src/proto/control_service/header.proto | 17 - comms/src/proto/control_service/ping.proto | 6 - .../control_service/request_connection.proto | 35 - comms/src/proto/envelope.proto | 14 - comms/src/proto/mod.rs | 3 - comms/src/proto/tari.comms.control_service.rs | 61 - comms/src/proto/tari.comms.envelope.rs | 20 - comms/src/protocol/identity.rs | 3 +- comms/src/protocol/messaging/error.rs | 10 +- comms/src/protocol/messaging/inbound.rs | 74 - comms/src/protocol/messaging/mod.rs | 3 +- comms/src/protocol/messaging/outbound.rs | 94 +- comms/src/protocol/messaging/protocol.rs | 98 +- comms/src/protocol/messaging/test.rs | 60 +- .../test_utils/mocks/connection_manager.rs | 5 +- comms/src/test_utils/mocks/peer_connection.rs | 2 +- comms/src/test_utils/node_identity.rs | 6 +- comms/src/utils/cidr.rs | 2 +- comms/src/utils/datetime.rs | 33 + comms/src/utils/mod.rs | 1 + config/presets/rincewind-simple.toml | 16 - infrastructure/storage/Cargo.toml | 3 +- .../src/key_val_store/lmdb_database.rs | 2 +- .../storage/src/key_val_store/mod.rs | 1 + .../storage/src/lmdb_store/error.rs | 37 +- .../storage/src/lmdb_store/store.rs | 35 +- scripts/build-dists-tarball.sh | 37 +- scripts/create_bundle.sh | 2 +- scripts/publish_crates.sh | 27 +- scripts/update_crate_metadata.sh | 24 +- 343 files changed, 18181 insertions(+), 9654 deletions(-) create mode 100644 applications/tari_base_node/assets/tari_banner.rs create mode 100644 applications/tari_base_node/build.rs create mode 100755 applications/tari_base_node/osx-pkg/build-pkg.sh create mode 100644 applications/tari_base_node/osx-pkg/env-sample create mode 100755 applications/tari_base_node/osx-pkg/scripts/postinstall create mode 100755 applications/tari_base_node/osx-pkg/scripts/preinstall create mode 100755 applications/tari_base_node/osx/post-install.sh create mode 100755 applications/tari_base_node/osx/uninstall-pkg.sh create mode 100644 applications/tari_base_node/src/table.rs rename {applications/tari_base_node/src => base_layer/core/src/chain_storage}/consts.rs (89%) create mode 100644 base_layer/core/src/mempool/mempool_storage.rs delete mode 100644 base_layer/core/src/mempool/pending_pool/pending_pool_storage.rs create mode 100644 base_layer/core/src/mempool/proto/state_response.proto create mode 100644 base_layer/core/src/mempool/proto/state_response.rs delete mode 100644 base_layer/core/src/mempool/unconfirmed_pool/unconfirmed_pool_storage.rs delete mode 100644 base_layer/core/src/proof_of_work/diff_adj_manager/diff_adj_manager.rs delete mode 100644 base_layer/core/src/proof_of_work/diff_adj_manager/diff_adj_storage.rs create mode 100644 base_layer/core/src/proof_of_work/median_timestamp.rs create mode 100644 base_layer/core/src/proof_of_work/target_difficulty.rs create mode 100644 base_layer/core/src/validation/accum_difficulty_validators.rs delete mode 100644 base_layer/core/tests/diff_adj_manager.rs create mode 100644 base_layer/core/tests/helpers/pow_blockchain.rs create mode 100644 base_layer/core/tests/median_timestamp.rs create mode 100644 base_layer/core/tests/target_difficulty.rs delete mode 100644 base_layer/p2p/src/services/liveness/neighbours.rs create mode 100644 base_layer/p2p/src/services/liveness/peer_pool.rs rename base_layer/{core/src/proof_of_work/diff_adj_manager => wallet/src/transaction_service/protocols}/mod.rs (88%) create mode 100644 base_layer/wallet/src/transaction_service/protocols/transaction_broadcast_protocol.rs create mode 100644 base_layer/wallet/src/transaction_service/protocols/transaction_chain_monitoring_protocol.rs rename base_layer/{core/src/base_node/states/error.rs => wallet/src/transaction_service/protocols/transaction_receive_protocol.rs} (77%) create mode 100644 base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs rename {config => common/config}/README.md (100%) create mode 100644 common/config/presets/rincewind-simple.toml rename {config => common/config}/tari.config.json (100%) rename {config => common/config}/tari_config_sample.toml (100%) create mode 100644 common/examples/base_node_init.rs create mode 100644 common/examples/mempool_config.rs delete mode 120000 common/presets create mode 100644 common/src/configuration/bootstrap.rs create mode 100644 common/src/configuration/error.rs rename common/src/{configuration.rs => configuration/global.rs} (57%) create mode 100644 common/src/configuration/loader.rs create mode 100644 common/src/configuration/mod.rs create mode 100644 common/src/configuration/utils.rs create mode 100644 comms/dht/diesel.toml create mode 100644 comms/dht/migrations/.gitkeep create mode 100644 comms/dht/migrations/2020-04-01-095825_initial/down.sql create mode 100644 comms/dht/migrations/2020-04-01-095825_initial/up.sql create mode 100644 comms/dht/migrations/2020-04-07-161148_remove_origin_signature/down.sql create mode 100644 comms/dht/migrations/2020-04-07-161148_remove_origin_signature/up.sql create mode 100644 comms/dht/migrations/2020-04-16-165626_clear_stored_messages/down.sql create mode 100644 comms/dht/migrations/2020-04-16-165626_clear_stored_messages/up.sql create mode 100644 comms/dht/migrations/2020-04-20-082924_rename_settings_to_metadata/down.sql create mode 100644 comms/dht/migrations/2020-04-20-082924_rename_settings_to_metadata/up.sql create mode 100644 comms/dht/src/dedup.rs delete mode 100644 comms/dht/src/inbound/dedup.rs delete mode 100644 comms/dht/src/outbound/encryption.rs create mode 100644 comms/dht/src/outbound/message_send_state.rs create mode 100644 comms/dht/src/schema.rs create mode 100644 comms/dht/src/storage/connection.rs create mode 100644 comms/dht/src/storage/database.rs create mode 100644 comms/dht/src/storage/dht_setting_entry.rs create mode 100644 comms/dht/src/storage/error.rs rename base_layer/core/src/proof_of_work/diff_adj_manager/error.rs => comms/dht/src/storage/mod.rs (81%) create mode 100644 comms/dht/src/store_forward/database/mod.rs create mode 100644 comms/dht/src/store_forward/database/stored_message.rs create mode 100644 comms/dht/src/store_forward/service.rs create mode 100644 comms/dht/src/test_utils/store_and_forward_mock.rs delete mode 100644 comms/dht/src/tower_filter/error.rs delete mode 100644 comms/dht/src/tower_filter/test.rs rename comms/dht/src/{store_forward/state.rs => utils.rs} (61%) delete mode 100644 comms/src/proto/control_service/header.proto delete mode 100644 comms/src/proto/control_service/ping.proto delete mode 100644 comms/src/proto/control_service/request_connection.proto delete mode 100644 comms/src/proto/tari.comms.control_service.rs delete mode 100644 comms/src/protocol/messaging/inbound.rs create mode 100644 comms/src/utils/datetime.rs delete mode 100644 config/presets/rincewind-simple.toml diff --git a/.github/workflows/clippy-check.yml b/.github/workflows/clippy-check.yml index 7133aebc59..0acb6003ae 100644 --- a/.github/workflows/clippy-check.yml +++ b/.github/workflows/clippy-check.yml @@ -8,7 +8,7 @@ jobs: - uses: actions-rs/toolchain@v1 with: toolchain: nightly - components: clippy + components: clippy, rustfmt override: true - name: Install dependencies run: | diff --git a/README.md b/README.md index c3b60236e3..9136692dd3 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ You can check that the binaries match the hash by running If you have docker on your machine, you can run a prebuilt node using one of the docker images on [quay.io](https://quay.io/user/tarilabs). -### Building from Source (Ubuntu 18.04) +### Building from source (Ubuntu 18.04) To build the Tari codebase from source, there are a few dependencies you need to have installed. @@ -72,88 +72,121 @@ A successful build should output something as follows The executable is currently inside your `target/release` folder. You can run it from that folder if you like, but you'll more likely want to copy it somewhere more convenient. You can simply run - cargo install -p tari_base_node [![Build](https://circleci.com/gh/tari-project/tari.svg?style=svg)](https://circleci.com/gh/tari-project/tari) - -# The Tari protocol - -## Installing the base node software - -### Using binaries - -[Download binaries from tari.com](https://tari.com/downloads). This is the easiest way to run a Tari node, but you're -essentially trusting the person that built and uploaded them that nothing untoward has happened. - -We've tried to limit the risks by publishing [hashes of the binaries](https://tari.com/downloads) on our website. - -You can check that the binaries match the hash by running + cargo install -p tari_base_node - sha256sum path/to/tari_base_node +and cargo will copy the executable into `~/.cargo/bin`. This folder was added to your path in a previous step, so it +will be executable from anywhere on your system. -### Running a node in Docker +Alternatively, you can run the node from your source folder with the command -If you have docker on your machine, you can run a prebuilt node using one of the docker images on -[quay.io](https://quay.io/user/tarilabs). + cargo run -p tari_base_node -### Building from Source (Ubuntu 18.04) +### Building from source (Windows 10) -To build the Tari codebase from source, there are a few dependencies you need to have installed. +To build the Tari codebase from source on Windows 10, there are a few dependencies you need to have installed. +_**Note:** The Tari codebase does not work in Windows Subsystem for Linux version 1 (WSL 1), as the low-level calls +used by LMBD breaks it. Compatibility with WSL-2 must still be tested in future when it is released in a stable +Windows build._ -#### Install development packages +#### Install dependencies First you'll need to make sure you have a full development environment set up: -``` -sudo apt-get -y install openssl libssl-dev pkg-config libsqlite3-dev clang git cmake libc++-dev libc++abi-dev -``` +- git + - https://git-scm.com/downloads + +- LLVM + - https://releases.llvm.org/ + - Create a `LIBCLANG_PATH` environment variable pointing to the LLVM lib path, e.g. + ``` + setx LIBCLANG_PATH "C:\Program Files\LLVM\lib" + ``` + +- Build Tools + - Microsoft Visual Studio Version 2019 or later + - C++ CMake tools for Windows + - MSVC build tools (latest version for your platform ARM, ARM64 or x64.x86) + - Spectre-mitigated libs (latest version for your platform ARM, ARM64 or x64.x86) + + or + + - [CMake](https://cmake.org/download/) + - [Build Tools for Visual Studio 2019]( +https://visualstudio.microsoft.com/thank-you-downloading-visual-studio/?sku=BuildTools&rel=16) + +- SQLite: + - Download 32bit/64bit Precompiled Binaries for Windows for [SQL Lite](https://www.sqlite.org/index.html) and unzip + to local path, e.g. `%USERPROFILE%\.sqlite` + - Open the appropriate x64\x86 `Native Tools Command Prompt for VS 2019` in `%USERPROFILE%\.sqlite` + - Run either of these, depending on your environment (32bit/64bit): + ``` + lib /DEF:sqlite3.def /OUT:sqlite3.lib /MACHINE:x64 + ``` + ``` + lib /DEF:sqlite3.def /OUT:sqlite3.lib /MACHINE:x86 + ``` + - Ensure folder containing `sqlite3.dll`, e.g. `%USERPROFILE%\.sqlite`, is in the path + - Create a `SQLITE3_LIB_DIR` environment variable pointing to the SQLite lib path, e.g. + ``` + setx SQLITE3_LIB_DIR "%USERPROFILE%\.sqlite" + ``` + +- Tor + - Donwload [Tor Windows Expert Bundle](https://www.torproject.org/download/tor/) + - Extract to local path, e.g. `C:\Program Files (x86)\Tor Services` + - Ensure folder containing the Tor executable, e.g. `C:\Program Files (x86)\Tor Services\Tor`, is in the path #### Install Rust -You can follow along at [The Rust Website](https://www.rust-lang.org/tools/install) or just follow these steps to get -Rust installed on your machine. +Follow the installation process for Windows at [The Rust Website](https://www.rust-lang.org/tools/install). Then make +sure that `cargo` and `rustc` has been added to your path: - curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh - -Then make sure that `cargo` has been added to your path. - - export PATH="$HOME/.cargo/bin:$PATH" + cargo --version + rustc --version #### Checkout the source code -In your folder of choice, clone the Tari repo +In your folder of choice, e.g. `%USERPROFILE%\Code`, clone the Tari repo git clone https://github.com/tari-project/tari.git #### Build -Grab a cup of coffee and begin the Tari build +This is similar to [building in Ubuntu](#building-from-source-ubuntu-1804), except the Microsoft Visual Studio +environment must be sourced. - cd tari +Open the appropriate _x64\x86 Native Tools Command Prompt for VS 2019_, and in your main Tari folder perform the +build, which will create the executable inside your `%USERPROFILE%\Code\tari\target\release` folder: + + cd %USERPROFILE%\Code\tari cargo build --release A successful build should output something as follows ``` - Compiling tari_wallet v0.0.9 (.../tari/base_layer/wallet) - Compiling test_faucet v0.0.1 (.../tari/applications/test_faucet) - Compiling tari_wallet_ffi v0.0.9 (.../tari/base_layer/wallet_ffi) - Compiling tari_base_node v0.0.9 (.../tari/applications/tari_base_node) + Compiling tari_wallet v0.0.9 (...\tari\base_layer\wallet) + Compiling test_faucet v0.0.1 (...\tari\applications\test_faucet) + Compiling tari_wallet_ffi v0.0.9 (...\tari\base_layer\wallet_ffi) + Compiling tari_base_node v0.0.9 (...\tari\applications\tari_base_node) Finished release [optimized] target(s) in 12m 24s ``` -#### Run +Alternatively, cargo can build and install the executable into `%USERPROFILE%\.cargo\bin`: -The executable is currently inside your `target/release` folder. You can run it from that folder if you like, but you'll -more likely want to copy it somewhere more convenient. You can simply run + cargo install tari_base_node - cargo install -p tari_base_node +#### Run -and cargo will copy the executable into `~/.cargo/bin`. This folder was added to your path in a previous step, so it -will be executable from anywhere on your system. +The executable will either be inside your `%USERPROFILE%\Code\tari\target\release` or the `%USERPROFILE%\.cargo\bin` +folder, depending on the build choice above. If the former build method was used, you can run it from that folder, +or you more likely want to copy it somewhere more convenient. Using the latter method, it will be executable from +anywhere on your system, as `%USERPROFILE%\.cargo\bin` was added to your path in a previous step. Alternatively, you can run the node from your source folder with the command - cargo run -p tari_base_node + cargo run --bin tari_base_node + ### Building a docker image @@ -178,6 +211,7 @@ Test your image * [Building with Vagrant](https://github.com/tari-project/tari/issues/1407) + # Project documentation * [RFC documents](https://rfc.tari.com) are hosted on Github Pages. The source markdown is in the `RFC` directory. diff --git a/RELEASE_CHECKLIST.md b/RELEASE_CHECKLIST.md index 0f0d73a990..f2b98e6ffc 100644 --- a/RELEASE_CHECKLIST.md +++ b/RELEASE_CHECKLIST.md @@ -11,4 +11,22 @@ THings to do before pushing a new commit to `master`: * Tag commit * Write release notes on GitHub. * Merge back into development (where appropriate) -* Delete branch \ No newline at end of file +* Delete branch + +| Crate | Version | Last change | +|:-----------------------------|:--------|:-----------------------------------------| +| infrastructure/derive | 0.0.10 | 7d734a2e79bfe2dd5d4ae00a2b760614d21e69c4 | +| infrastructure/shutdown | 0.0.10 | 7d734a2e79bfe2dd5d4ae00a2b760614d21e69c4 | +| infrastructure/storage | 0.1.0 | | +| infrastructure/test_utils | 0.0.10 | 7d734a2e79bfe2dd5d4ae00a2b760614d21e69c4 | +| base_layer/core | 0.1.0 | | +| base_layer/key_manager | 0.0.10 | 7d734a2e79bfe2dd5d4ae00a2b760614d21e69c4 | +| base_layer/mmr | 0.1.0 | | +| base_layer/p2p | 0.1.0 | | +| base_layer/service_framework | 0.0.10 | 7d734a2e79bfe2dd5d4ae00a2b760614d21e69c4 | +| base_layer/wallet | 0.1.0 | | +| base_layer/wallet_ffi | 0.1.0 | | +| common | 0.1.0 | | +| comms | 0.1.0 | | +| comms/dht | 0.1.0 | | +| applications/tari_base_node | 0.1.0 | | diff --git a/applications/tari_base_node/Cargo.toml b/applications/tari_base_node/Cargo.toml index 169dcbfb1d..f9f101e8fa 100644 --- a/applications/tari_base_node/Cargo.toml +++ b/applications/tari_base_node/Cargo.toml @@ -4,24 +4,25 @@ authors = ["The Tari Development Community"] description = "The tari full base node implementation" repository = "https://github.com/tari-project/tari" license = "BSD-3-Clause" -version = "0.0.10" +version = "0.1.0" edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -tari_common = {path = "../../common", version= "^0.0"} -tari_comms = { version = "^0.0", path = "../../comms"} -tari_comms_dht = { version = "^0.0", path = "../../comms/dht"} -tari_core = {path = "../../base_layer/core", version= "^0.0"} -tari_p2p = {path = "../../base_layer/p2p", version= "^0.0"} +tari_common = { version= "^0.1", path = "../../common" } +tari_comms = { version = "^0.1", path = "../../comms"} +tari_comms_dht = { version = "^0.1", path = "../../comms/dht"} +tari_core = {path = "../../base_layer/core", version= "^0.1"} +tari_p2p = {path = "../../base_layer/p2p", version= "^0.1"} tari_service_framework = { version = "^0.0", path = "../../base_layer/service_framework"} tari_shutdown = { path = "../../infrastructure/shutdown", version = "^0.0" } -tari_mmr = { path = "../../base_layer/mmr", version = "^0.0" } -tari_wallet = { path = "../../base_layer/wallet", version = "^0.0" } +tari_mmr = { path = "../../base_layer/mmr", version = "^0.1" } +tari_wallet = { path = "../../base_layer/wallet", version = "^0.1" } tari_broadcast_channel = "^0.1" +tari_crypto = { version = "^0.3" } -clap = "2.33.0" +structopt = { version = "0.3.13", default_features = false } config = { version = "0.9.3" } dirs = "2.0.2" futures = { version = "^0.3.1", default-features = false, features = ["alloc"]} @@ -35,3 +36,11 @@ rustyline-derive = "0.3" strum = "0.18.0" strum_macros = "0.18.0" qrcode = { version = "0.12" } +chrono = "0.4" +chrono-english = "0.1" +regex = "1" + +[build-dependencies] +serde = "1.0.90" +toml = "0.5" +git2 = "0.8" diff --git a/applications/tari_base_node/assets/tari_banner.rs b/applications/tari_base_node/assets/tari_banner.rs new file mode 100644 index 0000000000..0813734c55 --- /dev/null +++ b/applications/tari_base_node/assets/tari_banner.rs @@ -0,0 +1,25 @@ +" +                                            ▄▄▄▄▄▄▄                                                                                   +  ▄▄▄  ▓██▄     ░                      ▄▒█████████████▄                       ░▐  ░▐                           ▄▄▄▄          ▄▄▄▄     +  ▀███▄ ▀▓█▓     ░                   ▄▓█████████████████▌                      ░░░░ ▒ ░                      ███████▄      ▄███████   + ▄▄ ▀▓██▄ ▓██▄      ▓███            ▓█████████████████████▄                  ░▄▄▒░▒▄▒░░░░░                 ▐█████████▌    ██████████  +▐▓██▄ ▀███▄▐███▄   ▐▓▓▓            ▓▌  ▄▄▄▄  ░▀▀▀  ▄▄▄▄▄  ▓░               ▒▒▓▓▀▒▀▓▓▓▒░░░░░░               ███████████   ▐███████████ +  ▀▓██▌ ▀███▓████▄▄██▓▒            ▓▓ ▐██ ░░   ▄  ░░ ███░▐▓▒             ▐▒▒▓▓▓▄ ▄▓▓▓▒▒░ ░░░              ▐▐█░   █████▌  ▀██   ▀█████ +  ▄ ▐▓██▌░▓███████████▓           ▐▒▓▌▐███▄▄  ▓█▌ ▄▄████ ▓▓▒            ▌▒▒▓▓▓▓▓▓▓▓▓▓▒▒▒░ ░░                     ▐████▌        ░█████ + ▀██▓▄▐███████████████▓░           ▒▓▓▓▄▄▄▄▄▄█████▄▄▄▄▄▄█▓▒▒           ██▒▒▒▒▒▒▒▒▒▒▒▒▒░▒▒                  ░   ▄▒█████▒  ░    ▒██████ +   ▀▓██▄▄████████████▓▓▒            ▒▓▓▓▓ ▀▀██████▀▀ ▓▓▓▓▒▒           ▀▀██▌▒▒░░░░░░░░░▒▒▒▒                 ▀█████████▌    ██████████░ +      ▓█████████████▓▒▒              ▒▒▓▓▓▓       ▓▓▓▓▓▒▒░            ▒▒███▒       ░░▒▒▒▒▒                  ▀███████▌      ████████░  +        ▀▓███████▓▓▓▒▒                 ▀▓▓▓▓▓▓▓▓▓▓▓▓▓▓░                            ░▒▒▒▀                     ▀▀█▀▀          ▀██▀▀     +           ▀▓▓▓▓▓▒▒▒                       ▀▒▒▒▒▒▒▀                                                                                   +                                                                                                                                      +                                                                                                                                      +                                                                                                                                      +                                                                                                                                      +                                           ▄▄▄▄▄▄▄▄▄▄        ▄▄      ▄▄▄▄▄▄▄▄▄▄    ▄▄▄▄▄▄                                             +                                           ██ ▀███ ▐█      ▐████      ████   ███    ▐███                                              +                                              ▐███        ███  ███    ███▌   ███    ▐███                                              +                                              ▐███      ███▌    ███▌  ████▄▄▄██▀    ▐███                                              +                                              ▐███      ███████████▌  ███▌ ███▌     ▐███                                              +                                              ▐███      ███▌    ███▌  ███▌   ███    ▐███                                              +                                             ▄████▄▄    ███▌    ███▌ ▄███▌   ███   ▄████▄                                             +" \ No newline at end of file diff --git a/applications/tari_base_node/assets/tari_logo.rs b/applications/tari_base_node/assets/tari_logo.rs index 14c4eec4fb..eed0070532 100644 --- a/applications/tari_base_node/assets/tari_logo.rs +++ b/applications/tari_base_node/assets/tari_logo.rs @@ -1,11 +1,14 @@ " -⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⣶⣿⣿⣿⣿⣶⣦⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ -⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣤⣾⣿⡿⠋⠀⠀⠀⠀⠉⠛⠿⣿⣿⣶⣤⣀⠀⠀⠀⠀⠀⠀⢰⣿⣾⣾⣾⣾⣾⣾⣾⣾⣾⣿⠀⠀⠀⣾⣾⣾⡀⠀⠀⠀⠀⢰⣾⣾⣾⣾⣿⣶⣶⡀⠀⠀⠀⢸⣾⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀ -⠀⠀⠀⠀⠀⠀⠀⠀⠀⣿⣿⣿⣿⣿⣶⣶⣤⣄⡀⠀⠀⠀⠀⠀⠉⠛⣿⣿⠀⠀⠀⠀⠀⠈⠉⠉⠉⠉⣿⣿⡏⠉⠉⠉⠉⠀⠀⣰⣿⣿⣿⣿⠀⠀⠀⠀⢸⣿⣿⠉⠉⠉⠛⣿⣿⡆⠀⠀⢸⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀ -⠀⠀⠀⠀⠀⠀⠀⠀⠀⣿⣿⠀⠀⠀⠈⠙⣿⡿⠿⣿⣿⣿⣶⣶⣤⣤⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣿⣿⡇⠀⠀⠀⠀⠀⢠⣿⣿⠃⣿⣿⣷⠀⠀⠀⢸⣿⣿⣀⣀⣀⣴⣿⣿⠃⠀⠀⢸⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀ -⠀⠀⠀⠀⠀⠀⠀⠀⠀⣿⣿⣤⠀⠀⠀⢸⣿⡟⠀⠀⠀⠀⠀⠉⣽⣿⣿⠟⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣿⣿⡇⠀⠀⠀⠀⠀⣿⣿⣿⣤⣬⣿⣿⣆⠀⠀⢸⣿⣿⣿⣿⣿⡿⠟⠉⠀⠀⠀⢸⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀ -⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⣿⣿⣤⠀⢸⣿⡟⠀⠀⠀⣠⣾⣿⡿⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣿⣿⡇⠀⠀⠀⠀⣾⣿⣿⠿⠿⠿⢿⣿⣿⡀⠀⢸⣿⣿⠙⣿⣿⣿⣄⠀⠀⠀⠀⢸⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀ -⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⣿⣿⣼⣿⡟⣀⣶⣿⡿⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣿⣿⡇⠀⠀⠀⣰⣿⣿⠃⠀⠀⠀⠀⣿⣿⣿⠀⢸⣿⣿⠀⠀⠙⣿⣿⣷⣄⠀⠀⢸⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀ -⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⣿⣿⣿⣿⠛⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ -⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ +      ▄███████▄        +    ▄████▀ ▀▀████▄     + ▄███▀ ▀▀███▄  +█████████▄▄ ▀██ +██▌ ▀▀▀███████▄▄ ▐██ +██▌ ██ ▀▀████████ + ▀██▄ ██ ▄██▀  +   ▀██▄ ██ ▄██▀    +     ▀██▄ ██ ▄██▀      +       ▀███████▀        +         ▀███▀          +         ▀          " \ No newline at end of file diff --git a/applications/tari_base_node/build.rs b/applications/tari_base_node/build.rs new file mode 100644 index 0000000000..3f1b967dc8 --- /dev/null +++ b/applications/tari_base_node/build.rs @@ -0,0 +1,97 @@ +// Copyright 2020. The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// + +use serde::Deserialize; +use std::{env, fs, path::Path, string::ToString}; + +fn main() { + write_constants_file(); +} + +#[derive(Deserialize)] +struct Package { + authors: Vec, + version: String, +} + +#[derive(Deserialize)] +struct Manifest { + pub package: Package, +} + +fn write_constants_file() { + let data = extract_manifest(); + let mut package = data.package; + package.version = full_version(&package.version); + let out_dir = env::var_os("OUT_DIR").unwrap(); + let dest_path = Path::new(&out_dir).join("consts.rs"); + let output = format!( + r#" + pub const VERSION: &str = "{}"; + pub const AUTHOR: &str = "{}"; + "#, + package.version, + package.authors.join(",") + ); + fs::write(&dest_path, output.as_bytes()).unwrap(); +} + +fn extract_manifest() -> Manifest { + let cargo_path = Path::new(&env::var("CARGO_MANIFEST_DIR").unwrap()).join("Cargo.toml"); + let cargo = fs::read(cargo_path).expect("Could not read Cargo.toml"); + let cargo = std::str::from_utf8(&cargo).unwrap(); + toml::from_str(&cargo).unwrap() +} + +/// Add the git version commit and built type to the version number +/// The final output looks like 0.1.2-fc435c-release +fn full_version(ver: &str) -> String { + let sha = get_commit(); + let build = env::var("PROFILE").unwrap_or_else(|_| "Unknown".to_string()); + format!("{}-{}-{}", ver, sha, build) +} + +#[allow(clippy::let_and_return)] +fn get_commit() -> String { + let path = Path::new(&env::var("CARGO_MANIFEST_DIR").unwrap()) + .join("..") + .join(".."); + let repo = match git2::Repository::open(&path) { + Ok(r) => r, + Err(e) => { + println!("cargo:warning=Could not open repo: {}", e.to_string()); + return "NoGitRepository".to_string(); + }, + }; + let result = match repo.revparse_single("HEAD") { + Ok(head) => { + let id = format!("{:?}", head.id()); + id.split_at(7).0.to_string() + }, + Err(e) => { + println!("cargo:warning=Could not find latest commit: {}", e.to_string()); + String::from("NoGitRepository") + }, + }; + result +} diff --git a/applications/tari_base_node/osx-pkg/build-pkg.sh b/applications/tari_base_node/osx-pkg/build-pkg.sh new file mode 100755 index 0000000000..42cb64f3e6 --- /dev/null +++ b/applications/tari_base_node/osx-pkg/build-pkg.sh @@ -0,0 +1,107 @@ +#!/bin/bash +# +# build OSX pkg and submit too Apple for signing and notarization +# + +# Debugging enabled +set -x + +# ToDo +# Check options + +# Env +instName="tari_base_node" +sName=$(basename $0) +#sPath=$(realpath $0) +sPath=$(dirname $0) + +if [ $# -lt 3 ];then + echo "Usage: $0 {packageRoot} {packageVersion} {destDir}" + echo " ie: $0 /tmp/packageRoot 1.2.3.4 /tmp/destDir" + exit 1 +else + pkgRoot="$1" + pkgVersion="$2" + destDir="$3" +fi + +envFile="$sPath/.env" +if [ -f "$envFile" ]; then + echo "Overriding Enviroment with $envFile file for settings ..." + source "$envFile" +fi + +# Some Error checking +if [ "$(uname)" == "Darwin" ]; then + echo "Building OSX pkg ..." +else + echo "Not OSX!" + exit 2 +fi + +mkdir -p "$destDir/pkgRoot" +mkdir -p "$destDir/pkgRoot/usr/local/bin/" +# Verify signed? +codesign --verify --deep --display --verbose=4 \ + "$destDir/dist/$instName" +#spctl -a -v "$destDir/dist/$instName" +spctl -vvv --assess --type exec "$destDir/dist/$instName" + +cp "$destDir/dist/$instName" "$destDir/pkgRoot/usr/local/bin/" + +mkdir -p "$destDir/pkgRoot/usr/local/share/$instName" +COPY_SHARE_FILES=( + *.sh +) +for COPY_SHARE_FILE in "${COPY_SHARE_FILES[@]}"; do + cp "$destDir/dist/"$COPY_SHARE_FILE "$destDir/pkgRoot/usr/local/share/$instName/" +done + +mkdir -p "$destDir/pkgRoot/usr/local/share/doc/$instName" +COPY_DOC_FILES=( + "rincewind-simple.toml" + "tari_config_sample.toml" +# "log4rs.yml" + "log4rs-sample.yml" + "README.md" +) +for COPY_DOC_FILE in "${COPY_DOC_FILES[@]}"; do + cp "$destDir/dist/"$COPY_DOC_FILE "$destDir/pkgRoot/usr/local/share/doc/$instName/" +done + +mkdir -p "$destDir/scripts" +cp -r "${sPath}/scripts/"* "$destDir/scripts/" + +pkgbuildResult=$(pkgbuild --root $destDir/pkgRoot \ + --identifier "com.tarilabs.pkg.basenode" \ + --version "$pkgVersion" --install-location "/" \ + --scripts "$destDir/scripts" \ + --sign "$osxSigDevIDInstaller" $destDir/$instName-$pkgVersion.pkg) + +echo $pkgbuildResult + +echo "Submitting package, please wait ..." +RequestUUIDR=$(xcrun altool --notarize-app \ + --primary-bundle-id "com.tarilabs.com" \ + --username "$osxUsername" --password "$osxPassword" \ + --asc-provider "$osxASCProvider" \ + --file $destDir/$instName-$pkgVersion.pkg) + +requestStatus=$? +if [ $requestStatus -eq 0 ]; then + RequestLen=${#RequestUUIDR} + echo "Let length of ... $RequestLen ..." + echo "|$RequestUUIDR|" + RequestUUID=${RequestUUIDR#*RequestUUID = } + echo "Our request UUID is $RequestUUID ..." + echo "|$RequestUUID|" +else + echo "Error submitting ..." + echo $RequestUUIDR + exit 1 +fi + +RequestResult=$(xcrun altool --notarization-info "$RequestUUID" \ + --username "$osxUsername" --password "$osxPassword") + +echo "Our $RequestResult ..." diff --git a/applications/tari_base_node/osx-pkg/env-sample b/applications/tari_base_node/osx-pkg/env-sample new file mode 100644 index 0000000000..1d6ff24ff1 --- /dev/null +++ b/applications/tari_base_node/osx-pkg/env-sample @@ -0,0 +1,8 @@ + +# Copy env-sample to .env if need be ... + +# Package Signing/Notarization Envs +#osxSigDevIDInstaller="Developer ID Installer: XXXX (XXXXXXXXXX)" +#osxUsername="email@example.com" +#osxPassword="@keychain:Lable" +#osxASCProvider="XXXXXXXXX" diff --git a/applications/tari_base_node/osx-pkg/scripts/postinstall b/applications/tari_base_node/osx-pkg/scripts/postinstall new file mode 100755 index 0000000000..259b1a2a2e --- /dev/null +++ b/applications/tari_base_node/osx-pkg/scripts/postinstall @@ -0,0 +1,72 @@ +#!/usr/bin/env bash +# +# Installer script for Tari base node. +# This script is bundled with OSX PGK verion +# of the Tari base node binary distribution. +# + +# Debugging enabled +#set -x +set -e + +if [ ! "$(uname)" == "Darwin" ]; then + echo "Installer script meant for OSX" + echo "Please visit https://tari.com/downloads/" + echo " and download the binary distro for your platform" + exit 1 +fi + +logging_file=/tmp/my_postinstall.log +echo "Running postinstall - $(date +'%Y-%m-%d %Hh%M:%S')" > "$logging_file" + +loggedInUserID=`id -u "${USER}"` + +if [[ -n "$logging_file" ]] && [[ -f "$logging_file" ]]; then + echo "Redirecting out too an logging file $logging_file ..." + + # all output to log file + exec > >(tee -a "$logging_file") + exec 2>&1 + + echo "Redirecting should be working ..." +else + echo "No logging." +fi + +#env +#echo "Positional arguments" $@ + +# Detects if /Users is present. If /Users is present, +# the chflags command will unhide it +if [[ -d "$3/Users" ]]; then + #chflags nohidden "$3/Users" + echo "chflags nohidden $3/Users" +fi + +# Detects if /Users/Shared is present. If /Users/Shared is present, +# the chflags command will unhide it +if [[ -d "$3/Users/Shared" ]]; then + #chflags nohidden "$3/Users/Shared" + echo "chflags nohidden $3/Users/Shared" +fi + +echo "Checking XCode ..." +if !xcode-select -p 1>&2 2>/dev/null; then + echo "XCode not installed. Installing..." +# xcode-select --install 1>&2 + echo "XCode successfully installed" +else + echo "XCode already installed." +fi + +if [ "${COMMAND_LINE_INSTALL}" = "" ]; then + #/bin/launchctl asuser "${loggedInUserID}" /usr/bin/open -g "$3"/usr/local/share/tari_base_node/ + /bin/launchctl asuser "${loggedInUserID}" /usr/bin/open "$3"/usr/local/share/doc/tari_base_node/ + /bin/launchctl asuser "${loggedInUserID}" /usr/bin/open "$3"/usr/local/share/tari_base_node/ + #osascript -e 'tell app "Terminal" to do script "echo hello"' + /bin/launchctl asuser "${loggedInUserID}" /usr/bin/open -a Terminal.app -g "$3"/usr/local/share/tari_base_node/post-install.sh +fi + +echo "Done postinstall - $(date +'%Y-%m-%d %Hh%M:%S')" + +exit 0 # all good diff --git a/applications/tari_base_node/osx-pkg/scripts/preinstall b/applications/tari_base_node/osx-pkg/scripts/preinstall new file mode 100755 index 0000000000..8fbfb40e3a --- /dev/null +++ b/applications/tari_base_node/osx-pkg/scripts/preinstall @@ -0,0 +1,10 @@ +#!/bin/sh +# +# +# +echo "Running preinstall - $(date +'%Y-%m-%d %Hh%M:%S')" > /tmp/my_preinstall.log +env >> /tmp/my_preinstall.log +echo "Positional arguments" $@ >> /tmp/my_preinstall_envs.log +echo "Done preinstall - $(date +'%Y-%m-%d %Hh%M:%S')" >> /tmp/my_preinstall.log + +exit 0 # all good diff --git a/applications/tari_base_node/osx/post-install.sh b/applications/tari_base_node/osx/post-install.sh new file mode 100755 index 0000000000..f485308aeb --- /dev/null +++ b/applications/tari_base_node/osx/post-install.sh @@ -0,0 +1,211 @@ +#!/usr/bin/env bash +# +# Setup init Tari Base Node - Default +# + +# Installer script for Tari base node. This script is bundled with OSX +# versions of the Tari base node binary distributions. + +set -e + +logo=" +⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣤⣾⣿⣿⣶⣤⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ +⠀⠀⠀⠀⠀⠀⣠⣶⣿⣿⣿⣿⠛⠿⣿⣿⣿⣿⣿⣦⣤⠀⠀⠀⠀⠀⠀⠀⠀⠀ +⠀⠀⠀⣤⣾⣿⣿⣿⡿⠋⠀⠀⠀⠀⠀⠀⠉⠛⠿⣿⣿⣿⣿⣷⣦⣄⠀⠀⠀⠀ +⣴⣿⣿⣿⣿⣿⣉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠛⢿⣿⣿⣿⣿⣶⣤ +⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣶⣦⣤⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⣿⣿⣿ +⣿⣿⣿⠀⠀⠀⠀⠉⠉⠛⠿⣿⣿⣿⣿⣿⣿⣿⣿⣶⣶⣤⣄⣀⠀⠀⠀⣿⣿⣿ +⣿⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⣿⣿⣿⠀⠈⠉⠛⠛⠿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿ +⢿⣿⣿⣷⣄⠀⠀⠀⠀⠀⠀⣿⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣉⣿⣿⣿⣿⠟ +⠀⠈⢿⣿⣿⣷⣄⠀⠀⠀⠀⣿⣿⣿⠀⠀⠀⠀⠀⠀⠀⢀⣴⣿⣿⣿⡿⠋⠀⠀ +⠀⠀⠀⠈⢿⣿⣿⣷⡄⠀⠀⣿⣿⣿⠀⠀⠀⠀⢀⣴⣿⣿⣿⡿⠋⠀⠀⠀⠀⠀ +⠀⠀⠀⠀⠀⠈⢿⣿⣿⣷⡀⣿⣿⣿⠀⠀⣤⣾⣿⣿⣿⠛⠀⠀⠀⠀⠀⠀⠀⠀ +⠀⠀⠀⠀⠀⠀⠀⠈⢿⣿⣿⣿⣿⣿⣾⣿⣿⣿⠟⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ +⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⢿⣿⣿⣿⣿⠟⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ +⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠿⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ +" + +function display_center() { +# columns="$(tput cols)" + echo "$1" | while IFS= read -r line; do + printf "%*s\n" $(( (${#line} + columns) / 2)) "$line" + done +} + +function banner() { +# columns="$(tput cols)" + for (( c=1; c<=$columns; c++ )); do + echo -n "—" + done + + display_center " ✨ $1 ✨ " + for (( c=1; c<=$columns; c++ )); do + echo -n "—" + done + + echo +} + +columns="$(tput cols)" +if [ $? -eq 0 ]; then + echo "." +else + # not in terminal - force colums + echo ".." + colums=80 +fi + +for line in $logo; do + printf "%*s\n" $(( (31 + columns) / 2)) "$line" +done + +if [ ! "$(uname)" == "Darwin" ]; then + echo "Installer script meant for OSX" + echo "Please visit https://tari.com/downloads/" + echo " and download the binary distro for your platform" + exit 1 +fi + +banner "Installing XCode/Brew and Tor for OSX ..." + +if !xcode-select -p 1>&2 2>/dev/null; then + echo "XCode not installed. Installing..." +# xcode-select --install 1>&2 + echo "XCode successfully installed" +else + echo "XCode already installed." +fi + +if [[ $(command -v brew) == "" ]]; then + echo "Homebrew not installed. Installing now ... " +# ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)" + bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install.sh)" + echo "Homebrew successfully installed" +else + echo "Updating Homebrew... " + brew update +fi + +echo "Brew packages ..." +brew services + +# sqlite3 +for pkg in sqlite tor torsocks wget; do + if brew list -1 | grep -q "^${pkg}\$"; then + echo "Package '$pkg' is installed" + else + echo "Package '$pkg' is not installed, installing now ..." + brew install $pkg + fi +done + +echo "brew serivces list ..." +result=$(brew services list | grep -e "^tor") +echo $result + +if [[ $result =~ start ]];then + echo "Tor is running, stopping before making changes" + brew services stop tor +else + echo "Tor is Stopped" +fi + +echo "Setup Tor as a running service ..." +#/usr/local/etc/tor/torrc +if [ ! -f "/usr/local/etc/tor/torrc.custom" ]; then +#sudo tee -a /etc/tor/torrc.custom >> /dev/null << EOD +tee -a /usr/local/etc/tor/torrc.custom >> /dev/null << EOD + +# basenode only supports single port +SocksPort 127.0.0.1:9050 + +# Control Port Enable +ControlPort 127.0.0.1:9051 +CookieAuthentication 0 +#HashedControlPassword "" + +ClientOnly 1 +ClientUseIPv6 1 + +SafeLogging 0 + +EOD +fi + +if [ -f /usr/local/etc/tor/torrc ] ;then + if grep -Fxq "%include /usr/local/etc/tor/torrc.custom" /usr/local/etc/tor/torrc ;then + echo " torrc.custom already included for torrc ..." + else + echo "Adding torrc.custom include to torrc ..." + #sudo tee -a /etc/tor/torrc >> /dev/null << EOD + tee -a /usr/local/etc/tor/torrc >> /dev/null << EOD + +# Include torrc.custom +%include /usr/local/etc/tor/torrc.custom +# + +EOD + fi +else + echo "No /usr/local/etc/tor/torrc for Tor!" + echo "Adding torrc.custom include to torrc ..." + #sudo tee -a /etc/tor/torrc >> /dev/null << EOD + tee -a /usr/local/etc/tor/torrc >> /dev/null << EOD + +# Include torrc.custom +%include /usr/local/etc/tor/torrc.custom +# + +EOD + +fi + +brew services start tor +brew services list + +echo "Sleeping fro 30sec while Tor starts up ..." +sleep 30 + +# Check Tor service +#curl --socks5 localhost:9050 --socks5-hostname localhost:9050 -s https://check.torproject.org/ | cat | grep -m 1 Congratulations | xargs +#torsocks curl icanhazip.com +#curl icanhazip.com + +wget -qO - https://api.ipify.org; echo +torsocks wget -qO - https://api.ipify.org; echo + +DATA_DIR=${1:-"$HOME/.tari"} +NETWORK=rincewind + +banner Installing and setting up your Tari Base Node +if [ ! -d "$DATA_DIR/$NETWORK" ]; then + echo "Creating Tari data folder in $DATA_DIR" + mkdir -p $DATA_DIR/$NETWORK +fi + +if [ ! -f "$DATA_DIR/config.toml" ]; then + echo "Copying configuraton files" +# cp rincewind-simple.toml $DATA_DIR/config.toml +# cp log4rs-sample.yml $DATA_DIR/log4rs.yml + + # Configure Base Node + tari_base_node --init --create-id + + echo "Configuration complete." +fi + +banner Running Tari Base Node +# Run Base Node +if [ -e ~/Desktop/tari_base_node ]; then + echo "Desktop Link to Tari Base Node exits" +else + ln -s /usr/local/bin/tari_base_node ~/Desktop/tari_base_node +fi +cd "$DATA_DIR" +open /usr/local/bin/tari_base_node + +# Start Tari Base Node in another Terminal +#osascript -e "tell application \"Terminal\" to do script \"sh ${PWD}/start_tor.sh\"" + +banner Tari Base Node Install Done! +exit 0 diff --git a/applications/tari_base_node/osx/uninstall-pkg.sh b/applications/tari_base_node/osx/uninstall-pkg.sh new file mode 100755 index 0000000000..c30e2aa428 --- /dev/null +++ b/applications/tari_base_node/osx/uninstall-pkg.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +# +# Uninastall Tari Base Node for OSX +# + +# ToDo: +# Force/Check/Files/PKG +# + +if [ ! "$(uname)" == "Darwin" ]; then + echo "Uninstaller script meant for OSX" + echo " Please visit https://tari.com/downloads/" + echo " and download the binary distro for your platform" + exit 1 +fi + +#osascript -e 'tell application \"Terminal\" to do script \"cd directory\" in \ +# selected tab of the front window' > /dev/null 2>&1 + +# Check pgk +pkgutil --pkgs=com.tarilabs.pkg.basenode* + +pkgutil --files com.tarilabs.pkg.basenode +# rm -fr /usr/local/bin/tari_base_node +# rm -fr /usr/local/share/tari_base_node +# rm -fr /usr/local/share/doc/tari_base_node + +#tariLabsReceipts=$(pkgutil --pkgs=com.tarilabs.pkg.basenode*) +#for myReceipt in $tariLabsReceipts; do +# pkgutil --forget $myReceipt +#done + +#pkgutil --forget com.tarilabs.pkg.basenode diff --git a/applications/tari_base_node/src/builder.rs b/applications/tari_base_node/src/builder.rs index 98b0d1e156..09b8f3e24b 100644 --- a/applications/tari_base_node/src/builder.rs +++ b/applications/tari_base_node/src/builder.rs @@ -46,7 +46,7 @@ use tari_comms::{ ConnectionManagerEvent, PeerManager, }; -use tari_comms_dht::Dht; +use tari_comms_dht::{DbConnectionUrl, Dht, DhtConfig}; use tari_core::{ base_node::{ chain_metadata_service::{ChainMetadataHandle, ChainMetadataServiceInitializer}, @@ -60,20 +60,28 @@ use tari_core::{ create_lmdb_database, BlockchainBackend, BlockchainDatabase, + BlockchainDatabaseConfig, LMDBDatabase, MemoryDatabase, Validators, }, consensus::{ConsensusManager, ConsensusManagerBuilder, Network as NetworkType}, - mempool::{Mempool, MempoolConfig, MempoolServiceConfig, MempoolServiceInitializer, MempoolValidators}, + mempool::{ + service::LocalMempoolService, + Mempool, + MempoolConfig, + MempoolServiceConfig, + MempoolServiceInitializer, + MempoolValidators, + }, mining::Miner, - proof_of_work::DiffAdjManager, tari_utilities::{hex::Hex, message_format::MessageFormat}, transactions::{ crypto::keys::SecretKey as SK, types::{CryptoFactories, HashDigest, PrivateKey, PublicKey}, }, validation::{ + accum_difficulty_validators::AccumDifficultyValidator, block_validators::{FullConsensusValidator, StatelessBlockValidator}, transaction_validators::{FullTxValidator, TxInputAndMaturityValidator}, }, @@ -105,7 +113,7 @@ use tari_wallet::{ TransactionServiceInitializer, }, }; -use tokio::{runtime, stream::StreamExt, sync::broadcast, task}; +use tokio::{runtime, stream::StreamExt, sync::broadcast, task, time::delay_for}; const LOG_TARGET: &str = "c::bn::initialization"; @@ -145,11 +153,22 @@ impl NodeContainer { using_backend!(self, ctx, ctx.local_node()) } + /// Returns a handle to the local mempool service. This function panics if it has not been registered + /// with the comms service + pub fn local_mempool(&self) -> LocalMempoolService { + using_backend!(self, ctx, ctx.local_mempool()) + } + /// Returns the CommsNode. pub fn base_node_comms(&self) -> &CommsNode { using_backend!(self, ctx, &ctx.base_node_comms) } + /// Returns the wallet CommsNode. + pub fn wallet_comms(&self) -> &CommsNode { + using_backend!(self, ctx, &ctx.wallet_comms) + } + /// Returns this node's identity. pub fn base_node_identity(&self) -> Arc { using_backend!(self, ctx, ctx.base_node_comms.node_identity()) @@ -186,10 +205,18 @@ impl NodeContainer { debug!(target: LOG_TARGET, "Mining wallet ready to receive coins."); while let Some(utxo) = rx.next().await { match wallet_output_handle.add_output(utxo).await { - Ok(_) => info!( - target: LOG_TARGET, - "🤑💰🤑 Newly mined coinbase output added to wallet 🤑💰🤑" - ), + Ok(_) => { + info!( + target: LOG_TARGET, + "🤑💰🤑 Newly mined coinbase output added to wallet 🤑💰🤑" + ); + // TODO Remove this when the wallet monitors the UTXO's more intelligently + let mut oms_handle_clone = wallet_output_handle.clone(); + tokio::spawn(async move { + delay_for(Duration::from_secs(240)).await; + oms_handle_clone.sync_with_base_node().await; + }); + }, Err(e) => warn!(target: LOG_TARGET, "Error adding output: {}", e), } } @@ -209,7 +236,14 @@ impl NodeContainer { } } -pub struct BaseNodeContext { +/// The base node context is a container for all the key structural pieces for the base node application, including the +/// communications stack, the node state machine, the miner and handles to the various services that are registered +/// on the comms stack. +/// +/// `BaseNodeContext` is not intended to be ever used directly, so is a private struct. It is only ever created in the +/// [NodeContainer] enum, which serves the purpose of abstracting the specific `BlockchainBackend` instance away +/// from users of the full base node stack. +struct BaseNodeContext { pub base_node_comms: CommsNode, pub base_node_dht: Dht, pub wallet_comms: CommsNode, @@ -222,18 +256,28 @@ pub struct BaseNodeContext { } impl BaseNodeContext { + /// Returns a handle to the Output Manager pub fn output_manager(&self) -> OutputManagerHandle { self.wallet_handles .get_handle::() .expect("Problem getting wallet output manager handle") } + /// Returns the handle to the Comms Interface pub fn local_node(&self) -> LocalNodeCommsInterface { self.base_node_handles .get_handle::() - .expect("Could not get local comms interface handle") + .expect("Could not get local node interface handle") + } + + /// Returns the handle to the Mempool + pub fn local_mempool(&self) -> LocalMempoolService { + self.base_node_handles + .get_handle::() + .expect("Could not get local mempool interface handle") } + /// Return the handle to the Transaciton Service pub fn wallet_transaction_service(&self) -> TransactionServiceHandle { self.wallet_handles .get_handle::() @@ -243,6 +287,11 @@ impl BaseNodeContext { /// Tries to construct a node identity by loading the secret key and other metadata from disk and calculating the /// missing fields from that information. +/// ## Parameters +/// `path` - Reference to a path +/// +/// ## Returns +/// Result containing a NodeIdentity on success, string indicates the reason on failure pub fn load_identity(path: &Path) -> Result { if !path.exists() { return Err(format!("Identity file, {}, does not exist.", path.to_str().unwrap())); @@ -271,6 +320,13 @@ pub fn load_identity(path: &Path) -> Result { } /// Create a new node id and save it to disk +/// ## Parameters +/// `path` - Reference to path to save the file +/// `public_addr` - Network address of the base node +/// `peer_features` - The features enabled for the base node +/// +/// ## Returns +/// Result containing the node identity, string will indicate reason on error pub fn create_new_base_node_identity>( path: P, public_addr: Multiaddr, @@ -284,6 +340,12 @@ pub fn create_new_base_node_identity>( Ok(node_identity) } +/// Loads the node identity from json at the given path +/// ## Parameters +/// `path` - Path to file from which to load the node identity +/// +/// ## Returns +/// Result containing an object on success, string will indicate reason on error pub fn load_from_json, T: MessageFormat>(path: P) -> Result { if !path.as_ref().exists() { return Err(format!( @@ -297,6 +359,13 @@ pub fn load_from_json, T: MessageFormat>(path: P) -> Result, T: MessageFormat>(path: P, object: &T) -> Result<(), String> { let json = object.to_json().unwrap(); if let Some(p) = path.as_ref().parent() { @@ -315,6 +384,14 @@ pub fn save_as_json, T: MessageFormat>(path: P, object: &T) -> Re Ok(()) } +/// Sets up and initializes the base node, creating the context and database +/// ## Paramters +/// `config` - The configuration for the base node +/// `node_identity` - The node identity information of the base node +/// `wallet_node_identity` - The node identity information of the base node's wallet +/// `interrupt_signal` - The signal used to stop the application +/// ## Returns +/// Result containing the NodeContainer, String will contain the reason on error pub async fn configure_and_initialize_node( config: &GlobalConfig, node_identity: Arc, @@ -357,6 +434,16 @@ pub async fn configure_and_initialize_node( Ok(result) } +/// Constructs the base node context, this includes settin up the consensus manager, mempool, base node, wallet, miner +/// and state machine ## Paramters +/// `backend` - Backend interface +/// `network` - The NetworkType (rincewind, mainnet, local) +/// `base_node_identity` - The node identity information of the base node +/// `wallet_node_identity` - The node identity information of the base node's wallet +/// `config` - The configuration for the base node +/// `interrupt_signal` - The signal used to stop the application +/// ## Returns +/// Result containing the BaseNodeContext, String will contain the reason on error async fn build_node_context( backend: B, network: NetworkType, @@ -375,13 +462,14 @@ where let validators = Validators::new( FullConsensusValidator::new(rules.clone(), factories.clone()), StatelessBlockValidator::new(&rules.consensus_constants()), + AccumDifficultyValidator {}, ); - let db = BlockchainDatabase::new(backend, &rules, validators).map_err(|e| e.to_string())?; + // TODO - make BlockchainDatabaseConfig configurable + let db = BlockchainDatabase::new(backend, &rules, validators, BlockchainDatabaseConfig::default()) + .map_err(|e| e.to_string())?; let mempool_validator = MempoolValidators::new(FullTxValidator::new(factories.clone()), TxInputAndMaturityValidator {}); let mempool = Mempool::new(db.clone(), MempoolConfig::default(), mempool_validator); - let diff_adj_manager = DiffAdjManager::new(&rules.consensus_constants()).map_err(|e| e.to_string())?; - rules.set_diff_manager(diff_adj_manager).map_err(|e| e.to_string())?; let handle = runtime::Handle::current(); //---------------------------------- Base Node --------------------------------------------// @@ -512,6 +600,14 @@ where }) } +/// Asynchronously syncs peers with base node, adding peers if the peer is not already known +/// ## Parameters +/// `events_rx` - The event stream +/// `base_node_peer_manager` - The peer manager for the base node wrapped in an atomic reference counter +/// `wallet_peer_manager` - The peer manager for the base node's wallet wrapped in an atomic reference counter +/// +/// ## Returns +/// Nothing is returned async fn sync_peers( mut events_rx: broadcast::Receiver>, base_node_peer_manager: Arc, @@ -537,6 +633,12 @@ async fn sync_peers( } } +/// Parses the seed peers from a delimited string into a list of peers +/// ## Parameters +/// `seeds` - A string of peers delimited by '::' +/// +/// ## Returns +/// A list of peers, peers which do not have a valid public key are excluded fn parse_peer_seeds(seeds: &[String]) -> Vec { info!("Adding {} peers to the peer database", seeds.len()); let mut result = Vec::with_capacity(seeds.len()); @@ -595,6 +697,12 @@ fn parse_peer_seeds(seeds: &[String]) -> Vec { result } +/// Creates a transport type from the given configuration +/// /// ## Paramters +/// `config` - The reference to the configuration in which to set up the comms stack, see [GlobalConfig] +/// +/// ##Returns +/// TransportType based on the configuration fn setup_transport_type(config: &GlobalConfig) -> TransportType { debug!(target: LOG_TARGET, "Transport is set to '{:?}'", config.comms_transport); @@ -665,6 +773,12 @@ fn setup_transport_type(config: &GlobalConfig) -> TransportType { } } +/// Creates a transport type for the base node's wallet using the provided configuration +/// ## Paramters +/// `config` - The reference to the configuration in which to set up the comms stack, see [GlobalConfig] +/// +/// ##Returns +/// TransportType based on the configuration fn setup_wallet_transport_type(config: &GlobalConfig) -> TransportType { debug!( target: LOG_TARGET, @@ -749,6 +863,12 @@ fn setup_wallet_transport_type(config: &GlobalConfig) -> TransportType { } } +/// Converts one socks authentication struct into another +/// ## Parameters +/// `auth` - Socks authentication of type SocksAuthentication +/// +/// ## Returns +/// Socks authentication of type socks::Authentication fn into_socks_authentication(auth: SocksAuthentication) -> socks::Authentication { match auth { SocksAuthentication::None => socks::Authentication::None, @@ -758,6 +878,12 @@ fn into_socks_authentication(auth: SocksAuthentication) -> socks::Authentication } } +/// Creates the storage location for the base node's wallet +/// ## Parameters +/// `wallet_path` - Reference to a file path +/// +/// ## Returns +/// A Result to determine if it was successful or not, string will indicate the reason on error fn create_wallet_folder>(wallet_path: P) -> Result<(), String> { let path = wallet_path.as_ref(); match fs::create_dir_all(path) { @@ -781,6 +907,12 @@ fn create_wallet_folder>(wallet_path: P) -> Result<(), String> { } } +/// Creates the directory to store the peer database +/// ## Parameters +/// `peer_db_path` - Reference to a file path +/// +/// ## Returns +/// A Result to determine if it was successful or not, string will indicate the reason on error fn create_peer_db_folder>(peer_db_path: P) -> Result<(), String> { let path = peer_db_path.as_ref(); match fs::create_dir_all(path) { @@ -800,6 +932,13 @@ fn create_peer_db_folder>(peer_db_path: P) -> Result<(), String> } } +/// Asynchronously initializes comms for the base node +/// ## Parameters +/// `node_identity` - The node identity to initialize the comms stack with, see [NodeIdentity] +/// `config` - The reference to the configuration in which to set up the comms stack, see [GlobalConfig] +/// `publisher` - The publisher for the publish-subscribe messaging system +/// ## Returns +/// A Result containing the commsnode and dht on success, string will indicate the reason on error async fn setup_base_node_comms( node_identity: Arc, config: &GlobalConfig, @@ -814,7 +953,10 @@ async fn setup_base_node_comms( max_concurrent_inbound_tasks: 100, outbound_buffer_size: 100, // TODO - make this configurable - dht: Default::default(), + dht: DhtConfig { + database_url: DbConnectionUrl::File(config.data_dir.join("dht.db")), + ..Default::default() + }, // TODO: This should be false unless testing locally - make this configurable allow_test_addresses: true, listener_liveness_whitelist_cidrs: config.listener_liveness_whitelist_cidrs.clone(), @@ -838,6 +980,15 @@ async fn setup_base_node_comms( Ok((comms, dht)) } +/// Asynchronously initializes comms for the base node's wallet +/// ## Parameters +/// `node_identity` - The node identity to initialize the comms stack with, see [NodeIdentity] +/// `config` - The configuration in which to set up the comms stack, see [GlobalConfig] +/// `publisher` - The publisher for the publish-subscribe messaging system +/// `base_node_peer` - The base node for the wallet to connect to +/// `peers` - A list of peers to be added to the comms node, the current node identity of the comms stack is excluded if +/// found in the list. ## Returns +/// A Result containing the commsnode and dht on success, string will indicate the reason on error async fn setup_wallet_comms( node_identity: Arc, config: &GlobalConfig, @@ -853,7 +1004,10 @@ async fn setup_wallet_comms( max_concurrent_inbound_tasks: 100, outbound_buffer_size: 100, // TODO - make this configurable - dht: Default::default(), + dht: DhtConfig { + database_url: DbConnectionUrl::File(config.data_dir.join("dht-wallet.db")), + ..Default::default() + }, // TODO: This should be false unless testing locally - make this configurable allow_test_addresses: true, listener_liveness_whitelist_cidrs: Vec::new(), @@ -877,10 +1031,26 @@ async fn setup_wallet_comms( Ok((comms, dht)) } +/// Adds a new peer to the base node +/// ## Parameters +/// `comms_node` - A reference to the comms node. This is the communications stack +/// `peers` - A list of peers to be added to the comms node, the current node identity of the comms stack is excluded if +/// found in the list. +/// +/// ## Returns +/// A Result to determine if the call was successful or not, string will indicate the reason on error async fn add_peers_to_comms(comms: &CommsNode, peers: Vec) -> Result<(), String> { for p in peers { let peer_desc = p.to_string(); info!(target: LOG_TARGET, "Adding seed peer [{}]", peer_desc); + + if &p.public_key == comms.node_identity().public_key() { + info!( + target: LOG_TARGET, + "Attempting to add yourself [{}] as a seed peer to comms layer, ignoring request", peer_desc + ); + continue; + } comms .peer_manager() .add_peer(p) @@ -890,6 +1060,19 @@ async fn add_peers_to_comms(comms: &CommsNode, peers: Vec) -> Result<(), S Ok(()) } +/// Asynchronously registers services of the base node +/// +/// ## Parameters +/// `comms` - A reference to the comms node. This is the communications stack +/// `db` - The interface to the blockchain database, for all transactions stored in a block +/// `dht` - A reference to the peer discovery service +/// `subscription_factory` - The publish-subscribe messaging system, wrapped in an atomic reference counter +/// `mempool` - The mempool interface, for all transactions not yet included or recently included in a block +/// `consensus_manager` - The consensus manager for the blockchain +/// `factories` - Cryptographic factory based on Pederson Commitments +/// +/// ## Returns +/// A hashmap of handles wrapped in an atomic reference counter async fn register_base_node_services( comms: &CommsNode, dht: &Dht, @@ -921,11 +1104,13 @@ where LivenessConfig { auto_ping_interval: Some(Duration::from_secs(30)), enable_auto_join: true, - enable_auto_stored_message_request: true, refresh_neighbours_interval: Duration::from_secs(3 * 60), + random_peer_selection_ratio: 0.4, + ..Default::default() }, subscription_factory, dht.dht_requester(), + comms.connection_manager(), )) .add_initializer(ChainMetadataServiceInitializer) .finish() @@ -933,6 +1118,16 @@ where .expect("Service initialization failed") } +/// Asynchronously registers services for the base node's wallet +/// ## Parameters +/// `wallet_comms` - A reference to the comms node. This is the communications stack +/// `wallet_dht` - A reference to the peer discovery service +/// `wallet_db_conn` - A reference to the sqlite database connection for the transaction and output manager services +/// `subscription_factory` - The publish-subscribe messaging system, wrapped in an atomic reference counter +/// `factories` - Cryptographic factory based on Pederson Commitments +/// +/// ## Returns +/// A hashmap of handles wrapped in an atomic reference counter async fn register_wallet_services( wallet_comms: &CommsNode, wallet_dht: &Dht, @@ -947,13 +1142,13 @@ async fn register_wallet_services( LivenessConfig{ auto_ping_interval: None, enable_auto_join: true, - enable_auto_stored_message_request: true, ..Default::default() }, subscription_factory.clone(), - wallet_dht.dht_requester() + wallet_dht.dht_requester(), + wallet_comms.connection_manager() - )) + )) // Wallet services .add_initializer(OutputManagerServiceInitializer::new( OutputManagerServiceConfig::default(), @@ -964,7 +1159,6 @@ async fn register_wallet_services( .add_initializer(TransactionServiceInitializer::new( TransactionServiceConfig::default(), subscription_factory, - wallet_comms.subscribe_messaging_events(), TransactionServiceSqliteDatabase::new(wallet_db_conn.clone()), wallet_comms.node_identity(), factories, diff --git a/applications/tari_base_node/src/cli.rs b/applications/tari_base_node/src/cli.rs index 9e3739edf3..dc24ee5059 100644 --- a/applications/tari_base_node/src/cli.rs +++ b/applications/tari_base_node/src/cli.rs @@ -21,49 +21,164 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // -use crate::consts; -use clap::clap_app; -use tari_common::{bootstrap_config_from_cli, ConfigBootstrap}; +use structopt::StructOpt; +use tari_common::ConfigBootstrap; -/// Prints a pretty banner on the console -pub fn print_banner() { - let logo = include!("../assets/tari_logo.rs"); - println!( - "{}\n\n$ Tari Base Node\n$ Copyright 2019-2020. {}\n$ Version {}\n\nPress Ctrl-C to quit..", - logo, - consts::AUTHOR, - consts::VERSION - ); +// Import the auto-generated const values from the Manifest and Git +include!(concat!(env!("OUT_DIR"), "/consts.rs")); + +/// returns the top or bottom box line of the specified length +fn box_line(length: usize, is_top: bool) -> String { + if length < 2 { + return format!(""); + } + if is_top { + format!("{}{}{}", "┌", "─".repeat(length - 2), "┐") + } else { + format!("{}{}{}", "└", "─".repeat(length - 2), "┘") + } } -/// Parsed command-line arguments -pub struct Arguments { - pub bootstrap: ConfigBootstrap, - pub create_id: bool, - pub init: bool, +/// returns a horizontal rule of the box of the specified length +fn box_separator(length: usize) -> String { + if length < 2 { + return format!(""); + } + format!("{}{}{}", "├", "─".repeat(length - 2), "┤") +} + +/// returns a line in the box, with box borders at the beginning and end, contents centered. +fn box_data(data: String, target_length: usize) -> String { + let padding = if ((target_length - 2) / 2) > (data.chars().count() / 2) { + ((target_length - 2) / 2) - (data.chars().count() / 2) + } else { + 0 + }; + let mut s = format!("{}{}{}{}", " ".repeat(padding), data, " ".repeat(padding), "│"); + // for integer rounding error, usually only 1 char, to ensure border lines up + while s.chars().count() < target_length - 1 { + s = format!("{}{}", " ", s); + } + format!("{}{}", "│", s) +} + +/// returns a vector of strings, for each vector of strings, the strings are combined (padded and spaced as necessary), +/// then the result is sent to `box_data` before being added to the result +fn box_tabular_data_rows( + data: Vec>, + sizes: Vec, + target_length: usize, + spacing: usize, +) -> Vec +{ + let max_cell_length = sizes.iter().max().unwrap(); + let mut result = Vec::new(); + for items in data { + let mut s = " ".repeat(spacing); + for string in items { + if &string.chars().count() < max_cell_length { + let padding = (max_cell_length / 2) - (&string.chars().count() / 2); + s = format!( + "{}{}{}{}{}", + s, + " ".repeat(padding), + string, + " ".repeat(padding), + " ".repeat(spacing) + ); + } else { + s = format!("{}{}{}", s, string, " ".repeat(spacing)); + } + } + result.push(box_data(s, target_length)); + } + result } -/// Parse the command-line args and populate the minimal bootstrap config object -pub fn parse_cli_args() -> Arguments { - let matches = clap_app!(myapp => - (version: consts::VERSION) - (author: consts::AUTHOR) - (about: "The reference Tari cryptocurrency base node implementation") - (@arg base_dir: -b --base_dir +takes_value "A path to a directory to store your files") - (@arg config: -c --config +takes_value "A path to the configuration file to use (config.toml)") - (@arg log_config: -l --log_config +takes_value "A path to the logfile configuration (log4rs.yml))") - (@arg init: --init "Create a default configuration file if it doesn't exist") - (@arg create_id: --create_id "Create and save new node identity if one doesn't exist ") - ) - .get_matches(); +fn multiline_find_display_length(lines: &str) -> usize { + let mut result = 0; + if let Some(line) = lines.lines().max_by(|x, y| x.chars().count().cmp(&y.chars().count())) { + result = line.as_bytes().len(); + result /= 2; + result -= result / 10; + } + result +} + +/// Prints a pretty banner on the console as well as the list of available commands +pub fn print_banner(commands: Vec, chunk_size: i32) { + let chunks: Vec> = commands.chunks(chunk_size as usize).map(|x| x.to_vec()).collect(); + let mut cell_sizes = Vec::new(); + + let mut row_cell_count: i32 = 0; + let mut command_data: Vec> = Vec::new(); + for chunk in chunks { + let mut cells: Vec = Vec::new(); + for item in chunk { + cells.push(item.clone()); + cell_sizes.push(item.chars().count()); + row_cell_count += 1; + } + if row_cell_count < chunk_size { + while row_cell_count < chunk_size { + cells.push(" ".to_string()); + cell_sizes.push(1); + row_cell_count += 1; + } + } else { + row_cell_count = 0; + } + command_data.push(cells); + } - let bootstrap = bootstrap_config_from_cli(&matches); - let create_id = matches.is_present("create_id"); - let init = matches.is_present("init"); + let row_cell_sizes: Vec> = cell_sizes.chunks(chunk_size as usize).map(|x| x.to_vec()).collect(); + let mut row_cell_size = Vec::new(); + let mut max_cell_size: usize = 0; + for sizes in row_cell_sizes { + for size in sizes { + if size > max_cell_size { + max_cell_size = size; + } + } + row_cell_size.push(max_cell_size); + max_cell_size = 0; + } - Arguments { - bootstrap, - create_id, - init, + let banner = include!("../assets/tari_banner.rs"); + let target_line_length = multiline_find_display_length(banner); + for line in banner.lines() { + println!("{}", line.to_string()); + } + println!("\n{}", box_line(target_line_length, true)); + let logo = include!("../assets/tari_logo.rs"); + for line in logo.lines() { + println!("{}", box_data(line.to_string(), target_line_length)); } + println!("{}", box_data(" ".to_string(), target_line_length)); + println!("{}", box_data("Tari Base Node".to_string(), target_line_length)); + println!("{}", box_data("~~~~~~~~~~~~~~".to_string(), target_line_length)); + println!( + "{}", + box_data(format!("Copyright 2019-2020. {}", AUTHOR), target_line_length) + ); + println!("{}", box_data(format!("Version {}", VERSION), target_line_length)); + println!("{}", box_separator(target_line_length)); + println!("{}", box_data("Commands".to_string(), target_line_length)); + println!("{}", box_data("~~~~~~~~".to_string(), target_line_length)); + println!("{}", box_separator(target_line_length)); + let rows = box_tabular_data_rows(command_data, row_cell_size, target_line_length, 10); + for row in rows { + println!("{}", row); + } + println!("{}", box_line(target_line_length, false)); +} + +/// The reference Tari cryptocurrency base node implementation +#[derive(StructOpt)] +pub struct Arguments { + /// Create and save new node identity if one doesn't exist + #[structopt(long)] + pub create_id: bool, + #[structopt(flatten)] + pub bootstrap: ConfigBootstrap, } diff --git a/applications/tari_base_node/src/main.rs b/applications/tari_base_node/src/main.rs index 9f65cf2780..49fe331aec 100644 --- a/applications/tari_base_node/src/main.rs +++ b/applications/tari_base_node/src/main.rs @@ -19,14 +19,68 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// + +/// ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⣶⣿⣿⣿⣿⣶⣦⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ +/// ⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣤⣾⣿⡿⠋⠀⠀⠀⠀⠉⠛⠿⣿⣿⣶⣤⣀⠀⠀⠀⠀⠀⠀⢰⣿⣾⣾⣾⣾⣾⣾⣾⣾⣾⣿⠀⠀⠀⣾⣾⣾⡀⠀⠀⠀⠀⢰⣾⣾⣾⣾⣿⣶⣶⡀⠀⠀⠀⢸⣾⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀ +/// ⠀⠀⠀⠀⠀⠀⠀⠀⠀⣿⣿⣿⣿⣿⣶⣶⣤⣄⡀⠀⠀⠀⠀⠀⠉⠛⣿⣿⠀⠀⠀⠀⠀⠈⠉⠉⠉⠉⣿⣿⡏⠉⠉⠉⠉⠀⠀⣰⣿⣿⣿⣿⠀⠀⠀⠀⢸⣿⣿⠉⠉⠉⠛⣿⣿⡆⠀⠀⢸⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀ +/// ⠀⠀⠀⠀⠀⠀⠀⠀⠀⣿⣿⠀⠀⠀⠈⠙⣿⡿⠿⣿⣿⣿⣶⣶⣤⣤⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣿⣿⡇⠀⠀⠀⠀⠀⢠⣿⣿⠃⣿⣿⣷⠀⠀⠀⢸⣿⣿⣀⣀⣀⣴⣿⣿⠃⠀⠀⢸⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀ +/// ⠀⠀⠀⠀⠀⠀⠀⠀⠀⣿⣿⣤⠀⠀⠀⢸⣿⡟⠀⠀⠀⠀⠀⠉⣽⣿⣿⠟⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣿⣿⡇⠀⠀⠀⠀⠀⣿⣿⣿⣤⣬⣿⣿⣆⠀⠀⢸⣿⣿⣿⣿⣿⡿⠟⠉⠀⠀⠀⢸⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀ +/// ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⣿⣿⣤⠀⢸⣿⡟⠀⠀⠀⣠⣾⣿⡿⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣿⣿⡇⠀⠀⠀⠀⣾⣿⣿⠿⠿⠿⢿⣿⣿⡀⠀⢸⣿⣿⠙⣿⣿⣿⣄⠀⠀⠀⠀⢸⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀ +/// ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⣿⣿⣼⣿⡟⣀⣶⣿⡿⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣿⣿⡇⠀⠀⠀⣰⣿⣿⠃⠀⠀⠀⠀⣿⣿⣿⠀⢸⣿⣿⠀⠀⠙⣿⣿⣷⣄⠀⠀⢸⣿⣿⠀⠀⠀⠀⠀⠀⠀⠀⠀ +/// ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⣿⣿⣿⣿⠛⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ +/// ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀ +/// +/// # Tari Base Node +/// +/// The Tari Base Node is a major application in the Tari Network +/// +/// It consists of the Base Node itself, a Wallet and a Miner +/// +/// ## Running the Tari Base Node +/// +/// Tor needs to be started first +/// ``` +/// tor --allow-missing-torrc --ignore-missing-torrc \ +/// --clientonly 1 --socksport 9050 --controlport 127.0.0.1:9051 \ +/// --log "notice stdout" --clientuseipv6 1 +/// ``` +/// +/// For the first run +/// ```cargo run tari_base_node -- --create_id``` +/// +/// Subsequent runs +/// ```cargo run tari_base_node``` +/// +/// ## Commands +/// +/// `help` - Displays a list of commands +/// `get-balance` - Displays the balance of the wallet (available, pending incoming, pending outgoing) +/// `send-tari` - Sends Tari, the amount needs to be specified, followed by the destination (public key or emoji id) and +/// an optional message `get-chain-metadata` - Lists information about the blockchain of this Base Node +/// `list-peers` - Lists information about peers known by this base node +/// `ban-peer` - Bans a peer +/// `unban-peer` - Removes a ban for a peer +/// `list-connections` - Lists active connections to this Base Node +/// `list-headers` - Lists header information. Either the first header height and the last header height needs to be +/// specified, or the amount of headers from the top `check-db` - Checks the blockchain database for missing blocks and +/// headers `calc-timing` - Calculates the time average time taken to mine a given range of blocks +/// `discover-peer` - Attempts to discover a peer on the network, a public key or emoji id needs to be specified +/// `get-block` - Retrieves a block, the height of the block needs to be specified +/// `get-mempool-stats` - Displays information about the mempool +/// `get-mempool-state` - Displays state information for the mempool +/// `whoami` - Displays identity information about this Base Node and it's wallet +/// `toggle-mining` - Turns the miner on or off +/// `quit` - Exits the Base Node +/// `exit` - Same as quit + +/// Used to display tabulated data +#[macro_use] +mod table; /// Utilities and helpers for building the base node instance mod builder; /// The command line interface definition and configuration mod cli; -/// Application-specific constants -mod consts; /// Miner lib Todo hide behind feature flag mod miner; /// Parser module used to control user commands @@ -38,40 +92,48 @@ use log::*; use parser::Parser; use rustyline::{config::OutputStreamType, error::ReadlineError, CompletionType, Config, EditMode, Editor}; use std::{path::PathBuf, sync::Arc}; -use tari_common::{load_configuration, GlobalConfig}; +use structopt::StructOpt; +use tari_common::GlobalConfig; use tari_comms::{multiaddr::Multiaddr, peer_manager::PeerFeatures, NodeIdentity}; use tari_shutdown::Shutdown; use tokio::runtime::Runtime; pub const LOG_TARGET: &str = "base_node::app"; +/// Enum to show failure information enum ExitCodes { ConfigError = 101, UnknownError = 102, } +impl From for ExitCodes { + fn from(err: tari_common::ConfigError) -> Self { + error!(target: LOG_TARGET, "{}", err); + Self::ConfigError + } +} + +/// Application entry point fn main() { - cli::print_banner(); match main_inner() { Ok(_) => std::process::exit(0), Err(exit_code) => std::process::exit(exit_code as i32), } } +/// Sets up the base node and runs the cli_loop fn main_inner() -> Result<(), ExitCodes> { // Parse and validate command-line arguments - let arguments = cli::parse_cli_args(); + let mut arguments = cli::Arguments::from_args(); + + // check and initialize configuration files + arguments.bootstrap.init_dirs()?; // Initialise the logger - if !tari_common::initialize_logging(&arguments.bootstrap.log_config) { - return Err(ExitCodes::ConfigError); - } + arguments.bootstrap.initialize_logging()?; // Load and apply configuration file - let cfg = load_configuration(&arguments.bootstrap).map_err(|err| { - error!(target: LOG_TARGET, "{}", err); - ExitCodes::ConfigError - })?; + let cfg = arguments.bootstrap.load_configuration()?; // Populate the configuration struct let node_config = GlobalConfig::convert_from(cfg).map_err(|err| { @@ -127,13 +189,16 @@ fn main_inner() -> Result<(), ExitCodes> { return Ok(()); } - if arguments.init { + if arguments.bootstrap.init { info!(target: LOG_TARGET, "Default configuration created. Done."); return Ok(()); } // Run, node, run! let parser = Parser::new(rt.handle().clone(), &ctx); + + cli::print_banner(parser.get_commands(), 3); + let base_node_handle = rt.spawn(ctx.run(rt.handle().clone())); info!( @@ -152,6 +217,12 @@ fn main_inner() -> Result<(), ExitCodes> { Ok(()) } +/// Sets up the tokio runtime based on the configuration +/// ## Parameters +/// `config` - The configuration of the base node +/// +/// ## Returns +/// A result containing the runtime on success, string indicating the error on failure fn setup_runtime(config: &GlobalConfig) -> Result { let num_core_threads = config.core_threads; let num_blocking_threads = config.blocking_threads; @@ -173,6 +244,13 @@ fn setup_runtime(config: &GlobalConfig) -> Result { .map_err(|e| format!("There was an error while building the node runtime. {}", e.to_string())) } +/// Runs the Base Node +/// ## Parameters +/// `parser` - The parser to process input commands +/// `shutdown` - The trigger for shutting down +/// +/// ## Returns +/// Doesn't return anything fn cli_loop(parser: Parser, mut shutdown: Shutdown) { let cli_config = Config::builder() .history_ignore_space(true) @@ -214,6 +292,15 @@ fn cli_loop(parser: Parser, mut shutdown: Shutdown) { } } +/// Loads the node identity, or creates a new one if the --create_id flag was specified +/// ## Parameters +/// `identity_file` - Reference to file path +/// `public_address` - Network address of the base node +/// `create_id` - Whether an identity needs to be created or not +/// `peer_features` - Enables features of the base node +/// +/// # Return +/// A NodeIdentity wrapped in an atomic reference counter on success, the exit code indicating the reason on failure fn setup_node_identity( identity_file: &PathBuf, public_address: &Multiaddr, diff --git a/applications/tari_base_node/src/miner.rs b/applications/tari_base_node/src/miner.rs index 3abbd61688..c15d248a1a 100644 --- a/applications/tari_base_node/src/miner.rs +++ b/applications/tari_base_node/src/miner.rs @@ -30,6 +30,13 @@ use tari_core::{ use tari_service_framework::handles::ServiceHandles; use tari_shutdown::ShutdownSignal; +/// Builds the miner for the base node +/// ## Parameters +/// `handles` - Handles to the base node services +/// `kill_signal` - Signal to stop the miner +/// `event_stream` - Message stream of the publish-subscribe message system +/// `consensus_manager`- The rules for the blockchain +/// `num_threads` - The number of threads on which to run the miner pub fn build_miner>( handles: H, kill_signal: ShutdownSignal, diff --git a/applications/tari_base_node/src/parser.rs b/applications/tari_base_node/src/parser.rs index 9916c69045..30cb8021fa 100644 --- a/applications/tari_base_node/src/parser.rs +++ b/applications/tari_base_node/src/parser.rs @@ -21,9 +21,17 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::LOG_TARGET; -use crate::{builder::NodeContainer, utils}; +use crate::{ + builder::NodeContainer, + table::Table, + utils, + utils::{format_duration_basic, format_naive_datetime}, +}; +use chrono::Utc; +use chrono_english::{parse_date_string, Dialect}; use log::*; use qrcode::{render::unicode, QrCode}; +use regex::Regex; use rustyline::{ completion::Completer, error::ReadlineError, @@ -33,6 +41,7 @@ use rustyline::{ }; use rustyline_derive::{Helper, Highlighter, Validator}; use std::{ + io::{self, Write}, str::FromStr, string::ToString, sync::{ @@ -52,12 +61,18 @@ use tari_comms::{ use tari_comms_dht::{envelope::NodeDestination, DhtDiscoveryRequester}; use tari_core::{ base_node::LocalNodeCommsInterface, + blocks::BlockHeader, + mempool::service::LocalMempoolService, tari_utilities::{hex::Hex, Hashable}, - transactions::tari_amount::{uT, MicroTari}, + transactions::{ + tari_amount::{uT, MicroTari}, + transaction::OutputFeatures, + }, }; +use tari_crypto::ristretto::pedersen::PedersenCommitmentFactory; use tari_shutdown::Shutdown; use tari_wallet::{ - output_manager_service::handle::OutputManagerHandle, + output_manager_service::{error::OutputManagerError, handle::OutputManagerHandle}, transaction_service::{error::TransactionServiceError, handle::TransactionServiceHandle}, util::emoji::EmojiId, }; @@ -69,17 +84,28 @@ use tokio::{runtime, time}; pub enum BaseNodeCommand { Help, GetBalance, + ListUtxos, + ListTransactions, + ListCompletedTransactions, + CancelTransaction, SendTari, GetChainMetadata, ListPeers, + ResetOfflinePeers, BanPeer, UnbanPeer, ListConnections, ListHeaders, + CheckDb, + CalcTiming, DiscoverPeer, GetBlock, + GetMempoolStats, + GetMempoolState, Whoami, ToggleMining, + MakeItRain, + CoinSplit, Quit, Exit, } @@ -92,16 +118,24 @@ pub struct Parser { discovery_service: DhtDiscoveryRequester, base_node_identity: Arc, peer_manager: Arc, + wallet_peer_manager: Arc, connection_manager: ConnectionManagerRequester, commands: Vec, hinter: HistoryHinter, wallet_output_service: OutputManagerHandle, node_service: LocalNodeCommsInterface, + mempool_service: LocalMempoolService, wallet_transaction_service: TransactionServiceHandle, enable_miner: Arc, } -// This will go through all instructions and look for potential matches +const MAKE_IT_RAIN_USAGE: &str = "\nmake-it-rain [Txs/s] [duration (s)] [start amount (uT)] [increment (uT)/Tx] \ + [\"start time (UTC)\" / 'now' for immediate start] [public key or emoji id to send \ + to] [message]\n or\nmake-it-rain [Txs/s] [duration (s)] [start amount (uT)] \ + [increment (uT)/Tx] [\"start time (UTC)\" / 'now' for immediate start] --file \ + [\"path to file\" containing list of 'public key or emoji id' 'message']\n"; + +/// This will go through all instructions and look for potential matches impl Completer for Parser { type Candidate = String; @@ -121,7 +155,7 @@ impl Completer for Parser { } } -// This allows us to make hints based on historic inputs +/// This allows us to make hints based on historic inputs impl Hinter for Parser { fn hint(&self, line: &str, pos: usize, ctx: &rustyline::Context<'_>) -> Option { self.hinter.hint(line, pos, ctx) @@ -137,21 +171,39 @@ impl Parser { discovery_service: ctx.base_node_dht().discovery_service_requester(), base_node_identity: ctx.base_node_identity(), peer_manager: ctx.base_node_comms().peer_manager(), + wallet_peer_manager: ctx.wallet_comms().peer_manager(), connection_manager: ctx.base_node_comms().connection_manager(), commands: BaseNodeCommand::iter().map(|x| x.to_string()).collect(), hinter: HistoryHinter {}, wallet_output_service: ctx.output_manager(), node_service: ctx.local_node(), + mempool_service: ctx.local_mempool(), wallet_transaction_service: ctx.wallet_transaction_service(), enable_miner: ctx.miner_enabled(), } } + /// This will return the list of commands from the parser + pub fn get_commands(&self) -> Vec { + self.commands.clone() + } + /// This will parse the provided command and execute the task pub fn handle_command(&mut self, command_str: &str, shutdown: &mut Shutdown) { if command_str.trim().is_empty() { return; } + + // Delimit arguments using spaces and pairs of quotation marks, which may include spaces + let arg_temp = command_str.trim().to_string(); + let re = Regex::new(r#"[^\s"]+|"(?:\\"|[^"])+""#).unwrap(); + let arg_temp_vec: Vec<&str> = re.find_iter(&arg_temp).map(|mat| mat.as_str()).collect(); + // Remove quotation marks left behind by `Regex` - it does not support look ahead and look behind + let mut del_arg_vec = Vec::new(); + for arg in arg_temp_vec.iter().skip(1) { + del_arg_vec.push(str::replace(arg, "\"", "")); + } + let mut args = command_str.split_whitespace(); let command = BaseNodeCommand::from_str(args.next().unwrap_or(&"help")); if command.is_err() { @@ -160,14 +212,15 @@ impl Parser { return; } let command = command.unwrap(); - self.process_command(command, args, shutdown); + self.process_command(command, args, del_arg_vec, shutdown); } - // Function to process commands + /// Function to process commands fn process_command<'a, I: Iterator>( &mut self, command: BaseNodeCommand, args: I, + del_arg_vec: Vec, shutdown: &mut Shutdown, ) { @@ -179,6 +232,18 @@ impl Parser { GetBalance => { self.process_get_balance(); }, + ListUtxos => { + self.process_list_unspent_outputs(); + }, + ListTransactions => { + self.process_list_transactions(); + }, + ListCompletedTransactions => { + self.process_list_completed_transactions(args); + }, + CancelTransaction => { + self.process_cancel_transaction(args); + }, SendTari => { self.process_send_tari(args); }, @@ -191,6 +256,12 @@ impl Parser { ListPeers => { self.process_list_peers(args); }, + ResetOfflinePeers => { + self.process_reset_offline_peers(); + }, + CheckDb => { + self.process_check_db(); + }, BanPeer => { self.process_ban_peer(args, true); }, @@ -203,15 +274,30 @@ impl Parser { ListHeaders => { self.process_list_headers(args); }, + CalcTiming => { + self.process_calc_timing(args); + }, ToggleMining => { self.process_toggle_mining(); }, GetBlock => { self.process_get_block(args); }, + GetMempoolStats => { + self.process_get_mempool_stats(); + }, + GetMempoolState => { + self.process_get_mempool_state(); + }, Whoami => { self.process_whoami(); }, + MakeItRain => { + self.process_make_it_rain(del_arg_vec); + }, + CoinSplit => { + self.process_coin_split(args); + }, Exit | Quit => { println!("Shutting down..."); info!( @@ -223,6 +309,7 @@ impl Parser { } } + /// Displays the commands or context specific help for a given command fn print_help<'a, I: Iterator>(&self, mut args: I) { let help_for = BaseNodeCommand::from_str(args.next().unwrap_or_default()).unwrap_or(BaseNodeCommand::Help); use BaseNodeCommand::*; @@ -235,6 +322,20 @@ impl Parser { GetBalance => { println!("Gets your balance"); }, + ListUtxos => { + println!("List your UTXOs"); + }, + ListTransactions => { + println!("Print a list of pending inbound and outbound transactions"); + }, + ListCompletedTransactions => { + println!("Print a list of completed transactions."); + println!("USAGE: list-completed-transactions [last n] or list-completed-transactions [n] [m]"); + }, + CancelTransaction => { + println!("Cancel a transaction"); + println!("USAGE: cancel-transaction [transaction ID]"); + }, SendTari => { println!("Sends an amount of Tari to a address call this command via:"); println!("send-tari [amount of tari to send] [destination public key or emoji id] [optional: msg]"); @@ -248,12 +349,18 @@ impl Parser { ListPeers => { println!("Lists the peers that this node knows about"); }, + ResetOfflinePeers => { + println!("Clear offline flag from all peers"); + }, BanPeer => { println!("Bans a peer"); }, UnbanPeer => { println!("Removes the peer ban"); }, + CheckDb => { + println!("Checks the blockchain database for missing blocks and headers"); + }, ListConnections => { println!("Lists the peer connections currently held by this node"); }, @@ -262,6 +369,9 @@ impl Parser { println!("list-headers [first header height] [last header height]"); println!("list-headers [number of headers starting from the chain tip back]"); }, + CalcTiming => { + println!("Calculates the time average time taken to mine a given range of blocks."); + }, ToggleMining => { println!("Enable or disable the miner on this node, calling this command will toggle the state"); }, @@ -269,19 +379,32 @@ impl Parser { println!("View a block of a height, call this command via:"); println!("get-block [height of the block]"); }, + GetMempoolStats => { + println!("Retrieves your mempools stats"); + }, + GetMempoolState => { + println!("Retrieves your mempools state"); + }, Whoami => { println!( "Display identity information about this node, including: public key, node ID and the public \ address" ); }, + MakeItRain => { + println!("Sends multiple amounts of Tari to a public wallet address via this command:"); + println!("{}", MAKE_IT_RAIN_USAGE); + }, + CoinSplit => { + println!("Constructs a transaction to split a small set of UTXOs into a large set of UTXOs"); + }, Exit | Quit => { println!("Exits the base node"); }, } } - // Function to process the get balance command + /// Function to process the get-balance command fn process_get_balance(&mut self) { let mut handler = self.wallet_output_service.clone(); self.executor.spawn(async move { @@ -296,7 +419,216 @@ impl Parser { }); } - // Function to process the get chain meta data + /// Function to process the list utxos command + fn process_list_unspent_outputs(&mut self) { + let mut handler1 = self.node_service.clone(); + let mut handler2 = self.wallet_output_service.clone(); + self.executor.spawn(async move { + let current_height = match handler1.get_metadata().await { + Err(err) => { + println!("Failed to retrieve chain metadata: {:?}", err); + warn!(target: LOG_TARGET, "Error communicating with base node: {:?}", err); + return; + }, + Ok(data) => data.height_of_longest_chain.unwrap() as i64, + }; + match handler2.get_unspent_outputs().await { + Err(e) => { + println!("Something went wrong"); + warn!(target: LOG_TARGET, "Error communicating with wallet: {:?}", e); + return; + }, + Ok(unspent_outputs) => { + if !unspent_outputs.is_empty() { + println!( + "\nYou have {} UTXOs: (value, commitment, mature in ? blocks, flags)", + unspent_outputs.len() + ); + let factory = PedersenCommitmentFactory::default(); + for uo in unspent_outputs.iter() { + let mature_in = std::cmp::max(uo.features.maturity as i64 - current_height, 0); + println!( + " {}, {}, {:>3}, {:?}", + uo.value, + uo.as_transaction_input(&factory, OutputFeatures::default()) + .commitment + .to_hex(), + mature_in, + uo.features.flags + ); + } + println!(); + } else { + println!("\nNo valid UTXOs found at this time\n"); + } + }, + }; + }); + } + + fn process_list_transactions(&mut self) { + let mut transactions = self.wallet_transaction_service.clone(); + + self.executor.spawn(async move { + println!("Inbound Transactions"); + match transactions.get_pending_inbound_transactions().await { + Ok(transactions) => { + if transactions.is_empty() { + println!("No pending inbound transactions found."); + } else { + let mut table = Table::new(); + table.set_titles(vec![ + "Transaction ID", + "Source Public Key", + "Amount", + "Status", + "Receiver State", + "Timestamp", + "Message", + ]); + for (tx_id, txn) in transactions { + table.add_row(row![ + tx_id, + txn.source_public_key, + txn.amount, + txn.status, + txn.receiver_protocol.state, + format_naive_datetime(&txn.timestamp), + txn.message + ]); + } + + table.print_std(); + } + }, + Err(err) => { + println!("Failed to retrieve inbound transactions: {:?}", err); + return; + }, + } + + println!(); + println!("Outbound Transactions"); + match transactions.get_pending_outbound_transactions().await { + Ok(transactions) => { + if transactions.is_empty() { + println!("No pending outbound transactions found."); + return; + } + + let mut table = Table::new(); + table.set_titles(vec![ + "Transaction ID", + "Dest Public Key", + "Amount", + "Fee", + "Status", + "Sender State", + "Timestamp", + "Message", + ]); + for (tx_id, txn) in transactions { + table.add_row(row![ + tx_id, + txn.destination_public_key, + txn.amount, + txn.fee, + txn.status, + txn.sender_protocol, + format_naive_datetime(&txn.timestamp), + txn.message + ]); + } + + table.print_std(); + }, + Err(err) => { + println!("Failed to retrieve inbound transactions: {:?}", err); + return; + }, + } + }); + } + + fn process_list_completed_transactions<'a, I: Iterator>(&self, mut args: I) { + let mut transactions = self.wallet_transaction_service.clone(); + let n = args.next().and_then(|s| s.parse::().ok()).unwrap_or(10); + let m = args.next().and_then(|s| s.parse::().ok()); + + self.executor.spawn(async move { + match transactions.get_completed_transactions().await { + Ok(transactions) => { + if transactions.is_empty() { + println!("No completed transactions found."); + return; + } + // TODO: This doesn't scale well because hashmap has a random ordering. Support for this query + // should be added at the database level + let mut transactions = transactions.into_iter().map(|(_, txn)| txn).collect::>(); + transactions.sort_by(|a, b| b.timestamp.cmp(&a.timestamp)); + let transactions = match m { + Some(m) => transactions.into_iter().skip(n).take(m).collect::>(), + None => transactions.into_iter().take(n).collect::>(), + }; + + let mut table = Table::new(); + table.set_titles(vec![ + "Transaction ID", + "Sender", + "Receiver", + "Amount", + "Fee", + "Status", + "Timestamp", + "Message", + ]); + for txn in transactions { + table.add_row(row![ + txn.tx_id, + txn.source_public_key, + txn.destination_public_key, + txn.amount, + txn.fee, + txn.status, + format_naive_datetime(&txn.timestamp), + txn.message + ]); + } + + table.print_std(); + }, + Err(err) => { + println!("Failed to retrieve inbound transactions: {:?}", err); + return; + }, + } + }); + } + + fn process_cancel_transaction<'a, I: Iterator>(&self, mut args: I) { + let mut transactions = self.wallet_transaction_service.clone(); + let tx_id = match args.next().and_then(|s| s.parse::().ok()) { + Some(id) => id, + None => { + println!("Please enter a valid transaction ID"); + println!("USAGE: cancel-transaction [transaction id]"); + return; + }, + }; + + self.executor.spawn(async move { + match transactions.cancel_transaction(tx_id).await { + Ok(_) => { + println!("Transaction {} successfully cancelled", tx_id); + }, + Err(err) => { + println!("Failed to cancel transaction: {:?}", err); + }, + } + }); + } + + /// Function to process the get-chain-metadata command fn process_get_chain_meta(&mut self) { let mut handler = self.node_service.clone(); self.executor.spawn(async move { @@ -311,6 +643,7 @@ impl Parser { }); } + /// Function to process the get-block command fn process_get_block<'a, I: Iterator>(&self, args: I) { let command_arg = args.take(4).collect::>(); let height = if command_arg.len() == 1 { @@ -332,7 +665,10 @@ impl Parser { match handler.get_blocks(vec![height]).await { Err(err) => { println!("Failed to retrieve blocks: {:?}", err); - warn!(target: LOG_TARGET, "Error communicating with base node: {:?}", err,); + warn!( + target: LOG_TARGET, + "Error communicating with local base node: {:?}", err, + ); return; }, Ok(mut data) => match data.pop() { @@ -343,6 +679,37 @@ impl Parser { }); } + /// Function to process the get-mempool-stats command + fn process_get_mempool_stats(&mut self) { + let mut handler = self.mempool_service.clone(); + self.executor.spawn(async move { + match handler.get_mempool_stats().await { + Ok(stats) => println!("{}", stats), + Err(err) => { + println!("Failed to retrieve mempool stats: {:?}", err); + warn!(target: LOG_TARGET, "Error communicating with local mempool: {:?}", err,); + return; + }, + }; + }); + } + + /// Function to process the get-mempool-state command + fn process_get_mempool_state(&mut self) { + let mut handler = self.mempool_service.clone(); + self.executor.spawn(async move { + match handler.get_mempool_state().await { + Ok(state) => println!("{}", state), + Err(err) => { + println!("Failed to retrieve mempool state: {:?}", err); + warn!(target: LOG_TARGET, "Error communicating with local mempool: {:?}", err,); + return; + }, + }; + }); + } + + /// Function to process the discover-peer command fn process_discover_peer<'a, I: Iterator>(&mut self, mut args: I) { let mut dht = self.discovery_service.clone(); @@ -358,10 +725,12 @@ impl Parser { self.executor.spawn(async move { let start = Instant::now(); println!("🌎 Peer discovery started."); - match dht.discover_peer(dest_pubkey, None, NodeDestination::Unknown).await { + match dht + .discover_peer(dest_pubkey.clone(), NodeDestination::PublicKey(dest_pubkey)) + .await + { Ok(p) => { - let end = Instant::now(); - println!("⚡️ Discovery succeeded in {}ms!", (end - start).as_millis()); + println!("⚡️ Discovery succeeded in {}ms!", start.elapsed().as_millis()); println!("This peer was found:"); println!("{}", p); }, @@ -372,6 +741,7 @@ impl Parser { }); } + /// Function to process the list-peers command fn process_list_peers<'a, I: Iterator>(&mut self, mut args: I) { let peer_manager = self.peer_manager.clone(); let filter = args.next().map(ToOwned::to_owned); @@ -391,12 +761,48 @@ impl Parser { match peer_manager.perform_query(query).await { Ok(peers) => { let num_peers = peers.len(); - println!( - "{}", - peers - .into_iter() - .fold(String::new(), |acc, p| format!("{}\n{}", acc, p)) - ); + println!(); + let mut table = Table::new(); + table.set_titles(vec![ + "NodeId", + "Public Key", + "Flags", + "Role", + "Status", + "Added at", + "Last connection", + ]); + + for peer in peers { + let status_str = { + let mut s = Vec::new(); + if let Some(offline_at) = peer.offline_at.as_ref() { + s.push(format!("OFFLINE since {}", format_naive_datetime(offline_at))); + } + + if let Some(dt) = peer.banned_until() { + s.push(format!("BANNED until {}", format_naive_datetime(dt))); + } + s.join(", ") + }; + table.add_row(row![ + peer.node_id.short_str(), + peer.public_key, + format!("{:?}", peer.flags), + { + if peer.features == PeerFeatures::COMMUNICATION_CLIENT { + "Wallet" + } else { + "Base node" + } + }, + status_str, + peer.added_at.date(), + peer.connection_stats, + ]); + } + table.print_std(); + println!("{} peer(s) known by this node", num_peers); }, Err(err) => { @@ -408,62 +814,150 @@ impl Parser { }); } - fn process_ban_peer<'a, I: Iterator>(&mut self, mut args: I, is_banned: bool) { + /// Function to process the ban-peer command + fn process_ban_peer<'a, I: Iterator>(&mut self, mut args: I, must_ban: bool) { let peer_manager = self.peer_manager.clone(); + let wallet_peer_manager = self.wallet_peer_manager.clone(); let mut connection_manager = self.connection_manager.clone(); let public_key = match args.next().and_then(parse_emoji_id_or_public_key) { Some(v) => Box::new(v), None => { println!("Please enter a valid destination public key or emoji id"); - println!("ban-peer/unban-peer [hex public key or emoji id]"); + println!( + "ban-peer/unban-peer [hex public key or emoji id] (length of time to ban the peer for in seconds)" + ); return; }, }; + let pubkeys = vec![ + self.base_node_identity.public_key(), + self.wallet_node_identity.public_key(), + ]; + if pubkeys.contains(&&*public_key) { + println!("Cannot ban our own wallet or node"); + return; + } + + let duration = args + .next() + .and_then(|s| s.parse::().ok()) + .map(Duration::from_secs) + .unwrap_or_else(|| Duration::from_secs(std::u64::MAX)); + self.executor.spawn(async move { - match peer_manager.set_banned(&public_key, is_banned).await { - Ok(node_id) => { - if is_banned { - match connection_manager.disconnect_peer(node_id).await { - Ok(_) => { - println!("Peer was banned."); - }, - Err(err) => { - println!( - "Peer was banned but an error occurred when disconnecting them: {:?}", - err - ); - }, - } - } else { - println!("Peer ban was removed."); - } - }, - Err(err) => { - println!("Failed to ban/unban peer: {:?}", err); - error!(target: LOG_TARGET, "Could not ban/unban peer: {:?}", err); - return; - }, + if must_ban { + match peer_manager.ban_for(&public_key, duration).await { + Ok(node_id) => match connection_manager.disconnect_peer(node_id).await { + Ok(_) => { + println!("Peer was banned in base node."); + }, + Err(err) => { + println!( + "Peer was banned but an error occurred when disconnecting them: {:?}", + err + ); + }, + }, + Err(err) if err.is_peer_not_found() => { + println!("Peer not found in base node"); + }, + Err(err) => { + println!("Failed to ban peer: {:?}", err); + error!(target: LOG_TARGET, "Could not ban peer: {:?}", err); + }, + } + + match wallet_peer_manager.ban_for(&public_key, duration).await { + Ok(node_id) => match connection_manager.disconnect_peer(node_id).await { + Ok(_) => { + println!("Peer was banned in wallet."); + }, + Err(err) => { + println!( + "Peer was banned but an error occurred when disconnecting them: {:?}", + err + ); + }, + }, + Err(err) if err.is_peer_not_found() => { + println!("Peer not found in wallet"); + }, + Err(err) => { + println!("Failed to ban peer: {:?}", err); + error!(target: LOG_TARGET, "Could not ban peer: {:?}", err); + }, + } + } else { + match peer_manager.unban(&public_key).await { + Ok(_) => { + println!("Peer ban was removed from base node."); + }, + Err(err) if err.is_peer_not_found() => { + println!("Peer not found in base node"); + }, + Err(err) => { + println!("Failed to ban peer: {:?}", err); + error!(target: LOG_TARGET, "Could not ban peer: {:?}", err); + }, + } + + match wallet_peer_manager.unban(&public_key).await { + Ok(_) => { + println!("Peer ban was removed from wallet."); + }, + Err(err) if err.is_peer_not_found() => { + println!("Peer not found in wallet"); + }, + Err(err) => { + println!("Failed to ban peer: {:?}", err); + error!(target: LOG_TARGET, "Could not ban peer: {:?}", err); + }, + } } }); } + /// Function to process the list-connections command fn process_list_connections(&self) { let mut connection_manager = self.connection_manager.clone(); + let peer_manager = self.peer_manager.clone(); + self.executor.spawn(async move { match connection_manager.get_active_connections().await { Ok(conns) if conns.is_empty() => { println!("No active peer connections."); }, Ok(conns) => { + println!(); let num_connections = conns.len(); - println!( - "{}", - conns - .into_iter() - .fold(String::new(), |acc, p| format!("{}\n{}", acc, p)) - ); + let mut table = Table::new(); + table.set_titles(vec!["NodeId", "Public Key", "Address", "Direction", "Uptime", "Role"]); + for conn in conns { + let peer = peer_manager + .find_by_node_id(conn.peer_node_id()) + .await + .expect("Unexpected peer database error or peer not found"); + + table.add_row(row![ + peer.node_id.short_str(), + peer.public_key, + conn.address(), + conn.direction(), + format_duration_basic(conn.connected_since()), + { + if peer.features == PeerFeatures::COMMUNICATION_CLIENT { + "Wallet" + } else { + "Base node" + } + }, + ]); + } + + table.print_std(); + println!("{} active connection(s)", num_connections); }, Err(err) => { @@ -475,6 +969,34 @@ impl Parser { }); } + fn process_reset_offline_peers(&self) { + let peer_manager = self.peer_manager.clone(); + self.executor.spawn(async move { + let result = peer_manager + .update_each(|mut peer| { + if peer.is_offline() { + peer.set_offline(false); + Some(peer) + } else { + None + } + }) + .await; + + match result { + Ok(num_updated) => { + println!("{} peer(s) were unmarked as offline.", num_updated); + }, + Err(err) => { + println!("Failed to clear offline peer states: {:?}", err); + error!(target: LOG_TARGET, "{:?}", err); + return; + }, + } + }); + } + + /// Function to process the toggle-mining command fn process_toggle_mining(&mut self) { let new_state = !self.enable_miner.load(Ordering::SeqCst); self.enable_miner.store(new_state, Ordering::SeqCst); @@ -486,19 +1008,32 @@ impl Parser { debug!(target: LOG_TARGET, "Mining state is now switched to {}", new_state); } + /// Function to process the list-headers command fn process_list_headers<'a, I: Iterator>(&self, args: I) { - let command_arg = args.take(4).collect::>(); + let command_arg = args.map(|arg| arg.to_string()).take(4).collect::>(); if (command_arg.is_empty()) || (command_arg.len() > 2) { println!("Command entered incorrectly, please use the following formats: "); println!("list-headers [first header height] [last header height]"); println!("list-headers [amount of headers from top]"); return; } + let handler = self.node_service.clone(); + self.executor.spawn(async move { + let headers = Parser::get_headers(handler, command_arg).await; + for header in headers { + println!("\n\nHeader hash: {}", header.hash().to_hex()); + println!("{}", header); + } + }); + } + + /// Function to process the get-headers command + async fn get_headers(mut handler: LocalNodeCommsInterface, command_arg: Vec) -> Vec { let height = if command_arg.len() == 2 { let height = command_arg[1].parse::(); if height.is_err() { println!("Invalid number provided"); - return; + return Vec::new(); }; Some(height.unwrap()) } else { @@ -507,55 +1042,105 @@ impl Parser { let start = command_arg[0].parse::(); if start.is_err() { println!("Invalid number provided"); - return; + return Vec::new(); }; let counter = if command_arg.len() == 2 { let start = start.unwrap(); let temp_height = height.clone().unwrap(); if temp_height <= start { - println!("start hight should be bigger than the end height"); - return; + println!("Start height should be bigger than the end height"); + return Vec::new(); } (temp_height - start) as usize } else { start.unwrap() as usize }; + let mut height = if let Some(v) = height { + v + } else { + match handler.get_metadata().await { + Err(err) => { + println!("Failed to retrieve chain height: {:?}", err); + warn!(target: LOG_TARGET, "Error communicating with base node: {}", err,); + 0 + }, + Ok(data) => data.height_of_longest_chain.unwrap_or(0), + } + }; + let mut headers = Vec::new(); + headers.push(height); + while (headers.len() <= counter) && (height > 0) { + height -= 1; + headers.push(height); + } + match handler.get_headers(headers).await { + Err(err) => { + println!("Failed to retrieve headers: {:?}", err); + warn!(target: LOG_TARGET, "Error communicating with base node: {}", err,); + Vec::new() + }, + Ok(data) => data, + } + } - let mut handler = self.node_service.clone(); + /// Function to process the calc-timing command + fn process_calc_timing<'a, I: Iterator>(&self, args: I) { + let command_arg = args.map(|arg| arg.to_string()).take(4).collect::>(); + if (command_arg.is_empty()) || (command_arg.len() > 2) { + println!("Command entered incorrectly, please use the following formats: "); + println!("calc-timing [first header height] [last header height]"); + println!("calc-timing [number of headers from chain tip]"); + return; + } + + let handler = self.node_service.clone(); self.executor.spawn(async move { - let mut height = if let Some(v) = height { - v - } else { - match handler.get_metadata().await { - Err(err) => { - println!("Failed to retrieve chain height: {:?}", err); - warn!(target: LOG_TARGET, "Error communicating with base node: {}", err,); - 0 - }, - Ok(data) => data.height_of_longest_chain.unwrap_or(0), - } - }; - let mut headers = Vec::new(); - headers.push(height); - while (headers.len() <= counter) && (height > 0) { + let headers = Parser::get_headers(handler, command_arg).await; + let (max, min, avg) = timing_stats(&headers); + println!("Max block time: {}", max); + println!("Min block time: {}", min); + println!("Avg block time: {}", avg); + }); + } + + /// Function to process the check-db command + fn process_check_db(&mut self) { + // Todo, add calls to ask peers for missing data + let mut node = self.node_service.clone(); + self.executor.spawn(async move { + let meta = node.get_metadata().await.expect("Could not retrieve chain meta"); + + let mut height = meta.height_of_longest_chain.expect("Could not retrieve chain height"); + let mut missing_blocks = Vec::new(); + let mut missing_headers = Vec::new(); + print!("Searching for height: "); + while height > 0 { + print!("{}", height); + io::stdout().flush().unwrap(); + let block = node.get_blocks(vec![height]).await; + if block.is_err() { + // for some apparent reason this block is missing, means we have to ask for it again + missing_blocks.push(height); + }; height -= 1; - headers.push(height); + let next_header = node.get_headers(vec![height]).await; + if next_header.is_err() { + // this header is missing, so we stop here and need to ask for this header + missing_headers.push(height); + }; + print!("\x1B[{}D\x1B[K", (height + 1).to_string().chars().count()); } - let headers = match handler.get_header(headers).await { - Err(err) => { - println!("Failed to retrieve headers: {:?}", err); - warn!(target: LOG_TARGET, "Error communicating with base node: {}", err,); - return; - }, - Ok(data) => data, - }; - for header in headers { - println!("\n\nHeader hash: {}", header.hash().to_hex()); - println!("{}", header); + println!("Complete"); + for missing_block in missing_blocks { + println!("Missing block at height: {}", missing_block); + } + for missing_header_height in missing_headers { + println!("Missing header at height: {}", missing_header_height) } }); } + /// Function to process the whoami command fn process_whoami(&self) { println!("======== Wallet =========="); println!("{}", self.wallet_node_identity); @@ -579,7 +1164,52 @@ impl Parser { println!("{}", self.base_node_identity); } - // Function to process the send transaction function + /// Function to process the coin split command + fn process_coin_split<'a, I: Iterator>(&mut self, mut args: I) { + let amount_per_split = args.next().and_then(|v| v.parse::().ok()); + let split_count = args.next().and_then(|v| v.parse::().ok()); + if amount_per_split.is_none() | split_count.is_none() { + println!("Command entered incorrectly, please use the following format: "); + println!("coin-split [amount of tari to allocated to each UTXO] [number of UTXOs to create]"); + return; + } + let amount_per_split: MicroTari = amount_per_split.unwrap().into(); + let split_count = split_count.unwrap(); + + // Use output manager service to get utxo and create the coin split transaction + let fee_per_gram = 25 * uT; // TODO: use configured fee per gram + let mut output_manager = self.wallet_output_service.clone(); + let mut txn_service = self.wallet_transaction_service.clone(); + self.executor.spawn(async move { + match output_manager + .create_coin_split(amount_per_split, split_count, fee_per_gram, None) + .await + { + Ok((tx_id, tx, fee, amount)) => { + match txn_service + .submit_transaction(tx_id, tx, fee, amount, "Coin split".into()) + .await + { + Ok(_) => println!("Coin split transaction created with tx_id:\n{}", tx_id), + Err(e) => { + println!("Something went wrong creating a coin split transaction"); + println!("{:?}", e); + warn!(target: LOG_TARGET, "Error communicating with wallet: {:?}", e); + return; + }, + }; + }, + Err(e) => { + println!("Something went wrong creating a coin split transaction"); + println!("{:?}", e); + warn!(target: LOG_TARGET, "Error communicating with wallet: {:?}", e); + return; + }, + }; + }); + } + + /// Function to process the send transaction command fn process_send_tari<'a, I: Iterator>(&mut self, mut args: I) { let amount = args.next().and_then(|v| v.parse::().ok()); if amount.is_none() { @@ -628,25 +1258,23 @@ impl Parser { .await { Ok(true) => { - let end = Instant::now(); println!( "Discovery succeeded for peer {} after {}ms", dest_pubkey, - (end - start).as_millis() + start.elapsed().as_millis() ); debug!( target: LOG_TARGET, "Discovery succeeded for peer {} after {}ms", dest_pubkey, - (end - start).as_millis() + start.elapsed().as_millis() ); }, Ok(false) => { - let end = Instant::now(); println!( "Discovery failed for peer {} after {}ms", dest_pubkey, - (end - start).as_millis() + start.elapsed().as_millis() ); println!("The peer may be offline. Please try again later."); @@ -654,7 +1282,7 @@ impl Parser { target: LOG_TARGET, "Discovery failed for peer {} after {}ms", dest_pubkey, - (end - start).as_millis() + start.elapsed().as_millis() ); }, Err(_) => { @@ -667,6 +1295,9 @@ impl Parser { }, } }, + Err(TransactionServiceError::OutputManagerError(OutputManagerError::NotEnoughFunds)) => { + println!("Not enough funds to fulfill the transaction."); + }, Err(e) => { println!("Something went wrong sending funds"); println!("{:?}", e); @@ -677,10 +1308,187 @@ impl Parser { }; }); } + + // Function to process the make it rain transaction function + fn process_make_it_rain(&mut self, command_arg: Vec) { + let command_error_msg = + "Command entered incorrectly, please use the following format:\n".to_owned() + MAKE_IT_RAIN_USAGE; + + if (command_arg.is_empty()) || (command_arg.len() < 6) { + println!("{}", command_error_msg); + println!("Expected at least 6 arguments, received {}\n", command_arg.len()); + return; + } + + // [number of Txs/s] + let mut inc: u8 = 0; + let tx_per_s = command_arg[inc as usize].parse::(); + if tx_per_s.is_err() { + println!("{}", command_error_msg); + println!("Invalid data provided for [number of Txs/s]\n"); + return; + }; + let tx_per_s = tx_per_s.unwrap(); + + // [test duration (s)] + inc += 1; + let duration = command_arg[inc as usize].parse::(); + if duration.is_err() { + println!("{}", command_error_msg); + println!("Invalid data provided for [test duration (s)]\n"); + return; + }; + let duration = duration.unwrap(); + if (tx_per_s * duration as f64) < 1.0 { + println!("{}", command_error_msg); + println!("Invalid data provided for [number of Txs/s] * [test duration (s)], must be >= 1\n"); + return; + } + + // [starting amount (uT)] + inc += 1; + let start_amount = command_arg[inc as usize].parse::(); + if start_amount.is_err() { + println!("{}", command_error_msg); + println!("Invalid data provided for [starting amount (uT)]\n"); + return; + } + let start_amount: MicroTari = start_amount.unwrap().into(); + + // [increment (uT)/Tx] + inc += 1; + let amount_inc = command_arg[inc as usize].parse::(); + if amount_inc.is_err() { + println!("{}", command_error_msg); + println!("Invalid data provided for [increment (uT)/Tx]\n"); + return; + } + let amount_inc: MicroTari = amount_inc.unwrap().into(); + + // [start time (UTC) / 'now'] + inc += 1; + let time = command_arg[inc as usize].to_string(); + let time_utc_ref = Utc::now(); + let mut _time_utc_start = Utc::now(); + let datetime = parse_date_string(&time, Utc::now(), Dialect::Uk); + match datetime { + Ok(t) => { + if t > time_utc_ref { + _time_utc_start = t; + } + }, + Err(e) => { + println!("{}", command_error_msg); + println!("Invalid data provided for [start time (UTC) / 'now']\n"); + println!("{}", e); + return; + }, + } + + // TODO: Read in recipient address list and custom message from file + // [public key or emoji id to send to] + inc += 1; + let key = command_arg[inc as usize].to_string(); + let dest_pubkey = match parse_emoji_id_or_public_key(&key) { + Some(v) => v, + None => { + println!("{}", command_error_msg); + println!("Invalid data provided for [public key or emoji id to send to]\n"); + return; + }, + }; + + // [message] + let mut msg = "".to_string(); + inc += 1; + if command_arg.len() > inc as usize { + for arg in command_arg.iter().skip(inc as usize) { + msg = msg + arg + " "; + } + msg = msg.trim().to_string(); + } + + // TODO: Implement Tx rate vs. as fast as possible, must be non-blocking + // TODO: Start at specified time, must be non-blocking + for i in 0..(tx_per_s * duration as f64) as usize { + // `send-tari` commands: [amount of tari to send] [destination public key or emoji id] [optional: msg] + let command_str = + (start_amount.0 + amount_inc.0 * i as u64).to_string() + " " + &dest_pubkey.to_string() + " " + &msg; + let args = command_str.split_whitespace(); + // Execute + self.process_send_tari(args); + } + } } +/// Returns a CommsPublicKey from either a emoji id or a public key fn parse_emoji_id_or_public_key(key: &str) -> Option { EmojiId::str_to_pubkey(&key.trim().replace('|', "")) .or_else(|_| CommsPublicKey::from_hex(key)) .ok() } + +/// Given a slice of headers (in reverse order), calculate the maximum, minimum and average periods between them +fn timing_stats(headers: &[BlockHeader]) -> (u64, u64, f64) { + let (max, min) = headers.windows(2).fold((0u64, std::u64::MAX), |(max, min), next| { + let delta_t = match next[0].timestamp.checked_sub(next[1].timestamp) { + Some(delta) => delta.as_u64(), + None => 0u64, + }; + let min = min.min(delta_t); + let max = max.max(delta_t); + (max, min) + }); + let avg = if headers.len() >= 2 { + let dt = headers.first().unwrap().timestamp - headers.last().unwrap().timestamp; + let n = headers.len() - 1; + dt.as_u64() as f64 / n as f64 + } else { + 0.0 + }; + (max, min, avg) +} + +#[cfg(test)] +mod test { + use crate::parser::timing_stats; + use tari_core::{blocks::BlockHeader, tari_utilities::epoch_time::EpochTime}; + + #[test] + fn test_timing_stats() { + let headers = vec![500, 350, 300, 210, 100u64] + .into_iter() + .map(|t| BlockHeader { + timestamp: EpochTime::from(t), + ..BlockHeader::default() + }) + .collect::>(); + let (max, min, avg) = timing_stats(&headers); + assert_eq!(max, 150); + assert_eq!(min, 50); + assert_eq!(avg, 100f64); + } + + #[test] + fn timing_negative_blocks() { + let headers = vec![150, 90, 100u64] + .into_iter() + .map(|t| BlockHeader { + timestamp: EpochTime::from(t), + ..BlockHeader::default() + }) + .collect::>(); + let (max, min, avg) = timing_stats(&headers); + assert_eq!(max, 60); + assert_eq!(min, 0); + assert_eq!(avg, 25f64); + } + + #[test] + fn timing_empty_list() { + let (max, min, avg) = timing_stats(&[]); + assert_eq!(max, 0); + assert_eq!(min, std::u64::MAX); + assert_eq!(avg, 0f64); + } +} diff --git a/applications/tari_base_node/src/table.rs b/applications/tari_base_node/src/table.rs new file mode 100644 index 0000000000..71ee599de9 --- /dev/null +++ b/applications/tari_base_node/src/table.rs @@ -0,0 +1,147 @@ +// Copyright 2020, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use std::{cmp, io, io::Write}; + +/// Basic ASCII table implementation that is easy to put in a spreadsheet. +pub struct Table<'t, 's> { + titles: Option>, + rows: Vec>, + delim_str: &'s str, +} + +impl<'t, 's> Table<'t, 's> { + pub fn new() -> Self { + Self { + titles: None, + rows: Vec::new(), + delim_str: "|", + } + } + + pub fn set_titles(&mut self, titles: Vec<&'t str>) { + self.titles = Some(titles); + } + + pub fn add_row(&mut self, row: Vec) { + self.rows.push(row); + } + + pub fn render(&self, out: &mut T) -> io::Result<()> { + self.render_titles(out)?; + if self.rows.len() > 0 { + out.write_all(b"\n")?; + self.render_rows(out)?; + out.write_all(b"\n")?; + } + Ok(()) + } + + pub fn print_std(&self) { + self.render(&mut io::stdout()).unwrap(); + } + + fn col_width(&self, idx: usize) -> usize { + let title_width = self.titles.as_ref().map(|titles| titles[idx].len()).unwrap_or(0); + let rows_width = self.rows.iter().fold(0, |max, r| { + if idx < r.len() { + cmp::max(max, r[idx].len()) + } else { + max + } + }); + cmp::max(title_width, rows_width) + } + + fn render_titles(&self, out: &mut T) -> io::Result<()> { + if let Some(titles) = self.titles.as_ref() { + self.render_row(titles, out)?; + } + Ok(()) + } + + fn render_rows(&self, out: &mut T) -> io::Result<()> { + let rows_len = self.rows.len(); + for (i, row) in self.rows.iter().enumerate() { + self.render_row(row, out)?; + if i < rows_len - 1 { + out.write_all(b"\n")?; + } + } + Ok(()) + } + + fn render_row, S: ToString>(&self, row: I, out: &mut T) -> io::Result<()> { + let row_len = row.as_ref().len(); + for (i, string) in row.as_ref().iter().enumerate() { + let s = string.to_string(); + let width = self.col_width(i); + let pad_left = if i == 0 { "" } else { " " }; + let pad_right = " ".repeat(width - s.len() + 1); + out.write_all(pad_left.as_bytes())?; + out.write_all(s.as_bytes())?; + out.write_all(pad_right.as_bytes())?; + if i < row_len - 1 { + out.write_all(self.delim_str.as_bytes())?; + } + } + Ok(()) + } +} + +macro_rules! row { + ($($s:expr),*$(,)?) => { + vec![$($s.to_string()),*] + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn renders_titles() { + let mut table = Table::new(); + table.set_titles(vec!["Hello", "World", "Bonjour", "Le", "Monde"]); + let mut buf = io::Cursor::new(Vec::new()); + table.render(&mut buf).unwrap(); + assert_eq!( + String::from_utf8_lossy(&buf.into_inner()), + "Hello | World | Bonjour | Le | Monde " + ); + } + + #[test] + fn renders_rows_with_titles() { + let mut table = Table::new(); + table.set_titles(vec!["Name", "Age", "Telephone Number", "Favourite Headwear"]); + table.add_row(row!["Trevor", 132, "+123 12323223", "Pith Helmet"]); + table.add_row(row![]); + table.add_row(row!["Hatless", 2]); + let mut buf = io::Cursor::new(Vec::new()); + table.render(&mut buf).unwrap(); + assert_eq!( + String::from_utf8_lossy(&buf.into_inner()), + "Name | Age | Telephone Number | Favourite Headwear \nTrevor | 132 | +123 12323223 | Pith Helmet \n\nHatless | 2 \n" + ); + } +} diff --git a/applications/tari_base_node/src/utils.rs b/applications/tari_base_node/src/utils.rs index bdf59b26c9..33769864bb 100644 --- a/applications/tari_base_node/src/utils.rs +++ b/applications/tari_base_node/src/utils.rs @@ -20,20 +20,37 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use chrono::NaiveDateTime; use futures::{Stream, StreamExt}; -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use tari_wallet::transaction_service::handle::TransactionEvent; +use tokio::sync::broadcast::RecvError; +pub const LOG_TARGET: &str = "base_node::app::utils"; + +/// Asynchronously processes the event stream checking to see if the given tx_id is present or not +/// ## Parameters +/// `event_stream` - The stream of events to search +/// `expected_tx_id` - The transaction id to be searched for +/// +/// ## Returns +/// True if found, false otherwise pub async fn wait_for_discovery_transaction_event(mut event_stream: S, expected_tx_id: u64) -> bool -where S: Stream> + Unpin { +where S: Stream, RecvError>> + Unpin { loop { match event_stream.next().await { - Some(event) => { - if let TransactionEvent::TransactionSendDiscoveryComplete(tx_id, is_success) = &*event { - if *tx_id == expected_tx_id { - break *is_success; + Some(event_result) => match event_result { + Ok(event) => { + if let TransactionEvent::TransactionDirectSendResult(tx_id, is_success) = &*event { + if *tx_id == expected_tx_id { + break *is_success; + } } - } + }, + Err(e) => { + log::error!(target: LOG_TARGET, "Error reading from event broadcast channel {:?}", e); + break false; + }, }, None => { break false; @@ -41,3 +58,38 @@ where S: Stream> + Unpin { } } } + +pub fn format_duration_basic(duration: Duration) -> String { + let secs = duration.as_secs(); + if secs > 60 { + let mins = secs / 60; + if mins > 60 { + let hours = mins / 60; + format!("{}h {}m {}s", hours, mins % 60, secs % 60) + } else { + format!("{}m {}s", mins, secs % 60) + } + } else { + format!("{}s", secs) + } +} + +/// Standard formatting helper function for a NaiveDateTime +pub fn format_naive_datetime(dt: &NaiveDateTime) -> String { + dt.format("%Y-%m-%d %H:%M:%S").to_string() +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn formats_duration() { + let s = format_duration_basic(Duration::from_secs(5)); + assert_eq!(s, "5s"); + let s = format_duration_basic(Duration::from_secs(23 * 60 + 10)); + assert_eq!(s, "23m 10s"); + let s = format_duration_basic(Duration::from_secs(9 * 60 * 60 + 35 * 60 + 45)); + assert_eq!(s, "9h 35m 45s"); + } +} diff --git a/applications/test_faucet/Cargo.toml b/applications/test_faucet/Cargo.toml index 64429bb338..dedea95134 100644 --- a/applications/test_faucet/Cargo.toml +++ b/applications/test_faucet/Cargo.toml @@ -13,7 +13,7 @@ serde_json = "1.0" rand = "0.7.2" [dependencies.tari_core] -version = "^0.0" +version = "^0.1" path = "../../base_layer/core/" default-features = false features = ["transactions", "avx2"] diff --git a/base_layer/core/Cargo.toml b/base_layer/core/Cargo.toml index 1520851352..3fe1146ad3 100644 --- a/base_layer/core/Cargo.toml +++ b/base_layer/core/Cargo.toml @@ -6,7 +6,7 @@ repository = "https://github.com/tari-project/tari" homepage = "https://tari.com" readme = "README.md" license = "BSD-3-Clause" -version = "0.0.10" +version = "0.1.0" edition = "2018" [features] @@ -18,25 +18,26 @@ base_node_proto = [] avx2 = ["tari_crypto/avx2"] [dependencies] -tari_comms = { version = "^0.0", path = "../../comms"} -tari_infra_derive = { path = "../../infrastructure/derive", version = "^0.0" } +tari_comms = { version = "^0.1", path = "../../comms"} +tari_infra_derive = { version = "^0.0", path = "../../infrastructure/derive" } tari_crypto = { version = "^0.3" } -tari_storage = { path = "../../infrastructure/storage", version = "^0.0" } -tari_common = {path = "../../common", version= "^0.0"} +tari_storage = { version = "^0.1", path = "../../infrastructure/storage" } +tari_common = { version= "^0.1", path = "../../common"} tari_service_framework = { version = "^0.0", path = "../service_framework"} -tari_p2p = {path = "../../base_layer/p2p", version = "^0.0"} -tari_comms_dht = { version = "^0.0", path = "../../comms/dht"} +tari_p2p = { version = "^0.1", path = "../../base_layer/p2p" } +tari_comms_dht = { version = "^0.1", path = "../../comms/dht"} tari_broadcast_channel = "^0.1" tari_pubsub = "^0.1" -tari_shutdown = { path = "../../infrastructure/shutdown", version = "^0.0"} -tari_mmr = { path = "../../base_layer/mmr", version = "^0.0", optional = true } +tari_shutdown = { version = "^0.0", path = "../../infrastructure/shutdown" } +tari_mmr = { version = "^0.1", path = "../../base_layer/mmr", optional = true } -randomx-rs = { version = "0.1.2", optional = true } +randomx-rs = { version = "0.2.0", optional = true } monero = { version = "0.5", features= ["serde_support"], optional = true } bitflags = "1.0.4" chrono = { version = "0.4.6", features = ["serde"]} digest = "0.8.0" derive-error = "0.0.4" +thiserror = "1.0.15" rand = "0.7.2" serde = { version = "1.0.97", features = ["derive"] } rmp-serde = "0.13.7" @@ -65,13 +66,13 @@ strum = "0.17.1" strum_macros = "0.17.1" [dev-dependencies] -tari_p2p = {path = "../../base_layer/p2p", version = "^0.0", features=["test-mocks"]} +tari_p2p = {path = "../../base_layer/p2p", version = "^0.1", features=["test-mocks"]} tari_test_utils = { path = "../../infrastructure/test_utils", version = "^0.0" } env_logger = "0.7.0" tempdir = "0.3.7" tokio-macros = "0.2.4" -tari_wallet = { path = "../../base_layer/wallet", version = "^0.0" } +tari_wallet = { path = "../../base_layer/wallet", version = "^0.1" } tokio-test = "0.2.0" [build-dependencies] -tari_common = { version = "^0.0", path="../../common"} +tari_common = { version = "^0.1", path="../../common"} diff --git a/base_layer/core/src/base_node/chain_metadata_service/handle.rs b/base_layer/core/src/base_node/chain_metadata_service/handle.rs index f3bdd00463..2477a00956 100644 --- a/base_layer/core/src/base_node/chain_metadata_service/handle.rs +++ b/base_layer/core/src/base_node/chain_metadata_service/handle.rs @@ -22,6 +22,7 @@ use crate::chain_storage::ChainMetadata; use futures::{stream::Fuse, StreamExt}; +use std::fmt::{Display, Error, Formatter}; use tari_broadcast_channel::Subscriber; use tari_comms::peer_manager::NodeId; @@ -40,6 +41,13 @@ impl PeerChainMetadata { } } +impl Display for PeerChainMetadata { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { + writeln!(f, "Node ID: {}", self.node_id)?; + writeln!(f, "Chain metadata: {}", self.chain_metadata) + } +} + #[derive(Debug)] pub enum ChainMetadataEvent { PeerChainMetadataReceived(Vec), diff --git a/base_layer/core/src/base_node/chain_metadata_service/service.rs b/base_layer/core/src/base_node/chain_metadata_service/service.rs index a3107f9e50..feff551df0 100644 --- a/base_layer/core/src/base_node/chain_metadata_service/service.rs +++ b/base_layer/core/src/base_node/chain_metadata_service/service.rs @@ -120,7 +120,7 @@ impl ChainMetadataService { /// Send this node's metadata to async fn update_liveness_chain_metadata(&mut self) -> Result<(), ChainMetadataSyncError> { let chain_metadata = self.base_node.get_metadata().await?; - let bytes = proto::ChainMetadata::from(chain_metadata).to_encoded_bytes()?; + let bytes = proto::ChainMetadata::from(chain_metadata).to_encoded_bytes(); self.liveness .set_pong_metadata_entry(MetadataKey::ChainMetadata, bytes) .await?; @@ -131,35 +131,26 @@ impl ChainMetadataService { match event { // Received a pong, check if our neighbour sent it and it contains ChainMetadata LivenessEvent::ReceivedPong(event) => { - if event.is_neighbour { - trace!( - target: LOG_TARGET, - "Received pong from neighbouring node '{}'.", - event.node_id - ); - self.collect_chain_state_from_pong(&event.node_id, &event.metadata)?; - - // All peers have responded in this round, send the chain metadata to the base node service - if self.peer_chain_metadata.len() == self.peer_chain_metadata.capacity() { - self.flush_chain_metadata_to_event_publisher().await?; - } - } else { - debug!( - target: LOG_TARGET, - "Received pong from non-neighbouring node '{}'. Pong ignored...", event.node_id - ) + trace!( + target: LOG_TARGET, + "Received pong from neighbouring node '{}'.", + event.node_id + ); + self.collect_chain_state_from_pong(&event.node_id, &event.metadata)?; + + // All peers have responded in this round, send the chain metadata to the base node service + if self.peer_chain_metadata.len() == self.peer_chain_metadata.capacity() { + self.flush_chain_metadata_to_event_publisher().await?; } }, // New ping round has begun LivenessEvent::BroadcastedNeighbourPings(num_peers) => { + debug!( + target: LOG_TARGET, + "New chain metadata round sent to {} peer(s)", num_peers + ); // If we have chain metadata to send to the base node service, send them now // because the next round of pings is happening. - // TODO: It's pretty easy for this service to require either a percentage of peers - // to respond or, a time limit before assuming some peers will never respond - // between rounds (even if this time limit is larger than one or more ping rounds) - // before publishing the chain metadata event. - // The following will send the chain metadata at the start of a new round if at - // least one node has responded. if !self.peer_chain_metadata.is_empty() { self.flush_chain_metadata_to_event_publisher().await?; } @@ -300,10 +291,7 @@ mod test { let (liveness_handle, _) = create_p2p_liveness_mock(1); let mut metadata = Metadata::new(); let proto_chain_metadata = create_sample_proto_chain_metadata(); - metadata.insert( - MetadataKey::ChainMetadata, - proto_chain_metadata.to_encoded_bytes().unwrap(), - ); + metadata.insert(MetadataKey::ChainMetadata, proto_chain_metadata.to_encoded_bytes()); let node_id = NodeId::new(); let pong_event = PongEvent { @@ -356,28 +344,6 @@ mod test { assert_eq!(service.peer_chain_metadata.len(), 0); } - #[tokio_macros::test] - async fn handle_liveness_event_not_neighbour() { - let (liveness_handle, _) = create_p2p_liveness_mock(1); - let metadata = Metadata::new(); - let node_id = NodeId::new(); - let pong_event = PongEvent { - is_neighbour: false, - metadata, - node_id, - latency: None, - is_monitored: false, - }; - - let (base_node, _) = create_base_node_nci(); - let (publisher, _subscriber) = broadcast_channel::bounded(1); - let mut service = ChainMetadataService::new(liveness_handle, base_node, publisher); - - let sample_event = LivenessEvent::ReceivedPong(Box::new(pong_event)); - service.handle_liveness_event(&sample_event).await.unwrap(); - assert_eq!(service.peer_chain_metadata.len(), 0); - } - #[tokio_macros::test] async fn handle_liveness_event_bad_metadata() { let (liveness_handle, _) = create_p2p_liveness_mock(1); diff --git a/base_layer/core/src/base_node/comms_interface/comms_request.rs b/base_layer/core/src/base_node/comms_interface/comms_request.rs index 75537916dd..37f5e8e5de 100644 --- a/base_layer/core/src/base_node/comms_interface/comms_request.rs +++ b/base_layer/core/src/base_node/comms_interface/comms_request.rs @@ -48,7 +48,7 @@ pub enum NodeCommsRequest { FetchUtxos(Vec), FetchBlocks(Vec), FetchBlocksWithHashes(Vec), - GetNewBlockTemplate, + GetNewBlockTemplate(PowAlgorithm), GetNewBlock(NewBlockTemplate), GetTargetDifficulty(PowAlgorithm), } @@ -64,7 +64,7 @@ impl Display for NodeCommsRequest { NodeCommsRequest::FetchUtxos(v) => f.write_str(&format!("FetchUtxos (n={})", v.len())), NodeCommsRequest::FetchBlocks(v) => f.write_str(&format!("FetchBlocks (n={})", v.len())), NodeCommsRequest::FetchBlocksWithHashes(v) => f.write_str(&format!("FetchBlocks (n={})", v.len())), - NodeCommsRequest::GetNewBlockTemplate => f.write_str("GetNewBlockTemplate"), + NodeCommsRequest::GetNewBlockTemplate(algo) => f.write_str(&format!("GetNewBlockTemplate ({})", algo)), NodeCommsRequest::GetNewBlock(b) => f.write_str(&format!("GetNewBlock (Block Height={})", b.header.height)), NodeCommsRequest::GetTargetDifficulty(algo) => f.write_str(&format!("GetTargetDifficulty ({})", algo)), } diff --git a/base_layer/core/src/base_node/comms_interface/inbound_handlers.rs b/base_layer/core/src/base_node/comms_interface/inbound_handlers.rs index 7b5d5abe5f..7c971bf68a 100644 --- a/base_layer/core/src/base_node/comms_interface/inbound_handlers.rs +++ b/base_layer/core/src/base_node/comms_interface/inbound_handlers.rs @@ -36,6 +36,7 @@ use crate::{ }, consensus::ConsensusManager, mempool::{async_mempool, Mempool}, + proof_of_work::{get_target_difficulty, Difficulty, PowAlgorithm}, transactions::transaction::{TransactionKernel, TransactionOutput}, }; use futures::SinkExt; @@ -202,15 +203,18 @@ where T: BlockchainBackend + 'static } Ok(NodeCommsResponse::HistoricalBlocks(blocks)) }, - NodeCommsRequest::GetNewBlockTemplate => { + NodeCommsRequest::GetNewBlockTemplate(pow_algo) => { let metadata = async_db::get_metadata(self.blockchain_db.clone()).await?; let best_block_hash = metadata .best_block .ok_or_else(|| CommsInterfaceError::UnexpectedApiResponse)?; let best_block_header = async_db::fetch_header_with_block_hash(self.blockchain_db.clone(), best_block_hash).await?; + + let constants = self.consensus_manager.consensus_constants(); let mut header = BlockHeader::from_previous(&best_block_header); - header.version = self.consensus_manager.consensus_constants().blockchain_version(); + header.version = constants.blockchain_version(); + header.pow.target_difficulty = self.get_target_difficulty(*pow_algo).await?; let transactions = async_mempool::retrieve( self.mempool.clone(), @@ -233,12 +237,9 @@ where T: BlockchainBackend + 'static let block = async_db::calculate_mmr_roots(self.blockchain_db.clone(), block_template.clone()).await?; Ok(NodeCommsResponse::NewBlock(block)) }, - NodeCommsRequest::GetTargetDifficulty(pow_algo) => { - let (db, metadata) = &self.blockchain_db.db_and_metadata_read_access()?; - Ok(NodeCommsResponse::TargetDifficulty( - self.consensus_manager.get_target_difficulty(metadata, db, *pow_algo)?, - )) - }, + NodeCommsRequest::GetTargetDifficulty(pow_algo) => Ok(NodeCommsResponse::TargetDifficulty( + self.get_target_difficulty(*pow_algo).await?, + )), } } @@ -262,7 +263,7 @@ where T: BlockchainBackend + 'static // Create block event on block event stream let block_event = match add_block_result.clone() { Ok(block_add_result) => { - debug!(target: LOG_TARGET, "Block event created: {:?}", block_add_result); + debug!(target: LOG_TARGET, "Block event created: {}", block_add_result); BlockEvent::Verified((Box::new(block.clone()), block_add_result)) }, Err(e) => { @@ -296,6 +297,32 @@ where T: BlockchainBackend + 'static } Ok(()) } + + async fn get_target_difficulty(&self, pow_algo: PowAlgorithm) -> Result { + let height_of_longest_chain = async_db::get_metadata(self.blockchain_db.clone()) + .await? + .height_of_longest_chain + .ok_or_else(|| CommsInterfaceError::UnexpectedApiResponse)?; + debug!( + target: LOG_TARGET, + "Calculating target difficulty at height:{} for PoW:{}", height_of_longest_chain, pow_algo + ); + let constants = self.consensus_manager.consensus_constants(); + let target_difficulties = self.blockchain_db.fetch_target_difficulties( + pow_algo, + height_of_longest_chain, + constants.get_difficulty_block_window() as usize, + )?; + let target = get_target_difficulty( + target_difficulties, + constants.get_difficulty_block_window() as usize, + constants.get_diff_target_block_interval(), + constants.min_pow_difficulty(pow_algo), + constants.get_difficulty_max_block_interval(), + )?; + debug!(target: LOG_TARGET, "Target difficulty:{} for PoW:{}", target, pow_algo); + Ok(target) + } } impl Clone for InboundNodeCommsHandlers diff --git a/base_layer/core/src/base_node/comms_interface/local_interface.rs b/base_layer/core/src/base_node/comms_interface/local_interface.rs index 806d13120e..1dd05e5243 100644 --- a/base_layer/core/src/base_node/comms_interface/local_interface.rs +++ b/base_layer/core/src/base_node/comms_interface/local_interface.rs @@ -84,7 +84,7 @@ impl LocalNodeCommsInterface { } /// Request the block header of the current tip at the block height - pub async fn get_header(&mut self, block_heights: Vec) -> Result, CommsInterfaceError> { + pub async fn get_headers(&mut self, block_heights: Vec) -> Result, CommsInterfaceError> { match self .request_sender .call(NodeCommsRequest::FetchHeaders(block_heights)) @@ -96,10 +96,14 @@ impl LocalNodeCommsInterface { } /// Request the construction of a new mineable block template from the base node service. - pub async fn get_new_block_template(&mut self) -> Result { + pub async fn get_new_block_template( + &mut self, + pow_algorithm: PowAlgorithm, + ) -> Result + { match self .request_sender - .call(NodeCommsRequest::GetNewBlockTemplate) + .call(NodeCommsRequest::GetNewBlockTemplate(pow_algorithm)) .await?? { NodeCommsResponse::NewBlockTemplate(new_block_template) => Ok(new_block_template), diff --git a/base_layer/core/src/base_node/comms_interface/outbound_interface.rs b/base_layer/core/src/base_node/comms_interface/outbound_interface.rs index 8a480d945f..361fb402bc 100644 --- a/base_layer/core/src/base_node/comms_interface/outbound_interface.rs +++ b/base_layer/core/src/base_node/comms_interface/outbound_interface.rs @@ -173,13 +173,13 @@ impl OutboundNodeCommsInterface { node_id: Option, ) -> Result, CommsInterfaceError> { - let to_hash = to_hash.unwrap_or(HashOutput::new()); + let to_hash = to_hash.unwrap_or_default(); if let NodeCommsResponse::FetchHeadersAfterResponse(headers) = self .request_sender .call((NodeCommsRequest::FetchHeadersAfter(from_hash, to_hash), node_id)) .await?? { - Ok(headers.clone()) + Ok(headers) } else { Err(CommsInterfaceError::UnexpectedApiResponse) } diff --git a/base_layer/core/src/base_node/proto/request.proto b/base_layer/core/src/base_node/proto/request.proto index adff5e5c00..a10fed1744 100644 --- a/base_layer/core/src/base_node/proto/request.proto +++ b/base_layer/core/src/base_node/proto/request.proto @@ -23,7 +23,7 @@ message BaseNodeServiceRequest { // Indicates a FetchBlocksWithHashes request. HashOutputs fetch_blocks_with_hashes = 8; // Indicates a GetNewBlockTemplate request. - bool get_new_block_template = 9; + uint64 get_new_block_template = 9; // Indicates a GetNewBlock request. tari.core.NewBlockTemplate get_new_block = 10; // Indicates a GetTargetDifficulty request. diff --git a/base_layer/core/src/base_node/proto/request.rs b/base_layer/core/src/base_node/proto/request.rs index 2e7ab63789..18e31140d3 100644 --- a/base_layer/core/src/base_node/proto/request.rs +++ b/base_layer/core/src/base_node/proto/request.rs @@ -47,7 +47,9 @@ impl TryInto for ProtoNodeCommsRequest { FetchUtxos(hash_outputs) => ci::NodeCommsRequest::FetchUtxos(hash_outputs.outputs), FetchBlocks(block_heights) => ci::NodeCommsRequest::FetchBlocks(block_heights.heights), FetchBlocksWithHashes(block_hashes) => ci::NodeCommsRequest::FetchBlocksWithHashes(block_hashes.outputs), - GetNewBlockTemplate(_) => ci::NodeCommsRequest::GetNewBlockTemplate, + GetNewBlockTemplate(pow_algo) => { + ci::NodeCommsRequest::GetNewBlockTemplate(PowAlgorithm::try_from(pow_algo)?) + }, GetNewBlock(block_template) => ci::NodeCommsRequest::GetNewBlock(block_template.try_into()?), GetTargetDifficulty(pow_algo) => { ci::NodeCommsRequest::GetTargetDifficulty(PowAlgorithm::try_from(pow_algo)?) @@ -66,15 +68,12 @@ impl From for ProtoNodeCommsRequest { FetchHeaders(block_heights) => ProtoNodeCommsRequest::FetchHeaders(block_heights.into()), FetchHeadersWithHashes(block_hashes) => ProtoNodeCommsRequest::FetchHeadersWithHashes(block_hashes.into()), FetchHeadersAfter(hashes, stopping_hash) => { - ProtoNodeCommsRequest::FetchHeadersAfter(ProtoFetchHeadersAfter { - hashes: hashes.into(), - stopping_hash: stopping_hash.into(), - }) + ProtoNodeCommsRequest::FetchHeadersAfter(ProtoFetchHeadersAfter { hashes, stopping_hash }) }, FetchUtxos(hash_outputs) => ProtoNodeCommsRequest::FetchUtxos(hash_outputs.into()), FetchBlocks(block_heights) => ProtoNodeCommsRequest::FetchBlocks(block_heights.into()), FetchBlocksWithHashes(block_hashes) => ProtoNodeCommsRequest::FetchBlocksWithHashes(block_hashes.into()), - GetNewBlockTemplate => ProtoNodeCommsRequest::GetNewBlockTemplate(true), + GetNewBlockTemplate(pow_algo) => ProtoNodeCommsRequest::GetNewBlockTemplate(pow_algo as u64), GetNewBlock(block_template) => ProtoNodeCommsRequest::GetNewBlock(block_template.into()), GetTargetDifficulty(pow_algo) => ProtoNodeCommsRequest::GetTargetDifficulty(pow_algo as u64), } diff --git a/base_layer/core/src/base_node/service/initializer.rs b/base_layer/core/src/base_node/service/initializer.rs index 643ac067de..d2aed71668 100644 --- a/base_layer/core/src/base_node/service/initializer.rs +++ b/base_layer/core/src/base_node/service/initializer.rs @@ -136,6 +136,7 @@ async fn extract_block(msg: Arc) -> Option> { Some(DomainMessage { source_peer: msg.source_peer.clone(), dht_header: msg.dht_header.clone(), + authenticated_origin: msg.authenticated_origin.clone(), inner: block, }) }, diff --git a/base_layer/core/src/base_node/service/mod.rs b/base_layer/core/src/base_node/service/mod.rs index ccf0861bb1..4b1b40fab2 100644 --- a/base_layer/core/src/base_node/service/mod.rs +++ b/base_layer/core/src/base_node/service/mod.rs @@ -22,6 +22,7 @@ mod error; mod initializer; +#[allow(clippy::module_inception)] mod service; mod service_request; mod service_response; diff --git a/base_layer/core/src/base_node/service/service.rs b/base_layer/core/src/base_node/service/service.rs index 1dbaf41768..b9be68502e 100644 --- a/base_layer/core/src/base_node/service/service.rs +++ b/base_layer/core/src/base_node/service/service.rs @@ -401,7 +401,7 @@ async fn handle_incoming_request( outbound_message_service .send_direct( origin_public_key, - OutboundEncryption::EncryptForPeer, + OutboundEncryption::None, OutboundDomainMessage::new(TariMessageType::BaseNodeResponse, message), ) .await?; @@ -463,7 +463,7 @@ async fn handle_outbound_request( .map_err(|e| CommsInterfaceError::OutboundMessageService(e.to_string()))?; match send_result.resolve_ok().await { - Some(tags) if tags.is_empty() => { + Some(send_states) if send_states.is_empty() => { let _ = reply_tx .send(Err(CommsInterfaceError::NoBootstrapNodesConfigured)) .or_else(|resp| { @@ -506,7 +506,7 @@ async fn handle_outbound_block( outbound_message_service .propagate( NodeDestination::Unknown, - OutboundEncryption::EncryptForPeer, + OutboundEncryption::None, exclude_peers, OutboundDomainMessage::new(TariMessageType::NewBlock, ProtoBlock::from(block)), ) diff --git a/base_layer/core/src/base_node/state_machine.rs b/base_layer/core/src/base_node/state_machine.rs index 6822ce422a..90560aae0a 100644 --- a/base_layer/core/src/base_node/state_machine.rs +++ b/base_layer/core/src/base_node/state_machine.rs @@ -105,11 +105,9 @@ impl BaseNodeStateMachine { (Starting(s), Initialized) => Listening(s.into()), (BlockSync(s, _, _), BlocksSynchronized) => Listening(s.into()), (BlockSync(s, _, _), BlockSyncFailure) => Waiting(s.into()), - (Listening(_), FallenBehind(Lagging(network_tip, sync_peers))) => BlockSync( - self.config.block_sync_config.sync_strategy.clone(), - network_tip, - sync_peers, - ), + (Listening(_), FallenBehind(Lagging(network_tip, sync_peers))) => { + BlockSync(self.config.block_sync_config.sync_strategy, network_tip, sync_peers) + }, (Waiting(s), Continue) => Listening(s.into()), (_, FatalError(s)) => Shutdown(states::Shutdown::with_reason(s)), (_, UserQuit) => Shutdown(states::Shutdown::with_reason("Shutdown initiated by user".to_string())), @@ -152,7 +150,7 @@ impl BaseNodeStateMachine { let _ = self.event_sender.send(next_event.clone()).await; debug!( target: LOG_TARGET, - "=== Base Node event in State [{}]: {:?}", state, next_event + "=== Base Node event in State [{}]: {}", state, next_event ); state = self.transition(state, next_event); } diff --git a/base_layer/core/src/base_node/states/block_sync.rs b/base_layer/core/src/base_node/states/block_sync.rs index 85c5b1be71..1fbbf5b6a2 100644 --- a/base_layer/core/src/base_node/states/block_sync.rs +++ b/base_layer/core/src/base_node/states/block_sync.rs @@ -36,7 +36,7 @@ use core::cmp::min; use derive_error::Error; use log::*; use rand::seq::SliceRandom; -use std::str::FromStr; +use std::{str::FromStr, time::Duration}; use tari_comms::{ connection_manager::ConnectionManagerError, peer_manager::{NodeId, PeerManagerError}, @@ -59,6 +59,8 @@ const MAX_ADD_BLOCK_RETRY_ATTEMPTS: usize = 3; const HEADER_REQUEST_SIZE: usize = 100; // The number of blocks that can be requested in a single query. const BLOCK_REQUEST_SIZE: usize = 5; +// The default length of time to ban a misbehaving/malfunctioning sync peer (24 hours) +const DEFAULT_PEER_BAN_DURATION: Duration = Duration::from_secs(24 * 60 * 60); /// Configuration for the Block Synchronization. #[derive(Clone, Copy)] @@ -71,6 +73,7 @@ pub struct BlockSyncConfig { pub max_add_block_retry_attempts: usize, pub header_request_size: usize, pub block_request_size: usize, + pub peer_ban_duration: Duration, } impl Default for BlockSyncConfig { @@ -84,6 +87,7 @@ impl Default for BlockSyncConfig { max_add_block_retry_attempts: MAX_ADD_BLOCK_RETRY_ATTEMPTS, header_request_size: HEADER_REQUEST_SIZE, block_request_size: BLOCK_REQUEST_SIZE, + peer_ban_duration: DEFAULT_PEER_BAN_DURATION, } } } @@ -325,12 +329,20 @@ async fn find_chain_split_height( if prev_header.height + 1 == header.height { return Ok(header.height); } else { + warn!( + target: LOG_TARGET, + "Banning peer {} from local node, because they supplied invalid chain link", sync_peer + ); ban_sync_peer(shared, sync_peers, sync_peer.clone()).await?; return Err(BlockSyncError::InvalidChainLink); } } } } + warn!( + target: LOG_TARGET, + "Banning all peers from local node, because they could not provide a valid chain link", + ); ban_all_sync_peers(shared, sync_peers).await?; Err(BlockSyncError::ForkChainNotLinked) } @@ -364,14 +376,23 @@ async fn request_and_add_blocks( "Invalid block {} received from peer. Retrying", block_hash.to_hex(), ); + warn!( + target: LOG_TARGET, + "Banning peer {} from local node, because they supplied invalid block", sync_peer + ); ban_sync_peer(shared, sync_peers, sync_peer.clone()).await?; break; }, - Err(ChainStorageError::ValidationError(_)) => { + Err(ChainStorageError::ValidationError { source }) => { warn!( target: LOG_TARGET, - "Validation on block {} from peer failed. Retrying", + "Validation on block {} from peer failed due to: {:?}. Retrying", block_hash.to_hex(), + source, + ); + warn!( + target: LOG_TARGET, + "Banning peer {} from local node, because they supplied invalid block", sync_peer ); ban_sync_peer(shared, sync_peers, sync_peer.clone()).await?; break; @@ -379,7 +400,7 @@ async fn request_and_add_blocks( Err(e) => return Err(BlockSyncError::ChainStorageError(e)), } } - if block_nums.len() == 0 { + if block_nums.is_empty() { return Ok(()); } info!(target: LOG_TARGET, "Retrying block add. Attempt {}", attempt); @@ -419,6 +440,10 @@ async fn request_blocks( return Ok((blocks, sync_peer)); } else { debug!(target: LOG_TARGET, "This was NOT the blocks we were expecting."); + warn!( + target: LOG_TARGET, + "Banning peer {} from local node, because they supplied the incorrect blocks", sync_peer + ); ban_sync_peer(shared, sync_peers, sync_peer.clone()).await?; } } else { @@ -428,6 +453,11 @@ async fn request_blocks( block_nums.len(), hist_blocks.len() ); + warn!( + target: LOG_TARGET, + "Banning peer {} from local node, because they supplied the incorrect number of blocks", + sync_peer + ); ban_sync_peer(shared, sync_peers, sync_peer.clone()).await?; } }, @@ -486,6 +516,10 @@ async fn request_headers( return Ok((headers, sync_peer)); } else { debug!(target: LOG_TARGET, "This was NOT the headers we were expecting."); + warn!( + target: LOG_TARGET, + "Banning peer {} from local node, because they supplied the incorrect headers", sync_peer + ); ban_sync_peer(shared, sync_peers, sync_peer.clone()).await?; } } else { @@ -495,11 +529,20 @@ async fn request_headers( block_nums.len(), headers.len() ); + warn!( + target: LOG_TARGET, + "Banning peer {} from local node, because they supplied the incorrect number of headers", + sync_peer + ); ban_sync_peer(shared, sync_peers, sync_peer.clone()).await?; } }, Err(CommsInterfaceError::UnexpectedApiResponse) => { debug!(target: LOG_TARGET, "Remote node provided an unexpected api response.",); + warn!( + target: LOG_TARGET, + "Banning peer {} from local node, because they provided an unexpected api response", sync_peer + ); ban_sync_peer(shared, sync_peers, sync_peer.clone()).await?; }, Err(CommsInterfaceError::RequestTimedOut) => { @@ -567,10 +610,12 @@ async fn ban_sync_peer( sync_peer: NodeId, ) -> Result<(), BlockSyncError> { - warn!(target: LOG_TARGET, "Banning peer {} from local node.", sync_peer); sync_peers.retain(|p| *p != sync_peer); let peer = shared.peer_manager.find_by_node_id(&sync_peer).await?; - shared.peer_manager.set_banned(&peer.public_key, true).await?; + shared + .peer_manager + .ban_for(&peer.public_key, shared.config.block_sync_config.peer_ban_duration) + .await?; shared.connection_manager.disconnect_peer(sync_peer).await??; if sync_peers.is_empty() { return Err(BlockSyncError::NoSyncPeers); @@ -585,6 +630,7 @@ async fn ban_all_sync_peers( ) -> Result<(), BlockSyncError> { while !sync_peers.is_empty() { + warn!(target: LOG_TARGET, "Banning peer {} from local node.", sync_peers[0]); ban_sync_peer(shared, sync_peers, sync_peers[0].clone()).await?; } Ok(()) diff --git a/base_layer/core/src/base_node/states/events_and_states.rs b/base_layer/core/src/base_node/states/events_and_states.rs index 8b24595099..ed349f0ec2 100644 --- a/base_layer/core/src/base_node/states/events_and_states.rs +++ b/base_layer/core/src/base_node/states/events_and_states.rs @@ -72,7 +72,7 @@ impl Display for SyncStatus { "Lagging behind {} peers (#{}, Difficulty: {})", v.len(), m.height_of_longest_chain.unwrap_or(0), - m.accumulated_difficulty.unwrap_or(Difficulty::min()) + m.accumulated_difficulty.unwrap_or_else(Difficulty::min) ), UpToDate => f.write_str("UpToDate"), } diff --git a/base_layer/core/src/base_node/states/forward_block_sync.rs b/base_layer/core/src/base_node/states/forward_block_sync.rs index 3ef0b73a57..2f8f5875ee 100644 --- a/base_layer/core/src/base_node/states/forward_block_sync.rs +++ b/base_layer/core/src/base_node/states/forward_block_sync.rs @@ -105,7 +105,7 @@ async fn synchronize_blocks( Ok(headers) => { if let Some(first_header) = headers.first() { if let Ok(block) = shared.db.fetch_header_with_block_hash(first_header.prev_hash.clone()) { - if &shared.db.fetch_tip_header().map_err(|e| e.to_string())? != &block { + if shared.db.fetch_tip_header().map_err(|e| e.to_string())? != block { // If peer returns genesis block, it means that there is a split, but it is further back // than the headers we sent. let oldest_header_sent = from_headers.last().unwrap(); @@ -148,7 +148,7 @@ async fn synchronize_blocks( target: LOG_TARGET, "Could not sync with node '{}': Node did not return headers", sync_node_string ); - sync_node = sync_nodes.pop().map(|n| n.clone()); + sync_node = sync_nodes.pop().map(|n| n); continue; } @@ -245,11 +245,12 @@ async fn download_blocks( ); return Ok(false); }, - Err(ChainStorageError::ValidationError(_)) => { + Err(ChainStorageError::ValidationError { source }) => { warn!( target: LOG_TARGET, - "Validation on block {} from peer failed. Retrying", + "Validation on block {} because of {} from peer failed. Retrying", block_hash.to_hex(), + source ); return Ok(false); }, diff --git a/base_layer/core/src/base_node/states/listening.rs b/base_layer/core/src/base_node/states/listening.rs index 38615799f9..1e1c896601 100644 --- a/base_layer/core/src/base_node/states/listening.rs +++ b/base_layer/core/src/base_node/states/listening.rs @@ -91,20 +91,18 @@ fn find_sync_peers(best_metadata: &ChainMetadata, peer_metadata_list: &Vec ChainMetadata { // TODO: Use heuristics to weed out outliers / dishonest nodes. - metadata_list - .into_iter() - .fold(ChainMetadata::default(), |best, current| { - if current - .chain_metadata - .accumulated_difficulty - .unwrap_or(Difficulty::min()) >= - best.accumulated_difficulty.unwrap_or_else(|| 0.into()) - { - current.chain_metadata.clone() - } else { - best - } - }) + metadata_list.iter().fold(ChainMetadata::default(), |best, current| { + if current + .chain_metadata + .accumulated_difficulty + .unwrap_or_else(Difficulty::min) >= + best.accumulated_difficulty.unwrap_or_else(|| 0.into()) + { + current.chain_metadata.clone() + } else { + best + } + }) } /// Given a local and the network chain state respectively, figure out what synchronisation state we should be in. @@ -142,8 +140,8 @@ fn determine_sync_mode( } else { info!( target: log_target, - "Our local blockchain is up-to-date. We're at block #{} with an accumulated difficulty of {} and \ - the network chain tip is at #{} with an accumulated difficulty of {}", + "Our blockchain is up-to-date. We're at block {} with an accumulated difficulty of {} and the \ + network chain tip is at {} with an accumulated difficulty of {}", local.height_of_longest_chain.unwrap_or(0), local_tip_accum_difficulty, network.height_of_longest_chain.unwrap_or(0), diff --git a/base_layer/core/src/base_node/states/mod.rs b/base_layer/core/src/base_node/states/mod.rs index b18f73c4c9..596ab33961 100644 --- a/base_layer/core/src/base_node/states/mod.rs +++ b/base_layer/core/src/base_node/states/mod.rs @@ -57,7 +57,6 @@ //! required, and then shutdown. mod block_sync; -mod error; mod events_and_states; mod forward_block_sync; mod listening; diff --git a/base_layer/core/src/base_node/states/starting_state.rs b/base_layer/core/src/base_node/states/starting_state.rs index fbcc8e0a30..262786062a 100644 --- a/base_layer/core/src/base_node/states/starting_state.rs +++ b/base_layer/core/src/base_node/states/starting_state.rs @@ -22,7 +22,7 @@ // use crate::{ base_node::{ - states::{error::BaseNodeError, listening::ListeningInfo, StateEvent}, + states::{listening::ListeningInfo, StateEvent}, BaseNodeStateMachine, }, chain_storage::BlockchainBackend, @@ -36,22 +36,12 @@ const LOG_TARGET: &str = "c::bn::states::starting_state"; pub struct Starting; impl Starting { - /// Apply the configuration settings for this node. - fn apply_config(&mut self) -> Result<(), BaseNodeError> { - // TODO apply configuration - Ok(()) - } - pub async fn next_event( &mut self, _shared: &BaseNodeStateMachine, ) -> StateEvent { - info!(target: LOG_TARGET, "Configuring node."); - if let Err(err) = self.apply_config() { - return err.as_fatal("There was an error with the base node configuration."); - } - info!(target: LOG_TARGET, "Node configuration complete."); + info!(target: LOG_TARGET, "Starting node."); StateEvent::Initialized } } diff --git a/base_layer/core/src/blocks/block.rs b/base_layer/core/src/blocks/block.rs index 877c588022..cc50d2e3bd 100644 --- a/base_layer/core/src/blocks/block.rs +++ b/base_layer/core/src/blocks/block.rs @@ -63,6 +63,8 @@ pub enum BlockValidationError { MismatchedMmrRoots, // The block contains transactions that should have been cut through. NoCutThrough, + // The block weight is above the maximum + BlockTooLarge, } /// A Tari block. Blocks are linked together into a blockchain. diff --git a/base_layer/core/src/blocks/genesis_block.rs b/base_layer/core/src/blocks/genesis_block.rs index 1cfec932ea..6d4e2ba782 100644 --- a/base_layer/core/src/blocks/genesis_block.rs +++ b/base_layer/core/src/blocks/genesis_block.rs @@ -108,7 +108,7 @@ pub fn get_rincewind_genesis_block_raw() -> Block { prev_hash: vec![ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], - timestamp: 1_585_476_000.into(), // Sunday, 29 March 2020 12:00:00 GMT+02:00 + timestamp: 1_587_837_600.into(), // Saturday, 25 April 2020 18:00:00 GMT output_mr: from_hex("fab84d9d797c272b33011caa78718f93c3d5fc44c7d35bbf138613440fca2c79").unwrap(), range_proof_mr: from_hex("63a36ba139a884434702dffccec348b02ba886d3851a19732d8d111a54e17d56").unwrap(), kernel_mr: from_hex("b097af173dc852862f48af67aa57f48c47d20bc608d77b46a3018999bffba911").unwrap(), @@ -120,6 +120,7 @@ pub fn get_rincewind_genesis_block_raw() -> Block { pow: ProofOfWork { accumulated_monero_difficulty: 1.into(), accumulated_blake_difficulty: 1.into(), + target_difficulty: 1.into(), pow_algo: PowAlgorithm::Blake, pow_data: vec![], }, diff --git a/base_layer/core/src/chain_storage/async_db.rs b/base_layer/core/src/chain_storage/async_db.rs index 91396aff76..7cc56c98bd 100644 --- a/base_layer/core/src/chain_storage/async_db.rs +++ b/base_layer/core/src/chain_storage/async_db.rs @@ -54,12 +54,11 @@ where F: FnOnce() -> R { trace_id ); let ret = f(); - let end = Instant::now(); trace!( target: LOG_TARGET, "[{}] Exited blocking thread after {}ms. trace_id: '{}'", name, - (end - start).as_millis(), + start.elapsed().as_millis(), trace_id ); ret @@ -99,13 +98,13 @@ make_async!(fetch_utxo(hash: HashOutput) -> TransactionOutput, "fetch_utxo"); make_async!(fetch_stxo(hash: HashOutput) -> TransactionOutput, "fetch_stxo"); make_async!(fetch_orphan(hash: HashOutput) -> Block, "fetch_orphan"); make_async!(is_utxo(hash: HashOutput) -> bool, "is_utxo"); +make_async!(is_stxo(hash: HashOutput) -> bool, "is_stxo"); make_async!(fetch_mmr_root(tree: MmrTree) -> HashOutput, "fetch_mmr_root"); make_async!(fetch_mmr_only_root(tree: MmrTree) -> HashOutput, "fetch_mmr_only_root"); make_async!(calculate_mmr_root(tree: MmrTree,additions: Vec,deletions: Vec) -> HashOutput, "calculate_mmr_root"); make_async!(add_block(block: Block) -> BlockAddResult, "add_block"); make_async!(calculate_mmr_roots(template: NewBlockTemplate) -> Block, "calculate_mmr_roots"); -// make_async!(is_new_best_block(block: &Block) -> bool); make_async!(fetch_block(height: u64) -> HistoricalBlock, "fetch_block"); make_async!(fetch_block_with_hash(hash: HashOutput) -> Option, "fetch_block_with_hash"); make_async!(rewind_to_height(height: u64) -> Vec, "rewind_to_height"); diff --git a/base_layer/core/src/chain_storage/blockchain_database.rs b/base_layer/core/src/chain_storage/blockchain_database.rs index 3b52d0c2be..bd59cc8a34 100644 --- a/base_layer/core/src/chain_storage/blockchain_database.rs +++ b/base_layer/core/src/chain_storage/blockchain_database.rs @@ -23,36 +23,48 @@ use crate::{ blocks::{blockheader::BlockHash, Block, BlockHeader, NewBlockTemplate}, chain_storage::{ + consts::BLOCKCHAIN_DATABASE_ORPHAN_STORAGE_CAPACITY, db_transaction::{DbKey, DbKeyValuePair, DbTransaction, DbValue, MetadataKey, MetadataValue, MmrTree}, error::ChainStorageError, ChainMetadata, HistoricalBlock, }, consensus::ConsensusManager, - proof_of_work::{Difficulty, ProofOfWork}, + proof_of_work::{Difficulty, PowAlgorithm, ProofOfWork}, transactions::{ transaction::{TransactionInput, TransactionKernel, TransactionOutput}, - types::{BlindingFactor, Commitment, CommitmentFactory, HashOutput}, + types::{Commitment, HashOutput}, }, - validation::{StatelessValidation, StatelessValidator, ValidationError, ValidationWriteGuard, ValidatorWriteGuard}, + validation::{StatelessValidation, StatelessValidator, Validation, ValidationError, Validator}, }; use croaring::Bitmap; use log::*; use serde::{Deserialize, Serialize}; use std::{ collections::VecDeque, - ops::{Deref, DerefMut}, sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}, }; -use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - tari_utilities::{hex::Hex, Hashable}, -}; +use strum_macros::Display; +use tari_crypto::tari_utilities::{epoch_time::EpochTime, hex::Hex, Hashable}; use tari_mmr::{Hash, MerkleCheckPoint, MerkleProof, MutableMmrLeafNodes}; const LOG_TARGET: &str = "c::cs::database"; -#[derive(Clone, Debug, PartialEq)] +/// Configuration for the BlockchainDatabase. +#[derive(Clone, Copy)] +pub struct BlockchainDatabaseConfig { + pub orphan_storage_capacity: usize, +} + +impl Default for BlockchainDatabaseConfig { + fn default() -> Self { + Self { + orphan_storage_capacity: BLOCKCHAIN_DATABASE_ORPHAN_STORAGE_CAPACITY, + } + } +} + +#[derive(Clone, Debug, PartialEq, Display)] pub enum BlockAddResult { Ok, BlockExists, @@ -74,19 +86,22 @@ pub struct MutableMmrState { /// The `GenesisBlockValidator` is used to check that the chain builds on the correct genesis block. /// The `ChainTipValidator` is used to check that the accounting balance and MMR states of the chain state is valid. pub struct Validators { - block: Arc>, + block: Arc>, orphan: Arc>, + accum_difficulty: Arc>, } impl Validators { pub fn new( - block: impl ValidationWriteGuard + 'static, + block: impl Validation + 'static, orphan: impl StatelessValidation + 'static, + accum_difficulty: impl Validation + 'static, ) -> Self { Self { block: Arc::new(Box::new(block)), orphan: Arc::new(Box::new(orphan)), + accum_difficulty: Arc::new(Box::new(accum_difficulty)), } } } @@ -96,6 +111,7 @@ impl Clone for Validators { Validators { block: Arc::clone(&self.block), orphan: Arc::clone(&self.orphan), + accum_difficulty: Arc::clone(&self.accum_difficulty), } } } @@ -144,6 +160,8 @@ pub trait BlockchainBackend: Send + Sync { where Self: Sized, F: FnMut(Result<(HashOutput, Block), ChainStorageError>); + /// Returns the number of blocks in the block orphan pool. + fn get_orphan_count(&self) -> Result; /// Performs the function F for each transaction kernel. fn for_each_kernel(&self, f: F) -> Result<(), ChainStorageError> where @@ -161,6 +179,15 @@ pub trait BlockchainBackend: Send + Sync { F: FnMut(Result<(HashOutput, TransactionOutput), ChainStorageError>); /// Returns the stored header with the highest corresponding height. fn fetch_last_header(&self) -> Result, ChainStorageError>; + /// Returns the stored chain metadata. + fn fetch_metadata(&self) -> Result; + /// Returns the set of target difficulties for the specified proof of work algorithm. + fn fetch_target_difficulties( + &self, + pow_algo: PowAlgorithm, + height: u64, + block_window: usize, + ) -> Result, ChainStorageError>; } // Private macro that pulls out all the boiler plate of extracting a DB query result from its variants @@ -174,22 +201,6 @@ macro_rules! fetch { Err(e) => log_error(key, e), } }}; - - (meta $db:expr, $meta_key:ident, $default:expr) => {{ - match $db.fetch(&DbKey::Metadata(MetadataKey::$meta_key)) { - Ok(None) => { - warn!( - target: LOG_TARGET, - "The {} entry is not present in the database. Assuming the database is empty.", - DbKey::Metadata(MetadataKey::$meta_key) - ); - $default - }, - Ok(Some(DbValue::Metadata(MetadataValue::$meta_key(v)))) => v, - Ok(Some(other)) => return unexpected_result(DbKey::Metadata(MetadataKey::$meta_key), other), - Err(e) => return log_error(DbKey::Metadata(MetadataKey::$meta_key), e), - } - }}; } /// A generic blockchain storage mechanism. This struct defines the API for storing and retrieving Tari blockchain @@ -203,25 +214,29 @@ macro_rules! fetch { /// /// ``` /// use tari_core::{ -/// chain_storage::{BlockchainDatabase, MemoryDatabase, Validators}, +/// chain_storage::{BlockchainDatabase, BlockchainDatabaseConfig, MemoryDatabase, Validators}, /// consensus::{ConsensusManagerBuilder, Network}, /// transactions::types::HashDigest, -/// validation::{mocks::MockValidator, Validation}, +/// validation::{accum_difficulty_validators::AccumDifficultyValidator, mocks::MockValidator, Validation}, /// }; /// let db_backend = MemoryDatabase::::default(); -/// let validators = Validators::new(MockValidator::new(true), MockValidator::new(true)); +/// let validators = Validators::new( +/// MockValidator::new(true), +/// MockValidator::new(true), +/// AccumDifficultyValidator {}, +/// ); /// let db = MemoryDatabase::::default(); /// let network = Network::LocalNet; /// let rules = ConsensusManagerBuilder::new(network).build(); -/// let db = BlockchainDatabase::new(db_backend, &rules, validators).unwrap(); +/// let db = BlockchainDatabase::new(db_backend, &rules, validators, BlockchainDatabaseConfig::default()).unwrap(); /// // Do stuff with db /// ``` pub struct BlockchainDatabase where T: BlockchainBackend { - metadata: Arc>, db: Arc>, validators: Validators, + config: BlockchainDatabaseConfig, } impl BlockchainDatabase @@ -232,129 +247,24 @@ where T: BlockchainBackend db: T, consensus_manager: &ConsensusManager, validators: Validators, + config: BlockchainDatabaseConfig, ) -> Result { - let metadata = Self::read_metadata(&db)?; let blockchain_db = BlockchainDatabase { - metadata: Arc::new(RwLock::new(metadata)), db: Arc::new(RwLock::new(db)), validators, + config, }; if blockchain_db.get_height()?.is_none() { let genesis_block = consensus_manager.get_genesis_block(); - let genesis_block_hash = genesis_block.hash(); - let mut pow = genesis_block.header.pow.clone(); - pow.add_difficulty( - &genesis_block.header.pow, - ProofOfWork::achieved_difficulty(&genesis_block.header), - ); - let pow = pow.total_accumulated_difficulty(); blockchain_db.store_new_block(genesis_block)?; - blockchain_db.update_metadata(0, genesis_block_hash, pow)?; } Ok(blockchain_db) } - /// Reads the blockchain metadata (block height etc) from the underlying backend and returns it. - /// If the metadata values aren't in the database, (e.g. when running a node for the first time), - /// then log as much and return a reasonable default. - fn read_metadata(db: &T) -> Result { - let height = fetch!(meta db, ChainHeight, None); - let hash = fetch!(meta db, BestBlock, None); - let accumulated_difficulty = fetch!(meta db, AccumulatedWork, None); - // Set a default of 2880 blocks (2 days with 1min blocks) - let horizon = fetch!(meta db, PruningHorizon, 2880); - Ok(ChainMetadata { - height_of_longest_chain: height, - best_block: hash, - pruning_horizon: horizon, - accumulated_difficulty, - }) - } - - fn read_metadata_with_guard(db: &RwLockReadGuard) -> Result { - let height = fetch!(meta db, ChainHeight, None); - let hash = fetch!(meta db, BestBlock, None); - let accumulated_difficulty = fetch!(meta db, AccumulatedWork, None); - // Set a default of 2880 blocks (2 days with 1min blocks) - let horizon = fetch!(meta db, PruningHorizon, 2880); - Ok(ChainMetadata { - height_of_longest_chain: height, - best_block: hash, - pruning_horizon: horizon, - accumulated_difficulty, - }) - } - - /// If a call to any metadata function fails, you can try and force a re-sync with this function. If the RWLock - /// is poisoned because a write attempt failed, this function will replace the old lock with a new one with data - /// freshly read from the underlying database. If this still fails, there's probably something badly wrong. - /// - /// # Returns - /// Ok(true) - The lock was refreshed and data was successfully re-read from the database. Proceed with caution. - /// The database *may* be inconsistent. - /// Ok(false) - Everything looks fine. Why did you call this function again? - /// Err(ChainStorageError::CriticalError) - Refreshing the lock failed. We couldn't refresh the metadata from the DB - /// backend, so you should probably just shut things down and look at the logs. - pub fn try_recover_metadata(&mut self) -> Result { - if !self.metadata.is_poisoned() { - // metadata is fine. Nothing to do here - return Ok(false); - } - match BlockchainDatabase::read_metadata_with_guard( - &self - .db - .read() - .map_err(|e| ChainStorageError::AccessError(e.to_string()))?, - ) { - Ok(data) => { - self.metadata = Arc::new(RwLock::new(data)); - Ok(true) - }, - Err(e) => { - error!( - target: LOG_TARGET, - "Could not read metadata from database. {:?}. We're going to panic here. Perhaps restarting will \ - fix things", - e - ); - Err(ChainStorageError::CriticalError) - }, - } - } - - pub fn metadata_read_access(&self) -> Result, ChainStorageError> { - self.metadata.read().map_err(|e| { - error!( - target: LOG_TARGET, - "An attempt to get a read lock on the blockchain metadata failed. {:?}", e - ); - ChainStorageError::AccessError("Read lock on blockchain metadata failed".into()) - }) - } - - pub fn metadata_write_access(&self) -> Result, ChainStorageError> { - self.metadata.write().map_err(|e| { - error!( - target: LOG_TARGET, - "An attempt to get a write lock on the blockchain metadata failed. {:?}", e - ); - ChainStorageError::AccessError("Write lock on blockchain metadata failed".into()) - }) - } - - pub fn db_and_metadata_read_access( - &self, - ) -> Result<(RwLockReadGuard, RwLockReadGuard), ChainStorageError> { - // Always get metadata first so that deadlocks can't occur. - let metadata = self.metadata_read_access()?; - let db = self.db_read_access()?; - Ok((db, metadata)) - } - // Be careful about making this method public. Rather use `db_and_metadata_read_access` // so that metadata and db are read in the correct order so that deadlocks don't occur - fn db_read_access(&self) -> Result, ChainStorageError> { + pub fn db_read_access(&self) -> Result, ChainStorageError> { self.db.read().map_err(|e| { error!( target: LOG_TARGET, @@ -364,7 +274,7 @@ where T: BlockchainBackend }) } - fn db_write_access(&self) -> Result, ChainStorageError> { + pub fn db_write_access(&self) -> Result, ChainStorageError> { self.db.write().map_err(|e| { error!( target: LOG_TARGET, @@ -374,39 +284,27 @@ where T: BlockchainBackend }) } - fn update_metadata( - &self, - new_height: u64, - new_hash: Vec, - accumulated_difficulty: Difficulty, - ) -> Result<(), ChainStorageError> - { - let mut metadata = self.metadata_write_access()?; - let mut db = self.db_write_access()?; - update_metadata(&mut metadata, &mut db, new_height, new_hash, accumulated_difficulty) - } - /// Returns the height of the current longest chain. This method will only fail if there's a fairly serious /// synchronisation problem on the database. You can try calling [BlockchainDatabase::try_recover_metadata] in /// that case to re-sync the metadata; or else just exit the program. /// /// If the chain is empty (the genesis block hasn't been added yet), this function returns `None` pub fn get_height(&self) -> Result, ChainStorageError> { - let metadata = self.metadata_read_access()?; - Ok(metadata.height_of_longest_chain) + let db = self.db_read_access()?; + Ok(db.fetch_metadata()?.height_of_longest_chain) } /// Return the geometric mean of the proof of work of the longest chain. /// The proof of work is returned as the geometric mean of all difficulties pub fn get_accumulated_difficulty(&self) -> Result, ChainStorageError> { - let metadata = self.metadata_read_access()?; - Ok(metadata.accumulated_difficulty) + let db = self.db_read_access()?; + Ok(db.fetch_metadata()?.accumulated_difficulty) } /// Returns a copy of the current blockchain database metadata pub fn get_metadata(&self) -> Result { - let metadata = self.metadata_read_access()?; - Ok(metadata.clone()) + let db = self.db_read_access()?; + Ok(db.fetch_metadata()?.clone()) } /// Returns the transaction kernel with the given hash. @@ -418,7 +316,13 @@ where T: BlockchainBackend /// Returns the block header at the given block height. pub fn fetch_header(&self, block_num: u64) -> Result { let db = self.db_read_access()?; - fetch_header(&db, block_num) + fetch_header(&*db, block_num) + } + + /// Returns the set of block headers specified by the block numbers. + pub fn fetch_headers(&self, block_nums: Vec) -> Result, ChainStorageError> { + let db = self.db_read_access()?; + fetch_headers(&*db, block_nums) } /// Returns the block header corresponding` to the provided BlockHash @@ -429,7 +333,7 @@ where T: BlockchainBackend pub fn fetch_tip_header(&self) -> Result { let db = self.db_read_access()?; - fetch_tip_header(&db) + fetch_tip_header(&*db) } /// Returns the UTXO with the given hash. @@ -444,28 +348,46 @@ where T: BlockchainBackend fetch_stxo(&*db, hash) } + /// Returns the STXO with the given hash. + pub fn is_stxo(&self, hash: HashOutput) -> Result { + let db = self.db_read_access()?; + is_stxo(&*db, hash) + } + /// Returns the orphan block with the given hash. pub fn fetch_orphan(&self, hash: HashOutput) -> Result { let db = self.db_read_access()?; fetch_orphan(&*db, hash) } + /// Returns the set of target difficulties for the specified proof of work algorithm. + pub fn fetch_target_difficulties( + &self, + pow_algo: PowAlgorithm, + height: u64, + block_window: usize, + ) -> Result, ChainStorageError> + { + let db = self.db_read_access()?; + fetch_target_difficulties(&*db, pow_algo, height, block_window) + } + /// Returns true if the given UTXO, represented by its hash exists in the UTXO set. pub fn is_utxo(&self, hash: HashOutput) -> Result { let db = self.db_read_access()?; - is_utxo(&db, hash) + is_utxo(&*db, hash) } /// Calculate the Merklish root of the specified merkle mountain range. pub fn fetch_mmr_root(&self, tree: MmrTree) -> Result { let db = self.db_read_access()?; - fetch_mmr_root(&db, tree) + fetch_mmr_root(&*db, tree) } /// Returns only the MMR merkle root without the state of the roaring bitmap. pub fn fetch_mmr_only_root(&self, tree: MmrTree) -> Result { let db = self.db_read_access()?; - fetch_mmr_only_root(&db, tree) + fetch_mmr_only_root(&*db, tree) } /// Apply the current change set to a pruned copy of the merkle mountain range and calculate the resulting Merklish @@ -486,13 +408,13 @@ where T: BlockchainBackend /// actually be a valid extension to the chain; only the new MMR roots are calculated pub fn calculate_mmr_roots(&self, template: NewBlockTemplate) -> Result { let db = self.db_read_access()?; - calculate_mmr_roots(&db, template) + calculate_mmr_roots(&*db, template) } /// Fetch a Merklish proof for the given hash, tree and position in the MMR pub fn fetch_mmr_proof(&self, tree: MmrTree, pos: usize) -> Result { let db = self.db_read_access()?; - fetch_mmr_proof(&db, tree, pos) + fetch_mmr_proof(&*db, tree, pos) } /// Tries to add a block to the longest chain. @@ -522,14 +444,16 @@ where T: BlockchainBackend /// If an error does occur while writing the new block parts, all changes are reverted before returning. pub fn add_block(&self, block: Block) -> Result { // Perform orphan block validation. - self.validators - .orphan - .validate(&block) - .map_err(ChainStorageError::ValidationError)?; + self.validators.orphan.validate(&block)?; - let mut metadata = self.metadata_write_access()?; let mut db = self.db_write_access()?; - add_block(&mut metadata, &mut db, &self.validators.block, block) + add_block( + &mut db, + &self.validators.block, + &self.validators.accum_difficulty, + block, + self.config.orphan_storage_capacity, + ) } fn store_new_block(&self, block: Block) -> Result<(), ChainStorageError> { @@ -537,18 +461,6 @@ where T: BlockchainBackend store_new_block(&mut db, block) } - /// Returns true if the given block -- assuming everything else is valid -- would be added to the tip of the - /// longest chain; i.e. the following conditions are met: - /// * The blockchain is empty, - /// * or ALL of: - /// * the block's parent hash is the hash of the block at the current chain tip, - /// * the block height is one greater than the parent block - pub fn is_at_chain_tip(&self, block: &Block) -> Result { - let metadata = self.metadata_read_access()?; - let db = self.db_read_access()?; - is_at_chain_tip(&metadata, &db, block) - } - /// Fetch a block from the blockchain database. /// /// # Returns @@ -585,27 +497,8 @@ where T: BlockchainBackend /// * The block height is in the future /// * The block height is before pruning horizon pub fn rewind_to_height(&self, height: u64) -> Result, ChainStorageError> { - let mut metadata = self.metadata_write_access()?; let mut db = self.db_write_access()?; - rewind_to_height(&mut metadata, &mut db, height) - } - - /// Calculate the total kernel excess for all kernels in the chain. - pub fn total_kernel_excess(&self) -> Result { - let db = self.db_read_access()?; - total_kernel_excess(&db) - } - - /// Calculate the total kernel offset for all the kernel offsets recorded in the headers of the chain. - pub fn total_kernel_offset(&self) -> Result { - let db = self.db_read_access()?; - total_kernel_offset(&db) - } - - /// Calculate the total sum of all the UTXO commitments in the chain. - pub fn total_utxo_commitment(&self) -> Result { - let db = self.db_read_access()?; - total_utxo_commitment(&db) + rewind_to_height(&mut db, height) } } @@ -615,56 +508,24 @@ fn unexpected_result(req: DbKey, res: DbValue) -> Result( - metadata: &mut RwLockWriteGuard, - db: &mut RwLockWriteGuard, - new_height: u64, - new_hash: Vec, - accumulated_difficulty: Difficulty, -) -> Result<(), ChainStorageError> -{ - metadata.height_of_longest_chain = Some(new_height); - metadata.best_block = Some(new_hash); - metadata.accumulated_difficulty = Some(accumulated_difficulty); - - let mut txn = DbTransaction::new(); - txn.insert(DbKeyValuePair::Metadata( - MetadataKey::ChainHeight, - MetadataValue::ChainHeight(metadata.height_of_longest_chain), - )); - txn.insert(DbKeyValuePair::Metadata( - MetadataKey::BestBlock, - MetadataValue::BestBlock(metadata.best_block.clone()), - )); - txn.insert(DbKeyValuePair::Metadata( - MetadataKey::AccumulatedWork, - MetadataValue::AccumulatedWork(metadata.accumulated_difficulty), - )); - commit(db, txn) -} - fn fetch_kernel(db: &T, hash: HashOutput) -> Result { fetch!(db, hash, TransactionKernel) } -pub fn fetch_header( - db: &RwLockReadGuard, - block_num: u64, -) -> Result -{ - fetch_header_impl(db.deref(), block_num) -} - -fn fetch_header_impl(db: &T, block_num: u64) -> Result { +pub fn fetch_header(db: &T, block_num: u64) -> Result { fetch!(db, block_num, BlockHeader) } -pub fn fetch_header_writeguard( - db: &RwLockWriteGuard, - block_num: u64, -) -> Result +pub fn fetch_headers( + db: &T, + block_nums: Vec, +) -> Result, ChainStorageError> { - fetch!(db, block_num, BlockHeader) + let mut headers = Vec::::with_capacity(block_nums.len()); + for block_num in block_nums { + headers.push(fetch_header(db, block_num)?); + } + Ok(headers) } fn fetch_header_with_block_hash( @@ -675,7 +536,7 @@ fn fetch_header_with_block_hash( fetch!(db, hash, BlockHash) } -fn fetch_tip_header(db: &RwLockReadGuard) -> Result { +fn fetch_tip_header(db: &T) -> Result { db.fetch_last_header() .or_else(|e| { error!(target: LOG_TARGET, "Could not fetch the tip header of the db. {:?}", e); @@ -696,58 +557,36 @@ fn fetch_orphan(db: &T, hash: HashOutput) -> Result(db: &RwLockReadGuard, hash: HashOutput) -> Result { +pub fn fetch_target_difficulties( + db: &T, + pow_algo: PowAlgorithm, + height: u64, + block_window: usize, +) -> Result, ChainStorageError> +{ + db.fetch_target_difficulties(pow_algo, height, block_window) +} + +pub fn is_utxo(db: &T, hash: HashOutput) -> Result { let key = DbKey::UnspentOutput(hash); db.contains(&key) } -pub fn is_utxo_writeguard( - db: &RwLockWriteGuard, - hash: HashOutput, -) -> Result -{ - let key = DbKey::UnspentOutput(hash); +pub fn is_stxo(db: &T, hash: HashOutput) -> Result { + let key = DbKey::SpentOutput(hash); db.contains(&key) } -fn fetch_mmr_root( - db: &RwLockReadGuard, - tree: MmrTree, -) -> Result -{ +fn fetch_mmr_root(db: &T, tree: MmrTree) -> Result { db.fetch_mmr_root(tree) } -fn fetch_mmr_only_root( - db: &RwLockReadGuard, - tree: MmrTree, -) -> Result -{ +fn fetch_mmr_only_root(db: &T, tree: MmrTree) -> Result { db.fetch_mmr_only_root(tree) } pub fn calculate_mmr_roots( - db: &RwLockReadGuard, - template: NewBlockTemplate, -) -> Result -{ - let NewBlockTemplate { header, mut body } = template; - // Make sure the body components are sorted. If they already are, this is a very cheap call. - body.sort(); - let kernel_hashes: Vec = body.kernels().iter().map(|k| k.hash()).collect(); - let out_hashes: Vec = body.outputs().iter().map(|out| out.hash()).collect(); - let rp_hashes: Vec = body.outputs().iter().map(|out| out.proof().hash()).collect(); - let inp_hashes: Vec = body.inputs().iter().map(|inp| inp.hash()).collect(); - - let mut header = BlockHeader::from(header); - header.kernel_mr = db.calculate_mmr_root(MmrTree::Kernel, kernel_hashes, vec![])?; - header.output_mr = db.calculate_mmr_root(MmrTree::Utxo, out_hashes, inp_hashes)?; - header.range_proof_mr = db.calculate_mmr_root(MmrTree::RangeProof, rp_hashes, vec![])?; - Ok(Block { header, body }) -} - -pub fn calculate_mmr_roots_writeguard( - db: &RwLockWriteGuard, + db: &T, template: NewBlockTemplate, ) -> Result { @@ -767,65 +606,69 @@ pub fn calculate_mmr_roots_writeguard( } /// Fetch a Merklish proof for the given hash, tree and position in the MMR -fn fetch_mmr_proof( - db: &RwLockReadGuard, - tree: MmrTree, - pos: usize, -) -> Result -{ +fn fetch_mmr_proof(db: &T, tree: MmrTree, pos: usize) -> Result { db.fetch_mmr_proof(tree, pos) } fn add_block( - metadata: &mut RwLockWriteGuard, db: &mut RwLockWriteGuard, - block_validator: &Arc>, + block_validator: &Arc>, + accum_difficulty_validator: &Arc>, block: Block, + orphan_storage_capacity: usize, ) -> Result { let block_hash = block.hash(); if db.contains(&DbKey::BlockHash(block_hash))? { return Ok(BlockAddResult::BlockExists); } - - handle_possible_reorg(metadata, db, block_validator, block) + let block_add_result = handle_possible_reorg(db, block_validator, accum_difficulty_validator, block)?; + // Cleanup orphan block pool + match block_add_result { + BlockAddResult::Ok => {}, + BlockAddResult::BlockExists => {}, + BlockAddResult::OrphanBlock => cleanup_orphans_single(db, orphan_storage_capacity)?, + BlockAddResult::ChainReorg(_) => cleanup_orphans_comprehensive(db, orphan_storage_capacity)?, + } + Ok(block_add_result) } +// Adds a new block onto the chain tip. fn store_new_block(db: &mut RwLockWriteGuard, block: Block) -> Result<(), ChainStorageError> { let (header, inputs, outputs, kernels) = block.dissolve(); + let height = header.height; + let best_block = header.hash(); + let accumulated_difficulty = + ProofOfWork::new_from_difficulty(&header.pow, ProofOfWork::achieved_difficulty(&header)) + .total_accumulated_difficulty(); // Build all the DB queries needed to add the block and the add it atomically let mut txn = DbTransaction::new(); + // Update metadata + txn.insert(DbKeyValuePair::Metadata( + MetadataKey::ChainHeight, + MetadataValue::ChainHeight(Some(height)), + )); + txn.insert(DbKeyValuePair::Metadata( + MetadataKey::BestBlock, + MetadataValue::BestBlock(Some(best_block.clone())), + )); + txn.insert(DbKeyValuePair::Metadata( + MetadataKey::AccumulatedWork, + MetadataValue::AccumulatedWork(Some(accumulated_difficulty)), + )); + // Insert block txn.insert_header(header); txn.spend_inputs(&inputs); outputs.iter().for_each(|utxo| txn.insert_utxo(utxo.clone(), true)); kernels.iter().for_each(|k| txn.insert_kernel(k.clone(), true)); txn.commit_block(); - commit(db, txn) -} - -fn is_at_chain_tip( - metadata: &ChainMetadata, - db: &RwLockReadGuard, - block: &Block, -) -> Result -{ - let (height, parent_hash) = { - // If the database is empty, the best block must be the genesis block - if metadata.height_of_longest_chain.is_none() { - return Ok(block.header.height == 0); - } - ( - metadata.height_of_longest_chain.clone().unwrap(), - metadata.best_block.clone().unwrap(), - ) - }; - let best_block = fetch_header(db, height)?; - Ok(block.header.prev_hash == parent_hash && block.header.height == best_block.height + 1) + commit(db, txn)?; + Ok(()) } fn fetch_block(db: &T, height: u64) -> Result { let tip_height = check_for_valid_height(&*db, height)?; - let header = fetch_header_impl(db, height)?; + let header = fetch_header(db, height)?; let kernel_cp = fetch_checkpoint(db, MmrTree::Kernel, height)?; let (kernel_hashes, _) = kernel_cp.into_parts(); let kernels = fetch_kernels(db, kernel_hashes)?; @@ -857,7 +700,7 @@ fn fetch_block_with_hash( } fn check_for_valid_height(db: &T, height: u64) -> Result { - let db_height = db.fetch_last_header()?.map(|tip| tip.height).unwrap_or(0); + let db_height = db.fetch_metadata()?.height_of_longest_chain.unwrap_or(0); if height > db_height { return Err(ChainStorageError::InvalidQuery(format!( "Cannot get block at height {}. Chain tip is at {}", @@ -926,11 +769,10 @@ fn fetch_checkpoint( } pub fn commit(db: &mut RwLockWriteGuard, txn: DbTransaction) -> Result<(), ChainStorageError> { - db.deref_mut().write(txn) + db.write(txn) } fn rewind_to_height( - metadata: &mut RwLockWriteGuard, db: &mut RwLockWriteGuard, height: u64, ) -> Result, ChainStorageError> @@ -977,26 +819,39 @@ fn rewind_to_height( txn.rewind_kernel_mmr(steps_back); txn.rewind_utxo_mmr(steps_back); txn.rewind_rp_mmr(steps_back); + // Update metadata + let last_header = fetch_header(&**db, height)?; + let accumulated_work = + ProofOfWork::new_from_difficulty(&last_header.pow, ProofOfWork::achieved_difficulty(&last_header)) + .total_accumulated_difficulty(); + txn.insert(DbKeyValuePair::Metadata( + MetadataKey::ChainHeight, + MetadataValue::ChainHeight(Some(last_header.height)), + )); + txn.insert(DbKeyValuePair::Metadata( + MetadataKey::BestBlock, + MetadataValue::BestBlock(Some(last_header.hash())), + )); + txn.insert(DbKeyValuePair::Metadata( + MetadataKey::AccumulatedWork, + MetadataValue::AccumulatedWork(Some(accumulated_work)), + )); commit(db, txn)?; - let last_header = fetch_header_writeguard(db, height)?.clone(); - let pow = ProofOfWork::new_from_difficulty(&last_header.pow, ProofOfWork::achieved_difficulty(&last_header)); - let pow = pow.total_accumulated_difficulty(); - update_metadata(metadata, db, height, last_header.hash(), pow)?; - Ok(removed_blocks) } // Checks whether we should add the block as an orphan. If it is the case, the orphan block is added and the chain // is reorganised if necessary. fn handle_possible_reorg( - metadata: &mut RwLockWriteGuard, db: &mut RwLockWriteGuard, - block_validator: &Arc>, + block_validator: &Arc>, + accum_difficulty_validator: &Arc>, block: Block, ) -> Result { - let db_height = metadata + let db_height = db + .fetch_metadata()? .height_of_longest_chain .ok_or_else(|| ChainStorageError::InvalidQuery("Cannot retrieve block. Blockchain DB is empty".into())) .or_else(|e| { @@ -1016,7 +871,7 @@ fn handle_possible_reorg( trace!(target: LOG_TARGET, "{}", block); // Trigger a reorg check for all blocks in the orphan block pool debug!(target: LOG_TARGET, "Checking for chain re-org."); - handle_reorg(metadata, db, block_validator, block) + handle_reorg(db, block_validator, accum_difficulty_validator, block) } // The handle_reorg function is triggered by the adding of orphaned blocks. Reorg chains are constructed by @@ -1026,16 +881,16 @@ fn handle_possible_reorg( // reorg chain is constructed with a higher accumulated difficulty, then the main chain is rewound and updated // with the newly un-orphaned blocks from the reorg chain. fn handle_reorg( - metadata: &mut RwLockWriteGuard, db: &mut RwLockWriteGuard, - block_validator: &Arc>, + block_validator: &Arc>, + accum_difficulty_validator: &Arc>, new_block: Block, ) -> Result { // We can assume that the new block is part of the re-org chain if it exists, otherwise the re-org would have // happened on the previous call to this function. // Try and construct a path from `new_block` to the main chain: - let reorg_chain = try_construct_fork(db, new_block.clone())?; + let mut reorg_chain = try_construct_fork(db, new_block.clone())?; if reorg_chain.is_empty() { trace!( target: LOG_TARGET, @@ -1046,9 +901,9 @@ fn handle_reorg( // Try and find all orphaned chain tips that can be linked to the new orphan block, if no better orphan chain // tips can be found then the new_block is a tip. let new_block_hash = new_block.hash(); - let orphan_chain_tips = find_orphan_chain_tips(db, new_block.header.height, new_block_hash); + let orphan_chain_tips = find_orphan_chain_tips(&**db, new_block.header.height, new_block_hash.clone()); // Check the accumulated difficulty of the best fork chain compared to the main chain. - let (fork_accum_difficulty, fork_tip_hash) = find_strongest_orphan_tip(db, orphan_chain_tips)?; + let (fork_accum_difficulty, fork_tip_hash) = find_strongest_orphan_tip(&**db, orphan_chain_tips)?; let tip_header = db .fetch_last_header()? .ok_or_else(|| ChainStorageError::InvalidQuery("Cannot retrieve header. Blockchain DB is empty".into()))?; @@ -1060,26 +915,23 @@ fn handle_reorg( tip_header.total_accumulated_difficulty_inclusive(), tip_header.hash().to_hex() ); - if fork_accum_difficulty >= tip_header.total_accumulated_difficulty_inclusive() { - // TODO: this should be > and not >=, this breaks some of the tests that assume that they can be the same. + if accum_difficulty_validator.validate(&fork_accum_difficulty, db).is_ok() { // We've built the strongest orphan chain we can by going backwards and forwards from the new orphan block // that is linked with the main chain. let fork_tip_block = fetch_orphan(&**db, fork_tip_hash.clone())?; let fork_tip_header = fork_tip_block.header.clone(); - let reorg_chain = try_construct_fork(db, fork_tip_block)?; + if fork_tip_hash != new_block_hash { + // New block is not the tip, find complete chain from tip to main chain. + reorg_chain = try_construct_fork(db, fork_tip_block)?; + } let added_blocks: Vec = reorg_chain.iter().map(Clone::clone).collect(); - let pow = - ProofOfWork::new_from_difficulty(&fork_tip_header.pow, ProofOfWork::achieved_difficulty(&fork_tip_header)); - let pow = pow.total_accumulated_difficulty(); - let fork_height = reorg_chain .front() .expect("The new orphan block should be in the queue") .header .height - 1; - let removed_blocks = reorganize_chain(metadata, db, block_validator, fork_height, reorg_chain)?; - update_metadata(metadata, db, fork_tip_header.height, fork_tip_hash, pow)?; + let removed_blocks = reorganize_chain(db, block_validator, fork_height, reorg_chain)?; if removed_blocks.is_empty() { return Ok(BlockAddResult::Ok); } else { @@ -1103,21 +955,20 @@ fn handle_reorg( // Reorganize the main chain with the provided fork chain, starting at the specified height. fn reorganize_chain( - metadata: &mut RwLockWriteGuard, db: &mut RwLockWriteGuard, - block_validator: &Arc>, + block_validator: &Arc>, height: u64, chain: VecDeque, ) -> Result, ChainStorageError> { - let removed_blocks = rewind_to_height(metadata, db, height)?; + let removed_blocks = rewind_to_height(db, height)?; trace!(target: LOG_TARGET, "Validate and add chain blocks.",); let mut validation_result: Result<(), ValidationError> = Ok(()); let mut orphan_hashes = Vec::::with_capacity(chain.len()); for block in chain { let block_hash = block.hash(); orphan_hashes.push(block_hash.clone()); - validation_result = block_validator.validate(&block, db, metadata); + validation_result = block_validator.validate(&block, db); if validation_result.is_err() { debug!( target: LOG_TARGET, @@ -1144,7 +995,7 @@ fn reorganize_chain( }, Err(e) => { trace!(target: LOG_TARGET, "Restoring previous chain after failed reorg.",); - let invalid_chain = rewind_to_height(metadata, db, height)?; + let invalid_chain = rewind_to_height(db, height)?; debug!( target: LOG_TARGET, "Removed incomplete chain of blocks during chain restore: {:?}.", @@ -1159,7 +1010,7 @@ fn reorganize_chain( store_new_block(db, block)?; } commit(db, txn)?; - Err(ChainStorageError::ValidationError(e)) + Err(e.into()) }, } } @@ -1182,33 +1033,6 @@ fn remove_orphan( commit(db, txn) } -fn total_kernel_excess(db: &RwLockReadGuard) -> Result { - let mut excess = CommitmentFactory::default().zero(); - db.for_each_kernel(|pair| { - let (_, kernel) = pair.unwrap(); - excess = &excess + &kernel.excess; - })?; - Ok(excess) -} - -fn total_kernel_offset(db: &RwLockReadGuard) -> Result { - let mut offset = BlindingFactor::default(); - db.for_each_header(|pair| { - let (_, header) = pair.unwrap(); - offset = &offset + &header.total_kernel_offset; - })?; - Ok(offset) -} - -fn total_utxo_commitment(db: &RwLockReadGuard) -> Result { - let mut total_commitment = CommitmentFactory::default().zero(); - db.for_each_utxo(|pair| { - let (_, utxo) = pair.unwrap(); - total_commitment = &total_commitment + &utxo.commitment; - })?; - Ok(total_commitment) -} - /// We try and build a chain from this block to the main chain. If we can't do that we can stop. /// We start with the current, newly received block, and look for a blockchain sequence (via `prev_hash`). /// Each successful link is pushed to the front of the queue. An empty queue is returned if the fork chain did not @@ -1222,49 +1046,81 @@ fn try_construct_fork( let mut hash = new_block.header.prev_hash.clone(); let mut height = new_block.header.height; fork_chain.push_front(new_block); - while let Ok(b) = fetch_orphan(&**db, hash.clone()) { + + loop { + let fork_start_header = fork_chain + .front() + .expect("The new orphan block should be in the queue") + .header + .clone(); trace!( target: LOG_TARGET, - "Checking block #{} forms a sequence to main chain or is orphaned", - b.header.height + "Checking if block {} ({}) is connected to the main chain.", + fork_start_header.height, + fork_start_header.hash().to_hex(), ); - if b.header.height + 1 != height { - // Well now. The block heights don't form a sequence, which means that we should not only stop now, - // but remove one or both of these orphans from the pool because the blockchain is broken at this point. - info!( - target: LOG_TARGET, - "A broken blockchain sequence was detected in the database. Cleaning up and removing block with hash \ - {}", - hash.to_hex() - ); - remove_orphan(db, hash)?; - return Err(ChainStorageError::InvalidBlock); + if let Ok(header) = fetch_header_with_block_hash(&**db, fork_start_header.prev_hash) { + if header.height + 1 == fork_start_header.height { + trace!( + target: LOG_TARGET, + "Connection with main chain found at block {} ({}).", + header.height, + header.hash().to_hex(), + ); + return Ok(fork_chain); + } } - hash = b.header.prev_hash.clone(); - height -= 1; - fork_chain.push_front(b); - } - // Check if the constructed fork chain is connected to the main chain. - let fork_start_header = fork_chain - .front() - .expect("The new orphan block should be in the queue") - .header - .clone(); - if let Ok(header) = fetch_header_with_block_hash(&**db, fork_start_header.prev_hash) { - if header.height + 1 == fork_start_header.height { - return Ok(fork_chain); + + trace!( + target: LOG_TARGET, + "Not connected, checking if fork chain can be extended.", + ); + match fetch_orphan(&**db, hash.clone()) { + Ok(prev_block) => { + trace!( + target: LOG_TARGET, + "Checking if block {} ({}) forms a sequence with next block.", + prev_block.header.height, + hash.to_hex(), + ); + if prev_block.header.height + 1 != height { + // Well now. The block heights don't form a sequence, which means that we should not only stop now, + // but remove one or both of these orphans from the pool because the blockchain is broken at this + // point. + info!( + target: LOG_TARGET, + "A broken blockchain sequence was detected, removing block {} ({}).", + prev_block.header.height, + hash.to_hex() + ); + remove_orphan(db, hash)?; + return Err(ChainStorageError::InvalidBlock); + } + trace!( + target: LOG_TARGET, + "Fork chain extended with block {} ({}).", + prev_block.header.height, + hash.to_hex(), + ); + hash = prev_block.header.prev_hash.clone(); + height -= 1; + fork_chain.push_front(prev_block); + }, + Err(ChainStorageError::ValueNotFound(_)) => { + trace!( + target: LOG_TARGET, + "Fork chain extension not found and it isn't connected to the main chain.", + ); + break; + }, + Err(e) => return Err(e), } } Ok(VecDeque::new()) } /// Try to find all orphan chain tips that originate from the current orphan parent block. -fn find_orphan_chain_tips( - db: &RwLockWriteGuard, - parent_height: u64, - parent_hash: BlockHash, -) -> Vec -{ +fn find_orphan_chain_tips(db: &T, parent_height: u64, parent_hash: BlockHash) -> Vec { let mut tip_hashes = Vec::::new(); let mut parents = Vec::<(BlockHash, u64)>::new(); db.for_each_orphan(|pair| { @@ -1294,14 +1150,14 @@ fn find_orphan_chain_tips( /// Find and return the orphan chain tip with the highest accumulated difficulty. fn find_strongest_orphan_tip( - db: &RwLockWriteGuard, + db: &T, orphan_chain_tips: Vec, ) -> Result<(Difficulty, BlockHash), ChainStorageError> { let mut best_accum_difficulty = Difficulty::min(); let mut best_tip_hash: Vec = vec![0; 32]; for tip_hash in orphan_chain_tips { - let header = fetch_orphan(db.deref(), tip_hash.clone())?.header; + let header = fetch_orphan(db, tip_hash.clone())?.header; let accum_difficulty = header.total_accumulated_difficulty_inclusive(); if accum_difficulty >= best_accum_difficulty { best_tip_hash = tip_hash; @@ -1311,6 +1167,73 @@ fn find_strongest_orphan_tip( Ok((best_accum_difficulty, best_tip_hash)) } +// Discards the orphan block with the minimum height from the block orphan pool to maintain the configured orphan pool +// storage limit. +fn cleanup_orphans_single( + db: &mut RwLockWriteGuard, + orphan_storage_capacity: usize, +) -> Result<(), ChainStorageError> +{ + if db.get_orphan_count()? > orphan_storage_capacity { + trace!( + target: LOG_TARGET, + "Orphan block storage limit reached, performing simple cleanup.", + ); + let mut min_height: u64 = u64::max_value(); + let mut remove_hash: Option = None; + db.for_each_orphan(|pair| { + let (_, block) = pair.unwrap(); + if block.header.height < min_height { + min_height = block.header.height; + remove_hash = Some(block.hash()); + } + }) + .expect("Unexpected result for database query"); + if let Some(hash) = remove_hash { + trace!(target: LOG_TARGET, "Discarding orphan block ({}).", hash.to_hex()); + remove_orphan(db, hash)?; + } + } + Ok(()) +} + +// Perform a comprehensive search to remove all the minimum height orphans to maintain the configured orphan pool +// storage limit. +fn cleanup_orphans_comprehensive( + db: &mut RwLockWriteGuard, + orphan_storage_capacity: usize, +) -> Result<(), ChainStorageError> +{ + let orphan_count = db.get_orphan_count()?; + if orphan_count > orphan_storage_capacity { + trace!( + target: LOG_TARGET, + "Orphan block storage limit reached, performing comprehensive cleanup.", + ); + let remove_count = orphan_count - orphan_storage_capacity; + + let mut orphans = Vec::<(u64, BlockHash)>::with_capacity(orphan_count); + db.for_each_orphan(|pair| { + let (_, block) = pair.unwrap(); + orphans.push((block.header.height, block.hash())); + }) + .expect("Unexpected result for database query"); + orphans.sort_by(|a, b| a.0.cmp(&b.0)); + + let mut txn = DbTransaction::new(); + for i in 0..remove_count { + trace!( + target: LOG_TARGET, + "Discarding orphan block ({}).", + orphans[i].1.to_hex() + ); + txn.delete(DbKey::OrphanBlock(orphans[i].1.clone())); + } + commit(db, txn)?; + } + Ok(()) +} + fn log_error(req: DbKey, err: ChainStorageError) -> Result { error!( target: LOG_TARGET, @@ -1326,9 +1249,9 @@ where T: BlockchainBackend { fn clone(&self) -> Self { BlockchainDatabase { - metadata: self.metadata.clone(), db: self.db.clone(), validators: self.validators.clone(), + config: self.config.clone(), } } } diff --git a/applications/tari_base_node/src/consts.rs b/base_layer/core/src/chain_storage/consts.rs similarity index 89% rename from applications/tari_base_node/src/consts.rs rename to base_layer/core/src/chain_storage/consts.rs index 83357cd40f..7eb3bfe74f 100644 --- a/applications/tari_base_node/src/consts.rs +++ b/base_layer/core/src/chain_storage/consts.rs @@ -1,4 +1,4 @@ -// Copyright 2019. The Tari Project +// Copyright 2020. The Tari Project // // Redistribution and use in source and binary forms, with or without modification, are permitted provided that the // following conditions are met: @@ -19,7 +19,6 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// -pub const VERSION: &str = "0.0.10.build-8560aab9"; -pub const AUTHOR: &str = "The Tari Community"; +/// The maximum number of orphans that can be stored in the Orphan block pool. +pub const BLOCKCHAIN_DATABASE_ORPHAN_STORAGE_CAPACITY: usize = 720; diff --git a/base_layer/core/src/chain_storage/error.rs b/base_layer/core/src/chain_storage/error.rs index d44b6fae01..a7aac538dd 100644 --- a/base_layer/core/src/chain_storage/error.rs +++ b/base_layer/core/src/chain_storage/error.rs @@ -24,50 +24,58 @@ use crate::{ chain_storage::{db_transaction::DbKey, MmrTree}, validation::ValidationError, }; -use derive_error::Error; use tari_mmr::{error::MerkleMountainRangeError, MerkleProofError}; +use thiserror::Error; #[derive(Debug, Clone, Error, PartialEq)] pub enum ChainStorageError { - // Access to the underlying storage mechanism failed - #[error(non_std, no_from)] + #[error("Access to the underlying storage mechanism failed:{0}")] AccessError(String), - // The database may be corrupted or otherwise be in an inconsistent state. Please check logs to try and identify - // the issue - #[error(non_std, no_from)] + #[error( + "The database may be corrupted or otherwise be in an inconsistent state. Please check logs to try and \ + identify the issue:{0}" + )] CorruptedDatabase(String), - // A given input could not be spent because it was not in the UTXO set + #[error("A given input could not be spent because it was not in the UTXO set")] UnspendableInput, - // A problem occurred trying to move a STXO back into the UTXO pool during a re-org. + #[error("A problem occurred trying to move a STXO back into the UTXO pool during a re-org.")] UnspendError, - // An unexpected result type was received for the given database request. This suggests that there is an internal - // error or bug of sorts. - #[error(msg_embedded, non_std, no_from)] + #[error( + "An unexpected result type was received for the given database request. This suggests that there is an \ + internal error or bug of sorts: {0}" + )] UnexpectedResult(String), - // You tried to execute an invalid Database operation - #[error(msg_embedded, non_std, no_from)] + #[error("You tried to execute an invalid Database operation:{0}")] InvalidOperation(String), - // There appears to be a critical error on the back end. The database might be in an inconsistent state. Check - // the logs for more information. - CriticalError, - // Cannot return data for requests older than the current pruning horizon + #[error("There appears to be a critical error on the back end:{0}. Check the logs for more information.")] + CriticalError(String), + #[error("Cannot return data for requests older than the current pruning horizon")] BeyondPruningHorizon, - // A parameter to the request was invalid - #[error(msg_embedded, non_std, no_from)] + #[error("A parameter to the request was invalid")] InvalidQuery(String), - // The requested value was not found in the database - #[error(non_std, no_from)] + #[error("The requested value '{0}' was not found in the database")] ValueNotFound(DbKey), - MerkleMountainRangeError(MerkleMountainRangeError), - MerkleProofError(MerkleProofError), - ValidationError(ValidationError), - // An MMR root in the provided block header did not match the MMR root in the database - #[error(non_std, no_from)] + #[error("MMR error: {source}")] + MerkleMountainRangeError { + #[from] + source: MerkleMountainRangeError, + }, + #[error("Merkle proof error: {source}")] + MerkleProofError { + #[from] + source: MerkleProofError, + }, + #[error("Validation error:{source}")] + ValidationError { + #[from] + source: ValidationError, + }, + #[error("The MMR root for {0} in the provided block header did not match the MMR root in the database")] MismatchedMmrRoot(MmrTree), - // An invalid block was submitted to the database + #[error("An invalid block was submitted to the database")] InvalidBlock, - #[error(msg_embedded, non_std, no_from)] + #[error("Blocking task spawn error:{0}")] BlockingTaskSpawnError(String), - // A request was out of range + #[error("A request was out of range")] OutOfRange, } diff --git a/base_layer/core/src/chain_storage/lmdb_db/lmdb_db.rs b/base_layer/core/src/chain_storage/lmdb_db/lmdb_db.rs index ea95d8f932..a21a658647 100644 --- a/base_layer/core/src/chain_storage/lmdb_db/lmdb_db.rs +++ b/base_layer/core/src/chain_storage/lmdb_db/lmdb_db.rs @@ -21,10 +21,22 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - blocks::{blockheader::BlockHeader, Block}, + blocks::{ + blockheader::{BlockHash, BlockHeader}, + Block, + }, chain_storage::{ blockchain_database::BlockchainBackend, - db_transaction::{DbKey, DbKeyValuePair, DbTransaction, DbValue, MetadataValue, MmrTree, WriteOperation}, + db_transaction::{ + DbKey, + DbKeyValuePair, + DbTransaction, + DbValue, + MetadataKey, + MetadataValue, + MmrTree, + WriteOperation, + }, error::ChainStorageError, lmdb_db::{ lmdb::{lmdb_delete, lmdb_exists, lmdb_for_each, lmdb_get, lmdb_insert, lmdb_len, lmdb_replace}, @@ -42,7 +54,9 @@ use crate::{ LMDB_DB_UTXO_MMR_CP_BACKEND, }, memory_db::MemDbVec, + ChainMetadata, }, + proof_of_work::{Difficulty, PowAlgorithm}, transactions::{ transaction::{TransactionKernel, TransactionOutput}, types::{HashDigest, HashOutput}, @@ -52,8 +66,8 @@ use croaring::Bitmap; use digest::Digest; use lmdb_zero::{Database, Environment, WriteTransaction}; use log::*; -use std::{path::Path, sync::Arc}; -use tari_crypto::tari_utilities::hash::Hashable; +use std::{collections::VecDeque, path::Path, sync::Arc}; +use tari_crypto::tari_utilities::{epoch_time::EpochTime, hash::Hashable}; use tari_mmr::{ functions::{prune_mutable_mmr, PrunedMutableMmr}, ArrayLike, @@ -76,6 +90,7 @@ where D: Digest { env: Arc, metadata_db: DatabaseRef, + mem_metadata: ChainMetadata, // Memory copy of stored metadata headers_db: DatabaseRef, block_hashes_db: DatabaseRef, utxos_db: DatabaseRef, @@ -102,7 +117,7 @@ where D: Digest + Send + Sync store.env(), store .get_handle(LMDB_DB_UTXO_MMR_CP_BACKEND) - .ok_or_else(|| ChainStorageError::CriticalError)? + .ok_or_else(|| ChainStorageError::CriticalError("Could not create UTXO MMR backend".to_string()))? .db() .clone(), ); @@ -110,7 +125,7 @@ where D: Digest + Send + Sync store.env(), store .get_handle(LMDB_DB_KERNEL_MMR_CP_BACKEND) - .ok_or_else(|| ChainStorageError::CriticalError)? + .ok_or_else(|| ChainStorageError::CriticalError("Could not create kernel MMR backend".to_string()))? .db() .clone(), ); @@ -118,49 +133,64 @@ where D: Digest + Send + Sync store.env(), store .get_handle(LMDB_DB_RANGE_PROOF_MMR_CP_BACKEND) - .ok_or_else(|| ChainStorageError::CriticalError)? + .ok_or_else(|| { + ChainStorageError::CriticalError("Could not create range proof MMR backend".to_string()) + })? .db() .clone(), ); + // Restore memory metadata + let env = store.env(); + let metadata_db = store + .get_handle(LMDB_DB_METADATA) + .ok_or_else(|| ChainStorageError::CriticalError("Could not create metadata backend".to_string()))? + .db() + .clone(); + let metadata = ChainMetadata { + height_of_longest_chain: fetch_chain_height(&env, &metadata_db)?, + best_block: fetch_best_block(&env, &metadata_db)?, + pruning_horizon: fetch_pruning_horizon(&env, &metadata_db)?, + accumulated_difficulty: fetch_accumulated_work(&env, &metadata_db)?, + }; + Ok(Self { - metadata_db: store - .get_handle(LMDB_DB_METADATA) - .ok_or_else(|| ChainStorageError::CriticalError)? - .db() - .clone(), + metadata_db, + mem_metadata: metadata, headers_db: store .get_handle(LMDB_DB_HEADERS) - .ok_or_else(|| ChainStorageError::CriticalError)? + .ok_or_else(|| ChainStorageError::CriticalError("Could not get handle to headers DB".to_string()))? .db() .clone(), block_hashes_db: store .get_handle(LMDB_DB_BLOCK_HASHES) - .ok_or_else(|| ChainStorageError::CriticalError)? + .ok_or_else(|| { + ChainStorageError::CriticalError("Could not create handle to block hashes DB".to_string()) + })? .db() .clone(), utxos_db: store .get_handle(LMDB_DB_UTXOS) - .ok_or_else(|| ChainStorageError::CriticalError)? + .ok_or_else(|| ChainStorageError::CriticalError("Could not create handle to UTXOs DB".to_string()))? .db() .clone(), stxos_db: store .get_handle(LMDB_DB_STXOS) - .ok_or_else(|| ChainStorageError::CriticalError)? + .ok_or_else(|| ChainStorageError::CriticalError("Could not create handle to STXOs DB".to_string()))? .db() .clone(), txos_hash_to_index_db: store .get_handle(LMDB_DB_TXOS_HASH_TO_INDEX) - .ok_or_else(|| ChainStorageError::CriticalError)? + .ok_or_else(|| ChainStorageError::CriticalError("Could not create handle to TXOs DB".to_string()))? .db() .clone(), kernels_db: store .get_handle(LMDB_DB_KERNELS) - .ok_or_else(|| ChainStorageError::CriticalError)? + .ok_or_else(|| ChainStorageError::CriticalError("Could not create handle to kernels DB".to_string()))? .db() .clone(), orphans_db: store .get_handle(LMDB_DB_ORPHANS) - .ok_or_else(|| ChainStorageError::CriticalError)? + .ok_or_else(|| ChainStorageError::CriticalError("Could not create handle to orphans DB".to_string()))? .db() .clone(), utxo_mmr: MmrCache::new(MemDbVec::new(), utxo_checkpoints.clone(), mmr_cache_config)?, @@ -172,7 +202,7 @@ where D: Digest + Send + Sync range_proof_mmr: MmrCache::new(MemDbVec::new(), range_proof_checkpoints.clone(), mmr_cache_config)?, range_proof_checkpoints, curr_range_proof_checkpoint: MerkleCheckPoint::new(Vec::new(), Bitmap::create()), - env: store.env(), + env, }) } @@ -276,6 +306,7 @@ where D: Digest + Send + Sync // changes committed to the backend databases. CreateMmrCheckpoint and RewindMmr txns will be performed after these // txns have been successfully applied. fn apply_mmr_and_storage_txs(&mut self, tx: &DbTransaction) -> Result<(), ChainStorageError> { + let mut update_mem_metadata = false; let txn = WriteTransaction::new(self.env.clone()).map_err(|e| ChainStorageError::AccessError(e.to_string()))?; { for op in tx.operations.iter() { @@ -283,6 +314,7 @@ where D: Digest + Send + Sync WriteOperation::Insert(insert) => match insert { DbKeyValuePair::Metadata(k, v) => { lmdb_replace(&txn, &self.metadata_db, &(k.clone() as u32), &v)?; + update_mem_metadata = true; }, DbKeyValuePair::BlockHeader(k, v) => { if lmdb_exists(&self.env, &self.headers_db, &k)? { @@ -389,7 +421,18 @@ where D: Digest + Send + Sync } } } - txn.commit().map_err(|e| ChainStorageError::AccessError(e.to_string())) + txn.commit() + .map_err(|e| ChainStorageError::AccessError(e.to_string()))?; + + if update_mem_metadata { + self.mem_metadata = ChainMetadata { + height_of_longest_chain: fetch_chain_height(&self.env, &self.metadata_db)?, + best_block: fetch_best_block(&self.env, &self.metadata_db)?, + pruning_horizon: fetch_pruning_horizon(&self.env, &self.metadata_db)?, + accumulated_difficulty: fetch_accumulated_work(&self.env, &self.metadata_db)?, + }; + } + Ok(()) } // Returns the leaf index of the hash. If the hash is in the newly added hashes it returns the future MMR index for @@ -464,7 +507,7 @@ pub fn create_lmdb_database( std::fs::create_dir_all(&path).unwrap_or_default(); let lmdb_store = LMDBBuilder::new() .set_path(path.to_str().unwrap()) - .set_environment_size(15) + .set_environment_size(50000) .set_max_number_of_databases(15) .add_database(LMDB_DB_METADATA, flags) .add_database(LMDB_DB_HEADERS, flags) @@ -478,7 +521,7 @@ pub fn create_lmdb_database( .add_database(LMDB_DB_KERNEL_MMR_CP_BACKEND, flags) .add_database(LMDB_DB_RANGE_PROOF_MMR_CP_BACKEND, flags) .build() - .map_err(|_| ChainStorageError::CriticalError)?; + .map_err(|err| ChainStorageError::CriticalError(format!("Could not create LMDB store:{}", err)))?; LMDBDatabase::::new(lmdb_store, mmr_cache_config) } @@ -617,6 +660,11 @@ where D: Digest + Send + Sync lmdb_for_each::(&self.env, &self.orphans_db, f) } + /// Returns the number of blocks in the block orphan pool. + fn get_orphan_count(&self) -> Result { + lmdb_len(&self.env, &self.orphans_db) + } + /// Iterate over all the stored transaction kernels and execute the function `f` for each kernel. fn for_each_kernel(&self, f: F) -> Result<(), ChainStorageError> where F: FnMut(Result<(HashOutput, TransactionKernel), ChainStorageError>) { @@ -645,6 +693,96 @@ where D: Digest + Send + Sync Ok(None) } } + + /// Returns the metadata of the chain. + fn fetch_metadata(&self) -> Result { + Ok(self.mem_metadata.clone()) + } + + /// Returns the set of target difficulties for the specified proof of work algorithm. + fn fetch_target_difficulties( + &self, + pow_algo: PowAlgorithm, + height: u64, + block_window: usize, + ) -> Result, ChainStorageError> + { + let mut target_difficulties = VecDeque::<(EpochTime, Difficulty)>::with_capacity(block_window); + let tip_height = self.mem_metadata.height_of_longest_chain.ok_or_else(|| { + ChainStorageError::InvalidQuery("Cannot retrieve chain height. Blockchain DB is empty".into()) + })?; + if height <= tip_height { + for height in (0..=height).rev() { + let header: BlockHeader = lmdb_get(&self.env, &self.headers_db, &height)? + .ok_or_else(|| ChainStorageError::InvalidQuery("Cannot retrieve header.".into()))?; + if header.pow.pow_algo == pow_algo { + target_difficulties.push_front((header.timestamp, header.pow.target_difficulty)); + if target_difficulties.len() >= block_window { + break; + } + } + } + } + Ok(target_difficulties + .into_iter() + .collect::>()) + } +} + +// Fetches the chain height from the provided metadata db. +fn fetch_chain_height(env: &Environment, db: &Database) -> Result, ChainStorageError> { + let k = MetadataKey::ChainHeight; + let val: Option = lmdb_get(&env, &db, &(k as u32))?; + let val: Option = val.map(DbValue::Metadata); + Ok( + if let Some(DbValue::Metadata(MetadataValue::ChainHeight(height))) = val { + height + } else { + None + }, + ) +} + +// Fetches the best block hash from the provided metadata db. +fn fetch_best_block(env: &Environment, db: &Database) -> Result, ChainStorageError> { + let k = MetadataKey::BestBlock; + let val: Option = lmdb_get(&env, &db, &(k as u32))?; + let val: Option = val.map(DbValue::Metadata); + Ok( + if let Some(DbValue::Metadata(MetadataValue::BestBlock(best_block))) = val { + best_block + } else { + None + }, + ) +} + +// Fetches the accumulated work from the provided metadata db. +fn fetch_accumulated_work(env: &Environment, db: &Database) -> Result, ChainStorageError> { + let k = MetadataKey::AccumulatedWork; + let val: Option = lmdb_get(&env, &db, &(k as u32))?; + let val: Option = val.map(DbValue::Metadata); + Ok( + if let Some(DbValue::Metadata(MetadataValue::AccumulatedWork(accumulated_work))) = val { + accumulated_work + } else { + None + }, + ) +} + +// Fetches the pruning horizon from the provided metadata db. +fn fetch_pruning_horizon(env: &Environment, db: &Database) -> Result { + let k = MetadataKey::PruningHorizon; + let val: Option = lmdb_get(&env, &db, &(k as u32))?; + let val: Option = val.map(DbValue::Metadata); + Ok( + if let Some(DbValue::Metadata(MetadataValue::PruningHorizon(pruning_horizon))) = val { + pruning_horizon + } else { + 2880 + }, + ) } // Calculated the new checkpoint count after rewinding a set number of steps back. diff --git a/base_layer/core/src/chain_storage/lmdb_db/mod.rs b/base_layer/core/src/chain_storage/lmdb_db/mod.rs index 3915216817..2f5103fb5a 100644 --- a/base_layer/core/src/chain_storage/lmdb_db/mod.rs +++ b/base_layer/core/src/chain_storage/lmdb_db/mod.rs @@ -21,6 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mod lmdb; +#[allow(clippy::module_inception)] mod lmdb_db; mod lmdb_vec; diff --git a/base_layer/core/src/chain_storage/memory_db/memory_db.rs b/base_layer/core/src/chain_storage/memory_db/memory_db.rs index a02e441376..67dd3bf169 100644 --- a/base_layer/core/src/chain_storage/memory_db/memory_db.rs +++ b/base_layer/core/src/chain_storage/memory_db/memory_db.rs @@ -23,13 +23,24 @@ //! This is a memory-based blockchain database, generally only useful for testing purposes use crate::{ - blocks::{Block, BlockHeader}, + blocks::{blockheader::BlockHash, Block, BlockHeader}, chain_storage::{ blockchain_database::BlockchainBackend, - db_transaction::{DbKey, DbKeyValuePair, DbTransaction, DbValue, MetadataValue, MmrTree, WriteOperation}, + db_transaction::{ + DbKey, + DbKeyValuePair, + DbTransaction, + DbValue, + MetadataKey, + MetadataValue, + MmrTree, + WriteOperation, + }, error::ChainStorageError, memory_db::MemDbVec, + ChainMetadata, }, + proof_of_work::{Difficulty, PowAlgorithm}, transactions::{ transaction::{TransactionKernel, TransactionOutput}, types::HashOutput, @@ -38,10 +49,10 @@ use crate::{ use croaring::Bitmap; use digest::Digest; use std::{ - collections::HashMap, + collections::{HashMap, VecDeque}, sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}, }; -use tari_crypto::tari_utilities::hash::Hashable; +use tari_crypto::tari_utilities::{epoch_time::EpochTime, hash::Hashable}; use tari_mmr::{ functions::{prune_mutable_mmr, PrunedMutableMmr}, ArrayLike, @@ -95,7 +106,7 @@ where D: Digest } impl MemoryDatabase -where D: Digest +where D: Digest + Send + Sync { pub fn new(mmr_cache_config: MmrCacheConfig) -> Self { let utxo_checkpoints = MemDbVec::::new(); @@ -133,6 +144,58 @@ where D: Digest .read() .map_err(|e| ChainStorageError::AccessError(e.to_string())) } + + // Fetches the chain metadata chain height. + fn fetch_chain_height(&self) -> Result, ChainStorageError> { + Ok( + if let Some(DbValue::Metadata(MetadataValue::ChainHeight(height))) = + self.fetch(&DbKey::Metadata(MetadataKey::ChainHeight))? + { + height + } else { + None + }, + ) + } + + // Fetches the chain metadata best block hash. + fn fetch_best_block(&self) -> Result, ChainStorageError> { + Ok( + if let Some(DbValue::Metadata(MetadataValue::BestBlock(best_block))) = + self.fetch(&DbKey::Metadata(MetadataKey::BestBlock))? + { + best_block + } else { + None + }, + ) + } + + // Fetches the chain metadata accumulated work. + fn fetch_accumulated_work(&self) -> Result, ChainStorageError> { + Ok( + if let Some(DbValue::Metadata(MetadataValue::AccumulatedWork(accumulated_work))) = + self.fetch(&DbKey::Metadata(MetadataKey::AccumulatedWork))? + { + accumulated_work + } else { + None + }, + ) + } + + // Fetches the chain metadata pruning horizon. + fn fetch_pruning_horizon(&self) -> Result { + Ok( + if let Some(DbValue::Metadata(MetadataValue::PruningHorizon(pruning_horizon))) = + self.fetch(&DbKey::Metadata(MetadataKey::PruningHorizon))? + { + pruning_horizon + } else { + 2880 + }, + ) + } } impl BlockchainBackend for MemoryDatabase @@ -410,6 +473,12 @@ where D: Digest + Send + Sync Ok(()) } + /// Returns the number of blocks in the block orphan pool. + fn get_orphan_count(&self) -> Result { + let db = self.db_access()?; + Ok(db.orphans.len()) + } + /// Iterate over all the stored transaction kernels and execute the function `f` for each kernel. fn for_each_kernel(&self, mut f: F) -> Result<(), ChainStorageError> where F: FnMut(Result<(HashOutput, TransactionKernel), ChainStorageError>) { @@ -451,6 +520,48 @@ where D: Digest + Send + Sync Ok(None) } } + + /// Returns the metadata of the chain. + fn fetch_metadata(&self) -> Result { + Ok(ChainMetadata { + height_of_longest_chain: self.fetch_chain_height()?, + best_block: self.fetch_best_block()?, + pruning_horizon: self.fetch_pruning_horizon()?, + accumulated_difficulty: self.fetch_accumulated_work()?, + }) + } + + /// Returns the set of target difficulties for the specified proof of work algorithm. + fn fetch_target_difficulties( + &self, + pow_algo: PowAlgorithm, + height: u64, + block_window: usize, + ) -> Result, ChainStorageError> + { + let mut target_difficulties = VecDeque::<(EpochTime, Difficulty)>::with_capacity(block_window); + let tip_height = self.fetch_chain_height()?.ok_or_else(|| { + ChainStorageError::InvalidQuery("Cannot retrieve chain height. Blockchain DB is empty".into()) + })?; + if height <= tip_height { + let db = self.db_access()?; + for height in (0..=height).rev() { + let header = db + .headers + .get(&height) + .ok_or_else(|| ChainStorageError::InvalidQuery("Cannot retrieve header.".into()))?; + if header.pow.pow_algo == pow_algo { + target_difficulties.push_front((header.timestamp, header.pow.target_difficulty)); + if target_difficulties.len() >= block_window { + break; + } + } + } + } + Ok(target_difficulties + .into_iter() + .collect::>()) + } } impl Clone for MemoryDatabase diff --git a/base_layer/core/src/chain_storage/memory_db/mod.rs b/base_layer/core/src/chain_storage/memory_db/mod.rs index 650a886f3a..8348ec10db 100644 --- a/base_layer/core/src/chain_storage/memory_db/mod.rs +++ b/base_layer/core/src/chain_storage/memory_db/mod.rs @@ -21,6 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mod mem_db_vec; +#[allow(clippy::module_inception)] mod memory_db; // Public API exports diff --git a/base_layer/core/src/chain_storage/mod.rs b/base_layer/core/src/chain_storage/mod.rs index 4e626963e5..3a8f9f4e1d 100644 --- a/base_layer/core/src/chain_storage/mod.rs +++ b/base_layer/core/src/chain_storage/mod.rs @@ -27,6 +27,7 @@ //! backed by LMDB, while the merkle trees are stored in flat files for example. mod blockchain_database; +mod consts; mod db_transaction; mod error; mod historical_block; @@ -40,14 +41,15 @@ pub mod async_db; // Public API exports pub use blockchain_database::{ calculate_mmr_roots, - calculate_mmr_roots_writeguard, fetch_header, - fetch_header_writeguard, + fetch_headers, + fetch_target_difficulties, + is_stxo, is_utxo, - is_utxo_writeguard, BlockAddResult, BlockchainBackend, BlockchainDatabase, + BlockchainDatabaseConfig, MutableMmrState, Validators, }; diff --git a/base_layer/core/src/consensus/consensus_constants.rs b/base_layer/core/src/consensus/consensus_constants.rs index b2cb835652..83b8981a2b 100644 --- a/base_layer/core/src/consensus/consensus_constants.rs +++ b/base_layer/core/src/consensus/consensus_constants.rs @@ -22,7 +22,7 @@ use crate::{ consensus::network::Network, - proof_of_work::Difficulty, + proof_of_work::{Difficulty, PowAlgorithm}, transactions::tari_amount::{uT, MicroTari, T}, }; use chrono::{DateTime, Duration, Utc}; @@ -58,7 +58,7 @@ pub struct ConsensusConstants { /// This is the emission curve tail amount pub(in crate::consensus) emission_tail: MicroTari, /// This is the initial min difficulty for the difficulty adjustment - min_pow_difficulty: Difficulty, + min_pow_difficulty: (Difficulty, Difficulty), } // The target time used by the difficulty adjustment algorithms, their target time is the target block interval * PoW // algorithm count @@ -132,13 +132,17 @@ impl ConsensusConstants { } // This is the min initial difficulty that can be requested for the pow - pub fn min_pow_difficulty(&self) -> Difficulty { - self.min_pow_difficulty + pub fn min_pow_difficulty(&self, pow_algo: PowAlgorithm) -> Difficulty { + match pow_algo { + PowAlgorithm::Monero => self.min_pow_difficulty.0, + PowAlgorithm::Blake => self.min_pow_difficulty.1, + } } + #[allow(clippy::identity_op)] pub fn rincewind() -> Self { - let target_block_interval = 60; - let difficulty_block_window = 150; + let target_block_interval = 120; + let difficulty_block_window = 90; ConsensusConstants { coinbase_lock_height: 60, blockchain_version: 1, @@ -146,13 +150,13 @@ impl ConsensusConstants { target_block_interval, difficulty_block_window, difficulty_max_block_interval: target_block_interval * 60, - max_block_transaction_weight: 6250, + max_block_transaction_weight: 19500, pow_algo_count: 1, median_timestamp_count: 11, emission_initial: 5_538_846_115 * uT, emission_decay: 0.999_999_560_409_038_5, emission_tail: 1 * T, - min_pow_difficulty: 6_000_000.into(), + min_pow_difficulty: (1.into(), 60_000_000.into()), } } @@ -166,13 +170,13 @@ impl ConsensusConstants { target_block_interval, difficulty_max_block_interval: target_block_interval * 6, difficulty_block_window, - max_block_transaction_weight: 6250, + max_block_transaction_weight: 19500, pow_algo_count: 2, median_timestamp_count: 11, emission_initial: 10_000_000.into(), emission_decay: 0.999, emission_tail: 100.into(), - min_pow_difficulty: 1.into(), + min_pow_difficulty: (1.into(), 1.into()), } } @@ -187,13 +191,13 @@ impl ConsensusConstants { target_block_interval, difficulty_max_block_interval: target_block_interval * 6, difficulty_block_window, - max_block_transaction_weight: 6250, + max_block_transaction_weight: 19500, pow_algo_count: 2, median_timestamp_count: 11, emission_initial: 10_000_000.into(), emission_decay: 0.999, emission_tail: 100.into(), - min_pow_difficulty: 500_000_000.into(), + min_pow_difficulty: (1.into(), 500_000_000.into()), } } } diff --git a/base_layer/core/src/consensus/consensus_manager.rs b/base_layer/core/src/consensus/consensus_manager.rs index cf63dfd4df..4b27706dfe 100644 --- a/base_layer/core/src/consensus/consensus_manager.rs +++ b/base_layer/core/src/consensus/consensus_manager.rs @@ -30,21 +30,19 @@ use crate::{ }, Block, }, - chain_storage::{BlockchainBackend, ChainMetadata, ChainStorageError}, + chain_storage::ChainStorageError, consensus::{emission::EmissionSchedule, network::Network, ConsensusConstants}, - proof_of_work::{DiffAdjManager, DiffAdjManagerError, Difficulty, DifficultyAdjustmentError, PowAlgorithm}, + proof_of_work::DifficultyAdjustmentError, transactions::tari_amount::MicroTari, }; use derive_error::Error; -use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}; -use tari_crypto::tari_utilities::{epoch_time::EpochTime, hash::Hashable}; +use std::sync::Arc; +use tari_crypto::tari_utilities::hash::Hashable; #[derive(Debug, Error, Clone, PartialEq)] pub enum ConsensusManagerError { /// Difficulty adjustment encountered an error DifficultyAdjustmentError(DifficultyAdjustmentError), - /// Difficulty adjustment manager encountered an error - DifficultyAdjustmentManagerError(DiffAdjManagerError), /// Problem with the DB backend storage ChainStorageError(ChainStorageError), /// There is no blockchain to query @@ -69,7 +67,7 @@ impl ConsensusManager { match self.inner.network { Network::MainNet => get_mainnet_genesis_block(), Network::Rincewind => get_rincewind_genesis_block(), - Network::LocalNet => (self.inner.gen_block.clone().unwrap_or(get_rincewind_genesis_block())), + Network::LocalNet => (self.inner.gen_block.clone().unwrap_or_else(get_rincewind_genesis_block)), } } @@ -78,7 +76,7 @@ impl ConsensusManager { match self.inner.network { Network::MainNet => get_mainnet_block_hash(), Network::Rincewind => get_rincewind_block_hash(), - Network::LocalNet => (self.inner.gen_block.as_ref().unwrap_or(&get_rincewind_genesis_block())).hash(), + Network::LocalNet => (self.inner.gen_block.clone().unwrap_or_else(get_rincewind_genesis_block)).hash(), } } @@ -92,131 +90,12 @@ impl ConsensusManager { &self.inner.consensus_constants } - /// This moves over a difficulty adjustment manager to the ConsensusManager to control. - pub fn set_diff_manager(&self, diff_manager: DiffAdjManager) -> Result<(), ConsensusManagerError> { - let mut lock = self - .inner - .diff_adj_manager - .write() - .map_err(|e| ConsensusManagerError::PoisonedAccess(e.to_string()))?; - *lock = Some(diff_manager); - Ok(()) - } - - /// This returns the difficulty adjustment manager back. This can safely be cloned as the Difficulty adjustment - /// manager wraps an ARC in side of it. - pub fn get_diff_manager(&self) -> Result { - match self.access_diff_adj()?.as_ref() { - Some(v) => Ok(v.clone()), - None => Err(ConsensusManagerError::MissingDifficultyAdjustmentManager), - } - } - - /// Returns the estimated target difficulty for the specified PoW algorithm at the chain tip. - pub fn get_target_difficulty( - &self, - metadata: &RwLockReadGuard, - db: &RwLockReadGuard, - pow_algo: PowAlgorithm, - ) -> Result - { - match self.access_diff_adj()?.as_ref() { - Some(v) => v - .get_target_difficulty(metadata, db, pow_algo) - .map_err(ConsensusManagerError::DifficultyAdjustmentManagerError), - None => Err(ConsensusManagerError::MissingDifficultyAdjustmentManager), - } - } - - /// Returns the estimated target difficulty for the specified PoW algorithm and provided height. - pub fn get_target_difficulty_with_height( - &self, - db: &RwLockReadGuard, - pow_algo: PowAlgorithm, - height: u64, - ) -> Result - { - match self.access_diff_adj()?.as_ref() { - Some(v) => v - .get_target_difficulty_at_height(db, pow_algo, height) - .map_err(ConsensusManagerError::DifficultyAdjustmentManagerError), - None => Err(ConsensusManagerError::MissingDifficultyAdjustmentManager), - } - } - - pub fn get_target_difficulty_with_height_writeguard( - &self, - db: &RwLockWriteGuard, - pow_algo: PowAlgorithm, - height: u64, - ) -> Result - { - match self.access_diff_adj()?.as_ref() { - Some(v) => v - .get_target_difficulty_at_height_writeguard(db, pow_algo, height) - .map_err(ConsensusManagerError::DifficultyAdjustmentManagerError), - None => Err(ConsensusManagerError::MissingDifficultyAdjustmentManager), - } - } - - /// Returns the median timestamp of the past 11 blocks at the chain tip. - pub fn get_median_timestamp( - &self, - metadata: &RwLockReadGuard, - db: &RwLockReadGuard, - ) -> Result - { - match self.access_diff_adj()?.as_ref() { - Some(v) => v - .get_median_timestamp(metadata, db) - .map_err(ConsensusManagerError::DifficultyAdjustmentManagerError), - None => Err(ConsensusManagerError::MissingDifficultyAdjustmentManager), - } - } - - /// Returns the median timestamp of the past 11 blocks at the provided height. - pub fn get_median_timestamp_at_height( - &self, - db: &RwLockReadGuard, - height: u64, - ) -> Result - { - match self.access_diff_adj()?.as_ref() { - Some(v) => v - .get_median_timestamp_at_height(db, height) - .map_err(ConsensusManagerError::DifficultyAdjustmentManagerError), - None => Err(ConsensusManagerError::MissingDifficultyAdjustmentManager), - } - } - - pub fn get_median_timestamp_at_height_writeguard( - &self, - db: &RwLockWriteGuard, - height: u64, - ) -> Result - { - match self.access_diff_adj()?.as_ref() { - Some(v) => v - .get_median_timestamp_at_height_writeguard(db, height) - .map_err(ConsensusManagerError::DifficultyAdjustmentManagerError), - None => Err(ConsensusManagerError::MissingDifficultyAdjustmentManager), - } - } - /// Creates a total_coinbase offset containing all fees for the validation from block pub fn calculate_coinbase_and_fees(&self, block: &Block) -> MicroTari { let coinbase = self.emission_schedule().block_reward(block.header.height); coinbase + block.calculate_fees() } - // Inner helper function to access to the difficulty adjustment manager - fn access_diff_adj(&self) -> Result>, ConsensusManagerError> { - self.inner - .diff_adj_manager - .read() - .map_err(|e| ConsensusManagerError::PoisonedAccess(e.to_string())) - } - /// This is the currently configured chain network. pub fn network(&self) -> Network { self.inner.network @@ -233,11 +112,8 @@ impl Clone for ConsensusManager { /// This is the used to control all consensus values. struct ConsensusManagerInner { - /// Difficulty adjustment manager for the blockchain - pub diff_adj_manager: RwLock>, /// This is the inner struct used to control all consensus values. pub consensus_constants: ConsensusConstants, - /// The configured chain network. pub network: Network, /// The configuration for the emission schedule. @@ -248,8 +124,6 @@ struct ConsensusManagerInner { /// Constructor for the consensus manager struct pub struct ConsensusManagerBuilder { - /// Difficulty adjustment manager for the blockchain - pub diff_adj_manager: Option, /// This is the inner struct used to control all consensus values. pub consensus_constants: Option, /// The configured chain network. @@ -262,7 +136,6 @@ impl ConsensusManagerBuilder { /// Creates a new ConsensusManagerBuilder with the specified network pub fn new(network: Network) -> Self { ConsensusManagerBuilder { - diff_adj_manager: None, consensus_constants: None, network, gen_block: None, @@ -275,12 +148,6 @@ impl ConsensusManagerBuilder { self } - /// Adds in a difficulty adjustment manager to be used to be used - pub fn with_difficulty_adjustment_manager(mut self, difficulty_adj: DiffAdjManager) -> Self { - self.diff_adj_manager = Some(difficulty_adj); - self - } - /// Adds in a custom block to be used. This will be overwritten if the network is anything else than localnet pub fn with_block(mut self, block: Block) -> Self { self.gen_block = Some(block); @@ -288,6 +155,7 @@ impl ConsensusManagerBuilder { } /// Builds a consensus manager + #[allow(clippy::or_fun_call)] pub fn build(self) -> ConsensusManager { let consensus_constants = self .consensus_constants @@ -298,7 +166,6 @@ impl ConsensusManagerBuilder { consensus_constants.emission_tail, ); let inner = ConsensusManagerInner { - diff_adj_manager: RwLock::new(self.diff_adj_manager), consensus_constants, network: self.network, emission, diff --git a/base_layer/core/src/consensus/network.rs b/base_layer/core/src/consensus/network.rs index 479f3ed6d4..9bc840e0b0 100644 --- a/base_layer/core/src/consensus/network.rs +++ b/base_layer/core/src/consensus/network.rs @@ -35,7 +35,7 @@ pub enum Network { } impl Network { - pub fn create_consensus_constants(&self) -> ConsensusConstants { + pub fn create_consensus_constants(self) -> ConsensusConstants { match self { Network::MainNet => ConsensusConstants::mainnet(), Network::Rincewind => ConsensusConstants::rincewind(), diff --git a/base_layer/core/src/helpers/mock_backend.rs b/base_layer/core/src/helpers/mock_backend.rs index 2645185730..f3dde27577 100644 --- a/base_layer/core/src/helpers/mock_backend.rs +++ b/base_layer/core/src/helpers/mock_backend.rs @@ -19,16 +19,17 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// use crate::{ blocks::{Block, BlockHeader}, - chain_storage::{BlockchainBackend, ChainStorageError, DbKey, DbTransaction, DbValue, MmrTree}, + chain_storage::{BlockchainBackend, ChainMetadata, ChainStorageError, DbKey, DbTransaction, DbValue, MmrTree}, + proof_of_work::{Difficulty, PowAlgorithm}, transactions::{ transaction::{TransactionKernel, TransactionOutput}, types::HashOutput, }, }; +use tari_crypto::tari_utilities::epoch_time::EpochTime; use tari_mmr::{Hash, MerkleCheckPoint, MerkleProof}; // This is a test backend. This is used so that the ConsensusManager can be called without actually having a backend. @@ -86,6 +87,10 @@ impl BlockchainBackend for MockBackend { unimplemented!() } + fn get_orphan_count(&self) -> Result { + unimplemented!() + } + fn for_each_kernel(&self, _f: F) -> Result<(), ChainStorageError> where Self: Sized, @@ -113,4 +118,18 @@ impl BlockchainBackend for MockBackend { fn fetch_last_header(&self) -> Result, ChainStorageError> { unimplemented!() } + + fn fetch_metadata(&self) -> Result { + unimplemented!() + } + + fn fetch_target_difficulties( + &self, + _pow_algo: PowAlgorithm, + _height: u64, + _block_window: usize, + ) -> Result, ChainStorageError> + { + unimplemented!() + } } diff --git a/base_layer/core/src/helpers/mod.rs b/base_layer/core/src/helpers/mod.rs index 449cc26756..6c7058ad62 100644 --- a/base_layer/core/src/helpers/mod.rs +++ b/base_layer/core/src/helpers/mod.rs @@ -27,10 +27,10 @@ mod mock_backend; use crate::{ blocks::{Block, BlockHeader}, - chain_storage::{BlockchainDatabase, MemoryDatabase, Validators}, + chain_storage::{BlockchainDatabase, BlockchainDatabaseConfig, MemoryDatabase, Validators}, consensus::{ConsensusConstants, ConsensusManager}, transactions::{transaction::Transaction, types::HashDigest}, - validation::mocks::MockValidator, + validation::{accum_difficulty_validators::MockAccumDifficultyValidator, mocks::MockValidator}, }; pub use mock_backend::MockBackend; @@ -49,7 +49,11 @@ pub fn create_orphan_block( } pub fn create_mem_db(consensus_manager: &ConsensusManager) -> BlockchainDatabase> { - let validators = Validators::new(MockValidator::new(true), MockValidator::new(true)); + let validators = Validators::new( + MockValidator::new(true), + MockValidator::new(true), + MockAccumDifficultyValidator {}, + ); let db = MemoryDatabase::::default(); - BlockchainDatabase::new(db, consensus_manager, validators).unwrap() + BlockchainDatabase::new(db, consensus_manager, validators, BlockchainDatabaseConfig::default()).unwrap() } diff --git a/base_layer/core/src/mempool/async_mempool.rs b/base_layer/core/src/mempool/async_mempool.rs index 049729ce0c..0b5e390b45 100644 --- a/base_layer/core/src/mempool/async_mempool.rs +++ b/base_layer/core/src/mempool/async_mempool.rs @@ -23,7 +23,7 @@ use crate::{ blocks::Block, chain_storage::BlockchainBackend, - mempool::{error::MempoolError, Mempool, StatsResponse, TxStorageResponse}, + mempool::{error::MempoolError, Mempool, StateResponse, StatsResponse, TxStorageResponse}, transactions::{transaction::Transaction, types::Signature}, }; use std::sync::Arc; @@ -69,3 +69,4 @@ make_async!(snapshot() -> Vec>); make_async!(retrieve(total_weight: u64) -> Vec>); make_async!(has_tx_with_excess_sig(excess_sig: Signature) -> TxStorageResponse); make_async!(stats() -> StatsResponse); +make_async!(state() -> StateResponse); diff --git a/base_layer/core/src/mempool/consts.rs b/base_layer/core/src/mempool/consts.rs index ea8c2b7125..ec6e2cacca 100644 --- a/base_layer/core/src/mempool/consts.rs +++ b/base_layer/core/src/mempool/consts.rs @@ -19,7 +19,6 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// use std::time::Duration; diff --git a/base_layer/core/src/mempool/error.rs b/base_layer/core/src/mempool/error.rs index 281fe4ff81..e3d46dc976 100644 --- a/base_layer/core/src/mempool/error.rs +++ b/base_layer/core/src/mempool/error.rs @@ -44,4 +44,7 @@ pub enum MempoolError { ChainHeightUndefined, #[error(msg_embedded, non_std, no_from)] BlockingTaskSpawnError(String), + /// A problem has been encountered with the storage backend. + #[error(non_std, no_from)] + BackendError(String), } diff --git a/base_layer/core/src/mempool/mempool.rs b/base_layer/core/src/mempool/mempool.rs index 3ad70bfec9..fd2cbd9441 100644 --- a/base_layer/core/src/mempool/mempool.rs +++ b/base_layer/core/src/mempool/mempool.rs @@ -25,22 +25,16 @@ use crate::{ chain_storage::{BlockchainBackend, BlockchainDatabase}, mempool::{ error::MempoolError, - orphan_pool::OrphanPool, - pending_pool::PendingPool, - reorg_pool::ReorgPool, - unconfirmed_pool::UnconfirmedPool, + mempool_storage::MempoolStorage, MempoolConfig, + StateResponse, StatsResponse, TxStorageResponse, }, transactions::{transaction::Transaction, types::Signature}, - validation::{Validation, ValidationError, Validator}, + validation::{Validation, Validator}, }; -use log::*; -use std::sync::Arc; -use tari_crypto::tari_utilities::{hex::Hex, Hashable}; - -pub const LOG_TARGET: &str = "c::mp::mempool"; +use std::sync::{Arc, RwLock}; /// Struct containing the validators the mempool needs to run, It forces the correct amount of validators are given pub struct MempoolValidators { @@ -71,12 +65,7 @@ impl MempoolValidators { pub struct Mempool where T: BlockchainBackend { - blockchain_db: BlockchainDatabase, - unconfirmed_pool: UnconfirmedPool, - orphan_pool: OrphanPool, - pending_pool: PendingPool, - reorg_pool: ReorgPool, - validator: Arc>, + pool_storage: Arc>>, } impl Mempool @@ -84,172 +73,76 @@ where T: BlockchainBackend { /// Create a new Mempool with an UnconfirmedPool, OrphanPool, PendingPool and ReOrgPool. pub fn new(blockchain_db: BlockchainDatabase, config: MempoolConfig, validators: MempoolValidators) -> Self { - let (mempool_validator, orphan_validator) = validators.into_validators(); Self { - unconfirmed_pool: UnconfirmedPool::new(config.unconfirmed_pool_config), - orphan_pool: OrphanPool::new(config.orphan_pool_config, orphan_validator, blockchain_db.clone()), - pending_pool: PendingPool::new(config.pending_pool_config), - reorg_pool: ReorgPool::new(config.reorg_pool_config), - blockchain_db, - validator: Arc::new(mempool_validator), + pool_storage: Arc::new(RwLock::new(MempoolStorage::new(blockchain_db, config, validators))), } } /// Insert an unconfirmed transaction into the Mempool. The transaction *MUST* have passed through the validation /// pipeline already and will thus always be internally consistent by this stage pub fn insert(&self, tx: Arc) -> Result { - debug!( - target: LOG_TARGET, - "Inserting tx into mempool: {}", - tx.body.kernels()[0].excess_sig.get_signature().to_hex() - ); - // The transaction is already internally consistent - let (db, metadata) = self.blockchain_db.db_and_metadata_read_access()?; - - match self.validator.validate(&tx, &db, &metadata) { - Ok(()) => { - self.unconfirmed_pool.insert(tx)?; - Ok(TxStorageResponse::UnconfirmedPool) - }, - Err(ValidationError::UnknownInputs) => { - self.orphan_pool.insert(tx)?; - Ok(TxStorageResponse::OrphanPool) - }, - Err(ValidationError::MaturityError) => { - self.pending_pool.insert(tx)?; - Ok(TxStorageResponse::PendingPool) - }, - _ => Ok(TxStorageResponse::NotStored), - } - } - - // Insert a set of new transactions into the UTxPool. - fn insert_txs(&self, txs: Vec>) -> Result<(), MempoolError> { - for tx in txs { - self.insert(tx)?; - } - Ok(()) + self.pool_storage + .write() + .map_err(|e| MempoolError::BackendError(e.to_string()))? + .insert(tx) } /// Update the Mempool based on the received published block. pub fn process_published_block(&self, published_block: Block) -> Result<(), MempoolError> { - trace!(target: LOG_TARGET, "Mempool processing new block: {}", published_block); - // Move published txs to ReOrgPool and discard double spends - self.reorg_pool.insert_txs( - self.unconfirmed_pool - .remove_published_and_discard_double_spends(&published_block)?, - )?; - - // Move txs with valid input UTXOs and expired time-locks to UnconfirmedPool and discard double spends - self.unconfirmed_pool.insert_txs( - self.pending_pool - .remove_unlocked_and_discard_double_spends(&published_block)?, - )?; - - // Move txs with recently expired time-locks that have input UTXOs that have recently become valid to the - // UnconfirmedPool - let (txs, time_locked_txs) = self.orphan_pool.scan_for_and_remove_unorphaned_txs()?; - self.unconfirmed_pool.insert_txs(txs)?; - // Move Time-locked txs that have input UTXOs that have recently become valid to PendingPool. - self.pending_pool.insert_txs(time_locked_txs)?; - - Ok(()) - } - - // Update the Mempool based on the received set of published blocks. - fn process_published_blocks(&self, published_blocks: Vec) -> Result<(), MempoolError> { - for published_block in published_blocks { - self.process_published_block(published_block)?; - } - Ok(()) + self.pool_storage + .write() + .map_err(|e| MempoolError::BackendError(e.to_string()))? + .process_published_block(published_block) } /// In the event of a ReOrg, resubmit all ReOrged transactions into the Mempool and process each newly introduced /// block from the latest longest chain. pub fn process_reorg(&self, removed_blocks: Vec, new_blocks: Vec) -> Result<(), MempoolError> { - debug!(target: LOG_TARGET, "Mempool processing reorg"); - for block in &removed_blocks { - trace!( - target: LOG_TARGET, - "Mempool processing reorg removed block {} ({})", - block.header.height, - block.header.hash().to_hex(), - ); - } - for block in &new_blocks { - trace!( - target: LOG_TARGET, - "Mempool processing reorg added new block {} ({})", - block.header.height, - block.header.hash().to_hex(), - ); - } - - self.insert_txs( - self.reorg_pool - .remove_reorged_txs_and_discard_double_spends(removed_blocks, &new_blocks)?, - )?; - self.process_published_blocks(new_blocks)?; - Ok(()) + self.pool_storage + .write() + .map_err(|e| MempoolError::BackendError(e.to_string()))? + .process_reorg(removed_blocks, new_blocks) } /// Returns all unconfirmed transaction stored in the Mempool, except the transactions stored in the ReOrgPool. // TODO: Investigate returning an iterator rather than a large vector of transactions pub fn snapshot(&self) -> Result>, MempoolError> { - let mut txs = self.unconfirmed_pool.snapshot()?; - txs.append(&mut self.orphan_pool.snapshot()?); - txs.append(&mut self.pending_pool.snapshot()?); - Ok(txs) + self.pool_storage + .read() + .map_err(|e| MempoolError::BackendError(e.to_string()))? + .snapshot() } /// Returns a list of transaction ranked by transaction priority up to a given weight. pub fn retrieve(&self, total_weight: u64) -> Result>, MempoolError> { - Ok(self.unconfirmed_pool.highest_priority_txs(total_weight)?) + self.pool_storage + .read() + .map_err(|e| MempoolError::BackendError(e.to_string()))? + .retrieve(total_weight) } /// Check if the specified transaction is stored in the Mempool. pub fn has_tx_with_excess_sig(&self, excess_sig: Signature) -> Result { - if self.unconfirmed_pool.has_tx_with_excess_sig(&excess_sig)? { - Ok(TxStorageResponse::UnconfirmedPool) - } else if self.orphan_pool.has_tx_with_excess_sig(&excess_sig)? { - Ok(TxStorageResponse::OrphanPool) - } else if self.pending_pool.has_tx_with_excess_sig(&excess_sig)? { - Ok(TxStorageResponse::PendingPool) - } else if self.reorg_pool.has_tx_with_excess_sig(&excess_sig)? { - Ok(TxStorageResponse::ReorgPool) - } else { - Ok(TxStorageResponse::NotStored) - } - } - - // Returns the total number of transactions in the Mempool. - fn len(&self) -> Result { - Ok( - self.unconfirmed_pool.len()? + - self.orphan_pool.len()? + - self.pending_pool.len()? + - self.reorg_pool.len()?, - ) - } - - // Returns the total weight of all transactions stored in the Mempool. - fn calculate_weight(&self) -> Result { - Ok(self.unconfirmed_pool.calculate_weight()? + - self.orphan_pool.calculate_weight()? + - self.pending_pool.calculate_weight()? + - self.reorg_pool.calculate_weight()?) + self.pool_storage + .read() + .map_err(|e| MempoolError::BackendError(e.to_string()))? + .has_tx_with_excess_sig(excess_sig) } /// Gathers and returns the stats of the Mempool. pub fn stats(&self) -> Result { - Ok(StatsResponse { - total_txs: self.len()?, - unconfirmed_txs: self.unconfirmed_pool.len()?, - orphan_txs: self.orphan_pool.len()?, - timelocked_txs: self.pending_pool.len()?, - published_txs: self.reorg_pool.len()?, - total_weight: self.calculate_weight()?, - }) + self.pool_storage + .read() + .map_err(|e| MempoolError::BackendError(e.to_string()))? + .stats() + } + + /// Gathers and returns a breakdown of all the transaction in the Mempool. + pub fn state(&self) -> Result { + self.pool_storage + .read() + .map_err(|e| MempoolError::BackendError(e.to_string()))? + .state() } } @@ -258,12 +151,7 @@ where T: BlockchainBackend { fn clone(&self) -> Self { Mempool { - blockchain_db: self.blockchain_db.clone(), - unconfirmed_pool: self.unconfirmed_pool.clone(), - orphan_pool: self.orphan_pool.clone(), - pending_pool: self.pending_pool.clone(), - reorg_pool: self.reorg_pool.clone(), - validator: self.validator.clone(), + pool_storage: self.pool_storage.clone(), } } } diff --git a/base_layer/core/src/mempool/mempool_storage.rs b/base_layer/core/src/mempool/mempool_storage.rs new file mode 100644 index 0000000000..d9da04b4db --- /dev/null +++ b/base_layer/core/src/mempool/mempool_storage.rs @@ -0,0 +1,289 @@ +// Copyright 2019. The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use crate::{ + blocks::Block, + chain_storage::{BlockchainBackend, BlockchainDatabase}, + mempool::{ + error::MempoolError, + mempool::MempoolValidators, + orphan_pool::OrphanPool, + pending_pool::PendingPool, + reorg_pool::ReorgPool, + unconfirmed_pool::UnconfirmedPool, + MempoolConfig, + StateResponse, + StatsResponse, + TxStorageResponse, + }, + transactions::{transaction::Transaction, types::Signature}, + validation::{ValidationError, Validator}, +}; +use log::*; +use std::sync::Arc; +use tari_crypto::tari_utilities::{hex::Hex, Hashable}; + +pub const LOG_TARGET: &str = "c::mp::mempool"; + +/// The Mempool consists of an Unconfirmed Transaction Pool, Pending Pool, Orphan Pool and Reorg Pool and is responsible +/// for managing and maintaining all unconfirmed transactions have not yet been included in a block, and transactions +/// that have recently been included in a block. +pub struct MempoolStorage +where T: BlockchainBackend +{ + blockchain_db: BlockchainDatabase, + unconfirmed_pool: UnconfirmedPool, + orphan_pool: OrphanPool, + pending_pool: PendingPool, + reorg_pool: ReorgPool, + validator: Arc>, +} + +impl MempoolStorage +where T: BlockchainBackend +{ + /// Create a new Mempool with an UnconfirmedPool, OrphanPool, PendingPool and ReOrgPool. + pub fn new(blockchain_db: BlockchainDatabase, config: MempoolConfig, validators: MempoolValidators) -> Self { + let (mempool_validator, orphan_validator) = validators.into_validators(); + Self { + unconfirmed_pool: UnconfirmedPool::new(config.unconfirmed_pool_config), + orphan_pool: OrphanPool::new(config.orphan_pool_config, orphan_validator, blockchain_db.clone()), + pending_pool: PendingPool::new(config.pending_pool_config), + reorg_pool: ReorgPool::new(config.reorg_pool_config), + blockchain_db, + validator: Arc::new(mempool_validator), + } + } + + /// Insert an unconfirmed transaction into the Mempool. The transaction *MUST* have passed through the validation + /// pipeline already and will thus always be internally consistent by this stage + pub fn insert(&mut self, tx: Arc) -> Result { + debug!( + target: LOG_TARGET, + "Inserting tx into mempool: {}", + tx.body.kernels()[0].excess_sig.get_signature().to_hex() + ); + // The transaction is already internally consistent + let db = self.blockchain_db.db_read_access()?; + + match self.validator.validate(&tx, &db) { + Ok(()) => { + self.unconfirmed_pool.insert(tx)?; + Ok(TxStorageResponse::UnconfirmedPool) + }, + Err(ValidationError::UnknownInputs) => { + self.orphan_pool.insert(tx)?; + Ok(TxStorageResponse::OrphanPool) + }, + Err(ValidationError::ContainsSTxO) => { + self.reorg_pool.insert(tx)?; + Ok(TxStorageResponse::ReorgPool) + }, + Err(ValidationError::MaturityError) => { + self.pending_pool.insert(tx)?; + Ok(TxStorageResponse::PendingPool) + }, + _ => Ok(TxStorageResponse::NotStored), + } + } + + // Insert a set of new transactions into the UTxPool. + fn insert_txs(&mut self, txs: Vec>) -> Result<(), MempoolError> { + for tx in txs { + self.insert(tx)?; + } + Ok(()) + } + + /// Update the Mempool based on the received published block. + pub fn process_published_block(&mut self, published_block: Block) -> Result<(), MempoolError> { + trace!(target: LOG_TARGET, "Mempool processing new block: {}", published_block); + // Move published txs to ReOrgPool and discard double spends + self.reorg_pool.insert_txs( + self.unconfirmed_pool + .remove_published_and_discard_double_spends(&published_block), + )?; + + // Move txs with valid input UTXOs and expired time-locks to UnconfirmedPool and discard double spends + self.unconfirmed_pool.insert_txs( + self.pending_pool + .remove_unlocked_and_discard_double_spends(&published_block)?, + )?; + + // Move txs with recently expired time-locks that have input UTXOs that have recently become valid to the + // UnconfirmedPool + let (txs, time_locked_txs) = self.orphan_pool.scan_for_and_remove_unorphaned_txs()?; + self.unconfirmed_pool.insert_txs(txs)?; + // Move Time-locked txs that have input UTXOs that have recently become valid to PendingPool. + self.pending_pool.insert_txs(time_locked_txs)?; + + Ok(()) + } + + // Update the Mempool based on the received set of published blocks. + fn process_published_blocks(&mut self, published_blocks: Vec) -> Result<(), MempoolError> { + for published_block in published_blocks { + self.process_published_block(published_block)?; + } + Ok(()) + } + + /// In the event of a ReOrg, resubmit all ReOrged transactions into the Mempool and process each newly introduced + /// block from the latest longest chain. + pub fn process_reorg(&mut self, removed_blocks: Vec, new_blocks: Vec) -> Result<(), MempoolError> { + debug!(target: LOG_TARGET, "Mempool processing reorg"); + for block in &removed_blocks { + trace!( + target: LOG_TARGET, + "Mempool processing reorg removed block {} ({})", + block.header.height, + block.header.hash().to_hex(), + ); + } + for block in &new_blocks { + trace!( + target: LOG_TARGET, + "Mempool processing reorg added new block {} ({})", + block.header.height, + block.header.hash().to_hex(), + ); + } + + let prev_tip_height = removed_blocks + .last() + .expect("Added empty set of blocks on reorg.") + .header + .height; + let new_tip_height = new_blocks + .last() + .expect("Removed empty set of blocks on reorg.") + .header + .height; + self.insert_txs( + self.reorg_pool + .remove_reorged_txs_and_discard_double_spends(removed_blocks, &new_blocks)?, + )?; + self.process_published_blocks(new_blocks)?; + + if new_tip_height < prev_tip_height { + trace!( + target: LOG_TARGET, + "Checking for time locked transactions in unconfirmed pool as chain height was reduced from {} to {} \ + during reorg.", + prev_tip_height, + new_tip_height, + ); + self.pending_pool + .insert_txs(self.unconfirmed_pool.remove_timelocked(new_tip_height))?; + } + + Ok(()) + } + + /// Returns all unconfirmed transaction stored in the Mempool, except the transactions stored in the ReOrgPool. + // TODO: Investigate returning an iterator rather than a large vector of transactions + pub fn snapshot(&self) -> Result>, MempoolError> { + let mut txs = self.unconfirmed_pool.snapshot(); + txs.append(&mut self.orphan_pool.snapshot()?); + txs.append(&mut self.pending_pool.snapshot()); + Ok(txs) + } + + /// Returns a list of transaction ranked by transaction priority up to a given weight. + pub fn retrieve(&self, total_weight: u64) -> Result>, MempoolError> { + Ok(self.unconfirmed_pool.highest_priority_txs(total_weight)?) + } + + /// Check if the specified transaction is stored in the Mempool. + pub fn has_tx_with_excess_sig(&self, excess_sig: Signature) -> Result { + if self.unconfirmed_pool.has_tx_with_excess_sig(&excess_sig) { + Ok(TxStorageResponse::UnconfirmedPool) + } else if self.orphan_pool.has_tx_with_excess_sig(&excess_sig)? { + Ok(TxStorageResponse::OrphanPool) + } else if self.pending_pool.has_tx_with_excess_sig(&excess_sig) { + Ok(TxStorageResponse::PendingPool) + } else if self.reorg_pool.has_tx_with_excess_sig(&excess_sig)? { + Ok(TxStorageResponse::ReorgPool) + } else { + Ok(TxStorageResponse::NotStored) + } + } + + // Returns the total number of transactions in the Mempool. + fn len(&self) -> Result { + Ok(self.unconfirmed_pool.len() + self.orphan_pool.len()? + self.pending_pool.len() + self.reorg_pool.len()?) + } + + // Returns the total weight of all transactions stored in the Mempool. + fn calculate_weight(&self) -> Result { + Ok(self.unconfirmed_pool.calculate_weight() + + self.orphan_pool.calculate_weight()? + + self.pending_pool.calculate_weight() + + self.reorg_pool.calculate_weight()?) + } + + /// Gathers and returns the stats of the Mempool. + pub fn stats(&self) -> Result { + Ok(StatsResponse { + total_txs: self.len()?, + unconfirmed_txs: self.unconfirmed_pool.len(), + orphan_txs: self.orphan_pool.len()?, + timelocked_txs: self.pending_pool.len(), + published_txs: self.reorg_pool.len()?, + total_weight: self.calculate_weight()?, + }) + } + + /// Gathers and returns a breakdown of all the transaction in the Mempool. + pub fn state(&self) -> Result { + let unconfirmed_pool = self + .unconfirmed_pool + .snapshot() + .iter() + .map(|tx| tx.body.kernels()[0].excess_sig.clone()) + .collect::>(); + let orphan_pool = self + .orphan_pool + .snapshot()? + .iter() + .map(|tx| tx.body.kernels()[0].excess_sig.clone()) + .collect::>(); + let pending_pool = self + .pending_pool + .snapshot() + .iter() + .map(|tx| tx.body.kernels()[0].excess_sig.clone()) + .collect::>(); + let reorg_pool = self + .reorg_pool + .snapshot()? + .iter() + .map(|tx| tx.body.kernels()[0].excess_sig.clone()) + .collect::>(); + Ok(StateResponse { + unconfirmed_pool, + orphan_pool, + pending_pool, + reorg_pool, + }) + } +} diff --git a/base_layer/core/src/mempool/mod.rs b/base_layer/core/src/mempool/mod.rs index 1271ce8074..c4976227d0 100644 --- a/base_layer/core/src/mempool/mod.rs +++ b/base_layer/core/src/mempool/mod.rs @@ -26,9 +26,12 @@ mod config; mod consts; #[cfg(feature = "base_node")] mod error; +#[allow(clippy::module_inception)] #[cfg(feature = "base_node")] mod mempool; #[cfg(feature = "base_node")] +mod mempool_storage; +#[cfg(feature = "base_node")] mod orphan_pool; #[cfg(feature = "base_node")] mod pending_pool; @@ -58,10 +61,12 @@ pub mod proto; #[cfg(any(feature = "base_node", feature = "mempool_proto"))] pub mod service; +use crate::transactions::types::Signature; use core::fmt::{Display, Error, Formatter}; use serde::{Deserialize, Serialize}; +use tari_crypto::tari_utilities::hex::Hex; -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct StatsResponse { pub total_txs: usize, pub unconfirmed_txs: usize, @@ -87,7 +92,38 @@ impl Display for StatsResponse { } } -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct StateResponse { + pub unconfirmed_pool: Vec, + pub orphan_pool: Vec, + pub pending_pool: Vec, + pub reorg_pool: Vec, +} + +impl Display for StateResponse { + fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), Error> { + fmt.write_str("----------------- Mempool -----------------\n")?; + fmt.write_str("--- Unconfirmed Pool ---\n")?; + for excess_sig in &self.unconfirmed_pool { + fmt.write_str(&format!(" {}\n", excess_sig.get_signature().to_hex()))?; + } + fmt.write_str("--- Orphan Pool ---\n")?; + for excess_sig in &self.orphan_pool { + fmt.write_str(&format!(" {}\n", excess_sig.get_signature().to_hex()))?; + } + fmt.write_str("--- Pending Pool ---\n")?; + for excess_sig in &self.pending_pool { + fmt.write_str(&format!(" {}\n", excess_sig.get_signature().to_hex()))?; + } + fmt.write_str("--- Reorg Pool ---\n")?; + for excess_sig in &self.reorg_pool { + fmt.write_str(&format!(" {}\n", excess_sig.get_signature().to_hex()))?; + } + Ok(()) + } +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum TxStorageResponse { UnconfirmedPool, OrphanPool, diff --git a/base_layer/core/src/mempool/orphan_pool/mod.rs b/base_layer/core/src/mempool/orphan_pool/mod.rs index a693315aac..f3e8acaf28 100644 --- a/base_layer/core/src/mempool/orphan_pool/mod.rs +++ b/base_layer/core/src/mempool/orphan_pool/mod.rs @@ -21,6 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mod error; +#[allow(clippy::module_inception)] mod orphan_pool; mod orphan_pool_storage; diff --git a/base_layer/core/src/mempool/orphan_pool/orphan_pool.rs b/base_layer/core/src/mempool/orphan_pool/orphan_pool.rs index 2a601a8fb4..c132e7e815 100644 --- a/base_layer/core/src/mempool/orphan_pool/orphan_pool.rs +++ b/base_layer/core/src/mempool/orphan_pool/orphan_pool.rs @@ -168,12 +168,12 @@ mod test { #[test] fn test_insert_rlu_and_ttl() { - let tx1 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(500), lock: 4000, inputs: 2, outputs: 1).0); - let tx2 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(300), lock: 3000, inputs: 2, outputs: 1).0); - let tx3 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(100), lock: 2500, inputs: 2, outputs: 1).0); - let tx4 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(200), lock: 1000, inputs: 2, outputs: 1).0); - let tx5 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(500), lock: 2000, inputs: 2, outputs: 1).0); - let tx6 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(600), lock: 5500, inputs: 2, outputs: 1).0); + let tx1 = Arc::new(tx!(MicroTari(100_000), fee: MicroTari(500), lock: 4000, inputs: 2, outputs: 1).0); + let tx2 = Arc::new(tx!(MicroTari(100_000), fee: MicroTari(300), lock: 3000, inputs: 2, outputs: 1).0); + let tx3 = Arc::new(tx!(MicroTari(100_000), fee: MicroTari(100), lock: 2500, inputs: 2, outputs: 1).0); + let tx4 = Arc::new(tx!(MicroTari(100_000), fee: MicroTari(200), lock: 1000, inputs: 2, outputs: 1).0); + let tx5 = Arc::new(tx!(MicroTari(100_000), fee: MicroTari(500), lock: 2000, inputs: 2, outputs: 1).0); + let tx6 = Arc::new(tx!(MicroTari(100_000), fee: MicroTari(600), lock: 5500, inputs: 2, outputs: 1).0); let network = Network::LocalNet; let consensus_manager = ConsensusManagerBuilder::new(network).build(); let store = create_mem_db(&consensus_manager); diff --git a/base_layer/core/src/mempool/orphan_pool/orphan_pool_storage.rs b/base_layer/core/src/mempool/orphan_pool/orphan_pool_storage.rs index f747256bee..03112ccb0d 100644 --- a/base_layer/core/src/mempool/orphan_pool/orphan_pool_storage.rs +++ b/base_layer/core/src/mempool/orphan_pool/orphan_pool_storage.rs @@ -102,9 +102,9 @@ where T: BlockchainBackend // We dont care about tx's that appeared in valid blocks. Those tx's will time out in orphan pool and remove // them selves. for (tx_key, tx) in self.txs_by_signature.iter() { - let (db, metadata) = self.blockchain_db.db_and_metadata_read_access()?; + let db = self.blockchain_db.db_read_access()?; - match self.validator.validate(&tx, &db, &metadata) { + match self.validator.validate(&tx, &db) { Ok(()) => { trace!( target: LOG_TARGET, diff --git a/base_layer/core/src/mempool/pending_pool/mod.rs b/base_layer/core/src/mempool/pending_pool/mod.rs index c0654f26da..7ef60a14ce 100644 --- a/base_layer/core/src/mempool/pending_pool/mod.rs +++ b/base_layer/core/src/mempool/pending_pool/mod.rs @@ -21,10 +21,9 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mod error; +#[allow(clippy::module_inception)] mod pending_pool; -mod pending_pool_storage; // Public re-exports pub use error::PendingPoolError; pub use pending_pool::{PendingPool, PendingPoolConfig}; -pub use pending_pool_storage::PendingPoolStorage; diff --git a/base_layer/core/src/mempool/pending_pool/pending_pool.rs b/base_layer/core/src/mempool/pending_pool/pending_pool.rs index 4b08c717ee..4535cf5ce4 100644 --- a/base_layer/core/src/mempool/pending_pool/pending_pool.rs +++ b/base_layer/core/src/mempool/pending_pool/pending_pool.rs @@ -24,11 +24,20 @@ use crate::{ blocks::Block, mempool::{ consts::MEMPOOL_PENDING_POOL_STORAGE_CAPACITY, - pending_pool::{PendingPoolError, PendingPoolStorage}, + pending_pool::PendingPoolError, + priority::{FeePriority, TimelockPriority, TimelockedTransaction}, }, transactions::{transaction::Transaction, types::Signature}, }; -use std::sync::{Arc, RwLock}; +use log::*; +use std::{ + collections::{BTreeMap, HashMap}, + convert::TryFrom, + sync::Arc, +}; +use tari_crypto::tari_utilities::hex::Hex; + +pub const LOG_TARGET: &str = "c::mp::pending_pool::pending_pool_storage"; /// Configuration for the PendingPool. #[derive(Clone, Copy)] @@ -47,101 +56,188 @@ impl Default for PendingPoolConfig { /// The Pending Pool contains all transactions that are restricted by time-locks. Once the time-locks have expired then /// the transactions can be moved to the Unconfirmed Transaction Pool for inclusion in future blocks. +/// The txs_by_signature HashMap is used to find a transaction using its excess_sig, this functionality is used to match +/// transactions included in blocks with transactions stored in the pool. +/// The txs_by_fee_priority BTreeMap prioritize the transactions in the pool according to FeePriority, it allows +/// transactions to be inserted in sorted order based on their priority. The txs_by_timelock_priority BTreeMap +/// prioritize the transactions in the pool according to TimelockPriority, it allows transactions to be inserted in +/// sorted order based on the expiry of their time-locks. pub struct PendingPool { - pool_storage: Arc>, + config: PendingPoolConfig, + txs_by_signature: HashMap, + txs_by_fee_priority: BTreeMap, + txs_by_timelock_priority: BTreeMap, } impl PendingPool { /// Create a new PendingPool with the specified configuration. pub fn new(config: PendingPoolConfig) -> Self { Self { - pool_storage: Arc::new(RwLock::new(PendingPoolStorage::new(config))), + config, + txs_by_signature: HashMap::new(), + txs_by_fee_priority: BTreeMap::new(), + txs_by_timelock_priority: BTreeMap::new(), + } + } + + fn lowest_fee_priority(&self) -> &FeePriority { + self.txs_by_fee_priority.iter().next().unwrap().0 + } + + fn remove_tx_with_lowest_fee_priority(&mut self) { + if let Some((_, tx_key)) = self + .txs_by_fee_priority + .iter() + .next() + .map(|(p, s)| (p.clone(), s.clone())) + { + if let Some(removed_tx) = self.txs_by_signature.remove(&tx_key) { + trace!( + target: LOG_TARGET, + "Removing tx from pending pool: {:?}, {:?}", + removed_tx.fee_priority, + removed_tx.timelock_priority + ); + self.txs_by_fee_priority.remove(&removed_tx.fee_priority); + self.txs_by_timelock_priority.remove(&removed_tx.timelock_priority); + } } } /// Insert a new transaction into the PendingPool. Low priority transactions will be removed to make space for /// higher priority transactions. The lowest priority transactions will be removed when the maximum capacity is /// reached and the new transaction has a higher priority than the currently stored lowest priority transaction. - pub fn insert(&self, transaction: Arc) -> Result<(), PendingPoolError> { - self.pool_storage - .write() - .map_err(|e| PendingPoolError::BackendError(e.to_string()))? - .insert(transaction) + #[allow(clippy::map_entry)] + pub fn insert(&mut self, tx: Arc) -> Result<(), PendingPoolError> { + let tx_key = tx.body.kernels()[0].excess_sig.clone(); + if !self.txs_by_signature.contains_key(&tx_key) { + debug!( + target: LOG_TARGET, + "Inserting tx into pending pool: {}", + tx_key.get_signature().to_hex() + ); + trace!(target: LOG_TARGET, "Transaction inserted: {}", tx); + let prioritized_tx = TimelockedTransaction::try_from((*tx).clone())?; + if self.txs_by_signature.len() >= self.config.storage_capacity { + if prioritized_tx.fee_priority < *self.lowest_fee_priority() { + return Ok(()); + } + self.remove_tx_with_lowest_fee_priority(); + } + self.txs_by_fee_priority + .insert(prioritized_tx.fee_priority.clone(), tx_key.clone()); + self.txs_by_timelock_priority + .insert(prioritized_tx.timelock_priority.clone(), tx_key.clone()); + self.txs_by_signature.insert(tx_key, prioritized_tx); + } + Ok(()) } /// Insert a set of new transactions into the PendingPool. - pub fn insert_txs(&self, transactions: Vec>) -> Result<(), PendingPoolError> { - self.pool_storage - .write() - .map_err(|e| PendingPoolError::BackendError(e.to_string()))? - .insert_txs(transactions) + pub fn insert_txs(&mut self, txs: Vec>) -> Result<(), PendingPoolError> { + for tx in txs.into_iter() { + self.insert(tx)?; + } + Ok(()) } /// Check if a specific transaction is available in the PendingPool. - pub fn has_tx_with_excess_sig(&self, excess_sig: &Signature) -> Result { - Ok(self - .pool_storage - .read() - .map_err(|e| PendingPoolError::BackendError(e.to_string()))? - .has_tx_with_excess_sig(excess_sig)) + pub fn has_tx_with_excess_sig(&self, excess_sig: &Signature) -> bool { + self.txs_by_signature.contains_key(excess_sig) + } + + /// Remove double-spends from the PendingPoolStorage. These transactions were orphaned by the provided published + /// block. Check if any of the unspent transactions in the PendingPool has inputs that was spent by the provided + /// published block. + fn discard_double_spends(&mut self, published_block: &Block) { + let mut removed_tx_keys: Vec = Vec::new(); + for (tx_key, ptx) in self.txs_by_signature.iter() { + for input in ptx.transaction.body.inputs() { + if published_block.body.inputs().contains(input) { + self.txs_by_fee_priority.remove(&ptx.fee_priority); + self.txs_by_timelock_priority.remove(&ptx.timelock_priority); + removed_tx_keys.push(tx_key.clone()); + } + } + } + + for tx_key in &removed_tx_keys { + trace!(target: LOG_TARGET, "Removed double spends: {:?}", tx_key); + self.txs_by_signature.remove(&tx_key); + } } /// Remove transactions with expired time-locks so that they can be move to the UnconfirmedPool. Double spend /// transactions are also removed. pub fn remove_unlocked_and_discard_double_spends( - &self, + &mut self, published_block: &Block, ) -> Result>, PendingPoolError> { - self.pool_storage - .write() - .map_err(|e| PendingPoolError::BackendError(e.to_string()))? - .remove_unlocked_and_discard_double_spends(published_block) + self.discard_double_spends(published_block); + + let mut removed_txs: Vec> = Vec::new(); + let mut removed_tx_keys: Vec = Vec::new(); + for (_, tx_key) in self.txs_by_timelock_priority.iter() { + if self + .txs_by_signature + .get(tx_key) + .ok_or(PendingPoolError::StorageOutofSync)? + .max_timelock_height > + published_block.header.height + { + break; + } + + if let Some(removed_ptx) = self.txs_by_signature.remove(&tx_key) { + self.txs_by_fee_priority.remove(&removed_ptx.fee_priority); + removed_tx_keys.push(removed_ptx.timelock_priority); + removed_txs.push(removed_ptx.transaction); + } + } + + for tx_key in &removed_tx_keys { + trace!(target: LOG_TARGET, "Removed unlocked and double spends: {:?}", tx_key); + self.txs_by_timelock_priority.remove(&tx_key); + } + + Ok(removed_txs) } /// Returns the total number of time-locked transactions stored in the PendingPool. - pub fn len(&self) -> Result { - Ok(self - .pool_storage - .read() - .map_err(|e| PendingPoolError::BackendError(e.to_string()))? - .len()) + pub fn len(&self) -> usize { + self.txs_by_signature.len() } /// Returns all transaction stored in the PendingPool. - pub fn snapshot(&self) -> Result>, PendingPoolError> { - Ok(self - .pool_storage - .read() - .map_err(|e| PendingPoolError::BackendError(e.to_string()))? - .snapshot()) + pub fn snapshot(&self) -> Vec> { + self.txs_by_signature + .iter() + .map(|(_, ptx)| ptx.transaction.clone()) + .collect() } /// Returns the total weight of all transactions stored in the pool. - pub fn calculate_weight(&self) -> Result { - Ok(self - .pool_storage - .read() - .map_err(|e| PendingPoolError::BackendError(e.to_string()))? - .calculate_weight()) + pub fn calculate_weight(&self) -> u64 { + self.txs_by_signature + .iter() + .fold(0, |weight, (_, ptx)| weight + ptx.transaction.calculate_weight()) } #[cfg(test)] - /// Checks the consistency status of the Hashmap and BtreeMaps. - pub fn check_status(&self) -> Result { - Ok(self - .pool_storage - .read() - .map_err(|e| PendingPoolError::BackendError(e.to_string()))? - .check_status()) - } -} - -impl Clone for PendingPool { - fn clone(&self) -> Self { - PendingPool { - pool_storage: self.pool_storage.clone(), + /// Checks the consistency status of the Hashmap and the BtreeMaps + pub fn check_status(&self) -> bool { + if (self.txs_by_fee_priority.len() != self.txs_by_signature.len()) || + (self.txs_by_timelock_priority.len() != self.txs_by_signature.len()) + { + return false; } + self.txs_by_fee_priority + .iter() + .all(|(_, tx_key)| self.txs_by_signature.contains_key(tx_key)) && + self.txs_by_timelock_priority + .iter() + .all(|(_, tx_key)| self.txs_by_signature.contains_key(tx_key)) } } @@ -158,14 +254,14 @@ mod test { #[test] fn test_insert_and_lru() { - let tx1 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(50), lock: 500, inputs: 2, outputs: 1).0); + let tx1 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(49), lock: 500, inputs: 2, outputs: 1).0); let tx2 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(20), lock: 2150, inputs: 1, outputs: 2).0); let tx3 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(100), lock: 1000, inputs: 2, outputs: 1).0); let tx4 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(30), lock: 2450, inputs: 2, outputs: 2).0); let tx5 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(50), lock: 1000, inputs: 3, outputs: 3).0); let tx6 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(75), lock: 1850, inputs: 2, outputs: 2).0); - let pending_pool = PendingPool::new(PendingPoolConfig { storage_capacity: 3 }); + let mut pending_pool = PendingPool::new(PendingPoolConfig { storage_capacity: 3 }); pending_pool .insert_txs(vec![ tx1.clone(), @@ -177,45 +273,33 @@ mod test { ]) .unwrap(); // Check that lowest priority txs were removed to make room for higher priority transactions - assert_eq!(pending_pool.len().unwrap(), 3); + assert_eq!(pending_pool.len(), 3); assert_eq!( - pending_pool - .has_tx_with_excess_sig(&tx1.body.kernels()[0].excess_sig) - .unwrap(), - true + pending_pool.has_tx_with_excess_sig(&tx1.body.kernels()[0].excess_sig), + false ); assert_eq!( - pending_pool - .has_tx_with_excess_sig(&tx2.body.kernels()[0].excess_sig) - .unwrap(), + pending_pool.has_tx_with_excess_sig(&tx2.body.kernels()[0].excess_sig), false ); assert_eq!( - pending_pool - .has_tx_with_excess_sig(&tx3.body.kernels()[0].excess_sig) - .unwrap(), + pending_pool.has_tx_with_excess_sig(&tx3.body.kernels()[0].excess_sig), true ); assert_eq!( - pending_pool - .has_tx_with_excess_sig(&tx4.body.kernels()[0].excess_sig) - .unwrap(), + pending_pool.has_tx_with_excess_sig(&tx4.body.kernels()[0].excess_sig), false ); assert_eq!( - pending_pool - .has_tx_with_excess_sig(&tx5.body.kernels()[0].excess_sig) - .unwrap(), - false + pending_pool.has_tx_with_excess_sig(&tx5.body.kernels()[0].excess_sig), + true ); assert_eq!( - pending_pool - .has_tx_with_excess_sig(&tx6.body.kernels()[0].excess_sig) - .unwrap(), + pending_pool.has_tx_with_excess_sig(&tx6.body.kernels()[0].excess_sig), true ); - assert!(pending_pool.check_status().unwrap()); + assert!(pending_pool.check_status()); } #[test] @@ -235,7 +319,7 @@ mod test { let tx6 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(75), lock: 1450, inputs: 2, maturity: 1400, outputs: 2).0); - let pending_pool = PendingPool::new(PendingPoolConfig { storage_capacity: 10 }); + let mut pending_pool = PendingPool::new(PendingPoolConfig { storage_capacity: 10 }); pending_pool .insert_txs(vec![ tx1.clone(), @@ -246,9 +330,9 @@ mod test { tx6.clone(), ]) .unwrap(); - assert_eq!(pending_pool.len().unwrap(), 6); + assert_eq!(pending_pool.len(), 6); - let snapshot_txs = pending_pool.snapshot().unwrap(); + let snapshot_txs = pending_pool.snapshot(); assert_eq!(snapshot_txs.len(), 6); assert!(snapshot_txs.contains(&tx1)); assert!(snapshot_txs.contains(&tx2)); @@ -262,17 +346,13 @@ mod test { .remove_unlocked_and_discard_double_spends(&published_block) .unwrap(); - assert_eq!(pending_pool.len().unwrap(), 2); + assert_eq!(pending_pool.len(), 2); assert_eq!( - pending_pool - .has_tx_with_excess_sig(&tx2.body.kernels()[0].excess_sig) - .unwrap(), + pending_pool.has_tx_with_excess_sig(&tx2.body.kernels()[0].excess_sig), true ); assert_eq!( - pending_pool - .has_tx_with_excess_sig(&tx4.body.kernels()[0].excess_sig) - .unwrap(), + pending_pool.has_tx_with_excess_sig(&tx4.body.kernels()[0].excess_sig), true ); @@ -281,6 +361,6 @@ mod test { assert!(unlocked_txs.contains(&tx3)); assert!(unlocked_txs.contains(&tx5)); - assert!(pending_pool.check_status().unwrap()); + assert!(pending_pool.check_status()); } } diff --git a/base_layer/core/src/mempool/pending_pool/pending_pool_storage.rs b/base_layer/core/src/mempool/pending_pool/pending_pool_storage.rs deleted file mode 100644 index ba4389b509..0000000000 --- a/base_layer/core/src/mempool/pending_pool/pending_pool_storage.rs +++ /dev/null @@ -1,223 +0,0 @@ -// Copyright 2019 The Tari Project -// -// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the -// following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following -// disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the -// following disclaimer in the documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote -// products derived from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, -// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -use crate::{ - blocks::Block, - mempool::{ - pending_pool::{PendingPoolConfig, PendingPoolError}, - priority::{FeePriority, TimelockPriority, TimelockedTransaction}, - }, - transactions::{transaction::Transaction, types::Signature}, -}; -use log::*; -use std::{ - collections::{BTreeMap, HashMap}, - convert::TryFrom, - sync::Arc, -}; -use tari_crypto::tari_utilities::hex::Hex; - -pub const LOG_TARGET: &str = "c::mp::pending_pool::pending_pool_storage"; - -/// PendingPool makes use of PendingPoolStorage to provide thread safe access to its Hashmap and BTreeMaps. -/// The txs_by_signature HashMap is used to find a transaction using its excess_sig, this functionality is used to match -/// transactions included in blocks with transactions stored in the pool. -/// The txs_by_fee_priority BTreeMap prioritize the transactions in the pool according to FeePriority, it allows -/// transactions to be inserted in sorted order based on their priority. The txs_by_timelock_priority BTreeMap -/// prioritize the transactions in the pool according to TimelockPriority, it allows transactions to be inserted in -/// sorted order based on the expiry of their time-locks. -pub struct PendingPoolStorage { - config: PendingPoolConfig, - txs_by_signature: HashMap, - txs_by_fee_priority: BTreeMap, - txs_by_timelock_priority: BTreeMap, -} - -impl PendingPoolStorage { - /// Create a new PendingPoolStorage with the specified configuration - pub fn new(config: PendingPoolConfig) -> Self { - Self { - config, - txs_by_signature: HashMap::new(), - txs_by_fee_priority: BTreeMap::new(), - txs_by_timelock_priority: BTreeMap::new(), - } - } - - fn lowest_fee_priority(&self) -> &FeePriority { - self.txs_by_fee_priority.iter().next().unwrap().0 - } - - fn remove_tx_with_lowest_fee_priority(&mut self) { - if let Some((_, tx_key)) = self - .txs_by_fee_priority - .iter() - .next() - .map(|(p, s)| (p.clone(), s.clone())) - { - if let Some(removed_tx) = self.txs_by_signature.remove(&tx_key) { - trace!( - target: LOG_TARGET, - "Removing tx from pending pool: {:?}, {:?}", - removed_tx.fee_priority, - removed_tx.timelock_priority - ); - self.txs_by_fee_priority.remove(&removed_tx.fee_priority); - self.txs_by_timelock_priority.remove(&removed_tx.timelock_priority); - } - } - } - - /// Insert a new transaction into the PendingPoolStorage. Low priority transactions will be removed to make space - /// for higher priority transactions. - pub fn insert(&mut self, tx: Arc) -> Result<(), PendingPoolError> { - let tx_key = tx.body.kernels()[0].excess_sig.clone(); - if !self.txs_by_signature.contains_key(&tx_key) { - debug!( - target: LOG_TARGET, - "Inserting tx into pending pool: {}", - tx_key.get_signature().to_hex() - ); - trace!(target: LOG_TARGET, "Transaction inserted: {}", tx); - let prioritized_tx = TimelockedTransaction::try_from((*tx).clone())?; - if self.txs_by_signature.len() >= self.config.storage_capacity { - if prioritized_tx.fee_priority < *self.lowest_fee_priority() { - return Ok(()); - } - self.remove_tx_with_lowest_fee_priority(); - } - - self.txs_by_fee_priority - .insert(prioritized_tx.fee_priority.clone(), tx_key.clone()); - self.txs_by_timelock_priority - .insert(prioritized_tx.timelock_priority.clone(), tx_key.clone()); - self.txs_by_signature.insert(tx_key, prioritized_tx); - } - Ok(()) - } - - /// Insert a set of new transactions into the PendingPoolStorage - pub fn insert_txs(&mut self, txs: Vec>) -> Result<(), PendingPoolError> { - for tx in txs.into_iter() { - self.insert(tx)?; - } - Ok(()) - } - - /// Check if a transaction is stored in the PendingPoolStorage - pub fn has_tx_with_excess_sig(&self, excess_sig: &Signature) -> bool { - self.txs_by_signature.contains_key(excess_sig) - } - - /// Remove double-spends from the PendingPoolStorage. These transactions were orphaned by the provided published - /// block. Check if any of the unspent transactions in the PendingPool has inputs that was spent by the provided - /// published block. - fn discard_double_spends(&mut self, published_block: &Block) { - let mut removed_tx_keys: Vec = Vec::new(); - for (tx_key, ptx) in self.txs_by_signature.iter() { - for input in ptx.transaction.body.inputs() { - if published_block.body.inputs().contains(input) { - self.txs_by_fee_priority.remove(&ptx.fee_priority); - self.txs_by_timelock_priority.remove(&ptx.timelock_priority); - removed_tx_keys.push(tx_key.clone()); - } - } - } - - for tx_key in &removed_tx_keys { - trace!(target: LOG_TARGET, "Removed double spends: {:?}", tx_key); - self.txs_by_signature.remove(&tx_key); - } - } - - /// Remove all published transactions from the UnconfirmedPoolStorage and discard double spends - pub fn remove_unlocked_and_discard_double_spends( - &mut self, - published_block: &Block, - ) -> Result>, PendingPoolError> - { - self.discard_double_spends(published_block); - - let mut removed_txs: Vec> = Vec::new(); - let mut removed_tx_keys: Vec = Vec::new(); - for (_, tx_key) in self.txs_by_timelock_priority.iter() { - if self - .txs_by_signature - .get(tx_key) - .ok_or(PendingPoolError::StorageOutofSync)? - .max_timelock_height > - published_block.header.height - { - break; - } - - if let Some(removed_ptx) = self.txs_by_signature.remove(&tx_key) { - self.txs_by_fee_priority.remove(&removed_ptx.fee_priority); - removed_tx_keys.push(removed_ptx.timelock_priority); - removed_txs.push(removed_ptx.transaction); - } - } - - for tx_key in &removed_tx_keys { - trace!(target: LOG_TARGET, "Removed unlocked and double spends: {:?}", tx_key); - self.txs_by_timelock_priority.remove(&tx_key); - } - - Ok(removed_txs) - } - - /// Returns the total number of unconfirmed transactions stored in the PendingPoolStorage - pub fn len(&self) -> usize { - self.txs_by_signature.len() - } - - /// Returns all transaction stored in the PendingPoolStorage. - pub fn snapshot(&self) -> Vec> { - self.txs_by_signature - .iter() - .map(|(_, ptx)| ptx.transaction.clone()) - .collect() - } - - /// Returns the total weight of all transactions stored in the pool. - pub fn calculate_weight(&self) -> u64 { - self.txs_by_signature - .iter() - .fold(0, |weight, (_, ptx)| weight + ptx.transaction.calculate_weight()) - } - - #[cfg(test)] - /// Checks the consistency status of the Hashmap and the BtreeMaps - pub fn check_status(&self) -> bool { - if (self.txs_by_fee_priority.len() != self.txs_by_signature.len()) || - (self.txs_by_timelock_priority.len() != self.txs_by_signature.len()) - { - return false; - } - self.txs_by_fee_priority - .iter() - .all(|(_, tx_key)| self.txs_by_signature.contains_key(tx_key)) && - self.txs_by_timelock_priority - .iter() - .all(|(_, tx_key)| self.txs_by_signature.contains_key(tx_key)) - } -} diff --git a/base_layer/core/src/mempool/priority/prioritized_transaction.rs b/base_layer/core/src/mempool/priority/prioritized_transaction.rs index 78c28381d4..e697d564f8 100644 --- a/base_layer/core/src/mempool/priority/prioritized_transaction.rs +++ b/base_layer/core/src/mempool/priority/prioritized_transaction.rs @@ -32,6 +32,7 @@ pub struct FeePriority(Vec); impl FeePriority { pub fn try_from(transaction: &Transaction) -> Result { + // The weights have been normalised, so the fee priority is now equal to the fee per gram ± a few pct points let fee_per_byte = (transaction.calculate_ave_fee_per_gram() * 1000.0) as usize; // Include 3 decimal places before flooring let mut fee_priority = fee_per_byte.to_binary()?; fee_priority.reverse(); // Requires Big-endian for BtreeMap sorting diff --git a/base_layer/core/src/mempool/proto/mempool_request.rs b/base_layer/core/src/mempool/proto/mempool_request.rs index 2e4c78239b..7115432a27 100644 --- a/base_layer/core/src/mempool/proto/mempool_request.rs +++ b/base_layer/core/src/mempool/proto/mempool_request.rs @@ -36,6 +36,7 @@ impl TryInto for ProtoMempoolRequest { let request = match self { // Field was not specified GetStats(_) => MempoolRequest::GetStats, + GetState(_) => MempoolRequest::GetState, GetTxStateWithExcessSig(excess_sig) => MempoolRequest::GetTxStateWithExcessSig( excess_sig.try_into().map_err(|err: ByteArrayError| err.to_string())?, ), @@ -50,6 +51,7 @@ impl From for ProtoMempoolRequest { use MempoolRequest::*; match request { GetStats => ProtoMempoolRequest::GetStats(true), + GetState => ProtoMempoolRequest::GetState(true), GetTxStateWithExcessSig(excess_sig) => ProtoMempoolRequest::GetTxStateWithExcessSig(excess_sig.into()), SubmitTransaction(tx) => ProtoMempoolRequest::SubmitTransaction(tx.into()), } diff --git a/base_layer/core/src/mempool/proto/mempool_response.rs b/base_layer/core/src/mempool/proto/mempool_response.rs index 324a54aafe..f50bb4d4f8 100644 --- a/base_layer/core/src/mempool/proto/mempool_response.rs +++ b/base_layer/core/src/mempool/proto/mempool_response.rs @@ -37,6 +37,7 @@ impl TryInto for ProtoMempoolResponse { use ProtoMempoolResponse::*; let response = match self { Stats(stats_response) => MempoolResponse::Stats(stats_response.try_into()?), + State(state_response) => MempoolResponse::State(state_response.try_into()?), TxStorage(tx_storage_response) => { let tx_storage_response = ProtoTxStorageResponse::from_i32(tx_storage_response) .ok_or_else(|| "Invalid or unrecognised `TxStorageResponse` enum".to_string())?; @@ -66,6 +67,7 @@ impl From for ProtoMempoolResponse { use MempoolResponse::*; match response { Stats(stats_response) => ProtoMempoolResponse::Stats(stats_response.into()), + State(state_response) => ProtoMempoolResponse::State(state_response.into()), TxStorage(tx_storage_response) => { let tx_storage_response: ProtoTxStorageResponse = tx_storage_response.into(); ProtoMempoolResponse::TxStorage(tx_storage_response.into()) diff --git a/base_layer/core/src/mempool/proto/mod.rs b/base_layer/core/src/mempool/proto/mod.rs index 6b9ab69722..ed27a0e691 100644 --- a/base_layer/core/src/mempool/proto/mod.rs +++ b/base_layer/core/src/mempool/proto/mod.rs @@ -29,6 +29,7 @@ pub mod mempool { pub mod mempool_request; pub mod mempool_response; +pub mod state_response; pub mod stats_response; pub mod tx_storage_response; pub use mempool::{MempoolServiceRequest, MempoolServiceResponse}; diff --git a/base_layer/core/src/mempool/proto/service_request.proto b/base_layer/core/src/mempool/proto/service_request.proto index 2b14d73589..4f5d59229d 100644 --- a/base_layer/core/src/mempool/proto/service_request.proto +++ b/base_layer/core/src/mempool/proto/service_request.proto @@ -11,9 +11,11 @@ message MempoolServiceRequest { oneof request { // Indicates a GetStats request. The value of the bool should be ignored. bool get_stats = 2; + // Indicates a GetState request. The value of the bool should be ignored. + bool get_state = 3; // Indicates a GetTxStateWithExcessSig request. - tari.types.Signature get_tx_state_with_excess_sig = 3; + tari.types.Signature get_tx_state_with_excess_sig = 4; // Indicates a SubmitTransaction request. - tari.types.Transaction submit_transaction = 4; + tari.types.Transaction submit_transaction = 5; } } diff --git a/base_layer/core/src/mempool/proto/service_response.proto b/base_layer/core/src/mempool/proto/service_response.proto index 9d4809d965..bee472bfaa 100644 --- a/base_layer/core/src/mempool/proto/service_response.proto +++ b/base_layer/core/src/mempool/proto/service_response.proto @@ -1,6 +1,7 @@ syntax = "proto3"; import "stats_response.proto"; +import "state_response.proto"; import "tx_storage_response.proto"; package tari.mempool; @@ -10,7 +11,8 @@ message MempoolServiceResponse { uint64 request_key = 1; oneof response { StatsResponse stats = 2; - TxStorageResponse tx_storage = 3; + StateResponse state = 3; + TxStorageResponse tx_storage = 4; } } diff --git a/base_layer/core/src/mempool/proto/state_response.proto b/base_layer/core/src/mempool/proto/state_response.proto new file mode 100644 index 0000000000..d75dfa0d94 --- /dev/null +++ b/base_layer/core/src/mempool/proto/state_response.proto @@ -0,0 +1,22 @@ +syntax = "proto3"; + +package tari.mempool; + +// TODO: Remove duplicate Signature, transaction also has a Signature. +// Define the explicit Signature implementation for the Tari base layer. A different signature scheme can be +// employed by redefining this type. +message Signature { + bytes public_nonce = 1; + bytes signature = 2; +} + +message StateResponse { + // List of transactions in unconfirmed pool. + repeated Signature unconfirmed_pool = 1; + // List of transactions in orphan pool. + repeated Signature orphan_pool = 2; + // List of transactions in pending pool. + repeated Signature pending_pool = 3; + // List of transactions in reorg pool. + repeated Signature reorg_pool = 4; +} \ No newline at end of file diff --git a/base_layer/core/src/mempool/proto/state_response.rs b/base_layer/core/src/mempool/proto/state_response.rs new file mode 100644 index 0000000000..eff7206f3f --- /dev/null +++ b/base_layer/core/src/mempool/proto/state_response.rs @@ -0,0 +1,98 @@ +// Copyright 2019, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use crate::mempool::{proto::mempool::StateResponse as ProtoStateResponse, StateResponse}; +use std::convert::{TryFrom, TryInto}; +// use crate::transactions::proto::types::Signature as ProtoSignature; +use crate::{ + mempool::proto::mempool::Signature as ProtoSignature, + transactions::types::{PrivateKey, PublicKey, Signature}, +}; +use tari_crypto::tari_utilities::{ByteArray, ByteArrayError}; + +//---------------------------------- Signature --------------------------------------------// +// TODO: Remove duplicate Signature, transaction also has a Signature. +impl TryFrom for Signature { + type Error = ByteArrayError; + + fn try_from(sig: ProtoSignature) -> Result { + let public_nonce = PublicKey::from_bytes(&sig.public_nonce)?; + let signature = PrivateKey::from_bytes(&sig.signature)?; + + Ok(Self::new(public_nonce, signature)) + } +} + +impl From for ProtoSignature { + fn from(sig: Signature) -> Self { + Self { + public_nonce: sig.get_public_nonce().to_vec(), + signature: sig.get_signature().to_vec(), + } + } +} + +//--------------------------------- StateResponse -------------------------------------------// + +impl TryFrom for StateResponse { + type Error = String; + + fn try_from(state: ProtoStateResponse) -> Result { + Ok(Self { + unconfirmed_pool: state + .unconfirmed_pool + .into_iter() + .map(TryInto::try_into) + .collect::, _>>() + .map_err(|err: ByteArrayError| err.to_string())?, + orphan_pool: state + .orphan_pool + .into_iter() + .map(TryInto::try_into) + .collect::, _>>() + .map_err(|err: ByteArrayError| err.to_string())?, + pending_pool: state + .pending_pool + .into_iter() + .map(TryInto::try_into) + .collect::, _>>() + .map_err(|err: ByteArrayError| err.to_string())?, + reorg_pool: state + .reorg_pool + .into_iter() + .map(TryInto::try_into) + .collect::, _>>() + .map_err(|err: ByteArrayError| err.to_string())?, + }) + } +} + +impl From for ProtoStateResponse { + fn from(state: StateResponse) -> Self { + Self { + unconfirmed_pool: state.unconfirmed_pool.into_iter().map(Into::into).collect(), + orphan_pool: state.orphan_pool.into_iter().map(Into::into).collect(), + pending_pool: state.pending_pool.into_iter().map(Into::into).collect(), + reorg_pool: state.reorg_pool.into_iter().map(Into::into).collect(), + } + } +} diff --git a/base_layer/core/src/mempool/reorg_pool/mod.rs b/base_layer/core/src/mempool/reorg_pool/mod.rs index e5221f0efe..31d15c8d68 100644 --- a/base_layer/core/src/mempool/reorg_pool/mod.rs +++ b/base_layer/core/src/mempool/reorg_pool/mod.rs @@ -21,6 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mod error; +#[allow(clippy::module_inception)] mod reorg_pool; mod reorg_pool_storage; diff --git a/base_layer/core/src/mempool/reorg_pool/reorg_pool.rs b/base_layer/core/src/mempool/reorg_pool/reorg_pool.rs index c952cedee2..ab084beae9 100644 --- a/base_layer/core/src/mempool/reorg_pool/reorg_pool.rs +++ b/base_layer/core/src/mempool/reorg_pool/reorg_pool.rs @@ -78,6 +78,16 @@ impl ReorgPool { Ok(()) } + /// Insert a new transaction into the ReorgPool. Published transactions will have a limited Time-to-live in + /// the ReorgPool and will be discarded once the Time-to-live threshold has been reached. + pub fn insert(&self, transaction: Arc) -> Result<(), ReorgPoolError> { + self.pool_storage + .write() + .map_err(|e| ReorgPoolError::BackendError(e.to_string()))? + .insert(transaction); + Ok(()) + } + /// Check if a transaction is stored in the ReorgPool pub fn has_tx_with_excess_sig(&self, excess_sig: &Signature) -> Result { Ok(self @@ -111,6 +121,15 @@ impl ReorgPool { .len()) } + /// Returns all transaction stored in the ReorgPool. + pub fn snapshot(&self) -> Result>, ReorgPoolError> { + Ok(self + .pool_storage + .write() + .map_err(|e| ReorgPoolError::BackendError(e.to_string()))? + .snapshot()) + } + /// Returns the total weight of all transactions stored in the pool. pub fn calculate_weight(&self) -> Result { Ok(self @@ -137,12 +156,12 @@ mod test { #[test] fn test_insert_rlu_and_ttl() { - let tx1 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(500), lock: 4000, inputs: 2, outputs: 1).0); - let tx2 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(300), lock: 3000, inputs: 2, outputs: 1).0); - let tx3 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(100), lock: 2500, inputs: 2, outputs: 1).0); - let tx4 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(200), lock: 1000, inputs: 2, outputs: 1).0); - let tx5 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(500), lock: 2000, inputs: 2, outputs: 1).0); - let tx6 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(600), lock: 5500, inputs: 2, outputs: 1).0); + let tx1 = Arc::new(tx!(MicroTari(100_000), fee: MicroTari(500), lock: 4000, inputs: 2, outputs: 1).0); + let tx2 = Arc::new(tx!(MicroTari(100_000), fee: MicroTari(300), lock: 3000, inputs: 2, outputs: 1).0); + let tx3 = Arc::new(tx!(MicroTari(100_000), fee: MicroTari(100), lock: 2500, inputs: 2, outputs: 1).0); + let tx4 = Arc::new(tx!(MicroTari(100_000), fee: MicroTari(200), lock: 1000, inputs: 2, outputs: 1).0); + let tx5 = Arc::new(tx!(MicroTari(100_000), fee: MicroTari(500), lock: 2000, inputs: 2, outputs: 1).0); + let tx6 = Arc::new(tx!(MicroTari(100_000), fee: MicroTari(600), lock: 5500, inputs: 2, outputs: 1).0); let reorg_pool = ReorgPool::new(ReorgPoolConfig { storage_capacity: 3, diff --git a/base_layer/core/src/mempool/reorg_pool/reorg_pool_storage.rs b/base_layer/core/src/mempool/reorg_pool/reorg_pool_storage.rs index b1c0e6684d..1e9423785c 100644 --- a/base_layer/core/src/mempool/reorg_pool/reorg_pool_storage.rs +++ b/base_layer/core/src/mempool/reorg_pool/reorg_pool_storage.rs @@ -125,6 +125,11 @@ impl ReorgPoolStorage { self.txs_by_signature.iter().count() } + /// Returns all transaction stored in the ReorgPoolStorage. + pub fn snapshot(&mut self) -> Vec> { + self.txs_by_signature.iter().map(|(_, tx)| tx).cloned().collect() + } + /// Returns the total weight of all transactions stored in the pool. pub fn calculate_weight(&mut self) -> u64 { self.txs_by_signature diff --git a/base_layer/core/src/mempool/service/inbound_handlers.rs b/base_layer/core/src/mempool/service/inbound_handlers.rs index a07aea0f7e..09f3fb61e9 100644 --- a/base_layer/core/src/mempool/service/inbound_handlers.rs +++ b/base_layer/core/src/mempool/service/inbound_handlers.rs @@ -62,6 +62,9 @@ where T: BlockchainBackend + 'static MempoolRequest::GetStats => Ok(MempoolResponse::Stats( async_mempool::stats(self.mempool.clone()).await?, )), + MempoolRequest::GetState => Ok(MempoolResponse::State( + async_mempool::state(self.mempool.clone()).await?, + )), MempoolRequest::GetTxStateWithExcessSig(excess_sig) => Ok(MempoolResponse::TxStorage( async_mempool::has_tx_with_excess_sig(self.mempool.clone(), excess_sig.clone()).await?, )), diff --git a/base_layer/core/src/mempool/service/initializer.rs b/base_layer/core/src/mempool/service/initializer.rs index 438846277a..8d62cc2dfc 100644 --- a/base_layer/core/src/mempool/service/initializer.rs +++ b/base_layer/core/src/mempool/service/initializer.rs @@ -28,6 +28,7 @@ use crate::{ proto, service::{ inbound_handlers::MempoolInboundHandlers, + local_service::LocalMempoolService, outbound_interface::OutboundMempoolServiceInterface, service::{MempoolService, MempoolStreams}, }, @@ -120,10 +121,9 @@ async fn extract_transaction(msg: Arc) -> Option { let tx = match Transaction::try_from(tx) { Err(e) => { - let origin = msg.origin_public_key(); warn!( target: LOG_TARGET, - "Inbound transaction message from {} was ill-formed. {}", origin, e + "Inbound transaction message from {} was ill-formed. {}", msg.source_peer.public_key, e ); return None; }, @@ -132,6 +132,7 @@ async fn extract_transaction(msg: Arc) -> Option Err(MempoolServiceError::UnexpectedApiResponse), } } + + pub async fn get_mempool_state(&mut self) -> Result { + match self.request_sender.call(MempoolRequest::GetState).await?? { + MempoolResponse::State(s) => Ok(s), + _ => Err(MempoolServiceError::UnexpectedApiResponse), + } + } } #[cfg(test)] @@ -96,7 +103,7 @@ mod test { MempoolRequest::GetStats => Ok(MempoolResponse::Stats(request_stats())), _ => Err(MempoolServiceError::UnexpectedApiResponse), }; - reply_channel.send(res); + reply_channel.send(res).unwrap(); } } diff --git a/base_layer/core/src/mempool/service/mod.rs b/base_layer/core/src/mempool/service/mod.rs index 9307a375e7..534e9bd87d 100644 --- a/base_layer/core/src/mempool/service/mod.rs +++ b/base_layer/core/src/mempool/service/mod.rs @@ -30,6 +30,7 @@ mod initializer; mod local_service; #[cfg(feature = "base_node")] mod outbound_interface; +#[allow(clippy::module_inception)] #[cfg(feature = "base_node")] mod service; @@ -39,6 +40,8 @@ pub use error::MempoolServiceError; #[cfg(feature = "base_node")] pub use initializer::MempoolServiceInitializer; #[cfg(feature = "base_node")] +pub use local_service::LocalMempoolService; +#[cfg(feature = "base_node")] pub use outbound_interface::OutboundMempoolServiceInterface; #[cfg(feature = "base_node")] pub use service::MempoolService; diff --git a/base_layer/core/src/mempool/service/request.rs b/base_layer/core/src/mempool/service/request.rs index b6bb6233af..1fcdab2568 100644 --- a/base_layer/core/src/mempool/service/request.rs +++ b/base_layer/core/src/mempool/service/request.rs @@ -32,6 +32,7 @@ use tari_crypto::tari_utilities::hex::Hex; #[derive(Debug, Serialize, Deserialize)] pub enum MempoolRequest { GetStats, + GetState, GetTxStateWithExcessSig(Signature), SubmitTransaction(Transaction), } @@ -40,6 +41,7 @@ impl Display for MempoolRequest { fn fmt(&self, f: &mut Formatter) -> Result<(), Error> { match self { MempoolRequest::GetStats => f.write_str("GetStats"), + MempoolRequest::GetState => f.write_str("GetState"), MempoolRequest::GetTxStateWithExcessSig(sig) => { f.write_str(&format!("GetTxStateWithExcessSig ({})", sig.get_signature().to_hex())) }, diff --git a/base_layer/core/src/mempool/service/response.rs b/base_layer/core/src/mempool/service/response.rs index b4d1fb1bf1..32c6daddf1 100644 --- a/base_layer/core/src/mempool/service/response.rs +++ b/base_layer/core/src/mempool/service/response.rs @@ -22,19 +22,20 @@ use crate::{ base_node::RequestKey, - mempool::{StatsResponse, TxStorageResponse}, + mempool::{StateResponse, StatsResponse, TxStorageResponse}, }; use serde::{Deserialize, Serialize}; /// API Response enum for Mempool responses. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub enum MempoolResponse { Stats(StatsResponse), + State(StateResponse), TxStorage(TxStorageResponse), } /// Response type for a received MempoolService requests -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct MempoolServiceResponse { pub request_key: RequestKey, pub response: MempoolResponse, diff --git a/base_layer/core/src/mempool/service/service.rs b/base_layer/core/src/mempool/service/service.rs index 4a3105c042..3c31972f50 100644 --- a/base_layer/core/src/mempool/service/service.rs +++ b/base_layer/core/src/mempool/service/service.rs @@ -63,21 +63,23 @@ use tokio::task; const LOG_TARGET: &str = "c::mempool::service::service"; /// A convenience struct to hold all the Mempool service streams -pub struct MempoolStreams { +pub struct MempoolStreams { outbound_request_stream: SOutReq, outbound_tx_stream: UnboundedReceiver<(Transaction, Vec)>, inbound_request_stream: SInReq, inbound_response_stream: SInRes, inbound_transaction_stream: STxIn, + local_request_stream: SLocalReq, block_event_stream: Subscriber, } -impl MempoolStreams +impl MempoolStreams where SOutReq: Stream>>, SInReq: Stream>, SInRes: Stream>, STxIn: Stream>, + SLocalReq: Stream>>, { pub fn new( outbound_request_stream: SOutReq, @@ -85,6 +87,7 @@ where inbound_request_stream: SInReq, inbound_response_stream: SInRes, inbound_transaction_stream: STxIn, + local_request_stream: SLocalReq, block_event_stream: Subscriber, ) -> Self { @@ -94,6 +97,7 @@ where inbound_request_stream, inbound_response_stream, inbound_transaction_stream, + local_request_stream, block_event_stream, } } @@ -130,15 +134,16 @@ where B: BlockchainBackend + 'static } } - pub async fn start( + pub async fn start( mut self, - streams: MempoolStreams, + streams: MempoolStreams, ) -> Result<(), MempoolServiceError> where SOutReq: Stream>>, SInReq: Stream>, SInRes: Stream>, STxIn: Stream>, + SLocalReq: Stream>>, { let outbound_request_stream = streams.outbound_request_stream.fuse(); pin_mut!(outbound_request_stream); @@ -150,6 +155,8 @@ where B: BlockchainBackend + 'static pin_mut!(inbound_response_stream); let inbound_transaction_stream = streams.inbound_transaction_stream.fuse(); pin_mut!(inbound_transaction_stream); + let local_request_stream = streams.local_request_stream.fuse(); + pin_mut!(local_request_stream); let block_event_stream = streams.block_event_stream.fuse(); pin_mut!(block_event_stream); let timeout_receiver_stream = self @@ -185,6 +192,11 @@ where B: BlockchainBackend + 'static self.spawn_handle_incoming_tx(transaction_msg); } + // Incoming local request messages from the LocalMempoolServiceInterface and other local services + local_request_context = local_request_stream.select_next_some() => { + self.spawn_handle_local_request(local_request_context); + }, + // Block events from local Base Node. block_event = block_event_stream.select_next_some() => { self.spawn_handle_block_event(block_event); @@ -291,6 +303,26 @@ where B: BlockchainBackend + 'static }); } + fn spawn_handle_local_request( + &self, + request_context: RequestContext>, + ) + { + let mut inbound_handlers = self.inbound_handlers.clone(); + task::spawn(async move { + let (request, reply_tx) = request_context.split(); + let _ = reply_tx + .send(inbound_handlers.handle_request(&request).await) + .or_else(|err| { + error!( + target: LOG_TARGET, + "MempoolService failed to send reply to local request {:?}", err + ); + Err(err) + }); + }); + } + fn spawn_handle_block_event(&self, block_event: Arc) { let inbound_handlers = self.inbound_handlers.clone(); task::spawn(async move { @@ -339,7 +371,7 @@ async fn handle_incoming_request( outbound_message_service .send_direct( origin_public_key, - OutboundEncryption::EncryptForPeer, + OutboundEncryption::None, OutboundDomainMessage::new(TariMessageType::MempoolResponse, message), ) .await?; @@ -389,7 +421,7 @@ async fn handle_outbound_request( .send_random( 1, NodeDestination::Unknown, - OutboundEncryption::EncryptForPeer, + OutboundEncryption::None, OutboundDomainMessage::new(TariMessageType::MempoolRequest, service_request), ) .await @@ -400,7 +432,7 @@ async fn handle_outbound_request( .map_err(|e| MempoolServiceError::OutboundMessageService(e.to_string()))?; match send_result.resolve_ok().await { - Some(tags) if !tags.is_empty() => { + Some(send_states) if !send_states.is_empty() => { // Spawn timeout and wait for matching response to arrive waiting_requests.insert(request_key, Some(reply_tx))?; // Spawn timeout for waiting_request @@ -484,7 +516,7 @@ async fn handle_outbound_tx( outbound_message_service .propagate( NodeDestination::Unknown, - OutboundEncryption::EncryptForPeer, + OutboundEncryption::None, exclude_peers, OutboundDomainMessage::new(TariMessageType::NewTransaction, ProtoTransaction::from(tx)), ) diff --git a/base_layer/core/src/mempool/unconfirmed_pool/error.rs b/base_layer/core/src/mempool/unconfirmed_pool/error.rs index b611005ad5..74577d67c8 100644 --- a/base_layer/core/src/mempool/unconfirmed_pool/error.rs +++ b/base_layer/core/src/mempool/unconfirmed_pool/error.rs @@ -27,8 +27,5 @@ use derive_error::Error; pub enum UnconfirmedPoolError { /// The HashMap and BTreeMap are out of sync StorageOutofSync, - /// A problem has been encountered with the storage backend. - #[error(non_std, no_from)] - BackendError(String), PriorityError(PriorityError), } diff --git a/base_layer/core/src/mempool/unconfirmed_pool/mod.rs b/base_layer/core/src/mempool/unconfirmed_pool/mod.rs index c5568aa70a..3ce8d80931 100644 --- a/base_layer/core/src/mempool/unconfirmed_pool/mod.rs +++ b/base_layer/core/src/mempool/unconfirmed_pool/mod.rs @@ -21,10 +21,9 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mod error; +#[allow(clippy::module_inception)] mod unconfirmed_pool; -mod unconfirmed_pool_storage; // Public re-exports pub use error::UnconfirmedPoolError; pub use unconfirmed_pool::{UnconfirmedPool, UnconfirmedPoolConfig}; -pub use unconfirmed_pool_storage::UnconfirmedPoolStorage; diff --git a/base_layer/core/src/mempool/unconfirmed_pool/unconfirmed_pool.rs b/base_layer/core/src/mempool/unconfirmed_pool/unconfirmed_pool.rs index 7797900d72..72c5ccccae 100644 --- a/base_layer/core/src/mempool/unconfirmed_pool/unconfirmed_pool.rs +++ b/base_layer/core/src/mempool/unconfirmed_pool/unconfirmed_pool.rs @@ -24,11 +24,20 @@ use crate::{ blocks::Block, mempool::{ consts::{MEMPOOL_UNCONFIRMED_POOL_STORAGE_CAPACITY, MEMPOOL_UNCONFIRMED_POOL_WEIGHT_TRANSACTION_SKIP_COUNT}, - unconfirmed_pool::{UnconfirmedPoolError, UnconfirmedPoolStorage}, + priority::{FeePriority, PrioritizedTransaction}, + unconfirmed_pool::UnconfirmedPoolError, }, transactions::{transaction::Transaction, types::Signature}, }; -use std::sync::{Arc, RwLock}; +use log::*; +use std::{ + collections::{BTreeMap, HashMap}, + convert::TryFrom, + sync::Arc, +}; +use tari_crypto::tari_utilities::hex::Hex; + +pub const LOG_TARGET: &str = "c::mp::unconfirmed_pool::unconfirmed_pool_storage"; /// Configuration for the UnconfirmedPool #[derive(Clone, Copy)] @@ -51,110 +60,197 @@ impl Default for UnconfirmedPoolConfig { /// The Unconfirmed Transaction Pool consists of all unconfirmed transactions that are ready to be included in a block /// and they are prioritised according to the priority metric. +/// The txs_by_signature HashMap is used to find a transaction using its excess_sig, this functionality is used to match +/// transactions included in blocks with transactions stored in the pool. The txs_by_priority BTreeMap prioritise the +/// transactions in the pool according to TXPriority, it allows transactions to be inserted in sorted order by their +/// priority. The txs_by_priority BTreeMap makes it easier to select the set of highest priority transactions that can +/// be included in a block. The excess_sig of a transaction is used a key to uniquely identify a specific transaction in +/// these containers. pub struct UnconfirmedPool { - pool_storage: Arc>, + config: UnconfirmedPoolConfig, + txs_by_signature: HashMap, + txs_by_priority: BTreeMap, } impl UnconfirmedPool { /// Create a new UnconfirmedPool with the specified configuration pub fn new(config: UnconfirmedPoolConfig) -> Self { Self { - pool_storage: Arc::new(RwLock::new(UnconfirmedPoolStorage::new(config))), + config, + txs_by_signature: HashMap::new(), + txs_by_priority: BTreeMap::new(), + } + } + + fn lowest_priority(&self) -> &FeePriority { + self.txs_by_priority.iter().next().unwrap().0 + } + + fn remove_lowest_priority_tx(&mut self) { + if let Some((priority, sig)) = self.txs_by_priority.iter().next().map(|(p, s)| (p.clone(), s.clone())) { + self.txs_by_signature.remove(&sig); + self.txs_by_priority.remove(&priority); } } /// Insert a new transaction into the UnconfirmedPool. Low priority transactions will be removed to make space for /// higher priority transactions. The lowest priority transactions will be removed when the maximum capacity is /// reached and the new transaction has a higher priority than the currently stored lowest priority transaction. - pub fn insert(&self, transaction: Arc) -> Result<(), UnconfirmedPoolError> { - self.pool_storage - .write() - .map_err(|e| UnconfirmedPoolError::BackendError(e.to_string()))? - .insert(transaction) + #[allow(clippy::map_entry)] + pub fn insert(&mut self, tx: Arc) -> Result<(), UnconfirmedPoolError> { + let tx_key = tx.body.kernels()[0].excess_sig.clone(); + if !self.txs_by_signature.contains_key(&tx_key) { + debug!( + target: LOG_TARGET, + "Inserting tx into unconfirmed pool: {}", + tx_key.get_signature().to_hex() + ); + trace!(target: LOG_TARGET, "Transaction inserted: {}", tx); + let prioritized_tx = PrioritizedTransaction::try_from((*tx).clone())?; + if self.txs_by_signature.len() >= self.config.storage_capacity { + if prioritized_tx.priority < *self.lowest_priority() { + return Ok(()); + } + self.remove_lowest_priority_tx(); + } + self.txs_by_priority + .insert(prioritized_tx.priority.clone(), tx_key.clone()); + self.txs_by_signature.insert(tx_key, prioritized_tx); + } + Ok(()) } - /// Insert a set of new transactions into the UnconfirmedPool - pub fn insert_txs(&self, transactions: Vec>) -> Result<(), UnconfirmedPoolError> { - self.pool_storage - .write() - .map_err(|e| UnconfirmedPoolError::BackendError(e.to_string()))? - .insert_txs(transactions) + /// Insert a set of new transactions into the UnconfirmedPool + pub fn insert_txs(&mut self, txs: Vec>) -> Result<(), UnconfirmedPoolError> { + for tx in txs.into_iter() { + self.insert(tx)?; + } + Ok(()) } /// Check if a transaction is available in the UnconfirmedPool - pub fn has_tx_with_excess_sig(&self, excess_sig: &Signature) -> Result { - Ok(self - .pool_storage - .read() - .map_err(|e| UnconfirmedPoolError::BackendError(e.to_string()))? - .has_tx_with_excess_sig(excess_sig)) + pub fn has_tx_with_excess_sig(&self, excess_sig: &Signature) -> bool { + self.txs_by_signature.contains_key(excess_sig) } /// Returns a set of the highest priority unconfirmed transactions, that can be included in a block pub fn highest_priority_txs(&self, total_weight: u64) -> Result>, UnconfirmedPoolError> { - self.pool_storage - .read() - .map_err(|e| UnconfirmedPoolError::BackendError(e.to_string()))? - .highest_priority_txs(total_weight) + let mut selected_txs: Vec> = Vec::new(); + let mut curr_weight: u64 = 0; + let mut curr_skip_count: usize = 0; + for (_, tx_key) in self.txs_by_priority.iter().rev() { + let ptx = self + .txs_by_signature + .get(tx_key) + .ok_or_else(|| UnconfirmedPoolError::StorageOutofSync)?; + + if curr_weight + ptx.weight <= total_weight { + curr_weight += ptx.weight; + selected_txs.push(ptx.transaction.clone()); + } else { + // Check if some the next few txs with slightly lower priority wont fit in the remaining space. + curr_skip_count += 1; + if curr_skip_count >= self.config.weight_tx_skip_count { + break; + } + } + } + Ok(selected_txs) } /// Remove all published transactions from the UnconfirmedPool and discard all double spend transactions. /// Returns a list of all transactions that were removed the unconfirmed pool as a result of appearing in the block. - pub fn remove_published_and_discard_double_spends( - &self, - published_block: &Block, - ) -> Result>, UnconfirmedPoolError> - { - Ok(self - .pool_storage - .write() - .map_err(|e| UnconfirmedPoolError::BackendError(e.to_string()))? - .remove_published_and_discard_double_spends(published_block)) + fn discard_double_spends(&mut self, published_block: &Block) { + let mut removed_tx_keys: Vec = Vec::new(); + for (tx_key, ptx) in self.txs_by_signature.iter() { + for input in ptx.transaction.body.inputs() { + if published_block.body.inputs().contains(input) { + self.txs_by_priority.remove(&ptx.priority); + removed_tx_keys.push(tx_key.clone()); + } + } + } + + for tx_key in &removed_tx_keys { + trace!( + target: LOG_TARGET, + "Removing double spends from unconfirmed pool: {:?}", + tx_key + ); + self.txs_by_signature.remove(&tx_key); + } + } + + /// Remove all published transactions from the UnconfirmedPoolStorage and discard double spends + pub fn remove_published_and_discard_double_spends(&mut self, published_block: &Block) -> Vec> { + let mut removed_txs: Vec> = Vec::new(); + published_block.body.kernels().iter().for_each(|kernel| { + if let Some(ptx) = self.txs_by_signature.get(&kernel.excess_sig) { + self.txs_by_priority.remove(&ptx.priority); + if let Some(ptx) = self.txs_by_signature.remove(&kernel.excess_sig) { + removed_txs.push(ptx.transaction); + } + } + }); + // First remove published transactions before discarding double spends + self.discard_double_spends(published_block); + + removed_txs + } + + /// Remove all unconfirmed transactions that have become time locked. This can happen when the chain height was + /// reduced on some reorgs. + pub fn remove_timelocked(&mut self, tip_height: u64) -> Vec> { + let mut removed_tx_keys: Vec = Vec::new(); + for (tx_key, ptx) in self.txs_by_signature.iter() { + if ptx.transaction.min_spendable_height() > tip_height + 1 { + self.txs_by_priority.remove(&ptx.priority); + removed_tx_keys.push(tx_key.clone()); + } + } + let mut removed_txs: Vec> = Vec::new(); + for tx_key in removed_tx_keys { + trace!( + target: LOG_TARGET, + "Removing time locked transaction from unconfirmed pool: {:?}", + tx_key + ); + if let Some(ptx) = self.txs_by_signature.remove(&tx_key) { + removed_txs.push(ptx.transaction); + } + } + removed_txs } - /// Returns the total number of unconfirmed transactions stored in the UnconfirmedPool - pub fn len(&self) -> Result { - Ok(self - .pool_storage - .read() - .map_err(|e| UnconfirmedPoolError::BackendError(e.to_string()))? - .len()) + /// Returns the total number of unconfirmed transactions stored in the UnconfirmedPool. + pub fn len(&self) -> usize { + self.txs_by_signature.len() } /// Returns all transaction stored in the UnconfirmedPool. - pub fn snapshot(&self) -> Result>, UnconfirmedPoolError> { - Ok(self - .pool_storage - .read() - .map_err(|e| UnconfirmedPoolError::BackendError(e.to_string()))? - .snapshot()) + pub fn snapshot(&self) -> Vec> { + self.txs_by_signature + .iter() + .map(|(_, ptx)| ptx.transaction.clone()) + .collect() } /// Returns the total weight of all transactions stored in the pool. - pub fn calculate_weight(&self) -> Result { - Ok(self - .pool_storage - .read() - .map_err(|e| UnconfirmedPoolError::BackendError(e.to_string()))? - .calculate_weight()) + pub fn calculate_weight(&self) -> u64 { + self.txs_by_signature + .iter() + .fold(0, |weight, (_, ptx)| weight + ptx.transaction.calculate_weight()) } #[cfg(test)] /// Checks the consistency status of the Hashmap and BtreeMap - pub fn check_status(&self) -> Result { - Ok(self - .pool_storage - .read() - .map_err(|e| UnconfirmedPoolError::BackendError(e.to_string()))? - .check_status()) - } -} - -impl Clone for UnconfirmedPool { - fn clone(&self) -> Self { - UnconfirmedPool { - pool_storage: self.pool_storage.clone(), + pub fn check_status(&self) -> bool { + if self.txs_by_priority.len() != self.txs_by_signature.len() { + return false; } + self.txs_by_priority + .iter() + .all(|(_, tx_key)| self.txs_by_signature.contains_key(tx_key)) } } @@ -169,9 +265,9 @@ mod test { let tx2 = Arc::new(tx!(MicroTari(5_000), fee: MicroTari(20), inputs: 4, outputs: 1).0); let tx3 = Arc::new(tx!(MicroTari(5_000), fee: MicroTari(100), inputs: 5, outputs: 1).0); let tx4 = Arc::new(tx!(MicroTari(5_000), fee: MicroTari(30), inputs: 3, outputs: 1).0); - let tx5 = Arc::new(tx!(MicroTari(5_000), fee: MicroTari(50), inputs: 5, outputs: 1).0); + let tx5 = Arc::new(tx!(MicroTari(5_000), fee: MicroTari(55), inputs: 5, outputs: 1).0); - let unconfirmed_pool = UnconfirmedPool::new(UnconfirmedPoolConfig { + let mut unconfirmed_pool = UnconfirmedPool::new(UnconfirmedPoolConfig { storage_capacity: 4, weight_tx_skip_count: 3, }); @@ -180,46 +276,36 @@ mod test { .unwrap(); // Check that lowest priority tx was removed to make room for new incoming transactions assert_eq!( - unconfirmed_pool - .has_tx_with_excess_sig(&tx1.body.kernels()[0].excess_sig) - .unwrap(), + unconfirmed_pool.has_tx_with_excess_sig(&tx1.body.kernels()[0].excess_sig), true ); assert_eq!( - unconfirmed_pool - .has_tx_with_excess_sig(&tx2.body.kernels()[0].excess_sig) - .unwrap(), + unconfirmed_pool.has_tx_with_excess_sig(&tx2.body.kernels()[0].excess_sig), false ); assert_eq!( - unconfirmed_pool - .has_tx_with_excess_sig(&tx3.body.kernels()[0].excess_sig) - .unwrap(), + unconfirmed_pool.has_tx_with_excess_sig(&tx3.body.kernels()[0].excess_sig), true ); assert_eq!( - unconfirmed_pool - .has_tx_with_excess_sig(&tx4.body.kernels()[0].excess_sig) - .unwrap(), + unconfirmed_pool.has_tx_with_excess_sig(&tx4.body.kernels()[0].excess_sig), true ); assert_eq!( - unconfirmed_pool - .has_tx_with_excess_sig(&tx5.body.kernels()[0].excess_sig) - .unwrap(), + unconfirmed_pool.has_tx_with_excess_sig(&tx5.body.kernels()[0].excess_sig), true ); // Retrieve the set of highest priority unspent transactions - let desired_weight = tx1.calculate_weight() + tx3.calculate_weight() + tx4.calculate_weight(); + let desired_weight = tx1.calculate_weight() + tx3.calculate_weight() + tx5.calculate_weight(); let selected_txs = unconfirmed_pool.highest_priority_txs(desired_weight).unwrap(); assert_eq!(selected_txs.len(), 3); assert!(selected_txs.contains(&tx1)); assert!(selected_txs.contains(&tx3)); - assert!(selected_txs.contains(&tx4)); + assert!(selected_txs.contains(&tx5)); // Note that transaction tx5 could not be included as its weight was to big to fit into the remaining allocated // space, the second best transaction was then included - assert!(unconfirmed_pool.check_status().unwrap()); + assert!(unconfirmed_pool.check_status()); } #[test] @@ -233,7 +319,7 @@ mod test { let tx5 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(50), inputs:3, outputs: 1).0); let tx6 = Arc::new(tx!(MicroTari(10_000), fee: MicroTari(75), inputs:2, outputs: 1).0); - let unconfirmed_pool = UnconfirmedPool::new(UnconfirmedPoolConfig { + let mut unconfirmed_pool = UnconfirmedPool::new(UnconfirmedPoolConfig { storage_capacity: 10, weight_tx_skip_count: 3, }); @@ -243,7 +329,7 @@ mod test { // utx6 should not be added to unconfirmed_pool as it is an unknown transactions that was included in the block // by another node - let snapshot_txs = unconfirmed_pool.snapshot().unwrap(); + let snapshot_txs = unconfirmed_pool.snapshot(); assert_eq!(snapshot_txs.len(), 5); assert!(snapshot_txs.contains(&tx1)); assert!(snapshot_txs.contains(&tx2)); @@ -259,43 +345,31 @@ mod test { let _ = unconfirmed_pool.remove_published_and_discard_double_spends(&published_block); assert_eq!( - unconfirmed_pool - .has_tx_with_excess_sig(&tx1.body.kernels()[0].excess_sig) - .unwrap(), + unconfirmed_pool.has_tx_with_excess_sig(&tx1.body.kernels()[0].excess_sig), false ); assert_eq!( - unconfirmed_pool - .has_tx_with_excess_sig(&tx2.body.kernels()[0].excess_sig) - .unwrap(), + unconfirmed_pool.has_tx_with_excess_sig(&tx2.body.kernels()[0].excess_sig), true ); assert_eq!( - unconfirmed_pool - .has_tx_with_excess_sig(&tx3.body.kernels()[0].excess_sig) - .unwrap(), + unconfirmed_pool.has_tx_with_excess_sig(&tx3.body.kernels()[0].excess_sig), false ); assert_eq!( - unconfirmed_pool - .has_tx_with_excess_sig(&tx4.body.kernels()[0].excess_sig) - .unwrap(), + unconfirmed_pool.has_tx_with_excess_sig(&tx4.body.kernels()[0].excess_sig), true ); assert_eq!( - unconfirmed_pool - .has_tx_with_excess_sig(&tx5.body.kernels()[0].excess_sig) - .unwrap(), + unconfirmed_pool.has_tx_with_excess_sig(&tx5.body.kernels()[0].excess_sig), false ); assert_eq!( - unconfirmed_pool - .has_tx_with_excess_sig(&tx6.body.kernels()[0].excess_sig) - .unwrap(), + unconfirmed_pool.has_tx_with_excess_sig(&tx6.body.kernels()[0].excess_sig), false ); - assert!(unconfirmed_pool.check_status().unwrap()); + assert!(unconfirmed_pool.check_status()); } #[test] @@ -314,7 +388,7 @@ mod test { let tx5 = Arc::new(tx5); let tx6 = Arc::new(tx6); - let unconfirmed_pool = UnconfirmedPool::new(UnconfirmedPoolConfig { + let mut unconfirmed_pool = UnconfirmedPool::new(UnconfirmedPoolConfig { storage_capacity: 10, weight_tx_skip_count: 3, }); @@ -336,47 +410,33 @@ mod test { &consensus_constants, ); - let _ = unconfirmed_pool - .remove_published_and_discard_double_spends(&published_block) - .unwrap(); // Double spends are discarded + let _ = unconfirmed_pool.remove_published_and_discard_double_spends(&published_block); // Double spends are discarded assert_eq!( - unconfirmed_pool - .has_tx_with_excess_sig(&tx1.body.kernels()[0].excess_sig) - .unwrap(), + unconfirmed_pool.has_tx_with_excess_sig(&tx1.body.kernels()[0].excess_sig), false ); assert_eq!( - unconfirmed_pool - .has_tx_with_excess_sig(&tx2.body.kernels()[0].excess_sig) - .unwrap(), + unconfirmed_pool.has_tx_with_excess_sig(&tx2.body.kernels()[0].excess_sig), false ); assert_eq!( - unconfirmed_pool - .has_tx_with_excess_sig(&tx3.body.kernels()[0].excess_sig) - .unwrap(), + unconfirmed_pool.has_tx_with_excess_sig(&tx3.body.kernels()[0].excess_sig), false ); assert_eq!( - unconfirmed_pool - .has_tx_with_excess_sig(&tx4.body.kernels()[0].excess_sig) - .unwrap(), + unconfirmed_pool.has_tx_with_excess_sig(&tx4.body.kernels()[0].excess_sig), true ); assert_eq!( - unconfirmed_pool - .has_tx_with_excess_sig(&tx5.body.kernels()[0].excess_sig) - .unwrap(), + unconfirmed_pool.has_tx_with_excess_sig(&tx5.body.kernels()[0].excess_sig), false ); assert_eq!( - unconfirmed_pool - .has_tx_with_excess_sig(&tx6.body.kernels()[0].excess_sig) - .unwrap(), + unconfirmed_pool.has_tx_with_excess_sig(&tx6.body.kernels()[0].excess_sig), false ); - assert!(unconfirmed_pool.check_status().unwrap()); + assert!(unconfirmed_pool.check_status()); } } diff --git a/base_layer/core/src/mempool/unconfirmed_pool/unconfirmed_pool_storage.rs b/base_layer/core/src/mempool/unconfirmed_pool/unconfirmed_pool_storage.rs deleted file mode 100644 index 13d7d8f346..0000000000 --- a/base_layer/core/src/mempool/unconfirmed_pool/unconfirmed_pool_storage.rs +++ /dev/null @@ -1,210 +0,0 @@ -// Copyright 2019 The Tari Project -// -// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the -// following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following -// disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the -// following disclaimer in the documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote -// products derived from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, -// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -use crate::{ - blocks::Block, - mempool::{ - priority::{FeePriority, PrioritizedTransaction}, - unconfirmed_pool::{UnconfirmedPoolConfig, UnconfirmedPoolError}, - }, - transactions::{transaction::Transaction, types::Signature}, -}; -use log::*; -use std::{ - collections::{BTreeMap, HashMap}, - convert::TryFrom, - sync::Arc, -}; -use tari_crypto::tari_utilities::hex::Hex; - -pub const LOG_TARGET: &str = "c::mp::unconfirmed_pool::unconfirmed_pool_storage"; - -/// UnconfirmedPool makes use of UnconfirmedPoolStorage to provide thread save access to its Hashmap and BTreeMap. -/// The txs_by_signature HashMap is used to find a transaction using its excess_sig, this functionality is used to match -/// transactions included in blocks with transactions stored in the pool. The txs_by_priority BTreeMap prioritise the -/// transactions in the pool according to TXPriority, it allows transactions to be inserted in sorted order by their -/// priority. The txs_by_priority BTreeMap makes it easier to select the set of highest priority transactions that can -/// be included in a block. The excess_sig of a transaction is used a key to uniquely identify a specific transaction in -/// these containers. -pub struct UnconfirmedPoolStorage { - config: UnconfirmedPoolConfig, - txs_by_signature: HashMap, - txs_by_priority: BTreeMap, -} - -impl UnconfirmedPoolStorage { - /// Create a new UnconfirmedPoolStorage with the specified configuration - pub fn new(config: UnconfirmedPoolConfig) -> Self { - Self { - config, - txs_by_signature: HashMap::new(), - txs_by_priority: BTreeMap::new(), - } - } - - fn lowest_priority(&self) -> &FeePriority { - self.txs_by_priority.iter().next().unwrap().0 - } - - fn remove_lowest_priority_tx(&mut self) { - if let Some((priority, sig)) = self.txs_by_priority.iter().next().map(|(p, s)| (p.clone(), s.clone())) { - self.txs_by_signature.remove(&sig); - self.txs_by_priority.remove(&priority); - } - } - - /// Insert a new transaction into the UnconfirmedPoolStorage. Low priority transactions will be removed to make - /// space for higher priority transactions. The lowest priority transactions will be removed when the maximum - /// capacity is reached and the new transaction has a higher priority than the currently stored lowest priority - /// transaction. - pub fn insert(&mut self, tx: Arc) -> Result<(), UnconfirmedPoolError> { - let tx_key = tx.body.kernels()[0].excess_sig.clone(); - if !self.txs_by_signature.contains_key(&tx_key) { - debug!( - target: LOG_TARGET, - "Inserting tx into unconfirmed pool: {}", - tx_key.get_signature().to_hex() - ); - trace!(target: LOG_TARGET, "Transaction inserted: {}", tx); - let prioritized_tx = PrioritizedTransaction::try_from((*tx).clone())?; - if self.txs_by_signature.len() >= self.config.storage_capacity { - if prioritized_tx.priority < *self.lowest_priority() { - return Ok(()); - } - self.remove_lowest_priority_tx(); - } - - self.txs_by_priority - .insert(prioritized_tx.priority.clone(), tx_key.clone()); - self.txs_by_signature.insert(tx_key, prioritized_tx); - } - Ok(()) - } - - /// Insert a set of new transactions into the UnconfirmedPoolStorage - pub fn insert_txs(&mut self, txs: Vec>) -> Result<(), UnconfirmedPoolError> { - for tx in txs.into_iter() { - self.insert(tx)?; - } - Ok(()) - } - - /// Check if a transaction is stored in the UnconfirmedPoolStorage - pub fn has_tx_with_excess_sig(&self, excess_sig: &Signature) -> bool { - self.txs_by_signature.contains_key(excess_sig) - } - - /// Returns a set of the highest priority unconfirmed transactions, that can be included in a block. - pub fn highest_priority_txs(&self, total_weight: u64) -> Result>, UnconfirmedPoolError> { - let mut selected_txs: Vec> = Vec::new(); - let mut curr_weight: u64 = 0; - let mut curr_skip_count: usize = 0; - for (_, tx_key) in self.txs_by_priority.iter().rev() { - let ptx = self - .txs_by_signature - .get(tx_key) - .ok_or_else(|| UnconfirmedPoolError::StorageOutofSync)?; - - if curr_weight + ptx.weight <= total_weight { - curr_weight += ptx.weight; - selected_txs.push(ptx.transaction.clone()); - } else { - // Check if some the next few txs with slightly lower priority wont fit in the remaining space. - curr_skip_count += 1; - if curr_skip_count >= self.config.weight_tx_skip_count { - break; - } - } - } - Ok(selected_txs) - } - - // Remove double-spends from the UnconfirmedPoolStorage. These transactions were orphaned by the provided published - // block. Check if any of the unspent transactions in the UnconfirmedPool has inputs that was spent by the provided - // published block. - fn discard_double_spends(&mut self, published_block: &Block) { - let mut removed_tx_keys: Vec = Vec::new(); - for (tx_key, ptx) in self.txs_by_signature.iter() { - for input in ptx.transaction.body.inputs() { - if published_block.body.inputs().contains(input) { - self.txs_by_priority.remove(&ptx.priority); - removed_tx_keys.push(tx_key.clone()); - } - } - } - - for tx_key in &removed_tx_keys { - trace!( - target: LOG_TARGET, - "Removing double spends from unconfirmed pool: {:?}", - tx_key - ); - self.txs_by_signature.remove(&tx_key); - } - } - - /// Remove all published transactions from the UnconfirmedPoolStorage and discard double spends - pub fn remove_published_and_discard_double_spends(&mut self, published_block: &Block) -> Vec> { - let mut removed_txs: Vec> = Vec::new(); - published_block.body.kernels().iter().for_each(|kernel| { - if let Some(ptx) = self.txs_by_signature.get(&kernel.excess_sig) { - self.txs_by_priority.remove(&ptx.priority); - removed_txs.push(self.txs_by_signature.remove(&kernel.excess_sig).unwrap().transaction); - } - }); - // First remove published transactions before discarding double spends - self.discard_double_spends(published_block); - - removed_txs - } - - /// Returns the total number of unconfirmed transactions stored in the UnconfirmedPoolStorage - pub fn len(&self) -> usize { - self.txs_by_signature.len() - } - - /// Returns all transaction stored in the UnconfirmedPoolStorage. - pub fn snapshot(&self) -> Vec> { - self.txs_by_signature - .iter() - .map(|(_, ptx)| ptx.transaction.clone()) - .collect() - } - - /// Returns the total weight of all transactions stored in the pool. - pub fn calculate_weight(&self) -> u64 { - self.txs_by_signature - .iter() - .fold(0, |weight, (_, ptx)| weight + ptx.transaction.calculate_weight()) - } - - #[cfg(test)] - /// Checks the consistency status of the Hashmap and BtreeMap - pub fn check_status(&self) -> bool { - if self.txs_by_priority.len() != self.txs_by_signature.len() { - return false; - } - self.txs_by_priority - .iter() - .all(|(_, tx_key)| self.txs_by_signature.contains_key(tx_key)) - } -} diff --git a/base_layer/core/src/mining/blake_miner.rs b/base_layer/core/src/mining/blake_miner.rs index d45e1150cd..65535eee6a 100644 --- a/base_layer/core/src/mining/blake_miner.rs +++ b/base_layer/core/src/mining/blake_miner.rs @@ -20,10 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - blocks::BlockHeader, - proof_of_work::{Difficulty, ProofOfWork}, -}; +use crate::{blocks::BlockHeader, proof_of_work::ProofOfWork}; use log::*; use rand::{rngs::OsRng, RngCore}; use serde::{Deserialize, Serialize}; @@ -48,20 +45,18 @@ pub struct CpuBlakePow; impl CpuBlakePow { /// A simple miner. It starts with a random nonce and iterates until it finds a header hash that meets the desired /// target - pub fn mine( - target_difficulty: Difficulty, - mut header: BlockHeader, - stop_flag: Arc, - ) -> Option - { + pub fn mine(mut header: BlockHeader, stop_flag: Arc) -> Option { let mut start = Instant::now(); let mut nonce: u64 = OsRng.next_u64(); let mut last_measured_nonce = nonce; // We're mining over here! let mut difficulty = ProofOfWork::achieved_difficulty(&header); info!(target: LOG_TARGET, "Mining started."); - debug!(target: LOG_TARGET, "Mining for difficulty: {:?}", target_difficulty); - while difficulty < target_difficulty { + debug!( + target: LOG_TARGET, + "Mining for difficulty: {:?}", header.pow.target_difficulty + ); + while difficulty < header.pow.target_difficulty { if start.elapsed() >= Duration::from_secs(60) { // nonce might have wrapped around let hashes = if nonce >= last_measured_nonce { diff --git a/base_layer/core/src/mining/miner.rs b/base_layer/core/src/mining/miner.rs index 65b6443902..4c0ea130be 100644 --- a/base_layer/core/src/mining/miner.rs +++ b/base_layer/core/src/mining/miner.rs @@ -29,7 +29,7 @@ use crate::{ chain_storage::BlockAddResult, consensus::ConsensusManager, mining::{blake_miner::CpuBlakePow, error::MinerError, CoinbaseBuilder}, - proof_of_work::{Difficulty, PowAlgorithm}, + proof_of_work::PowAlgorithm, transactions::{ transaction::UnblindedOutput, types::{CryptoFactories, PrivateKey}, @@ -125,11 +125,31 @@ impl Miner { async fn mining(mut self) -> Result { // Lets make sure its set to mine debug!(target: LOG_TARGET, "Miner asking for new candidate block to mine."); - let mut block_template = self.get_block_template().await?; - let output = self.add_coinbase(&mut block_template)?; - let mut block = self.get_block(block_template).await?; + let block_template = self.get_block_template().await; + if block_template.is_err() { + error!( + target: LOG_TARGET, + "Could not get block template from basenode {:?}.", block_template + ); + return Ok(self); + }; + let mut block_template = block_template.unwrap(); + let output = self.add_coinbase(&mut block_template); + if output.is_err() { + error!( + target: LOG_TARGET, + "Could not add coinbase to block template {:?}.", output + ); + return Ok(self); + }; + let output = output.unwrap(); + let block = self.get_block(block_template).await; + if block.is_err() { + error!(target: LOG_TARGET, "Could not get block from basenode {:?}.", block); + return Ok(self); + }; + let mut block = block.unwrap(); debug!(target: LOG_TARGET, "Miner got new block to mine."); - let difficulty = self.get_req_difficulty().await?; let (tx, mut rx): (Sender>, Receiver>) = mpsc::channel(self.threads); for _ in 0..self.threads { let stop_mining_flag = self.stop_mining_flag.clone(); @@ -137,7 +157,7 @@ impl Miner { let mut tx_channel = tx.clone(); trace!("spawning mining thread"); spawn_blocking(move || { - let result = CpuBlakePow::mine(difficulty, header, stop_mining_flag); + let result = CpuBlakePow::mine(header, stop_mining_flag); // send back what the miner found, None will be sent if the miner did not find a nonce if let Err(e) = tx_channel.try_send(result) { warn!(target: LOG_TARGET, "Could not return mining result: {}", e); @@ -151,17 +171,25 @@ impl Miner { // found block, lets ensure we kill all other threads self.stop_mining_flag.store(true, Ordering::Relaxed); block.header = r; - self.send_block(block).await.or_else(|e| { - error!(target: LOG_TARGET, "Could not send block to base node. {:?}.", e); - Err(e) - })?; - self.utxo_sender + if self + .send_block(block) + .await + .or_else(|e| { + error!(target: LOG_TARGET, "Could not send block to base node. {:?}.", e); + Err(e) + }) + .is_err() + { + break; + }; + let _ = self + .utxo_sender .try_send(output) .or_else(|e| { error!(target: LOG_TARGET, "Could not send utxo to wallet. {:?}.", e); Err(e) }) - .map_err(|e| MinerError::CommunicationError(e.to_string()))?; + .map_err(|e| MinerError::CommunicationError(e.to_string())); break; } } @@ -197,44 +225,56 @@ impl Miner { if !self.enabled.load(Ordering::Relaxed) { start_mining = false; } + #[allow(clippy::match_bool)] let mining_future = match start_mining { true => task::spawn(self.mining()), false => task::spawn(self.not_mining()), }; - futures::select! { - msg = block_event.select_next_some() => { - match *msg { - BlockEvent::Verified((_, ref result)) => { - if *result == BlockAddResult::Ok{ - stop_mining_flag.store(true, Ordering::Relaxed); - start_mining = true; - }; - }, - _ => (), - } - }, - event = state_event.select_next_some() => { - use StateEvent::*; - stop_mining_flag.store(true, Ordering::Relaxed); - match *event { - BlocksSynchronized | NetworkSilence => { - info!(target: LOG_TARGET, - "Our chain has synchronised with the network, or is a seed node. Starting miner"); - start_mining = true; - }, - FallenBehind(SyncStatus::Lagging(_, _)) => { - info!(target: LOG_TARGET, "Our chain has fallen behind the network. Pausing miner"); - start_mining = false; + // This flag will let the future select loop again if the miner has not been issued a shutdown command. + let mut wait_for_miner = false; + while !wait_for_miner { + futures::select! { + msg = block_event.select_next_some() => { + match *msg { + BlockEvent::Verified((_, ref result)) => { + //Miner does not care if the chain reorg'ed or just added a new block. Both cases means a new chain tip, so it needs to restart. + match *result { + BlockAddResult::Ok | BlockAddResult::ChainReorg(_) => { + stop_mining_flag.store(true, Ordering::Relaxed); + start_mining = true; + wait_for_miner = true; + }, + _ => {} + } }, - _ => {}, + _ => (), + } + }, + event = state_event.select_next_some() => { + use StateEvent::*; + stop_mining_flag.store(true, Ordering::Relaxed); + match *event { + BlocksSynchronized | NetworkSilence => { + info!(target: LOG_TARGET, + "Our chain has synchronised with the network, or is a seed node. Starting miner"); + start_mining = true; + wait_for_miner = true; + }, + FallenBehind(SyncStatus::Lagging(_, _)) => { + info!(target: LOG_TARGET, "Our chain has fallen behind the network. Pausing miner"); + start_mining = false; + wait_for_miner = true; + }, + _ => {wait_for_miner = true;}, + } + }, + _ = kill_signal => { + info!(target: LOG_TARGET, "Mining kill signal received! Miner is shutting down"); + stop_mining_flag.store(true, Ordering::Relaxed); + break 'main; } - }, - _ = kill_signal => { - info!(target: LOG_TARGET, "Mining kill signal received! Miner is shutting down"); - stop_mining_flag.store(true, Ordering::Relaxed); - break 'main; + }; } - }; self = mining_future.await.expect("Miner crashed").expect("Miner crashed"); } debug!(target: LOG_TARGET, "Mining thread stopped."); @@ -245,7 +285,7 @@ impl Miner { trace!(target: LOG_TARGET, "Requesting new block template from node."); Ok(self .node_interface - .get_new_block_template() + .get_new_block_template(PowAlgorithm::Blake) .await .or_else(|e| { error!( @@ -277,24 +317,6 @@ impl Miner { .map_err(|e| MinerError::CommunicationError(e.to_string()))?) } - /// function to get the required difficulty - pub async fn get_req_difficulty(&mut self) -> Result { - trace!(target: LOG_TARGET, "Requesting target difficulty from node"); - Ok(self - .node_interface - .get_target_difficulty(PowAlgorithm::Blake) - .await - .or_else(|e| { - error!( - target: LOG_TARGET, - "Could not get the required difficulty from the base node. {:?}.", e - ); - - Err(e) - }) - .map_err(|e| MinerError::CommunicationError(e.to_string()))?) - } - // add the coinbase to the NewBlockTemplate fn add_coinbase(&self, block: &mut NewBlockTemplate) -> Result { let fees = block.body.get_total_fee(); diff --git a/base_layer/core/src/proof_of_work/diff_adj_manager/diff_adj_manager.rs b/base_layer/core/src/proof_of_work/diff_adj_manager/diff_adj_manager.rs deleted file mode 100644 index ca1de439e3..0000000000 --- a/base_layer/core/src/proof_of_work/diff_adj_manager/diff_adj_manager.rs +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2019. The Tari Project -// -// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the -// following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following -// disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the -// following disclaimer in the documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote -// products derived from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, -// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -use crate::{ - chain_storage::{BlockchainBackend, ChainMetadata}, - consensus::ConsensusConstants, - proof_of_work::{ - diff_adj_manager::{diff_adj_storage::DiffAdjStorage, error::DiffAdjManagerError}, - Difficulty, - PowAlgorithm, - }, -}; -use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}; -use tari_crypto::tari_utilities::epoch_time::EpochTime; - -/// The DiffAdjManager is used to calculate the current target difficulty based on PoW recorded in the latest blocks of -/// the current best chain. -pub struct DiffAdjManager { - diff_adj_storage: Arc>, -} - -impl DiffAdjManager { - /// Constructs a new DiffAdjManager with access to the blockchain db. - pub fn new(consensus_constants: &ConsensusConstants) -> Result { - Ok(Self { - diff_adj_storage: Arc::new(RwLock::new(DiffAdjStorage::new(consensus_constants))), - }) - } - - /// Returns the estimated target difficulty for the specified PoW algorithm at the chain tip. - pub fn get_target_difficulty( - &self, - metadata: &RwLockReadGuard, - db: &RwLockReadGuard, - pow_algo: PowAlgorithm, - ) -> Result - { - self.diff_adj_storage - .write() - .map_err(|_| DiffAdjManagerError::PoisonedAccess)? - .get_target_difficulty(metadata, db, pow_algo) - } - - /// Returns the estimated target difficulty for the specified PoW algorithm and provided height. - pub fn get_target_difficulty_at_height( - &self, - db: &RwLockReadGuard, - pow_algo: PowAlgorithm, - height: u64, - ) -> Result - { - self.diff_adj_storage - .write() - .map_err(|_| DiffAdjManagerError::PoisonedAccess)? - .get_target_difficulty_at_height(db, pow_algo, height) - } - - /// Returns the estimated target difficulty for the specified PoW algorithm and provided height. - pub fn get_target_difficulty_at_height_writeguard( - &self, - db: &RwLockWriteGuard, - pow_algo: PowAlgorithm, - height: u64, - ) -> Result - { - self.diff_adj_storage - .write() - .map_err(|_| DiffAdjManagerError::PoisonedAccess)? - .get_target_difficulty_at_height_writeguard(db, pow_algo, height) - } - - /// Returns the median timestamp of the past 11 blocks at the chain tip. - pub fn get_median_timestamp( - &self, - metadata: &RwLockReadGuard, - db: &RwLockReadGuard, - ) -> Result - { - self.diff_adj_storage - .write() - .map_err(|_| DiffAdjManagerError::PoisonedAccess)? - .get_median_timestamp(metadata, db) - } - - /// Returns the median timestamp of the past 11 blocks at the provided height. - pub fn get_median_timestamp_at_height( - &self, - db: &RwLockReadGuard, - height: u64, - ) -> Result - { - self.diff_adj_storage - .write() - .map_err(|_| DiffAdjManagerError::PoisonedAccess)? - .get_median_timestamp_at_height(db, height) - } - - /// Returns the median timestamp of the past 11 blocks at the provided height. - pub fn get_median_timestamp_at_height_writeguard( - &self, - db: &RwLockWriteGuard, - height: u64, - ) -> Result - { - self.diff_adj_storage - .write() - .map_err(|_| DiffAdjManagerError::PoisonedAccess)? - .get_median_timestamp_at_height_writeguard(db, height) - } -} - -impl Clone for DiffAdjManager { - fn clone(&self) -> Self { - Self { - diff_adj_storage: self.diff_adj_storage.clone(), - } - } -} diff --git a/base_layer/core/src/proof_of_work/diff_adj_manager/diff_adj_storage.rs b/base_layer/core/src/proof_of_work/diff_adj_manager/diff_adj_storage.rs deleted file mode 100644 index e367c20a19..0000000000 --- a/base_layer/core/src/proof_of_work/diff_adj_manager/diff_adj_storage.rs +++ /dev/null @@ -1,462 +0,0 @@ -// Copyright 2019 The Tari Project -// -// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the -// following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following -// disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the -// following disclaimer in the documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote -// products derived from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, -// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -use crate::{ - blocks::blockheader::BlockHash, - chain_storage::{fetch_header, fetch_header_writeguard, BlockchainBackend, ChainMetadata}, - consensus::ConsensusConstants, - proof_of_work::{ - diff_adj_manager::error::DiffAdjManagerError, - difficulty::DifficultyAdjustment, - lwma_diff::LinearWeightedMovingAverage, - Difficulty, - PowAlgorithm, - ProofOfWork, - }, -}; -use log::*; -use std::{ - cmp, - collections::VecDeque, - sync::{RwLockReadGuard, RwLockWriteGuard}, -}; -use tari_crypto::tari_utilities::{epoch_time::EpochTime, hash::Hashable}; - -pub const LOG_TARGET: &str = "c::pow::diff_adj_manager::diff_adj_storage"; - -/// The UpdateState enum is used to specify what update operation should be performed to keep the difficulty adjustment -/// system upto date with the blockchain db. -enum UpdateState { - FullSync, - SyncToTip, - Synced, -} - -/// DiffAdjManager makes use of DiffAdjStorage to provide thread save access to its LinearWeightedMovingAverages for -/// each PoW algorithm. -pub struct DiffAdjStorage { - monero_lwma: LinearWeightedMovingAverage, - blake_lwma: LinearWeightedMovingAverage, - sync_data: Option<(u64, BlockHash)>, - timestamps: VecDeque, - difficulty_block_window: u64, - diff_target_block_interval: u64, - difficulty_max_block_interval: u64, - median_timestamp_count: usize, - min_pow_difficulty: Difficulty, -} - -impl DiffAdjStorage { - /// Constructs a new DiffAdjStorage with access to the blockchain db. - pub fn new(consensus_constants: &ConsensusConstants) -> Self { - Self { - monero_lwma: LinearWeightedMovingAverage::new( - consensus_constants.get_difficulty_block_window() as usize, - consensus_constants.get_diff_target_block_interval(), - consensus_constants.min_pow_difficulty(), - consensus_constants.get_difficulty_max_block_interval(), - ), - blake_lwma: LinearWeightedMovingAverage::new( - consensus_constants.get_difficulty_block_window() as usize, - consensus_constants.get_diff_target_block_interval(), - consensus_constants.min_pow_difficulty(), - consensus_constants.get_difficulty_max_block_interval(), - ), - sync_data: None, - timestamps: VecDeque::new(), - difficulty_block_window: consensus_constants.get_difficulty_block_window(), - median_timestamp_count: consensus_constants.get_median_timestamp_count(), - diff_target_block_interval: consensus_constants.get_diff_target_block_interval(), - min_pow_difficulty: consensus_constants.min_pow_difficulty(), - difficulty_max_block_interval: consensus_constants.get_difficulty_max_block_interval(), - } - } - - // Check if the difficulty adjustment manager is in sync with specified height. It will also check if a full sync - // or update sync needs to be performed. - fn check_sync_state( - &self, - db: &RwLockReadGuard, - block_hash: &BlockHash, - height: u64, - ) -> Result - { - Ok(match &self.sync_data { - Some((sync_height, sync_block_hash)) => { - if *sync_block_hash != *block_hash { - if height < *sync_height { - UpdateState::FullSync - } else { - let header = fetch_header(db, *sync_height)?; - if *sync_block_hash == header.hash() { - UpdateState::SyncToTip - } else { - UpdateState::FullSync - } - } - } else { - UpdateState::Synced - } - }, - None => UpdateState::FullSync, - }) - } - - // Check if the difficulty adjustment manager is in sync with specified height. It will also check if a full sync - // or update sync needs to be performed. - fn check_sync_state_writeguard( - &self, - db: &RwLockWriteGuard, - block_hash: &BlockHash, - height: u64, - ) -> Result - { - Ok(match &self.sync_data { - Some((sync_height, sync_block_hash)) => { - if *sync_block_hash != *block_hash { - if height < *sync_height { - UpdateState::FullSync - } else { - let header = fetch_header_writeguard(db, *sync_height)?; - if *sync_block_hash == header.hash() { - UpdateState::SyncToTip - } else { - UpdateState::FullSync - } - } - } else { - UpdateState::Synced - } - }, - None => UpdateState::FullSync, - }) - } - - // Performs an update on the difficulty adjustment manager based on the detected sync state. - fn update( - &mut self, - db: &RwLockReadGuard, - height: u64, - ) -> Result<(), DiffAdjManagerError> - { - debug!( - target: LOG_TARGET, - "Updating difficulty adjustment manager to height:{}", height - ); - let block_hash = fetch_header(db, height)?.hash(); - match self.check_sync_state(db, &block_hash, height)? { - UpdateState::FullSync => self.sync_full_history(db, block_hash, height)?, - UpdateState::SyncToTip => self.sync_to_chain_tip(db, block_hash, height)?, - UpdateState::Synced => debug!( - target: LOG_TARGET, - "Difficulty adjustment manager is already synced to height:{}", height - ), - }; - Ok(()) - } - - // Performs an update on the difficulty adjustment manager based on the detected sync state. - fn update_writeguard( - &mut self, - db: &RwLockWriteGuard, - height: u64, - ) -> Result<(), DiffAdjManagerError> - { - let block_hash = fetch_header_writeguard(db, height)?.hash(); - match self.check_sync_state_writeguard(db, &block_hash, height)? { - UpdateState::FullSync => self.sync_full_history_writeguard(db, block_hash, height)?, - UpdateState::SyncToTip => self.sync_to_chain_tip_writeguard(db, block_hash, height)?, - UpdateState::Synced => debug!( - target: LOG_TARGET, - "Difficulty adjustment manager is already synced to height:{}", height - ), - }; - Ok(()) - } - - // Retrieves the height of the longest chain from the blockchain db - fn get_height_of_longest_chain( - &self, - metadata: &RwLockReadGuard, - ) -> Result - { - metadata - .height_of_longest_chain - .ok_or_else(|| DiffAdjManagerError::EmptyBlockchain) - } - - /// Returns the estimated target difficulty for the specified PoW algorithm at the chain tip. - pub fn get_target_difficulty( - &mut self, - metadata: &RwLockReadGuard, - db: &RwLockReadGuard, - pow_algo: PowAlgorithm, - ) -> Result - { - let height = self.get_height_of_longest_chain(metadata)?; - self.get_target_difficulty_at_height(db, pow_algo, height) - } - - /// Returns the estimated target difficulty for the specified PoW algorithm and provided height. - pub fn get_target_difficulty_at_height( - &mut self, - db: &RwLockReadGuard, - pow_algo: PowAlgorithm, - height: u64, - ) -> Result - { - self.update(db, height)?; - debug!( - target: LOG_TARGET, - "Getting target difficulty at height:{} for PoW:{}", height, pow_algo - ); - Ok(match pow_algo { - PowAlgorithm::Monero => self.monero_lwma.get_difficulty(), - PowAlgorithm::Blake => cmp::max(self.min_pow_difficulty, self.blake_lwma.get_difficulty()), - }) - } - - /// Returns the estimated target difficulty for the specified PoW algorithm and provided height. - pub fn get_target_difficulty_at_height_writeguard( - &mut self, - db: &RwLockWriteGuard, - pow_algo: PowAlgorithm, - height: u64, - ) -> Result - { - self.update_writeguard(db, height)?; - debug!( - target: LOG_TARGET, - "Getting target difficulty at height:{} for PoW:{}", height, pow_algo - ); - Ok(match pow_algo { - PowAlgorithm::Monero => self.monero_lwma.get_difficulty(), - PowAlgorithm::Blake => cmp::max(self.min_pow_difficulty, self.blake_lwma.get_difficulty()), - }) - } - - /// Returns the median timestamp of the past 11 blocks at the chain tip. - pub fn get_median_timestamp( - &mut self, - metadata: &RwLockReadGuard, - db: &RwLockReadGuard, - ) -> Result - { - let height = self.get_height_of_longest_chain(metadata)?; - self.get_median_timestamp_at_height(db, height) - } - - /// Returns the median timestamp of the past 11 blocks at the provided height. - pub fn get_median_timestamp_at_height( - &mut self, - db: &RwLockReadGuard, - height: u64, - ) -> Result - { - self.update(db, height)?; - let mut length = self.timestamps.len(); - if length == 0 { - return Err(DiffAdjManagerError::EmptyBlockchain); - } - let mut sorted_timestamps: Vec = self.timestamps.clone().into(); - sorted_timestamps.sort(); - trace!(target: LOG_TARGET, "sorted median timestamps: {:?}", sorted_timestamps); - length /= 2; // we want the median, should be index (MEDIAN_TIMESTAMP_COUNT/2) - Ok(sorted_timestamps[length]) - } - - /// Returns the median timestamp of the past 11 blocks at the provided height. - pub fn get_median_timestamp_at_height_writeguard( - &mut self, - db: &RwLockWriteGuard, - height: u64, - ) -> Result - { - self.update_writeguard(db, height)?; - let mut length = self.timestamps.len(); - if length == 0 { - return Err(DiffAdjManagerError::EmptyBlockchain); - } - let mut sorted_timestamps: Vec = self.timestamps.clone().into(); - sorted_timestamps.sort(); - trace!(target: LOG_TARGET, "sorted median timestamps: {:?}", sorted_timestamps); - length /= 2; // we want the median, should be index (MEDIAN_TIMESTAMP_COUNT/2) - Ok(sorted_timestamps[length]) - } - - // Resets the DiffAdjStorage. - fn reset(&mut self) { - debug!(target: LOG_TARGET, "Resetting difficulty adjustment manager LWMAs"); - self.monero_lwma = LinearWeightedMovingAverage::new( - self.difficulty_block_window as usize, - self.diff_target_block_interval, - self.min_pow_difficulty, - self.difficulty_max_block_interval, - ); - self.blake_lwma = LinearWeightedMovingAverage::new( - self.difficulty_block_window as usize, - self.diff_target_block_interval, - self.min_pow_difficulty, - self.difficulty_max_block_interval, - ); - self.sync_data = None; - self.timestamps = VecDeque::new(); - } - - // Adds the new PoW sample to the specific LinearWeightedMovingAverage specified by the PoW algorithm. - fn add(&mut self, timestamp: EpochTime, pow: ProofOfWork) -> Result<(), DiffAdjManagerError> { - debug!( - target: LOG_TARGET, - "Adding timestamp {} for {}", timestamp, pow.pow_algo - ); - match pow.pow_algo { - PowAlgorithm::Monero => { - let target_difficulty = self.monero_lwma.get_difficulty(); - self.monero_lwma.add(timestamp, target_difficulty)? - }, - - PowAlgorithm::Blake => { - let target_difficulty = cmp::max(self.min_pow_difficulty, self.blake_lwma.get_difficulty()); - self.blake_lwma.add(timestamp, target_difficulty)? - }, - } - Ok(()) - } - - // Resets the DiffAdjStorage and perform a full sync using the blockchain db. - fn sync_full_history( - &mut self, - db: &RwLockReadGuard, - best_block: BlockHash, - height_of_longest_chain: u64, - ) -> Result<(), DiffAdjManagerError> - { - self.reset(); - debug!( - target: LOG_TARGET, - "Syncing full difficulty adjustment manager history to height:{}", height_of_longest_chain - ); - - // TODO: Store the target difficulty so that we don't have to calculate it for the whole chain - for height in 0..=height_of_longest_chain { - let header = fetch_header(db, height)?; - // keep MEDIAN_TIMESTAMP_COUNT blocks for median timestamp - // we need to keep the last bunch - self.timestamps.push_back(header.timestamp); - if self.timestamps.len() > self.median_timestamp_count { - let _ = self.timestamps.remove(0); - } - self.add(header.timestamp, header.pow)?; - } - self.sync_data = Some((height_of_longest_chain, best_block)); - - Ok(()) - } - - // Resets the DiffAdjStorage and perform a full sync using the blockchain db. - fn sync_full_history_writeguard( - &mut self, - db: &RwLockWriteGuard, - best_block: BlockHash, - height_of_longest_chain: u64, - ) -> Result<(), DiffAdjManagerError> - { - self.reset(); - debug!( - target: LOG_TARGET, - "Syncing full difficulty adjustment manager history to height:{}", height_of_longest_chain - ); - - // TODO: Store the target difficulty so that we don't have to calculate it for the whole chain - for height in 0..=height_of_longest_chain { - let header = fetch_header_writeguard(db, height)?; - // keep MEDIAN_TIMESTAMP_COUNT blocks for median timestamp - // we need to keep the last bunch - self.timestamps.push_back(header.timestamp); - if self.timestamps.len() > self.median_timestamp_count { - let _ = self.timestamps.remove(0); - } - self.add(header.timestamp, header.pow)?; - } - self.sync_data = Some((height_of_longest_chain, best_block)); - - Ok(()) - } - - // The difficulty adjustment manager has fallen behind, perform an update to the chain tip. - fn sync_to_chain_tip( - &mut self, - db: &RwLockReadGuard, - best_block: BlockHash, - height_of_longest_chain: u64, - ) -> Result<(), DiffAdjManagerError> - { - if let Some((sync_height, _)) = self.sync_data { - debug!( - target: LOG_TARGET, - "Syncing difficulty adjustment manager from height:{} to height:{}", - sync_height, - height_of_longest_chain - ); - for height in (sync_height + 1)..=height_of_longest_chain { - let header = fetch_header(db, height)?; - // add new timestamps - self.timestamps.push_back(header.timestamp); - if self.timestamps.len() > self.median_timestamp_count { - self.timestamps.remove(0); // remove oldest - } - self.add(header.timestamp, header.pow)?; - } - self.sync_data = Some((height_of_longest_chain, best_block)); - } - Ok(()) - } - - // The difficulty adjustment manager has fallen behind, perform an update to the chain tip. - fn sync_to_chain_tip_writeguard( - &mut self, - db: &RwLockWriteGuard, - best_block: BlockHash, - height_of_longest_chain: u64, - ) -> Result<(), DiffAdjManagerError> - { - if let Some((sync_height, _)) = self.sync_data { - debug!( - target: LOG_TARGET, - "Syncing difficulty adjustment manager from height:{} to height:{}", - sync_height, - height_of_longest_chain - ); - for height in (sync_height + 1)..=height_of_longest_chain { - let header = fetch_header_writeguard(db, height)?; - // add new timestamps - self.timestamps.push_back(header.timestamp); - if self.timestamps.len() > self.median_timestamp_count { - self.timestamps.remove(0); // remove oldest - } - self.add(header.timestamp, header.pow)?; - } - self.sync_data = Some((height_of_longest_chain, best_block)); - } - Ok(()) - } -} diff --git a/base_layer/core/src/proof_of_work/error.rs b/base_layer/core/src/proof_of_work/error.rs index 4fa477bd38..69689f21d6 100644 --- a/base_layer/core/src/proof_of_work/error.rs +++ b/base_layer/core/src/proof_of_work/error.rs @@ -28,6 +28,8 @@ pub enum PowError { InvalidProofOfWork, // Target difficulty not achieved AchievedDifficultyTooLow, + // Invalid target difficulty + InvalidTargetDifficulty, } #[derive(Debug, Error, Clone, PartialEq)] diff --git a/base_layer/core/src/proof_of_work/lwma_diff.rs b/base_layer/core/src/proof_of_work/lwma_diff.rs index 3fb3d74c6e..fc9b597be7 100644 --- a/base_layer/core/src/proof_of_work/lwma_diff.rs +++ b/base_layer/core/src/proof_of_work/lwma_diff.rs @@ -104,16 +104,18 @@ impl LinearWeightedMovingAverage { panic!("Difficulty target has overflowed"); } let target = target.ceil() as u64; // difficulty difference of 1 should not matter much, but difficulty should never be below 1, ceil(0.9) = 1 - debug!(target: LOG_TARGET, "New target difficulty: {}", target); + trace!(target: LOG_TARGET, "New target difficulty: {}", target); target.into() } } impl DifficultyAdjustment for LinearWeightedMovingAverage { fn add(&mut self, timestamp: EpochTime, target_difficulty: Difficulty) -> Result<(), DifficultyAdjustmentError> { - debug!( + trace!( target: LOG_TARGET, - "Adding new timestamp and difficulty requested: {:?}, {:?}", timestamp, target_difficulty + "Adding new timestamp and difficulty requested: {:?}, {:?}", + timestamp, + target_difficulty ); self.timestamps.push_back(timestamp); diff --git a/base_layer/core/src/proof_of_work/median_timestamp.rs b/base_layer/core/src/proof_of_work/median_timestamp.rs new file mode 100644 index 0000000000..d638b39959 --- /dev/null +++ b/base_layer/core/src/proof_of_work/median_timestamp.rs @@ -0,0 +1,43 @@ +// Copyright 2019. The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use log::*; +use tari_crypto::tari_utilities::epoch_time::EpochTime; + +pub const LOG_TARGET: &str = "c::pow::median_timestamp"; + +/// Returns the median timestamp for the provided header set. +pub fn get_median_timestamp(mut timestamps: Vec) -> Option { + if timestamps.is_empty() { + return None; + } + timestamps.sort(); + trace!(target: LOG_TARGET, "Sorted median timestamps: {:?}", timestamps); + let mid_index = timestamps.len() / 2; + let median_timestamp = if timestamps.len() % 2 == 0 { + (timestamps[mid_index - 1] + timestamps[mid_index]) / 2 + } else { + timestamps[mid_index] + }; + trace!(target: LOG_TARGET, "Median timestamp:{}", median_timestamp); + Some(median_timestamp) +} diff --git a/base_layer/core/src/proof_of_work/mod.rs b/base_layer/core/src/proof_of_work/mod.rs index feec825e57..e4f5737626 100644 --- a/base_layer/core/src/proof_of_work/mod.rs +++ b/base_layer/core/src/proof_of_work/mod.rs @@ -21,11 +21,14 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mod blake_pow; -mod diff_adj_manager; mod difficulty; mod error; +mod median_timestamp; +#[allow(clippy::enum_variant_names)] mod monero_rx; +#[allow(clippy::module_inception)] mod proof_of_work; +mod target_difficulty; #[cfg(test)] pub use blake_pow::test as blake_test; @@ -33,8 +36,9 @@ pub use blake_pow::test as blake_test; pub mod lwma_diff; pub use blake_pow::{blake_difficulty, blake_difficulty_with_hash}; -pub use diff_adj_manager::{DiffAdjManager, DiffAdjManagerError}; pub use difficulty::{Difficulty, DifficultyAdjustment}; pub use error::{DifficultyAdjustmentError, PowError}; +pub use median_timestamp::get_median_timestamp; pub use monero_rx::monero_difficulty; pub use proof_of_work::{PowAlgorithm, ProofOfWork}; +pub use target_difficulty::get_target_difficulty; diff --git a/base_layer/core/src/proof_of_work/monero_rx.rs b/base_layer/core/src/proof_of_work/monero_rx.rs index 9a6bda110f..06c6d4d221 100644 --- a/base_layer/core/src/proof_of_work/monero_rx.rs +++ b/base_layer/core/src/proof_of_work/monero_rx.rs @@ -24,7 +24,7 @@ use crate::{blocks::BlockHeader, proof_of_work::Difficulty}; use bigint::uint::U256; use derive_error::Error; use monero::blockdata::{block::BlockHeader as MoneroBlockHeader, Transaction as MoneroTransaction}; -use randomx_rs::{RandomXCache, RandomXError, RandomXFlag, RandomXVM}; +use randomx_rs::{RandomXCache, RandomXDataset, RandomXError, RandomXFlag, RandomXVM}; use serde::{Deserialize, Serialize}; use tari_mmr::MerkleProof; @@ -78,13 +78,13 @@ pub fn monero_difficulty(header: &BlockHeader) -> Difficulty { fn monero_difficulty_calculation(header: &BlockHeader) -> Result { let monero = MoneroData::new(header)?; verify_header(&header, &monero)?; - let flags = RandomXFlag::FLAG_DEFAULT; + let flags = RandomXFlag::get_recommended_flags(); let key = monero.key.clone(); let input = create_input_blob(&monero)?; let cache = RandomXCache::new(flags, &key)?; - let vm = RandomXVM::new(flags, &cache, None)?; + let dataset = RandomXDataset::new(flags, &cache, 0)?; + let vm = RandomXVM::new(flags, Some(&cache), Some(&dataset))?; let hash = vm.calculate_hash(&input)?; - let scalar = U256::from_big_endian(&hash); // Big endian so the hash has leading zeroes let result = MAX_TARGET / scalar; let difficulty = u64::from(result).into(); diff --git a/base_layer/core/src/proof_of_work/proof_of_work.rs b/base_layer/core/src/proof_of_work/proof_of_work.rs index 7a6bb3b870..1713bc0325 100644 --- a/base_layer/core/src/proof_of_work/proof_of_work.rs +++ b/base_layer/core/src/proof_of_work/proof_of_work.rs @@ -70,6 +70,8 @@ pub struct ProofOfWork { /// but not including this block, tracked separately. pub accumulated_monero_difficulty: Difficulty, pub accumulated_blake_difficulty: Difficulty, + /// The target difficulty for solving the current block using the specified proof of work algorithm. + pub target_difficulty: Difficulty, /// The algorithm used to mine this block pub pow_algo: PowAlgorithm, /// Supplemental proof of work data. For example for Blake, this would be empty (only the block header is @@ -82,6 +84,7 @@ impl Default for ProofOfWork { Self { accumulated_monero_difficulty: Difficulty::default(), accumulated_blake_difficulty: Difficulty::default(), + target_difficulty: Difficulty::default(), pow_algo: PowAlgorithm::Blake, pow_data: vec![], } @@ -96,6 +99,7 @@ impl ProofOfWork { pow_algo, accumulated_monero_difficulty: Difficulty::default(), accumulated_blake_difficulty: Difficulty::default(), + target_difficulty: Difficulty::default(), pow_data: vec![], } } @@ -141,8 +145,8 @@ impl ProofOfWork { self.accumulated_monero_difficulty = pow.accumulated_monero_difficulty; } - /// Creates anew proof of work from the provided proof of work difficulty with the sum of this proof of work's total - /// cumulative difficulty and the provided `added_difficulty`. + /// Creates a new proof of work from the provided proof of work difficulty with the sum of this proof of work's + /// total cumulative difficulty and the provided `added_difficulty`. pub fn new_from_difficulty(pow: &ProofOfWork, added_difficulty: Difficulty) -> ProofOfWork { let (m, b) = match pow.pow_algo { PowAlgorithm::Monero => ( @@ -157,6 +161,7 @@ impl ProofOfWork { ProofOfWork { accumulated_monero_difficulty: m, accumulated_blake_difficulty: b, + target_difficulty: pow.target_difficulty, pow_algo: pow.pow_algo, pow_data: pow.pow_data.clone(), } diff --git a/base_layer/core/src/proof_of_work/target_difficulty.rs b/base_layer/core/src/proof_of_work/target_difficulty.rs new file mode 100644 index 0000000000..3c2c8ccaeb --- /dev/null +++ b/base_layer/core/src/proof_of_work/target_difficulty.rs @@ -0,0 +1,44 @@ +// Copyright 2019. The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use crate::{ + consensus::ConsensusManagerError, + proof_of_work::{difficulty::DifficultyAdjustment, lwma_diff::LinearWeightedMovingAverage, Difficulty}, +}; +use tari_crypto::tari_utilities::epoch_time::EpochTime; + +/// Returns the estimated target difficulty for the provided set of target difficulties. +pub fn get_target_difficulty( + target_difficulties: Vec<(EpochTime, Difficulty)>, + block_window: usize, + target_time: u64, + initial_difficulty: Difficulty, + max_block_time: u64, +) -> Result +{ + let mut lwma = LinearWeightedMovingAverage::new(block_window, target_time, initial_difficulty, max_block_time); + for target_difficulty in target_difficulties { + lwma.add(target_difficulty.0, target_difficulty.1)? + } + let target_difficulty = lwma.get_difficulty(); + Ok(target_difficulty) +} diff --git a/base_layer/core/src/proto/block.proto b/base_layer/core/src/proto/block.proto index 82fab7046b..969deb9865 100644 --- a/base_layer/core/src/proto/block.proto +++ b/base_layer/core/src/proto/block.proto @@ -15,6 +15,7 @@ message ProofOfWork { uint64 accumulated_monero_difficulty = 2; uint64 accumulated_blake_difficulty = 3; bytes pow_data = 4; + uint64 target_difficulty = 5; } // The BlockHeader contains all the metadata for the block, including proof of work, a link to the previous block diff --git a/base_layer/core/src/proto/block.rs b/base_layer/core/src/proto/block.rs index 0de01d6140..436f9a2069 100644 --- a/base_layer/core/src/proto/block.rs +++ b/base_layer/core/src/proto/block.rs @@ -134,6 +134,7 @@ impl TryFrom for ProofOfWork { pow_algo: PowAlgorithm::try_from(pow.pow_algo)?, accumulated_monero_difficulty: Difficulty::from(pow.accumulated_monero_difficulty), accumulated_blake_difficulty: Difficulty::from(pow.accumulated_blake_difficulty), + target_difficulty: Difficulty::from(pow.target_difficulty), pow_data: pow.pow_data, }) } @@ -145,6 +146,7 @@ impl From for proto::ProofOfWork { pow_algo: pow.pow_algo as u64, accumulated_monero_difficulty: pow.accumulated_monero_difficulty.as_u64(), accumulated_blake_difficulty: pow.accumulated_blake_difficulty.as_u64(), + target_difficulty: pow.target_difficulty.as_u64(), pow_data: pow.pow_data, } } diff --git a/base_layer/core/src/transactions/aggregated_body.rs b/base_layer/core/src/transactions/aggregated_body.rs index f3503dc831..fee9493afe 100644 --- a/base_layer/core/src/transactions/aggregated_body.rs +++ b/base_layer/core/src/transactions/aggregated_body.rs @@ -21,6 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::transactions::{ + fee::Fee, tari_amount::*, transaction::*, types::{BlindingFactor, Commitment, CommitmentFactory, CryptoFactories, PrivateKey, RangeProofService}, @@ -266,6 +267,11 @@ impl AggregateBody { } Ok(()) } + + /// Returns the byte size or weight of a body + pub fn calculate_weight(&self) -> u64 { + Fee::calculate_weight(self.kernels().len(), self.inputs().len(), self.outputs().len()) + } } /// This will strip away the offset of the transaction returning a pure aggregate body diff --git a/base_layer/core/src/transactions/fee.rs b/base_layer/core/src/transactions/fee.rs index 2d642ecc61..7680f0281a 100644 --- a/base_layer/core/src/transactions/fee.rs +++ b/base_layer/core/src/transactions/fee.rs @@ -25,19 +25,25 @@ use crate::transactions::{tari_amount::*, transaction::MINIMUM_TRANSACTION_FEE}; pub struct Fee {} pub const WEIGHT_PER_INPUT: u64 = 1; -pub const WEIGHT_PER_OUTPUT: u64 = 4; -pub const BASE_COST: u64 = 1; +pub const WEIGHT_PER_OUTPUT: u64 = 13; +pub const KERNEL_WEIGHT: u64 = 3; // Constant weight per transaction; covers kernel and part of header. impl Fee { /// Computes the absolute transaction fee given the fee-per-gram, and the size of the transaction - pub fn calculate(fee_per_gram: MicroTari, num_inputs: usize, num_outputs: usize) -> MicroTari { - (BASE_COST + Fee::calculate_weight(num_inputs, num_outputs) * u64::from(fee_per_gram)).into() + pub fn calculate(fee_per_gram: MicroTari, num_kernels: usize, num_inputs: usize, num_outputs: usize) -> MicroTari { + (Fee::calculate_weight(num_kernels, num_inputs, num_outputs) * u64::from(fee_per_gram)).into() } /// Computes the absolute transaction fee using `calculate`, but the resulting fee will always be at least the /// minimum network transaction fee. - pub fn calculate_with_minimum(fee_per_gram: MicroTari, num_inputs: usize, num_outputs: usize) -> MicroTari { - let fee = Fee::calculate(fee_per_gram, num_inputs, num_outputs); + pub fn calculate_with_minimum( + fee_per_gram: MicroTari, + num_kernels: usize, + num_inputs: usize, + num_outputs: usize, + ) -> MicroTari + { + let fee = Fee::calculate(fee_per_gram, num_kernels, num_inputs, num_outputs); if fee < MINIMUM_TRANSACTION_FEE { MINIMUM_TRANSACTION_FEE } else { @@ -46,7 +52,9 @@ impl Fee { } /// Calculate the weight of a transaction based on the number of inputs and outputs - pub fn calculate_weight(num_inputs: usize, num_outputs: usize) -> u64 { - WEIGHT_PER_INPUT * num_inputs as u64 + WEIGHT_PER_OUTPUT * num_outputs as u64 + pub fn calculate_weight(num_kernels: usize, num_inputs: usize, num_outputs: usize) -> u64 { + KERNEL_WEIGHT * num_kernels as u64 + + WEIGHT_PER_INPUT * num_inputs as u64 + + WEIGHT_PER_OUTPUT * num_outputs as u64 } } diff --git a/base_layer/core/src/transactions/helpers.rs b/base_layer/core/src/transactions/helpers.rs index 0eca71307e..fd6a83b9c8 100644 --- a/base_layer/core/src/transactions/helpers.rs +++ b/base_layer/core/src/transactions/helpers.rs @@ -264,7 +264,7 @@ pub fn create_tx( unblinded_inputs.push(input.clone()); stx_builder.with_input(utxo, input); - let estimated_fee = Fee::calculate(fee_per_gram, input_count as usize, output_count as usize); + let estimated_fee = Fee::calculate(fee_per_gram, 1, input_count as usize, output_count as usize); let amount_per_output = (amount - estimated_fee) / output_count; let amount_for_last_output = (amount - estimated_fee) - amount_per_output * (output_count - 1); for i in 0..output_count { diff --git a/base_layer/core/src/transactions/transaction.rs b/base_layer/core/src/transactions/transaction.rs index 33968ede67..356d315b53 100644 --- a/base_layer/core/src/transactions/transaction.rs +++ b/base_layer/core/src/transactions/transaction.rs @@ -25,7 +25,6 @@ use crate::transactions::{ aggregated_body::AggregateBody, - fee::Fee, tari_amount::{uT, MicroTari}, transaction_protocol::{build_challenge, TransactionMetadata}, types::{ @@ -46,6 +45,7 @@ use digest::Input; use serde::{Deserialize, Serialize}; use std::{ cmp::{max, min, Ordering}, + fmt, fmt::{Display, Formatter}, hash::{Hash, Hasher}, ops::Add, @@ -134,6 +134,16 @@ impl Ord for OutputFeatures { } } +impl Display for OutputFeatures { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!( + f, + "OutputFeatures: Flags = {:?}, Maturity = {}", + self.flags, self.maturity + ) + } +} + bitflags! { #[derive(Deserialize, Serialize)] pub struct OutputFlags: u8 { @@ -623,7 +633,7 @@ impl Transaction { /// Returns the byte size or weight of a transaction pub fn calculate_weight(&self) -> u64 { - Fee::calculate_weight(self.body.inputs().len(), self.body.outputs().len()) + self.body.calculate_weight() } /// Returns the total fee allocated to each byte of the transaction diff --git a/base_layer/core/src/transactions/transaction_protocol/mod.rs b/base_layer/core/src/transactions/transaction_protocol/mod.rs index af31eab3b2..244486405b 100644 --- a/base_layer/core/src/transactions/transaction_protocol/mod.rs +++ b/base_layer/core/src/transactions/transaction_protocol/mod.rs @@ -13,6 +13,37 @@ //! illustrates the progression of the two state machines and shows where the public data messages are constructed and //! accepted in each state machine //! +//! The sequence diagram for the single receiver protocol is: +//! +//!
+//! sequenceDiagram +//! participant Sender +//! participant Receiver +//! # +//! activate Sender +//! Sender-->>Sender: initialize transaction +//! deactivate Sender +//! # +//! activate Sender +//! Sender-->>+Receiver: partial tx info +//! Receiver-->>Receiver: validate tx info +//! Receiver-->>Receiver: create new output and sign +//! Receiver-->>-Sender: signed partial transaction +//! deactivate Sender +//! # +//! activate Sender +//! Sender-->>Sender: validate and sign +//! deactivate Sender +//! # +//! alt tx is valid +//! Sender-->>Network: Broadcast transaction +//! else tx is invalid +//! Sender--XSender: Failed +//! end +//!
+//! +//! If there are multiple recipients, the protocol is more involved and requires three rounds of communication: +//! //!
//! sequenceDiagram //! participant Sender diff --git a/base_layer/core/src/transactions/transaction_protocol/recipient.rs b/base_layer/core/src/transactions/transaction_protocol/recipient.rs index 12ed59b623..354fa8a5b1 100644 --- a/base_layer/core/src/transactions/transaction_protocol/recipient.rs +++ b/base_layer/core/src/transactions/transaction_protocol/recipient.rs @@ -30,7 +30,7 @@ use crate::transactions::{ types::{CryptoFactories, MessageHash, PrivateKey, PublicKey, Signature}, }; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; +use std::{collections::HashMap, fmt}; #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub enum RecipientState { @@ -38,6 +38,20 @@ pub enum RecipientState { Failed(TransactionProtocolError), } +impl fmt::Display for RecipientState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use RecipientState::*; + match self { + Finalized(signed_message) => write!( + f, + "Finalized({:?}, maturity = {})", + signed_message.output.features.flags, signed_message.output.features.maturity + ), + Failed(err) => write!(f, "Failed({:?})", err), + } + } +} + /// An enum describing the types of information that a recipient can send back to the receiver #[derive(Debug, Clone, PartialEq)] pub(super) enum RecipientInfo { diff --git a/base_layer/core/src/transactions/transaction_protocol/sender.rs b/base_layer/core/src/transactions/transaction_protocol/sender.rs index af83ba6af0..6cabdf4dab 100644 --- a/base_layer/core/src/transactions/transaction_protocol/sender.rs +++ b/base_layer/core/src/transactions/transaction_protocol/sender.rs @@ -44,6 +44,7 @@ use crate::transactions::{ }; use digest::Digest; use serde::{Deserialize, Serialize}; +use std::fmt; use tari_crypto::{ristretto::pedersen::PedersenCommitment, tari_utilities::ByteArray}; //---------------------------------------- Local Data types ----------------------------------------------------// @@ -458,6 +459,12 @@ impl SenderTransactionProtocol { } } +impl fmt::Display for SenderTransactionProtocol { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.state) + } +} + pub fn calculate_tx_id(pub_nonce: &PublicKey, index: usize) -> u64 { let hash = D::new().chain(pub_nonce.as_bytes()).chain(index.to_le_bytes()).result(); let mut bytes: [u8; 8] = [0u8; 8]; @@ -502,6 +509,45 @@ impl SenderState { } } +impl fmt::Display for SenderState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use SenderState::*; + match self { + Initializing(info) => write!( + f, + "Initializing({} input(s), {} output(s))", + info.inputs.len(), + info.outputs.len() + ), + SingleRoundMessageReady(info) => write!( + f, + "SingleRoundMessageReady({} input(s), {} output(s))", + info.inputs.len(), + info.outputs.len() + ), + CollectingSingleSignature(info) => write!( + f, + "CollectingSingleSignature({} input(s), {} output(s))", + info.inputs.len(), + info.outputs.len() + ), + Finalizing(info) => write!( + f, + "Finalizing({} input(s), {} output(s))", + info.inputs.len(), + info.outputs.len() + ), + FinalizedTransaction(txn) => write!( + f, + "FinalizedTransaction({} input(s), {} output(s))", + txn.body.inputs().len(), + txn.body.outputs().len() + ), + Failed(err) => write!(f, "Failed({:?})", err), + } + } +} + //---------------------------------------- Tests ----------------------------------------------------// #[cfg(test)] @@ -557,7 +603,7 @@ mod test { let b = TestParams::new(); let (utxo, input) = make_input(&mut OsRng, MicroTari(1200), &factories.commitment); let mut builder = SenderTransactionProtocol::builder(1); - let fee = Fee::calculate(MicroTari(20), 1, 1); + let fee = Fee::calculate(MicroTari(20), 1, 1, 1); builder .with_lock_height(0) .with_fee_per_gram(MicroTari(20)) @@ -615,7 +661,7 @@ mod test { let b = TestParams::new(); let (utxo, input) = make_input(&mut OsRng, MicroTari(2500), &factories.commitment); let mut builder = SenderTransactionProtocol::builder(1); - let fee = Fee::calculate(MicroTari(20), 1, 2); + let fee = Fee::calculate(MicroTari(20), 1, 1, 2); builder .with_lock_height(0) .with_fee_per_gram(MicroTari(20)) diff --git a/base_layer/core/src/transactions/transaction_protocol/transaction_initializer.rs b/base_layer/core/src/transactions/transaction_protocol/transaction_initializer.rs index bdc87ecfaf..7804667d54 100644 --- a/base_layer/core/src/transactions/transaction_protocol/transaction_initializer.rs +++ b/base_layer/core/src/transactions/transaction_protocol/transaction_initializer.rs @@ -39,6 +39,7 @@ use crate::transactions::{ }; use digest::Digest; use std::{ + cmp::max, collections::HashMap, fmt::{Debug, Error, Formatter}, }; @@ -166,11 +167,10 @@ impl SenderTransactionInitializer { let num_inputs = self.inputs.len(); let total_being_spent = self.unblinded_inputs.iter().map(|i| i.value).sum::(); let total_to_self = self.outputs.iter().map(|o| o.value).sum::(); - let total_amount = self.amounts.sum().ok_or_else(|| "Not all amounts have been provided")?; let fee_per_gram = self.fee_per_gram.ok_or_else(|| "Fee per gram was not provided")?; - let fee_without_change = Fee::calculate(fee_per_gram, num_inputs, num_outputs); - let fee_with_change = Fee::calculate(fee_per_gram, num_inputs, num_outputs + 1); + let fee_without_change = Fee::calculate(fee_per_gram, 1, num_inputs, num_outputs); + let fee_with_change = Fee::calculate(fee_per_gram, 1, num_inputs, num_outputs + 1); let extra_fee = fee_with_change - fee_without_change; // Subtract with a check on going negative let change_amount = total_being_spent.checked_sub(total_to_self + total_amount + fee_without_change); @@ -272,8 +272,9 @@ impl SenderTransactionInitializer { 1 => RecipientInfo::Single(None), _ => RecipientInfo::Multiple(HashMap::new()), }; - let mut ids = Vec::with_capacity(self.num_recipients); - for i in 0..self.num_recipients { + let num_ids = max(1, self.num_recipients); + let mut ids = Vec::with_capacity(num_ids); + for i in 0..num_ids { ids.push(calculate_tx_id::(&public_nonce, i)); } let sender_info = RawTransactionInfo { @@ -313,7 +314,7 @@ impl SenderTransactionInitializer { #[cfg(test)] mod test { use crate::transactions::{ - fee::{Fee, BASE_COST, WEIGHT_PER_INPUT, WEIGHT_PER_OUTPUT}, + fee::{Fee, KERNEL_WEIGHT, WEIGHT_PER_INPUT, WEIGHT_PER_OUTPUT}, helpers::{make_input, TestParams}, tari_amount::*, transaction::{UnblindedOutput, MAX_TRANSACTION_INPUTS}, @@ -347,10 +348,10 @@ mod test { .with_offset(p.offset) .with_private_nonce(p.nonce); builder.with_output(UnblindedOutput::new(MicroTari(100), p.spend_key, None)); - let (utxo, input) = make_input(&mut OsRng, MicroTari(500), &factories.commitment); + let (utxo, input) = make_input(&mut OsRng, MicroTari(5_000), &factories.commitment); builder.with_input(utxo, input); builder.with_fee_per_gram(MicroTari(20)); - let expected_fee = Fee::calculate(MicroTari(20), 1, 2); + let expected_fee = Fee::calculate(MicroTari(20), 1, 1, 2); // We needed a change input, so this should fail let err = builder.build::(&factories).unwrap_err(); assert_eq!(err.message, "Change spending key was not provided"); @@ -362,7 +363,7 @@ mod test { if let SenderState::Finalizing(info) = result.state { assert_eq!(info.num_recipients, 0, "Number of receivers"); assert_eq!(info.signatures.len(), 0, "Number of signatures"); - assert_eq!(info.ids.len(), 0, "Number of tx_ids"); + assert_eq!(info.ids.len(), 1, "Number of tx_ids"); assert_eq!(info.amounts.len(), 0, "Number of external payment amounts"); assert_eq!(info.metadata.lock_height, 100, "Lock height"); assert_eq!(info.metadata.fee, expected_fee, "Fee"); @@ -380,7 +381,7 @@ mod test { let factories = CryptoFactories::default(); let p = TestParams::new(); let (utxo, input) = make_input(&mut OsRng, MicroTari(500), &factories.commitment); - let expected_fee = Fee::calculate(MicroTari(20), 1, 1); + let expected_fee = Fee::calculate(MicroTari(20), 1, 1, 1); let output = UnblindedOutput::new(MicroTari(500) - expected_fee, p.spend_key, None); // Start the builder let mut builder = SenderTransactionInitializer::new(0); @@ -396,7 +397,7 @@ mod test { if let SenderState::Finalizing(info) = result.state { assert_eq!(info.num_recipients, 0, "Number of receivers"); assert_eq!(info.signatures.len(), 0, "Number of signatures"); - assert_eq!(info.ids.len(), 0, "Number of tx_ids"); + assert_eq!(info.ids.len(), 1, "Number of tx_ids"); assert_eq!(info.amounts.len(), 0, "Number of external payment amounts"); assert_eq!(info.metadata.lock_height, 0, "Lock height"); assert_eq!(info.metadata.fee, expected_fee, "Fee"); @@ -414,8 +415,9 @@ mod test { let factories = CryptoFactories::default(); let p = TestParams::new(); let (utxo, input) = make_input(&mut OsRng, MicroTari(500), &factories.commitment); - let expected_fee = MicroTari::from(BASE_COST + (WEIGHT_PER_INPUT + 1 * WEIGHT_PER_OUTPUT) * 20); // 101, output = 80 - // Pay out so that I should get change, but not enough to pay for the output + let expected_fee = MicroTari::from((KERNEL_WEIGHT + WEIGHT_PER_INPUT + 1 * WEIGHT_PER_OUTPUT) * 20); + // fee == 340, output = 80 + // Pay out so that I should get change, but not enough to pay for the output let output = UnblindedOutput::new(MicroTari(500) - expected_fee - MicroTari(50), p.spend_key, None); // Start the builder let mut builder = SenderTransactionInitializer::new(0); @@ -431,7 +433,7 @@ mod test { if let SenderState::Finalizing(info) = result.state { assert_eq!(info.num_recipients, 0, "Number of receivers"); assert_eq!(info.signatures.len(), 0, "Number of signatures"); - assert_eq!(info.ids.len(), 0, "Number of tx_ids"); + assert_eq!(info.ids.len(), 1, "Number of tx_ids"); assert_eq!(info.amounts.len(), 0, "Number of external payment amounts"); assert_eq!(info.metadata.lock_height, 0, "Lock height"); assert_eq!(info.metadata.fee, expected_fee + MicroTari(50), "Fee"); @@ -511,7 +513,7 @@ mod test { // Create some inputs let factories = CryptoFactories::default(); let p = TestParams::new(); - let (utxo, input) = make_input(&mut OsRng, MicroTari(1000), &factories.commitment); + let (utxo, input) = make_input(&mut OsRng, MicroTari(100_000), &factories.commitment); let output = UnblindedOutput::new(MicroTari(150), p.spend_key, None); // Start the builder let mut builder = SenderTransactionInitializer::new(2); @@ -542,7 +544,7 @@ mod test { let (utxo1, input1) = make_input(&mut OsRng, MicroTari(2000), &factories.commitment); let (utxo2, input2) = make_input(&mut OsRng, MicroTari(3000), &factories.commitment); let weight = MicroTari(30); - let expected_fee = Fee::calculate(weight, 2, 3); + let expected_fee = Fee::calculate(weight, 1, 2, 3); let output = UnblindedOutput::new(MicroTari(1500) - expected_fee, p.spend_key, None); // Start the builder let mut builder = SenderTransactionInitializer::new(1); diff --git a/base_layer/core/src/validation/accum_difficulty_validators.rs b/base_layer/core/src/validation/accum_difficulty_validators.rs new file mode 100644 index 0000000000..c511fc2514 --- /dev/null +++ b/base_layer/core/src/validation/accum_difficulty_validators.rs @@ -0,0 +1,63 @@ +// Copyright 2019. The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use crate::{ + chain_storage::BlockchainBackend, + proof_of_work::Difficulty, + validation::{Validation, ValidationError}, +}; + +/// This validator will check if a provided accumulated difficulty is stronger than the chain tip. +#[derive(Clone)] +pub struct AccumDifficultyValidator {} + +impl Validation for AccumDifficultyValidator { + fn validate(&self, accum_difficulty: &Difficulty, db: &B) -> Result<(), ValidationError> { + let tip_header = db + .fetch_last_header() + .map_err(|e| ValidationError::CustomError(e.to_string()))? + .ok_or_else(|| ValidationError::CustomError("Cannot retrieve tip header. Blockchain DB is empty".into()))?; + if *accum_difficulty <= tip_header.total_accumulated_difficulty_inclusive() { + return Err(ValidationError::WeakerAccumulatedDifficulty); + } + Ok(()) + } +} + +/// This a mock validator that can be used for testing, it will check if a provided accumulated difficulty is equal or +/// stronger than the chain tip. This will simplify testing where small testing blockchains need to be constructed as +/// the accumulated difficulty of preceding blocks don't have to have an increasing accumulated difficulty. +#[derive(Clone)] +pub struct MockAccumDifficultyValidator {} + +impl Validation for MockAccumDifficultyValidator { + fn validate(&self, accum_difficulty: &Difficulty, db: &B) -> Result<(), ValidationError> { + let tip_header = db + .fetch_last_header() + .map_err(|e| ValidationError::CustomError(e.to_string()))? + .ok_or_else(|| ValidationError::CustomError("Cannot retrieve tip header. Blockchain DB is empty".into()))?; + if *accum_difficulty < tip_header.total_accumulated_difficulty_inclusive() { + return Err(ValidationError::WeakerAccumulatedDifficulty); + } + Ok(()) + } +} diff --git a/base_layer/core/src/validation/block_validators.rs b/base_layer/core/src/validation/block_validators.rs index cab6fb83a4..f47158aa5a 100644 --- a/base_layer/core/src/validation/block_validators.rs +++ b/base_layer/core/src/validation/block_validators.rs @@ -27,18 +27,17 @@ use crate::{ BlockValidationError, NewBlockTemplate, }, - chain_storage::{calculate_mmr_roots_writeguard, is_utxo_writeguard, BlockchainBackend, ChainMetadata}, + chain_storage::{calculate_mmr_roots, is_utxo, BlockchainBackend}, consensus::{ConsensusConstants, ConsensusManager}, transactions::{transaction::OutputFlags, types::CryptoFactories}, validation::{ - helpers::{check_achieved_difficulty, check_median_timestamp}, + helpers::{check_achieved_and_target_difficulty, check_median_timestamp}, StatelessValidation, + Validation, ValidationError, - ValidationWriteGuard, }, }; use log::*; -use std::sync::RwLockWriteGuard; use tari_crypto::tari_utilities::{hash::Hashable, hex::Hex}; pub const LOG_TARGET: &str = "c::val::block_validators"; @@ -64,6 +63,7 @@ impl StatelessValidation for StatelessBlockValidator { /// 1. Are all inputs allowed to be spent (Are the feature flags satisfied) fn validate(&self, block: &Block) -> Result<(), ValidationError> { check_coinbase_output(block, &self.consensus_constants)?; + check_block_weight(block, &self.consensus_constants)?; // Check that the inputs are are allowed to be spent block.check_stxo_rules().map_err(BlockValidationError::from)?; check_cut_through(block)?; @@ -84,7 +84,7 @@ impl FullConsensusValidator { } } -impl ValidationWriteGuard for FullConsensusValidator { +impl Validation for FullConsensusValidator { /// The consensus checks that are done (in order of cheapest to verify to most expensive): /// 1. Does the block satisfy the stateless checks? /// 1. Are all inputs currently in the UTXO set? @@ -93,13 +93,7 @@ impl ValidationWriteGuard for FullConsensusValid /// 1. Is the block header timestamp greater than the median timestamp? /// 1. Is the Proof of Work valid? /// 1. Is the achieved difficulty of this block >= the target difficulty for this block? - fn validate( - &self, - block: &Block, - db: &RwLockWriteGuard, - metadata: &RwLockWriteGuard, - ) -> Result<(), ValidationError> - { + fn validate(&self, block: &Block, db: &B) -> Result<(), ValidationError> { trace!( target: LOG_TARGET, "Validating block at height {} with hash: {}", @@ -107,15 +101,20 @@ impl ValidationWriteGuard for FullConsensusValid block.hash().to_hex() ); check_coinbase_output(block, &self.rules.consensus_constants())?; + check_block_weight(block, &self.rules.consensus_constants())?; check_cut_through(block)?; block.check_stxo_rules().map_err(BlockValidationError::from)?; check_accounting_balance(block, self.rules.clone(), &self.factories)?; - check_inputs_are_utxos(block, &db)?; - check_mmr_roots(block, &db)?; + check_inputs_are_utxos(block, db)?; + check_mmr_roots(block, db)?; check_timestamp_ftl(&block.header, &self.rules)?; - let tip_height = metadata.height_of_longest_chain.unwrap_or(0); + let tip_height = db + .fetch_metadata() + .map_err(|e| ValidationError::CustomError(e.to_string()))? + .height_of_longest_chain + .unwrap_or(0); check_median_timestamp(db, &block.header, tip_height, self.rules.clone())?; - check_achieved_difficulty(db, &block.header, tip_height, self.rules.clone())?; + check_achieved_and_target_difficulty(db, &block.header, tip_height, self.rules.clone())?; Ok(()) } } @@ -148,6 +147,22 @@ fn check_accounting_balance( }) } +fn check_block_weight(block: &Block, consensus_constants: &ConsensusConstants) -> Result<(), ValidationError> { + trace!( + target: LOG_TARGET, + "Checking weight of block with hash {}", + block.hash().to_hex() + ); + // The genesis block has a larger weight than other blocks may have so we have to exclude it here + if block.body.calculate_weight() <= consensus_constants.get_max_block_transaction_weight() || + block.header.height == 0 + { + Ok(()) + } else { + Err(BlockValidationError::BlockTooLarge).map_err(ValidationError::from) + } +} + fn check_coinbase_output(block: &Block, consensus_constants: &ConsensusConstants) -> Result<(), ValidationError> { trace!( target: LOG_TARGET, @@ -160,15 +175,11 @@ fn check_coinbase_output(block: &Block, consensus_constants: &ConsensusConstants } /// This function checks that all inputs in the blocks are valid UTXO's to be spend -fn check_inputs_are_utxos( - block: &Block, - db: &RwLockWriteGuard, -) -> Result<(), ValidationError> -{ +fn check_inputs_are_utxos(block: &Block, db: &B) -> Result<(), ValidationError> { trace!(target: LOG_TARGET, "Checking input UXTOs exist",); for utxo in block.body.inputs() { if !(utxo.features.flags.contains(OutputFlags::COINBASE_OUTPUT)) && - !(is_utxo_writeguard(db, utxo.hash())).map_err(|e| ValidationError::CustomError(e.to_string()))? + !(is_utxo(db, utxo.hash())).map_err(|e| ValidationError::CustomError(e.to_string()))? { warn!( target: LOG_TARGET, @@ -203,32 +214,43 @@ fn check_timestamp_ftl( Ok(()) } -fn check_mmr_roots(block: &Block, db: &RwLockWriteGuard) -> Result<(), ValidationError> { +fn check_mmr_roots(block: &Block, db: &B) -> Result<(), ValidationError> { trace!(target: LOG_TARGET, "Checking MMR roots match",); let template = NewBlockTemplate::from(block.clone()); - let tmp_block = - calculate_mmr_roots_writeguard(db, template).map_err(|e| ValidationError::CustomError(e.to_string()))?; + let tmp_block = calculate_mmr_roots(db, template).map_err(|e| ValidationError::CustomError(e.to_string()))?; let tmp_header = &tmp_block.header; let header = &block.header; - if header.kernel_mr != tmp_header.kernel_mr || - header.output_mr != tmp_header.output_mr || - header.range_proof_mr != tmp_header.range_proof_mr - { + if header.kernel_mr != tmp_header.kernel_mr { warn!( target: LOG_TARGET, - "Block header MMR roots in {} do not match calculated roots", + "Block header kernel MMR roots in {} do not match calculated roots", block.hash().to_hex() ); - Err(ValidationError::BlockError(BlockValidationError::MismatchedMmrRoots)) - } else { - Ok(()) - } + return Err(ValidationError::BlockError(BlockValidationError::MismatchedMmrRoots)); + }; + if header.output_mr != tmp_header.output_mr { + warn!( + target: LOG_TARGET, + "Block header output MMR roots in {} do not match calculated roots", + block.hash().to_hex() + ); + return Err(ValidationError::BlockError(BlockValidationError::MismatchedMmrRoots)); + }; + if header.range_proof_mr != tmp_header.range_proof_mr { + warn!( + target: LOG_TARGET, + "Block header range_proof MMR roots in {} do not match calculated roots", + block.hash().to_hex() + ); + return Err(ValidationError::BlockError(BlockValidationError::MismatchedMmrRoots)); + }; + Ok(()) } fn check_cut_through(block: &Block) -> Result<(), ValidationError> { trace!( target: LOG_TARGET, - "Checking coinbase output on block with hash {}", + "Checking cut through on block with hash {}", block.hash().to_hex() ); if !block.body.cut_through_check() { diff --git a/base_layer/core/src/validation/error.rs b/base_layer/core/src/validation/error.rs index d04a5af1cd..2cccee31e9 100644 --- a/base_layer/core/src/validation/error.rs +++ b/base_layer/core/src/validation/error.rs @@ -44,4 +44,8 @@ pub enum ValidationError { // The total expected supply plus the total accumulated (offset) excess does not equal the sum of all UTXO // commitments. InvalidAccountingBalance, + // Transaction contains already spent inputs + ContainsSTxO, + // The recorded chain accumulated difficulty was stronger + WeakerAccumulatedDifficulty, } diff --git a/base_layer/core/src/validation/helpers.rs b/base_layer/core/src/validation/helpers.rs index ad5e291a0c..fc7c997e10 100644 --- a/base_layer/core/src/validation/helpers.rs +++ b/base_layer/core/src/validation/helpers.rs @@ -22,20 +22,20 @@ use crate::{ blocks::blockheader::{BlockHeader, BlockHeaderValidationError}, - chain_storage::BlockchainBackend, + chain_storage::{fetch_headers, BlockchainBackend}, consensus::ConsensusManager, - proof_of_work::PowError, + proof_of_work::{get_target_difficulty, PowError}, validation::ValidationError, }; use log::*; use tari_crypto::tari_utilities::hash::Hashable; pub const LOG_TARGET: &str = "c::val::helpers"; -use std::sync::RwLockWriteGuard; +use crate::{chain_storage::fetch_target_difficulties, proof_of_work::get_median_timestamp}; use tari_crypto::tari_utilities::hex::Hex; /// This function tests that the block timestamp is greater than the median timestamp at the specified height. pub fn check_median_timestamp( - db: &RwLockWriteGuard, + db: &B, block_header: &BlockHeader, height: u64, rules: ConsensusManager, @@ -45,14 +45,18 @@ pub fn check_median_timestamp( if block_header.height == 0 || rules.get_genesis_block_hash() == block_header.hash() { return Ok(()); // Its the genesis block, so we dont have to check median } - let median_timestamp = rules - .get_median_timestamp_at_height_writeguard(db, height) - .or_else(|e| { - error!(target: LOG_TARGET, "Validation could not get median timestamp"); - - Err(e) - }) - .map_err(|_| ValidationError::BlockHeaderError(BlockHeaderValidationError::InvalidTimestamp))?; + trace!(target: LOG_TARGET, "Calculating median timestamp to height:{}", height); + let min_height = height.saturating_sub(rules.consensus_constants().get_median_timestamp_count() as u64); + let block_nums = (min_height..=height).collect(); + let timestamps = fetch_headers(db, block_nums) + .map_err(|e| ValidationError::CustomError(e.to_string()))? + .iter() + .map(|h| h.timestamp) + .collect::>(); + let median_timestamp = get_median_timestamp(timestamps).ok_or({ + error!(target: LOG_TARGET, "Validation could not get median timestamp"); + ValidationError::BlockHeaderError(BlockHeaderValidationError::InvalidTimestamp) + })?; if block_header.timestamp < median_timestamp { warn!( target: LOG_TARGET, @@ -69,8 +73,8 @@ pub fn check_median_timestamp( } /// Calculates the achieved and target difficulties at the specified height and compares them. -pub fn check_achieved_difficulty( - db: &RwLockWriteGuard, +pub fn check_achieved_and_target_difficulty( + db: &B, block_header: &BlockHeader, height: u64, rules: ConsensusManager, @@ -81,19 +85,39 @@ pub fn check_achieved_difficulty( "Checking block has acheived the required difficulty", ); let achieved = block_header.achieved_difficulty(); - let mut target = 1.into(); - if block_header.height > 0 || rules.get_genesis_block_hash() != block_header.hash() { - target = rules - .get_target_difficulty_with_height_writeguard(db, block_header.pow.pow_algo, height) - .or_else(|e| { - error!(target: LOG_TARGET, "Validation could not get achieved difficulty"); - Err(e) - }) - .map_err(|_| { - ValidationError::BlockHeaderError(BlockHeaderValidationError::ProofOfWorkError( - PowError::InvalidProofOfWork, - )) - })?; + let pow_algo = block_header.pow.pow_algo; + let target = if block_header.height > 0 || rules.get_genesis_block_hash() != block_header.hash() { + let constants = rules.consensus_constants(); + let target_difficulties = + fetch_target_difficulties(db, pow_algo, height, constants.get_difficulty_block_window() as usize) + .map_err(|e| ValidationError::CustomError(e.to_string()))?; + get_target_difficulty( + target_difficulties, + constants.get_difficulty_block_window() as usize, + constants.get_diff_target_block_interval(), + constants.min_pow_difficulty(pow_algo), + constants.get_difficulty_max_block_interval(), + ) + .or_else(|e| { + error!(target: LOG_TARGET, "Validation could not get target difficulty"); + Err(e) + }) + .map_err(|_| { + ValidationError::BlockHeaderError(BlockHeaderValidationError::ProofOfWorkError( + PowError::InvalidProofOfWork, + )) + })? + } else { + 1.into() + }; + if block_header.pow.target_difficulty != target { + warn!( + target: LOG_TARGET, + "Recorded header target difficulty {} was incorrect: {}", block_header.pow.target_difficulty, target + ); + return Err(ValidationError::BlockHeaderError( + BlockHeaderValidationError::ProofOfWorkError(PowError::InvalidTargetDifficulty), + )); } if achieved < target { warn!( diff --git a/base_layer/core/src/validation/mocks.rs b/base_layer/core/src/validation/mocks.rs index 0d15828099..4eacb12bbc 100644 --- a/base_layer/core/src/validation/mocks.rs +++ b/base_layer/core/src/validation/mocks.rs @@ -20,12 +20,8 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use super::{StatelessValidation, Validation, ValidationWriteGuard}; -use crate::{ - chain_storage::{BlockchainBackend, ChainMetadata}, - validation::error::ValidationError, -}; -use std::sync::{RwLockReadGuard, RwLockWriteGuard}; +use super::{StatelessValidation, Validation}; +use crate::{chain_storage::BlockchainBackend, validation::error::ValidationError}; #[derive(Clone)] pub struct MockValidator { @@ -39,31 +35,7 @@ impl MockValidator { } impl Validation for MockValidator { - fn validate( - &self, - _item: &T, - _db: &RwLockReadGuard, - _metadata: &RwLockReadGuard, - ) -> Result<(), ValidationError> - { - if self.result { - Ok(()) - } else { - Err(ValidationError::CustomError( - "This mock validator always returns an error".into(), - )) - } - } -} - -impl ValidationWriteGuard for MockValidator { - fn validate( - &self, - _item: &T, - _db: &RwLockWriteGuard, - _metadata: &RwLockWriteGuard, - ) -> Result<(), ValidationError> - { + fn validate(&self, _item: &T, _db: &B) -> Result<(), ValidationError> { if self.result { Ok(()) } else { diff --git a/base_layer/core/src/validation/mod.rs b/base_layer/core/src/validation/mod.rs index 947595fb5d..e7e33b4d66 100644 --- a/base_layer/core/src/validation/mod.rs +++ b/base_layer/core/src/validation/mod.rs @@ -34,12 +34,6 @@ mod traits; pub mod block_validators; pub mod mocks; pub use error::ValidationError; -pub use traits::{ - StatelessValidation, - StatelessValidator, - Validation, - ValidationWriteGuard, - Validator, - ValidatorWriteGuard, -}; +pub use traits::{StatelessValidation, StatelessValidator, Validation, Validator}; +pub mod accum_difficulty_validators; pub mod transaction_validators; diff --git a/base_layer/core/src/validation/traits.rs b/base_layer/core/src/validation/traits.rs index 68f18b3d76..16c7c243c9 100644 --- a/base_layer/core/src/validation/traits.rs +++ b/base_layer/core/src/validation/traits.rs @@ -20,14 +20,9 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - chain_storage::{BlockchainBackend, ChainMetadata}, - validation::error::ValidationError, -}; -use std::sync::{RwLockReadGuard, RwLockWriteGuard}; +use crate::{chain_storage::BlockchainBackend, validation::error::ValidationError}; pub type Validator = Box>; -pub type ValidatorWriteGuard = Box>; pub type StatelessValidator = Box>; /// The core validation trait. Multiple `Validation` implementors can be chained together in a [ValidatorPipeline] to @@ -37,25 +32,7 @@ pub trait Validation: Send + Sync where B: BlockchainBackend { /// General validation code that can run independent of external state - fn validate( - &self, - item: &T, - db: &RwLockReadGuard, - metadata: &RwLockReadGuard, - ) -> Result<(), ValidationError>; -} - -/// A write guard version of the core validation trait that allows access to the db backend using a lock write guard. -pub trait ValidationWriteGuard: Send + Sync -where B: BlockchainBackend -{ - /// General validation code that can run independent of external state - fn validate( - &self, - item: &T, - db: &RwLockWriteGuard, - metadata: &RwLockWriteGuard, - ) -> Result<(), ValidationError>; + fn validate(&self, item: &T, db: &B) -> Result<(), ValidationError>; } /// Stateless version of the core validation trait. diff --git a/base_layer/core/src/validation/transaction_validators.rs b/base_layer/core/src/validation/transaction_validators.rs index e6be4ae5b7..a3e49d9b84 100644 --- a/base_layer/core/src/validation/transaction_validators.rs +++ b/base_layer/core/src/validation/transaction_validators.rs @@ -21,12 +21,11 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - chain_storage::{is_utxo, BlockchainBackend, ChainMetadata}, + chain_storage::{is_stxo, is_utxo, BlockchainBackend}, transactions::{transaction::Transaction, types::CryptoFactories}, validation::{StatelessValidation, Validation, ValidationError}, }; use log::*; -use std::sync::RwLockReadGuard; use tari_crypto::tari_utilities::hash::Hashable; pub const LOG_TARGET: &str = "c::val::transaction_validators"; @@ -63,16 +62,14 @@ impl FullTxValidator { } impl Validation for FullTxValidator { - fn validate( - &self, - tx: &Transaction, - db: &RwLockReadGuard, - metadata: &RwLockReadGuard, - ) -> Result<(), ValidationError> - { + fn validate(&self, tx: &Transaction, db: &B) -> Result<(), ValidationError> { verify_tx(tx, &self.factories)?; - verify_inputs(tx, &db)?; - let tip_height = metadata.height_of_longest_chain.unwrap_or(0); + verify_inputs(tx, db)?; + let tip_height = db + .fetch_metadata() + .map_err(|e| ValidationError::CustomError(e.to_string()))? + .height_of_longest_chain + .unwrap_or(0); verify_timelocks(tx, tip_height)?; Ok(()) } @@ -83,15 +80,13 @@ impl Validation for FullTxValidator { pub struct TxInputAndMaturityValidator {} impl Validation for TxInputAndMaturityValidator { - fn validate( - &self, - tx: &Transaction, - db: &RwLockReadGuard, - metadata: &RwLockReadGuard, - ) -> Result<(), ValidationError> - { - verify_inputs(tx, &db)?; - let tip_height = metadata.height_of_longest_chain.unwrap_or(0); + fn validate(&self, tx: &Transaction, db: &B) -> Result<(), ValidationError> { + verify_inputs(tx, db)?; + let tip_height = db + .fetch_metadata() + .map_err(|e| ValidationError::CustomError(e.to_string()))? + .height_of_longest_chain + .unwrap_or(0); verify_timelocks(tx, tip_height)?; Ok(()) } @@ -101,14 +96,8 @@ impl Validation for TxInputAndMaturityVali pub struct InputTxValidator {} impl Validation for InputTxValidator { - fn validate( - &self, - tx: &Transaction, - db: &RwLockReadGuard, - _metadata: &RwLockReadGuard, - ) -> Result<(), ValidationError> - { - verify_inputs(tx, &db)?; + fn validate(&self, tx: &Transaction, db: &B) -> Result<(), ValidationError> { + verify_inputs(tx, db)?; Ok(()) } } @@ -117,14 +106,12 @@ impl Validation for InputTxValidator { pub struct TimeLockTxValidator {} impl Validation for TimeLockTxValidator { - fn validate( - &self, - tx: &Transaction, - _db: &RwLockReadGuard, - metadata: &RwLockReadGuard, - ) -> Result<(), ValidationError> - { - let tip_height = metadata.height_of_longest_chain.unwrap_or(0); + fn validate(&self, tx: &Transaction, db: &B) -> Result<(), ValidationError> { + let tip_height = db + .fetch_metadata() + .map_err(|e| ValidationError::CustomError(e.to_string()))? + .height_of_longest_chain + .unwrap_or(0); verify_timelocks(tx, tip_height)?; Ok(()) } @@ -147,8 +134,16 @@ fn verify_timelocks(tx: &Transaction, current_height: u64) -> Result<(), Validat } // This function checks that all inputs exist in the provided database backend -fn verify_inputs(tx: &Transaction, db: &RwLockReadGuard) -> Result<(), ValidationError> { +fn verify_inputs(tx: &Transaction, db: &B) -> Result<(), ValidationError> { for input in tx.body.inputs() { + if is_stxo(db, input.hash()).map_err(|e| ValidationError::CustomError(e.to_string()))? { + // we dont want to log this as a node or wallet might retransmit a transaction + debug!( + target: LOG_TARGET, + "Transaction validation failed due to already spent input: {}", input + ); + return Err(ValidationError::ContainsSTxO); + } if !(is_utxo(db, input.hash())).map_err(|e| ValidationError::CustomError(e.to_string()))? { warn!( target: LOG_TARGET, diff --git a/base_layer/core/tests/async_db.rs b/base_layer/core/tests/async_db.rs index 3b8ad8751b..469458741a 100644 --- a/base_layer/core/tests/async_db.rs +++ b/base_layer/core/tests/async_db.rs @@ -134,7 +134,7 @@ fn fetch_async_utxo() { } #[test] -fn async_is_utxo() { +fn async_is_utxo_stxo() { let (db, blocks, outputs, _) = create_blockchain_db_no_cut_through(); let factory = CommitmentFactory::default(); blocks.iter().for_each(|b| println!("{}", b)); @@ -144,6 +144,10 @@ fn async_is_utxo() { // Check using sync functions assert_eq!(db.is_utxo(utxo.hash()), Ok(true)); assert_eq!(db.is_utxo(stxo.hash()), Ok(false)); + + assert_eq!(db.is_stxo(utxo.hash()), Ok(false)); + assert_eq!(db.is_stxo(stxo.hash()), Ok(true)); + test_async(move |rt| { let db = db.clone(); let db2 = db.clone(); @@ -151,10 +155,16 @@ fn async_is_utxo() { rt.spawn(async move { let is_utxo = async_db::is_utxo(db.clone(), utxo.hash()).await; assert_eq!(is_utxo, Ok(true)); + + let is_stxo = async_db::is_stxo(db.clone(), utxo.hash()).await; + assert_eq!(is_stxo, Ok(false)); }); rt.spawn(async move { let is_utxo = async_db::is_utxo(db2.clone(), stxo.hash()).await; assert_eq!(is_utxo, Ok(false)); + + let is_stxo = async_db::is_stxo(db2.clone(), stxo.hash()).await; + assert_eq!(is_stxo, Ok(true)); }); }); } diff --git a/base_layer/core/tests/block_validation.rs b/base_layer/core/tests/block_validation.rs index dc57c3fc62..41a66e0c0b 100644 --- a/base_layer/core/tests/block_validation.rs +++ b/base_layer/core/tests/block_validation.rs @@ -21,11 +21,13 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use tari_core::{ - chain_storage::{BlockchainDatabase, MemoryDatabase, Validators}, + chain_storage::{BlockchainDatabase, BlockchainDatabaseConfig, MemoryDatabase, Validators}, consensus::{ConsensusManagerBuilder, Network}, - proof_of_work::DiffAdjManager, transactions::types::{CryptoFactories, HashDigest}, - validation::block_validators::{FullConsensusValidator, StatelessBlockValidator}, + validation::{ + accum_difficulty_validators::AccumDifficultyValidator, + block_validators::{FullConsensusValidator, StatelessBlockValidator}, + }, }; #[test] @@ -37,10 +39,9 @@ fn test_genesis_block() { let validators = Validators::new( FullConsensusValidator::new(rules.clone(), factories), StatelessBlockValidator::new(&rules.consensus_constants()), + AccumDifficultyValidator {}, ); - let db = BlockchainDatabase::new(backend, &rules, validators).unwrap(); - let diff_adj_manager = DiffAdjManager::new(&rules.consensus_constants()).unwrap(); - rules.set_diff_manager(diff_adj_manager).unwrap(); + let db = BlockchainDatabase::new(backend, &rules, validators, BlockchainDatabaseConfig::default()).unwrap(); let block = rules.get_genesis_block(); let result = db.add_block(block); assert!(result.is_ok()); diff --git a/base_layer/core/tests/chain_storage_tests/chain_backend.rs b/base_layer/core/tests/chain_storage_tests/chain_backend.rs index 0092d9bbdf..4a6d4a7368 100644 --- a/base_layer/core/tests/chain_storage_tests/chain_backend.rs +++ b/base_layer/core/tests/chain_storage_tests/chain_backend.rs @@ -37,6 +37,7 @@ use tari_core::{ }, consensus::{ConsensusConstants, Network}, helpers::create_orphan_block, + proof_of_work::{Difficulty, PowAlgorithm}, transactions::{ helpers::{create_test_kernel, create_utxo}, tari_amount::MicroTari, @@ -44,7 +45,7 @@ use tari_core::{ }, tx, }; -use tari_crypto::tari_utilities::{hex::Hex, Hashable}; +use tari_crypto::tari_utilities::{epoch_time::EpochTime, hex::Hex, Hashable}; use tari_mmr::{MmrCacheConfig, MutableMmr}; use tari_test_utils::paths::create_temporary_data_path; @@ -86,8 +87,19 @@ fn memory_insert_contains_delete_and_fetch_header() { #[test] fn lmdb_insert_contains_delete_and_fetch_header() { - let db = create_lmdb_database(&create_temporary_data_path(), MmrCacheConfig::default()).unwrap(); - insert_contains_delete_and_fetch_header(db); + // Create temporary test folder + let temp_path = create_temporary_data_path(); + + // Perform test + { + let db = create_lmdb_database(&temp_path, MmrCacheConfig::default()).unwrap(); + insert_contains_delete_and_fetch_header(db); + } + + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } } fn insert_contains_delete_and_fetch_utxo(mut db: T) { @@ -120,8 +132,19 @@ fn memory_insert_contains_delete_and_fetch_utxo() { #[test] fn lmdb_insert_contains_delete_and_fetch_utxo() { - let db = create_lmdb_database(&create_temporary_data_path(), MmrCacheConfig::default()).unwrap(); - insert_contains_delete_and_fetch_utxo(db); + // Create temporary test folder + let temp_path = create_temporary_data_path(); + + // Perform test + { + let db = create_lmdb_database(&temp_path, MmrCacheConfig::default()).unwrap(); + insert_contains_delete_and_fetch_utxo(db); + } + + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } } fn insert_contains_delete_and_fetch_kernel(mut db: T) { @@ -155,8 +178,19 @@ fn memory_insert_contains_delete_and_fetch_kernel() { #[test] fn lmdb_insert_contains_delete_and_fetch_kernel() { - let db = create_lmdb_database(&create_temporary_data_path(), MmrCacheConfig::default()).unwrap(); - insert_contains_delete_and_fetch_kernel(db); + // Create temporary test folder + let temp_path = create_temporary_data_path(); + + // Perform test + { + let db = create_lmdb_database(&temp_path, MmrCacheConfig::default()).unwrap(); + insert_contains_delete_and_fetch_kernel(db); + } + + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } } fn insert_contains_delete_and_fetch_orphan(mut db: T, consensus_constants: &ConsensusConstants) { @@ -195,10 +229,21 @@ fn memory_insert_contains_delete_and_fetch_orphan() { #[test] fn lmdb_insert_contains_delete_and_fetch_orphan() { - let network = Network::LocalNet; - let consensus_constants = network.create_consensus_constants(); - let db = create_lmdb_database(&create_temporary_data_path(), MmrCacheConfig::default()).unwrap(); - insert_contains_delete_and_fetch_orphan(db, &consensus_constants); + // Create temporary test folder + let temp_path = create_temporary_data_path(); + + // Perform test + { + let network = Network::LocalNet; + let consensus_constants = network.create_consensus_constants(); + let db = create_lmdb_database(&temp_path, MmrCacheConfig::default()).unwrap(); + insert_contains_delete_and_fetch_orphan(db, &consensus_constants); + } + + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } } fn spend_utxo_and_unspend_stxo(mut db: T) { @@ -258,8 +303,19 @@ fn memory_spend_utxo_and_unspend_stxo() { #[test] fn lmdb_spend_utxo_and_unspend_stxo() { - let db = create_lmdb_database(&create_temporary_data_path(), MmrCacheConfig::default()).unwrap(); - spend_utxo_and_unspend_stxo(db); + // Create temporary test folder + let temp_path = create_temporary_data_path(); + + // Perform test + { + let db = create_lmdb_database(&temp_path, MmrCacheConfig::default()).unwrap(); + spend_utxo_and_unspend_stxo(db); + } + + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } } fn insert_fetch_metadata(mut db: T) { @@ -337,8 +393,19 @@ fn memory_insert_fetch_metadata() { #[test] fn lmdb_insert_fetch_metadata() { - let db = create_lmdb_database(&create_temporary_data_path(), MmrCacheConfig::default()).unwrap(); - insert_fetch_metadata(db); + // Create temporary test folder + let temp_path = create_temporary_data_path(); + + // Perform test + { + let db = create_lmdb_database(&temp_path, MmrCacheConfig::default()).unwrap(); + insert_fetch_metadata(db); + } + + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } } fn fetch_mmr_root_and_proof_for_utxo_and_rp(mut db: T) { @@ -412,8 +479,19 @@ fn memory_fetch_mmr_root_and_proof_for_utxo_and_rp() { #[test] fn lmdb_fetch_mmr_root_and_proof_for_utxo_and_rp() { - let db = create_lmdb_database(&create_temporary_data_path(), MmrCacheConfig::default()).unwrap(); - fetch_mmr_root_and_proof_for_utxo_and_rp(db); + // Create temporary test folder + let temp_path = create_temporary_data_path(); + + // Perform test + { + let db = create_lmdb_database(&temp_path, MmrCacheConfig::default()).unwrap(); + fetch_mmr_root_and_proof_for_utxo_and_rp(db); + } + + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } } fn fetch_mmr_root_and_proof_for_kernel(mut db: T) { @@ -462,8 +540,19 @@ fn memory_fetch_mmr_root_and_proof_for_kernel() { #[test] fn lmdb_fetch_mmr_root_and_proof_for_kernel() { - let db = create_lmdb_database(&create_temporary_data_path(), MmrCacheConfig::default()).unwrap(); - fetch_mmr_root_and_proof_for_kernel(db); + // Create temporary test folder + let temp_path = create_temporary_data_path(); + + // Perform test + { + let db = create_lmdb_database(&temp_path, MmrCacheConfig::default()).unwrap(); + fetch_mmr_root_and_proof_for_kernel(db); + } + + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } } fn fetch_future_mmr_root_for_utxo_and_rp(mut db: T) { @@ -513,8 +602,19 @@ fn memory_fetch_future_mmr_root_for_utxo_and_rp() { #[test] fn lmdb_fetch_future_mmr_root_for_utxo_and_rp() { - let db = create_lmdb_database(&create_temporary_data_path(), MmrCacheConfig::default()).unwrap(); - fetch_future_mmr_root_for_utxo_and_rp(db); + // Create temporary test folder + let temp_path = create_temporary_data_path(); + + // Perform test + { + let db = create_lmdb_database(&temp_path, MmrCacheConfig::default()).unwrap(); + fetch_future_mmr_root_for_utxo_and_rp(db); + } + + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } } fn fetch_future_mmr_root_for_for_kernel(mut db: T) { @@ -552,8 +652,19 @@ fn memory_fetch_future_mmr_root_for_for_kernel() { #[test] fn lmdb_fetch_future_mmr_root_for_for_kernel() { - let db = create_lmdb_database(&create_temporary_data_path(), MmrCacheConfig::default()).unwrap(); - fetch_future_mmr_root_for_for_kernel(db); + // Create temporary test folder + let temp_path = create_temporary_data_path(); + + // Perform test + { + let db = create_lmdb_database(&temp_path, MmrCacheConfig::default()).unwrap(); + fetch_future_mmr_root_for_for_kernel(db); + } + + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } } fn commit_block_and_create_fetch_checkpoint_and_rewind_mmr(mut db: T) { @@ -649,8 +760,19 @@ fn memory_commit_block_and_create_fetch_checkpoint_and_rewind_mmr() { #[test] fn lmdb_commit_block_and_create_fetch_checkpoint_and_rewind_mmr() { - let db = create_lmdb_database(&create_temporary_data_path(), MmrCacheConfig::default()).unwrap(); - commit_block_and_create_fetch_checkpoint_and_rewind_mmr(db); + // Create temporary test folder + let temp_path = create_temporary_data_path(); + + // Perform test + { + let db = create_lmdb_database(&temp_path, MmrCacheConfig::default()).unwrap(); + commit_block_and_create_fetch_checkpoint_and_rewind_mmr(db); + } + + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } } // TODO: Test Needed: fetch_mmr_node @@ -712,10 +834,21 @@ fn memory_for_each_orphan() { #[test] fn lmdb_for_each_orphan() { - let network = Network::LocalNet; - let consensus_constants = network.create_consensus_constants(); - let db = create_lmdb_database(&create_temporary_data_path(), MmrCacheConfig::default()).unwrap(); - for_each_orphan(db, &consensus_constants); + // Create temporary test folder + let temp_path = create_temporary_data_path(); + + // Perform test + { + let network = Network::LocalNet; + let consensus_constants = network.create_consensus_constants(); + let db = create_lmdb_database(&temp_path, MmrCacheConfig::default()).unwrap(); + for_each_orphan(db, &consensus_constants); + } + + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } } fn for_each_kernel(mut db: T) { @@ -761,8 +894,19 @@ fn memory_for_each_kernel() { #[test] fn lmdb_for_each_kernel() { - let db = create_lmdb_database(&create_temporary_data_path(), MmrCacheConfig::default()).unwrap(); - for_each_kernel(db); + // Create temporary test folder + let temp_path = create_temporary_data_path(); + + // Perform test + { + let db = create_lmdb_database(&temp_path, MmrCacheConfig::default()).unwrap(); + for_each_kernel(db); + } + + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } } fn for_each_header(mut db: T) { @@ -808,8 +952,19 @@ fn memory_for_each_header() { #[test] fn lmdb_for_each_header() { - let db = create_lmdb_database(&create_temporary_data_path(), MmrCacheConfig::default()).unwrap(); - for_each_header(db); + // Create temporary test folder + let temp_path = create_temporary_data_path(); + + // Perform test + { + let db = create_lmdb_database(&temp_path, MmrCacheConfig::default()).unwrap(); + for_each_header(db); + } + + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } } fn for_each_utxo(mut db: T) { @@ -856,8 +1011,19 @@ fn memory_for_each_utxo() { #[test] fn lmdb_for_each_utxo() { - let db = create_lmdb_database(&create_temporary_data_path(), MmrCacheConfig::default()).unwrap(); - for_each_utxo(db); + // Create temporary test folder + let temp_path = create_temporary_data_path(); + + // Perform test + { + let db = create_lmdb_database(&temp_path, MmrCacheConfig::default()).unwrap(); + for_each_utxo(db); + } + + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } } #[test] @@ -882,120 +1048,138 @@ fn lmdb_backend_restore() { // Create backend storage let path = create_temporary_data_path(); { - let mut db = create_lmdb_database(&path, MmrCacheConfig::default()).unwrap(); - let mut txn = DbTransaction::new(); - txn.insert_orphan(orphan.clone()); - txn.insert_utxo(utxo1, true); - txn.insert_utxo(utxo2, true); - txn.insert_kernel(kernel, true); - txn.insert_header(header.clone()); - txn.commit_block(); - assert!(db.write(txn).is_ok()); - let mut txn = DbTransaction::new(); - txn.spend_utxo(stxo_hash.clone()); - assert!(db.write(txn).is_ok()); - + { + let mut db = create_lmdb_database(&path, MmrCacheConfig::default()).unwrap(); + let mut txn = DbTransaction::new(); + txn.insert_orphan(orphan.clone()); + txn.insert_utxo(utxo1, true); + txn.insert_utxo(utxo2, true); + txn.insert_kernel(kernel, true); + txn.insert_header(header.clone()); + txn.commit_block(); + assert!(db.write(txn).is_ok()); + let mut txn = DbTransaction::new(); + txn.spend_utxo(stxo_hash.clone()); + assert!(db.write(txn).is_ok()); + + assert_eq!(db.contains(&DbKey::BlockHeader(header.height)), Ok(true)); + assert_eq!(db.contains(&DbKey::BlockHash(header_hash.clone())), Ok(true)); + assert_eq!(db.contains(&DbKey::UnspentOutput(utxo_hash.clone())), Ok(true)); + assert_eq!(db.contains(&DbKey::SpentOutput(stxo_hash.clone())), Ok(true)); + assert_eq!(db.contains(&DbKey::TransactionKernel(kernel_hash.clone())), Ok(true)); + assert_eq!(db.contains(&DbKey::OrphanBlock(orphan_hash.clone())), Ok(true)); + } + // Restore backend storage + let db = create_lmdb_database(&path, MmrCacheConfig::default()).unwrap(); assert_eq!(db.contains(&DbKey::BlockHeader(header.height)), Ok(true)); - assert_eq!(db.contains(&DbKey::BlockHash(header_hash.clone())), Ok(true)); - assert_eq!(db.contains(&DbKey::UnspentOutput(utxo_hash.clone())), Ok(true)); - assert_eq!(db.contains(&DbKey::SpentOutput(stxo_hash.clone())), Ok(true)); - assert_eq!(db.contains(&DbKey::TransactionKernel(kernel_hash.clone())), Ok(true)); - assert_eq!(db.contains(&DbKey::OrphanBlock(orphan_hash.clone())), Ok(true)); + assert_eq!(db.contains(&DbKey::BlockHash(header_hash)), Ok(true)); + assert_eq!(db.contains(&DbKey::UnspentOutput(utxo_hash)), Ok(true)); + assert_eq!(db.contains(&DbKey::SpentOutput(stxo_hash)), Ok(true)); + assert_eq!(db.contains(&DbKey::TransactionKernel(kernel_hash)), Ok(true)); + assert_eq!(db.contains(&DbKey::OrphanBlock(orphan_hash)), Ok(true)); + } + + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&path).exists() { + std::fs::remove_dir_all(&path).unwrap(); } - // Restore backend storage - let db = create_lmdb_database(&path, MmrCacheConfig::default()).unwrap(); - assert_eq!(db.contains(&DbKey::BlockHeader(header.height)), Ok(true)); - assert_eq!(db.contains(&DbKey::BlockHash(header_hash)), Ok(true)); - assert_eq!(db.contains(&DbKey::UnspentOutput(utxo_hash)), Ok(true)); - assert_eq!(db.contains(&DbKey::SpentOutput(stxo_hash)), Ok(true)); - assert_eq!(db.contains(&DbKey::TransactionKernel(kernel_hash)), Ok(true)); - assert_eq!(db.contains(&DbKey::OrphanBlock(orphan_hash)), Ok(true)); } #[test] fn lmdb_mmr_reset_and_commit() { - let factories = CryptoFactories::default(); - let mut db = create_lmdb_database(&create_temporary_data_path(), MmrCacheConfig::default()).unwrap(); + // Create temporary test folder + let temp_path = create_temporary_data_path(); - let (utxo1, _) = create_utxo(MicroTari(10_000), &factories, None); - let (utxo2, _) = create_utxo(MicroTari(15_000), &factories, None); - let kernel1 = create_test_kernel(100.into(), 0); - let kernel2 = create_test_kernel(200.into(), 0); - let mut header1 = BlockHeader::new(0); - header1.height = 1; - let utxo_hash1 = utxo1.hash(); - let utxo_hash2 = utxo2.hash(); - let kernel_hash1 = kernel1.hash(); - let kernel_hash2 = kernel2.hash(); - let rp_hash1 = utxo1.proof.hash(); - let header_hash1 = header1.hash(); - - let mut txn = DbTransaction::new(); - txn.insert_utxo(utxo1, true); - txn.insert_kernel(kernel1, true); - txn.insert_header(header1); - txn.commit_block(); - assert!(db.write(txn).is_ok()); - - // Reset mmrs as a mmr txn failed without applying storage txns. - let mut txn = DbTransaction::new(); - txn.spend_utxo(utxo_hash2.clone()); - txn.commit_block(); - assert!(db.write(txn).is_err()); + // Perform test + { + let factories = CryptoFactories::default(); + let mut db = create_lmdb_database(&temp_path, MmrCacheConfig::default()).unwrap(); + + let (utxo1, _) = create_utxo(MicroTari(10_000), &factories, None); + let (utxo2, _) = create_utxo(MicroTari(15_000), &factories, None); + let kernel1 = create_test_kernel(100.into(), 0); + let kernel2 = create_test_kernel(200.into(), 0); + let mut header1 = BlockHeader::new(0); + header1.height = 1; + let utxo_hash1 = utxo1.hash(); + let utxo_hash2 = utxo2.hash(); + let kernel_hash1 = kernel1.hash(); + let kernel_hash2 = kernel2.hash(); + let rp_hash1 = utxo1.proof.hash(); + let header_hash1 = header1.hash(); - assert_eq!(db.contains(&DbKey::UnspentOutput(utxo_hash1.clone())), Ok(true)); - assert_eq!(db.contains(&DbKey::UnspentOutput(utxo_hash2.clone())), Ok(false)); - assert_eq!(db.contains(&DbKey::SpentOutput(utxo_hash1.clone())), Ok(false)); - assert_eq!(db.contains(&DbKey::SpentOutput(utxo_hash2.clone())), Ok(false)); - assert_eq!(db.contains(&DbKey::TransactionKernel(kernel_hash1.clone())), Ok(true)); - assert_eq!(db.contains(&DbKey::TransactionKernel(kernel_hash2.clone())), Ok(false)); - assert_eq!(db.contains(&DbKey::BlockHash(header_hash1.clone())), Ok(true)); - assert_eq!( - db.fetch_checkpoint(MmrTree::Utxo, 0).unwrap().nodes_added()[0], - utxo_hash1 - ); - assert_eq!( - db.fetch_checkpoint(MmrTree::Kernel, 0).unwrap().nodes_added()[0], - kernel_hash1 - ); - assert_eq!( - db.fetch_checkpoint(MmrTree::RangeProof, 0).unwrap().nodes_added()[0], - rp_hash1 - ); - assert!(db.fetch_checkpoint(MmrTree::Utxo, 1).is_err()); - assert!(db.fetch_checkpoint(MmrTree::Kernel, 1).is_err()); - assert!(db.fetch_checkpoint(MmrTree::RangeProof, 1).is_err()); + let mut txn = DbTransaction::new(); + txn.insert_utxo(utxo1, true); + txn.insert_kernel(kernel1, true); + txn.insert_header(header1); + txn.commit_block(); + assert!(db.write(txn).is_ok()); - // Reset mmrs as a storage txn failed after the mmr txns were applied, ensure the previous state was preserved. - let mut txn = DbTransaction::new(); - txn.spend_utxo(utxo_hash1.clone()); - txn.delete(DbKey::TransactionKernel(kernel_hash1.clone())); - txn.delete(DbKey::TransactionKernel(kernel_hash2.clone())); - txn.commit_block(); - assert!(db.write(txn).is_err()); + // Reset mmrs as a mmr txn failed without applying storage txns. + let mut txn = DbTransaction::new(); + txn.spend_utxo(utxo_hash2.clone()); + txn.commit_block(); + assert!(db.write(txn).is_err()); + + assert_eq!(db.contains(&DbKey::UnspentOutput(utxo_hash1.clone())), Ok(true)); + assert_eq!(db.contains(&DbKey::UnspentOutput(utxo_hash2.clone())), Ok(false)); + assert_eq!(db.contains(&DbKey::SpentOutput(utxo_hash1.clone())), Ok(false)); + assert_eq!(db.contains(&DbKey::SpentOutput(utxo_hash2.clone())), Ok(false)); + assert_eq!(db.contains(&DbKey::TransactionKernel(kernel_hash1.clone())), Ok(true)); + assert_eq!(db.contains(&DbKey::TransactionKernel(kernel_hash2.clone())), Ok(false)); + assert_eq!(db.contains(&DbKey::BlockHash(header_hash1.clone())), Ok(true)); + assert_eq!( + db.fetch_checkpoint(MmrTree::Utxo, 0).unwrap().nodes_added()[0], + utxo_hash1 + ); + assert_eq!( + db.fetch_checkpoint(MmrTree::Kernel, 0).unwrap().nodes_added()[0], + kernel_hash1 + ); + assert_eq!( + db.fetch_checkpoint(MmrTree::RangeProof, 0).unwrap().nodes_added()[0], + rp_hash1 + ); + assert!(db.fetch_checkpoint(MmrTree::Utxo, 1).is_err()); + assert!(db.fetch_checkpoint(MmrTree::Kernel, 1).is_err()); + assert!(db.fetch_checkpoint(MmrTree::RangeProof, 1).is_err()); + + // Reset mmrs as a storage txn failed after the mmr txns were applied, ensure the previous state was preserved. + let mut txn = DbTransaction::new(); + txn.spend_utxo(utxo_hash1.clone()); + txn.delete(DbKey::TransactionKernel(kernel_hash1.clone())); + txn.delete(DbKey::TransactionKernel(kernel_hash2.clone())); + txn.commit_block(); + assert!(db.write(txn).is_err()); + + assert_eq!(db.contains(&DbKey::UnspentOutput(utxo_hash1.clone())), Ok(true)); + assert_eq!(db.contains(&DbKey::UnspentOutput(utxo_hash2.clone())), Ok(false)); + assert_eq!(db.contains(&DbKey::SpentOutput(utxo_hash1.clone())), Ok(false)); + assert_eq!(db.contains(&DbKey::SpentOutput(utxo_hash2)), Ok(false)); + assert_eq!(db.contains(&DbKey::TransactionKernel(kernel_hash1.clone())), Ok(true)); + assert_eq!(db.contains(&DbKey::TransactionKernel(kernel_hash2)), Ok(false)); + assert_eq!(db.contains(&DbKey::BlockHash(header_hash1.clone())), Ok(true)); + assert_eq!( + db.fetch_checkpoint(MmrTree::Utxo, 0).unwrap().nodes_added()[0], + utxo_hash1 + ); + assert_eq!( + db.fetch_checkpoint(MmrTree::Kernel, 0).unwrap().nodes_added()[0], + kernel_hash1 + ); + assert_eq!( + db.fetch_checkpoint(MmrTree::RangeProof, 0).unwrap().nodes_added()[0], + rp_hash1 + ); + assert!(db.fetch_checkpoint(MmrTree::Utxo, 1).is_err()); + assert!(db.fetch_checkpoint(MmrTree::Kernel, 1).is_err()); + assert!(db.fetch_checkpoint(MmrTree::RangeProof, 1).is_err()); + } - assert_eq!(db.contains(&DbKey::UnspentOutput(utxo_hash1.clone())), Ok(true)); - assert_eq!(db.contains(&DbKey::UnspentOutput(utxo_hash2.clone())), Ok(false)); - assert_eq!(db.contains(&DbKey::SpentOutput(utxo_hash1.clone())), Ok(false)); - assert_eq!(db.contains(&DbKey::SpentOutput(utxo_hash2)), Ok(false)); - assert_eq!(db.contains(&DbKey::TransactionKernel(kernel_hash1.clone())), Ok(true)); - assert_eq!(db.contains(&DbKey::TransactionKernel(kernel_hash2)), Ok(false)); - assert_eq!(db.contains(&DbKey::BlockHash(header_hash1.clone())), Ok(true)); - assert_eq!( - db.fetch_checkpoint(MmrTree::Utxo, 0).unwrap().nodes_added()[0], - utxo_hash1 - ); - assert_eq!( - db.fetch_checkpoint(MmrTree::Kernel, 0).unwrap().nodes_added()[0], - kernel_hash1 - ); - assert_eq!( - db.fetch_checkpoint(MmrTree::RangeProof, 0).unwrap().nodes_added()[0], - rp_hash1 - ); - assert!(db.fetch_checkpoint(MmrTree::Utxo, 1).is_err()); - assert!(db.fetch_checkpoint(MmrTree::Kernel, 1).is_err()); - assert!(db.fetch_checkpoint(MmrTree::RangeProof, 1).is_err()); + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } } fn fetch_checkpoint(mut db: T) { @@ -1085,9 +1269,20 @@ fn memory_fetch_checkpoint() { #[test] fn lmdb_fetch_checkpoint() { - let mmr_cache_config = MmrCacheConfig { rewind_hist_len: 1 }; - let db = create_lmdb_database(&create_temporary_data_path(), mmr_cache_config).unwrap(); - fetch_checkpoint(db); + // Create temporary test folder + let temp_path = create_temporary_data_path(); + + // Perform test + { + let mmr_cache_config = MmrCacheConfig { rewind_hist_len: 1 }; + let db = create_lmdb_database(&temp_path, mmr_cache_config).unwrap(); + fetch_checkpoint(db); + } + + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } } fn duplicate_utxo(mut db: T) { @@ -1123,8 +1318,19 @@ fn memory_duplicate_utxo() { #[test] fn lmdb_duplicate_utxo() { - let db = create_lmdb_database(&create_temporary_data_path(), MmrCacheConfig::default()).unwrap(); - duplicate_utxo(db); + // Create temporary test folder + let temp_path = create_temporary_data_path(); + + // Perform test + { + let db = create_lmdb_database(&temp_path, MmrCacheConfig::default()).unwrap(); + duplicate_utxo(db); + } + + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } } fn fetch_last_header(mut db: T) { @@ -1156,6 +1362,113 @@ fn memory_fetch_last_header() { #[test] fn lmdb_fetch_last_header() { - let db = create_lmdb_database(&create_temporary_data_path(), MmrCacheConfig::default()).unwrap(); - fetch_last_header(db); + // Create temporary test folder + let temp_path = create_temporary_data_path(); + + // Perform test + { + let db = create_lmdb_database(&temp_path, MmrCacheConfig::default()).unwrap(); + fetch_last_header(db); + } + + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } +} + +fn fetch_target_difficulties(mut db: T) { + let mut header0 = BlockHeader::new(0); + header0.pow.pow_algo = PowAlgorithm::Blake; + header0.pow.target_difficulty = Difficulty::from(100); + let mut header1 = BlockHeader::from_previous(&header0); + header1.pow.pow_algo = PowAlgorithm::Monero; + header1.pow.target_difficulty = Difficulty::from(1000); + let mut header2 = BlockHeader::from_previous(&header1); + header2.pow.pow_algo = PowAlgorithm::Blake; + header2.pow.target_difficulty = Difficulty::from(2000); + let mut header3 = BlockHeader::from_previous(&header2); + header3.pow.pow_algo = PowAlgorithm::Blake; + header3.pow.target_difficulty = Difficulty::from(3000); + let mut header4 = BlockHeader::from_previous(&header3); + header4.pow.pow_algo = PowAlgorithm::Monero; + header4.pow.target_difficulty = Difficulty::from(200); + let mut header5 = BlockHeader::from_previous(&header4); + header5.pow.pow_algo = PowAlgorithm::Blake; + header5.pow.target_difficulty = Difficulty::from(4000); + assert!(db.fetch_target_difficulties(PowAlgorithm::Blake, 5, 100).is_err()); + assert!(db.fetch_target_difficulties(PowAlgorithm::Monero, 5, 100).is_err()); + + let mut txn = DbTransaction::new(); + txn.insert_header(header0.clone()); + txn.insert_header(header1.clone()); + txn.insert_header(header2.clone()); + txn.insert_header(header3.clone()); + txn.insert_header(header4.clone()); + txn.insert_header(header5.clone()); + txn.insert(DbKeyValuePair::Metadata( + MetadataKey::ChainHeight, + MetadataValue::ChainHeight(Some(header5.height)), + )); + assert!(db.write(txn).is_ok()); + + // Check block window constraint + let desired_targets: Vec<(EpochTime, Difficulty)> = vec![ + (header2.timestamp, header2.pow.target_difficulty), + (header3.timestamp, header3.pow.target_difficulty), + ]; + assert_eq!( + db.fetch_target_difficulties(PowAlgorithm::Blake, header4.height, 2), + Ok(desired_targets) + ); + let desired_targets: Vec<(EpochTime, Difficulty)> = vec![ + (header1.timestamp, header1.pow.target_difficulty), + (header4.timestamp, header4.pow.target_difficulty), + ]; + assert_eq!( + db.fetch_target_difficulties(PowAlgorithm::Monero, header4.height, 2), + Ok(desired_targets) + ); + // Check search from tip to genesis block + let desired_targets: Vec<(EpochTime, Difficulty)> = vec![ + (header0.timestamp, header0.pow.target_difficulty), + (header2.timestamp, header2.pow.target_difficulty), + (header3.timestamp, header3.pow.target_difficulty), + (header5.timestamp, header5.pow.target_difficulty), + ]; + assert_eq!( + db.fetch_target_difficulties(PowAlgorithm::Blake, header5.height, 100), + Ok(desired_targets) + ); + let desired_targets: Vec<(EpochTime, Difficulty)> = vec![ + (header1.timestamp, header1.pow.target_difficulty), + (header4.timestamp, header4.pow.target_difficulty), + ]; + assert_eq!( + db.fetch_target_difficulties(PowAlgorithm::Monero, header5.height, 100), + Ok(desired_targets) + ); +} + +#[test] +fn memory_fetch_target_difficulties() { + let db = MemoryDatabase::::default(); + fetch_target_difficulties(db); +} + +#[test] +fn lmdb_fetch_target_difficulties() { + // Create temporary test folder + let temp_path = create_temporary_data_path(); + + // Perform test + { + let db = create_lmdb_database(&temp_path, MmrCacheConfig::default()).unwrap(); + fetch_target_difficulties(db); + } + + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } } diff --git a/base_layer/core/tests/chain_storage_tests/chain_storage.rs b/base_layer/core/tests/chain_storage_tests/chain_storage.rs index 0ee6859dd3..5c836b881d 100644 --- a/base_layer/core/tests/chain_storage_tests/chain_storage.rs +++ b/base_layer/core/tests/chain_storage_tests/chain_storage.rs @@ -38,7 +38,9 @@ use tari_core::{ chain_storage::{ create_lmdb_database, BlockAddResult, + BlockchainBackend, BlockchainDatabase, + BlockchainDatabaseConfig, ChainStorageError, DbKey, DbTransaction, @@ -56,7 +58,11 @@ use tari_core::{ }, tx, txn_schema, - validation::{block_validators::StatelessBlockValidator, mocks::MockValidator}, + validation::{ + accum_difficulty_validators::MockAccumDifficultyValidator, + block_validators::StatelessBlockValidator, + mocks::MockValidator, + }, }; use tari_crypto::tari_utilities::{hex::Hex, Hashable}; use tari_mmr::{MmrCacheConfig, MutableMmr}; @@ -753,11 +759,15 @@ fn handle_reorg() { #[test] fn store_and_retrieve_blocks() { let mmr_cache_config = MmrCacheConfig { rewind_hist_len: 2 }; - let validators = Validators::new(MockValidator::new(true), MockValidator::new(true)); + let validators = Validators::new( + MockValidator::new(true), + MockValidator::new(true), + MockAccumDifficultyValidator {}, + ); let network = Network::LocalNet; let rules = ConsensusManagerBuilder::new(network).build(); let db = MemoryDatabase::::new(mmr_cache_config); - let store = BlockchainDatabase::new(db, &rules, validators).unwrap(); + let store = BlockchainDatabase::new(db, &rules, validators, BlockchainDatabaseConfig::default()).unwrap(); let block0 = store.fetch_block(0).unwrap().block().clone(); let block1 = append_block(&store, &block0, vec![], &rules.consensus_constants(), 1.into()).unwrap(); @@ -776,11 +786,15 @@ fn store_and_retrieve_blocks() { #[test] fn store_and_retrieve_chain_and_orphan_blocks_with_hashes() { let mmr_cache_config = MmrCacheConfig { rewind_hist_len: 2 }; - let validators = Validators::new(MockValidator::new(true), MockValidator::new(true)); + let validators = Validators::new( + MockValidator::new(true), + MockValidator::new(true), + MockAccumDifficultyValidator {}, + ); let network = Network::LocalNet; let rules = ConsensusManagerBuilder::new(network).build(); let db = MemoryDatabase::::new(mmr_cache_config); - let store = BlockchainDatabase::new(db, &rules, validators).unwrap(); + let store = BlockchainDatabase::new(db, &rules, validators, BlockchainDatabaseConfig::default()).unwrap(); let block0 = store.fetch_block(0).unwrap().block().clone(); let block1 = append_block(&store, &block0, vec![], &rules.consensus_constants(), 1.into()).unwrap(); @@ -798,221 +812,349 @@ fn store_and_retrieve_chain_and_orphan_blocks_with_hashes() { } #[test] -fn total_kernel_excess() { - let network = Network::LocalNet; - let consensus_manager = ConsensusManagerBuilder::new(network).build(); - let store = create_mem_db(&consensus_manager); - let block0 = store.fetch_block(0).unwrap().block().clone(); +fn restore_metadata() { + let path = create_temporary_data_path(); - let kernel1 = create_test_kernel(100.into(), 0); - let kernel2 = create_test_kernel(200.into(), 0); - let kernel3 = create_test_kernel(300.into(), 0); + // Perform test + { + let validators = Validators::new( + MockValidator::new(true), + MockValidator::new(true), + MockAccumDifficultyValidator {}, + ); + let network = Network::LocalNet; + let rules = ConsensusManagerBuilder::new(network).build(); + let block_hash: BlockHash; + { + let db = create_lmdb_database(&path, MmrCacheConfig::default()).unwrap(); + let db = + BlockchainDatabase::new(db, &rules, validators.clone(), BlockchainDatabaseConfig::default()).unwrap(); + + let block0 = db.fetch_block(0).unwrap().block().clone(); + let block1 = append_block(&db, &block0, vec![], &rules.consensus_constants(), 1.into()).unwrap(); + db.add_block(block1.clone()).unwrap(); + block_hash = block1.hash(); + let metadata = db.get_metadata().unwrap(); + assert_eq!(metadata.height_of_longest_chain, Some(1)); + assert_eq!(metadata.best_block, Some(block_hash.clone())); + } + // Restore blockchain db + let db = create_lmdb_database(&path, MmrCacheConfig::default()).unwrap(); + let db = BlockchainDatabase::new(db, &rules, validators, BlockchainDatabaseConfig::default()).unwrap(); - let mut txn = DbTransaction::new(); - txn.insert_kernel(kernel1.clone(), false); - txn.insert_kernel(kernel2.clone(), false); - txn.insert_kernel(kernel3.clone(), false); - assert!(store.commit(txn).is_ok()); + let metadata = db.get_metadata().unwrap(); + assert_eq!(metadata.height_of_longest_chain, Some(1)); + assert_eq!(metadata.best_block, Some(block_hash)); + } - let total_kernel_excess = store.total_kernel_excess().unwrap(); - assert_eq!( - total_kernel_excess, - &(&(block0.body.kernels()[0].excess) + &kernel1.excess) + &(&kernel2.excess + &kernel3.excess) - ); + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&path).exists() { + std::fs::remove_dir_all(&path).unwrap(); + } } #[test] -fn total_kernel_offset() { - let network = Network::LocalNet; - let consensus_manager = ConsensusManagerBuilder::new(network).build(); - let store = create_mem_db(&consensus_manager); - let block0 = store.fetch_block(0).unwrap().block().clone(); - - let header2 = BlockHeader::from_previous(&block0.header); - let header3 = BlockHeader::from_previous(&header2); - let mut txn = DbTransaction::new(); - txn.insert_header(header2.clone()); - txn.insert_header(header3.clone()); - assert!(store.commit(txn).is_ok()); +fn invalid_block() { + let temp_path = create_temporary_data_path(); + { + let factories = CryptoFactories::default(); + let network = Network::LocalNet; + let consensus_constants = ConsensusConstantsBuilder::new(network) + .with_emission_amounts(100_000_000.into(), 0.999, 100.into()) + .build(); + let (block0, output) = create_genesis_block(&factories, &consensus_constants); + let consensus_manager = ConsensusManagerBuilder::new(network) + .with_consensus_constants(consensus_constants.clone()) + .with_block(block0.clone()) + .build(); + let validators = Validators::new( + MockValidator::new(true), + StatelessBlockValidator::new(&consensus_manager.consensus_constants()), + MockAccumDifficultyValidator {}, + ); + let db = create_lmdb_database(&temp_path, MmrCacheConfig::default()).unwrap(); + let mut store = + BlockchainDatabase::new(db, &consensus_manager, validators, BlockchainDatabaseConfig::default()).unwrap(); + let mut blocks = vec![block0]; + let mut outputs = vec![vec![output]]; + let block0_hash = blocks[0].hash(); + let metadata = store.get_metadata().unwrap(); + let utxo_root0 = store.fetch_mmr_root(MmrTree::Utxo).unwrap(); + let kernel_root0 = store.fetch_mmr_root(MmrTree::Kernel).unwrap(); + let rp_root0 = store.fetch_mmr_root(MmrTree::RangeProof).unwrap(); + assert_eq!(metadata.height_of_longest_chain, Some(0)); + assert_eq!(metadata.best_block, Some(block0_hash.clone())); + assert_eq!(store.fetch_block(0).unwrap().block().hash(), block0_hash); + assert!(store.fetch_block(1).is_err()); + + // Block 1 + let txs = vec![txn_schema!( + from: vec![outputs[0][0].clone()], + to: vec![10 * T, 5 * T, 10 * T, 15 * T] + )]; + let coinbase_value = consensus_manager.emission_schedule().block_reward(1); + assert_eq!( + generate_new_block_with_coinbase( + &mut store, + &factories, + &mut blocks, + &mut outputs, + txs, + coinbase_value, + &consensus_manager.consensus_constants() + ), + Ok(BlockAddResult::Ok) + ); + let block1_hash = blocks[1].hash(); + let metadata = store.get_metadata().unwrap(); + let utxo_root1 = store.fetch_mmr_root(MmrTree::Utxo).unwrap(); + let kernel_root1 = store.fetch_mmr_root(MmrTree::Kernel).unwrap(); + let rp_root1 = store.fetch_mmr_root(MmrTree::RangeProof).unwrap(); + assert_eq!(metadata.height_of_longest_chain, Some(1)); + assert_eq!(metadata.best_block, Some(block1_hash.clone())); + assert_eq!(store.fetch_block(0).unwrap().block().hash(), block0_hash); + assert_eq!(store.fetch_block(1).unwrap().block().hash(), block1_hash); + assert!(store.fetch_block(2).is_err()); + assert_ne!(utxo_root0, utxo_root1); + assert_ne!(kernel_root0, kernel_root1); + assert_ne!(rp_root0, rp_root1); + + // Invalid Block 2 - Double spends genesis block output + let txs = vec![txn_schema!(from: vec![outputs[0][0].clone()], to: vec![20 * T, 20 * T])]; + let coinbase_value = consensus_manager.emission_schedule().block_reward(2); + assert_eq!( + generate_new_block_with_coinbase( + &mut store, + &factories, + &mut blocks, + &mut outputs, + txs, + coinbase_value, + &consensus_manager.consensus_constants() + ), + Err(ChainStorageError::UnspendableInput) + ); + let metadata = store.get_metadata().unwrap(); + let utxo_root2 = store.fetch_mmr_root(MmrTree::Utxo).unwrap(); + let kernel_root2 = store.fetch_mmr_root(MmrTree::Kernel).unwrap(); + let rp_root2 = store.fetch_mmr_root(MmrTree::RangeProof).unwrap(); + assert_eq!(metadata.height_of_longest_chain, Some(1)); + assert_eq!(metadata.best_block, Some(block1_hash.clone())); + assert_eq!(store.fetch_block(0).unwrap().block().hash(), block0_hash); + assert_eq!(store.fetch_block(1).unwrap().block().hash(), block1_hash); + assert!(store.fetch_block(2).is_err()); + assert_eq!(utxo_root1, utxo_root2); + assert_eq!(kernel_root1, kernel_root2); + assert_eq!(rp_root1, rp_root2); + + // Valid Block 2 + let txs = vec![txn_schema!(from: vec![outputs[1][0].clone()], to: vec![4 * T, 4 * T])]; + let coinbase_value = consensus_manager.emission_schedule().block_reward(2); + assert_eq!( + generate_new_block_with_coinbase( + &mut store, + &factories, + &mut blocks, + &mut outputs, + txs, + coinbase_value, + &consensus_manager.consensus_constants() + ), + Ok(BlockAddResult::Ok) + ); + let block2_hash = blocks[2].hash(); + let metadata = store.get_metadata().unwrap(); + let utxo_root2 = store.fetch_mmr_root(MmrTree::Utxo).unwrap(); + let kernel_root2 = store.fetch_mmr_root(MmrTree::Kernel).unwrap(); + let rp_root2 = store.fetch_mmr_root(MmrTree::RangeProof).unwrap(); + assert_eq!(metadata.height_of_longest_chain, Some(2)); + assert_eq!(metadata.best_block, Some(block2_hash.clone())); + assert_eq!(store.fetch_block(0).unwrap().block().hash(), block0_hash); + assert_eq!(store.fetch_block(1).unwrap().block().hash(), block1_hash); + assert_eq!(store.fetch_block(2).unwrap().block().hash(), block2_hash); + assert!(store.fetch_block(3).is_err()); + assert_ne!(utxo_root1, utxo_root2); + assert_ne!(kernel_root1, kernel_root2); + assert_ne!(rp_root1, rp_root2); + } - let total_kernel_offset = store.total_kernel_offset().unwrap(); - assert_eq!( - total_kernel_offset, - &(&block0.header.total_kernel_offset + &header2.total_kernel_offset) + &header3.total_kernel_offset - ); + // Cleanup test data - in Windows the LMBD `set_mapsize` sets file size equals to map size; Linux use sparse files + if std::path::Path::new(&temp_path).exists() { + std::fs::remove_dir_all(&temp_path).unwrap(); + } } #[test] -fn total_utxo_commitment() { - let factories = CryptoFactories::default(); +fn orphan_cleanup_on_block_add() { let network = Network::LocalNet; - let gen_block = genesis_block::get_rincewind_genesis_block_raw(); - let consensus_manager = ConsensusManagerBuilder::new(network).with_block(gen_block).build(); - let store = create_mem_db(&consensus_manager); - let block0 = store.fetch_block(0).unwrap().block().clone(); - - let (utxo1, _) = create_utxo(MicroTari(10_000), &factories, None); - let (utxo2, _) = create_utxo(MicroTari(15_000), &factories, None); - let (utxo3, _) = create_utxo(MicroTari(20_000), &factories, None); - - let mut txn = DbTransaction::new(); - txn.insert_utxo(utxo1.clone(), true); - txn.insert_utxo(utxo2.clone(), true); - txn.insert_utxo(utxo3.clone(), true); - assert!(store.commit(txn).is_ok()); - - let total_utxo_commitment = store.total_utxo_commitment().unwrap(); - assert_eq!( - total_utxo_commitment, - &(&(block0.body.outputs()[0].commitment) + &utxo1.commitment) + &(&utxo2.commitment + &utxo3.commitment) + let consensus_manager = ConsensusManagerBuilder::new(network).build(); + let validators = Validators::new( + MockValidator::new(true), + MockValidator::new(true), + MockAccumDifficultyValidator {}, ); + let db = MemoryDatabase::::default(); + let config = BlockchainDatabaseConfig { + orphan_storage_capacity: 3, + }; + let store = BlockchainDatabase::new(db, &consensus_manager, validators, config).unwrap(); + + let orphan1 = create_orphan_block(500, vec![], &consensus_manager.consensus_constants()); + let orphan2 = create_orphan_block(5, vec![], &consensus_manager.consensus_constants()); + let orphan3 = create_orphan_block(30, vec![], &consensus_manager.consensus_constants()); + let orphan4 = create_orphan_block(700, vec![], &consensus_manager.consensus_constants()); + let orphan5 = create_orphan_block(43, vec![], &consensus_manager.consensus_constants()); + let orphan6 = create_orphan_block(75, vec![], &consensus_manager.consensus_constants()); + let orphan7 = create_orphan_block(150, vec![], &consensus_manager.consensus_constants()); + let orphan1_hash = orphan1.hash(); + let orphan2_hash = orphan2.hash(); + let orphan3_hash = orphan3.hash(); + let orphan4_hash = orphan4.hash(); + let orphan5_hash = orphan5.hash(); + let orphan6_hash = orphan6.hash(); + let orphan7_hash = orphan7.hash(); + assert_eq!(store.add_block(orphan1.clone()), Ok(BlockAddResult::OrphanBlock)); + assert_eq!(store.add_block(orphan2), Ok(BlockAddResult::OrphanBlock)); + assert_eq!(store.add_block(orphan3), Ok(BlockAddResult::OrphanBlock)); + assert_eq!(store.add_block(orphan4.clone()), Ok(BlockAddResult::OrphanBlock)); + assert_eq!(store.add_block(orphan5), Ok(BlockAddResult::OrphanBlock)); + assert_eq!(store.add_block(orphan6), Ok(BlockAddResult::OrphanBlock)); + assert_eq!(store.add_block(orphan7.clone()), Ok(BlockAddResult::OrphanBlock)); + + assert_eq!(store.db_read_access().unwrap().get_orphan_count(), Ok(3)); + assert_eq!(store.fetch_orphan(orphan1_hash), Ok(orphan1)); + assert!(store.fetch_orphan(orphan2_hash).is_err()); + assert!(store.fetch_orphan(orphan3_hash).is_err()); + assert_eq!(store.fetch_orphan(orphan4_hash), Ok(orphan4)); + assert!(store.fetch_orphan(orphan5_hash).is_err()); + assert!(store.fetch_orphan(orphan6_hash).is_err()); + assert_eq!(store.fetch_orphan(orphan7_hash), Ok(orphan7)); } #[test] -fn restore_metadata() { - let validators = Validators::new(MockValidator::new(true), MockValidator::new(true)); +fn orphan_cleanup_on_reorg() { + // Create Main Chain let network = Network::LocalNet; - let rules = ConsensusManagerBuilder::new(network).build(); - let block_hash: BlockHash; - let path = create_temporary_data_path(); - { - let db = create_lmdb_database(&path, MmrCacheConfig::default()).unwrap(); - let db = BlockchainDatabase::new(db, &rules, validators.clone()).unwrap(); - - let block0 = db.fetch_block(0).unwrap().block().clone(); - let block1 = append_block(&db, &block0, vec![], &rules.consensus_constants(), 1.into()).unwrap(); - db.add_block(block1.clone()).unwrap(); - block_hash = block1.hash(); - let metadata = db.get_metadata().unwrap(); - assert_eq!(metadata.height_of_longest_chain, Some(1)); - assert_eq!(metadata.best_block, Some(block_hash.clone())); - } - // Restore blockchain db - let db = create_lmdb_database(&path, MmrCacheConfig::default()).unwrap(); - let db = BlockchainDatabase::new(db, &rules, validators).unwrap(); - - let metadata = db.get_metadata().unwrap(); - assert_eq!(metadata.height_of_longest_chain, Some(1)); - assert_eq!(metadata.best_block, Some(block_hash)); -} - -#[test] -fn invalid_block() { let factories = CryptoFactories::default(); - let network = Network::LocalNet; - let consensus_constants = ConsensusConstantsBuilder::new(network) - .with_emission_amounts(100_000_000.into(), 0.999, 100.into()) - .build(); + let consensus_constants = ConsensusConstantsBuilder::new(network).build(); let (block0, output) = create_genesis_block(&factories, &consensus_constants); let consensus_manager = ConsensusManagerBuilder::new(network) - .with_consensus_constants(consensus_constants.clone()) + .with_consensus_constants(consensus_constants) .with_block(block0.clone()) .build(); let validators = Validators::new( MockValidator::new(true), - StatelessBlockValidator::new(&consensus_manager.consensus_constants()), + MockValidator::new(true), + MockAccumDifficultyValidator {}, ); - let db = create_lmdb_database(&create_temporary_data_path(), MmrCacheConfig::default()).unwrap(); - let mut store = BlockchainDatabase::new(db, &consensus_manager, validators).unwrap(); + let db = MemoryDatabase::::default(); + let config = BlockchainDatabaseConfig { + orphan_storage_capacity: 3, + }; + let mut store = BlockchainDatabase::new(db, &consensus_manager, validators, config).unwrap(); let mut blocks = vec![block0]; let mut outputs = vec![vec![output]]; - let block0_hash = blocks[0].hash(); - let metadata = store.get_metadata().unwrap(); - let utxo_root0 = store.fetch_mmr_root(MmrTree::Utxo).unwrap(); - let kernel_root0 = store.fetch_mmr_root(MmrTree::Kernel).unwrap(); - let rp_root0 = store.fetch_mmr_root(MmrTree::RangeProof).unwrap(); - assert_eq!(metadata.height_of_longest_chain, Some(0)); - assert_eq!(metadata.best_block, Some(block0_hash.clone())); - assert_eq!(store.fetch_block(0).unwrap().block().hash(), block0_hash); - assert!(store.fetch_block(1).is_err()); - // Block 1 - let txs = vec![txn_schema!( - from: vec![outputs[0][0].clone()], - to: vec![10 * T, 5 * T, 10 * T, 15 * T] - )]; - let coinbase_value = consensus_manager.emission_schedule().block_reward(1); - assert_eq!( - generate_new_block_with_coinbase( - &mut store, - &factories, - &mut blocks, - &mut outputs, - txs, - coinbase_value, - &consensus_manager.consensus_constants() - ), - Ok(BlockAddResult::Ok) - ); - let block1_hash = blocks[1].hash(); - let metadata = store.get_metadata().unwrap(); - let utxo_root1 = store.fetch_mmr_root(MmrTree::Utxo).unwrap(); - let kernel_root1 = store.fetch_mmr_root(MmrTree::Kernel).unwrap(); - let rp_root1 = store.fetch_mmr_root(MmrTree::RangeProof).unwrap(); - assert_eq!(metadata.height_of_longest_chain, Some(1)); - assert_eq!(metadata.best_block, Some(block1_hash.clone())); - assert_eq!(store.fetch_block(0).unwrap().block().hash(), block0_hash); - assert_eq!(store.fetch_block(1).unwrap().block().hash(), block1_hash); - assert!(store.fetch_block(2).is_err()); - assert_ne!(utxo_root0, utxo_root1); - assert_ne!(kernel_root0, kernel_root1); - assert_ne!(rp_root0, rp_root1); - - // Invalid Block 2 - Double spends genesis block output - let txs = vec![txn_schema!(from: vec![outputs[0][0].clone()], to: vec![20 * T, 20 * T])]; - let coinbase_value = consensus_manager.emission_schedule().block_reward(2); - assert_eq!( - generate_new_block_with_coinbase( - &mut store, - &factories, - &mut blocks, - &mut outputs, - txs, - coinbase_value, - &consensus_manager.consensus_constants() - ), - Err(ChainStorageError::UnspendableInput) - ); - let metadata = store.get_metadata().unwrap(); - let utxo_root2 = store.fetch_mmr_root(MmrTree::Utxo).unwrap(); - let kernel_root2 = store.fetch_mmr_root(MmrTree::Kernel).unwrap(); - let rp_root2 = store.fetch_mmr_root(MmrTree::RangeProof).unwrap(); - assert_eq!(metadata.height_of_longest_chain, Some(1)); - assert_eq!(metadata.best_block, Some(block1_hash.clone())); - assert_eq!(store.fetch_block(0).unwrap().block().hash(), block0_hash); - assert_eq!(store.fetch_block(1).unwrap().block().hash(), block1_hash); - assert!(store.fetch_block(2).is_err()); - assert_eq!(utxo_root1, utxo_root2); - assert_eq!(kernel_root1, kernel_root2); - assert_eq!(rp_root1, rp_root2); - - // Valid Block 2 - let txs = vec![txn_schema!(from: vec![outputs[1][0].clone()], to: vec![4 * T, 4 * T])]; - let coinbase_value = consensus_manager.emission_schedule().block_reward(2); + // Block A1 + assert!(generate_new_block_with_achieved_difficulty( + &mut store, + &mut blocks, + &mut outputs, + vec![], + Difficulty::from(2), + &consensus_manager.consensus_constants() + ) + .is_ok()); + // Block A2 + assert!(generate_new_block_with_achieved_difficulty( + &mut store, + &mut blocks, + &mut outputs, + vec![], + Difficulty::from(3), + &consensus_manager.consensus_constants() + ) + .is_ok()); + // Block A3 + assert!(generate_new_block_with_achieved_difficulty( + &mut store, + &mut blocks, + &mut outputs, + vec![], + Difficulty::from(3), + &consensus_manager.consensus_constants() + ) + .is_ok()); + // Block A4 + assert!(generate_new_block_with_achieved_difficulty( + &mut store, + &mut blocks, + &mut outputs, + vec![], + Difficulty::from(3), + &consensus_manager.consensus_constants() + ) + .is_ok()); + + // Create Forked Chain + let consensus_manager_fork = ConsensusManagerBuilder::new(network) + .with_block(blocks[0].clone()) + .build(); + let mut orphan_store = create_mem_db(&consensus_manager_fork); + let mut orphan_blocks = vec![blocks[0].clone()]; + let mut orphan_outputs = vec![outputs[0].clone()]; + // Block B1 + assert!(generate_new_block_with_achieved_difficulty( + &mut orphan_store, + &mut orphan_blocks, + &mut orphan_outputs, + vec![], + Difficulty::from(2), + &consensus_manager_fork.consensus_constants() + ) + .is_ok()); + // Block B2 + assert!(generate_new_block_with_achieved_difficulty( + &mut orphan_store, + &mut orphan_blocks, + &mut orphan_outputs, + vec![], + Difficulty::from(10), + &consensus_manager_fork.consensus_constants() + ) + .is_ok()); + // Block B3 + assert!(generate_new_block_with_achieved_difficulty( + &mut orphan_store, + &mut orphan_blocks, + &mut orphan_outputs, + vec![], + Difficulty::from(15), + &consensus_manager_fork.consensus_constants() + ) + .is_ok()); + + // Fill orphan block pool + let orphan1 = create_orphan_block(1, vec![], &consensus_manager.consensus_constants()); + let orphan2 = create_orphan_block(1, vec![], &consensus_manager.consensus_constants()); + assert_eq!(store.add_block(orphan1.clone()), Ok(BlockAddResult::OrphanBlock)); + assert_eq!(store.add_block(orphan2.clone()), Ok(BlockAddResult::OrphanBlock)); + + // Adding B1 and B2 to the main chain will produce a reorg from GB->A1->A2->A3->A4 to GB->B1->B2->B3. assert_eq!( - generate_new_block_with_coinbase( - &mut store, - &factories, - &mut blocks, - &mut outputs, - txs, - coinbase_value, - &consensus_manager.consensus_constants() - ), - Ok(BlockAddResult::Ok) + store.add_block(orphan_blocks[1].clone()), + Ok(BlockAddResult::OrphanBlock) ); - let block2_hash = blocks[2].hash(); - let metadata = store.get_metadata().unwrap(); - let utxo_root2 = store.fetch_mmr_root(MmrTree::Utxo).unwrap(); - let kernel_root2 = store.fetch_mmr_root(MmrTree::Kernel).unwrap(); - let rp_root2 = store.fetch_mmr_root(MmrTree::RangeProof).unwrap(); - assert_eq!(metadata.height_of_longest_chain, Some(2)); - assert_eq!(metadata.best_block, Some(block2_hash.clone())); - assert_eq!(store.fetch_block(0).unwrap().block().hash(), block0_hash); - assert_eq!(store.fetch_block(1).unwrap().block().hash(), block1_hash); - assert_eq!(store.fetch_block(2).unwrap().block().hash(), block2_hash); - assert!(store.fetch_block(3).is_err()); - assert_ne!(utxo_root1, utxo_root2); - assert_ne!(kernel_root1, kernel_root2); - assert_ne!(rp_root1, rp_root2); + if let Ok(BlockAddResult::ChainReorg(_)) = store.add_block(orphan_blocks[2].clone()) { + assert!(true); + } else { + assert!(false); + } + + // Check that A2, A3 and A4 is in the orphan block pool, A1 and the other orphans were discarded by the orphan + // cleanup. + assert_eq!(store.db_read_access().unwrap().get_orphan_count(), Ok(3)); + assert_eq!(store.fetch_orphan(blocks[2].hash()), Ok(blocks[2].clone())); + assert_eq!(store.fetch_orphan(blocks[3].hash()), Ok(blocks[3].clone())); + assert_eq!(store.fetch_orphan(blocks[4].hash()), Ok(blocks[4].clone())); } diff --git a/base_layer/core/tests/diff_adj_manager.rs b/base_layer/core/tests/diff_adj_manager.rs deleted file mode 100644 index ac466ad400..0000000000 --- a/base_layer/core/tests/diff_adj_manager.rs +++ /dev/null @@ -1,548 +0,0 @@ -// Copyright 2019. The Tari Project -// -// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the -// following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following -// disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the -// following disclaimer in the documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote -// products derived from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, -// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -#[allow(dead_code)] -mod helpers; - -use helpers::block_builders::chain_block; -use tari_core::{ - blocks::Block, - chain_storage::{BlockchainDatabase, MemoryDatabase}, - consensus::{ConsensusConstants, ConsensusManagerBuilder, Network}, - helpers::create_mem_db, - proof_of_work::{ - lwma_diff::LinearWeightedMovingAverage, - DiffAdjManager, - Difficulty, - DifficultyAdjustment, - PowAlgorithm, - }, - transactions::types::HashDigest, -}; -use tari_crypto::tari_utilities::epoch_time::EpochTime; - -fn create_test_pow_blockchain( - db: &BlockchainDatabase>, - mut pow_algos: Vec, - consensus_constants: &ConsensusConstants, -) -{ - // Remove the first as it will be replaced by the genesis block - pow_algos.remove(0); - let block0 = db.fetch_block(0).unwrap().block().clone(); - append_to_pow_blockchain(db, block0, pow_algos, consensus_constants); -} - -fn append_to_pow_blockchain( - db: &BlockchainDatabase>, - chain_tip: Block, - pow_algos: Vec, - consensus: &ConsensusConstants, -) -{ - let mut prev_block = chain_tip; - for pow_algo in pow_algos { - let new_block = chain_block(&prev_block, Vec::new(), consensus); - let mut new_block = db.calculate_mmr_roots(new_block).unwrap(); - new_block.header.timestamp = prev_block - .header - .timestamp - .increase(consensus.get_target_block_interval()); - new_block.header.pow.pow_algo = pow_algo; - db.add_block(new_block.clone()).unwrap(); - prev_block = new_block; - } -} - -// Calculated the accumulated difficulty for the selected blocks in the blockchain db. -fn calculate_accumulated_difficulty( - db: &BlockchainDatabase>, - heights: Vec, - consensus_constants: &ConsensusConstants, -) -> Difficulty -{ - let mut lwma = LinearWeightedMovingAverage::new( - consensus_constants.get_difficulty_block_window() as usize, - consensus_constants.get_diff_target_block_interval(), - consensus_constants.min_pow_difficulty(), - consensus_constants.get_difficulty_max_block_interval(), - ); - for height in heights { - let header = db.fetch_header(height).unwrap(); - - lwma.add(header.timestamp, lwma.get_difficulty()).unwrap(); - } - lwma.get_difficulty() -} - -#[test] -fn test_initial_sync() { - let network = Network::LocalNet; - let consensus_manager = ConsensusManagerBuilder::new(network).build(); - let store = create_mem_db(&consensus_manager); - - let pow_algos = vec![ - PowAlgorithm::Blake, // GB default - PowAlgorithm::Blake, - PowAlgorithm::Monero, - PowAlgorithm::Blake, - PowAlgorithm::Blake, - PowAlgorithm::Monero, - PowAlgorithm::Monero, - PowAlgorithm::Blake, - ]; - create_test_pow_blockchain(&store, pow_algos.clone(), &consensus_manager.consensus_constants()); - let diff_adj_manager = DiffAdjManager::new(&consensus_manager.consensus_constants()).unwrap(); - - assert_eq!( - diff_adj_manager.get_target_difficulty( - &store.metadata_read_access().unwrap(), - &store.db_and_metadata_read_access().unwrap().0, - PowAlgorithm::Monero - ), - Ok(calculate_accumulated_difficulty( - &store, - vec![2, 5, 6], - &consensus_manager.consensus_constants() - )) - ); - assert_eq!( - diff_adj_manager.get_target_difficulty( - &store.metadata_read_access().unwrap(), - &store.db_and_metadata_read_access().unwrap().0, - PowAlgorithm::Blake - ), - Ok(calculate_accumulated_difficulty( - &store, - vec![0, 1, 3, 4, 7], - &consensus_manager.consensus_constants() - )) - ); -} - -#[test] -fn test_sync_to_chain_tip() { - let network = Network::LocalNet; - let consensus_manager = ConsensusManagerBuilder::new(network).build(); - let store = create_mem_db(&consensus_manager); - let diff_adj_manager = DiffAdjManager::new(&consensus_manager.consensus_constants()).unwrap(); - let _ = consensus_manager.set_diff_manager(diff_adj_manager); - - let pow_algos = vec![ - PowAlgorithm::Blake, // Genesis block default - PowAlgorithm::Monero, - PowAlgorithm::Blake, - PowAlgorithm::Blake, - PowAlgorithm::Monero, - PowAlgorithm::Blake, - ]; - create_test_pow_blockchain(&store, pow_algos, &consensus_manager.consensus_constants()); - assert_eq!(store.get_height(), Ok(Some(5))); - assert_eq!( - consensus_manager.get_target_difficulty( - &store.metadata_read_access().unwrap(), - &store.db_and_metadata_read_access().unwrap().0, - PowAlgorithm::Monero - ), - Ok(calculate_accumulated_difficulty( - &store, - vec![1, 4], - &consensus_manager.consensus_constants() - )) - ); - assert_eq!( - consensus_manager.get_target_difficulty( - &store.metadata_read_access().unwrap(), - &store.db_and_metadata_read_access().unwrap().0, - PowAlgorithm::Blake - ), - Ok(calculate_accumulated_difficulty( - &store, - vec![0, 2, 3, 5], - &consensus_manager.consensus_constants() - )) - ); - - let pow_algos = vec![ - PowAlgorithm::Blake, - PowAlgorithm::Monero, - PowAlgorithm::Blake, - PowAlgorithm::Monero, - ]; - let tip = store.fetch_block(store.get_height().unwrap().unwrap()).unwrap().block; - append_to_pow_blockchain(&store, tip, pow_algos, &consensus_manager.consensus_constants()); - assert_eq!(store.get_height(), Ok(Some(9))); - assert_eq!( - consensus_manager.get_target_difficulty( - &store.metadata_read_access().unwrap(), - &store.db_and_metadata_read_access().unwrap().0, - PowAlgorithm::Monero - ), - Ok(calculate_accumulated_difficulty( - &store, - vec![1, 4, 7, 9], - &consensus_manager.consensus_constants() - )) - ); - assert_eq!( - consensus_manager.get_target_difficulty( - &store.metadata_read_access().unwrap(), - &store.db_and_metadata_read_access().unwrap().0, - PowAlgorithm::Blake - ), - Ok(calculate_accumulated_difficulty( - &store, - vec![0, 2, 3, 5, 6, 8], - &consensus_manager.consensus_constants() - )) - ); -} - -#[test] -fn test_target_difficulty_with_height() { - let network = Network::LocalNet; - let consensus_manager = ConsensusManagerBuilder::new(network).build(); - let store = create_mem_db(&consensus_manager); - let diff_adj_manager = DiffAdjManager::new(&consensus_manager.consensus_constants()).unwrap(); - let _ = consensus_manager.set_diff_manager(diff_adj_manager); - assert!(consensus_manager - .get_target_difficulty_with_height(&store.db_and_metadata_read_access().unwrap().0, PowAlgorithm::Monero, 5) - .is_err()); - assert!(consensus_manager - .get_target_difficulty_with_height(&store.db_and_metadata_read_access().unwrap().0, PowAlgorithm::Blake, 5) - .is_err()); - - let pow_algos = vec![ - PowAlgorithm::Blake, // GB default - PowAlgorithm::Monero, - PowAlgorithm::Blake, - PowAlgorithm::Blake, - PowAlgorithm::Monero, - PowAlgorithm::Blake, - ]; - create_test_pow_blockchain(&store, pow_algos, &consensus_manager.consensus_constants()); - let diff_adj_manager = DiffAdjManager::new(&consensus_manager.consensus_constants()).unwrap(); - let _ = consensus_manager.set_diff_manager(diff_adj_manager); - - assert_eq!( - consensus_manager.get_target_difficulty_with_height( - &store.db_and_metadata_read_access().unwrap().0, - PowAlgorithm::Monero, - 5 - ), - Ok(calculate_accumulated_difficulty( - &store, - vec![1, 4], - &consensus_manager.consensus_constants() - )) - ); - assert_eq!( - consensus_manager.get_target_difficulty_with_height( - &store.db_and_metadata_read_access().unwrap().0, - PowAlgorithm::Blake, - 5 - ), - Ok(calculate_accumulated_difficulty( - &store, - vec![0, 2, 3, 5], - &consensus_manager.consensus_constants() - )) - ); - - assert_eq!( - consensus_manager.get_target_difficulty_with_height( - &store.db_and_metadata_read_access().unwrap().0, - PowAlgorithm::Monero, - 2 - ), - Ok(calculate_accumulated_difficulty( - &store, - vec![1], - &consensus_manager.consensus_constants() - )) - ); - assert_eq!( - consensus_manager.get_target_difficulty_with_height( - &store.db_and_metadata_read_access().unwrap().0, - PowAlgorithm::Blake, - 2 - ), - Ok(calculate_accumulated_difficulty( - &store, - vec![0, 2], - &consensus_manager.consensus_constants() - )) - ); - - assert_eq!( - consensus_manager.get_target_difficulty_with_height( - &store.db_and_metadata_read_access().unwrap().0, - PowAlgorithm::Monero, - 3 - ), - Ok(calculate_accumulated_difficulty( - &store, - vec![1], - &consensus_manager.consensus_constants() - )) - ); - assert_eq!( - consensus_manager.get_target_difficulty_with_height( - &store.db_and_metadata_read_access().unwrap().0, - PowAlgorithm::Blake, - 3 - ), - Ok(calculate_accumulated_difficulty( - &store, - vec![0, 2, 3], - &consensus_manager.consensus_constants() - )) - ); -} - -#[test] -#[ignore] // TODO Wait for reorg logic to be refactored -fn test_full_sync_on_reorg() { - let network = Network::LocalNet; - let consensus_manager = ConsensusManagerBuilder::new(network).build(); - let store = create_mem_db(&consensus_manager); - let diff_adj_manager = DiffAdjManager::new(&consensus_manager.consensus_constants()).unwrap(); - - let pow_algos = vec![ - PowAlgorithm::Blake, // GB default - PowAlgorithm::Blake, - PowAlgorithm::Blake, - PowAlgorithm::Blake, - PowAlgorithm::Monero, - ]; - create_test_pow_blockchain(&store, pow_algos, &consensus_manager.consensus_constants()); - assert_eq!(store.get_height(), Ok(Some(4))); - assert_eq!( - diff_adj_manager.get_target_difficulty( - &store.metadata_read_access().unwrap(), - &store.db_and_metadata_read_access().unwrap().0, - PowAlgorithm::Monero - ), - Ok(Difficulty::from(1)) - ); - assert_eq!( - diff_adj_manager.get_target_difficulty( - &store.metadata_read_access().unwrap(), - &store.db_and_metadata_read_access().unwrap().0, - PowAlgorithm::Blake - ), - Ok(Difficulty::from(18)) - ); - - let pow_algos = vec![ - PowAlgorithm::Blake, - PowAlgorithm::Blake, - PowAlgorithm::Monero, - PowAlgorithm::Monero, - PowAlgorithm::Blake, - PowAlgorithm::Monero, - PowAlgorithm::Blake, - PowAlgorithm::Monero, - ]; - assert_eq!(store.get_height(), Ok(Some(8))); - let tip = store.fetch_block(8).unwrap().block; - append_to_pow_blockchain(&store, tip, pow_algos, &consensus_manager.consensus_constants()); - assert_eq!( - diff_adj_manager.get_target_difficulty( - &store.metadata_read_access().unwrap(), - &store.db_and_metadata_read_access().unwrap().0, - PowAlgorithm::Monero - ), - Ok(Difficulty::from(2)) - ); - assert_eq!( - diff_adj_manager.get_target_difficulty( - &store.metadata_read_access().unwrap(), - &store.db_and_metadata_read_access().unwrap().0, - PowAlgorithm::Blake - ), - Ok(Difficulty::from(9)) - ); -} - -#[test] -fn test_median_timestamp() { - let network = Network::LocalNet; - let consensus_manager = ConsensusManagerBuilder::new(network).build(); - let store = create_mem_db(&consensus_manager); - let diff_adj_manager = DiffAdjManager::new(&consensus_manager.consensus_constants()).unwrap(); - let pow_algos = vec![PowAlgorithm::Blake]; // GB default - create_test_pow_blockchain(&store, pow_algos, &consensus_manager.consensus_constants()); - let start_timestamp = store.fetch_block(0).unwrap().block().header.timestamp.clone(); - let mut timestamp = diff_adj_manager - .get_median_timestamp( - &store.metadata_read_access().unwrap(), - &store.db_and_metadata_read_access().unwrap().0, - ) - .expect("median returned an error"); - assert_eq!(timestamp, start_timestamp); - - let pow_algos = vec![PowAlgorithm::Blake]; - // lets add 1 - let tip = store.fetch_block(store.get_height().unwrap().unwrap()).unwrap().block; - append_to_pow_blockchain(&store, tip, pow_algos.clone(), &consensus_manager.consensus_constants()); - let mut prev_timestamp: EpochTime = - start_timestamp.increase(consensus_manager.consensus_constants().get_target_block_interval()); - timestamp = diff_adj_manager - .get_median_timestamp( - &store.metadata_read_access().unwrap(), - &store.db_and_metadata_read_access().unwrap().0, - ) - .expect("median returned an error"); - assert_eq!(timestamp, prev_timestamp); - // lets add 1 - let tip = store.fetch_block(store.get_height().unwrap().unwrap()).unwrap().block; - append_to_pow_blockchain(&store, tip, pow_algos.clone(), &consensus_manager.consensus_constants()); - prev_timestamp = start_timestamp.increase(consensus_manager.consensus_constants().get_target_block_interval()); - timestamp = diff_adj_manager - .get_median_timestamp( - &store.metadata_read_access().unwrap(), - &store.db_and_metadata_read_access().unwrap().0, - ) - .expect("median returned an error"); - assert_eq!(timestamp, prev_timestamp); - - // lets build up 11 blocks - for i in 4..12 { - let tip = store.fetch_block(store.get_height().unwrap().unwrap()).unwrap().block; - append_to_pow_blockchain(&store, tip, pow_algos.clone(), &consensus_manager.consensus_constants()); - prev_timestamp = - start_timestamp.increase(consensus_manager.consensus_constants().get_target_block_interval() * (i / 2)); - timestamp = diff_adj_manager - .get_median_timestamp( - &store.metadata_read_access().unwrap(), - &store.db_and_metadata_read_access().unwrap().0, - ) - .expect("median returned an error"); - assert_eq!(timestamp, prev_timestamp); - } - - // lets add many1 blocks - for _i in 1..20 { - let tip = store.fetch_block(store.get_height().unwrap().unwrap()).unwrap().block; - append_to_pow_blockchain(&store, tip, pow_algos.clone(), &consensus_manager.consensus_constants()); - prev_timestamp = prev_timestamp.increase(consensus_manager.consensus_constants().get_target_block_interval()); - timestamp = diff_adj_manager - .get_median_timestamp( - &store.metadata_read_access().unwrap(), - &store.db_and_metadata_read_access().unwrap().0, - ) - .expect("median returned an error"); - assert_eq!(timestamp, prev_timestamp); - } -} - -#[test] -fn test_median_timestamp_with_height() { - let network = Network::LocalNet; - let consensus_manager = ConsensusManagerBuilder::new(network).build(); - let store = create_mem_db(&consensus_manager); - let diff_adj_manager = DiffAdjManager::new(&consensus_manager.consensus_constants()).unwrap(); - let pow_algos = vec![ - PowAlgorithm::Blake, // GB default - PowAlgorithm::Monero, - PowAlgorithm::Blake, - PowAlgorithm::Monero, - PowAlgorithm::Blake, - ]; - create_test_pow_blockchain(&store, pow_algos, &consensus_manager.consensus_constants()); - - let header0_timestamp = store.fetch_header(0).unwrap().timestamp; - let header1_timestamp = store.fetch_header(1).unwrap().timestamp; - let header2_timestamp = store.fetch_header(2).unwrap().timestamp; - - let timestamp = diff_adj_manager - .get_median_timestamp_at_height(&store.db_and_metadata_read_access().unwrap().0, 0) - .expect("median returned an error"); - assert_eq!(timestamp, header0_timestamp); - - let timestamp = diff_adj_manager - .get_median_timestamp_at_height(&store.db_and_metadata_read_access().unwrap().0, 3) - .expect("median returned an error"); - assert_eq!(timestamp, header2_timestamp); - - let timestamp = diff_adj_manager - .get_median_timestamp_at_height(&store.db_and_metadata_read_access().unwrap().0, 2) - .expect("median returned an error"); - assert_eq!(timestamp, header1_timestamp); - - let timestamp = diff_adj_manager - .get_median_timestamp_at_height(&store.db_and_metadata_read_access().unwrap().0, 4) - .expect("median returned an error"); - assert_eq!(timestamp, header2_timestamp); -} - -#[test] -fn test_median_timestamp_odd_order() { - let network = Network::LocalNet; - let consensus_manager = ConsensusManagerBuilder::new(network).build(); - let store = create_mem_db(&consensus_manager); - let diff_adj_manager = DiffAdjManager::new(&consensus_manager.consensus_constants()).unwrap(); - let pow_algos = vec![PowAlgorithm::Blake]; // GB default - create_test_pow_blockchain(&store, pow_algos, &consensus_manager.consensus_constants()); - let start_timestamp = store.fetch_block(0).unwrap().block().header.timestamp.clone(); - let mut timestamp = diff_adj_manager - .get_median_timestamp( - &store.metadata_read_access().unwrap(), - &store.db_and_metadata_read_access().unwrap().0, - ) - .expect("median returned an error"); - assert_eq!(timestamp, start_timestamp); - let pow_algos = vec![PowAlgorithm::Blake]; - // lets add 1 - let tip = store.fetch_block(store.get_height().unwrap().unwrap()).unwrap().block; - append_to_pow_blockchain(&store, tip, pow_algos.clone(), &consensus_manager.consensus_constants()); - let mut prev_timestamp: EpochTime = - start_timestamp.increase(consensus_manager.consensus_constants().get_target_block_interval()); - timestamp = diff_adj_manager - .get_median_timestamp( - &store.metadata_read_access().unwrap(), - &store.db_and_metadata_read_access().unwrap().0, - ) - .expect("median returned an error"); - assert_eq!(timestamp, prev_timestamp); - - // lets add 1 that's further back then - let append_height = store.get_height().unwrap().unwrap(); - let prev_block = store.fetch_block(append_height).unwrap().block().clone(); - let new_block = chain_block(&prev_block, Vec::new(), &consensus_manager.consensus_constants()); - let mut new_block = store.calculate_mmr_roots(new_block).unwrap(); - new_block.header.timestamp = - start_timestamp.increase(&consensus_manager.consensus_constants().get_target_block_interval() / 2); - new_block.header.pow.pow_algo = PowAlgorithm::Blake; - store.add_block(new_block).unwrap(); - - prev_timestamp = start_timestamp.increase(consensus_manager.consensus_constants().get_target_block_interval() / 2); - timestamp = diff_adj_manager - .get_median_timestamp( - &store.metadata_read_access().unwrap(), - &store.db_and_metadata_read_access().unwrap().0, - ) - .expect("median returned an error"); - // Median timestamp should be block 3 and not block 2 - assert_eq!(timestamp, prev_timestamp); -} diff --git a/base_layer/core/tests/helpers/mod.rs b/base_layer/core/tests/helpers/mod.rs index ba67626291..e346459716 100644 --- a/base_layer/core/tests/helpers/mod.rs +++ b/base_layer/core/tests/helpers/mod.rs @@ -7,4 +7,5 @@ pub mod block_builders; pub mod chain_metadata; pub mod event_stream; pub mod nodes; +pub mod pow_blockchain; pub mod sample_blockchains; diff --git a/base_layer/core/tests/helpers/nodes.rs b/base_layer/core/tests/helpers/nodes.rs index 985b69e407..f2c77a44af 100644 --- a/base_layer/core/tests/helpers/nodes.rs +++ b/base_layer/core/tests/helpers/nodes.rs @@ -37,7 +37,7 @@ use tari_core::{ OutboundNodeCommsInterface, }, blocks::Block, - chain_storage::{BlockchainDatabase, MemoryDatabase, Validators}, + chain_storage::{BlockchainDatabase, BlockchainDatabaseConfig, MemoryDatabase, Validators}, consensus::{ConsensusManager, ConsensusManagerBuilder, Network}, mempool::{ Mempool, @@ -47,13 +47,14 @@ use tari_core::{ MempoolValidators, OutboundMempoolServiceInterface, }, - proof_of_work::DiffAdjManager, + proof_of_work::Difficulty, transactions::types::HashDigest, validation::{ + accum_difficulty_validators::MockAccumDifficultyValidator, mocks::MockValidator, transaction_validators::TxInputAndMaturityValidator, StatelessValidation, - ValidationWriteGuard, + Validation, }, }; use tari_mmr::MmrCacheConfig; @@ -159,11 +160,12 @@ impl BaseNodeBuilder { pub fn with_validators( mut self, - block: impl ValidationWriteGuard> + 'static, + block: impl Validation> + 'static, orphan: impl StatelessValidation + 'static, + accum_difficulty: impl Validation> + 'static, ) -> Self { - let validators = Validators::new(block, orphan); + let validators = Validators::new(block, orphan, accum_difficulty); self.validators = Some(validators); self } @@ -177,22 +179,23 @@ impl BaseNodeBuilder { /// Build the test base node and start its services. pub fn start(self, runtime: &mut Runtime, data_path: &str) -> (NodeInterfaces, ConsensusManager) { let mmr_cache_config = self.mmr_cache_config.unwrap_or(MmrCacheConfig { rewind_hist_len: 10 }); - let validators = self - .validators - .unwrap_or(Validators::new(MockValidator::new(true), MockValidator::new(true))); + let validators = self.validators.unwrap_or(Validators::new( + MockValidator::new(true), + MockValidator::new(true), + MockAccumDifficultyValidator {}, + )); let consensus_manager = self .consensus_manager .unwrap_or(ConsensusManagerBuilder::new(self.network).build()); let db = MemoryDatabase::::new(mmr_cache_config); - let blockchain_db = BlockchainDatabase::new(db, &consensus_manager, validators).unwrap(); + let blockchain_db = + BlockchainDatabase::new(db, &consensus_manager, validators, BlockchainDatabaseConfig::default()).unwrap(); let mempool_validator = MempoolValidators::new(TxInputAndMaturityValidator {}, TxInputAndMaturityValidator {}); let mempool = Mempool::new( blockchain_db.clone(), self.mempool_config.unwrap_or(MempoolConfig::default()), mempool_validator, ); - let diff_adj_manager = DiffAdjManager::new(&consensus_manager.consensus_constants()).unwrap(); - consensus_manager.set_diff_manager(diff_adj_manager).unwrap(); let node_identity = self.node_identity.unwrap_or(random_node_identity()); let ( outbound_nci, @@ -515,6 +518,7 @@ fn setup_base_node_services( liveness_service_config, Arc::clone(&subscription_factory), dht.dht_requester(), + comms.connection_manager(), )) .add_initializer(BaseNodeServiceInitializer::new( subscription_factory.clone(), diff --git a/base_layer/core/tests/helpers/pow_blockchain.rs b/base_layer/core/tests/helpers/pow_blockchain.rs new file mode 100644 index 0000000000..1671ee55bc --- /dev/null +++ b/base_layer/core/tests/helpers/pow_blockchain.rs @@ -0,0 +1,104 @@ +// Copyright 2019. The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use super::block_builders::chain_block; +use tari_core::{ + blocks::Block, + chain_storage::{BlockchainBackend, BlockchainDatabase, MemoryDatabase}, + consensus::{ConsensusConstants, ConsensusManager}, + proof_of_work::{ + get_target_difficulty, + lwma_diff::LinearWeightedMovingAverage, + Difficulty, + DifficultyAdjustment, + PowAlgorithm, + }, + transactions::types::HashDigest, +}; + +pub fn create_test_pow_blockchain( + db: &BlockchainDatabase, + mut pow_algos: Vec, + consensus_manager: &ConsensusManager, +) +{ + // Remove the first as it will be replaced by the genesis block + pow_algos.remove(0); + let block0 = db.fetch_block(0).unwrap().block().clone(); + append_to_pow_blockchain(db, block0, pow_algos, consensus_manager); +} + +pub fn append_to_pow_blockchain( + db: &BlockchainDatabase, + chain_tip: Block, + pow_algos: Vec, + consensus_manager: &ConsensusManager, +) +{ + let constants = consensus_manager.consensus_constants(); + let mut prev_block = chain_tip; + for pow_algo in pow_algos { + let new_block = chain_block(&prev_block, Vec::new(), constants); + let mut new_block = db.calculate_mmr_roots(new_block).unwrap(); + new_block.header.timestamp = prev_block + .header + .timestamp + .increase(constants.get_target_block_interval()); + new_block.header.pow.pow_algo = pow_algo; + + let height = db.get_metadata().unwrap().height_of_longest_chain.unwrap(); + let target_difficulties = db + .fetch_target_difficulties(pow_algo, height, constants.get_difficulty_block_window() as usize) + .unwrap(); + new_block.header.pow.target_difficulty = get_target_difficulty( + target_difficulties, + constants.get_difficulty_block_window() as usize, + constants.get_diff_target_block_interval(), + constants.min_pow_difficulty(pow_algo), + constants.get_difficulty_max_block_interval(), + ) + .unwrap(); + db.add_block(new_block.clone()).unwrap(); + prev_block = new_block; + } +} + +// Calculated the accumulated difficulty for the selected blocks in the blockchain db. +pub fn calculate_accumulated_difficulty( + db: &BlockchainDatabase>, + pow_algo: PowAlgorithm, + heights: Vec, + consensus_constants: &ConsensusConstants, +) -> Difficulty +{ + let mut lwma = LinearWeightedMovingAverage::new( + consensus_constants.get_difficulty_block_window() as usize, + consensus_constants.get_diff_target_block_interval(), + consensus_constants.min_pow_difficulty(pow_algo), + consensus_constants.get_difficulty_max_block_interval(), + ); + for height in heights { + let header = db.fetch_header(height).unwrap(); + lwma.add(header.timestamp, header.pow.target_difficulty).unwrap(); + } + lwma.get_difficulty() +} diff --git a/base_layer/core/tests/median_timestamp.rs b/base_layer/core/tests/median_timestamp.rs new file mode 100644 index 0000000000..427a6b760f --- /dev/null +++ b/base_layer/core/tests/median_timestamp.rs @@ -0,0 +1,136 @@ +// Copyright 2019. The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#[allow(dead_code)] +mod helpers; + +use helpers::{ + block_builders::chain_block, + pow_blockchain::{append_to_pow_blockchain, create_test_pow_blockchain}, +}; +use tari_core::{ + chain_storage::{fetch_headers, BlockchainBackend}, + consensus::{ConsensusManagerBuilder, Network}, + helpers::create_mem_db, + proof_of_work::{get_median_timestamp, PowAlgorithm}, +}; +use tari_crypto::tari_utilities::epoch_time::EpochTime; + +pub fn get_header_timestamps(db: &B, height: u64, timestamp_count: u64) -> Vec { + let min_height = height.checked_sub(timestamp_count).unwrap_or(0); + let block_nums = (min_height..=height).collect(); + fetch_headers(db, block_nums) + .unwrap() + .iter() + .map(|h| h.timestamp) + .collect::>() +} + +#[test] +fn test_median_timestamp_with_height() { + let network = Network::LocalNet; + let consensus_manager = ConsensusManagerBuilder::new(network).build(); + let store = create_mem_db(&consensus_manager); + let pow_algos = vec![ + PowAlgorithm::Blake, // GB default + PowAlgorithm::Monero, + PowAlgorithm::Blake, + PowAlgorithm::Monero, + PowAlgorithm::Blake, + ]; + create_test_pow_blockchain(&store, pow_algos, &consensus_manager); + let timestamp_count = 10; + + let header0_timestamp = store.fetch_header(0).unwrap().timestamp; + let header1_timestamp = store.fetch_header(1).unwrap().timestamp; + let header2_timestamp = store.fetch_header(2).unwrap().timestamp; + + let db = &*store.db_read_access().unwrap(); + let median_timestamp = + get_median_timestamp(get_header_timestamps(db, 0, timestamp_count)).expect("median returned an error"); + assert_eq!(median_timestamp, header0_timestamp); + + let median_timestamp = + get_median_timestamp(get_header_timestamps(db, 3, timestamp_count)).expect("median returned an error"); + assert_eq!(median_timestamp, (header1_timestamp + header2_timestamp) / 2); + + let median_timestamp = + get_median_timestamp(get_header_timestamps(db, 2, timestamp_count)).expect("median returned an error"); + assert_eq!(median_timestamp, header1_timestamp); + + let median_timestamp = + get_median_timestamp(get_header_timestamps(db, 4, timestamp_count)).expect("median returned an error"); + assert_eq!(median_timestamp, header2_timestamp); +} + +#[test] +fn test_median_timestamp_odd_order() { + let network = Network::LocalNet; + let consensus_manager = ConsensusManagerBuilder::new(network).build(); + let timestamp_count = consensus_manager.consensus_constants().get_median_timestamp_count() as u64; + let store = create_mem_db(&consensus_manager); + let pow_algos = vec![PowAlgorithm::Blake]; // GB default + create_test_pow_blockchain(&store, pow_algos, &consensus_manager); + let mut timestamps = vec![store.fetch_block(0).unwrap().block().header.timestamp.clone()]; + let height = store.get_metadata().unwrap().height_of_longest_chain.unwrap(); + let mut median_timestamp = get_median_timestamp(get_header_timestamps( + &*store.db_read_access().unwrap(), + height, + timestamp_count, + )) + .expect("median returned an error"); + assert_eq!(median_timestamp, timestamps[0]); + let pow_algos = vec![PowAlgorithm::Blake]; + // lets add 1 + let tip = store.fetch_block(store.get_height().unwrap().unwrap()).unwrap().block; + append_to_pow_blockchain(&store, tip, pow_algos.clone(), &consensus_manager); + timestamps.push(timestamps[0].increase(consensus_manager.consensus_constants().get_target_block_interval())); + let height = store.get_metadata().unwrap().height_of_longest_chain.unwrap(); + median_timestamp = get_median_timestamp(get_header_timestamps( + &*store.db_read_access().unwrap(), + height, + timestamp_count, + )) + .expect("median returned an error"); + assert_eq!(median_timestamp, (timestamps[0] + timestamps[1]) / 2); + + // lets add 1 that's further back then + let append_height = store.get_height().unwrap().unwrap(); + let prev_block = store.fetch_block(append_height).unwrap().block().clone(); + let new_block = chain_block(&prev_block, Vec::new(), &consensus_manager.consensus_constants()); + let mut new_block = store.calculate_mmr_roots(new_block).unwrap(); + timestamps.push(timestamps[0].increase(&consensus_manager.consensus_constants().get_target_block_interval() / 2)); + new_block.header.timestamp = timestamps[2]; + new_block.header.pow.pow_algo = PowAlgorithm::Blake; + store.add_block(new_block).unwrap(); + + timestamps.push(timestamps[2].increase(consensus_manager.consensus_constants().get_target_block_interval() / 2)); + let height = store.get_metadata().unwrap().height_of_longest_chain.unwrap(); + median_timestamp = get_median_timestamp(get_header_timestamps( + &*store.db_read_access().unwrap(), + height, + timestamp_count, + )) + .expect("median returned an error"); + // Median timestamp should be block 3 and not block 2 + assert_eq!(median_timestamp, timestamps[2]); +} diff --git a/base_layer/core/tests/mempool.rs b/base_layer/core/tests/mempool.rs index 32b6acb81a..86c38263b1 100644 --- a/base_layer/core/tests/mempool.rs +++ b/base_layer/core/tests/mempool.rs @@ -164,7 +164,7 @@ fn test_insert_and_process_published_block() { assert_eq!(stats.orphan_txs, 1); assert_eq!(stats.timelocked_txs, 2); assert_eq!(stats.published_txs, 0); - assert_eq!(stats.total_weight, 36); + assert_eq!(stats.total_weight, 120); // Spend tx2, so it goes in Reorg pool, tx5 matures, so goes in Unconfirmed pool generate_block( @@ -219,7 +219,7 @@ fn test_insert_and_process_published_block() { assert_eq!(stats.orphan_txs, 1); assert_eq!(stats.timelocked_txs, 1); assert_eq!(stats.published_txs, 1); - assert_eq!(stats.total_weight, 36); + assert_eq!(stats.total_weight, 120); } #[test] @@ -382,10 +382,16 @@ fn test_reorg() { let stats = mempool.stats().unwrap(); assert_eq!(stats.unconfirmed_txs, 0); assert_eq!(stats.timelocked_txs, 1); + assert_eq!(stats.published_txs, 5); db.rewind_to_height(2).unwrap(); - mempool.process_reorg(vec![blocks[3].clone()], vec![]).unwrap(); + let template = chain_block(&blocks[2], vec![], consensus_manager.consensus_constants()); + let reorg_block3 = db.calculate_mmr_roots(template).unwrap(); + + mempool + .process_reorg(vec![blocks[3].clone()], vec![reorg_block3]) + .unwrap(); let stats = mempool.stats().unwrap(); assert_eq!(stats.unconfirmed_txs, 2); assert_eq!(stats.timelocked_txs, 1); @@ -501,7 +507,7 @@ fn request_response_get_stats() { bob.mempool.insert(orphan1.clone()).unwrap(); bob.mempool.insert(orphan2.clone()).unwrap(); - // The coinbase tx cannot be spent until maturity, so rxn1 will be in the timelocked pool. The other 2 txns are + // The coinbase tx cannot be spent until maturity, so txn1 will be in the timelocked pool. The other 2 txns are // orphans. let stats = bob.mempool.stats().unwrap(); assert_eq!(stats.total_txs, 3); @@ -509,7 +515,7 @@ fn request_response_get_stats() { assert_eq!(stats.unconfirmed_txs, 0); assert_eq!(stats.timelocked_txs, 1); assert_eq!(stats.published_txs, 0); - assert_eq!(stats.total_weight, 35); + assert_eq!(stats.total_weight, 116); runtime.block_on(async { // Alice will request mempool stats from Bob, and thus should be identical @@ -519,7 +525,7 @@ fn request_response_get_stats() { assert_eq!(received_stats.orphan_txs, 2); assert_eq!(received_stats.timelocked_txs, 1); assert_eq!(received_stats.published_txs, 0); - assert_eq!(received_stats.total_weight, 35); + assert_eq!(received_stats.total_weight, 116); alice.comms.shutdown().await; bob.comms.shutdown().await; diff --git a/base_layer/core/tests/node_comms_interface.rs b/base_layer/core/tests/node_comms_interface.rs index 094ab0b7d5..5ffae1dfef 100644 --- a/base_layer/core/tests/node_comms_interface.rs +++ b/base_layer/core/tests/node_comms_interface.rs @@ -36,7 +36,6 @@ use tari_core::{ consensus::{ConsensusManagerBuilder, Network}, helpers::create_mem_db, mempool::{Mempool, MempoolConfig, MempoolValidators}, - proof_of_work::DiffAdjManager, transactions::{ helpers::{create_test_kernel, create_utxo}, tari_amount::MicroTari, @@ -92,9 +91,7 @@ fn inbound_get_metadata() { let network = Network::LocalNet; let consensus_manager = ConsensusManagerBuilder::new(network).build(); - let diff_adj_manager = DiffAdjManager::new(&consensus_manager.consensus_constants()).unwrap(); let (block_event_publisher, _block_event_subscriber) = bounded(100); - assert!(consensus_manager.set_diff_manager(diff_adj_manager).is_ok()); let (request_sender, _) = reply_channel::unbounded(); let (block_sender, _) = futures_mpsc_channel_unbounded(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender.clone()); @@ -147,9 +144,7 @@ fn inbound_fetch_kernels() { let (mempool, store) = new_mempool(); let network = Network::LocalNet; let consensus_manager = ConsensusManagerBuilder::new(network).build(); - let diff_adj_manager = DiffAdjManager::new(&consensus_manager.consensus_constants()).unwrap(); let (block_event_publisher, _block_event_subscriber) = bounded(100); - assert!(consensus_manager.set_diff_manager(diff_adj_manager).is_ok()); let (request_sender, _) = reply_channel::unbounded(); let (block_sender, _) = futures_mpsc_channel_unbounded(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); @@ -210,9 +205,7 @@ fn inbound_fetch_headers() { let consensus_manager = ConsensusManagerBuilder::new(network) .with_consensus_constants(consensus_constants) .build(); - let diff_adj_manager = DiffAdjManager::new(&consensus_manager.consensus_constants()).unwrap(); let (block_event_publisher, _block_event_subscriber) = bounded(100); - assert!(consensus_manager.set_diff_manager(diff_adj_manager).is_ok()); let (request_sender, _) = reply_channel::unbounded(); let (block_sender, _) = futures_mpsc_channel_unbounded(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); @@ -271,8 +264,6 @@ fn inbound_fetch_utxos() { let consensus_manager = ConsensusManagerBuilder::new(network) .with_consensus_constants(consensus_constants) .build(); - let diff_adj_manager = DiffAdjManager::new(&consensus_manager.consensus_constants()).unwrap(); - assert!(consensus_manager.set_diff_manager(diff_adj_manager).is_ok()); let (request_sender, _) = reply_channel::unbounded(); let (block_sender, _) = futures_mpsc_channel_unbounded(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); @@ -335,8 +326,6 @@ fn inbound_fetch_blocks() { let consensus_manager = ConsensusManagerBuilder::new(network) .with_consensus_constants(consensus_constants) .build(); - let diff_adj_manager = DiffAdjManager::new(&consensus_manager.consensus_constants()).unwrap(); - assert!(consensus_manager.set_diff_manager(diff_adj_manager).is_ok()); let (request_sender, _) = reply_channel::unbounded(); let (block_sender, _) = futures_mpsc_channel_unbounded(); let outbound_nci = OutboundNodeCommsInterface::new(request_sender, block_sender); diff --git a/base_layer/core/tests/node_service.rs b/base_layer/core/tests/node_service.rs index e38653dca3..de97380e8a 100644 --- a/base_layer/core/tests/node_service.rs +++ b/base_layer/core/tests/node_service.rs @@ -60,7 +60,11 @@ use tari_core::{ types::CryptoFactories, }, txn_schema, - validation::{block_validators::StatelessBlockValidator, mocks::MockValidator}, + validation::{ + accum_difficulty_validators::MockAccumDifficultyValidator, + block_validators::StatelessBlockValidator, + mocks::MockValidator, + }, }; use tari_crypto::tari_utilities::hash::Hashable; use tari_mmr::MmrCacheConfig; @@ -504,6 +508,7 @@ fn propagate_and_forward_invalid_block() { .build(); let stateless_block_validator = StatelessBlockValidator::new(&rules.consensus_constants()); let mock_validator = MockValidator::new(true); + let mock_accum_difficulty_validator = MockAccumDifficultyValidator {}; let (mut alice_node, rules) = BaseNodeBuilder::new(network) .with_node_identity(alice_node_identity.clone()) .with_peers(vec![bob_node_identity.clone(), carol_node_identity.clone()]) @@ -513,13 +518,21 @@ fn propagate_and_forward_invalid_block() { .with_node_identity(bob_node_identity.clone()) .with_peers(vec![alice_node_identity.clone(), dan_node_identity.clone()]) .with_consensus_manager(rules) - .with_validators(mock_validator.clone(), stateless_block_validator.clone()) + .with_validators( + mock_validator.clone(), + stateless_block_validator.clone(), + mock_accum_difficulty_validator.clone(), + ) .start(&mut runtime, temp_dir.path().to_str().unwrap()); let (carol_node, rules) = BaseNodeBuilder::new(network) .with_node_identity(carol_node_identity.clone()) .with_peers(vec![alice_node_identity, dan_node_identity.clone()]) .with_consensus_manager(rules) - .with_validators(mock_validator.clone(), stateless_block_validator) + .with_validators( + mock_validator.clone(), + stateless_block_validator, + mock_accum_difficulty_validator.clone(), + ) .start(&mut runtime, temp_dir.path().to_str().unwrap()); let (dan_node, rules) = BaseNodeBuilder::new(network) .with_node_identity(dan_node_identity) @@ -653,7 +666,11 @@ fn local_get_new_block_template_and_get_new_block() { assert!(node.mempool.insert(txs[1].clone()).is_ok()); runtime.block_on(async { - let block_template = node.local_nci.get_new_block_template().await.unwrap(); + let block_template = node + .local_nci + .get_new_block_template(PowAlgorithm::Blake) + .await + .unwrap(); assert_eq!(block_template.header.height, 1); assert_eq!(block_template.body.kernels().len(), 2); diff --git a/base_layer/core/tests/node_state_machine.rs b/base_layer/core/tests/node_state_machine.rs index 357d7f9286..9f344f5491 100644 --- a/base_layer/core/tests/node_state_machine.rs +++ b/base_layer/core/tests/node_state_machine.rs @@ -62,7 +62,11 @@ use tari_core::{ helpers::create_mem_db, mempool::MempoolServiceConfig, transactions::types::CryptoFactories, - validation::{block_validators::StatelessBlockValidator, mocks::MockValidator}, + validation::{ + accum_difficulty_validators::MockAccumDifficultyValidator, + block_validators::StatelessBlockValidator, + mocks::MockValidator, + }, }; use tari_mmr::MmrCacheConfig; use tari_p2p::services::liveness::LivenessConfig; @@ -92,9 +96,8 @@ fn test_listening_lagging() { MempoolServiceConfig::default(), LivenessConfig { enable_auto_join: false, - enable_auto_stored_message_request: false, auto_ping_interval: Some(Duration::from_millis(100)), - refresh_neighbours_interval: Duration::from_secs(60), + ..Default::default() }, consensus_manager, temp_dir.path().to_str().unwrap(), @@ -611,7 +614,11 @@ fn test_sync_peer_banning() { .with_mempool_service_config(mempool_service_config) .with_liveness_service_config(liveness_service_config) .with_consensus_manager(consensus_manager) - .with_validators(mock_validator, stateless_block_validator) + .with_validators( + mock_validator, + stateless_block_validator, + MockAccumDifficultyValidator {}, + ) .start(&mut runtime, data_path); let (bob_node, consensus_manager) = BaseNodeBuilder::new(network) .with_node_identity(bob_node_identity) diff --git a/base_layer/core/tests/target_difficulty.rs b/base_layer/core/tests/target_difficulty.rs new file mode 100644 index 0000000000..7796676fc5 --- /dev/null +++ b/base_layer/core/tests/target_difficulty.rs @@ -0,0 +1,206 @@ +// Copyright 2019. The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#[allow(dead_code)] +mod helpers; + +use helpers::pow_blockchain::{calculate_accumulated_difficulty, create_test_pow_blockchain}; +use tari_core::{ + consensus::{ConsensusManagerBuilder, Network}, + helpers::create_mem_db, + proof_of_work::{get_target_difficulty, PowAlgorithm}, +}; + +#[test] +fn test_target_difficulty_at_tip() { + let network = Network::LocalNet; + let consensus_manager = ConsensusManagerBuilder::new(network).build(); + let constants = consensus_manager.consensus_constants(); + let block_window = constants.get_difficulty_block_window() as usize; + let target_time = constants.get_diff_target_block_interval(); + let max_block_time = constants.get_difficulty_max_block_interval(); + let store = create_mem_db(&consensus_manager); + + let pow_algos = vec![ + PowAlgorithm::Blake, // GB default + PowAlgorithm::Blake, + PowAlgorithm::Monero, + PowAlgorithm::Blake, + PowAlgorithm::Blake, + PowAlgorithm::Monero, + PowAlgorithm::Monero, + PowAlgorithm::Blake, + PowAlgorithm::Monero, + PowAlgorithm::Blake, + ]; + create_test_pow_blockchain(&store, pow_algos.clone(), &consensus_manager); + + let height = store.get_metadata().unwrap().height_of_longest_chain.unwrap(); + let pow_algo = PowAlgorithm::Monero; + let target_difficulties = store.fetch_target_difficulties(pow_algo, height, block_window).unwrap(); + assert_eq!( + get_target_difficulty( + target_difficulties, + block_window, + target_time, + constants.min_pow_difficulty(pow_algo), + max_block_time + ), + Ok(calculate_accumulated_difficulty( + &store, + pow_algo, + vec![2, 5, 6, 8], + &constants + )) + ); + + let pow_algo = PowAlgorithm::Blake; + let target_difficulties = store.fetch_target_difficulties(pow_algo, height, block_window).unwrap(); + assert_eq!( + get_target_difficulty( + target_difficulties, + block_window, + target_time, + constants.min_pow_difficulty(pow_algo), + max_block_time + ), + Ok(calculate_accumulated_difficulty( + &store, + pow_algo, + vec![0, 1, 3, 4, 7, 9], + &constants + )) + ); +} + +#[test] +fn test_target_difficulty_with_height() { + let network = Network::LocalNet; + let consensus_manager = ConsensusManagerBuilder::new(network).build(); + let constants = consensus_manager.consensus_constants(); + let block_window = constants.get_difficulty_block_window() as usize; + let target_time = constants.get_diff_target_block_interval(); + let max_block_time = constants.get_difficulty_max_block_interval(); + let store = create_mem_db(&consensus_manager); + + let pow_algos = vec![ + PowAlgorithm::Blake, // GB default + PowAlgorithm::Monero, + PowAlgorithm::Blake, + PowAlgorithm::Blake, + PowAlgorithm::Monero, + PowAlgorithm::Blake, + ]; + create_test_pow_blockchain(&store, pow_algos, &consensus_manager); + + let pow_algo = PowAlgorithm::Monero; + assert_eq!( + get_target_difficulty( + store.fetch_target_difficulties(pow_algo, 5, block_window).unwrap(), + block_window, + target_time, + constants.min_pow_difficulty(pow_algo), + max_block_time + ), + Ok(calculate_accumulated_difficulty( + &store, + pow_algo, + vec![1, 4], + &constants + )) + ); + + let pow_algo = PowAlgorithm::Blake; + assert_eq!( + get_target_difficulty( + store.fetch_target_difficulties(pow_algo, 5, block_window).unwrap(), + block_window, + target_time, + constants.min_pow_difficulty(pow_algo), + max_block_time + ), + Ok(calculate_accumulated_difficulty( + &store, + pow_algo, + vec![0, 2, 3, 5], + &constants + )) + ); + + let pow_algo = PowAlgorithm::Monero; + assert_eq!( + get_target_difficulty( + store.fetch_target_difficulties(pow_algo, 2, block_window).unwrap(), + block_window, + target_time, + constants.min_pow_difficulty(pow_algo), + max_block_time + ), + Ok(calculate_accumulated_difficulty(&store, pow_algo, vec![1], &constants)) + ); + + let pow_algo = PowAlgorithm::Blake; + assert_eq!( + get_target_difficulty( + store.fetch_target_difficulties(pow_algo, 2, block_window).unwrap(), + block_window, + target_time, + constants.min_pow_difficulty(pow_algo), + max_block_time + ), + Ok(calculate_accumulated_difficulty( + &store, + pow_algo, + vec![0, 2], + &constants + )) + ); + + let pow_algo = PowAlgorithm::Monero; + assert_eq!( + get_target_difficulty( + store.fetch_target_difficulties(pow_algo, 3, block_window).unwrap(), + block_window, + target_time, + constants.min_pow_difficulty(pow_algo), + max_block_time + ), + Ok(calculate_accumulated_difficulty(&store, pow_algo, vec![1], &constants)) + ); + + let pow_algo = PowAlgorithm::Blake; + assert_eq!( + get_target_difficulty( + store.fetch_target_difficulties(pow_algo, 3, block_window).unwrap(), + block_window, + target_time, + constants.min_pow_difficulty(pow_algo), + max_block_time + ), + Ok(calculate_accumulated_difficulty( + &store, + pow_algo, + vec![0, 2, 3], + &constants + )) + ); +} diff --git a/base_layer/core/tests/wallet.rs b/base_layer/core/tests/wallet.rs index cf45160204..1727ecc515 100644 --- a/base_layer/core/tests/wallet.rs +++ b/base_layer/core/tests/wallet.rs @@ -149,6 +149,7 @@ fn wallet_base_node_integration_test() { ContactsServiceMemoryDatabase::new(), ) .unwrap(); + let mut alice_event_stream = alice_wallet.transaction_service.get_event_stream_fused(); alice_wallet .set_base_node_peer( @@ -247,14 +248,13 @@ fn wallet_base_node_integration_test() { )) .unwrap(); - let mut alice_event_stream = alice_wallet.transaction_service.get_event_stream_fused(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let mut delay = delay_for(Duration::from_secs(60)).fuse(); let mut broadcast = false; loop { futures::select! { event = alice_event_stream.select_next_some() => { - if let TransactionEvent::TransactionBroadcast(_e) = (*event).clone() { + if let TransactionEvent::TransactionBroadcast(_e) = (*event.unwrap()).clone() { broadcast = true; break; } @@ -323,15 +323,13 @@ fn wallet_base_node_integration_test() { assert!(found_tx_outputs == transaction.body.outputs().len()); }); - let mut alice_event_stream = alice_wallet.transaction_service.get_event_stream_fused(); - runtime.block_on(async { let mut delay = delay_for(Duration::from_secs(30)).fuse(); let mut mined = false; loop { futures::select! { event = alice_event_stream.select_next_some() => { - if let TransactionEvent::TransactionMined(_e) = (*event).clone() { + if let TransactionEvent::TransactionMined(_e) = (*event.unwrap()).clone() { mined = true; break; } diff --git a/base_layer/mmr/Cargo.toml b/base_layer/mmr/Cargo.toml index 3c88df2493..74d707dfa6 100644 --- a/base_layer/mmr/Cargo.toml +++ b/base_layer/mmr/Cargo.toml @@ -4,7 +4,7 @@ authors = ["The Tari Development Community"] description = "A Merkle Mountain Range implementation" repository = "https://github.com/tari-project/tari" license = "BSD-3-Clause" -version = "0.0.10" +version = "0.1.0" edition = "2018" [dependencies] @@ -14,7 +14,7 @@ digest = "0.8.0" log = "0.4" serde = { version = "1.0.97", features = ["derive"] } croaring = "=0.3.9" -tari_storage = { path = "../../infrastructure/storage", version = "^0.0" } +tari_storage = { path = "../../infrastructure/storage", version = "^0.1" } [dev-dependencies] criterion = "0.2" diff --git a/base_layer/mmr/src/mmr_cache.rs b/base_layer/mmr/src/mmr_cache.rs index 39b6f08ff9..27bf71ff3c 100644 --- a/base_layer/mmr/src/mmr_cache.rs +++ b/base_layer/mmr/src/mmr_cache.rs @@ -172,8 +172,9 @@ where .checkpoints .len() .map_err(|e| MerkleMountainRangeError::BackendError(e.to_string()))?; - if cp_count < self.base_cp_index { - // Checkpoint before the base MMR index, this will require a full reconstruction of the cache. + if cp_count <= self.base_cp_index { + // Checkpoint before or the same as the base MMR index, this will require a full reconstruction of the + // cache. self.create_base_mmr()?; self.create_curr_mmr()?; } else if cp_count < self.curr_cp_index { diff --git a/base_layer/mmr/tests/mmr_cache.rs b/base_layer/mmr/tests/mmr_cache.rs index c002f94642..7244d88fa4 100644 --- a/base_layer/mmr/tests/mmr_cache.rs +++ b/base_layer/mmr/tests/mmr_cache.rs @@ -89,3 +89,71 @@ fn create_cache_update_and_rewind() { assert!(mmr_cache.update().is_ok()); assert_eq!(mmr_cache.get_mmr_only_root(), Ok(cp1_mmr_only_root)); } + +#[test] +fn multiple_rewinds() { + let config = MmrCacheConfig { rewind_hist_len: 2 }; + let mut checkpoint_db = MemBackendVec::::new(); + let mut mmr_cache = MmrCache::::new(Vec::new(), checkpoint_db.clone(), config).unwrap(); + + // Add h1, h2, h3 and h4 checkpoints + let h1 = int_to_hash(1); + let h2 = int_to_hash(2); + let h3 = int_to_hash(3); + let h4 = int_to_hash(4); + let h5 = int_to_hash(5); + checkpoint_db + .push(MerkleCheckPoint::new(vec![h1.clone()], Bitmap::create())) + .unwrap(); + assert!(mmr_cache.update().is_ok()); + assert_eq!(mmr_cache.get_mmr_only_root(), Ok(combine_hashes(&[&h1]).clone())); + + checkpoint_db + .push(MerkleCheckPoint::new(vec![h2.clone()], Bitmap::create())) + .unwrap(); + assert!(mmr_cache.update().is_ok()); + let h1h2 = combine_hashes(&[&h1, &h2]); + assert_eq!(mmr_cache.get_mmr_only_root(), Ok(combine_hashes(&[&h1h2]).clone())); + + checkpoint_db + .push(MerkleCheckPoint::new(vec![h3.clone()], Bitmap::create())) + .unwrap(); + assert!(mmr_cache.update().is_ok()); + assert_eq!(mmr_cache.get_mmr_only_root(), Ok(combine_hashes(&[&h1h2, &h3]).clone())); + + checkpoint_db + .push(MerkleCheckPoint::new(vec![h4.clone()], Bitmap::create())) + .unwrap(); + assert!(mmr_cache.update().is_ok()); + let h3h4 = combine_hashes(&[&h3, &h4]); + let h1h2h3h4 = combine_hashes(&[&h1h2, &h3h4]); + assert_eq!(mmr_cache.get_mmr_only_root(), Ok(combine_hashes(&[&h1h2h3h4]).clone())); + assert_eq!(checkpoint_db.len(), Ok(4)); + + // Remove h4 checkpoint + checkpoint_db.truncate(3).unwrap(); + assert_eq!(checkpoint_db.len(), Ok(3)); + assert!(mmr_cache.update().is_ok()); + assert_eq!(mmr_cache.get_mmr_only_root(), Ok(combine_hashes(&[&h1h2, &h3]).clone())); + + // Add h5 checkpoint + checkpoint_db + .push(MerkleCheckPoint::new(vec![h5.clone()], Bitmap::create())) + .unwrap(); + assert!(mmr_cache.update().is_ok()); + let h3h5 = combine_hashes(&[&h3, &h5]); + let h1h2h3h5 = combine_hashes(&[&h1h2, &h3h5]); + assert_eq!(mmr_cache.get_mmr_only_root(), Ok(combine_hashes(&[&h1h2h3h5]).clone())); + + // Remove h5 checkpoint + checkpoint_db.truncate(3).unwrap(); + assert_eq!(checkpoint_db.len(), Ok(3)); + assert!(mmr_cache.update().is_ok()); + assert_eq!(mmr_cache.get_mmr_only_root(), Ok(combine_hashes(&[&h1h2, &h3]).clone())); + + // Remove h3 checkpoint + checkpoint_db.truncate(2).unwrap(); + assert_eq!(checkpoint_db.len(), Ok(2)); + assert!(mmr_cache.update().is_ok()); + assert_eq!(mmr_cache.get_mmr_only_root(), Ok(combine_hashes(&[&h1h2]).clone())); +} diff --git a/base_layer/p2p/Cargo.toml b/base_layer/p2p/Cargo.toml index b378442bc2..008bdac60f 100644 --- a/base_layer/p2p/Cargo.toml +++ b/base_layer/p2p/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tari_p2p" -version = "0.0.10" +version = "0.1.0" authors = ["The Tari Development community"] description = "Tari base layer-specific peer-to-peer communication features" repository = "https://github.com/tari-project/tari" @@ -14,13 +14,13 @@ test-mocks = [] [dependencies] tari_broadcast_channel = "^0.1" -tari_comms = { version = "^0.0", path = "../../comms"} -tari_comms_dht = { version = "^0.0", path = "../../comms/dht"} +tari_comms = { version = "^0.1", path = "../../comms"} +tari_comms_dht = { version = "^0.1", path = "../../comms/dht"} tari_crypto = { version = "^0.3" } tari_pubsub = "^0.1" tari_service_framework = { version = "^0.0", path = "../service_framework"} tari_shutdown = { version = "^0.0", path="../../infrastructure/shutdown" } -tari_storage = {version = "^0.0", path = "../../infrastructure/storage"} +tari_storage = {version = "^0.1", path = "../../infrastructure/storage"} tari_utilities = "^0.1" bytes = "0.4.12" @@ -60,4 +60,4 @@ default-features = false version = "0.12.0" [build-dependencies] -tari_common = { version = "^0.0", path="../../common"} +tari_common = { version = "^0.1", path="../../common"} diff --git a/base_layer/p2p/examples/pingpong.rs b/base_layer/p2p/examples/pingpong.rs index 226a122e92..657e83cbf9 100644 --- a/base_layer/p2p/examples/pingpong.rs +++ b/base_layer/p2p/examples/pingpong.rs @@ -202,12 +202,12 @@ mod pingpong { .add_initializer(LivenessInitializer::new( LivenessConfig { auto_ping_interval: None, // Some(Duration::from_secs(5)), - enable_auto_stored_message_request: false, enable_auto_join: true, ..Default::default() }, Arc::clone(&subscription_factory), dht.dht_requester(), + comms.connection_manager(), )) .finish(), ) diff --git a/base_layer/p2p/src/comms_connector/inbound_connector.rs b/base_layer/p2p/src/comms_connector/inbound_connector.rs index 9113d7699d..0c09a00602 100644 --- a/base_layer/p2p/src/comms_connector/inbound_connector.rs +++ b/base_layer/p2p/src/comms_connector/inbound_connector.rs @@ -59,7 +59,7 @@ where fn call(&mut self, msg: DecryptedDhtMessage) -> Self::Future { let mut sink = self.sink.clone(); async move { - let peer_message = Self::to_peer_message(msg)?; + let peer_message = Self::do_peer_message(msg)?; // If this fails there is something wrong with the sink and the pubsub middleware should not // continue sink.send(Arc::new(peer_message)) @@ -72,7 +72,7 @@ where } impl InboundDomainConnector { - fn to_peer_message(mut inbound_message: DecryptedDhtMessage) -> Result { + fn do_peer_message(mut inbound_message: DecryptedDhtMessage) -> Result { let envelope_body = inbound_message .success_mut() .ok_or_else(|| "Message failed to decrypt")?; @@ -88,12 +88,14 @@ impl InboundDomainConnector { let DecryptedDhtMessage { source_peer, dht_header, + authenticated_origin, .. } = inbound_message; let peer_message = PeerMessage { message_header: header, source_peer: Clone::clone(&*source_peer), + authenticated_origin, dht_header, body: msg_bytes, }; @@ -116,7 +118,7 @@ where } fn start_send(mut self: Pin<&mut Self>, item: DecryptedDhtMessage) -> Result<(), Self::Error> { - let item = Self::to_peer_message(item)?; + let item = Self::do_peer_message(item)?; Pin::new(&mut self.sink) .start_send(Arc::new(item)) .map_err(PipelineError::from_debug) @@ -141,21 +143,17 @@ mod test { use crate::test_utils::{make_dht_inbound_message, make_node_identity}; use futures::{channel::mpsc, executor::block_on, StreamExt}; use tari_comms::{message::MessageExt, wrap_in_envelope_body}; - use tari_comms_dht::{domain_message::MessageHeader, envelope::DhtMessageFlags}; + use tari_comms_dht::domain_message::MessageHeader; use tower::ServiceExt; #[tokio_macros::test_basic] async fn handle_message() { let (tx, mut rx) = mpsc::channel(1); let header = MessageHeader::new(123); - let msg = wrap_in_envelope_body!(header, b"my message".to_vec()).unwrap(); - - let inbound_message = make_dht_inbound_message( - &make_node_identity(), - msg.to_encoded_bytes().unwrap(), - DhtMessageFlags::empty(), - ); - let decrypted = DecryptedDhtMessage::succeeded(msg, inbound_message); + let msg = wrap_in_envelope_body!(header, b"my message".to_vec()); + + let inbound_message = make_dht_inbound_message(&make_node_identity(), msg.to_encoded_bytes()); + let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); InboundDomainConnector::new(tx).oneshot(decrypted).await.unwrap(); let peer_message = block_on(rx.next()).unwrap(); @@ -167,14 +165,10 @@ mod test { async fn send_on_sink() { let (tx, mut rx) = mpsc::channel(1); let header = MessageHeader::new(123); - let msg = wrap_in_envelope_body!(header, b"my message".to_vec()).unwrap(); + let msg = wrap_in_envelope_body!(header, b"my message".to_vec()); - let inbound_message = make_dht_inbound_message( - &make_node_identity(), - msg.to_encoded_bytes().unwrap(), - DhtMessageFlags::empty(), - ); - let decrypted = DecryptedDhtMessage::succeeded(msg, inbound_message); + let inbound_message = make_dht_inbound_message(&make_node_identity(), msg.to_encoded_bytes()); + let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); InboundDomainConnector::new(tx).send(decrypted).await.unwrap(); @@ -187,14 +181,10 @@ mod test { async fn handle_message_fail_deserialize() { let (tx, mut rx) = mpsc::channel(1); let header = b"dodgy header".to_vec(); - let msg = wrap_in_envelope_body!(header, b"message".to_vec()).unwrap(); - - let inbound_message = make_dht_inbound_message( - &make_node_identity(), - msg.to_encoded_bytes().unwrap(), - DhtMessageFlags::empty(), - ); - let decrypted = DecryptedDhtMessage::succeeded(msg, inbound_message); + let msg = wrap_in_envelope_body!(header, b"message".to_vec()); + + let inbound_message = make_dht_inbound_message(&make_node_identity(), msg.to_encoded_bytes()); + let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); InboundDomainConnector::new(tx).oneshot(decrypted).await.unwrap_err(); assert!(rx.try_next().unwrap().is_none()); @@ -206,13 +196,9 @@ mod test { // from it's call function let (tx, _) = mpsc::channel(1); let header = MessageHeader::new(123); - let msg = wrap_in_envelope_body!(header, b"my message".to_vec()).unwrap(); - let inbound_message = make_dht_inbound_message( - &make_node_identity(), - msg.to_encoded_bytes().unwrap(), - DhtMessageFlags::empty(), - ); - let decrypted = DecryptedDhtMessage::succeeded(msg, inbound_message); + let msg = wrap_in_envelope_body!(header, b"my message".to_vec()); + let inbound_message = make_dht_inbound_message(&make_node_identity(), msg.to_encoded_bytes()); + let decrypted = DecryptedDhtMessage::succeeded(msg, None, inbound_message); let result = InboundDomainConnector::new(tx).oneshot(decrypted).await; assert!(result.is_err()); } diff --git a/base_layer/p2p/src/comms_connector/peer_message.rs b/base_layer/p2p/src/comms_connector/peer_message.rs index f1d53c205f..4a62e393b6 100644 --- a/base_layer/p2p/src/comms_connector/peer_message.rs +++ b/base_layer/p2p/src/comms_connector/peer_message.rs @@ -34,28 +34,13 @@ pub struct PeerMessage { pub source_peer: Peer, /// Domain message header pub message_header: MessageHeader, + /// This messages authenticated origin, otherwise None + pub authenticated_origin: Option, /// Serialized message data pub body: Vec, } impl PeerMessage { - pub fn new(dht_header: DhtMessageHeader, source_peer: Peer, message_header: MessageHeader, body: Vec) -> Self { - Self { - body, - message_header, - dht_header, - source_peer, - } - } - - pub fn origin_public_key(&self) -> &CommsPublicKey { - self.dht_header - .origin - .as_ref() - .map(|o| &o.public_key) - .unwrap_or(&self.source_peer.public_key) - } - pub fn decode_message(&self) -> Result where T: prost::Message + Default { let msg = T::decode(self.body.as_slice())?; diff --git a/base_layer/p2p/src/domain_message.rs b/base_layer/p2p/src/domain_message.rs index 7aea2a50c8..0299a466fb 100644 --- a/base_layer/p2p/src/domain_message.rs +++ b/base_layer/p2p/src/domain_message.rs @@ -32,6 +32,8 @@ pub struct DomainMessage { /// This DHT header of this message. If `DhtMessageHeader::origin_public_key` is different from the /// `source_peer.public_key`, this message was forwarded. pub dht_header: DhtMessageHeader, + /// The authenticated origin public key of this message or None a message origin was not provided. + pub authenticated_origin: Option, /// The domain-level message pub inner: T, } @@ -48,32 +50,15 @@ impl DomainMessage { /// Consumes this object returning the public key of the original sender of this message and the message itself pub fn into_origin_and_inner(self) -> (CommsPublicKey, T) { let inner = self.inner; - let pk = self - .dht_header - .origin - .map(|o| o.public_key) - .unwrap_or(self.source_peer.public_key); + let pk = self.authenticated_origin.unwrap_or(self.source_peer.public_key); (pk, inner) } - /// Returns true of this message was forwarded from another peer, otherwise false - pub fn is_forwarded(&self) -> bool { - self.dht_header - .origin - .as_ref() - // If the source and origin are different, then the message was forwarded - .map(|o| o.public_key != self.source_peer.public_key) - // Otherwise, if no origin is specified, the message was sent directly from the peer - .unwrap_or(false) - } - /// Returns the public key that sent this message. If no origin is specified, then the source peer /// sent this message. pub fn origin_public_key(&self) -> &CommsPublicKey { - self.dht_header - .origin + self.authenticated_origin .as_ref() - .map(|o| &o.public_key) .unwrap_or(&self.source_peer.public_key) } @@ -88,6 +73,7 @@ impl DomainMessage { DomainMessage { source_peer: self.source_peer, dht_header: self.dht_header, + authenticated_origin: self.authenticated_origin, inner, } } @@ -103,6 +89,7 @@ impl DomainMessage { Ok(DomainMessage { source_peer: self.source_peer, dht_header: self.dht_header, + authenticated_origin: self.authenticated_origin, inner, }) } diff --git a/base_layer/p2p/src/initialization.rs b/base_layer/p2p/src/initialization.rs index 345b3e3f6c..01b651acc7 100644 --- a/base_layer/p2p/src/initialization.rs +++ b/base_layer/p2p/src/initialization.rs @@ -41,7 +41,7 @@ use tari_comms::{ CommsBuilderError, CommsNode, }; -use tari_comms_dht::{Dht, DhtBuilder, DhtConfig}; +use tari_comms_dht::{Dht, DhtBuilder, DhtConfig, DhtInitializationError}; use tari_storage::{lmdb_store::LMDBBuilder, LMDBWrapper}; use tower::ServiceBuilder; @@ -50,6 +50,7 @@ const LOG_TARGET: &str = "b::p2p::initialization"; #[derive(Debug, Error)] pub enum CommsInitializationError { CommsBuilderError(CommsBuilderError), + DhtInitializationError(DhtInitializationError), HiddenServiceBuilderError(tor::HiddenServiceBuilderError), #[error(non_std, no_from, msg_embedded)] InvalidLivenessCidrs(String), @@ -104,7 +105,7 @@ where }; let datastore = LMDBBuilder::new() .set_path(data_path) - .set_environment_size(10) + .set_environment_size(50) .set_max_number_of_databases(1) .add_database(&peer_database_name, lmdb_zero::db::CREATE) .build() @@ -136,7 +137,8 @@ where ) .local_test() .with_discovery_timeout(discovery_request_timeout) - .finish(); + .finish() + .await?; let dht_outbound_layer = dht.outbound_middleware_layer(); @@ -268,7 +270,7 @@ where { let datastore = LMDBBuilder::new() .set_path(&config.datastore_path) - .set_environment_size(10) + .set_environment_size(50) .set_max_number_of_databases(1) .add_database(&config.peer_database_name, lmdb_zero::db::CREATE) .build() @@ -297,7 +299,8 @@ where comms.shutdown_signal(), ) .with_config(config.dht) - .finish(); + .finish() + .await?; let dht_outbound_layer = dht.outbound_middleware_layer(); diff --git a/base_layer/p2p/src/services/liveness/config.rs b/base_layer/p2p/src/services/liveness/config.rs index e5bdd611b1..7f3e3ffb78 100644 --- a/base_layer/p2p/src/services/liveness/config.rs +++ b/base_layer/p2p/src/services/liveness/config.rs @@ -29,10 +29,12 @@ pub struct LivenessConfig { pub auto_ping_interval: Option, /// Set to true to enable automatically joining the network on node startup (default: false) pub enable_auto_join: bool, - /// Set to true to enable a request for stored messages on node startup (default: false) - pub enable_auto_stored_message_request: bool, - /// The length of time between querying peer manager for closest neighbours. (default: 5mins) + /// The length of time between querying peer manager for closest neighbours. (default: 2 minutes) pub refresh_neighbours_interval: Duration, + /// The length of time between querying peer manager for random neighbours. (default: 2 hours) + pub refresh_random_pool_interval: Duration, + /// The ratio of random to neighbouring peers to include in ping rounds (Default: 0) + pub random_peer_selection_ratio: f32, } impl Default for LivenessConfig { @@ -40,8 +42,9 @@ impl Default for LivenessConfig { Self { auto_ping_interval: None, enable_auto_join: false, - enable_auto_stored_message_request: false, - refresh_neighbours_interval: Duration::from_secs(3 * 60), + refresh_neighbours_interval: Duration::from_secs(2 * 60), + refresh_random_pool_interval: Duration::from_secs(2 * 60 * 60), + random_peer_selection_ratio: 0.0, } } } diff --git a/base_layer/p2p/src/services/liveness/handle.rs b/base_layer/p2p/src/services/liveness/handle.rs index 6709028a18..de36868d9e 100644 --- a/base_layer/p2p/src/services/liveness/handle.rs +++ b/base_layer/p2p/src/services/liveness/handle.rs @@ -41,8 +41,6 @@ pub enum LivenessRequest { GetAvgLatency(NodeId), /// Set the metadata attached to each pong message SetPongMetadata(MetadataKey, Vec), - /// Request the number of active neighbours - GetNumActiveNeighbours, /// Add NodeId to be monitored AddNodeId(NodeId), /// Get stats for a monitored NodeId diff --git a/base_layer/p2p/src/services/liveness/mock.rs b/base_layer/p2p/src/services/liveness/mock.rs index cc6d923456..b95996e2a1 100644 --- a/base_layer/p2p/src/services/liveness/mock.rs +++ b/base_layer/p2p/src/services/liveness/mock.rs @@ -135,9 +135,6 @@ impl LivenessMock { SetPongMetadata(_, _) => { reply_tx.send(Ok(LivenessResponse::Ok)).unwrap(); }, - GetNumActiveNeighbours => { - reply_tx.send(Ok(LivenessResponse::NumActiveNeighbours(8))).unwrap(); - }, AddNodeId(_n) => reply_tx.send(Ok(LivenessResponse::NodeIdAdded)).unwrap(), GetNodeIdStats(_n) => reply_tx .send(Ok(LivenessResponse::NodeIdStats(NodeStats::new()))) diff --git a/base_layer/p2p/src/services/liveness/mod.rs b/base_layer/p2p/src/services/liveness/mod.rs index badd4eb846..d8f6bb47e2 100644 --- a/base_layer/p2p/src/services/liveness/mod.rs +++ b/base_layer/p2p/src/services/liveness/mod.rs @@ -39,7 +39,7 @@ mod config; pub mod error; mod handle; mod message; -mod neighbours; +mod peer_pool; mod service; mod state; @@ -75,6 +75,7 @@ pub use self::{ state::Metadata, }; pub use crate::proto::liveness::MetadataKey; +use tari_comms::connection_manager::ConnectionManagerRequester; const LOG_TARGET: &str = "p2p::services::liveness"; @@ -83,6 +84,7 @@ pub struct LivenessInitializer { config: Option, inbound_message_subscription_factory: Arc>>, dht_requester: Option, + connection_manager_requester: Option, } impl LivenessInitializer { @@ -91,12 +93,14 @@ impl LivenessInitializer { config: LivenessConfig, inbound_message_subscription_factory: Arc>>, dht_requester: DhtRequester, + connection_manager_requester: ConnectionManagerRequester, ) -> Self { Self { config: Some(config), inbound_message_subscription_factory, dht_requester: Some(dht_requester), + connection_manager_requester: Some(connection_manager_requester), } } @@ -136,6 +140,11 @@ impl ServiceInitializer for LivenessInitializer { .take() .expect("Liveness service initialized more than once."); + let connection_manager_requester = self + .connection_manager_requester + .take() + .expect("Liveness service initialized without a ConnectionManagerRequester"); + // Register handle before waiting for handles to be ready handles_fut.register(liveness_handle); @@ -165,25 +174,6 @@ impl ServiceInitializer for LivenessInitializer { } } - if config.enable_auto_stored_message_request { - // TODO: Record when store message request was last requested - // and request messages from after that time - match dht_requester.send_request_stored_messages().await { - Ok(_) => { - trace!( - target: LOG_TARGET, - "Stored message request has been sent to closest peers", - ); - }, - Err(err) => { - error!( - target: LOG_TARGET, - "Failed to send stored message on startup because '{}'", err - ); - }, - } - } - let state = LivenessState::new(); let service = LivenessService::new( @@ -192,6 +182,7 @@ impl ServiceInitializer for LivenessInitializer { ping_stream, state, dht_requester, + connection_manager_requester, outbound_handle, publisher, shutdown, diff --git a/base_layer/p2p/src/services/liveness/neighbours.rs b/base_layer/p2p/src/services/liveness/neighbours.rs deleted file mode 100644 index b7ac397c73..0000000000 --- a/base_layer/p2p/src/services/liveness/neighbours.rs +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2019, The Tari Project -// -// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the -// following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following -// disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the -// following disclaimer in the documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote -// products derived from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, -// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -use chrono::{NaiveDateTime, Utc}; -use std::time::Duration; -use tari_comms::peer_manager::{NodeId, Peer}; - -pub struct Neighbours { - last_updated: Option, - peers: Vec, - stale_interval: Duration, -} - -impl Neighbours { - pub fn new(stale_interval: Duration) -> Self { - Self { - last_updated: None, - peers: Vec::default(), - stale_interval, - } - } - - pub fn is_fresh(&self) -> bool { - self.last_updated - .map(|dt| { - let chrono_dt = chrono::Duration::from_std(self.stale_interval) - .expect("Neighbours::stale_interval is too large (overflows chrono::Duration::from_std)"); - dt.checked_add_signed(chrono_dt) - .map(|dt| dt < Utc::now().naive_utc()) - .expect("Neighbours::stale_interval is too large (overflows i32 when added to NaiveDateTime)") - }) - .unwrap_or(false) - } - - pub fn set_peers(&mut self, peers: Vec) { - self.peers = peers; - self.last_updated = Some(Utc::now().naive_utc()); - } - - pub fn peers(&self) -> &[Peer] { - &self.peers - } - - pub fn contains(&self, node_id: &NodeId) -> bool { - self.peers.iter().map(|p| &p.node_id).any(|n| n == node_id) - } -} diff --git a/base_layer/p2p/src/services/liveness/peer_pool.rs b/base_layer/p2p/src/services/liveness/peer_pool.rs new file mode 100644 index 0000000000..7cc2187153 --- /dev/null +++ b/base_layer/p2p/src/services/liveness/peer_pool.rs @@ -0,0 +1,120 @@ +// Copyright 2019, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use chrono::{NaiveDateTime, Utc}; +use rand::{rngs::OsRng, seq::SliceRandom}; +use std::time::Duration; +use tari_comms::peer_manager::NodeId; + +pub struct PeerPool { + last_updated: Option, + node_ids: Vec, + stale_interval: Duration, +} + +impl PeerPool { + pub fn new(stale_interval: Duration) -> Self { + Self { + last_updated: None, + node_ids: Vec::default(), + stale_interval, + } + } + + pub fn len(&self) -> usize { + self.node_ids.len() + } + + pub fn is_stale(&self) -> bool { + self.last_updated + .map(|dt| { + let chrono_dt = chrono::Duration::from_std(self.stale_interval) + .expect("PeerPool::stale_interval is too large (overflows chrono::Duration::from_std)"); + dt.checked_add_signed(chrono_dt) + .map(|dt| dt < Utc::now().naive_utc()) + .expect("PeerPool::stale_interval is too large (overflows i32 when added to NaiveDateTime)") + }) + .unwrap_or(true) + } + + pub fn set_node_ids(&mut self, node_ids: Vec) { + self.node_ids = node_ids; + self.last_updated = Some(Utc::now().naive_utc()); + } + + pub fn remove(&mut self, node_id: &NodeId) -> Option { + let pos = self.node_ids.iter().position(|n| n == node_id)?; + Some(self.node_ids.remove(pos)) + } + + pub fn push(&mut self, node_id: NodeId) { + self.node_ids.push(node_id) + } + + pub fn node_ids(&self) -> &[NodeId] { + &self.node_ids + } + + pub fn sample(&self, n: usize) -> Vec<&NodeId> { + self.node_ids.choose_multiple(&mut OsRng, n).collect() + } + + pub fn contains(&self, node_id: &NodeId) -> bool { + self.node_ids.iter().any(|n| n == node_id) + } +} +#[cfg(test)] +mod test { + use super::*; + use crate::test_utils::make_node_id; + use std::iter::repeat_with; + + #[test] + fn is_stale() { + let mut pool = PeerPool::new(Duration::from_secs(100)); + assert_eq!(pool.is_stale(), true); + pool.set_node_ids(vec![]); + assert_eq!(pool.is_stale(), false); + pool.last_updated = Some( + Utc::now() + .naive_utc() + .checked_sub_signed(chrono::Duration::from_std(Duration::from_secs(101)).unwrap()) + .unwrap(), + ); + assert_eq!(pool.is_stale(), true); + } + + #[test] + fn sample() { + let mut pool = PeerPool::new(Duration::from_secs(100)); + let node_ids = repeat_with(make_node_id).take(10).collect::>(); + pool.set_node_ids(node_ids.clone()); + let mut sample = pool.sample(4); + assert_eq!(sample.len(), 4); + node_ids.into_iter().for_each(|node_id| { + if let Some(pos) = sample.iter().position(|n| *n == &node_id) { + sample.remove(pos); + } + }); + assert_eq!(sample.len(), 0); + } +} diff --git a/base_layer/p2p/src/services/liveness/service.rs b/base_layer/p2p/src/services/liveness/service.rs index 8d71de1c83..71f4ed9518 100644 --- a/base_layer/p2p/src/services/liveness/service.rs +++ b/base_layer/p2p/src/services/liveness/service.rs @@ -31,16 +31,18 @@ use super::{ }; use crate::{ domain_message::DomainMessage, - services::liveness::{neighbours::Neighbours, LivenessEvent, PongEvent}, + services::liveness::{peer_pool::PeerPool, LivenessEvent, PongEvent}, tari_message::TariMessageType, }; -use futures::{pin_mut, stream::StreamExt, task::Context, SinkExt, Stream}; +use futures::{future::Either, pin_mut, stream::StreamExt, SinkExt, Stream}; use log::*; -use std::{pin::Pin, task::Poll, time::Instant}; +use std::{cmp, time::Instant}; use tari_broadcast_channel::Publisher; use tari_comms::{ - peer_manager::{NodeId, Peer}, + connection_manager::ConnectionManagerRequester, + peer_manager::NodeId, types::CommsPublicKey, + ConnectionManagerEvent, }; use tari_comms_dht::{ broadcast_strategy::BroadcastStrategy, @@ -52,10 +54,7 @@ use tari_service_framework::RequestContext; use tari_shutdown::ShutdownSignal; use tokio::time; -/// Service responsible for testing Liveness for Peers. -/// -/// Very basic global ping and pong counter stats are implemented. In future, -/// peer latency and availability stats will be added. +/// Service responsible for testing Liveness of Peers. pub struct LivenessService { config: LivenessConfig, request_rx: Option, @@ -64,11 +63,18 @@ pub struct LivenessService { dht_requester: DhtRequester, oms_handle: OutboundMessageRequester, event_publisher: Publisher, + connection_manager: ConnectionManagerRequester, shutdown_signal: Option, - neighbours: Neighbours, + neighbours: PeerPool, + random_peers: PeerPool, + active_pool: PeerPool, } -impl LivenessService { +impl LivenessService +where + TPingStream: Stream>, + THandleStream: Stream>>, +{ #[allow(clippy::too_many_arguments)] pub fn new( config: LivenessConfig, @@ -76,6 +82,7 @@ impl LivenessService { ping_stream: TPingStream, state: LivenessState, dht_requester: DhtRequester, + connection_manager: ConnectionManagerRequester, oms_handle: OutboundMessageRequester, event_publisher: Publisher, shutdown_signal: ShutdownSignal, @@ -87,20 +94,19 @@ impl LivenessService { state, dht_requester, oms_handle, + connection_manager, event_publisher, shutdown_signal: Some(shutdown_signal), - neighbours: Neighbours::new(config.refresh_neighbours_interval), + neighbours: PeerPool::new(config.refresh_neighbours_interval), + random_peers: PeerPool::new(config.refresh_random_pool_interval), + active_pool: PeerPool::new(config.refresh_neighbours_interval), config, } } -} -impl LivenessService -where - TPingStream: Stream>, - THandleStream: Stream>>, -{ pub async fn run(mut self) { + info!(target: LOG_TARGET, "Liveness service started"); + debug!(target: LOG_TARGET, "Config = {:?}", self.config); let ping_stream = self.ping_stream.take().expect("ping_stream cannot be None").fuse(); pin_mut!(ping_stream); @@ -108,11 +114,13 @@ where pin_mut!(request_stream); let mut ping_tick = match self.config.auto_ping_interval { - Some(interval) => EitherStream::Left(time::interval_at((Instant::now() + interval).into(), interval)), - None => EitherStream::Right(futures::stream::iter(Vec::new())), + Some(interval) => Either::Left(time::interval_at((Instant::now() + interval).into(), interval)), + None => Either::Right(futures::stream::iter(Vec::new())), } .fuse(); + let mut connection_manager_events = self.connection_manager.get_event_subscription().fuse(); + let mut shutdown_signal = self .shutdown_signal .take() @@ -129,15 +137,24 @@ where }); }, - _ = ping_tick.select_next_some() => { - let _ = self.ping_neighbours().await.or_else(|err| { - error!(target: LOG_TARGET, "Error when pinging neighbours: {:?}", err); - Err(err) - }); - let _ = self.ping_monitored_node_ids().await.or_else(|err| { - error!(target: LOG_TARGET, "Error when pinging monitored nodes: {:?}", err); + event = connection_manager_events.select_next_some() => { + if let Ok(event) = event { + let _ = self.handle_connection_manager_event(&*event).await.or_else(|err| { + error!(target: LOG_TARGET, "Error when handling connection manager event: {:?}", err); Err(err) }); + } + }, + + _ = ping_tick.select_next_some() => { + let _ = self.ping_active_pool().await.or_else(|err| { + error!(target: LOG_TARGET, "Error when pinging peers: {:?}", err); + Err(err) + }); + let _ = self.ping_monitored_node_ids().await.or_else(|err| { + error!(target: LOG_TARGET, "Error when pinging monitored nodes: {:?}", err); + Err(err) + }); }, // Incoming messages from the Comms layer msg = ping_stream.select_next_some() => { @@ -173,9 +190,18 @@ where self.publish_event(LivenessEvent::ReceivedPing).await?; }, PingPong::Pong => { - self.update_neighbours_if_stale().await?; - let maybe_latency = self.state.record_pong(ping_pong_msg.nonce); + if !self.state.is_inflight(ping_pong_msg.nonce) { + warn!( + target: LOG_TARGET, + "Received Pong that was not requested from '{}'. Ignoring it.", + node_id.short_str() + ); + return Ok(()); + } + let is_neighbour = self.neighbours.contains(&node_id); + self.refresh_peer_pools_if_stale().await?; + let maybe_latency = self.state.record_pong(ping_pong_msg.nonce); let is_monitored = self.state.is_monitored_node_id(&node_id); trace!( @@ -201,6 +227,18 @@ where Ok(()) } + async fn handle_connection_manager_event(&mut self, event: &ConnectionManagerEvent) -> Result<(), LivenessError> { + use ConnectionManagerEvent::*; + match event { + PeerDisconnected(node_id) | PeerConnectFailed(node_id, _) => { + self.replace_failed_peer_if_required(node_id).await?; + }, + _ => {}, + } + + Ok(()) + } + async fn send_ping(&mut self, node_id: NodeId) -> Result<(), LivenessError> { let msg = PingPongMessage::ping(); self.state.add_inflight_ping(msg.nonce, &node_id); @@ -261,10 +299,6 @@ where self.state.set_pong_metadata_entry(key, value); Ok(LivenessResponse::Ok) }, - GetNumActiveNeighbours => { - let num_active_neighbours = self.state.num_active_neighbours(); - Ok(LivenessResponse::NumActiveNeighbours(num_active_neighbours)) - }, AddNodeId(node_id) => { self.state.add_node_id(&node_id); self.send_ping(node_id.clone()).await?; @@ -277,38 +311,148 @@ where } } - async fn update_neighbours_if_stale(&mut self) -> Result<&[Peer], LivenessError> { - if self.neighbours.is_fresh() { - return Ok(self.neighbours.peers()); + async fn replace_failed_peer_if_required(&mut self, node_id: &NodeId) -> Result<(), LivenessError> { + if self.neighbours.contains(node_id) { + self.refresh_neighbour_pool().await?; + return Ok(()); + } + + if self.should_include_random_peers() && self.random_peers.contains(node_id) { + // Replace the peer in the random peer pool with another random peer + let excluded = self + .neighbours + .node_ids() + .into_iter() + .chain(vec![node_id]) + .cloned() + .collect(); + + if let Some(peer) = self + .dht_requester + .select_peers(BroadcastStrategy::Random(1, excluded)) + .await? + .pop() + { + self.random_peers.remove(node_id); + self.random_peers.push(peer.node_id) + } + } + + Ok(()) + } + + fn should_include_random_peers(&self) -> bool { + self.config.random_peer_selection_ratio > 0.0 + } + + async fn refresh_peer_pools_if_stale(&mut self) -> Result<(), LivenessError> { + let is_stale = self.neighbours.is_stale(); + if is_stale { + self.refresh_neighbour_pool().await?; + } + + if self.should_include_random_peers() && self.random_peers.is_stale() { + self.refresh_random_peer_pool().await?; + } + + if is_stale { + self.refresh_active_peer_pool(); + + info!( + target: LOG_TARGET, + "Selected {} active peers liveness neighbourhood out of a pool of {} neighbouring peers and {} random \ + peers", + self.active_pool.len(), + self.neighbours.len(), + self.random_peers.len() + ); } - let peers = self + Ok(()) + } + + fn refresh_active_peer_pool(&mut self) { + let rand_peer_ratio = 1.0f32.min(0.0f32.max(self.config.random_peer_selection_ratio)); + let desired_neighbours = (self.neighbours.len() as f32 * (1.0 - rand_peer_ratio)).ceil() as usize; + let desired_random = (self.neighbours.len() as f32 * rand_peer_ratio).ceil() as usize; + + let num_random = cmp::min(desired_random, self.random_peers.len()); + let num_neighbours = self.neighbours.len() - num_random; + debug!( + target: LOG_TARGET, + "Adding {} neighbouring peers (wanted = {}) and {} randomly selected (wanted = {}) peer(s) to active peer \ + pool", + num_neighbours, + desired_neighbours, + num_random, + desired_random + ); + + let mut active_node_ids = self.neighbours.sample(num_neighbours); + active_node_ids.extend(self.random_peers.sample(num_random)); + self.active_pool + .set_node_ids(active_node_ids.into_iter().cloned().collect()); + self.state.set_num_active_peers(self.active_pool.len()); + } + + async fn refresh_random_peer_pool(&mut self) -> Result<(), LivenessError> { + let excluded = self.neighbours.node_ids().into_iter().cloned().collect(); + + // Select a pool of random peers the same length as neighbouring peers + let random_peers = self .dht_requester - .select_peers(BroadcastStrategy::Neighbours(Vec::new(), false)) + .select_peers(BroadcastStrategy::Random(self.neighbours.len(), excluded)) .await?; - self.state.set_num_active_neighbours(peers.len()); - self.neighbours.set_peers(peers); + if random_peers.is_empty() { + warn!(target: LOG_TARGET, "No random peers selected for this round of pings"); + } + let new_node_ids = random_peers.into_iter().map(|p| p.node_id).collect::>(); + let removed = new_node_ids + .iter() + .filter(|n| self.random_peers.contains(*n)) + .collect::>(); + debug!(target: LOG_TARGET, "Removed {} random peer(s)", removed.len()); + for node_id in removed { + if let Err(err) = self.connection_manager.disconnect_peer(node_id.clone()).await { + error!(target: LOG_TARGET, "Failed to disconnect peer: {:?}", err); + } + } + + self.random_peers.set_node_ids(new_node_ids); - Ok(self.neighbours.peers()) + Ok(()) } - async fn ping_neighbours(&mut self) -> Result<(), LivenessError> { - self.update_neighbours_if_stale().await?; - let peers = self.neighbours.peers(); - let len_peers = peers.len(); - trace!( + async fn refresh_neighbour_pool(&mut self) -> Result<(), LivenessError> { + let neighbours = self + .dht_requester + .select_peers(BroadcastStrategy::Neighbours(Vec::new(), false)) + .await?; + + debug!( target: LOG_TARGET, - "Sending liveness ping to {} neighbour(s)", - len_peers + "Setting active peers ({} peer(s))", + neighbours.len() ); + self.neighbours + .set_node_ids(neighbours.into_iter().map(|p| p.node_id).collect()); + + Ok(()) + } + + async fn ping_active_pool(&mut self) -> Result<(), LivenessError> { + self.refresh_peer_pools_if_stale().await?; + let node_ids = self.active_pool.node_ids(); + let len_peers = node_ids.len(); + trace!(target: LOG_TARGET, "Sending liveness ping to {} peer(s)", len_peers); - for peer in peers { + for node_id in node_ids { let msg = PingPongMessage::ping(); - self.state.add_inflight_ping(msg.nonce, &peer.node_id); + self.state.add_inflight_ping(msg.nonce, &node_id); self.oms_handle - .send_direct( - peer.public_key.clone(), + .send_direct_node_id( + node_id.clone(), OutboundEncryption::None, OutboundDomainMessage::new(TariMessageType::PingPong, msg), ) @@ -364,33 +508,6 @@ where } } -// Unfortunately, `stream::Either` doesn't exist yet in futures-0.3.0 -enum EitherStream { - Left(A), - Right(B), -} - -impl Stream for EitherStream -where - A: Stream + Unpin, - B: Stream + Unpin, -{ - type Item = A::Item; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match &mut *self { - EitherStream::Left(stream) => { - pin_mut!(stream); - stream.poll_next(cx) - }, - EitherStream::Right(stream) => { - pin_mut!(stream); - stream.poll_next(cx) - }, - } - } -} - #[cfg(test)] mod test { use super::*; @@ -414,7 +531,8 @@ mod test { use tari_crypto::keys::PublicKey; use tari_service_framework::reply_channel; use tari_shutdown::Shutdown; - use tokio::task; + use tari_test_utils::collect_stream; + use tokio::{sync::broadcast, task}; #[tokio_macros::test_basic] async fn get_ping_pong_count() { @@ -435,6 +553,10 @@ mod test { let (dht_tx, _) = mpsc::channel(10); let dht_requester = DhtRequester::new(dht_tx); + let (tx, _) = mpsc::channel(0); + let (event_tx, _) = broadcast::channel(1); + let connection_manager = ConnectionManagerRequester::new(tx, event_tx); + let shutdown = Shutdown::new(); let service = LivenessService::new( Default::default(), @@ -442,6 +564,7 @@ mod test { stream::empty(), state, dht_requester, + connection_manager, oms_handle, publisher, shutdown.to_signal(), @@ -473,6 +596,10 @@ mod test { let (dht_tx, _) = mpsc::channel(10); let dht_requester = DhtRequester::new(dht_tx); + let (tx, _) = mpsc::channel(0); + let (event_tx, _) = broadcast::channel(1); + let connection_manager = ConnectionManagerRequester::new(tx, event_tx); + let shutdown = Shutdown::new(); let service = LivenessService::new( Default::default(), @@ -480,6 +607,7 @@ mod test { stream::empty(), state, dht_requester, + connection_manager, oms_handle, publisher, shutdown.to_signal(), @@ -494,7 +622,7 @@ mod test { task::spawn(async move { match outbound_rx.select_next_some().await { DhtOutboundRequest::SendMessage(_, _, reply_tx) => { - reply_tx.send(SendMessageResponse::Queued(vec![])).unwrap(); + reply_tx.send(SendMessageResponse::Queued(vec![].into())).unwrap(); }, } }); @@ -513,13 +641,16 @@ mod test { &[], ); DomainMessage { - dht_header: DhtMessageHeader::new( - Default::default(), - DhtMessageType::None, - None, - Network::LocalTest, - Default::default(), - ), + dht_header: DhtMessageHeader { + version: 0, + destination: Default::default(), + origin_mac: Vec::new(), + ephemeral_public_key: None, + message_type: DhtMessageType::None, + network: Network::LocalTest, + flags: Default::default(), + }, + authenticated_origin: None, source_peer, inner, } @@ -541,6 +672,11 @@ mod test { let dht_requester = DhtRequester::new(dht_tx); // Setup liveness service let (publisher, _subscriber) = broadcast_channel::bounded(100); + + let (tx, _) = mpsc::channel(0); + let (event_tx, _) = broadcast::channel(1); + let connection_manager = ConnectionManagerRequester::new(tx, event_tx); + let shutdown = Shutdown::new(); let service = LivenessService::new( Default::default(), @@ -548,6 +684,7 @@ mod test { pingpong_stream, state, dht_requester, + connection_manager, oms_handle, publisher, shutdown.to_signal(), @@ -568,12 +705,13 @@ mod test { let mut metadata = Metadata::new(); metadata.insert(MetadataKey::ChainMetadata, b"dummy-data".to_vec()); - let msg = create_dummy_message(PingPongMessage::pong_with_metadata(123, metadata)); + let msg = create_dummy_message(PingPongMessage::pong_with_metadata(123, metadata.clone())); let peer = msg.source_peer.clone(); state.add_inflight_ping(msg.inner.nonce, &msg.source_peer.node_id); - // A stream which emits one message and then closes - let pingpong_stream = stream::iter(std::iter::once(msg)); + // A stream which emits an inflight pong message and an unexpected one + let malicious_msg = create_dummy_message(PingPongMessage::pong_with_metadata(321, metadata)); + let pingpong_stream = stream::iter(vec![msg, malicious_msg]); let (dht_tx, mut dht_rx) = mpsc::channel(10); let dht_requester = DhtRequester::new(dht_tx); @@ -591,15 +729,20 @@ mod test { } }); + let (tx, _) = mpsc::channel(0); + let (event_tx, _) = broadcast::channel(1); + let connection_manager = ConnectionManagerRequester::new(tx, event_tx); + // Setup liveness service let (publisher, subscriber) = broadcast_channel::bounded(100); - let shutdown = Shutdown::new(); + let mut shutdown = Shutdown::new(); let service = LivenessService::new( Default::default(), stream::empty(), pingpong_stream, state, dht_requester, + connection_manager, oms_handle, publisher, shutdown.to_signal(), @@ -608,7 +751,8 @@ mod test { task::spawn(service.run()); // Listen for the pong event - let event = time::timeout(Duration::from_secs(10), subscriber.fuse().select_next_some()) + let mut subscriber = subscriber.fuse(); + let event = time::timeout(Duration::from_secs(10), subscriber.select_next_some()) .await .unwrap(); @@ -618,5 +762,11 @@ mod test { }, _ => panic!("Unexpected event"), } + + shutdown.trigger().unwrap(); + + // No further events (malicious_msg was ignored) + let events = collect_stream!(subscriber, timeout = Duration::from_secs(10)); + assert_eq!(events.len(), 0); } } diff --git a/base_layer/p2p/src/services/liveness/state.rs b/base_layer/p2p/src/services/liveness/state.rs index da1f004abb..a6206f0d4d 100644 --- a/base_layer/p2p/src/services/liveness/state.rs +++ b/base_layer/p2p/src/services/liveness/state.rs @@ -74,7 +74,7 @@ pub struct LivenessState { pongs_received: AtomicUsize, pings_sent: AtomicUsize, pongs_sent: AtomicUsize, - num_active_neighbours: AtomicUsize, + num_active_peers: AtomicUsize, pong_metadata: Metadata, nodes_to_monitor: HashMap, @@ -109,13 +109,12 @@ impl LivenessState { self.pongs_received.load(Ordering::Relaxed) } - pub fn num_active_neighbours(&self) -> usize { - self.num_active_neighbours.load(Ordering::Relaxed) + pub fn num_active_peers(&self) -> usize { + self.num_active_peers.load(Ordering::Relaxed) } - pub fn set_num_active_neighbours(&self, num_active_neighbours: usize) { - self.num_active_neighbours - .store(num_active_neighbours, Ordering::Relaxed); + pub fn set_num_active_peers(&self, n: usize) { + self.num_active_peers.store(n, Ordering::Relaxed); } #[cfg(test)] @@ -157,6 +156,11 @@ impl LivenessState { .collect(); } + /// Returns true if the nonce is inflight, otherwise false + pub fn is_inflight(&self, nonce: u64) -> bool { + self.inflight_pings.get(&nonce).is_some() + } + /// Records a pong. Specifically, the pong counter is incremented and /// a latency sample is added and calculated. pub fn record_pong(&mut self, nonce: u64) -> Option { diff --git a/base_layer/p2p/src/services/utils.rs b/base_layer/p2p/src/services/utils.rs index 111926ab58..70699d9f10 100644 --- a/base_layer/p2p/src/services/utils.rs +++ b/base_layer/p2p/src/services/utils.rs @@ -43,6 +43,7 @@ where T: prost::Message + Default { Ok(DomainMessage { source_peer: serialized.source_peer.clone(), dht_header: serialized.dht_header.clone(), + authenticated_origin: serialized.authenticated_origin.clone(), inner: serialized.decode_message()?, }) } diff --git a/base_layer/p2p/src/test_utils.rs b/base_layer/p2p/src/test_utils.rs index 1582449e63..97fbebe7e3 100644 --- a/base_layer/p2p/src/test_utils.rs +++ b/base_layer/p2p/src/test_utils.rs @@ -23,15 +23,16 @@ use rand::rngs::OsRng; use std::sync::Arc; use tari_comms::{ + message::MessageTag, multiaddr::Multiaddr, - peer_manager::{NodeIdentity, Peer, PeerFeatures, PeerFlags}, - utils::signature, + peer_manager::{NodeId, NodeIdentity, Peer, PeerFeatures, PeerFlags}, + types::CommsPublicKey, }; use tari_comms_dht::{ - envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageOrigin, DhtMessageType, Network, NodeDestination}, + envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageType, Network, NodeDestination}, inbound::DhtInboundMessage, }; -use tari_crypto::tari_utilities::message_format::MessageFormat; +use tari_crypto::keys::PublicKey; macro_rules! unwrap_oms_send_msg { ($var:expr, reply_value=$reply_value:expr) => { @@ -45,11 +46,16 @@ macro_rules! unwrap_oms_send_msg { ($var:expr) => { unwrap_oms_send_msg!( $var, - reply_value = tari_comms_dht::outbound::SendMessageResponse::Queued(vec![]) + reply_value = tari_comms_dht::outbound::SendMessageResponse::Queued(vec![].into()) ); }; } +pub fn make_node_id() -> NodeId { + let (public_key, _) = CommsPublicKey::random_keypair(&mut OsRng); + NodeId::from_key(&public_key).unwrap() +} + pub fn make_node_identity() -> Arc { Arc::new( NodeIdentity::random( @@ -61,31 +67,22 @@ pub fn make_node_identity() -> Arc { ) } -pub fn make_dht_header(node_identity: &NodeIdentity, message: &Vec, flags: DhtMessageFlags) -> DhtMessageHeader { +pub fn make_dht_header() -> DhtMessageHeader { DhtMessageHeader { version: 0, destination: NodeDestination::Unknown, - origin: Some(DhtMessageOrigin { - public_key: node_identity.public_key().clone(), - signature: signature::sign(&mut OsRng, node_identity.secret_key().clone(), message) - .unwrap() - .to_binary() - .unwrap(), - }), + origin_mac: Vec::new(), + ephemeral_public_key: None, message_type: DhtMessageType::None, network: Network::LocalTest, - flags, + flags: DhtMessageFlags::NONE, } } -pub fn make_dht_inbound_message( - node_identity: &NodeIdentity, - message: Vec, - flags: DhtMessageFlags, -) -> DhtInboundMessage -{ +pub fn make_dht_inbound_message(node_identity: &NodeIdentity, message: Vec) -> DhtInboundMessage { DhtInboundMessage::new( - make_dht_header(node_identity, &message, flags), + MessageTag::new(), + make_dht_header(), Arc::new(Peer::new( node_identity.public_key().clone(), node_identity.node_id().clone(), diff --git a/base_layer/p2p/tests/services/liveness.rs b/base_layer/p2p/tests/services/liveness.rs index a413801313..3f368f24aa 100644 --- a/base_layer/p2p/tests/services/liveness.rs +++ b/base_layer/p2p/tests/services/liveness.rs @@ -21,9 +21,11 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::support::comms_and_services::setup_comms_services; +use futures::channel::mpsc; use rand::rngs::OsRng; use std::{sync::Arc, time::Duration}; use tari_comms::{ + connection_manager::ConnectionManagerRequester, peer_manager::{NodeIdentity, PeerFeatures}, transports::MemoryTransport, CommsNode, @@ -33,13 +35,13 @@ use tari_p2p::{ comms_connector::pubsub_connector, services::{ comms_outbound::CommsOutboundServiceInitializer, - liveness::{LivenessConfig, LivenessEvent, LivenessHandle, LivenessInitializer}, + liveness::{LivenessEvent, LivenessHandle, LivenessInitializer}, }, }; use tari_service_framework::StackBuilder; use tari_test_utils::{collect_stream, random::string}; use tempdir::TempDir; -use tokio::runtime; +use tokio::{runtime, sync::broadcast}; pub async fn setup_liveness_service( node_identity: Arc, @@ -52,17 +54,17 @@ pub async fn setup_liveness_service( let subscription_factory = Arc::new(subscription_factory); let (comms, dht) = setup_comms_services(node_identity.clone(), peers, publisher, data_path).await; + let (tx, _) = mpsc::channel(0); + let (event_tx, _) = broadcast::channel(1); + let connection_manager = ConnectionManagerRequester::new(tx, event_tx); + let handles = StackBuilder::new(rt_handle.clone(), comms.shutdown_signal()) .add_initializer(CommsOutboundServiceInitializer::new(dht.outbound_requester())) .add_initializer(LivenessInitializer::new( - LivenessConfig { - enable_auto_join: false, - enable_auto_stored_message_request: false, - auto_ping_interval: None, - refresh_neighbours_interval: Duration::from_secs(60), - }, + Default::default(), Arc::clone(&subscription_factory), dht.dht_requester(), + connection_manager, )) .finish() .await diff --git a/base_layer/wallet/Cargo.toml b/base_layer/wallet/Cargo.toml index 1f95b1a518..018fa0d4bb 100644 --- a/base_layer/wallet/Cargo.toml +++ b/base_layer/wallet/Cargo.toml @@ -3,7 +3,7 @@ name = "tari_wallet" authors = ["The Tari Development Community"] description = "Tari cryptocurrency wallet library" license = "BSD-3-Clause" -version = "0.0.10" +version = "0.1.0" edition = "2018" [features] @@ -12,15 +12,15 @@ c_integration = [] [dependencies] tari_broadcast_channel = "^0.1" -tari_comms = { path = "../../comms", version = "^0.0"} -tari_comms_dht = { path = "../../comms/dht", version = "^0.0"} +tari_comms = { path = "../../comms", version = "^0.1"} +tari_comms_dht = { path = "../../comms/dht", version = "^0.1"} tari_crypto = { version = "^0.3" } tari_key_manager = {path = "../key_manager", version = "^0.0"} -tari_p2p = {path = "../p2p", version = "^0.0"} +tari_p2p = {path = "../p2p", version = "^0.1"} tari_pubsub = "^0.1" tari_service_framework = { version = "^0.0", path = "../service_framework"} tari_shutdown = { path = "../../infrastructure/shutdown", version = "^0.0"} -tari_storage = { version = "^0.0", path = "../../infrastructure/storage"} +tari_storage = { version = "^0.1", path = "../../infrastructure/storage"} chrono = { version = "0.4.6", features = ["serde"]} time = {version = "0.1.39"} @@ -45,12 +45,12 @@ tari_test_utils = { path = "../../infrastructure/test_utils", version = "^0.0", [dependencies.tari_core] path = "../../base_layer/core" -version = "^0.0" +version = "^0.1" default-features = false features = ["transactions", "mempool_proto", "base_node_proto"] [dev-dependencies] -tari_comms_dht = { path = "../../comms/dht", version = "^0.0", features=["test-mocks"]} +tari_comms_dht = { path = "../../comms/dht", version = "^0.1", features=["test-mocks"]} tari_test_utils = { path = "../../infrastructure/test_utils", version = "^0.0"} lazy_static = "1.3.0" env_logger = "0.7.1" diff --git a/base_layer/wallet/src/error.rs b/base_layer/wallet/src/error.rs index ac1d195f6b..409de76988 100644 --- a/base_layer/wallet/src/error.rs +++ b/base_layer/wallet/src/error.rs @@ -31,6 +31,7 @@ use diesel::result::Error as DieselError; use log::SetLoggerError; use serde_json::Error as SerdeJsonError; use tari_comms::{multiaddr, peer_manager::PeerManagerError}; +use tari_comms_dht::store_forward::StoreAndForwardError; use tari_p2p::{initialization::CommsInitializationError, services::liveness::error::LivenessError}; #[derive(Debug, Error)] @@ -44,6 +45,7 @@ pub enum WalletError { SetLoggerError(SetLoggerError), ContactsServiceError(ContactsServiceError), LivenessServiceError(LivenessError), + StoreAndForwardError(StoreAndForwardError), } #[derive(Debug, Error)] diff --git a/base_layer/wallet/src/output_manager_service/handle.rs b/base_layer/wallet/src/output_manager_service/handle.rs index 66e84ad12e..9a1969c3bb 100644 --- a/base_layer/wallet/src/output_manager_service/handle.rs +++ b/base_layer/wallet/src/output_manager_service/handle.rs @@ -31,7 +31,7 @@ use tari_broadcast_channel::Subscriber; use tari_comms::types::CommsPublicKey; use tari_core::transactions::{ tari_amount::MicroTari, - transaction::{TransactionInput, TransactionOutput, UnblindedOutput}, + transaction::{Transaction, TransactionInput, TransactionOutput, UnblindedOutput}, types::PrivateKey, SenderTransactionProtocol, }; @@ -57,6 +57,7 @@ pub enum OutputManagerRequest { GetSeedWords, SetBaseNodePublicKey(CommsPublicKey), SyncWithBaseNode, + CreateCoinSplit((MicroTari, usize, MicroTari, Option)), } impl fmt::Display for OutputManagerRequest { @@ -80,6 +81,7 @@ impl fmt::Display for OutputManagerRequest { Self::GetSeedWords => f.write_str("GetSeedWords"), Self::SetBaseNodePublicKey(k) => f.write_str(&format!("SetBaseNodePublicKey ({})", k)), Self::SyncWithBaseNode => f.write_str("SyncWithBaseNode"), + Self::CreateCoinSplit(v) => f.write_str(&format!("CreateCoinSplit ({})", v.0)), } } } @@ -102,6 +104,7 @@ pub enum OutputManagerResponse { SeedWords(Vec), BaseNodePublicKeySet, StartedBaseNodeSync(u64), + Transaction((u64, Transaction, MicroTari, MicroTari)), } /// Events that can be published on the Text Message Service Event Stream @@ -309,4 +312,27 @@ impl OutputManagerHandle { _ => Err(OutputManagerError::UnexpectedApiResponse), } } + + pub async fn create_coin_split( + &mut self, + amount_per_split: MicroTari, + split_count: usize, + fee_per_gram: MicroTari, + lock_height: Option, + ) -> Result<(u64, Transaction, MicroTari, MicroTari), OutputManagerError> + { + match self + .handle + .call(OutputManagerRequest::CreateCoinSplit(( + amount_per_split, + split_count, + fee_per_gram, + lock_height, + ))) + .await?? + { + OutputManagerResponse::Transaction(ct) => Ok(ct), + _ => Err(OutputManagerError::UnexpectedApiResponse), + } + } } diff --git a/base_layer/wallet/src/output_manager_service/service.rs b/base_layer/wallet/src/output_manager_service/service.rs index 9cfb086817..7f61500427 100644 --- a/base_layer/wallet/src/output_manager_service/service.rs +++ b/base_layer/wallet/src/output_manager_service/service.rs @@ -52,7 +52,14 @@ use tari_core::{ transactions::{ fee::Fee, tari_amount::MicroTari, - transaction::{OutputFeatures, TransactionInput, TransactionOutput, UnblindedOutput}, + transaction::{ + KernelFeatures, + OutputFeatures, + Transaction, + TransactionInput, + TransactionOutput, + UnblindedOutput, + }, types::{CryptoFactories, PrivateKey}, SenderTransactionProtocol, }, @@ -273,6 +280,10 @@ where .fetch_invalid_outputs() .await .map(OutputManagerResponse::InvalidOutputs), + OutputManagerRequest::CreateCoinSplit((amount_per_split, split_count, fee_per_gram, lock_height)) => self + .create_coin_split(amount_per_split, split_count, fee_per_gram, lock_height) + .await + .map(OutputManagerResponse::Transaction), } } @@ -294,7 +305,7 @@ where // Only process requests with a request_key that we are expecting. let queried_hashes: Vec> = match self.pending_utxo_query_keys.remove(&request_key) { None => { - debug!( + trace!( target: LOG_TARGET, "Ignoring Base Node Response with unexpected request key ({}), it was not meant for this service.", request_key @@ -343,10 +354,18 @@ where "Handled Base Node response for Query {}", request_key ); - self.event_publisher + let _ = self + .event_publisher .send(OutputManagerEvent::ReceiveBaseNodeResponse(request_key)) .await - .map_err(|_| OutputManagerError::EventStreamError)?; + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); Ok(()) } @@ -363,10 +382,18 @@ where self.query_unspent_outputs_status(utxo_query_timeout_futures).await?; // TODO Remove this once this bug is fixed trace!(target: LOG_TARGET, "Finished queueing new Base Node query timeout"); - self.event_publisher + let _ = self + .event_publisher .send(OutputManagerEvent::BaseNodeSyncRequestTimedOut(query_key)) .await - .map_err(|_| OutputManagerError::EventStreamError)?; + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); } Ok(()) } @@ -402,7 +429,7 @@ where self.outbound_message_service .send_direct( pk.clone(), - OutboundEncryption::EncryptForPeer, + OutboundEncryption::None, OutboundDomainMessage::new(TariMessageType::BaseNodeRequest, service_request), ) .await?; @@ -449,6 +476,7 @@ where .accept_incoming_pending_transaction(tx_id, amount, key.clone(), OutputFeatures::default()) .await?; + self.confirm_encumberance(tx_id).await?; Ok(key) } @@ -521,9 +549,7 @@ where message: String, ) -> Result { - let outputs = self - .select_outputs(amount, fee_per_gram, UTXOSelectionStrategy::MaturityThenSmallest) - .await?; + let (outputs, _) = self.select_utxos(amount, fee_per_gram, 1, None).await?; let total = outputs.iter().fold(MicroTari::from(0), |acc, x| acc + x.value); let offset = PrivateKey::random(&mut OsRng); @@ -545,7 +571,7 @@ where ); } - let fee_without_change = Fee::calculate(fee_per_gram, outputs.len(), 1); + let fee_without_change = Fee::calculate(fee_per_gram, 1, outputs.len(), 1); let mut change_key: Option = None; // If the input values > the amount to be sent + fees_without_change then we will need to include a change // output @@ -565,14 +591,14 @@ where .map_err(|e| OutputManagerError::BuildError(e.message))?; // If a change output was created add it to the pending_outputs list. - let change_output = match change_key { - Some(key) => Some(UnblindedOutput { + let mut change_output = Vec::::new(); + if let Some(key) = change_key { + change_output.push(UnblindedOutput { value: stp.get_amount_to_self()?, spending_key: key, features: OutputFeatures::default(), - }), - None => None, - }; + }); + } // The Transaction Protocol built successfully so we will pull the unspent outputs out of the unspent list and // store them until the transaction times out OR is confirmed @@ -638,6 +664,10 @@ where /// Cancel a pending transaction and place the encumbered outputs back into the unspent pool pub async fn cancel_transaction(&mut self, tx_id: u64) -> Result<(), OutputManagerError> { + trace!( + target: LOG_TARGET, + "Cancelling pending transaction outputs for TxId: tx_id" + ); Ok(self.db.cancel_pending_transaction_outputs(tx_id).await?) } @@ -646,22 +676,38 @@ where Ok(self.db.timeout_pending_transaction_outputs(period).await?) } - /// Select which outputs to use to send a transaction of the specified amount. Use the specified selection strategy - /// to choose the outputs - async fn select_outputs( + /// Select which unspent transaction outputs to use to send a transaction of the specified amount. Use the specified + /// selection strategy to choose the outputs. It also determines if a change output is required. + async fn select_utxos( &mut self, amount: MicroTari, fee_per_gram: MicroTari, - strategy: UTXOSelectionStrategy, - ) -> Result, OutputManagerError> + output_count: usize, + strategy: Option, + ) -> Result<(Vec, bool), OutputManagerError> { - let mut outputs = Vec::new(); + let mut utxos = Vec::new(); let mut total = MicroTari::from(0); let mut fee_without_change = MicroTari::from(0); let mut fee_with_change = MicroTari::from(0); let uo = self.db.fetch_sorted_unspent_outputs().await?; + // Heuristic for selecting strategy: Default to MaturityThenSmallest, but if amount > + // alpha * largest UTXO, use Largest + let strategy = match (strategy, uo.is_empty()) { + (Some(s), _) => s, + (None, true) => UTXOSelectionStrategy::Smallest, + (None, false) => { + let largest_utxo = &uo[uo.len() - 1]; + if amount > largest_utxo.value { + UTXOSelectionStrategy::Largest + } else { + UTXOSelectionStrategy::MaturityThenSmallest + } + }, + }; + let uo = match strategy { UTXOSelectionStrategy::Smallest => uo, // TODO: We should pass in the current height and group @@ -675,16 +721,21 @@ where }); new_uo }, + UTXOSelectionStrategy::Largest => uo.into_iter().rev().collect(), }; + let mut require_change_output = false; for o in uo.iter() { - outputs.push(o.clone()); + utxos.push(o.clone()); total += o.value; // I am assuming that the only output will be the payment output and change if required - fee_without_change = Fee::calculate(fee_per_gram, outputs.len(), 1); - fee_with_change = Fee::calculate(fee_per_gram, outputs.len(), 2); - - if total == amount + fee_without_change || total >= amount + fee_with_change { + fee_without_change = Fee::calculate(fee_per_gram, 1, utxos.len(), output_count); + if total == amount + fee_without_change { + break; + } + fee_with_change = Fee::calculate(fee_per_gram, 1, utxos.len(), output_count + 1); + if total >= amount + fee_with_change { + require_change_output = true; break; } } @@ -693,7 +744,7 @@ where return Err(OutputManagerError::NotEnoughFunds); } - Ok(outputs) + Ok((utxos, require_change_output)) } /// Set the base node public key to the list that will be used to check the status of UTXO's on the base chain. If @@ -732,6 +783,96 @@ where Ok(self.db.get_invalid_outputs().await?) } + pub async fn create_coin_split( + &mut self, + amount_per_split: MicroTari, + split_count: usize, + fee_per_gram: MicroTari, + lock_height: Option, + ) -> Result<(u64, Transaction, MicroTari, MicroTari), OutputManagerError> + { + trace!( + target: LOG_TARGET, + "Select UTXOs and estimate coin split transaction fee." + ); + let mut output_count = split_count; + let total_split_amount = amount_per_split * split_count as u64; + let (inputs, require_change_output) = self + .select_utxos( + total_split_amount, + fee_per_gram, + output_count, + Some(UTXOSelectionStrategy::Largest), + ) + .await?; + let utxo_total = inputs.iter().fold(MicroTari::from(0), |acc, x| acc + x.value); + let input_count = inputs.len(); + if require_change_output { + output_count = split_count + 1 + }; + let fee = Fee::calculate(fee_per_gram, 1, input_count, output_count); + + trace!(target: LOG_TARGET, "Construct coin split transaction."); + let offset = PrivateKey::random(&mut OsRng); + let nonce = PrivateKey::random(&mut OsRng); + let mut builder = SenderTransactionProtocol::builder(0); + builder + .with_lock_height(lock_height.unwrap_or(0)) + .with_fee_per_gram(fee_per_gram) + .with_offset(offset.clone()) + .with_private_nonce(nonce.clone()); + trace!(target: LOG_TARGET, "Add inputs to coin split transaction."); + for uo in inputs.iter() { + builder.with_input( + uo.as_transaction_input(&self.factories.commitment, uo.clone().features), + uo.clone(), + ); + } + trace!(target: LOG_TARGET, "Add outputs to coin split transaction."); + let mut outputs = Vec::with_capacity(output_count); + let change_output = utxo_total + .checked_sub(fee) + .ok_or(OutputManagerError::NotEnoughFunds)? + .checked_sub(total_split_amount) + .ok_or(OutputManagerError::NotEnoughFunds)?; + for i in 0..output_count { + let output_amount = if i < split_count { + amount_per_split + } else { + change_output + }; + + let mut spend_key = PrivateKey::default(); + { + let mut km = acquire_lock!(self.key_manager); + spend_key = km.next_key()?.k; + } + self.db.increment_key_index().await?; + let utxo = UnblindedOutput::new(output_amount, spend_key, None); + outputs.push(utxo.clone()); + builder.with_output(utxo); + } + trace!(target: LOG_TARGET, "Build coin split transaction."); + let factories = CryptoFactories::default(); + let mut stp = builder + .build::(&self.factories) + .map_err(|e| OutputManagerError::BuildError(e.message))?; + // The Transaction Protocol built successfully so we will pull the unspent outputs out of the unspent list and + // store them until the transaction times out OR is confirmed + let tx_id = stp.get_tx_id()?; + trace!( + target: LOG_TARGET, + "Encumber coin split transaction ({}) outputs.", + tx_id + ); + self.db.encumber_outputs(tx_id, inputs, outputs).await?; + self.confirm_encumberance(tx_id).await?; + trace!(target: LOG_TARGET, "Finalize coin split transaction ({}).", tx_id); + stp.finalize(KernelFeatures::empty(), &factories)?; + let tx = stp.get_transaction().map(Clone::clone)?; + Ok((tx_id, tx, fee, utxo_total)) + } + /// Return the Seed words for the current Master Key set in the Key Manager pub fn get_seed_words(&self) -> Result, OutputManagerError> { Ok(from_secret_key( @@ -749,6 +890,8 @@ pub enum UTXOSelectionStrategy { Smallest, // Start from oldest maturity to reduce the likelihood of grabbing locked up UTXOs MaturityThenSmallest, + // A strategy that selects the largest UTXOs first. Preferred when the amount is large + Largest, } /// This struct holds the detailed balance of the Output Manager Service. diff --git a/base_layer/wallet/src/output_manager_service/storage/database.rs b/base_layer/wallet/src/output_manager_service/storage/database.rs index 51f32072f2..97e010547e 100644 --- a/base_layer/wallet/src/output_manager_service/storage/database.rs +++ b/base_layer/wallet/src/output_manager_service/storage/database.rs @@ -57,7 +57,7 @@ pub trait OutputManagerBackend: Send + Sync { &self, tx_id: TxId, outputs_to_send: &[UnblindedOutput], - change_output: Option, + outputs_to_receive: &[UnblindedOutput], ) -> Result<(), OutputManagerStorageError>; /// This method confirms that a transaction negotiation is complete and outputs can be fully encumbered. This /// reserves these outputs until the transaction is confirmed or cancelled @@ -66,12 +66,13 @@ pub trait OutputManagerBackend: Send + Sync { /// transaction negotiation fn clear_short_term_encumberances(&self) -> Result<(), OutputManagerStorageError>; /// This method must take all the `outputs_to_be_spent` from the specified transaction and move them back into the - /// `UnspentOutputs` pool. + /// `UnspentOutputs` pool. The `outputs_to_be_received`'` will be marked as cancelled inbound outputs in case they + /// need to be recovered. fn cancel_pending_transaction(&self, tx_id: TxId) -> Result<(), OutputManagerStorageError>; /// This method must run through all the `PendingTransactionOutputs` and test if any have existed for longer that /// the specified duration. If they have they should be cancelled. fn timeout_pending_transactions(&self, period: Duration) -> Result<(), OutputManagerStorageError>; - /// This method will increment the currently stored key index for the key manager config. Increment this after eac + /// This method will increment the currently stored key index for the key manager config. Increment this after each /// key is generated fn increment_key_index(&self) -> Result<(), OutputManagerStorageError>; /// If an unspent output is detected as invalid (i.e. not available on the blockchain) then it should be moved to @@ -335,12 +336,12 @@ where T: OutputManagerBackend + 'static &self, tx_id: TxId, outputs_to_send: Vec, - change_output: Option, + outputs_to_receive: Vec, ) -> Result<(), OutputManagerStorageError> { let db_clone = self.db.clone(); tokio::task::spawn_blocking(move || { - db_clone.short_term_encumber_outputs(tx_id, &outputs_to_send, change_output) + db_clone.short_term_encumber_outputs(tx_id, &outputs_to_send, &outputs_to_receive) }) .await .or_else(|err| Err(OutputManagerStorageError::BlockingTaskSpawnError(err.to_string()))) diff --git a/base_layer/wallet/src/output_manager_service/storage/memory_db.rs b/base_layer/wallet/src/output_manager_service/storage/memory_db.rs index 4446b2dd4f..7a9f946da7 100644 --- a/base_layer/wallet/src/output_manager_service/storage/memory_db.rs +++ b/base_layer/wallet/src/output_manager_service/storage/memory_db.rs @@ -140,7 +140,7 @@ impl OutputManagerBackend for OutputManagerMemoryDatabase { db.unspent_outputs.push(*o); }, DbKeyValuePair::PendingTransactionOutputs(t, p) => { - db.pending_transactions.insert(t, *p); + db.short_term_pending_transactions.insert(t, *p); }, DbKeyValuePair::KeyManagerState(km) => db.key_manager_state = Some(km), }, @@ -204,7 +204,7 @@ impl OutputManagerBackend for OutputManagerMemoryDatabase { &self, tx_id: TxId, outputs_to_send: &[UnblindedOutput], - change_output: Option, + outputs_to_receive: &[UnblindedOutput], ) -> Result<(), OutputManagerStorageError> { let mut db = acquire_write_lock!(self.db); @@ -224,8 +224,8 @@ impl OutputManagerBackend for OutputManagerMemoryDatabase { timestamp: Utc::now().naive_utc(), }; - if let Some(co) = change_output { - pending_transaction.outputs_to_be_received.push(co); + for co in outputs_to_receive { + pending_transaction.outputs_to_be_received.push(co.clone()); } db.short_term_pending_transactions.insert(tx_id, pending_transaction); diff --git a/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs b/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs index 00478cf465..92c3611eff 100644 --- a/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs +++ b/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs @@ -269,7 +269,7 @@ impl OutputManagerBackend for OutputManagerSqliteDatabase { &self, tx_id: u64, outputs_to_send: &[UnblindedOutput], - change_output: Option, + outputs_to_receive: &[UnblindedOutput], ) -> Result<(), OutputManagerStorageError> { let conn = acquire_lock!(self.database_connection); @@ -295,8 +295,8 @@ impl OutputManagerBackend for OutputManagerSqliteDatabase { )?; } - if let Some(co) = change_output { - OutputSql::new(co, OutputStatus::EncumberedToBeReceived, Some(tx_id)).commit(&(*conn))?; + for co in outputs_to_receive { + OutputSql::new(co.clone(), OutputStatus::EncumberedToBeReceived, Some(tx_id)).commit(&(*conn))?; } Ok(()) @@ -346,7 +346,13 @@ impl OutputManagerBackend for OutputManagerSqliteDatabase { for o in outputs { if o.status == (OutputStatus::EncumberedToBeReceived as i32) { - o.delete(&(*conn))?; + o.update( + UpdateOutput { + status: Some(OutputStatus::CancelledInbound), + tx_id: None, + }, + &(*conn), + )?; } else if o.status == (OutputStatus::EncumberedToBeSpent as i32) { o.update( UpdateOutput { @@ -446,6 +452,7 @@ enum OutputStatus { EncumberedToBeReceived, EncumberedToBeSpent, Invalid, + CancelledInbound, } impl TryFrom for OutputStatus { @@ -458,6 +465,7 @@ impl TryFrom for OutputStatus { 2 => Ok(OutputStatus::EncumberedToBeReceived), 3 => Ok(OutputStatus::EncumberedToBeSpent), 4 => Ok(OutputStatus::Invalid), + 5 => Ok(OutputStatus::CancelledInbound), _ => Err(OutputManagerStorageError::ConversionError), } } @@ -650,7 +658,7 @@ impl From for UpdateOutputSql { /// This struct represents a PendingTransactionOutputs in the Sql database. A distinct struct is required to define the /// Sql friendly equivalent datatypes for the members. -#[derive(Clone, Queryable, Insertable)] +#[derive(Debug, Clone, Queryable, Insertable)] #[table_name = "pending_transaction_outputs"] struct PendingTransactionOutputSql { tx_id: i64, diff --git a/base_layer/wallet/src/testnet_utils.rs b/base_layer/wallet/src/testnet_utils.rs index 24df8dcedc..88c64b86b3 100644 --- a/base_layer/wallet/src/testnet_utils.rs +++ b/base_layer/wallet/src/testnet_utils.rs @@ -41,6 +41,7 @@ use crate::{ Wallet, }; use chrono::{Duration as ChronoDuration, Utc}; +use futures::{FutureExt, StreamExt}; use log::*; use rand::{distributions::Alphanumeric, rngs::OsRng, CryptoRng, Rng, RngCore}; use std::{ @@ -68,8 +69,7 @@ use tari_crypto::{ tari_utilities::hex::Hex, }; use tari_p2p::{initialization::CommsConfig, transport::TransportType}; -use tari_test_utils::collect_stream; -use tokio::runtime::Runtime; +use tokio::{runtime::Runtime, time::delay_for}; // Used to generate test wallet data @@ -134,7 +134,7 @@ pub fn create_wallet( max_concurrent_inbound_tasks: 100, outbound_buffer_size: 100, dht: DhtConfig { - discovery_request_timeout: Duration::from_millis(500), + discovery_request_timeout: Duration::from_secs(30), ..Default::default() }, allow_test_addresses: true, @@ -208,6 +208,8 @@ pub fn generate_wallet_test_data< .collect(); let mut message_index = 0; + let mut wallet_event_stream = wallet.transaction_service.get_event_stream_fused(); + // Generate contacts let mut generated_contacts = Vec::new(); for i in 0..names.len() { @@ -247,7 +249,7 @@ pub fn generate_wallet_test_data< generated_contacts[0].1.clone(), alice_temp_dir.clone(), ); - + let mut alice_event_stream = wallet_alice.transaction_service.get_event_stream_fused(); for i in 0..20 { let (_ti, uo) = make_input(&mut OsRng.clone(), MicroTari::from(1_500_000 + i * 530_500), &factories); wallet_alice @@ -267,6 +269,8 @@ pub fn generate_wallet_test_data< generated_contacts[1].1.clone(), bob_temp_dir.clone(), ); + let mut bob_event_stream = wallet_bob.transaction_service.get_event_stream_fused(); + for i in 0..20 { let (_ti, uo) = make_input( &mut OsRng.clone(), @@ -284,6 +288,7 @@ pub fn generate_wallet_test_data< wallet .runtime .block_on(wallet.comms.peer_manager().add_peer(alice_peer))?; + let bob_peer = wallet_bob.comms.node_identity().to_peer(); wallet @@ -310,35 +315,54 @@ pub fn generate_wallet_test_data< .unwrap(); info!(target: LOG_TARGET, "Starting to execute test transactions"); + // Grab the first 2 outbound tx_ids for later + let mut outbound_tx_ids = Vec::new(); + // Completed TX - wallet.runtime.block_on(wallet.transaction_service.send_transaction( + let tx_id = wallet.runtime.block_on(wallet.transaction_service.send_transaction( contacts[0].public_key.clone(), MicroTari::from(1_100_000), MicroTari::from(100), messages[message_index].clone(), ))?; - + outbound_tx_ids.push(tx_id); message_index = (message_index + 1) % messages.len(); - wallet.runtime.block_on(wallet.transaction_service.send_transaction( + + let tx_id = wallet.runtime.block_on(wallet.transaction_service.send_transaction( contacts[0].public_key.clone(), MicroTari::from(2_010_500), MicroTari::from(110), messages[message_index].clone(), ))?; + outbound_tx_ids.push(tx_id); message_index = (message_index + 1) % messages.len(); - // Grab the first 2 outbound tx_ids for later - let mut outbound_tx_ids = Vec::new(); - let wallet_event_stream = wallet.transaction_service.get_event_stream_fused(); - let wallet_stream = wallet - .runtime - .block_on(async { collect_stream!(wallet_event_stream, take = 4, timeout = Duration::from_secs(60)) }); - for v in wallet_stream { - if let TransactionEvent::TransactionSendResult(tx_id, _) = &*v { - outbound_tx_ids.push(tx_id.clone()); + wallet_alice.runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut count = 0; + loop { + futures::select! { + event = alice_event_stream.select_next_some() => { + match &*event.unwrap() { + TransactionEvent::ReceivedTransaction(_) => { + count +=1; + }, + TransactionEvent::ReceivedFinalizedTransaction(_) => { + count +=1; + }, + _ => (), + } + if count >=4 { + break; + } + }, + () = delay => { + break; + }, + } } - } - assert_eq!(outbound_tx_ids.len(), 2); + assert!(count >= 4, "Event waiting timed out before receiving expected events 1"); + }); wallet.runtime.block_on(wallet.transaction_service.send_transaction( contacts[0].public_key.clone(), @@ -419,31 +443,60 @@ pub fn generate_wallet_test_data< )); message_index = (message_index + 1) % messages.len(); - // Make sure that the messages have been received by the alice and bob wallets before they start sending messages so - // that they have the wallet in their peer_managers - let alice_event_stream = wallet_alice.transaction_service.get_event_stream_fused(); - let bob_event_stream = wallet_bob.transaction_service.get_event_stream_fused(); - - let _alice_stream = wallet_alice - .runtime - .block_on(async { collect_stream!(alice_event_stream, take = 12, timeout = Duration::from_secs(60)) }); - - let _bob_stream = wallet_bob - .runtime - .block_on(async { collect_stream!(bob_event_stream, take = 8, timeout = Duration::from_secs(60)) }); - // Make sure that the messages have been received by the alice and bob wallets before they start sending messages so - // that they have the wallet in their peer_managers - let alice_event_stream = wallet_alice.transaction_service.get_event_stream_fused(); - let bob_event_stream = wallet_bob.transaction_service.get_event_stream_fused(); - - let _alice_stream = wallet_bob - .runtime - .block_on(async { collect_stream!(alice_event_stream, take = 6, timeout = Duration::from_secs(60)) }); - - let _bob_stream = wallet_bob - .runtime - .block_on(async { collect_stream!(bob_event_stream, take = 2, timeout = Duration::from_secs(60)) }); - + wallet.runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut count = 0; + loop { + futures::select! { + event = wallet_event_stream.select_next_some() => { + match &*event.unwrap() { + TransactionEvent::TransactionDirectSendResult(_,_) => { + count+=1; + if count >= 10 { + break; + } + }, + _ => (), + } + }, + () = delay => { + break; + }, + } + } + assert!( + count >= 10, + "Event waiting timed out before receiving expected events 2" + ); + }); + + wallet_bob.runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut count = 0; + loop { + futures::select! { + event = bob_event_stream.select_next_some() => { + match &*event.unwrap() { + TransactionEvent::ReceivedTransaction(_) => { + count+=1; + }, + TransactionEvent::ReceivedFinalizedTransaction(_) => { + count+=1; + }, + _ => (), + } + if count >= 8 { + break; + } + }, + () = delay => { + break; + }, + } + } + assert!(count >= 8, "Event waiting timed out before receiving expected events 3"); + }); + log::error!("Inbound Transactions starting"); // Pending Inbound wallet_alice .runtime @@ -494,10 +547,30 @@ pub fn generate_wallet_test_data< messages[message_index].clone(), ))?; - let wallet_event_stream = wallet.transaction_service.get_event_stream_fused(); - let _wallet_stream = wallet - .runtime - .block_on(async { collect_stream!(wallet_event_stream, take = 30, timeout = Duration::from_secs(120)) }); + wallet.runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut count = 0; + loop { + futures::select! { + event = wallet_event_stream.select_next_some() => { + match &*event.unwrap() { + TransactionEvent::ReceivedFinalizedTransaction(_) => { + count+=1; + if count >= 5 { + break; + } + }, + _ => (), + } + }, + () = delay => { + break; + }, + } + } + assert!(count >= 5, "Event waiting timed out before receiving expected events 4"); + }); + let txs = wallet .runtime .block_on(wallet.transaction_service.get_completed_transactions()) @@ -678,7 +751,7 @@ pub fn receive_test_transaction< .runtime .block_on(wallet.transaction_service.test_accept_transaction( OsRng.next_u64(), - MicroTari::from(10_000 + OsRng.next_u64() % 10_1000), + MicroTari::from(10_000 + OsRng.next_u64() % 101_000), public_key, ))?; diff --git a/base_layer/wallet/src/text_message_service/service.rs b/base_layer/wallet/src/text_message_service/service.rs index a99dde41ea..b1fe247ff4 100644 --- a/base_layer/wallet/src/text_message_service/service.rs +++ b/base_layer/wallet/src/text_message_service/service.rs @@ -287,10 +287,18 @@ where let message_inner = message.clone().into_inner(); poll_fn(move |_| blocking(|| message_inner.commit(&conn))).await??; - self.event_publisher + let _ = self + .event_publisher .send(TextMessageEvent::ReceivedTextMessage) .await - .map_err(|_| TextMessageError::EventStreamError)?; + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); Ok(()) } @@ -315,10 +323,18 @@ where ); poll_fn(move |_| blocking(|| SentTextMessage::mark_sent_message_ack(&message_ack_inner.id, &conn))).await??; - self.event_publisher + let _ = self + .event_publisher .send(TextMessageEvent::ReceivedTextMessageAck) .await - .map_err(|_| TextMessageError::EventStreamError)?; + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); Ok(()) } diff --git a/base_layer/wallet/src/transaction_service/error.rs b/base_layer/wallet/src/transaction_service/error.rs index c5b6e7cd67..b2e7a50996 100644 --- a/base_layer/wallet/src/transaction_service/error.rs +++ b/base_layer/wallet/src/transaction_service/error.rs @@ -26,12 +26,14 @@ use crate::{ }; use derive_error::Error; use diesel::result::Error as DieselError; +use futures::channel::oneshot::Canceled; use serde_json::Error as SerdeJsonError; use tari_comms::peer_manager::node_id::NodeIdError; use tari_comms_dht::outbound::DhtOutboundError; use tari_core::transactions::{transaction::TransactionError, transaction_protocol::TransactionProtocolError}; use tari_service_framework::reply_channel::TransportChannelError; use time::OutOfRangeError; +use tokio::sync::broadcast::RecvError; #[derive(Debug, Error)] pub enum TransactionServiceError { @@ -73,6 +75,16 @@ pub enum TransactionServiceError { InvalidCompletedTransaction, /// No Base Node public keys are provided for Base chain broadcast and monitoring NoBaseNodeKeysProvided, + /// Error sending data to Protocol via register channels + ProtocolChannelError, + /// Transaction detected as rejected by mempool + MempoolRejection, + /// Mempool response key does not match on that is expected + UnexpectedMempoolResponse, + /// Base Node response key does not match on that is expected + UnexpectedBaseNodeResponse, + /// The current transaction has been cancelled + TransactionCancelled, DhtOutboundError(DhtOutboundError), OutputManagerError(OutputManagerError), TransportChannelError(TransportChannelError), @@ -86,6 +98,8 @@ pub enum TransactionServiceError { #[error(msg_embedded, no_from, non_std)] ConversionError(String), NodeIdError(NodeIdError), + BroadcastRecvError(RecvError), + OneshotCancelled(Canceled), } #[derive(Debug, Error)] @@ -114,3 +128,23 @@ pub enum TransactionStorageError { #[error(msg_embedded, non_std, no_from)] BlockingTaskSpawnError(String), } + +/// This error type is used to return TransactionServiceErrors from inside a Transaction Service protocol but also +/// include the ID of the protocol +#[derive(Debug)] +pub struct TransactionServiceProtocolError { + pub id: u64, + pub error: TransactionServiceError, +} + +impl TransactionServiceProtocolError { + pub fn new(id: u64, error: TransactionServiceError) -> Self { + Self { id, error } + } +} + +impl From for TransactionServiceError { + fn from(tspe: TransactionServiceProtocolError) -> Self { + tspe.error + } +} diff --git a/base_layer/wallet/src/transaction_service/handle.rs b/base_layer/wallet/src/transaction_service/handle.rs index 8e4259d627..aa578a08c6 100644 --- a/base_layer/wallet/src/transaction_service/handle.rs +++ b/base_layer/wallet/src/transaction_service/handle.rs @@ -29,13 +29,12 @@ use crate::{ }, }; use futures::{stream::Fuse, StreamExt}; -use std::{collections::HashMap, fmt}; -use tari_broadcast_channel::Subscriber; +use std::{collections::HashMap, fmt, sync::Arc}; use tari_comms::types::CommsPublicKey; use tari_core::transactions::{tari_amount::MicroTari, transaction::Transaction}; use tari_service_framework::reply_channel::SenderService; +use tokio::sync::broadcast; use tower::Service; - /// API Request enum #[derive(Debug)] pub enum TransactionServiceRequest { @@ -44,10 +43,12 @@ pub enum TransactionServiceRequest { GetCompletedTransactions, SetBaseNodePublicKey(CommsPublicKey), SendTransaction((CommsPublicKey, MicroTari, MicroTari, String)), + CancelTransaction(TxId), RequestCoinbaseSpendingKey((MicroTari, u64)), CompleteCoinbaseTransaction((TxId, Transaction)), CancelPendingCoinbaseTransaction(TxId), ImportUtxo(MicroTari, CommsPublicKey, String), + SubmitTransaction((TxId, Transaction, MicroTari, MicroTari, String)), #[cfg(feature = "test_harness")] CompletePendingOutboundTransaction(CompletedTransaction), #[cfg(feature = "test_harness")] @@ -70,6 +71,7 @@ impl fmt::Display for TransactionServiceRequest { Self::SendTransaction((k, v, _, msg)) => { f.write_str(&format!("SendTransaction (to {}, {}, {})", k, v, msg)) }, + Self::CancelTransaction(t) => f.write_str(&format!("CancelTransaction ({})", t)), Self::RequestCoinbaseSpendingKey((v, h)) => { f.write_str(&format!("RequestCoinbaseSpendingKey ({}, maturity={})", v, h)) }, @@ -78,6 +80,7 @@ impl fmt::Display for TransactionServiceRequest { f.write_str(&format!("CancelPendingCoinbaseTransaction ({}) ", id)) }, Self::ImportUtxo(v, k, msg) => f.write_str(&format!("ImportUtxo (from {}, {}, {})", k, v, msg)), + Self::SubmitTransaction((id, _, _, _, _)) => f.write_str(&format!("SubmitTransaction ({})", id)), #[cfg(feature = "test_harness")] Self::CompletePendingOutboundTransaction(tx) => { f.write_str(&format!("CompletePendingOutboundTransaction ({})", tx.tx_id)) @@ -99,7 +102,8 @@ impl fmt::Display for TransactionServiceRequest { /// API Response enum #[derive(Debug)] pub enum TransactionServiceResponse { - TransactionSent, + TransactionSent(TxId), + TransactionCancelled, PendingInboundTransactions(HashMap), PendingOutboundTransactions(HashMap), CompletedTransactions(HashMap), @@ -108,6 +112,7 @@ pub enum TransactionServiceResponse { CoinbaseTransactionCancelled, BaseNodePublicKeySet, UtxoImported(TxId), + TransactionSubmitted, #[cfg(feature = "test_harness")] CompletedPendingTransaction, #[cfg(feature = "test_harness")] @@ -127,33 +132,39 @@ pub enum TransactionEvent { ReceivedTransaction(TxId), ReceivedTransactionReply(TxId), ReceivedFinalizedTransaction(TxId), - TransactionSendResult(TxId, bool), - TransactionSendDiscoveryComplete(TxId, bool), + TransactionDirectSendResult(TxId, bool), + TransactionStoreForwardSendResult(TxId, bool), + TransactionCancelled(TxId), TransactionBroadcast(TxId), TransactionMined(TxId), TransactionMinedRequestTimedOut(TxId), Error(String), } +pub type TransactionEventSender = broadcast::Sender>; +pub type TransactionEventReceiver = broadcast::Receiver>; /// The Transaction Service Handle is a struct that contains the interfaces used to communicate with a running /// Transaction Service #[derive(Clone)] pub struct TransactionServiceHandle { handle: SenderService>, - event_stream: Subscriber, + event_stream_sender: TransactionEventSender, } impl TransactionServiceHandle { pub fn new( handle: SenderService>, - event_stream: Subscriber, + event_stream_sender: TransactionEventSender, ) -> Self { - Self { handle, event_stream } + Self { + handle, + event_stream_sender, + } } - pub fn get_event_stream_fused(&self) -> Fuse> { - self.event_stream.clone().fuse() + pub fn get_event_stream_fused(&self) -> Fuse { + self.event_stream_sender.subscribe().fuse() } pub async fn send_transaction( @@ -162,7 +173,7 @@ impl TransactionServiceHandle { amount: MicroTari, fee_per_gram: MicroTari, message: String, - ) -> Result<(), TransactionServiceError> + ) -> Result { match self .handle @@ -174,7 +185,18 @@ impl TransactionServiceHandle { ))) .await?? { - TransactionServiceResponse::TransactionSent => Ok(()), + TransactionServiceResponse::TransactionSent(tx_id) => Ok(tx_id), + _ => Err(TransactionServiceError::UnexpectedApiResponse), + } + } + + pub async fn cancel_transaction(&mut self, tx_id: TxId) -> Result<(), TransactionServiceError> { + match self + .handle + .call(TransactionServiceRequest::CancelTransaction(tx_id)) + .await?? + { + TransactionServiceResponse::TransactionCancelled => Ok(()), _ => Err(TransactionServiceError::UnexpectedApiResponse), } } @@ -303,6 +325,27 @@ impl TransactionServiceHandle { } } + pub async fn submit_transaction( + &mut self, + tx_id: u64, + tx: Transaction, + fee: MicroTari, + amount: MicroTari, + message: String, + ) -> Result<(), TransactionServiceError> + { + match self + .handle + .call(TransactionServiceRequest::SubmitTransaction(( + tx_id, tx, fee, amount, message, + ))) + .await?? + { + TransactionServiceResponse::TransactionSubmitted => Ok(()), + _ => Err(TransactionServiceError::UnexpectedApiResponse), + } + } + #[cfg(feature = "test_harness")] pub async fn test_complete_pending_transaction( &mut self, diff --git a/base_layer/wallet/src/transaction_service/mod.rs b/base_layer/wallet/src/transaction_service/mod.rs index 102fe9550f..812a766f74 100644 --- a/base_layer/wallet/src/transaction_service/mod.rs +++ b/base_layer/wallet/src/transaction_service/mod.rs @@ -23,6 +23,7 @@ pub mod config; pub mod error; pub mod handle; +pub mod protocols; pub mod service; pub mod storage; @@ -38,8 +39,7 @@ use crate::{ use futures::{future, Future, Stream, StreamExt}; use log::*; use std::sync::Arc; -use tari_broadcast_channel::bounded; -use tari_comms::{peer_manager::NodeIdentity, protocol::messaging::MessagingEventReceiver}; +use tari_comms::peer_manager::NodeIdentity; use tari_comms_dht::outbound::OutboundMessageRequester; use tari_core::{ base_node::proto::base_node as BaseNodeProto, @@ -60,7 +60,7 @@ use tari_service_framework::{ ServiceInitializer, }; use tari_shutdown::ShutdownSignal; -use tokio::runtime; +use tokio::{runtime, sync::broadcast}; const LOG_TARGET: &str = "wallet::transaction_service"; @@ -69,7 +69,6 @@ where T: TransactionBackend { config: TransactionServiceConfig, subscription_factory: Arc>>, - message_event_receiver: Option, backend: Option, node_identity: Arc, factories: CryptoFactories, @@ -81,7 +80,6 @@ where T: TransactionBackend pub fn new( config: TransactionServiceConfig, subscription_factory: Arc>>, - message_event_receiver: MessagingEventReceiver, backend: T, node_identity: Arc, factories: CryptoFactories, @@ -90,7 +88,6 @@ where T: TransactionBackend Self { config, subscription_factory, - message_event_receiver: Some(message_event_receiver), backend: Some(backend), node_identity, factories, @@ -153,9 +150,9 @@ where T: TransactionBackend + Clone + 'static let mempool_response_stream = self.mempool_response_stream(); let base_node_response_stream = self.base_node_response_stream(); - let (publisher, subscriber) = bounded(100); + let (publisher, _) = broadcast::channel(200); - let transaction_handle = TransactionServiceHandle::new(sender, subscriber); + let transaction_handle = TransactionServiceHandle::new(sender, publisher.clone()); // Register handle before waiting for handles to be ready handles_fut.register(transaction_handle); @@ -165,11 +162,6 @@ where T: TransactionBackend + Clone + 'static .take() .expect("Cannot start Transaction Service without providing a backend"); - let message_event_receiver = self - .message_event_receiver - .take() - .expect("Cannot start Transaction Service without providing an Message Event Receiver"); - let node_identity = self.node_identity.clone(); let factories = self.factories.clone(); let config = self.config.clone(); @@ -195,7 +187,6 @@ where T: TransactionBackend + Clone + 'static base_node_response_stream, output_manager_service, outbound_message_service, - message_event_receiver, publisher, node_identity, factories, diff --git a/base_layer/core/src/proof_of_work/diff_adj_manager/mod.rs b/base_layer/wallet/src/transaction_service/protocols/mod.rs similarity index 88% rename from base_layer/core/src/proof_of_work/diff_adj_manager/mod.rs rename to base_layer/wallet/src/transaction_service/protocols/mod.rs index 30b2fac4a3..6ff235e36e 100644 --- a/base_layer/core/src/proof_of_work/diff_adj_manager/mod.rs +++ b/base_layer/wallet/src/transaction_service/protocols/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2019. The Tari Project +// Copyright 2020. The Tari Project // // Redistribution and use in source and binary forms, with or without modification, are permitted provided that the // following conditions are met: @@ -20,9 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -mod diff_adj_manager; -mod diff_adj_storage; -mod error; - -pub use diff_adj_manager::DiffAdjManager; -pub use error::DiffAdjManagerError; +pub mod transaction_broadcast_protocol; +pub mod transaction_chain_monitoring_protocol; +pub mod transaction_receive_protocol; +pub mod transaction_send_protocol; diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_broadcast_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_broadcast_protocol.rs new file mode 100644 index 0000000000..21c4a4ef7a --- /dev/null +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_broadcast_protocol.rs @@ -0,0 +1,460 @@ +// Copyright 2020. The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use crate::transaction_service::{ + error::{TransactionServiceError, TransactionServiceProtocolError}, + handle::TransactionEvent, + service::TransactionServiceResources, + storage::database::{TransactionBackend, TransactionStatus}, +}; +use futures::{channel::mpsc::Receiver, FutureExt, StreamExt}; +use log::*; +use std::{convert::TryFrom, sync::Arc, time::Duration}; +use tari_comms::types::CommsPublicKey; +use tari_comms_dht::{domain_message::OutboundDomainMessage, outbound::OutboundEncryption}; +use tari_core::{ + base_node::proto::{ + base_node as BaseNodeProto, + base_node::{ + base_node_service_request::Request as BaseNodeRequestProto, + base_node_service_response::Response as BaseNodeResponseProto, + }, + }, + mempool::{ + proto::mempool as MempoolProto, + service::{MempoolResponse, MempoolServiceResponse}, + TxStorageResponse, + }, + transactions::transaction::TransactionOutput, +}; +use tari_crypto::tari_utilities::{hex::Hex, Hashable}; +use tari_p2p::tari_message::TariMessageType; +use tokio::time::delay_for; + +const LOG_TARGET: &str = "wallet::transaction_service::protocols::broadcast_protocol"; + +/// This protocol defines the process of monitoring a mempool and base node to detect when a Completed transaction is +/// Broadcast to the mempool or potentially Mined +pub struct TransactionBroadcastProtocol +where TBackend: TransactionBackend + Clone + 'static +{ + id: u64, + resources: TransactionServiceResources, + timeout: Duration, + base_node_public_key: CommsPublicKey, + mempool_response_receiver: Option>, + base_node_response_receiver: Option>, +} + +impl TransactionBroadcastProtocol +where TBackend: TransactionBackend + Clone + 'static +{ + pub fn new( + id: u64, + resources: TransactionServiceResources, + timeout: Duration, + base_node_public_key: CommsPublicKey, + mempool_response_receiver: Receiver, + base_node_response_receiver: Receiver, + ) -> Self + { + Self { + id, + resources, + timeout, + base_node_public_key, + mempool_response_receiver: Some(mempool_response_receiver), + base_node_response_receiver: Some(base_node_response_receiver), + } + } + + /// The task that defines the execution of the protocol. + pub async fn execute(mut self) -> Result { + let mut mempool_response_receiver = self + .mempool_response_receiver + .take() + .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))?; + + let mut base_node_response_receiver = self + .base_node_response_receiver + .take() + .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))?; + + // This is the main loop of the protocol and following the following steps + // 1) Check transaction being monitored is still in the Completed state and needs to be monitored + // 2) Send a MempoolRequest::SubmitTransaction to Mempool and a Mined? Request to base node + // 3) Wait for a either a Mempool response, Base Node response for the correct Id OR a Timeout + // a) A Mempool response for this Id is received > update the Tx status and end the protocol + // b) A Basenode response for this Id is received showing it is mined > Update Tx status and end protocol + // c) Timeout is reached > Start again + loop { + let completed_tx = match self.resources.db.get_completed_transaction(self.id).await { + Ok(tx) => tx, + Err(e) => { + error!( + target: LOG_TARGET, + "Cannot find Completed Transaction (TxId: {}) referred to by this Broadcast protocol: {:?}", + self.id, + e + ); + return Err(TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::TransactionDoesNotExistError, + )); + }, + }; + + if completed_tx.status != TransactionStatus::Completed { + debug!( + target: LOG_TARGET, + "Transaction (TxId: {}) no longer in Completed state and will stop being broadcast", self.id + ); + return Ok(self.id); + } + + info!( + target: LOG_TARGET, + "Attempting to Broadcast Transaction (TxId: {} and Kernel Signature: {}) to Mempool", + self.id, + completed_tx.transaction.body.kernels()[0] + .excess_sig + .get_signature() + .to_hex() + ); + trace!(target: LOG_TARGET, "{}", completed_tx.transaction); + + // Send Mempool Request + let mempool_request = MempoolProto::MempoolServiceRequest { + request_key: completed_tx.tx_id, + request: Some(MempoolProto::mempool_service_request::Request::SubmitTransaction( + completed_tx.transaction.clone().into(), + )), + }; + + self.resources + .outbound_message_service + .send_direct( + self.base_node_public_key.clone(), + OutboundEncryption::None, + OutboundDomainMessage::new(TariMessageType::MempoolRequest, mempool_request.clone()), + ) + .await + .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; + + // Send Base Node query + let mut hashes = Vec::new(); + for o in completed_tx.transaction.body.outputs() { + hashes.push(o.hash()); + } + + let request = BaseNodeRequestProto::FetchUtxos(BaseNodeProto::HashOutputs { outputs: hashes }); + let service_request = BaseNodeProto::BaseNodeServiceRequest { + request_key: self.id, + request: Some(request), + }; + self.resources + .outbound_message_service + .send_direct( + self.base_node_public_key.clone(), + OutboundEncryption::None, + OutboundDomainMessage::new(TariMessageType::BaseNodeRequest, service_request), + ) + .await + .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; + + let mut delay = delay_for(self.timeout).fuse(); + futures::select! { + mempool_response = mempool_response_receiver.select_next_some() => { + if self.handle_mempool_response(mempool_response).await? { + break; + } + }, + base_node_response = base_node_response_receiver.select_next_some() => { + if self.handle_base_node_response(base_node_response).await? { + break; + } + }, + () = delay => { + }, + } + + info!( + target: LOG_TARGET, + "Mempool broadcast timed out for Transaction with TX_ID: {}", self.id + ); + + let _ = self + .resources + .event_publisher + .send(Arc::new(TransactionEvent::MempoolBroadcastTimedOut(self.id))) + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); + } + + Ok(self.id) + } + + async fn handle_mempool_response( + &mut self, + response: MempoolServiceResponse, + ) -> Result + { + if response.request_key != self.id { + trace!( + target: LOG_TARGET, + "Mempool response key does not match this Broadcast Protocol Id" + ); + return Ok(false); + } + + // Handle a receive Mempool Response + match response.response { + MempoolResponse::Stats(_) => { + error!(target: LOG_TARGET, "Invalid Mempool response variant"); + }, + MempoolResponse::State(_) => { + error!(target: LOG_TARGET, "Invalid Mempool response variant"); + }, + MempoolResponse::TxStorage(ts) => { + let completed_tx = match self + .resources + .db + .get_completed_transaction(response.request_key.clone()) + .await + { + Ok(tx) => tx, + Err(e) => { + error!( + target: LOG_TARGET, + "Cannot find Completed Transaction (TxId: {}) referred to by this Broadcast protocol: {:?}", + self.id, + e + ); + return Err(TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::TransactionDoesNotExistError, + )); + }, + }; + match completed_tx.status { + TransactionStatus::Completed => match ts { + // Getting this response means the Mempool Rejected this transaction so it will be + // cancelled. + TxStorageResponse::NotStored => { + error!( + target: LOG_TARGET, + "Mempool response received for TxId: {:?}. Transaction was REJECTED. Cancelling \ + transaction.", + self.id + ); + if let Err(e) = self + .resources + .output_manager_service + .cancel_transaction(completed_tx.tx_id) + .await + { + error!( + target: LOG_TARGET, + "Failed to Cancel outputs for TX_ID: {} after failed sending attempt with error \ + {:?}", + completed_tx.tx_id, + e + ); + } + if let Err(e) = self.resources.db.cancel_completed_transaction(completed_tx.tx_id).await { + error!( + target: LOG_TARGET, + "Failed to Cancel TX_ID: {} after failed sending attempt with error {:?}", + completed_tx.tx_id, + e + ); + } + let _ = self + .resources + .event_publisher + .send(Arc::new(TransactionEvent::TransactionCancelled(self.id))) + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); + + return Err(TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::MempoolRejection, + )); + }, + // Any other variant of this enum means the transaction has been received by the + // base_node and is in one of the various mempools + _ => { + // If this transaction is still in the Completed State it should be upgraded to the + // Broadcast state + info!( + target: LOG_TARGET, + "Completed Transaction (TxId: {} and Kernel Excess Sig: {}) detected as Broadcast to \ + Base Node Mempool in {:?}", + self.id, + completed_tx.transaction.body.kernels()[0] + .excess_sig + .get_signature() + .to_hex(), + ts + ); + + self.resources + .db + .broadcast_completed_transaction(self.id) + .await + .map_err(|e| { + TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)) + })?; + let _ = self + .resources + .event_publisher + .send(Arc::new(TransactionEvent::TransactionBroadcast(self.id))) + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); + return Ok(true); + }, + }, + _ => (), + } + }, + } + + Ok(false) + } + + async fn handle_base_node_response( + &mut self, + response: BaseNodeProto::BaseNodeServiceResponse, + ) -> Result + { + if response.request_key != self.id { + trace!( + target: LOG_TARGET, + "Base Node response key does not match this Broadcast Protocol Id" + ); + return Ok(false); + } + + let response: Vec = match response.response { + Some(BaseNodeResponseProto::TransactionOutputs(outputs)) => outputs.outputs, + _ => { + return Ok(false); + }, + }; + + let completed_tx = match self.resources.db.get_completed_transaction(self.id).await { + Ok(tx) => tx, + Err(_) => { + error!( + target: LOG_TARGET, + "Cannot find Completed Transaction (TxId: {}) referred to by this Broadcast protocol", self.id + ); + return Err(TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::TransactionDoesNotExistError, + )); + }, + }; + + if !response.is_empty() && + (completed_tx.status == TransactionStatus::Broadcast || + completed_tx.status == TransactionStatus::Completed) + { + let mut check = true; + + for output in response.iter() { + let transaction_output = TransactionOutput::try_from(output.clone()).map_err(|_| { + TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::ConversionError("Could not convert Transaction Output".to_string()), + ) + })?; + + check = check && + completed_tx + .transaction + .body + .outputs() + .iter() + .any(|item| item == &transaction_output); + } + // If all outputs are present then mark this transaction as mined. + if check && !response.is_empty() { + self.resources + .output_manager_service + .confirm_transaction( + self.id, + completed_tx.transaction.body.inputs().clone(), + completed_tx.transaction.body.outputs().clone(), + ) + .await + .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; + + self.resources + .db + .mine_completed_transaction(self.id) + .await + .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; + + let _ = self + .resources + .event_publisher + .send(Arc::new(TransactionEvent::TransactionMined(self.id))) + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); + + info!( + target: LOG_TARGET, + "Transaction (TxId: {:?}) detected as mined on the Base Layer", self.id + ); + + return Ok(true); + } + } + + Ok(false) + } +} diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_chain_monitoring_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_chain_monitoring_protocol.rs new file mode 100644 index 0000000000..02f9e6aa07 --- /dev/null +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_chain_monitoring_protocol.rs @@ -0,0 +1,479 @@ +// Copyright 2020. The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use crate::{ + output_manager_service::TxId, + transaction_service::{ + error::{TransactionServiceError, TransactionServiceProtocolError}, + handle::TransactionEvent, + service::TransactionServiceResources, + storage::database::{TransactionBackend, TransactionStatus}, + }, +}; +use futures::{channel::mpsc::Receiver, FutureExt, StreamExt}; +use log::*; +use std::{convert::TryFrom, sync::Arc, time::Duration}; +use tari_comms::types::CommsPublicKey; +use tari_comms_dht::{domain_message::OutboundDomainMessage, outbound::OutboundEncryption}; +use tari_core::{ + base_node::proto::{ + base_node as BaseNodeProto, + base_node::{ + base_node_service_request::Request as BaseNodeRequestProto, + base_node_service_response::Response as BaseNodeResponseProto, + }, + }, + mempool::{ + proto::mempool as MempoolProto, + service::{MempoolResponse, MempoolServiceResponse}, + TxStorageResponse, + }, + transactions::transaction::TransactionOutput, +}; +use tari_crypto::tari_utilities::{hex::Hex, Hashable}; +use tari_p2p::tari_message::TariMessageType; +use tokio::time::delay_for; + +const LOG_TARGET: &str = "wallet::transaction_service::protocols::chain_monitoring_protocol"; + +/// This protocol defines the process of monitoring a mempool and base node to detect when a Broadcast transaction is +/// Mined or leaves the mempool in which case it should be cancelled +pub struct TransactionChainMonitoringProtocol +where TBackend: TransactionBackend + Clone + 'static +{ + id: u64, + tx_id: TxId, + resources: TransactionServiceResources, + timeout: Duration, + base_node_public_key: CommsPublicKey, + mempool_response_receiver: Option>, + base_node_response_receiver: Option>, +} + +impl TransactionChainMonitoringProtocol +where TBackend: TransactionBackend + Clone + 'static +{ + pub fn new( + id: u64, + tx_id: TxId, + resources: TransactionServiceResources, + timeout: Duration, + base_node_public_key: CommsPublicKey, + mempool_response_receiver: Receiver, + base_node_response_receiver: Receiver, + ) -> Self + { + Self { + id, + tx_id, + resources, + timeout, + base_node_public_key, + mempool_response_receiver: Some(mempool_response_receiver), + base_node_response_receiver: Some(base_node_response_receiver), + } + } + + /// The task that defines the execution of the protocol. + pub async fn execute(mut self) -> Result { + let mut mempool_response_receiver = self + .mempool_response_receiver + .take() + .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))?; + + let mut base_node_response_receiver = self + .base_node_response_receiver + .take() + .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))?; + + trace!( + target: LOG_TARGET, + "Starting chain monitoring protocol for TxId: {} with Protocol ID: {}", + self.tx_id, + self.id + ); + + // This is the main loop of the protocol and following the following steps + // 1) Check transaction being monitored is still in the Broadcast state and needs to be monitored + // 2) Send a MempoolRequest::GetTxStateWithExcessSig to Mempool and a Mined? Request to base node + // 3) Wait for both a Mempool response and Base Node response for the correct Id OR a Timeout + // a) If the Tx is not in the mempool AND is not mined the protocol ends and Tx should be cancelled + // b) If the Tx is in the mempool AND not mined > perform another iteration + // c) If the Tx is in the mempool AND mined then update the status of the Tx and end the protocol + // c) Timeout is reached > Start again + loop { + let completed_tx = match self.resources.db.get_completed_transaction(self.tx_id).await { + Ok(tx) => tx, + Err(e) => { + error!( + target: LOG_TARGET, + "Cannot find Completed Transaction (TxId: {}) referred to by this Chain Monitoring Protocol: \ + {:?}", + self.tx_id, + e + ); + return Err(TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::TransactionDoesNotExistError, + )); + }, + }; + + if completed_tx.status != TransactionStatus::Broadcast { + debug!( + target: LOG_TARGET, + "Transaction (TxId: {}) no longer in Broadcast state and will stop being monitored for being Mined", + self.tx_id + ); + return Ok(self.id); + } + + let mut hashes = Vec::new(); + for o in completed_tx.transaction.body.outputs() { + hashes.push(o.hash()); + } + + info!( + target: LOG_TARGET, + "Sending Transaction Mined? request for TxId: {} and Kernel Signature {} to Base Node (Contains {} \ + outputs)", + completed_tx.tx_id, + completed_tx.transaction.body.kernels()[0] + .excess_sig + .get_signature() + .to_hex(), + hashes.len(), + ); + + // Send Mempool query + let tx_excess_sig = completed_tx.transaction.body.kernels()[0].excess_sig.clone(); + let mempool_request = MempoolProto::MempoolServiceRequest { + request_key: self.id, + request: Some(MempoolProto::mempool_service_request::Request::GetTxStateWithExcessSig( + tx_excess_sig.into(), + )), + }; + + self.resources + .outbound_message_service + .send_direct( + self.base_node_public_key.clone(), + OutboundEncryption::None, + OutboundDomainMessage::new(TariMessageType::MempoolRequest, mempool_request.clone()), + ) + .await + .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; + + // Send Base Node query + let request = BaseNodeRequestProto::FetchUtxos(BaseNodeProto::HashOutputs { outputs: hashes }); + let service_request = BaseNodeProto::BaseNodeServiceRequest { + request_key: self.id, + request: Some(request), + }; + self.resources + .outbound_message_service + .send_direct( + self.base_node_public_key.clone(), + OutboundEncryption::None, + OutboundDomainMessage::new(TariMessageType::BaseNodeRequest, service_request), + ) + .await + .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; + + let mut delay = delay_for(self.timeout).fuse(); + let mut received_mempool_response = None; + let mut mempool_response_received = false; + let mut base_node_response_received = false; + // Loop until both a Mempool response AND a Base node response is received OR the Timeout expires. + loop { + futures::select! { + mempool_response = mempool_response_receiver.select_next_some() => { + //We must first check the Base Node response before checking the mempool repsonse so we will keep it for the end of the round + received_mempool_response = Some(mempool_response); + mempool_response_received = true; + }, + base_node_response = base_node_response_receiver.select_next_some() => { + //We can immediately check the Base Node Response + if self + .handle_base_node_response(completed_tx.tx_id, base_node_response) + .await? + { + // Tx is mined! + return Ok(self.id); + } + base_node_response_received = true; + }, + () = delay => { + break; + }, + } + + // If we have received both responses from this round we can check the mempool status and then continue + // to next round + if received_mempool_response.is_some() && base_node_response_received { + if let Some(mempool_response) = received_mempool_response { + if !self + .handle_mempool_response(completed_tx.tx_id, mempool_response) + .await? + { + return Err(TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::MempoolRejection, + )); + } + } + + break; + } + } + + if mempool_response_received && base_node_response_received { + info!( + target: LOG_TARGET, + "Base node and Mempool response received. TxId: {:?} not mined yet.", completed_tx.tx_id, + ); + // Finish out the rest of this period before moving onto next round + delay.await; + } + + info!( + target: LOG_TARGET, + "Chain monitoring process timed out for Transaction TX_ID: {}", completed_tx.tx_id + ); + + let _ = self + .resources + .event_publisher + .send(Arc::new(TransactionEvent::TransactionMinedRequestTimedOut( + completed_tx.tx_id, + ))) + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); + } + } + + async fn handle_mempool_response( + &mut self, + tx_id: TxId, + response: MempoolServiceResponse, + ) -> Result + { + // Handle a receive Mempool Response + match response.response { + MempoolResponse::Stats(_) => { + error!(target: LOG_TARGET, "Invalid Mempool response variant"); + }, + MempoolResponse::State(_) => { + error!(target: LOG_TARGET, "Invalid Mempool response variant"); + }, + MempoolResponse::TxStorage(ts) => { + let completed_tx = match self.resources.db.get_completed_transaction(tx_id).await { + Ok(tx) => tx, + Err(e) => { + error!( + target: LOG_TARGET, + "Cannot find Completed Transaction (TxId: {}) referred to by this Chain Monitoring \ + Protocol: {:?}", + self.tx_id, + e + ); + return Err(TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::TransactionDoesNotExistError, + )); + }, + }; + match completed_tx.status { + TransactionStatus::Broadcast => match ts { + // Getting this response means the Mempool Rejected this transaction so it will be + // cancelled. + TxStorageResponse::NotStored => { + error!( + target: LOG_TARGET, + "Mempool response received for TxId: {:?}. Transaction was REJECTED. Cancelling \ + transaction.", + tx_id + ); + if let Err(e) = self + .resources + .output_manager_service + .cancel_transaction(completed_tx.tx_id) + .await + { + error!( + target: LOG_TARGET, + "Failed to Cancel outputs for TX_ID: {} after failed sending attempt with error \ + {:?}", + completed_tx.tx_id, + e + ); + } + if let Err(e) = self.resources.db.cancel_completed_transaction(completed_tx.tx_id).await { + error!( + target: LOG_TARGET, + "Failed to Cancel TX_ID: {} after failed sending attempt with error {:?}", + completed_tx.tx_id, + e + ); + } + let _ = self + .resources + .event_publisher + .send(Arc::new(TransactionEvent::TransactionCancelled(self.id))) + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); + + return Err(TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::MempoolRejection, + )); + }, + // Any other variant of this enum means the transaction has been received by the + // base_node and is in one of the various mempools + _ => { + // If this transaction is still in the Completed State it should be upgraded to the + // Broadcast state + info!( + target: LOG_TARGET, + "Completed Transaction (TxId: {} and Kernel Excess Sig: {}) detected in Base Node \ + Mempool in {:?}", + completed_tx.tx_id, + completed_tx.transaction.body.kernels()[0] + .excess_sig + .get_signature() + .to_hex(), + ts + ); + return Ok(true); + }, + }, + _ => (), + } + }, + } + + Ok(true) + } + + async fn handle_base_node_response( + &mut self, + tx_id: TxId, + response: BaseNodeProto::BaseNodeServiceResponse, + ) -> Result + { + let response: Vec = match response.response { + Some(BaseNodeResponseProto::TransactionOutputs(outputs)) => outputs.outputs, + _ => { + return Ok(false); + }, + }; + + let completed_tx = match self.resources.db.get_completed_transaction(tx_id).await { + Ok(tx) => tx, + Err(e) => { + error!( + target: LOG_TARGET, + "Cannot find Completed Transaction (TxId: {}) referred to by this Chain Monitoring Protocol: {:?}", + self.tx_id, + e + ); + return Err(TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::TransactionDoesNotExistError, + )); + }, + }; + + if completed_tx.status == TransactionStatus::Broadcast { + let mut check = true; + + for output in response.iter() { + let transaction_output = TransactionOutput::try_from(output.clone()).map_err(|_| { + TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::ConversionError("Could not convert Transaction Output".to_string()), + ) + })?; + + check = check && + completed_tx + .transaction + .body + .outputs() + .iter() + .any(|item| item == &transaction_output); + } + // If all outputs are present then mark this transaction as mined. + if check && !response.is_empty() { + self.resources + .output_manager_service + .confirm_transaction( + completed_tx.tx_id, + completed_tx.transaction.body.inputs().clone(), + completed_tx.transaction.body.outputs().clone(), + ) + .await + .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; + + self.resources + .db + .mine_completed_transaction(completed_tx.tx_id) + .await + .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; + + let _ = self + .resources + .event_publisher + .send(Arc::new(TransactionEvent::TransactionMined(completed_tx.tx_id))) + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); + + info!( + target: LOG_TARGET, + "Transaction (TxId: {:?}) detected as mined on the Base Layer", completed_tx.tx_id + ); + + return Ok(true); + } + } + + Ok(false) + } +} diff --git a/base_layer/core/src/base_node/states/error.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_receive_protocol.rs similarity index 77% rename from base_layer/core/src/base_node/states/error.rs rename to base_layer/wallet/src/transaction_service/protocols/transaction_receive_protocol.rs index b35a4b4cc5..a851ecd17a 100644 --- a/base_layer/core/src/base_node/states/error.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_receive_protocol.rs @@ -1,4 +1,4 @@ -// Copyright 2019. The Tari Project +// Copyright 2020. The Tari Project // // Redistribution and use in source and binary forms, with or without modification, are permitted provided that the // following conditions are met: @@ -20,17 +20,13 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::base_node::states::StateEvent; -use derive_error::Error; - -#[derive(Clone, Debug, Error)] -pub enum BaseNodeError { - #[error(msg_embedded, non_std, no_from)] - ConfigurationError(String), -} - -impl BaseNodeError { - pub fn as_fatal(&self, preface: &str) -> StateEvent { - StateEvent::FatalError(format!("{} {}", preface, self.to_string())) - } -} +// pub struct TransactionReceiveProtocol { +// id: u64, +// db: TransactionDatabase, +// output_manager_service: OutputManagerHandle, +// outbound_message_service: OutboundMessageRequester, +// event_publisher: Publisher, +// node_identity: Arc, +// factories: CryptoFactories, +// transaction_finalized_channel: Receiver, +// } diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs new file mode 100644 index 0000000000..ed917299ca --- /dev/null +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_send_protocol.rs @@ -0,0 +1,511 @@ +// Copyright 2020. The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use std::sync::Arc; + +use chrono::Utc; +use futures::{channel::mpsc::Receiver, FutureExt, StreamExt}; +use log::*; + +use crate::transaction_service::{ + error::{TransactionServiceError, TransactionServiceProtocolError}, + handle::TransactionEvent, + service::TransactionServiceResources, + storage::database::{CompletedTransaction, OutboundTransaction, TransactionBackend, TransactionStatus}, +}; +use futures::channel::oneshot; +use tari_comms::{peer_manager::NodeId, types::CommsPublicKey}; +use tari_comms_dht::{domain_message::OutboundDomainMessage, envelope::NodeDestination, outbound::OutboundEncryption}; +use tari_core::transactions::{ + tari_amount::MicroTari, + transaction::{KernelFeatures, TransactionError}, + transaction_protocol::{proto, recipient::RecipientSignedMessage}, + SenderTransactionProtocol, +}; +use tari_p2p::tari_message::TariMessageType; + +const LOG_TARGET: &str = "wallet::transaction_service::protocols::send_protocol"; + +#[derive(Debug, PartialEq)] +pub enum TransactionProtocolStage { + Initial, + WaitForReply, +} + +pub struct TransactionSendProtocol +where TBackend: TransactionBackend + Clone + 'static +{ + id: u64, + resources: TransactionServiceResources, + transaction_reply_receiver: Option>, + cancellation_receiver: Option>, + dest_pubkey: CommsPublicKey, + amount: MicroTari, + message: String, + sender_protocol: SenderTransactionProtocol, + stage: TransactionProtocolStage, +} + +#[allow(clippy::too_many_arguments)] +impl TransactionSendProtocol +where TBackend: TransactionBackend + Clone + 'static +{ + pub fn new( + id: u64, + resources: TransactionServiceResources, + transaction_reply_receiver: Receiver<(CommsPublicKey, RecipientSignedMessage)>, + cancellation_receiver: oneshot::Receiver<()>, + dest_pubkey: CommsPublicKey, + amount: MicroTari, + message: String, + sender_protocol: SenderTransactionProtocol, + stage: TransactionProtocolStage, + ) -> Self + { + Self { + id, + resources, + transaction_reply_receiver: Some(transaction_reply_receiver), + cancellation_receiver: Some(cancellation_receiver), + dest_pubkey, + amount, + message, + sender_protocol, + stage, + } + } + + /// Execute the Transaction Send Protocol as an async task. + pub async fn execute(mut self) -> Result { + info!( + "Starting Transaction Send protocol for TxId: {} at Stage {:?}", + self.id, self.stage + ); + + // Only Send the transaction of the protocol stage is Initial. If the protocol is started in a later stage + // ignore this + if self.stage == TransactionProtocolStage::Initial { + self.send_transaction().await?; + } + + // Waiting for Transaction Reply + let tx_id = self.id; + let mut receiver = self + .transaction_reply_receiver + .take() + .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))?; + + let mut cancellation_receiver = self + .cancellation_receiver + .take() + .ok_or_else(|| TransactionServiceProtocolError::new(self.id, TransactionServiceError::InvalidStateError))? + .fuse(); + + let mut outbound_tx = self + .resources + .db + .get_pending_outbound_transaction(tx_id) + .await + .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; + + if !outbound_tx.sender_protocol.is_collecting_single_signature() { + error!(target: LOG_TARGET, "Pending Transaction not in correct state"); + return Err(TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::InvalidStateError, + )); + } + + let mut source_pubkey; + #[allow(unused_assignments)] + let mut reply = None; + loop { + #[allow(unused_assignments)] + let mut rr_tx_id = 0; + futures::select! { + (spk, rr) = receiver.select_next_some() => { + source_pubkey = spk; + rr_tx_id = rr.tx_id; + reply = Some(rr); + }, + _ = cancellation_receiver => { + info!(target: LOG_TARGET, "Cancelling Transaction Send Protocol for TxId: {}", self.id); + return Err(TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::TransactionCancelled, + )); + } + } + + if outbound_tx.destination_public_key != source_pubkey { + error!( + target: LOG_TARGET, + "Transaction Reply did not come from the expected Public Key" + ); + } else if !outbound_tx.sender_protocol.check_tx_id(rr_tx_id) { + error!(target: LOG_TARGET, "Transaction Reply does not have the correct TxId"); + } else { + break; + } + } + + let recipient_reply = reply.ok_or(TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::TransactionCancelled, + ))?; + + outbound_tx + .sender_protocol + .add_single_recipient_info(recipient_reply, &self.resources.factories.range_proof) + .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; + + let finalize_result = outbound_tx + .sender_protocol + .finalize(KernelFeatures::empty(), &self.resources.factories) + .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; + + if !finalize_result { + return Err(TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::TransactionError(TransactionError::ValidationError( + "Transaction could not be finalized".to_string(), + )), + )); + } + + let tx = outbound_tx + .sender_protocol + .get_transaction() + .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; + + let completed_transaction = CompletedTransaction { + tx_id, + source_public_key: self.resources.node_identity.public_key().clone(), + destination_public_key: outbound_tx.destination_public_key.clone(), + amount: outbound_tx.amount, + fee: outbound_tx.fee, + transaction: tx.clone(), + status: TransactionStatus::Completed, + message: outbound_tx.message.clone(), + timestamp: Utc::now().naive_utc(), + }; + + self.resources + .db + .complete_outbound_transaction(tx_id, completed_transaction.clone()) + .await + .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; + info!( + target: LOG_TARGET, + "Transaction Recipient Reply for TX_ID = {} received", tx_id, + ); + + let finalized_transaction_message = proto::TransactionFinalizedMessage { + tx_id, + transaction: Some(tx.clone().into()), + }; + + let _ = self + .resources + .event_publisher + .send(Arc::new(TransactionEvent::ReceivedTransactionReply(tx_id))) + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); + + // TODO Actually monitor the send status of this message + self.resources + .outbound_message_service + .send_direct( + outbound_tx.destination_public_key.clone(), + OutboundEncryption::None, + OutboundDomainMessage::new( + TariMessageType::TransactionFinalized, + finalized_transaction_message.clone(), + ), + ) + .await + .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; + + // TODO Monitor the final send result of this process + match self + .resources + .outbound_message_service + .propagate( + NodeDestination::NodeId(Box::new(NodeId::from_key(&self.dest_pubkey).map_err(|e| { + TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)) + })?)), + OutboundEncryption::EncryptFor(Box::new(self.dest_pubkey.clone())), + vec![], + OutboundDomainMessage::new( + TariMessageType::TransactionFinalized, + finalized_transaction_message.clone(), + ), + ) + .await + { + Ok(result) => match result.resolve_ok().await { + None => { + error!( + target: LOG_TARGET, + "Sending Finalized Transaction (TxId: {}) to neighbours for Store and Forward failed", self.id + ); + }, + Some(tags) if !tags.is_empty() => { + info!( + target: LOG_TARGET, + "Sending Finalized Transaction (TxId: {}) to Neighbours for Store and Forward successful with \ + Message Tags: {:?}", + tx_id, + tags, + ); + }, + Some(_) => { + error!( + target: LOG_TARGET, + "Sending Finalized Transaction to Neighbours for Store and Forward for TX_ID: {} was \ + unsuccessful and no messages were sent", + tx_id + ); + }, + }, + Err(e) => { + error!( + target: LOG_TARGET, + "Sending Finalized Transaction (TxId: {}) to neighbours for Store and Forward failed: {:?}", + self.id, + e + ); + }, + }; + + Ok(self.id) + } + + /// Contains all the logic to initially send the transaction. This will only be done on the first time this Protocol + /// is executed. + async fn send_transaction(&mut self) -> Result<(), TransactionServiceProtocolError> { + if !self.sender_protocol.is_single_round_message_ready() { + error!(target: LOG_TARGET, "Sender Transaction Protocol is in an invalid state"); + return Err(TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::InvalidStateError, + )); + } + + let msg = self + .sender_protocol + .build_single_round_message() + .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; + let tx_id = msg.tx_id; + + if tx_id != self.id { + return Err(TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::InvalidStateError, + )); + } + + let proto_message = proto::TransactionSenderMessage::single(msg.into()); + let mut direct_send_success = false; + match self + .resources + .outbound_message_service + .send_direct( + self.dest_pubkey.clone(), + OutboundEncryption::None, + OutboundDomainMessage::new(TariMessageType::SenderPartialTransaction, proto_message.clone()), + ) + .await + { + Ok(result) => match result.resolve_ok().await { + Some(send_states) if send_states.len() == 1 => { + info!( + target: LOG_TARGET, + "Transaction (TxId: {}) Direct Send to {} successful with Message Tag: {:?}", + tx_id, + self.dest_pubkey, + send_states[0].tag, + ); + direct_send_success = true; + + let event_publisher = self.resources.event_publisher.clone(); + // Launch a task to monitor if the message gets sent + tokio::spawn(async move { + match send_states.wait_single().await { + true => { + info!( + target: LOG_TARGET, + "Direct Send process for TX_ID: {} was successful", tx_id + ); + let _ = event_publisher + .send(Arc::new(TransactionEvent::TransactionDirectSendResult(tx_id, true))); + }, + false => { + error!( + target: LOG_TARGET, + "Direct Send process for TX_ID: {} was unsuccessful and no message was sent", tx_id + ); + let _ = event_publisher + .send(Arc::new(TransactionEvent::TransactionDirectSendResult(tx_id, false))); + }, + } + }); + }, + _ => { + let _ = self + .resources + .event_publisher + .send(Arc::new(TransactionEvent::TransactionDirectSendResult(tx_id, false))); + error!(target: LOG_TARGET, "Transaction Send Direct for TxID: {} failed", tx_id); + }, + }, + Err(e) => { + error!(target: LOG_TARGET, "Direct Transaction Send failed: {:?}", e); + let _ = self + .resources + .event_publisher + .send(Arc::new(TransactionEvent::TransactionDirectSendResult(tx_id, false))); + }, + }; + + // TODO Actually monitor the send status of this message + let mut store_and_forward_send_success = false; + match self + .resources + .outbound_message_service + .propagate( + NodeDestination::NodeId(Box::new(NodeId::from_key(&self.dest_pubkey).map_err(|e| { + TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)) + })?)), + OutboundEncryption::EncryptFor(Box::new(self.dest_pubkey.clone())), + vec![], + OutboundDomainMessage::new(TariMessageType::SenderPartialTransaction, proto_message), + ) + .await + { + Ok(result) => match result.resolve_ok().await { + None => { + error!( + target: LOG_TARGET, + "Transaction Send (TxId: {}) to neighbours for Store and Forward failed", self.id + ); + }, + Some(tags) if !tags.is_empty() => { + info!( + target: LOG_TARGET, + "Transaction (TxId: {}) Send to Neighbours for Store and Forward successful with Message \ + Tags: {:?}", + tx_id, + tags, + ); + store_and_forward_send_success = true; + }, + Some(_) => { + error!( + target: LOG_TARGET, + "Transaction Send to Neighbours for Store and Forward for TX_ID: {} was unsuccessful and no \ + messages were sent", + tx_id + ); + }, + }, + Err(e) => { + error!( + target: LOG_TARGET, + "Transaction Send (TxId: {}) to neighbours for Store and Forward failed: {:?}", self.id, e + ); + }, + }; + + if !direct_send_success && !store_and_forward_send_success { + error!( + target: LOG_TARGET, + "Failed to Send Transaction (TxId: {}) both Directly or via Store and Forward. Pending Transaction \ + will be cancelled", + tx_id + ); + if let Err(e) = self.resources.output_manager_service.cancel_transaction(tx_id).await { + error!( + target: LOG_TARGET, + "Failed to Cancel TX_ID: {} after failed sending attempt with error {:?}", tx_id, e + ); + }; + let _ = self + .resources + .event_publisher + .send(Arc::new(TransactionEvent::TransactionStoreForwardSendResult( + tx_id, false, + ))); + return Err(TransactionServiceProtocolError::new( + self.id, + TransactionServiceError::OutboundSendFailure, + )); + } + + self.resources + .output_manager_service + .confirm_pending_transaction(tx_id) + .await + .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; + + let fee = self + .sender_protocol + .get_fee_amount() + .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; + let outbound_tx = OutboundTransaction { + tx_id, + destination_public_key: self.dest_pubkey.clone(), + amount: self.amount, + fee, + sender_protocol: self.sender_protocol.clone(), + status: TransactionStatus::Pending, + message: self.message.clone(), + timestamp: Utc::now().naive_utc(), + }; + + self.resources + .db + .add_pending_outbound_transaction(outbound_tx.tx_id, outbound_tx) + .await + .map_err(|e| TransactionServiceProtocolError::new(self.id, TransactionServiceError::from(e)))?; + + info!( + target: LOG_TARGET, + "Pending Outbound Transaction TxId: {:?} added. Waiting for Reply or Cancellation", tx_id, + ); + + let _ = self + .resources + .event_publisher + .send(Arc::new(TransactionEvent::TransactionStoreForwardSendResult( + tx_id, true, + ))); + + Ok(()) + } +} diff --git a/base_layer/wallet/src/transaction_service/service.rs b/base_layer/wallet/src/transaction_service/service.rs index ed4e9c3fa2..fd8ef2e895 100644 --- a/base_layer/wallet/src/transaction_service/service.rs +++ b/base_layer/wallet/src/transaction_service/service.rs @@ -20,17 +20,31 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use std::{ - collections::HashMap, - convert::{TryFrom, TryInto}, - sync::Arc, - time::Duration, +use crate::{ + output_manager_service::{handle::OutputManagerHandle, TxId}, + transaction_service::{ + config::TransactionServiceConfig, + error::{TransactionServiceError, TransactionServiceProtocolError}, + handle::{TransactionEvent, TransactionEventSender, TransactionServiceRequest, TransactionServiceResponse}, + protocols::{ + transaction_broadcast_protocol::TransactionBroadcastProtocol, + transaction_chain_monitoring_protocol::TransactionChainMonitoringProtocol, + transaction_send_protocol::{TransactionProtocolStage, TransactionSendProtocol}, + }, + storage::database::{ + CompletedTransaction, + InboundTransaction, + OutboundTransaction, + PendingCoinbaseTransaction, + TransactionBackend, + TransactionDatabase, + TransactionStatus, + }, + }, }; - use chrono::Utc; use futures::{ - channel::oneshot, - future::{BoxFuture, FutureExt}, + channel::{mpsc, mpsc::Sender, oneshot}, pin_mut, stream::FuturesUnordered, SinkExt, @@ -39,41 +53,28 @@ use futures::{ }; use log::*; use rand::{rngs::OsRng, RngCore}; -use tari_broadcast_channel::Publisher; -use tari_crypto::{ - commitment::HomomorphicCommitmentFactory, - keys::SecretKey, - tari_utilities::{hash::Hashable, hex::Hex}, +use std::{ + collections::HashMap, + convert::{TryFrom, TryInto}, + sync::Arc, }; - use tari_comms::{ - message::MessageTag, - peer_manager::NodeIdentity, - protocol::messaging::{MessagingEvent, MessagingEventReceiver}, + peer_manager::{NodeId, NodeIdentity}, types::CommsPublicKey, }; use tari_comms_dht::{ domain_message::OutboundDomainMessage, - outbound::{OutboundEncryption, OutboundMessageRequester, SendMessageResponse}, + envelope::NodeDestination, + outbound::{OutboundEncryption, OutboundMessageRequester}, }; #[cfg(feature = "test_harness")] use tari_core::transactions::{tari_amount::uT, types::BlindingFactor}; use tari_core::{ - base_node::proto::{ - base_node as BaseNodeProto, - base_node::{ - base_node_service_request::Request as BaseNodeRequestProto, - base_node_service_response::Response as BaseNodeResponseProto, - }, - }, - mempool::{ - proto::mempool as MempoolProto, - service::{MempoolResponse, MempoolServiceResponse}, - TxStorageResponse, - }, + base_node::proto::base_node as BaseNodeProto, + mempool::{proto::mempool as MempoolProto, service::MempoolServiceResponse}, transactions::{ tari_amount::MicroTari, - transaction::{KernelFeatures, OutputFeatures, OutputFlags, Transaction, TransactionOutput}, + transaction::{KernelFeatures, OutputFeatures, OutputFlags, Transaction}, transaction_protocol::{ proto, recipient::{RecipientSignedMessage, RecipientState}, @@ -83,27 +84,10 @@ use tari_core::{ ReceiverTransactionProtocol, }, }; +use tari_crypto::{commitment::HomomorphicCommitmentFactory, keys::SecretKey}; use tari_p2p::{domain_message::DomainMessage, tari_message::TariMessageType}; use tari_service_framework::{reply_channel, reply_channel::Receiver}; - -use crate::{ - output_manager_service::{handle::OutputManagerHandle, TxId}, - transaction_service::{ - config::TransactionServiceConfig, - error::TransactionServiceError, - handle::{TransactionEvent, TransactionServiceRequest, TransactionServiceResponse}, - storage::database::{ - CompletedTransaction, - InboundTransaction, - OutboundTransaction, - PendingCoinbaseTransaction, - TransactionBackend, - TransactionDatabase, - TransactionStatus, - }, - }, - util::futures::StateDelay, -}; +use tokio::task::JoinHandle; const LOG_TARGET: &str = "wallet::transaction_service::service"; @@ -136,7 +120,6 @@ where TBackend: TransactionBackend + Clone + 'static config: TransactionServiceConfig, db: TransactionDatabase, outbound_message_service: OutboundMessageRequester, - message_event_receiver: Option, output_manager_service: OutputManagerHandle, transaction_stream: Option, transaction_reply_stream: Option, @@ -146,12 +129,15 @@ where TBackend: TransactionBackend + Clone + 'static request_stream: Option< reply_channel::Receiver>, >, - event_publisher: Publisher, + event_publisher: TransactionEventSender, node_identity: Arc, factories: CryptoFactories, base_node_public_key: Option, - pending_outbound_message_results: HashMap, - pending_transaction_mined_queries: HashMap, + service_resources: TransactionServiceResources, + pending_transaction_reply_senders: HashMap>, + mempool_response_senders: HashMap>, + base_node_response_senders: HashMap>, + send_transaction_cancellation_senders: HashMap>, } #[allow(clippy::too_many_arguments)] @@ -179,17 +165,25 @@ where base_node_response_stream: BNResponseStream, output_manager_service: OutputManagerHandle, outbound_message_service: OutboundMessageRequester, - message_event_receiver: MessagingEventReceiver, - event_publisher: Publisher, + event_publisher: TransactionEventSender, node_identity: Arc, factories: CryptoFactories, ) -> Self { + // Collect the resources that all protocols will need so that they can be neatly cloned as the protocols are + // spawned. + let service_resources = TransactionServiceResources { + db: db.clone(), + output_manager_service: output_manager_service.clone(), + outbound_message_service: outbound_message_service.clone(), + event_publisher: event_publisher.clone(), + node_identity: node_identity.clone(), + factories: factories.clone(), + }; TransactionService { config, db, outbound_message_service, - message_event_receiver: Some(message_event_receiver), output_manager_service, transaction_stream: Some(transaction_stream), transaction_reply_stream: Some(transaction_reply_stream), @@ -201,8 +195,11 @@ where node_identity, factories, base_node_public_key: None, - pending_outbound_message_results: HashMap::new(), - pending_transaction_mined_queries: HashMap::new(), + service_resources, + pending_transaction_reply_senders: HashMap::new(), + mempool_response_senders: HashMap::new(), + base_node_response_senders: HashMap::new(), + send_transaction_cancellation_senders: HashMap::new(), } } @@ -244,27 +241,27 @@ where .expect("Transaction Service initialized without base_node_response_stream") .fuse(); pin_mut!(base_node_response_stream); - let message_event_receiver = self - .message_event_receiver - .take() - .expect("Transaction Service initialized without message_event_subscription") - .fuse(); - pin_mut!(message_event_receiver); - let mut discovery_process_futures: FuturesUnordered< - BoxFuture<'static, Result<(MessageTag, OutboundTransaction), TransactionServiceError>>, + let mut send_transaction_protocol_handles: FuturesUnordered< + JoinHandle>, > = FuturesUnordered::new(); - let mut broadcast_timeout_futures: FuturesUnordered> = FuturesUnordered::new(); - let mut mined_request_timeout_futures: FuturesUnordered> = FuturesUnordered::new(); + let mut transaction_broadcast_protocol_handles: FuturesUnordered< + JoinHandle>, + > = FuturesUnordered::new(); + let mut transaction_chain_monitoring_protocol_handles: FuturesUnordered< + JoinHandle>, + > = FuturesUnordered::new(); + + info!(target: LOG_TARGET, "Transaction Service started"); loop { futures::select! { //Incoming request request_context = request_stream.select_next_some() => { trace!(target: LOG_TARGET, "Handling Service API Request"); let (request, reply_tx) = request_context.split(); - let _ = reply_tx.send(self.handle_request(request, &mut discovery_process_futures, &mut broadcast_timeout_futures, &mut mined_request_timeout_futures).await.or_else(|resp| { + let _ = reply_tx.send(self.handle_request(request, &mut send_transaction_protocol_handles, &mut transaction_broadcast_protocol_handles, &mut transaction_chain_monitoring_protocol_handles).await.or_else(|resp| { error!(target: LOG_TARGET, "Error handling request: {:?}", resp); Err(resp) })).or_else(|resp| { @@ -276,58 +273,54 @@ where msg = transaction_stream.select_next_some() => { trace!(target: LOG_TARGET, "Handling Transaction Message"); let (origin_public_key, inner_msg) = msg.into_origin_and_inner(); - let result = self.accept_transaction(origin_public_key, inner_msg).await.or_else(|err| { - error!(target: LOG_TARGET, "Failed to handle incoming Transaction message: {:?} for NodeID: {}", err, self.node_identity.node_id().short_str()); - Err(err) - }); + let result = self.accept_transaction(origin_public_key, inner_msg).await; - if result.is_err() { - let _ = self.event_publisher - .send(TransactionEvent::Error( - "Error handling Transaction Sender message".to_string(), - )) - .await; + match result { + Err(TransactionServiceError::RepeatedMessageError) => { + trace!(target: LOG_TARGET, "A repeated Transaction message was received"); + } + Err(e) => { + error!(target: LOG_TARGET, "Failed to handle incoming Transaction message: {:?} for NodeID: {}", e, self.node_identity.node_id().short_str()); + let _ = self.event_publisher.send(Arc::new(TransactionEvent::Error(format!("Error handling Transaction Sender message: {:?}", e).to_string()))); + } + _ => (), } }, // Incoming messages from the Comms layer msg = transaction_reply_stream.select_next_some() => { trace!(target: LOG_TARGET, "Handling Transaction Reply Message"); let (origin_public_key, inner_msg) = msg.into_origin_and_inner(); - let result = self.accept_recipient_reply(origin_public_key, inner_msg, &mut broadcast_timeout_futures).await.or_else(|err| { - error!(target: LOG_TARGET, "Failed to handle incoming Transaction Reply message: {:?} for NodeId: {}", err, self.node_identity.node_id().short_str()); - Err(err) - }); + let result = self.accept_recipient_reply(origin_public_key, inner_msg).await; - if result.is_err() { - let _ = self.event_publisher - .send(TransactionEvent::Error( - "Error handling Transaction Recipient Reply message".to_string(), - )) - .await; + match result { + Err(TransactionServiceError::TransactionDoesNotExistError) => { + debug!(target: LOG_TARGET, "Unable to handle incoming Transaction Reply message from NodeId: {} due to Transaction not Existing. This usually means the message was a repeated message from Store and Forward", self.node_identity.node_id().short_str()); + }, + Err(e) => { + error!(target: LOG_TARGET, "Failed to handle incoming Transaction Reply message: {:?} for NodeId: {}", e, self.node_identity.node_id().short_str()); + let _ = self.event_publisher.send(Arc::new(TransactionEvent::Error("Error handling Transaction Recipient Reply message".to_string()))); + }, + Ok(_) => (), } }, // Incoming messages from the Comms layer msg = transaction_finalized_stream.select_next_some() => { trace!(target: LOG_TARGET, "Handling Transaction Finalized Message"); let (origin_public_key, inner_msg) = msg.into_origin_and_inner(); - let result = self.accept_finalized_transaction(origin_public_key, inner_msg, &mut broadcast_timeout_futures).await.or_else(|err| { + let result = self.accept_finalized_transaction(origin_public_key, inner_msg, &mut transaction_broadcast_protocol_handles).await.or_else(|err| { error!(target: LOG_TARGET, "Failed to handle incoming Transaction Finalized message: {:?} for NodeID: {}", err , self.node_identity.node_id().short_str()); Err(err) }); if result.is_err() { - let _ = self.event_publisher - .send(TransactionEvent::Error( - "Error handling Transaction Finalized message".to_string(), - )) - .await; + let _ = self.event_publisher.send(Arc::new(TransactionEvent::Error("Error handling Transaction Finalized message".to_string(),))); } }, // Incoming messages from the Comms layer msg = mempool_response_stream.select_next_some() => { trace!(target: LOG_TARGET, "Handling Mempool Response"); let (origin_public_key, inner_msg) = msg.into_origin_and_inner(); - let _ = self.handle_mempool_response(inner_msg, &mut mined_request_timeout_futures).await.or_else(|resp| { + let _ = self.handle_mempool_response(inner_msg).await.or_else(|resp| { error!(target: LOG_TARGET, "Error handling mempool service response: {:?}", resp); Err(resp) }); @@ -341,57 +334,26 @@ where Err(resp) }); } - response = discovery_process_futures.select_next_some() => { - trace!(target: LOG_TARGET, "Handling Discovery Process Completion"); - match response { - Ok((message_tag, outbound_tx)) => { - info!( - target: LOG_TARGET, - "Discovery process completed for TxId: {} with Message Tag: {} now waiting for MessageSent event", - outbound_tx.tx_id, - message_tag, - ); - self.db - .add_pending_outbound_transaction(outbound_tx.tx_id, outbound_tx.clone()) - .await?; - self.pending_outbound_message_results.insert(message_tag.clone(), outbound_tx); - }, - Err(TransactionServiceError::DiscoveryProcessFailed(tx_id)) => { - if let Err(e) = self.output_manager_service.cancel_transaction(tx_id).await { - error!(target: LOG_TARGET, "Failed to Cancel TX_ID: {} after failed sending attempt", tx_id); - } - error!(target: LOG_TARGET, "Discovery and Send failed for TX_ID: {}", tx_id); - let _ = self.event_publisher - .send(TransactionEvent::TransactionSendDiscoveryComplete(tx_id, false)) - .await; - } - Err(e) => error!(target: LOG_TARGET, "Discovery and Send failed with Error: {:?}", e), - } - }, - message_event = message_event_receiver.select_next_some() => { - match message_event { - Ok(event) => { - let _ = self.handle_message_event((*event).clone()).await.or_else(|resp| { - error!(target: LOG_TARGET, "Error handling outbound message event: {:?}", resp); - Err(resp) - }); - }, - Err(e) => error!(target: LOG_TARGET, "Error handling Outbound Message Event: {:?}", e), - } + join_result = send_transaction_protocol_handles.select_next_some() => { + trace!(target: LOG_TARGET, "Send Protocol for Transaction has ended with result {:?}", join_result); + match join_result { + Ok(join_result_inner) => self.complete_send_transaction_protocol(join_result_inner, &mut transaction_broadcast_protocol_handles).await, + Err(e) => error!(target: LOG_TARGET, "Error resolving Join Handle: {:?}", e), + }; } - tx_id = broadcast_timeout_futures.select_next_some() => { - trace!(target: LOG_TARGET, "Handling Broadcast Timeout"); - let _ = self.handle_mempool_broadcast_timeout(tx_id, &mut broadcast_timeout_futures).await.or_else(|resp| { - error!(target: LOG_TARGET, "Error handling mempool broadcast timeout : {:?}", resp); - Err(resp) - }); + join_result = transaction_broadcast_protocol_handles.select_next_some() => { + trace!(target: LOG_TARGET, "Transaction Broadcast protocol has ended with result {:?}", join_result); + match join_result { + Ok(join_result_inner) => self.complete_transaction_broadcast_protocol(join_result_inner, &mut transaction_chain_monitoring_protocol_handles).await, + Err(e) => error!(target: LOG_TARGET, "Error resolving Join Handle: {:?}", e), + }; } - tx_id = mined_request_timeout_futures.select_next_some() => { - trace!(target: LOG_TARGET, "Handling Mined Request Timeout"); - let _ = self.handle_transaction_mined_request_timeout(tx_id, &mut mined_request_timeout_futures).await.or_else(|resp| { - error!(target: LOG_TARGET, "Error handling transaction mined? request timeout : {:?}", resp); - Err(resp) - }); + join_result = transaction_chain_monitoring_protocol_handles.select_next_some() => { + trace!(target: LOG_TARGET, "Transaction chain monitoring protocol has ended with result {:?}", join_result); + match join_result { + Ok(join_result_inner) => self.complete_transaction_chain_monitoring_protocol(join_result_inner), + Err(e) => error!(target: LOG_TARGET, "Error resolving Join Handle: {:?}", e), + }; } complete => { info!(target: LOG_TARGET, "Transaction service shutting down"); @@ -406,19 +368,29 @@ where async fn handle_request( &mut self, request: TransactionServiceRequest, - discovery_process_futures: &mut FuturesUnordered< - BoxFuture<'static, Result<(MessageTag, OutboundTransaction), TransactionServiceError>>, + send_transaction_join_handles: &mut FuturesUnordered>>, + transaction_broadcast_join_handles: &mut FuturesUnordered< + JoinHandle>, >, - broadcast_timeout_futures: &mut FuturesUnordered>, - mined_request_timeout_futures: &mut FuturesUnordered>, + chain_monitoring_join_handles: &mut FuturesUnordered>>, ) -> Result { trace!(target: LOG_TARGET, "Handling Service Request: {}", request); match request { TransactionServiceRequest::SendTransaction((dest_pubkey, amount, fee_per_gram, message)) => self - .send_transaction(dest_pubkey, amount, fee_per_gram, message, discovery_process_futures) + .send_transaction( + dest_pubkey, + amount, + fee_per_gram, + message, + send_transaction_join_handles, + ) + .await + .map(TransactionServiceResponse::TransactionSent), + TransactionServiceRequest::CancelTransaction(tx_id) => self + .cancel_transaction(tx_id) .await - .map(|_| TransactionServiceResponse::TransactionSent), + .map(|_| TransactionServiceResponse::TransactionCancelled), TransactionServiceRequest::GetPendingInboundTransactions => Ok( TransactionServiceResponse::PendingInboundTransactions(self.get_pending_inbound_transactions().await?), ), @@ -443,13 +415,22 @@ where Ok(TransactionServiceResponse::CoinbaseTransactionCancelled) }, TransactionServiceRequest::SetBaseNodePublicKey(public_key) => self - .set_base_node_public_key(public_key, broadcast_timeout_futures, mined_request_timeout_futures) + .set_base_node_public_key( + public_key, + transaction_broadcast_join_handles, + chain_monitoring_join_handles, + send_transaction_join_handles, + ) .await .map(|_| TransactionServiceResponse::BaseNodePublicKeySet), TransactionServiceRequest::ImportUtxo(value, source_public_key, message) => self .add_utxo_import_transaction(value, source_public_key, message) .await .map(TransactionServiceResponse::UtxoImported), + TransactionServiceRequest::SubmitTransaction((tx_id, tx, fee, amount, message)) => self + .submit_transaction(transaction_broadcast_join_handles, tx_id, tx, fee, amount, message) + .await + .map(|_| TransactionServiceResponse::TransactionSubmitted), #[cfg(feature = "test_harness")] TransactionServiceRequest::CompletePendingOutboundTransaction(completed_transaction) => { self.complete_pending_outbound_transaction(completed_transaction) @@ -479,62 +460,6 @@ where } } - async fn handle_message_event(&mut self, message_event: MessagingEvent) -> Result<(), TransactionServiceError> { - let (message_tag, result) = match message_event { - MessagingEvent::MessageSent(message_tag) => (message_tag, true), - MessagingEvent::SendMessageFailed(outbound_message, _reason) => (outbound_message.tag, false), - _ => return Ok(()), - }; - match self.pending_outbound_message_results.remove(&message_tag) { - None => (), - Some(outbound_tx) => { - // If the message was successfully sent then add it to the pending transaction list - if result { - self.output_manager_service - .confirm_pending_transaction(outbound_tx.tx_id) - .await?; - info!( - target: LOG_TARGET, - "Pending Outbound Transaction TxId: {:?} was successfully sent with Message Tag: {:?}", - outbound_tx.tx_id, - message_tag - ); - } else { - error!( - target: LOG_TARGET, - "Pending Outbound Transaction TxId: {:?} with Message Tag {:?} could not be sent", - outbound_tx.tx_id, - message_tag, - ); - if let Err(e) = self.db.remove_pending_outbound_transaction(outbound_tx.tx_id).await { - error!( - target: LOG_TARGET, - "Failed to remove pending transaction TX_ID: {} after failed sending attempt with error \ - {:?}", - outbound_tx.tx_id, - e - ); - } - if let Err(e) = self.output_manager_service.cancel_transaction(outbound_tx.tx_id).await { - error!( - target: LOG_TARGET, - "Failed to Cancel TX_ID: {} after failed sending attempt with error {:?}", - outbound_tx.tx_id, - e - ); - } - } - - let _ = self - .event_publisher - .send(TransactionEvent::TransactionSendResult(outbound_tx.tx_id, result)) - .await; - }, - } - - Ok(()) - } - /// Sends a new transaction to a recipient /// # Arguments /// 'dest_pubkey': The Comms pubkey of the recipient node @@ -546,107 +471,37 @@ where amount: MicroTari, fee_per_gram: MicroTari, message: String, - discovery_process_futures: &mut FuturesUnordered< - BoxFuture<'static, Result<(MessageTag, OutboundTransaction), TransactionServiceError>>, - >, - ) -> Result<(), TransactionServiceError> + join_handles: &mut FuturesUnordered>>, + ) -> Result { - let mut sender_protocol = self + let sender_protocol = self .output_manager_service .prepare_transaction_to_send(amount, fee_per_gram, None, message.clone()) .await?; - if !sender_protocol.is_single_round_message_ready() { - return Err(TransactionServiceError::InvalidStateError); - } - - let msg = sender_protocol.build_single_round_message()?; - let tx_id = msg.tx_id; - let proto_message = proto::TransactionSenderMessage::single(msg.into()); - - match self - .outbound_message_service - .send_direct( - dest_pubkey.clone(), - OutboundEncryption::EncryptForPeer, - OutboundDomainMessage::new(TariMessageType::SenderPartialTransaction, proto_message), - ) - .await? - { - SendMessageResponse::Queued(tags) => match tags.len() { - 0 => error!( - target: LOG_TARGET, - "Queuing Transaction TX_ID: {} for send was unsuccessful and no message was sent", tx_id - ), - 1 => { - info!( - target: LOG_TARGET, - "Transaction (TxId: {}) Send successfully queued for send with Message Tag: {:?}", - tx_id, - tags[0], - ); - - let outbound_tx = OutboundTransaction { - tx_id, - destination_public_key: dest_pubkey.clone(), - amount, - fee: sender_protocol.get_fee_amount()?, - sender_protocol, - status: TransactionStatus::Pending, - message, - timestamp: Utc::now().naive_utc(), - }; - self.db - .add_pending_outbound_transaction(outbound_tx.tx_id, outbound_tx.clone()) - .await?; - self.pending_outbound_message_results - .insert(tags[0].clone(), outbound_tx); - }, - _ => error!( - target: LOG_TARGET, - "Send process for TX_ID: {} was unsuccessful due to more than 1 MessageTag being returned", tx_id - ), - }, - SendMessageResponse::Failed => return Err(TransactionServiceError::OutboundSendFailure), - SendMessageResponse::PendingDiscovery(r) => { - // The sending of the message resulted in a long running Discovery process being performed by the Comms - // layer. This can take minutes so we will spawn a task to wait for the result and then act - // appropriately on it - let tx_id_clone = tx_id; - let outbound_tx_clone = OutboundTransaction { - tx_id, - destination_public_key: dest_pubkey.clone(), - amount, - fee: sender_protocol.get_fee_amount()?, - sender_protocol: sender_protocol.clone(), - status: TransactionStatus::Pending, - message: message.clone(), - timestamp: Utc::now().naive_utc(), - }; - - info!( - target: LOG_TARGET, - "Send Transaction request for TxID: {:?} to recipient with public_key {} requires that a \ - Discovery Process be conducted", - tx_id, - dest_pubkey - ); - - let discovery_future = async move { - transaction_send_discovery_process_completion(r, tx_id_clone, outbound_tx_clone).await - }; - discovery_process_futures.push(discovery_future.boxed()); - - return Err(TransactionServiceError::OutboundSendDiscoveryInProgress(tx_id)); - }, - } + let tx_id = sender_protocol.get_tx_id()?; - info!( - target: LOG_TARGET, - "Transaction with TX_ID = {} queued to be sent to {}", tx_id, dest_pubkey + let (tx_reply_sender, tx_reply_receiver) = mpsc::channel(100); + let (cancellation_sender, cancellation_receiver) = oneshot::channel(); + self.pending_transaction_reply_senders.insert(tx_id, tx_reply_sender); + self.send_transaction_cancellation_senders + .insert(tx_id, cancellation_sender); + let protocol = TransactionSendProtocol::new( + tx_id, + self.service_resources.clone(), + tx_reply_receiver, + cancellation_receiver, + dest_pubkey, + amount, + message, + sender_protocol, + TransactionProtocolStage::Initial, ); - Ok(()) + let join_handle = tokio::spawn(protocol.execute()); + join_handles.push(join_handle); + + Ok(tx_id) } /// Accept the public reply from a recipient and apply the reply to the relevant transaction protocol @@ -656,82 +511,137 @@ where &mut self, source_pubkey: CommsPublicKey, recipient_reply: proto::RecipientSignedMessage, - broadcast_timeout_futures: &mut FuturesUnordered>, ) -> Result<(), TransactionServiceError> { let recipient_reply: RecipientSignedMessage = recipient_reply .try_into() .map_err(TransactionServiceError::InvalidMessageError)?; - let mut outbound_tx = self.db.get_pending_outbound_transaction(recipient_reply.tx_id).await?; - let tx_id = recipient_reply.tx_id; - if !outbound_tx.sender_protocol.check_tx_id(tx_id.clone()) || - !outbound_tx.sender_protocol.is_collecting_single_signature() - { - return Err(TransactionServiceError::InvalidStateError); - } - - outbound_tx - .sender_protocol - .add_single_recipient_info(recipient_reply, &self.factories.range_proof)?; - outbound_tx - .sender_protocol - .finalize(KernelFeatures::empty(), &self.factories)?; - let tx = outbound_tx.sender_protocol.get_transaction()?; - let completed_transaction = CompletedTransaction { - tx_id, - source_public_key: self.node_identity.public_key().clone(), - destination_public_key: outbound_tx.destination_public_key, - amount: outbound_tx.amount, - fee: outbound_tx.fee, - transaction: tx.clone(), - status: TransactionStatus::Completed, - message: outbound_tx.message.clone(), - timestamp: Utc::now().naive_utc(), + let sender = match self.pending_transaction_reply_senders.get_mut(&tx_id) { + None => return Err(TransactionServiceError::TransactionDoesNotExistError), + Some(s) => s, }; - self.db - .complete_outbound_transaction(tx_id.clone(), completed_transaction.clone()) - .await?; - info!( - target: LOG_TARGET, - "Transaction Recipient Reply for TX_ID = {} received", tx_id, - ); - let finalized_transaction_message = proto::TransactionFinalizedMessage { - tx_id, - transaction: Some(tx.clone().into()), - }; + sender + .send((source_pubkey, recipient_reply)) + .await + .map_err(|_| TransactionServiceError::ProtocolChannelError)?; - self.outbound_message_service - .send_direct( - source_pubkey.clone(), - OutboundEncryption::EncryptForPeer, - OutboundDomainMessage::new(TariMessageType::TransactionFinalized, finalized_transaction_message), - ) - .await?; + Ok(()) + } + + /// Handle the final clean up after a Send Transaction protocol completes + async fn complete_send_transaction_protocol( + &mut self, + join_result: Result, + transaction_broadcast_join_handles: &mut FuturesUnordered< + JoinHandle>, + >, + ) + { + match join_result { + Ok(id) => { + let _ = self.pending_transaction_reply_senders.remove(&id); + let _ = self.send_transaction_cancellation_senders.remove(&id); + let _ = self + .broadcast_completed_transaction_to_mempool(id, transaction_broadcast_join_handles) + .await + .or_else(|resp| { + error!( + target: LOG_TARGET, + "Error starting Broadcast Protocol after completed Send Transaction Protocol : {:?}", resp + ); + Err(resp) + }); + trace!( + target: LOG_TARGET, + "Send Transaction Protocol for TxId: {} completed successfully", + id + ); + }, + Err(TransactionServiceProtocolError { id, error }) => { + let _ = self.pending_transaction_reply_senders.remove(&id); + let _ = self.send_transaction_cancellation_senders.remove(&id); + error!( + target: LOG_TARGET, + "Error completing Send Transaction Protocol (Id: {}): {:?}", id, error + ); + let _ = self + .event_publisher + .send(Arc::new(TransactionEvent::Error(format!("{:?}", error)))); + }, + } + } + + /// Cancel a pending outbound transaction + async fn cancel_transaction(&mut self, tx_id: TxId) -> Result<(), TransactionServiceError> { + self.db.cancel_pending_transaction(tx_id).await.map_err(|e| { + error!( + target: LOG_TARGET, + "Pending Transaction does not exist and could not be cancelled: {:?}", e + ); + e + })?; + + self.output_manager_service.cancel_transaction(tx_id).await?; + + if let Some(cancellation_sender) = self.send_transaction_cancellation_senders.remove(&tx_id) { + let _ = cancellation_sender.send(()); + } + let _ = self.pending_transaction_reply_senders.remove(&tx_id); - // Logging this error here instead of propogating it up to the select! catchall which generates the Error Event. let _ = self - .broadcast_completed_transaction_to_mempool( - tx_id, - self.config.initial_mempool_broadcast_timeout, - broadcast_timeout_futures, - ) - .await + .event_publisher + .send(Arc::new(TransactionEvent::TransactionCancelled(tx_id))) .map_err(|e| { - error!( + trace!( target: LOG_TARGET, - "Error broadcasting completed transaction to mempool: {:?}", e + "Error sending event, usually because there are no subscribers: {:?}", + e ); e }); - self.event_publisher - .send(TransactionEvent::ReceivedTransactionReply(tx_id)) - .await - .map_err(|_| TransactionServiceError::EventStreamError)?; + info!(target: LOG_TARGET, "Pending Transaction (TxId: {}) cancelled", tx_id); + + Ok(()) + } + + async fn restart_all_send_transaction_protocols( + &mut self, + join_handles: &mut FuturesUnordered>>, + ) -> Result<(), TransactionServiceError> + { + let outbound_txs = self.db.get_pending_outbound_transactions().await?; + for (tx_id, tx) in outbound_txs { + if !self.pending_transaction_reply_senders.contains_key(&tx_id) { + debug!( + target: LOG_TARGET, + "Restarting listening for Reply for Pending Outbound Transaction TxId: {}", tx_id + ); + let (tx_reply_sender, tx_reply_receiver) = mpsc::channel(100); + let (cancellation_sender, cancellation_receiver) = oneshot::channel(); + self.pending_transaction_reply_senders.insert(tx_id, tx_reply_sender); + self.send_transaction_cancellation_senders + .insert(tx_id, cancellation_sender); + let protocol = TransactionSendProtocol::new( + tx_id, + self.service_resources.clone(), + tx_reply_receiver, + cancellation_receiver, + tx.destination_public_key, + tx.amount, + tx.message, + tx.sender_protocol, + TransactionProtocolStage::WaitForReply, + ); + + let join_handle = tokio::spawn(protocol.execute()); + join_handles.push(join_handle); + } + } Ok(()) } @@ -752,6 +662,23 @@ where // Currently we will only reply to a Single sender transaction protocol if let TransactionSenderMessage::Single(data) = sender_message.clone() { + trace!( + target: LOG_TARGET, + "Transaction (TxId: {}) received from {}", + data.tx_id, + source_pubkey + ); + // Check this is not a repeat message i.e. tx_id doesn't already exist in our pending or completed + // transactions + if self.db.transaction_exists(data.tx_id).await? { + trace!( + target: LOG_TARGET, + "Transaction (TxId: {}) already present in database.", + data.tx_id + ); + return Err(TransactionServiceError::RepeatedMessageError); + } + let amount = data.amount; let spending_key = self @@ -769,18 +696,21 @@ where ); let recipient_reply = rtp.get_signed_data()?.clone(); - // Check this is not a repeat message i.e. tx_id doesn't already exist in our pending or completed - // transactions - if self.db.transaction_exists(recipient_reply.tx_id).await? { - return Err(TransactionServiceError::RepeatedMessageError); - } - let tx_id = recipient_reply.tx_id; let proto_message: proto::RecipientSignedMessage = recipient_reply.into(); self.outbound_message_service .send_direct( source_pubkey.clone(), - OutboundEncryption::EncryptForPeer, + OutboundEncryption::None, + OutboundDomainMessage::new(TariMessageType::ReceiverPartialTransactionReply, proto_message.clone()), + ) + .await?; + + self.outbound_message_service + .propagate( + NodeDestination::NodeId(Box::new(NodeId::from_key(&source_pubkey)?)), + OutboundEncryption::EncryptFor(Box::new(source_pubkey.clone())), + vec![], OutboundDomainMessage::new(TariMessageType::ReceiverPartialTransactionReply, proto_message), ) .await?; @@ -808,10 +738,17 @@ where "Transaction (TX_ID: {}) - Amount: {} - Message: {}", tx_id, amount, data.message ); - self.event_publisher - .send(TransactionEvent::ReceivedTransaction(tx_id)) - .await - .map_err(|_| TransactionServiceError::EventStreamError)?; + let _ = self + .event_publisher + .send(Arc::new(TransactionEvent::ReceivedTransaction(tx_id))) + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); } Ok(()) } @@ -824,7 +761,9 @@ where &mut self, source_pubkey: CommsPublicKey, finalized_transaction: proto::TransactionFinalizedMessage, - broadcast_timeout_futures: &mut FuturesUnordered>, + transaction_broadcast_join_handles: &mut FuturesUnordered< + JoinHandle>, + >, ) -> Result<(), TransactionServiceError> { let tx_id = finalized_transaction.tx_id; @@ -841,6 +780,19 @@ where "Cannot convert Transaction field from TransactionFinalized message".to_string(), ) })?; + + let inbound_tx = match self.db.get_pending_inbound_transaction(tx_id).await { + Ok(tx) => tx, + Err(_e) => { + warn!( + target: LOG_TARGET, + "TxId for received Finalized Transaction does not exist in Pending Inbound Transactions, could be \ + a repeat Store and Forward message" + ); + return Ok(()); + }, + }; + info!( target: LOG_TARGET, "Finalized Transaction with TX_ID = {} received from {}", @@ -848,18 +800,6 @@ where source_pubkey.clone() ); - let inbound_tx = self - .db - .get_pending_inbound_transaction(tx_id.clone()) - .await - .map_err(|e| { - error!( - target: LOG_TARGET, - "Finalized transaction TxId does not exist in Pending Inbound Transactions" - ); - e - })?; - if inbound_tx.source_public_key != source_pubkey { error!( target: LOG_TARGET, @@ -896,14 +836,9 @@ where }; self.db - .complete_inbound_transaction(tx_id.clone(), completed_transaction.clone()) + .complete_inbound_transaction(tx_id, completed_transaction.clone()) .await?; - self.event_publisher - .send(TransactionEvent::ReceivedFinalizedTransaction(tx_id)) - .await - .map_err(|_| TransactionServiceError::EventStreamError)?; - info!( target: LOG_TARGET, "Inbound Transaction with TX_ID = {} from {} moved to Completed Transactions", @@ -913,11 +848,7 @@ where // Logging this error here instead of propogating it up to the select! catchall which generates the Error Event. let _ = self - .broadcast_completed_transaction_to_mempool( - tx_id, - self.config.initial_mempool_broadcast_timeout, - broadcast_timeout_futures, - ) + .broadcast_completed_transaction_to_mempool(tx_id, transaction_broadcast_join_handles) .await .map_err(|e| { error!( @@ -927,6 +858,18 @@ where e }); + let _ = self + .event_publisher + .send(Arc::new(TransactionEvent::ReceivedFinalizedTransaction(tx_id))) + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); + Ok(()) } @@ -941,11 +884,11 @@ where let spending_key = self .output_manager_service - .get_coinbase_spending_key(tx_id.clone(), amount.clone(), maturity_height) + .get_coinbase_spending_key(tx_id, amount.clone(), maturity_height) .await?; self.db - .add_pending_coinbase_transaction(tx_id.clone(), PendingCoinbaseTransaction { + .add_pending_coinbase_transaction(tx_id, PendingCoinbaseTransaction { tx_id, amount, commitment: self.factories.commitment.commit_value(&spending_key, u64::from(amount)), @@ -964,17 +907,13 @@ where completed_transaction: Transaction, ) -> Result<(), TransactionServiceError> { - let coinbase_tx = self - .db - .get_pending_coinbase_transaction(tx_id.clone()) - .await - .map_err(|e| { - error!( - target: LOG_TARGET, - "Finalized coinbase transaction TxId does not exist in Pending Inbound Transactions" - ); - e - })?; + let coinbase_tx = self.db.get_pending_coinbase_transaction(tx_id).await.map_err(|e| { + error!( + target: LOG_TARGET, + "Finalized coinbase transaction TxId does not exist in Pending Inbound Transactions" + ); + e + })?; if !completed_transaction.body.inputs().is_empty() || completed_transaction.body.outputs().len() != 1 || @@ -1018,17 +957,13 @@ where /// If a specific coinbase transaction will not be mined then the Miner can cancel it pub async fn cancel_pending_coinbase_transaction(&mut self, tx_id: TxId) -> Result<(), TransactionServiceError> { - let _ = self - .db - .get_pending_coinbase_transaction(tx_id.clone()) - .await - .map_err(|e| { - error!( - target: LOG_TARGET, - "Finalized coinbase transaction TxId does not exist in Pending Inbound Transactions" - ); - e - })?; + let _ = self.db.get_pending_coinbase_transaction(tx_id).await.map_err(|e| { + error!( + target: LOG_TARGET, + "Finalized coinbase transaction TxId does not exist in Pending Inbound Transactions" + ); + e + })?; self.output_manager_service.cancel_transaction(tx_id).await?; @@ -1061,8 +996,9 @@ where async fn set_base_node_public_key( &mut self, base_node_public_key: CommsPublicKey, - broadcast_timeout_futures: &mut FuturesUnordered>, - mined_request_timeout_futures: &mut FuturesUnordered>, + broadcast_join_handles: &mut FuturesUnordered>>, + chain_monitoring_join_handles: &mut FuturesUnordered>>, + send_transaction_join_handles: &mut FuturesUnordered>>, ) -> Result<(), TransactionServiceError> { let startup_broadcast = self.base_node_public_key.is_none(); @@ -1071,7 +1007,7 @@ where if startup_broadcast { let _ = self - .broadcast_all_completed_transactions_to_mempool(broadcast_timeout_futures) + .broadcast_all_completed_transactions_to_mempool(broadcast_join_handles) .await .or_else(|resp| { error!( @@ -1082,7 +1018,7 @@ where }); let _ = self - .monitor_all_completed_transactions_for_mining(mined_request_timeout_futures) + .start_chain_monitoring_for_all_broadcast_transactions(chain_monitoring_join_handles) .await .or_else(|resp| { error!( @@ -1091,6 +1027,17 @@ where ); Err(resp) }); + + let _ = self + .restart_all_send_transaction_protocols(send_transaction_join_handles) + .await + .or_else(|resp| { + error!( + target: LOG_TARGET, + "Error restarting protocols for all pending outbound transactions: {:?}", resp + ); + Err(resp) + }); } Ok(()) } @@ -1101,11 +1048,10 @@ where pub async fn broadcast_completed_transaction_to_mempool( &mut self, tx_id: TxId, - timeout: Duration, - broadcast_timeout_futures: &mut FuturesUnordered>, + join_handles: &mut FuturesUnordered>>, ) -> Result<(), TransactionServiceError> { - let completed_tx = self.db.get_completed_transaction(tx_id.clone()).await?; + let completed_tx = self.db.get_completed_transaction(tx_id).await?; if completed_tx.status != TransactionStatus::Completed || completed_tx.transaction.body.kernels().is_empty() { return Err(TransactionServiceError::InvalidCompletedTransaction); @@ -1113,35 +1059,20 @@ where match self.base_node_public_key.clone() { None => return Err(TransactionServiceError::NoBaseNodeKeysProvided), Some(pk) => { - info!( - target: LOG_TARGET, - "Attempting to Broadcast Transaction (TxId: {} and Kernel Signature: {}) to Mempool", - completed_tx.tx_id, - completed_tx.transaction.body.kernels()[0] - .excess_sig - .get_signature() - .to_hex() + let (mempool_response_sender, mempool_response_receiver) = mpsc::channel(100); + let (base_node_response_sender, base_node_response_receiver) = mpsc::channel(100); + self.mempool_response_senders.insert(tx_id, mempool_response_sender); + self.base_node_response_senders.insert(tx_id, base_node_response_sender); + let protocol = TransactionBroadcastProtocol::new( + tx_id, + self.service_resources.clone(), + self.config.mempool_broadcast_timeout, + pk, + mempool_response_receiver, + base_node_response_receiver, ); - trace!(target: LOG_TARGET, "{}", completed_tx.transaction); - - // Send Mempool Request - let mempool_request = MempoolProto::MempoolServiceRequest { - request_key: completed_tx.tx_id, - request: Some(MempoolProto::mempool_service_request::Request::SubmitTransaction( - completed_tx.transaction.into(), - )), - }; - self.outbound_message_service - .send_direct( - pk.clone(), - OutboundEncryption::EncryptForPeer, - OutboundDomainMessage::new(TariMessageType::MempoolRequest, mempool_request), - ) - .await?; - // Start Timeout - let state_timeout = StateDelay::new(timeout, completed_tx.tx_id); - - broadcast_timeout_futures.push(state_timeout.delay().boxed()); + let join_handle = tokio::spawn(protocol.execute()); + join_handles.push(join_handle); }, } @@ -1152,303 +1083,174 @@ where /// node followed by mempool requests to confirm that they have been received async fn broadcast_all_completed_transactions_to_mempool( &mut self, - broadcast_timeout_futures: &mut FuturesUnordered>, + join_handles: &mut FuturesUnordered>>, ) -> Result<(), TransactionServiceError> { - trace!(target: LOG_TARGET, "Querying Broadcast? for all completed Transactions"); + trace!(target: LOG_TARGET, "Attempting to Broadcast all Completed Transactions"); let completed_txs = self.db.get_completed_transactions().await?; for completed_tx in completed_txs.values() { - if completed_tx.status == TransactionStatus::Completed { - self.broadcast_completed_transaction_to_mempool( - completed_tx.tx_id.clone(), - self.config.initial_mempool_broadcast_timeout, - broadcast_timeout_futures, - ) - .await?; + if completed_tx.status == TransactionStatus::Completed && + !self.mempool_response_senders.contains_key(&completed_tx.tx_id) + { + self.broadcast_completed_transaction_to_mempool(completed_tx.tx_id, join_handles) + .await?; } } Ok(()) } - /// Handle the timeout of a pending transaction broadcast request. This will check if the transaction's status has - /// been updated by received MempoolRepsonse during the course of this timeout. If it has not been updated the - /// transaction is broadcast again - pub async fn handle_mempool_broadcast_timeout( + /// Handle an incoming mempool response message + pub async fn handle_mempool_response( &mut self, - tx_id: TxId, - broadcast_timeout_futures: &mut FuturesUnordered>, + response: MempoolProto::MempoolServiceResponse, ) -> Result<(), TransactionServiceError> { - let completed_tx = self.db.get_completed_transaction(tx_id.clone()).await?; - - if completed_tx.status == TransactionStatus::Completed { - info!( - target: LOG_TARGET, - "Mempool broadcast timed out for Transaction with TX_ID: {}", tx_id - ); + let response = MempoolServiceResponse::try_from(response).unwrap(); + trace!(target: LOG_TARGET, "Received Mempool Response: {:?}", response); - self.broadcast_completed_transaction_to_mempool( - tx_id, - self.config.mempool_broadcast_timeout, - broadcast_timeout_futures, - ) - .await?; + let tx_id = response.request_key; - self.event_publisher - .send(TransactionEvent::MempoolBroadcastTimedOut(tx_id)) - .await - .map_err(|_| TransactionServiceError::EventStreamError)?; - } + let sender = match self.mempool_response_senders.get_mut(&tx_id) { + None => { + trace!( + target: LOG_TARGET, + "Received Mempool response with unexpected key: {}. Not for this service", + response.request_key + ); + return Ok(()); + }, + Some(s) => s, + }; + sender + .send(response) + .await + .map_err(|_| TransactionServiceError::ProtocolChannelError)?; Ok(()) } - /// Handle an incoming mempool response message - pub async fn handle_mempool_response( + /// Handle the final clean up after a Transaction Broadcast protocol completes + async fn complete_transaction_broadcast_protocol( &mut self, - response: MempoolProto::MempoolServiceResponse, - mined_request_timeout_futures: &mut FuturesUnordered>, - ) -> Result<(), TransactionServiceError> + join_result: Result, + transaction_chain_monitoring_join_handles: &mut FuturesUnordered< + JoinHandle>, + >, + ) { - let response = MempoolServiceResponse::try_from(response).unwrap(); - let tx_id = response.request_key; - match response.response { - MempoolResponse::Stats(_) => { - return Err(TransactionServiceError::InvalidMessageError( - "Mempool Response of invalid type".to_string(), - )) - }, - MempoolResponse::TxStorage(ts) => { - let completed_tx = self.db.get_completed_transaction(response.request_key.clone()).await?; - - match completed_tx.status { - TransactionStatus::Completed => match ts { - // Getting this response means the Mempool Rejected this transaction so it will be cancelled. - TxStorageResponse::NotStored => { - // If this transaction is still in the Completed State it should be upgraded to the - // Broadcast state - error!( + match join_result { + Ok(id) => { + // Cleanup any registered senders + let _ = self.mempool_response_senders.remove(&id); + let _ = self.base_node_response_senders.remove(&id); + trace!( + target: LOG_TARGET, + "Transaction Broadcast Protocol for TxId: {} completed successfully", + id + ); + let _ = self + .start_transaction_chain_monitoring_protocol(id, transaction_chain_monitoring_join_handles) + .await + .or_else(|resp| { + match resp { + TransactionServiceError::InvalidCompletedTransaction => trace!( target: LOG_TARGET, - "Mempool response received for TxId: {:?}. Transaction was REJECTED. Cancelling \ - transaction.", - tx_id - ); - if let Err(e) = self.output_manager_service.cancel_transaction(completed_tx.tx_id).await { - error!( - target: LOG_TARGET, - "Failed to Cancel outputs for TX_ID: {} after failed sending attempt with error \ - {:?}", - completed_tx.tx_id, - e - ); - } - if let Err(e) = self.db.cancel_completed_transaction(completed_tx.tx_id).await { - error!( - target: LOG_TARGET, - "Failed to Cancel TX_ID: {} after failed sending attempt with error {:?}", - completed_tx.tx_id, - e - ); - } - self.event_publisher - .send(TransactionEvent::TransactionSendDiscoveryComplete(tx_id, false)) - .await - .map_err(|_| TransactionServiceError::EventStreamError)?; - }, - // Any other variant of this enum means the transaction has been received by the base_node and - // is in one of the various mempools - _ => { - // If this transaction is still in the Completed State it should be upgraded to the - // Broadcast state - - info!( + "Not starting Chain monitoring protocol as transaction cannot be found, either \ + cancelled or already mined." + ), + _ => error!( target: LOG_TARGET, - "Completed Transaction (TxId: {} and Kernel Excess Sig: {}) detected as Broadcast to \ - Base Node Mempool", - tx_id, - completed_tx.transaction.body.kernels()[0] - .excess_sig - .get_signature() - .to_hex() - ); - self.db.broadcast_completed_transaction(tx_id.clone()).await?; - // Start monitoring the base node to see if this Tx has been mined - self.send_transaction_mined_request( - tx_id.clone(), - self.config.base_node_mined_timeout, - mined_request_timeout_futures, - ) - .await?; - - self.event_publisher - .send(TransactionEvent::TransactionBroadcast(tx_id)) - .await - .map_err(|_| TransactionServiceError::EventStreamError)?; - }, - }, - TransactionStatus::Broadcast => { - info!( - target: LOG_TARGET, - "Mempool query for transaction Tx_ID: {} returned {:?}", completed_tx.tx_id, ts - ); - if let Some(result) = self.pending_transaction_mined_queries.get_mut(&completed_tx.tx_id) { - match ts { - TxStorageResponse::NotStored => result.mempool_response = Some(false), - _ => result.mempool_response = Some(true), - } - debug!(target: LOG_TARGET, "Current Mempool/Mined state {:?}", result); - if result.is_complete() { - self.handle_transaction_mined_request_result(completed_tx.tx_id).await; - } + "Error starting Chain Monitoring Protocol after completed Broadcast Protocol : {:?}", + resp + ), } - }, - _ => (), - } + Err(resp) + }); + }, + Err(TransactionServiceProtocolError { id, error }) => { + let _ = self.mempool_response_senders.remove(&id); + let _ = self.base_node_response_senders.remove(&id); + error!( + target: LOG_TARGET, + "Error completing Transaction Broadcast Protocol (Id: {}): {:?}", id, error + ); + let _ = self + .event_publisher + .send(Arc::new(TransactionEvent::Error(format!("{:?}", error)))); }, } - Ok(()) } /// Send a request to the Base Node to see if the specified transaction has been mined yet. This function will send /// the request and store a timeout future to check in on the status of the transaction in the future. - async fn send_transaction_mined_request( + async fn start_transaction_chain_monitoring_protocol( &mut self, tx_id: TxId, - timeout: Duration, - mined_request_timeout_futures: &mut FuturesUnordered>, + join_handles: &mut FuturesUnordered>>, ) -> Result<(), TransactionServiceError> { - let completed_tx = self.db.get_completed_transaction(tx_id.clone()).await?; + let completed_tx = self.db.get_completed_transaction(tx_id).await?; - if (completed_tx.status != TransactionStatus::Broadcast && completed_tx.status != TransactionStatus::Completed) || - completed_tx.transaction.body.kernels().is_empty() - { + if completed_tx.status != TransactionStatus::Broadcast || completed_tx.transaction.body.kernels().is_empty() { return Err(TransactionServiceError::InvalidCompletedTransaction); } match self.base_node_public_key.clone() { None => return Err(TransactionServiceError::NoBaseNodeKeysProvided), Some(pk) => { - let mut hashes = Vec::new(); - for o in completed_tx.transaction.body.outputs() { - hashes.push(o.hash()); - } - - info!( - target: LOG_TARGET, - "Sending Transaction Mined? request for TxId: {} to Base Node with {} outputs", - tx_id, - hashes.len(), + let protocol_id = OsRng.next_u64(); + + let (mempool_response_sender, mempool_response_receiver) = mpsc::channel(100); + let (base_node_response_sender, base_node_response_receiver) = mpsc::channel(100); + self.mempool_response_senders + .insert(protocol_id, mempool_response_sender); + self.base_node_response_senders + .insert(protocol_id, base_node_response_sender); + let protocol = TransactionChainMonitoringProtocol::new( + protocol_id, + completed_tx.tx_id, + self.service_resources.clone(), + self.config.base_node_mined_timeout, + pk, + mempool_response_receiver, + base_node_response_receiver, ); - - // Send a request to the mempool to find the state of the Tx there - let tx_excess_sig = completed_tx.transaction.body.kernels()[0].excess_sig.clone(); - let mempool_request = MempoolProto::MempoolServiceRequest { - request_key: completed_tx.tx_id, - request: Some(MempoolProto::mempool_service_request::Request::GetTxStateWithExcessSig( - tx_excess_sig.into(), - )), - }; - self.outbound_message_service - .send_direct( - pk.clone(), - OutboundEncryption::EncryptForPeer, - OutboundDomainMessage::new(TariMessageType::MempoolRequest, mempool_request), - ) - .await?; - - // Ask the base node if the outputs are in the chain - let request = BaseNodeRequestProto::FetchUtxos(BaseNodeProto::HashOutputs { outputs: hashes }); - let service_request = BaseNodeProto::BaseNodeServiceRequest { - request_key: tx_id, - request: Some(request), - }; - self.outbound_message_service - .send_direct( - pk.clone(), - OutboundEncryption::EncryptForPeer, - OutboundDomainMessage::new(TariMessageType::BaseNodeRequest, service_request), - ) - .await?; - // Start Timeout - let state_timeout = StateDelay::new(timeout, completed_tx.tx_id); - let _ = self - .pending_transaction_mined_queries - .insert(tx_id, TransactionMinedRequestResult::default()); - mined_request_timeout_futures.push(state_timeout.delay().boxed()); + let join_handle = tokio::spawn(protocol.execute()); + join_handles.push(join_handle); }, } Ok(()) } - /// Handle the timeout of a pending transaction mined? request. This will check if the transaction's status has - /// been updated by received BaseNodeRepsonse during the course of this timeout. If it has not been updated the - /// transaction is broadcast again - pub async fn handle_transaction_mined_request_timeout( + /// Handle the final clean up after a Transaction Chain Monitoring protocol completes + fn complete_transaction_chain_monitoring_protocol( &mut self, - tx_id: TxId, - mined_request_timeout_futures: &mut FuturesUnordered>, - ) -> Result<(), TransactionServiceError> + join_result: Result, + ) { - let completed_tx = self.db.get_completed_transaction(tx_id.clone()).await?; - - if completed_tx.status == TransactionStatus::Broadcast || completed_tx.status == TransactionStatus::Completed { - info!( - target: LOG_TARGET, - "Transaction Mined? request timed out for TX_ID: {}", tx_id - ); - - self.send_transaction_mined_request( - tx_id, - self.config.base_node_mined_timeout, - mined_request_timeout_futures, - ) - .await?; - - self.event_publisher - .send(TransactionEvent::TransactionMinedRequestTimedOut(tx_id)) - .await - .map_err(|_| TransactionServiceError::EventStreamError)?; - } - - Ok(()) - } - - /// Handle the result of receiving all the stages needed to complete a Transaction Mined request - pub async fn handle_transaction_mined_request_result(&mut self, tx_id: TxId) { - if let Some(result) = self.pending_transaction_mined_queries.remove(&tx_id) { - // If the transaction is not in mempool AND not mined then the Tx was reorged out and will never appear - // in the chain and should be cancelled - if result.mempool_response == Some(false) && result.chain_response == Some(false) { + match join_result { + Ok(id) => { + // Cleanup any registered senders + let _ = self.mempool_response_senders.remove(&id); + let _ = self.base_node_response_senders.remove(&id); + trace!( + target: LOG_TARGET, + "Transaction chain monitoring Protocol for TxId: {} completed successfully", + id + ); + }, + Err(TransactionServiceProtocolError { id, error }) => { + let _ = self.mempool_response_senders.remove(&id); + let _ = self.base_node_response_senders.remove(&id); error!( target: LOG_TARGET, - "Transaction (TxId: {}) has left the Mempool while not being Mined. It will be cancelled.", tx_id, + "Error completing Transaction chain monitoring Protocol (Id: {}): {:?}", id, error ); - let _ = self - .output_manager_service - .cancel_transaction(tx_id) - .await - .map_err(|e| { - error!( - target: LOG_TARGET, - "Failed to Cancel outputs for TX_ID: {} after failed sending attempt with error {:?}", - tx_id, - e - ); - }); - let _ = self.db.cancel_completed_transaction(tx_id).await.map_err(|e| { - error!( - target: LOG_TARGET, - "Failed to Cancel TX_ID: {} after failed sending attempt with error {:?}", tx_id, e - ); - }); let _ = self .event_publisher - .send(TransactionEvent::TransactionSendDiscoveryComplete(tx_id, false)) - .await - .map_err(|e| error!(target: LOG_TARGET, "Failed send event {:?}", e)); - } + .send(Arc::new(TransactionEvent::Error(format!("{:?}", error)))); + }, } } @@ -1458,115 +1260,41 @@ where response: BaseNodeProto::BaseNodeServiceResponse, ) -> Result<(), TransactionServiceError> { - let tx_id = response.request_key; - let response: Vec = match response.response { - Some(BaseNodeResponseProto::TransactionOutputs(outputs)) => outputs.outputs, - _ => { - return Ok(()); - }, - }; - - let completed_tx = match self.db.get_completed_transaction(tx_id.clone()).await { - Ok(tx) => tx, - Err(_) => { - debug!( + let sender = match self.base_node_response_senders.get_mut(&response.request_key) { + None => { + trace!( target: LOG_TARGET, - "Base Node Response received with unexpected key {:?}", tx_id + "Received Base Node response with unexpected key: {}. Not for this service", + response.request_key ); return Ok(()); }, + Some(s) => s, }; - // If this transaction is still in the Broadcast or Completed State it should be upgraded to the Mined state - if completed_tx.status == TransactionStatus::Broadcast || completed_tx.status == TransactionStatus::Completed { - // Confirm that all outputs were reported as mined for the transaction - if response.len() != completed_tx.transaction.body.outputs().len() { - info!( - target: LOG_TARGET, - "Base node response received. TxId: {:?} not mined yet. ({} outputs requested but {} returned)", - tx_id, - completed_tx.transaction.body.outputs().len(), - response.len(), - ); - - if completed_tx.status == TransactionStatus::Broadcast { - if let Some(result) = self.pending_transaction_mined_queries.get_mut(&completed_tx.tx_id) { - result.chain_response = Some(false); - debug!(target: LOG_TARGET, "Current Mempool/Mined state {:?}", result); - if result.is_complete() { - self.handle_transaction_mined_request_result(completed_tx.tx_id).await; - } - } - } - } else { - let mut check = true; - - for output in response.iter() { - let transaction_output = TransactionOutput::try_from(output.clone()) - .map_err(TransactionServiceError::ConversionError)?; - - check = check && - completed_tx - .transaction - .body - .outputs() - .iter() - .any(|item| item == &transaction_output); - } - // If all outputs are present then mark this transaction as mined. - if check { - self.output_manager_service - .confirm_transaction( - tx_id.clone(), - completed_tx.transaction.body.inputs().clone(), - completed_tx.transaction.body.outputs().clone(), - ) - .await?; - - self.db.mine_completed_transaction(tx_id).await?; - - self.event_publisher - .send(TransactionEvent::TransactionMined(tx_id)) - .await - .map_err(|_| TransactionServiceError::EventStreamError)?; - - info!( - target: LOG_TARGET, - "Transaction (TxId: {:?}) detected as mined on the Base Layer", tx_id - ); - } - } - } else { - debug!( - target: LOG_TARGET, - "Base node response received for TxId: {:?} but this transaction is not in the Broadcast state", tx_id - ); - } + sender + .send(response.clone()) + .await + .map_err(|_| TransactionServiceError::ProtocolChannelError)?; Ok(()) } - /// Go through all completed transactions that have been broadcast and start querying the base_node to see if they + /// Go through all completed transactions that have been broadcast and start querying the base_node to see if they /// have been mined - async fn monitor_all_completed_transactions_for_mining( + async fn start_chain_monitoring_for_all_broadcast_transactions( &mut self, - mined_request_timeout_futures: &mut FuturesUnordered>, + join_handles: &mut FuturesUnordered>>, ) -> Result<(), TransactionServiceError> { trace!( target: LOG_TARGET, - "Querying Transaction Mined? for all Broadcast Transactions" + "Starting Chain monitoring for all Broadcast Transactions" ); let completed_txs = self.db.get_completed_transactions().await?; for completed_tx in completed_txs.values() { - if completed_tx.status == TransactionStatus::Broadcast || - completed_tx.status == TransactionStatus::Completed - { - self.send_transaction_mined_request( - completed_tx.tx_id.clone(), - self.config.initial_base_node_mined_timeout, - mined_request_timeout_futures, - ) - .await?; + if completed_tx.status == TransactionStatus::Broadcast { + self.start_transaction_chain_monitoring_protocol(completed_tx.tx_id, join_handles) + .await?; } } @@ -1584,7 +1312,7 @@ where let tx_id = OsRng.next_u64(); self.db .add_utxo_import_transaction( - tx_id.clone(), + tx_id, value, source_public_key, self.node_identity.public_key().clone(), @@ -1594,6 +1322,43 @@ where Ok(tx_id) } + /// Submit a completed transaction to the Transaction Manager + pub async fn submit_transaction( + &mut self, + transaction_broadcast_join_handles: &mut FuturesUnordered< + JoinHandle>, + >, + tx_id: TxId, + tx: Transaction, + fee: MicroTari, + amount: MicroTari, + message: String, + ) -> Result<(), TransactionServiceError> + { + trace!(target: LOG_TARGET, "Submit transaction ({}) to db.", tx_id); + self.db + .insert_completed_transaction(tx_id, CompletedTransaction { + tx_id, + source_public_key: self.node_identity.public_key().clone(), + destination_public_key: self.node_identity.public_key().clone(), + amount, + fee, + transaction: tx, + status: TransactionStatus::Completed, + message, + timestamp: Utc::now().naive_utc(), + }) + .await?; + trace!( + target: LOG_TARGET, + "Launch the transaction broadcast protocol for submitted transaction ({}).", + tx_id + ); + self.complete_send_transaction_protocol(Ok(tx_id), transaction_broadcast_join_handles) + .await; + Ok(()) + } + /// This function is only available for testing by the client of LibWallet. It simulates a receiver accepting and /// replying to a Pending Outbound Transaction. This results in that transaction being "completed" and it's status /// set to `Broadcast` which indicated it is in a base_layer mempool. @@ -1604,7 +1369,7 @@ where ) -> Result<(), TransactionServiceError> { self.db - .complete_outbound_transaction(completed_tx.tx_id.clone(), completed_tx.clone()) + .complete_outbound_transaction(completed_tx.tx_id, completed_tx.clone()) .await?; Ok(()) } @@ -1615,16 +1380,23 @@ where #[cfg(feature = "test_harness")] pub async fn broadcast_transaction(&mut self, tx_id: TxId) -> Result<(), TransactionServiceError> { let completed_txs = self.db.get_completed_transactions().await?; - completed_txs.get(&tx_id.clone()).ok_or_else(|| { + completed_txs.get(&tx_id).ok_or_else(|| { TransactionServiceError::TestHarnessError("Could not find Completed TX to broadcast.".to_string()) })?; self.db.broadcast_completed_transaction(tx_id).await?; - self.event_publisher - .send(TransactionEvent::TransactionBroadcast(tx_id)) - .await - .map_err(|_| TransactionServiceError::EventStreamError)?; + let _ = self + .event_publisher + .send(Arc::new(TransactionEvent::TransactionBroadcast(tx_id))) + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); Ok(()) } @@ -1636,18 +1408,18 @@ where #[cfg(feature = "test_harness")] pub async fn mine_transaction(&mut self, tx_id: TxId) -> Result<(), TransactionServiceError> { let completed_txs = self.db.get_completed_transactions().await?; - let _found_tx = completed_txs.get(&tx_id.clone()).ok_or_else(|| { + let _found_tx = completed_txs.get(&tx_id).ok_or_else(|| { TransactionServiceError::TestHarnessError("Could not find Completed TX to mine.".to_string()) })?; let pending_tx_outputs = self.output_manager_service.get_pending_transactions().await?; - let pending_tx = pending_tx_outputs.get(&tx_id.clone()).ok_or_else(|| { + let pending_tx = pending_tx_outputs.get(&tx_id).ok_or_else(|| { TransactionServiceError::TestHarnessError("Could not find Pending TX to complete.".to_string()) })?; self.output_manager_service .confirm_transaction( - tx_id.clone(), + tx_id, pending_tx .outputs_to_be_spent .iter() @@ -1666,10 +1438,17 @@ where self.db.mine_completed_transaction(tx_id).await?; - self.event_publisher - .send(TransactionEvent::TransactionMined(tx_id)) - .await - .map_err(|_| TransactionServiceError::EventStreamError)?; + let _ = self + .event_publisher + .send(Arc::new(TransactionEvent::TransactionMined(tx_id))) + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); Ok(()) } @@ -1689,7 +1468,7 @@ where service::OutputManagerService, storage::{database::OutputManagerDatabase, memory_db::OutputManagerMemoryDatabase}, }; - use futures::{channel::mpsc, stream}; + use futures::stream; use tari_broadcast_channel::bounded; let (_sender, receiver) = reply_channel::unbounded(); @@ -1713,7 +1492,7 @@ where fake_oms.add_output(uo).await?; let mut stp = fake_oms - .prepare_transaction_to_send(amount, MicroTari::from(100), None, "".to_string()) + .prepare_transaction_to_send(amount, MicroTari::from(25), None, "".to_string()) .await?; let msg = stp.build_single_round_message()?; @@ -1724,7 +1503,7 @@ where let spending_key = self .output_manager_service - .get_recipient_spending_key(tx_id.clone(), amount.clone()) + .get_recipient_spending_key(tx_id, amount.clone()) .await?; let nonce = PrivateKey::random(&mut OsRng); let rtp = ReceiverTransactionProtocol::new( @@ -1746,13 +1525,20 @@ where }; self.db - .add_pending_inbound_transaction(tx_id.clone(), inbound_transaction.clone()) + .add_pending_inbound_transaction(tx_id, inbound_transaction.clone()) .await?; - self.event_publisher - .send(TransactionEvent::ReceivedTransaction(tx_id)) - .await - .map_err(|_| TransactionServiceError::EventStreamError)?; + let _ = self + .event_publisher + .send(Arc::new(TransactionEvent::ReceivedTransaction(tx_id))) + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); Ok(()) } @@ -1763,7 +1549,7 @@ where pub async fn finalize_received_test_transaction(&mut self, tx_id: TxId) -> Result<(), TransactionServiceError> { let inbound_txs = self.db.get_pending_inbound_transactions().await?; - let found_tx = inbound_txs.get(&tx_id.clone()).ok_or_else(|| { + let found_tx = inbound_txs.get(&tx_id).ok_or_else(|| { TransactionServiceError::TestHarnessError("Could not find Pending Inbound TX to finalize.".to_string()) })?; @@ -1780,85 +1566,32 @@ where }; self.db - .complete_inbound_transaction(tx_id.clone(), completed_transaction.clone()) + .complete_inbound_transaction(tx_id, completed_transaction.clone()) .await?; - self.event_publisher - .send(TransactionEvent::ReceivedFinalizedTransaction(tx_id)) - .await - .map_err(|_| TransactionServiceError::EventStreamError)?; - Ok(()) - } -} - -// Asynchronous Tasks - -async fn transaction_send_discovery_process_completion( - response_channel: oneshot::Receiver, - tx_id: TxId, - outbound_tx: OutboundTransaction, -) -> Result<(MessageTag, OutboundTransaction), TransactionServiceError> -{ - let mut message_tag: Option = None; - match response_channel.await { - Ok(response) => match response { - SendMessageResponse::Queued(tags) => match tags.len() { - 0 => error!( - target: LOG_TARGET, - "Send Discovery process for TX_ID: {} was unsuccessful and no message was sent", tx_id - ), - 1 => { - message_tag = Some(tags[0]); - - info!( - target: LOG_TARGET, - "Transaction (TxId: {}) Send Discovery process successful with Message Tag: {:?}", - tx_id, - message_tag, - ); - }, - _ => error!( - target: LOG_TARGET, - "Send Discovery process for TX_ID: {} was unsuccessful due to more than 1 MessageTag being \ - returned", - tx_id - ), - }, - _ => { - error!( + let _ = self + .event_publisher + .send(Arc::new(TransactionEvent::ReceivedFinalizedTransaction(tx_id))) + .map_err(|e| { + trace!( target: LOG_TARGET, - "Transaction (TxId: {}) Send Discovery process failed", tx_id + "Error sending event, usually because there are no subscribers: {:?}", + e ); - }, - }, - Err(_) => { - error!( - target: LOG_TARGET, - "Transaction (TxId: {}) Send Response One-shot channel dropped", tx_id - ); - }, - } - - if let Some(mt) = message_tag { - let updated_outbound_tx = OutboundTransaction { - timestamp: Utc::now().naive_utc(), - ..outbound_tx.clone() - }; - Ok((mt, updated_outbound_tx)) - } else { - Err(TransactionServiceError::DiscoveryProcessFailed(tx_id)) + e + }); + Ok(()) } } -/// This struct holds the responses of a multistage base node monitoring request to see if a Transaction has been mined. -/// This is used to keep track of the status of the transaction in both the mempool and chain. -#[derive(Debug, Default)] -struct TransactionMinedRequestResult { - mempool_response: Option, - chain_response: Option, -} - -impl TransactionMinedRequestResult { - fn is_complete(&self) -> bool { - self.mempool_response.is_some() && self.chain_response.is_some() - } +/// This struct is a collection of the common resources that a protocol in the service requires. +#[derive(Clone)] +pub struct TransactionServiceResources +where TBackend: TransactionBackend + Clone + 'static +{ + pub db: TransactionDatabase, + pub output_manager_service: OutputManagerHandle, + pub outbound_message_service: OutboundMessageRequester, + pub event_publisher: TransactionEventSender, + pub node_identity: Arc, + pub factories: CryptoFactories, } diff --git a/base_layer/wallet/src/transaction_service/storage/database.rs b/base_layer/wallet/src/transaction_service/storage/database.rs index 041183f47e..1895195970 100644 --- a/base_layer/wallet/src/transaction_service/storage/database.rs +++ b/base_layer/wallet/src/transaction_service/storage/database.rs @@ -79,6 +79,10 @@ pub trait TransactionBackend: Send + Sync { fn broadcast_completed_transaction(&self, tx_id: TxId) -> Result<(), TransactionStorageError>; /// Indicated that a completed transaction has been detected as mined on the base layer fn mine_completed_transaction(&self, tx_id: TxId) -> Result<(), TransactionStorageError>; + /// Cancel Completed transaction, this will update the transaction status + fn cancel_completed_transaction(&self, tx_id: TxId) -> Result<(), TransactionStorageError>; + /// Cancel Completed transaction, this will update the transaction status + fn cancel_pending_transaction(&self, tx_id: TxId) -> Result<(), TransactionStorageError>; /// Update a completed transactions timestamp for use in test data generation #[cfg(feature = "test_harness")] fn update_completed_transaction_timestamp( @@ -101,6 +105,8 @@ pub enum TransactionStatus { Imported, /// This transaction is still being negotiated by the parties Pending, + /// This transaction has been cancelled + Cancelled, } impl TryFrom for TransactionStatus { @@ -113,6 +119,7 @@ impl TryFrom for TransactionStatus { 2 => Ok(TransactionStatus::Mined), 3 => Ok(TransactionStatus::Imported), 4 => Ok(TransactionStatus::Pending), + 5 => Ok(TransactionStatus::Cancelled), _ => Err(TransactionStorageError::ConversionError), } } @@ -124,6 +131,13 @@ impl Default for TransactionStatus { } } +impl Display for TransactionStatus { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { + // No struct or tuple variants + write!(f, "{:?}", self) + } +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct InboundTransaction { pub tx_id: TxId, @@ -337,6 +351,25 @@ where T: TransactionBackend + 'static .and_then(|inner_result| inner_result) } + pub async fn insert_completed_transaction( + &self, + tx_id: TxId, + transaction: CompletedTransaction, + ) -> Result, TransactionStorageError> + { + let db_clone = self.db.clone(); + + tokio::task::spawn_blocking(move || { + db_clone.write(WriteOperation::Insert(DbKeyValuePair::CompletedTransaction( + tx_id, + Box::new(transaction), + ))) + }) + .await + .or_else(|err| Err(TransactionStorageError::BlockingTaskSpawnError(err.to_string()))) + .and_then(|inner_result| inner_result) + } + pub async fn get_pending_outbound_transaction( &self, tx_id: TxId, @@ -530,7 +563,15 @@ where T: TransactionBackend + 'static pub async fn cancel_completed_transaction(&mut self, tx_id: TxId) -> Result<(), TransactionStorageError> { let db_clone = self.db.clone(); - tokio::task::spawn_blocking(move || db_clone.write(WriteOperation::Remove(DbKey::CompletedTransaction(tx_id)))) + tokio::task::spawn_blocking(move || db_clone.cancel_completed_transaction(tx_id)) + .await + .or_else(|err| Err(TransactionStorageError::BlockingTaskSpawnError(err.to_string())))??; + Ok(()) + } + + pub async fn cancel_pending_transaction(&mut self, tx_id: TxId) -> Result<(), TransactionStorageError> { + let db_clone = self.db.clone(); + tokio::task::spawn_blocking(move || db_clone.cancel_pending_transaction(tx_id)) .await .or_else(|err| Err(TransactionStorageError::BlockingTaskSpawnError(err.to_string())))??; Ok(()) diff --git a/base_layer/wallet/src/transaction_service/storage/memory_db.rs b/base_layer/wallet/src/transaction_service/storage/memory_db.rs index ce29886845..52f733839b 100644 --- a/base_layer/wallet/src/transaction_service/storage/memory_db.rs +++ b/base_layer/wallet/src/transaction_service/storage/memory_db.rs @@ -81,32 +81,70 @@ impl TransactionBackend for TransactionMemoryDatabase { fn fetch(&self, key: &DbKey) -> Result, TransactionStorageError> { let db = acquire_read_lock!(self.db); let result = match key { - DbKey::PendingOutboundTransaction(t) => db - .pending_outbound_transactions - .get(t) - .map(|v| DbValue::PendingOutboundTransaction(Box::new(v.clone()))), - DbKey::PendingInboundTransaction(t) => db - .pending_inbound_transactions - .get(t) - .map(|v| DbValue::PendingInboundTransaction(Box::new(v.clone()))), - DbKey::CompletedTransaction(t) => db - .completed_transactions - .get(t) - .map(|v| DbValue::CompletedTransaction(Box::new(v.clone()))), + DbKey::PendingOutboundTransaction(t) => { + let mut result = None; + if let Some(v) = db.pending_outbound_transactions.get(t) { + if v.status != TransactionStatus::Cancelled { + result = Some(DbValue::PendingOutboundTransaction(Box::new(v.clone()))); + } + } + result + }, + DbKey::PendingInboundTransaction(t) => { + let mut result = None; + if let Some(v) = db.pending_inbound_transactions.get(t) { + if v.status != TransactionStatus::Cancelled { + result = Some(DbValue::PendingInboundTransaction(Box::new(v.clone()))); + } + } + result + }, + DbKey::CompletedTransaction(t) => { + let mut result = None; + if let Some(v) = db.completed_transactions.get(t) { + if v.status != TransactionStatus::Cancelled { + result = Some(DbValue::CompletedTransaction(Box::new(v.clone()))); + } + } + result + }, DbKey::PendingCoinbaseTransaction(t) => db .pending_coinbase_transactions .get(t) .map(|v| DbValue::PendingCoinbaseTransaction(Box::new(v.clone()))), - DbKey::PendingOutboundTransactions => Some(DbValue::PendingOutboundTransactions( - db.pending_outbound_transactions.clone(), - )), - DbKey::PendingInboundTransactions => Some(DbValue::PendingInboundTransactions( - db.pending_inbound_transactions.clone(), - )), + DbKey::PendingOutboundTransactions => { + // Filter out cancelled transactions + let mut result = HashMap::new(); + for (k, v) in db.pending_outbound_transactions.iter() { + if v.status != TransactionStatus::Cancelled { + result.insert(k.clone(), v.clone()); + } + } + Some(DbValue::PendingOutboundTransactions(result)) + }, + DbKey::PendingInboundTransactions => { + // Filter out cancelled transactions + let mut result = HashMap::new(); + for (k, v) in db.pending_inbound_transactions.iter() { + if v.status != TransactionStatus::Cancelled { + result.insert(k.clone(), v.clone()); + } + } + Some(DbValue::PendingInboundTransactions(result)) + }, DbKey::PendingCoinbaseTransactions => Some(DbValue::PendingCoinbaseTransactions( db.pending_coinbase_transactions.clone(), )), - DbKey::CompletedTransactions => Some(DbValue::CompletedTransactions(db.completed_transactions.clone())), + DbKey::CompletedTransactions => { + // Filter out cancelled transactions + let mut result = HashMap::new(); + for (k, v) in db.completed_transactions.iter() { + if v.status != TransactionStatus::Cancelled { + result.insert(k.clone(), v.clone()); + } + } + Some(DbValue::CompletedTransactions(result)) + }, }; Ok(result) @@ -296,11 +334,48 @@ impl TransactionBackend for TransactionMemoryDatabase { .completed_transactions .get_mut(&tx_id) .ok_or_else(|| TransactionStorageError::ValueNotFound(DbKey::CompletedTransaction(tx_id)))?; + + if completed_tx.status == TransactionStatus::Cancelled { + return Err(TransactionStorageError::ValueNotFound(DbKey::CompletedTransaction( + tx_id, + ))); + } + completed_tx.status = TransactionStatus::Mined; Ok(()) } + fn cancel_completed_transaction(&self, tx_id: TxId) -> Result<(), TransactionStorageError> { + let mut db = acquire_write_lock!(self.db); + + let mut completed_tx = db + .completed_transactions + .get_mut(&tx_id) + .ok_or_else(|| TransactionStorageError::ValueNotFound(DbKey::CompletedTransaction(tx_id)))?; + + completed_tx.status = TransactionStatus::Cancelled; + + Ok(()) + } + + fn cancel_pending_transaction(&self, tx_id: u64) -> Result<(), TransactionStorageError> { + let mut db = acquire_write_lock!(self.db); + + if db.pending_inbound_transactions.contains_key(&tx_id) { + if let Some(inbound) = db.pending_inbound_transactions.get_mut(&tx_id) { + inbound.status = TransactionStatus::Cancelled; + } + } else if db.pending_outbound_transactions.contains_key(&tx_id) { + if let Some(outbound) = db.pending_outbound_transactions.get_mut(&tx_id) { + outbound.status = TransactionStatus::Cancelled; + } + } else { + return Err(TransactionStorageError::ValuesNotFound); + } + Ok(()) + } + #[cfg(feature = "test_harness")] fn update_completed_transaction_timestamp( &self, diff --git a/base_layer/wallet/src/transaction_service/storage/sqlite_db.rs b/base_layer/wallet/src/transaction_service/storage/sqlite_db.rs index d4bdd98837..0ea6de1834 100644 --- a/base_layer/wallet/src/transaction_service/storage/sqlite_db.rs +++ b/base_layer/wallet/src/transaction_service/storage/sqlite_db.rs @@ -365,9 +365,9 @@ impl TransactionBackend for TransactionServiceSqliteDatabase { } }, Err(TransactionStorageError::DieselError(DieselError::NotFound)) => { - return Err(TransactionStorageError::ValueNotFound( - DbKey::PendingInboundTransaction(tx_id), - )) + return Err(TransactionStorageError::ValueNotFound(DbKey::CompletedTransaction( + tx_id, + ))) }, Err(e) => return Err(e), }; @@ -388,15 +388,52 @@ impl TransactionBackend for TransactionServiceSqliteDatabase { )?; }, Err(TransactionStorageError::DieselError(DieselError::NotFound)) => { - return Err(TransactionStorageError::ValueNotFound( - DbKey::PendingInboundTransaction(tx_id), - )) + return Err(TransactionStorageError::ValueNotFound(DbKey::CompletedTransaction( + tx_id, + ))) + }, + Err(e) => return Err(e), + }; + Ok(()) + } + + fn cancel_completed_transaction(&self, tx_id: u64) -> Result<(), TransactionStorageError> { + let conn = acquire_lock!(self.database_connection); + match CompletedTransactionSql::find(tx_id, &(*conn)) { + Ok(v) => { + v.cancel(&(*conn))?; + }, + Err(TransactionStorageError::DieselError(DieselError::NotFound)) => { + return Err(TransactionStorageError::ValueNotFound(DbKey::CompletedTransaction( + tx_id, + ))); }, Err(e) => return Err(e), }; Ok(()) } + fn cancel_pending_transaction(&self, tx_id: u64) -> Result<(), TransactionStorageError> { + let conn = acquire_lock!(self.database_connection); + match InboundTransactionSql::find(tx_id, &(*conn)) { + Ok(v) => { + let _ = v.cancel(&(*conn))?; + }, + Err(_) => { + match OutboundTransactionSql::find(tx_id, &(*conn)) { + Ok(v) => { + let _ = v.cancel(&(*conn))?; + }, + Err(TransactionStorageError::DieselError(DieselError::NotFound)) => { + return Err(TransactionStorageError::ValuesNotFound); + }, + Err(e) => return Err(e), + }; + }, + }; + Ok(()) + } + #[cfg(feature = "test_harness")] fn update_completed_transaction_timestamp( &self, @@ -460,6 +497,11 @@ impl InboundTransactionSql { Ok(()) } + + pub fn cancel(&self, conn: &SqliteConnection) -> Result<(), TransactionStorageError> { + // TODO Once sqlite migrations are implemented have cancellation be done with a Status flag + self.delete(conn) + } } impl TryFrom for InboundTransactionSql { @@ -536,6 +578,11 @@ impl OutboundTransactionSql { Ok(()) } + + pub fn cancel(&self, conn: &SqliteConnection) -> Result<(), TransactionStorageError> { + // TODO Once sqlite migrations are implemented have cancellation be done with a Status flag + self.delete(conn) + } } impl TryFrom for OutboundTransactionSql { @@ -664,12 +711,15 @@ impl CompletedTransactionSql { } pub fn index(conn: &SqliteConnection) -> Result, TransactionStorageError> { - Ok(completed_transactions::table.load::(conn)?) + Ok(completed_transactions::table + .filter(completed_transactions::status.ne(TransactionStatus::Cancelled as i32)) + .load::(conn)?) } pub fn find(tx_id: TxId, conn: &SqliteConnection) -> Result { Ok(completed_transactions::table .filter(completed_transactions::tx_id.eq(tx_id as i64)) + .filter(completed_transactions::status.ne(TransactionStatus::Cancelled as i32)) .first::(conn)?) } @@ -704,6 +754,24 @@ impl CompletedTransactionSql { Ok(CompletedTransactionSql::find(self.tx_id as u64, conn)?) } + + pub fn cancel(&self, conn: &SqliteConnection) -> Result<(), TransactionStorageError> { + let num_updated = + diesel::update(completed_transactions::table.filter(completed_transactions::tx_id.eq(&self.tx_id))) + .set(UpdateCompletedTransactionSql { + status: Some(TransactionStatus::Cancelled as i32), + timestamp: None, + }) + .execute(conn)?; + + if num_updated == 0 { + return Err(TransactionStorageError::UnexpectedResult( + "Database update error".to_string(), + )); + } + + Ok(()) + } } impl TryFrom for CompletedTransactionSql { diff --git a/base_layer/wallet/src/util/emoji.rs b/base_layer/wallet/src/util/emoji.rs index 4b389ac9ea..6fcb1fa3a3 100644 --- a/base_layer/wallet/src/util/emoji.rs +++ b/base_layer/wallet/src/util/emoji.rs @@ -30,20 +30,20 @@ use tari_crypto::tari_utilities::{ }; const EMOJI: [char; 256] = [ - '😀', '😃', '😄', '😁', '😆', '😅', '🤣', '😂', '🙂', '🙃', '😉', '😊', '😇', '🥰', '😍', '🤩', '😘', '😗', '😚', - '😙', '😋', '😛', '😜', '🤪', '😝', '🐠', '🤗', '🤭', '🤫', '🤔', '🤐', '🤨', '😐', '😑', '😶', '😏', '😒', '🙄', - '😬', '🤥', '😌', '😔', '😪', '🤤', '😴', '😷', '🤒', '🤕', '🤢', '🤮', '🤧', '🥵', '🥶', '🥴', '😵', '🤯', '🤠', '🥳', - '😎', '🤓', '🧐', '😕', '😟', '🙁', '😮', '😯', '😲', '😳', '🥺', '😦', '😧', '😨', '😰', '😥', '😢', '😭', '😱', - '😖', '😣', '😞', '😓', '😩', '😫', '😤', '😡', '😠', '🤬', '😈', '👿', '💀', '🐟', '💩', '🤡', '👹', '👺', '👻', - '👽', '👾', '🤖', '😺', '😹', '😻', '😼', '😽', '🙀', '😿', '😾', '💋', '👋', '🤚', '🖐', '✋', '🖖', '👌', '🤞', - '🤟', '🤘', '🤙', '👈', '👉', '👆', '🖕', '👇', '👍', '👎', '✊', '👊', '🤛', '🤜', '👏', '🙌', '👐', '🤲', '🤝', - '🙏', '💅', '🤳', '💪', '🦵', '🦶', '👂', '👃', '🧠', '🦷', '🦴', '👀', '👁', '👅', '👄', '🚶', '👣', '🧳', '🌂', '☂', - '🧵', '🧶', '👓', '🕶', '🥽', '🥼', '👔', '👕', '👖', '🧣', '🧤', '🧥', '🧦', '👗', '👘', '👙', '👚', '👛', '👜', '👝', - '🎒', '👞', '👟', '🥾', '🥿', '👠', '👡', '👢', '👑', '👒', '🎩', '🎓', '🧢', '⛑', '💄', '💍', '💼', '🙈', '🙉', - '🙊', '💥', '💫', '💦', '💨', '🐵', '🐒', '🦍', '🐶', '🐕', '🐩', '🐺', '🦊', '🦝', '🐱', '🐈', '🦁', '🐯', '🐅', - '🐆', '🐴', '🐎', '🦄', '🦓', '🦌', '🐮', '🐂', '🐃', '🐄', '🐷', '🐖', '🐗', '🐽', '🐏', '🐑', '🐐', '🐪', '🐫', - '🦙', '🦒', '🐘', '🦏', '🦛', '🐭', '🐁', '🐀', '🐹', '🐰', '🐇', '🐿', '🦔', '🦇', '🐻', '🐨', '🐼', '🦘', '🦡', '🐾', - '🦃', '🐓', '🐣', '🐋', '🐬', + '😀', '😂', '🤣', '😉', '😊', '😎', '😍', '😘', '🤗', '🤩', '🤔', '🙄', '😮', '🤐', '😴', '😛', '🤤', '🙃', '🤑', + '😤', '😨', '🤯', '😬', '😱', '🤪', '😵', '😷', '🤢', '🤮', '🤠', '🤡', '🤫', '🤭', '🤓', '😈', '👻', '👽', '🤖', + '💩', '😺', '👶', '👩', '👨', '👮', '🤴', '👸', '🧜', '🙅', '🙋', '🤦', '🤷', '💇', '🏃', '💃', '🧗', '🛀', '🛌', + '👤', '🏄', '🚴', '🤹', '💏', '👪', '💪', '👈', '👍', '✋', '👊', '👐', '🙏', '🤝', '💅', '👂', '👀', '🧠', '👄', + '💔', '💖', '💙', '💌', '💤', '💣', '💥', '💦', '💨', '💫', '👔', '👕', '👖', '🧣', '🧤', '🧦', '👗', '👙', '👜', + '🎒', '👑', '🧢', '💍', '💎', '🐒', '🐶', '🦁', '🐴', '🦄', '🐮', '🐷', '🐑', '🐫', '🦒', '🐘', '🐭', '🐇', '🐔', + '🦆', '🐸', '🐍', '🐳', '🐚', '🦀', '🐌', '🦋', '🌸', '🌲', '🌵', '🍇', '🍉', '🍌', '🍎', '🍒', '🍓', '🥑', '🥕', + '🌽', '🍄', '🥜', '🍞', '🧀', '🍖', '🍔', '🍟', '🍕', '🍿', '🍦', '🍪', '🍰', '🍫', '🍬', '🍷', '🍺', '🍴', '🌍', + '🌋', '🏠', '⛺', '🎡', '🎢', '🎨', '🚂', '🚌', '🚑', '🚒', '🚔', '🚕', '🚜', '🚲', '⛽', '🚦', '🚧', '⛵', '🚢', + '🛫', '💺', '🚁', '🚀', '🛸', '🚪', '🚽', '🚿', '⌛', '⏰', '🕙', '🌛', '🌞', '⛅', '🌀', '🌈', '🌂', '🔥', '✨', + '🎈', '🎉', '🎀', '🎁', '🏆', '🏅', '⚽', '🏀', '🏈', '🎾', '🥊', '🎯', '⛳', '🎣', '🎮', '🎲', '🔈', '🔔', '🎶', + '🎤', '🎧', '📻', '🎸', '🎹', '🎺', '🎻', '🥁', '📱', '🔋', '💻', '📷', '🔍', '🔭', '📡', '💡', '🔦', '📖', '📚', + '📝', '📅', '📌', '📎', '🔒', '🔑', '🔨', '🏹', '🔧', '💉', '💊', '🏧', '⛔', '🚫', '✅', '❌', '❓', '❕', '💯', + '🆗', '🆘', '⬛', '🔶', '🔵', '🏁', '🚩', '🎌', '🏴', ]; lazy_static! { @@ -70,9 +70,9 @@ lazy_static! { /// ``` /// use tari_wallet::util::emoji::EmojiId; /// -/// assert!(EmojiId::is_valid("🖖🥴😍🙃💦🤘🤜👁🙃🙌😱🖐🙀🤳🖖👍✊🐈☂💀👚😶🤟😳👢😘😺🙌🎩🤬🐼😎🥺")); +/// assert!(EmojiId::is_valid("🐇💃😴🤩⚽🐍🍎🍫🤩🍓💔🐘🦄🍞🐇🌲🍇🎶🏠🧣🚢😈🐸👊🕙🤤💎🍓⛅👔🆗🏄👐")); /// let eid = EmojiId::from_hex("70350e09c474809209824c6e6888707b7dd09959aa227343b5106382b856f73a").unwrap(); -/// assert_eq!(eid.as_str(), "🖖🥴😍🙃💦🤘🤜👁🙃🙌😱🖐🙀🤳🖖👍✊🐈☂💀👚😶🤟😳👢😘😺🙌🎩🤬🐼😎🥺"); +/// assert_eq!(eid.as_str(), "🐇💃😴🤩⚽🐍🍎🍫🤩🍓💔🐘🦄🍞🐇🌲🍇🎶🏠🧣🚢😈🐸👊🕙🤤💎🍓⛅👔🆗🏄👐"); /// ``` #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] pub struct EmojiId(String); @@ -170,7 +170,7 @@ mod test { let eid = EmojiId::from_hex("70350e09c474809209824c6e6888707b7dd09959aa227343b5106382b856f73a").unwrap(); assert_eq!( eid.as_str(), - "🖖🥴😍🙃💦🤘🤜👁🙃🙌😱🖐🙀🤳🖖👍✊🐈☂💀👚😶🤟😳👢😘😺🙌🎩🤬🐼😎🥺" + "🐇💃😴🤩⚽🐍🍎🍫🤩🍓💔🐘🦄🍞🐇🌲🍇🎶🏠🧣🚢😈🐸👊🕙🤤💎🍓⛅👔🆗🏄👐" ); assert_eq!(EmojiId::from_pubkey(&pubkey), eid); assert_eq!( @@ -178,7 +178,7 @@ mod test { "70350e09c474809209824c6e6888707b7dd09959aa227343b5106382b856f73a" ); assert_eq!( - EmojiId::str_to_pubkey("🖖🥴😍🙃💦🤘🤜👁🙃🙌😱🖐🙀🤳🖖👍✊🐈☂💀👚😶🤟😳👢😘😺🙌🎩🤬🐼😎🥺").unwrap(), + EmojiId::str_to_pubkey("🐇💃😴🤩⚽🐍🍎🍫🤩🍓💔🐘🦄🍞🐇🌲🍇🎶🏠🧣🚢😈🐸👊🕙🤤💎🍓⛅👔🆗🏄👐").unwrap(), pubkey ); } @@ -188,9 +188,10 @@ mod test { let eid = EmojiId::from_hex("70350e09c474809209824c6e6888707b7dd09959aa227343b5106382b856f73a").unwrap(); // Valid emojiID assert!(EmojiId::is_valid(eid.as_str())); + assert_eq!(EmojiId::is_valid(""), false, "Emoji ID too short"); assert_eq!(EmojiId::is_valid("😂"), false, "Emoji ID too short"); assert_eq!( - EmojiId::is_valid("🖖🥴😍🙃💦🤘🤜👁🙃🙌😱🖐🙀🤳🖖👍✊🐈☂💀👚😶🤟😳👢😘😺🙌🎩🤬🐼"), + EmojiId::is_valid("🤩⚽🐍🍎🍫🤩🍓💔🐘🦄🍞🐇🌲🍇🎶🏠🧣🚢😈🐸👊🕙🤤💎🍓⛅👔🆗🏄👐"), false, "Emoji ID too short" ); @@ -200,12 +201,12 @@ mod test { "Not emoji string" ); assert_eq!( - EmojiId::is_valid("🖖🥴😍🙃💦🤘🤜👁🙃🙌😱🖐🙀🤳🖖👍✊🐈☂💀👚😶🤟😳👢😘😺🙌🎩🤬🐼😎"), + EmojiId::is_valid("🐇💃😴🤩⚽🐍🍎🍫🤩🍓💔🐘🦄🍞🐇🌲🍇🎶🏠🧣🚢😈🐸👊🕙🤤💎🍓⛅👔🆗🏄"), false, "No checksum" ); assert_eq!( - EmojiId::is_valid("🖖🥴😍🙃💦🤘🤜👁🙃🙌😱🖐🙀🤳🖖👍✊🐈☂💀👚😶🤟😳👢😘😺🙌🎩🤬🐼😎😈"), + EmojiId::is_valid("🐇💃😴🤩⚽🐍🍎🍫🤩🍓💔🐘🦄🍞🐇🌲🍇🎶🏠🧣🚢😈🐸👊🕙🤤💎🍓⛅👔🆗🏄🎣"), false, "Wrong checksum" ); diff --git a/base_layer/wallet/src/util/luhn.rs b/base_layer/wallet/src/util/luhn.rs index 170aad4b36..f1b2661bc0 100644 --- a/base_layer/wallet/src/util/luhn.rs +++ b/base_layer/wallet/src/util/luhn.rs @@ -36,6 +36,9 @@ pub fn checksum(arr: &[usize], dict_len: usize) -> usize { /// Checks whether the last digit in the array matches the checksum for the array minus the last digit. pub fn is_valid(arr: &[usize], dict_len: usize) -> bool { + if arr.len() < 2 { + return false; + } let cs = checksum(&arr[..arr.len() - 1], dict_len); cs == arr[arr.len() - 1] } diff --git a/base_layer/wallet/src/wallet.rs b/base_layer/wallet/src/wallet.rs index cea739d5ab..8d24caff9c 100644 --- a/base_layer/wallet/src/wallet.rs +++ b/base_layer/wallet/src/wallet.rs @@ -47,7 +47,7 @@ use tari_comms::{ types::CommsPublicKey, CommsNode, }; -use tari_comms_dht::Dht; +use tari_comms_dht::{store_forward::StoreAndForwardRequester, Dht}; use tari_core::transactions::{ tari_amount::MicroTari, transaction::{OutputFeatures, UnblindedOutput}, @@ -90,6 +90,7 @@ where { pub comms: CommsNode, pub dht_service: Dht, + pub store_and_forward_requester: StoreAndForwardRequester, pub liveness_service: LivenessHandle, pub output_manager_service: OutputManagerHandle, pub transaction_service: TransactionServiceHandle, @@ -141,11 +142,11 @@ where LivenessConfig { auto_ping_interval: Some(Duration::from_secs(30)), enable_auto_join: true, - enable_auto_stored_message_request: true, - refresh_neighbours_interval: Default::default(), + ..Default::default() }, Arc::clone(&subscription_factory), dht.dht_requester(), + comms.connection_manager(), )) .add_initializer(OutputManagerServiceInitializer::new( OutputManagerServiceConfig::default(), @@ -156,7 +157,6 @@ where .add_initializer(TransactionServiceInitializer::new( config.transaction_service_config.unwrap_or_default(), subscription_factory.clone(), - comms.subscribe_messaging_events(), transaction_backend, comms.node_identity(), factories.clone(), @@ -184,9 +184,12 @@ where runtime.block_on(output_manager_handle.set_base_node_public_key(p.public_key.clone()))?; } + let store_and_forward_requester = dht.store_and_forward_requester(); + Ok(Wallet { comms, dht_service: dht, + store_and_forward_requester, liveness_service: liveness_handle, output_manager_service: output_manager_handle, transaction_service: transaction_service_handle, @@ -301,6 +304,9 @@ where /// Have all the wallet components that need to start a sync process with the set base node to confirm the wallets /// state is accurately reflected on the blockchain pub fn sync_with_base_node(&mut self) -> Result { + self.runtime + .block_on(self.store_and_forward_requester.request_saf_messages_from_neighbours())?; + let request_key = self .runtime .block_on(self.output_manager_service.sync_with_base_node())?; diff --git a/base_layer/wallet/tests/output_manager_service/service.rs b/base_layer/wallet/tests/output_manager_service/service.rs index 9f5265cf38..f3a8e46c2d 100644 --- a/base_layer/wallet/tests/output_manager_service/service.rs +++ b/base_layer/wallet/tests/output_manager_service/service.rs @@ -273,7 +273,7 @@ fn send_no_change(backend: T) { let (mut oms, _, _shutdown, _) = setup_output_manager_service(&mut runtime, backend); let fee_per_gram = MicroTari::from(20); - let fee_without_change = Fee::calculate(fee_per_gram, 2, 1); + let fee_without_change = Fee::calculate(fee_per_gram, 1, 2, 1); let key1 = PrivateKey::random(&mut OsRng); let value1 = 500; runtime @@ -347,7 +347,7 @@ fn send_not_enough_for_change(backend: T) { let (mut oms, _, _shutdown, _) = setup_output_manager_service(&mut runtime, backend); let fee_per_gram = MicroTari::from(20); - let fee_without_change = Fee::calculate(fee_per_gram, 2, 1); + let fee_without_change = Fee::calculate(fee_per_gram, 1, 2, 1); let key1 = PrivateKey::random(&mut OsRng); let value1 = 500; runtime @@ -680,8 +680,8 @@ fn test_startup_utxo_scan() { let output3 = UnblindedOutput::new(MicroTari::from(value3), key3, None); runtime.block_on(oms.add_output(output3.clone())).unwrap(); - let call = outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let bn_request: BaseNodeProto::BaseNodeServiceRequest = envelope_body .decode_part::(1) .unwrap() @@ -755,8 +755,8 @@ fn test_startup_utxo_scan() { runtime.block_on(oms.sync_with_base_node()).unwrap(); - let call = outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let bn_request: BaseNodeProto::BaseNodeServiceRequest = envelope_body .decode_part::(1) .unwrap() @@ -881,3 +881,88 @@ fn sending_transaction_with_short_term_clear_sqlite_db() { sending_transaction_with_short_term_clear(OutputManagerSqliteDatabase::new(connection)); } + +fn coin_split_with_change(backend: T) { + let factories = CryptoFactories::default(); + let mut runtime = Runtime::new().unwrap(); + let (mut oms, _, _, _) = setup_output_manager_service(&mut runtime, backend.clone()); + + let val1 = 6_000 * uT; + let val2 = 7_000 * uT; + let val3 = 8_000 * uT; + let (_ti, uo1) = make_input(&mut OsRng.clone(), val1, &factories.commitment); + let (_ti, uo2) = make_input(&mut OsRng.clone(), val2, &factories.commitment); + let (_ti, uo3) = make_input(&mut OsRng.clone(), val3, &factories.commitment); + assert!(runtime.block_on(oms.add_output(uo1)).is_ok()); + assert!(runtime.block_on(oms.add_output(uo2)).is_ok()); + assert!(runtime.block_on(oms.add_output(uo3)).is_ok()); + + let fee_per_gram = MicroTari::from(25); + let split_count = 8; + let (_tx_id, coin_split_tx, fee, amount) = runtime + .block_on(oms.create_coin_split(1000.into(), split_count, fee_per_gram, None)) + .unwrap(); + assert_eq!(coin_split_tx.body.inputs().len(), 2); + assert_eq!(coin_split_tx.body.outputs().len(), split_count + 1); + assert_eq!(fee, Fee::calculate(fee_per_gram, 1, 2, split_count + 1)); + assert_eq!(amount, val2 + val3); +} + +#[test] +fn coin_split_with_change_memory_db() { + coin_split_with_change(OutputManagerMemoryDatabase::new()); +} + +#[test] +fn coin_split_with_change_sqlite_db() { + let db_name = format!("{}.sqlite3", random_string(8).as_str()); + let db_tempdir = TempDir::new(random_string(8).as_str()).unwrap(); + let db_folder = db_tempdir.path().to_str().unwrap().to_string(); + let db_path = format!("{}/{}", db_folder, db_name); + let connection = run_migration_and_create_sqlite_connection(&db_path).unwrap(); + + coin_split_with_change(OutputManagerSqliteDatabase::new(connection)); +} + +fn coin_split_no_change(backend: T) { + let factories = CryptoFactories::default(); + let mut runtime = Runtime::new().unwrap(); + let (mut oms, _, _, _) = setup_output_manager_service(&mut runtime, backend.clone()); + + let fee_per_gram = MicroTari::from(25); + let split_count = 15; + let fee = Fee::calculate(fee_per_gram, 1, 3, 15); + let val1 = 4_000 * uT; + let val2 = 5_000 * uT; + let val3 = 6_000 * uT + fee; + let (_ti, uo1) = make_input(&mut OsRng.clone(), val1, &factories.commitment); + let (_ti, uo2) = make_input(&mut OsRng.clone(), val2, &factories.commitment); + let (_ti, uo3) = make_input(&mut OsRng.clone(), val3, &factories.commitment); + assert!(runtime.block_on(oms.add_output(uo1)).is_ok()); + assert!(runtime.block_on(oms.add_output(uo2)).is_ok()); + assert!(runtime.block_on(oms.add_output(uo3)).is_ok()); + + let (_tx_id, coin_split_tx, fee, amount) = runtime + .block_on(oms.create_coin_split(1000.into(), split_count, fee_per_gram, None)) + .unwrap(); + assert_eq!(coin_split_tx.body.inputs().len(), 3); + assert_eq!(coin_split_tx.body.outputs().len(), split_count); + assert_eq!(fee, Fee::calculate(fee_per_gram, 1, 3, split_count)); + assert_eq!(amount, val1 + val2 + val3); +} + +#[test] +fn coin_split_no_change_memory_db() { + coin_split_no_change(OutputManagerMemoryDatabase::new()); +} + +#[test] +fn coin_split_no_change_sqlite_db() { + let db_name = format!("{}.sqlite3", random_string(8).as_str()); + let db_tempdir = TempDir::new(random_string(8).as_str()).unwrap(); + let db_folder = db_tempdir.path().to_str().unwrap().to_string(); + let db_path = format!("{}/{}", db_folder, db_name); + let connection = run_migration_and_create_sqlite_connection(&db_path).unwrap(); + + coin_split_no_change(OutputManagerSqliteDatabase::new(connection)); +} diff --git a/base_layer/wallet/tests/output_manager_service/storage.rs b/base_layer/wallet/tests/output_manager_service/storage.rs index 5635a09d9a..b4d9f21002 100644 --- a/base_layer/wallet/tests/output_manager_service/storage.rs +++ b/base_layer/wallet/tests/output_manager_service/storage.rs @@ -176,7 +176,7 @@ pub fn test_db_backend(backend: T) { let outputs_to_encumber = vec![outputs[0].clone(), outputs[1].clone()]; let total_encumbered = outputs[0].clone().value + outputs[1].clone().value; runtime - .block_on(db.encumber_outputs(2, outputs_to_encumber, Some(uo_change.clone()))) + .block_on(db.encumber_outputs(2, outputs_to_encumber, vec![uo_change.clone()])) .unwrap(); runtime.block_on(db.confirm_encumbered_outputs(2)).unwrap(); @@ -373,11 +373,9 @@ pub async fn test_short_term_encumberance(bac let (_ti, uo) = make_input(&mut OsRng, MicroTari::from(50), &factories.commitment); pending_tx.outputs_to_be_received.push(uo); - db.encumber_outputs( - pending_tx.tx_id, - pending_tx.outputs_to_be_spent.clone(), - Some(pending_tx.outputs_to_be_received[0].clone()), - ) + db.encumber_outputs(pending_tx.tx_id, pending_tx.outputs_to_be_spent.clone(), vec![ + pending_tx.outputs_to_be_received[0].clone(), + ]) .await .unwrap(); @@ -389,11 +387,13 @@ pub async fn test_short_term_encumberance(bac let balance = db.get_balance().await.unwrap(); assert_eq!(available_balance, balance.available_balance); - db.encumber_outputs( - pending_tx.tx_id, - pending_tx.outputs_to_be_spent.clone(), - Some(pending_tx.outputs_to_be_received[0].clone()), - ) + pending_tx.outputs_to_be_received.clear(); + let (_ti, uo) = make_input(&mut OsRng, MicroTari::from(50), &factories.commitment); + pending_tx.outputs_to_be_received.push(uo); + + db.encumber_outputs(pending_tx.tx_id, pending_tx.outputs_to_be_spent.clone(), vec![ + pending_tx.outputs_to_be_received[0].clone(), + ]) .await .unwrap(); @@ -403,13 +403,15 @@ pub async fn test_short_term_encumberance(bac let balance = db.get_balance().await.unwrap(); assert_eq!(balance.available_balance, MicroTari(0)); + pending_tx.outputs_to_be_received.clear(); + let (_ti, uo) = make_input(&mut OsRng, MicroTari::from(50), &factories.commitment); + pending_tx.outputs_to_be_received.push(uo); + db.cancel_pending_transaction_outputs(pending_tx.tx_id).await.unwrap(); - db.encumber_outputs( - pending_tx.tx_id, - pending_tx.outputs_to_be_spent.clone(), - Some(pending_tx.outputs_to_be_received[0].clone()), - ) + db.encumber_outputs(pending_tx.tx_id, pending_tx.outputs_to_be_spent.clone(), vec![ + pending_tx.outputs_to_be_received[0].clone(), + ]) .await .unwrap(); diff --git a/base_layer/wallet/tests/support/comms_and_services.rs b/base_layer/wallet/tests/support/comms_and_services.rs index a8d50ed86a..13a6afd2da 100644 --- a/base_layer/wallet/tests/support/comms_and_services.rs +++ b/base_layer/wallet/tests/support/comms_and_services.rs @@ -77,13 +77,15 @@ pub fn create_dummy_message(inner: T, public_key: &CommsPublicKey) -> DomainM ); DomainMessage { dht_header: DhtMessageHeader { - origin: None, + ephemeral_public_key: None, + origin_mac: Vec::new(), version: Default::default(), message_type: Default::default(), flags: Default::default(), network: Network::LocalTest, destination: Default::default(), }, + authenticated_origin: None, source_peer: peer_source, inner, } diff --git a/base_layer/wallet/tests/transaction_service/service.rs b/base_layer/wallet/tests/transaction_service/service.rs index 424dc9fa0c..861c13a673 100644 --- a/base_layer/wallet/tests/transaction_service/service.rs +++ b/base_layer/wallet/tests/transaction_service/service.rs @@ -43,7 +43,6 @@ use tari_broadcast_channel::bounded; use tari_comms::{ message::EnvelopeBody, peer_manager::{NodeIdentity, PeerFeatures}, - protocol::messaging::MessagingEventSender, CommsNode, }; use tari_comms_dht::outbound::mock::{create_outbound_service_mock, OutboundServiceMockState}; @@ -60,10 +59,11 @@ use tari_core::{ transactions::{ proto::types::TransactionOutput as TransactionOutputProto, tari_amount::*, - transaction::{KernelBuilder, KernelFeatures, OutputFeatures, Transaction, TransactionOutput}, + transaction::{KernelBuilder, KernelFeatures, OutputFeatures, Transaction, TransactionOutput, UnblindedOutput}, transaction_protocol::{proto, recipient::RecipientSignedMessage, sender::TransactionSenderMessage}, types::{CryptoFactories, PrivateKey, PublicKey, RangeProof, Signature}, ReceiverTransactionProtocol, + SenderTransactionProtocol, }, }; use tari_crypto::{ @@ -76,7 +76,7 @@ use tari_p2p::{ services::comms_outbound::CommsOutboundServiceInitializer, }; use tari_service_framework::{reply_channel, StackBuilder}; -use tari_test_utils::{collect_stream, paths::with_temp_dir, unpack_enum}; +use tari_test_utils::{collect_stream, paths::with_temp_dir}; use tari_wallet::{ output_manager_service::{ config::OutputManagerServiceConfig, @@ -88,7 +88,6 @@ use tari_wallet::{ storage::connection_manager::run_migration_and_create_sqlite_connection, transaction_service::{ config::TransactionServiceConfig, - error::TransactionServiceError, handle::{TransactionEvent, TransactionServiceHandle}, service::TransactionService, storage::{ @@ -105,12 +104,13 @@ use tari_wallet::{ }, TransactionServiceInitializer, }, + types::HashDigest, }; use tempdir::TempDir; use tokio::{ runtime, runtime::{Builder, Runtime}, - sync::broadcast, + sync::broadcast::channel, time::delay_for, }; @@ -158,7 +158,6 @@ pub fn setup_transaction_service( ..Default::default() }, subscription_factory, - comms.subscribe_messaging_events(), backend, comms.node_identity().clone(), factories.clone(), @@ -189,13 +188,12 @@ pub fn setup_transaction_service_no_comms>, Sender>, Sender>, - MessagingEventSender, ) { let (oms_request_sender, oms_request_receiver) = reply_channel::unbounded(); let (oms_event_publisher, oms_event_subscriber) = bounded(100); - let (outbound_message_requester, mock_outbound_service) = create_outbound_service_mock(20); + let (outbound_message_requester, mock_outbound_service) = create_outbound_service_mock(100); let output_manager_service = runtime .block_on(OutputManagerService::new( @@ -212,8 +210,8 @@ pub fn setup_transaction_service_no_comms( .block_on(alice_ts.set_base_node_public_key(base_node_identity.public_key().clone())) .unwrap(); - let alice_event_stream = alice_ts.get_event_stream_fused(); + let mut alice_event_stream = alice_ts.get_event_stream_fused(); let (mut bob_ts, mut bob_oms, bob_comms) = setup_transaction_service( &mut runtime, @@ -320,7 +314,7 @@ fn manage_single_transaction( .block_on(bob_ts.set_base_node_public_key(base_node_identity.public_key().clone())) .unwrap(); - let bob_event_stream = bob_ts.get_event_stream_fused(); + let mut bob_event_stream = bob_ts.get_event_stream_fused(); runtime .block_on( @@ -353,22 +347,49 @@ fn manage_single_transaction( )) .unwrap(); - let _alice_events = - runtime.block_on(async { collect_stream!(alice_event_stream, take = 2, timeout = Duration::from_secs(20)) }); - - let bob_events = - runtime.block_on(async { collect_stream!(bob_event_stream, take = 2, timeout = Duration::from_secs(20)) }); + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let mut count = 0; + loop { + futures::select! { + event = alice_event_stream.select_next_some() => { + count+=1; + if count>=2 { + break; + } + }, + () = delay => { + break; + }, + } + } + }); - let tx_id = bob_events - .iter() - .find_map(|e| { - if let TransactionEvent::ReceivedFinalizedTransaction(tx_id) = &**e { - Some(tx_id.clone()) - } else { - None + let mut tx_id = 0u64; + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(90)).fuse(); + let mut finalized = 0; + loop { + futures::select! { + event = bob_event_stream.select_next_some() => { + match &*event.unwrap() { + TransactionEvent::ReceivedFinalizedTransaction(id) => { + tx_id = *id; + finalized+=1; + }, + _ => (), + } + if finalized == 1 { + break; + } + }, + () = delay => { + break; + }, } - }) - .unwrap(); + } + assert_eq!(finalized, 1); + }); let mut bob_completed_tx = runtime.block_on(bob_ts.get_completed_transactions()).unwrap(); @@ -443,6 +464,13 @@ fn manage_multiple_transactions( NodeIdentity::random(&mut OsRng, get_next_memory_address(), PeerFeatures::COMMUNICATION_NODE).unwrap(), ); + log::info!( + "wallet::manage_multiple_transactions: Alice: '{}', Bob: '{}', carol: '{}'", + alice_node_identity.node_id().short_str(), + bob_node_identity.node_id().short_str(), + carol_node_identity.node_id().short_str() + ); + let (mut alice_ts, mut alice_oms, alice_comms) = setup_transaction_service( &mut runtime, alice_node_identity.clone(), @@ -450,8 +478,35 @@ fn manage_multiple_transactions( factories.clone(), alice_backend, database_path.clone(), + Duration::from_secs(60), + ); + let mut alice_event_stream = alice_ts.get_event_stream_fused(); + + // Spin up Bob and Carol + let (mut bob_ts, mut bob_oms, bob_comms) = setup_transaction_service( + &mut runtime, + bob_node_identity.clone(), + vec![alice_node_identity.clone()], + factories.clone(), + bob_backend, + database_path.clone(), + Duration::from_secs(1), + ); + let mut bob_event_stream = bob_ts.get_event_stream_fused(); + let (mut carol_ts, mut carol_oms, carol_comms) = setup_transaction_service( + &mut runtime, + carol_node_identity.clone(), + vec![alice_node_identity.clone()], + factories.clone(), + carol_backend, + database_path, Duration::from_secs(1), ); + let mut carol_event_stream = carol_ts.get_event_stream_fused(); + let (_utxo, uo2) = make_input(&mut OsRng, MicroTari(3500), &factories.commitment); + runtime.block_on(bob_oms.add_output(uo2)).unwrap(); + let (_utxo, uo3) = make_input(&mut OsRng, MicroTari(4500), &factories.commitment); + runtime.block_on(carol_oms.add_output(uo3)).unwrap(); // Add some funds to Alices wallet let (_utxo, uo1a) = make_input(&mut OsRng, MicroTari(5500), &factories.commitment); @@ -466,7 +521,8 @@ fn manage_multiple_transactions( let value_a_to_b_2 = MicroTari::from(800); let value_b_to_a_1 = MicroTari::from(1100); let value_a_to_c_1 = MicroTari::from(1400); - runtime + log::trace!("Sending A to B 1"); + let tx_id_a_to_b_1 = runtime .block_on(alice_ts.send_transaction( bob_node_identity.public_key().clone(), value_a_to_b_1, @@ -474,7 +530,9 @@ fn manage_multiple_transactions( "a to b 1".to_string(), )) .unwrap(); - runtime + log::trace!("A to B 1 TxID: {}", tx_id_a_to_b_1); + log::trace!("Sending A to C 1"); + let tx_id_a_to_c_1 = runtime .block_on(alice_ts.send_transaction( carol_node_identity.public_key().clone(), value_a_to_c_1, @@ -484,31 +542,7 @@ fn manage_multiple_transactions( .unwrap(); let alice_completed_tx = runtime.block_on(alice_ts.get_completed_transactions()).unwrap(); assert_eq!(alice_completed_tx.len(), 0); - - // Spin up Bob and Carol - let (mut bob_ts, mut bob_oms, bob_comms) = setup_transaction_service( - &mut runtime, - bob_node_identity.clone(), - vec![alice_node_identity.clone()], - factories.clone(), - bob_backend, - database_path.clone(), - Duration::from_secs(1), - ); - let (mut carol_ts, mut carol_oms, carol_comms) = setup_transaction_service( - &mut runtime, - carol_node_identity.clone(), - vec![alice_node_identity.clone()], - factories.clone(), - carol_backend, - database_path, - Duration::from_secs(1), - ); - - let (_utxo, uo2) = make_input(&mut OsRng, MicroTari(3500), &factories.commitment); - runtime.block_on(bob_oms.add_output(uo2)).unwrap(); - let (_utxo, uo3) = make_input(&mut OsRng, MicroTari(4500), &factories.commitment); - runtime.block_on(carol_oms.add_output(uo3)).unwrap(); + log::trace!("A to C 1 TxID: {}", tx_id_a_to_c_1); runtime .block_on(bob_ts.send_transaction( @@ -527,8 +561,6 @@ fn manage_multiple_transactions( )) .unwrap(); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); - runtime.block_on(async { let mut delay = delay_for(Duration::from_secs(90)).fuse(); let mut tx_reply = 0; @@ -536,12 +568,12 @@ fn manage_multiple_transactions( loop { futures::select! { event = alice_event_stream.select_next_some() => { - if let TransactionEvent::ReceivedTransactionReply(_) = &*event{ - tx_reply+=1; - } - if let TransactionEvent::ReceivedFinalizedTransaction(_) = &*event{ - finalized+=1; + match &*event.unwrap() { + TransactionEvent::ReceivedTransactionReply(_) => tx_reply+=1, + TransactionEvent::ReceivedFinalizedTransaction(_) => finalized+=1, + _ => (), } + if tx_reply == 3 && finalized ==1 { break; } @@ -551,12 +583,10 @@ fn manage_multiple_transactions( }, } } - assert_eq!(tx_reply, 3); + assert_eq!(tx_reply, 3, "Need 3 replies"); assert_eq!(finalized, 1); }); - let mut bob_event_stream = bob_ts.get_event_stream_fused(); - runtime.block_on(async { let mut delay = delay_for(Duration::from_secs(90)).fuse(); let mut tx_reply = 0; @@ -564,11 +594,10 @@ fn manage_multiple_transactions( loop { futures::select! { event = bob_event_stream.select_next_some() => { - if let TransactionEvent::ReceivedTransactionReply(_) = &*event{ - tx_reply+=1; - } - if let TransactionEvent::ReceivedFinalizedTransaction(_) = &*event{ - finalized+=1; + match &*event.unwrap() { + TransactionEvent::ReceivedTransactionReply(_) => tx_reply+=1, + TransactionEvent::ReceivedFinalizedTransaction(_) => finalized+=1, + _ => (), } if tx_reply == 1 && finalized == 2 { break; @@ -583,15 +612,15 @@ fn manage_multiple_transactions( assert_eq!(finalized, 2); }); - let mut carol_event_stream = carol_ts.get_event_stream_fused(); runtime.block_on(async { let mut delay = delay_for(Duration::from_secs(90)).fuse(); let mut finalized = 0; loop { futures::select! { event = carol_event_stream.select_next_some() => { - if let TransactionEvent::ReceivedFinalizedTransaction(_) = &*event{ - finalized+=1; + match &*event.unwrap() { + TransactionEvent::ReceivedFinalizedTransaction(_) => finalized+=1, + _ => (), } if finalized == 1 { break; @@ -608,11 +637,11 @@ fn manage_multiple_transactions( let alice_pending_outbound = runtime.block_on(alice_ts.get_pending_outbound_transactions()).unwrap(); let alice_completed_tx = runtime.block_on(alice_ts.get_completed_transactions()).unwrap(); assert_eq!(alice_pending_outbound.len(), 0); - assert_eq!(alice_completed_tx.len(), 4); + assert_eq!(alice_completed_tx.len(), 4, "Not enough transactions for Alice"); let bob_pending_outbound = runtime.block_on(bob_ts.get_pending_outbound_transactions()).unwrap(); let bob_completed_tx = runtime.block_on(bob_ts.get_completed_transactions()).unwrap(); assert_eq!(bob_pending_outbound.len(), 0); - assert_eq!(bob_completed_tx.len(), 3); + assert_eq!(bob_completed_tx.len(), 3, "Not enough transactions for Bob"); let carol_pending_inbound = runtime.block_on(carol_ts.get_pending_inbound_transactions()).unwrap(); let carol_completed_tx = runtime.block_on(carol_ts.get_completed_transactions()).unwrap(); @@ -660,103 +689,8 @@ fn manage_multiple_transactions_sqlite_db() { ); } -fn test_sending_repeated_tx_ids(alice_backend: T, bob_backend: T) { - let mut runtime = create_runtime(); - let factories = CryptoFactories::default(); - - let bob_node_identity = NodeIdentity::random( - &mut OsRng, - "/ip4/127.0.0.1/tcp/55741".parse().unwrap(), - PeerFeatures::COMMUNICATION_NODE, - ) - .unwrap(); - - let ( - alice_ts, - _alice_output_manager, - alice_outbound_service, - mut alice_tx_sender, - _alice_tx_ack_sender, - _alice_mempool_response_sender, - _, - _, - _, - ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), alice_backend, None); - let (_bob_ts, mut bob_output_manager, _bob_outbound_service, _bob_tx_sender, _bob_tx_ack_sender, _, _, _, _) = - setup_transaction_service_no_comms(&mut runtime, factories.clone(), bob_backend, None); - let alice_event_stream = alice_ts.get_event_stream_fused(); - - let (_utxo, uo) = make_input(&mut OsRng, MicroTari(250000), &factories.commitment); - - runtime.block_on(bob_output_manager.add_output(uo)).unwrap(); - - let mut stp = runtime - .block_on(bob_output_manager.prepare_transaction_to_send( - MicroTari::from(500), - MicroTari::from(1000), - None, - "".to_string(), - )) - .unwrap(); - let msg = stp.build_single_round_message().unwrap(); - let tx_message = create_dummy_message( - TransactionSenderMessage::Single(Box::new(msg.clone())).into(), - &bob_node_identity.public_key(), - ); - - runtime.block_on(alice_tx_sender.send(tx_message.clone())).unwrap(); - runtime.block_on(alice_tx_sender.send(tx_message.clone())).unwrap(); - - let result = - runtime.block_on(async { collect_stream!(alice_event_stream, take = 2, timeout = Duration::from_secs(10)) }); - - alice_outbound_service - .wait_call_count(1, Duration::from_secs(10)) - .unwrap(); - - assert_eq!(result.len(), 2); - assert!(result - .iter() - .find(|i| if let TransactionEvent::ReceivedTransaction(_) = &***i { - true - } else { - false - }) - .is_some()); - assert!(result - .iter() - .find(|i| if let TransactionEvent::Error(s) = &***i { - s == &"Error handling Transaction Sender message".to_string() - } else { - false - }) - .is_some()); -} - -#[test] -fn test_sending_repeated_tx_ids_memory_db() { - test_sending_repeated_tx_ids(TransactionMemoryDatabase::new(), TransactionMemoryDatabase::new()); -} - -#[test] -fn test_sending_repeated_tx_ids_sqlite_db() { - with_temp_dir(|dir_path| { - let path_string = dir_path.to_str().unwrap().to_string(); - let alice_db_name = format!("{}.sqlite3", random_string(8).as_str()); - let alice_db_path = format!("{}/{}", path_string, alice_db_name); - let bob_db_name = format!("{}.sqlite3", random_string(8).as_str()); - let bob_db_path = format!("{}/{}", path_string, bob_db_name); - let connection_alice = run_migration_and_create_sqlite_connection(&alice_db_path).unwrap(); - let connection_bob = run_migration_and_create_sqlite_connection(&bob_db_path).unwrap(); - test_sending_repeated_tx_ids( - TransactionServiceSqliteDatabase::new(connection_alice), - TransactionServiceSqliteDatabase::new(connection_bob), - ); - }); -} - fn test_accepting_unknown_tx_id_and_malformed_reply(alice_backend: T) { - let mut runtime = create_runtime(); + let mut runtime = Runtime::new().unwrap(); let factories = CryptoFactories::default(); let bob_node_identity = @@ -770,10 +704,9 @@ fn test_accepting_unknown_tx_id_and_malformed_reply(1) .unwrap() @@ -823,15 +758,28 @@ fn test_accepting_unknown_tx_id_and_malformed_reply { + if let TransactionEvent::Error(s) = &*event.unwrap() { + if s == &"TransactionError(ValidationError(\"Transaction could not be finalized\"))".to_string() { + errors+=1; + } + if errors >= 2 { + break; + } + } + }, + () = delay => { + break; + }, + } + } + assert!(errors >= 1); + }); } #[test] @@ -850,26 +798,68 @@ fn test_accepting_unknown_tx_id_and_malformed_reply_sqlite_db() { }); } -fn finalize_tx_with_nonexistent_txid(alice_backend: T) { +fn finalize_tx_with_incorrect_pubkey(alice_backend: T, bob_backend: T) { let mut runtime = create_runtime(); let factories = CryptoFactories::default(); let ( alice_ts, _alice_output_manager, - _alice_outbound_service, - _alice_tx_sender, + alice_outbound_service, + mut alice_tx_sender, _alice_tx_ack_sender, mut alice_tx_finalized, _, _, - _, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), alice_backend, None); let alice_event_stream = alice_ts.get_event_stream_fused(); - let tx = Transaction::new(vec![], vec![], vec![], PrivateKey::random(&mut OsRng)); + let bob_node_identity = + NodeIdentity::random(&mut OsRng, get_next_memory_address(), PeerFeatures::COMMUNICATION_NODE).unwrap(); + let (_bob_ts, mut bob_output_manager, _bob_outbound_service, _bob_tx_sender, _bob_tx_ack_sender, _, _, _) = + setup_transaction_service_no_comms(&mut runtime, factories.clone(), bob_backend, None); + + let (_utxo, uo) = make_input(&mut OsRng, MicroTari(250000), &factories.commitment); + + runtime.block_on(bob_output_manager.add_output(uo)).unwrap(); + + let mut stp = runtime + .block_on(bob_output_manager.prepare_transaction_to_send( + MicroTari::from(500), + MicroTari::from(1000), + None, + "".to_string(), + )) + .unwrap(); + let msg = stp.build_single_round_message().unwrap(); + let tx_message = create_dummy_message( + TransactionSenderMessage::Single(Box::new(msg.clone())).into(), + &bob_node_identity.public_key(), + ); + + runtime.block_on(alice_tx_sender.send(tx_message.clone())).unwrap(); + + alice_outbound_service + .wait_call_count(2, Duration::from_secs(10)) + .unwrap(); + let (_, body) = alice_outbound_service.pop_call().unwrap(); + let _ = alice_outbound_service.pop_call().unwrap(); // burn SAF message + + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let recipient_reply: RecipientSignedMessage = envelope_body + .decode_part::(1) + .unwrap() + .unwrap() + .try_into() + .unwrap(); + + stp.add_single_recipient_info(recipient_reply.clone(), &factories.range_proof) + .unwrap(); + stp.finalize(KernelFeatures::empty(), &factories).unwrap(); + let tx = stp.get_transaction().unwrap(); + let finalized_transaction_message = proto::TransactionFinalizedMessage { - tx_id: 88u64, + tx_id: recipient_reply.tx_id, transaction: Some(tx.clone().into()), }; @@ -881,108 +871,9 @@ fn finalize_tx_with_nonexistent_txid(al .unwrap(); assert!(runtime - .block_on(async { collect_stream!(alice_event_stream, take = 1, timeout = Duration::from_secs(10)) }) + .block_on(async { collect_stream!(alice_event_stream, take = 2, timeout = Duration::from_secs(10)) }) .iter() - .find(|i| if let TransactionEvent::Error(s) = &***i { - s == &"Error handling Transaction Finalized message".to_string() - } else { - false - }) - .is_some()); -} - -#[test] -fn finalize_tx_with_nonexistent_txid_memory_db() { - finalize_tx_with_nonexistent_txid(TransactionMemoryDatabase::new()); -} - -#[test] -fn finalize_tx_with_nonexistent_txid_sqlite_db() { - with_temp_dir(|dir_path| { - let path_string = dir_path.to_str().unwrap().to_string(); - let alice_db_name = format!("{}.sqlite3", random_string(8).as_str()); - let alice_db_path = format!("{}/{}", path_string, alice_db_name); - let connection_alice = run_migration_and_create_sqlite_connection(&alice_db_path).unwrap(); - - finalize_tx_with_nonexistent_txid(TransactionServiceSqliteDatabase::new(connection_alice)); - }); -} - -fn finalize_tx_with_incorrect_pubkey(alice_backend: T, bob_backend: T) { - let mut runtime = create_runtime(); - let factories = CryptoFactories::default(); - - let ( - alice_ts, - _alice_output_manager, - alice_outbound_service, - mut alice_tx_sender, - _alice_tx_ack_sender, - mut alice_tx_finalized, - _, - _, - _, - ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), alice_backend, None); - let alice_event_stream = alice_ts.get_event_stream_fused(); - - let bob_node_identity = - NodeIdentity::random(&mut OsRng, get_next_memory_address(), PeerFeatures::COMMUNICATION_NODE).unwrap(); - let (_bob_ts, mut bob_output_manager, _bob_outbound_service, _bob_tx_sender, _bob_tx_ack_sender, _, _, _, _) = - setup_transaction_service_no_comms(&mut runtime, factories.clone(), bob_backend, None); - - let (_utxo, uo) = make_input(&mut OsRng, MicroTari(250000), &factories.commitment); - - runtime.block_on(bob_output_manager.add_output(uo)).unwrap(); - - let mut stp = runtime - .block_on(bob_output_manager.prepare_transaction_to_send( - MicroTari::from(500), - MicroTari::from(1000), - None, - "".to_string(), - )) - .unwrap(); - let msg = stp.build_single_round_message().unwrap(); - let tx_message = create_dummy_message( - TransactionSenderMessage::Single(Box::new(msg.clone())).into(), - &bob_node_identity.public_key(), - ); - - runtime.block_on(alice_tx_sender.send(tx_message.clone())).unwrap(); - - alice_outbound_service - .wait_call_count(1, Duration::from_secs(10)) - .unwrap(); - let (_, body) = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(body.as_slice()).unwrap(); - let recipient_reply: RecipientSignedMessage = envelope_body - .decode_part::(1) - .unwrap() - .unwrap() - .try_into() - .unwrap(); - - stp.add_single_recipient_info(recipient_reply.clone(), &factories.range_proof) - .unwrap(); - stp.finalize(KernelFeatures::empty(), &factories).unwrap(); - let tx = stp.get_transaction().unwrap(); - - let finalized_transaction_message = proto::TransactionFinalizedMessage { - tx_id: recipient_reply.tx_id, - transaction: Some(tx.clone().into()), - }; - - runtime - .block_on(alice_tx_finalized.send(create_dummy_message( - finalized_transaction_message.clone(), - &PublicKey::from_secret_key(&PrivateKey::random(&mut OsRng)), - ))) - .unwrap(); - - assert!(runtime - .block_on(async { collect_stream!(alice_event_stream, take = 2, timeout = Duration::from_secs(10)) }) - .iter() - .find(|i| if let TransactionEvent::Error(s) = &***i { + .find(|i| if let TransactionEvent::Error(s) = &**(**i).as_ref().unwrap() { s == &"Error handling Transaction Finalized message".to_string() } else { false @@ -1025,13 +916,12 @@ fn finalize_tx_with_missing_output(alic mut alice_tx_finalized, _, _, - _, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), alice_backend, None); let alice_event_stream = alice_ts.get_event_stream_fused(); let bob_node_identity = NodeIdentity::random(&mut OsRng, get_next_memory_address(), PeerFeatures::COMMUNICATION_NODE).unwrap(); - let (_bob_ts, mut bob_output_manager, _bob_outbound_service, _bob_tx_sender, _bob_tx_ack_sender, _, _, _, _) = + let (_bob_ts, mut bob_output_manager, _bob_outbound_service, _bob_tx_sender, _bob_tx_ack_sender, _, _, _) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), bob_backend, None); let (_utxo, uo) = make_input(&mut OsRng, MicroTari(250000), &factories.commitment); @@ -1055,10 +945,12 @@ fn finalize_tx_with_missing_output(alic runtime.block_on(alice_tx_sender.send(tx_message.clone())).unwrap(); alice_outbound_service - .wait_call_count(1, Duration::from_secs(10)) + .wait_call_count(2, Duration::from_secs(10)) .unwrap(); let (_, body) = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(body.as_slice()).unwrap(); + let _ = alice_outbound_service.pop_call().unwrap(); // burn SAF message + + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let recipient_reply: RecipientSignedMessage = envelope_body .decode_part::(1) .unwrap() @@ -1085,7 +977,7 @@ fn finalize_tx_with_missing_output(alic assert!(runtime .block_on(async { collect_stream!(alice_event_stream, take = 2, timeout = Duration::from_secs(10)) }) .iter() - .find(|i| if let TransactionEvent::Error(s) = &***i { + .find(|i| if let TransactionEvent::Error(s) = &**(**i).as_ref().unwrap() { s == &"Error handling Transaction Finalized message".to_string() } else { false @@ -1222,48 +1114,97 @@ fn discovery_async_return_test() { let value_a_to_c_1 = MicroTari::from(1400); - let tx_id = match runtime.block_on(alice_ts.send_transaction( - carol_node_identity.public_key().clone(), - value_a_to_c_1, - MicroTari::from(20), - "Discovery Tx!".to_string(), - )) { - Err(TransactionServiceError::OutboundSendDiscoveryInProgress(tx_id)) => tx_id, - _ => { - assert!(false, "Send should not succeed as Peer is not known"); - 0u64 - }, - }; + let tx_id = runtime + .block_on(alice_ts.send_transaction( + carol_node_identity.public_key().clone(), + value_a_to_c_1, + MicroTari::from(20), + "Discovery Tx!".to_string(), + )) + .unwrap(); + assert_ne!(initial_balance, runtime.block_on(alice_oms.get_balance()).unwrap()); - let event = runtime.block_on(alice_event_stream.next()).unwrap(); - unpack_enum!(TransactionEvent::TransactionSendDiscoveryComplete(txid, is_success) = &*event); - assert_eq!(txid, &tx_id); - assert_eq!(*is_success, false); + let mut txid = 0; + let mut is_success = true; + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + loop { + futures::select! { + event = alice_event_stream.select_next_some() => { + if let TransactionEvent::TransactionDirectSendResult(tx_id, result) = (*event.unwrap()).clone() { + txid = tx_id; + is_success = result; + break; + } + }, + () = delay => { + break; + }, + } + } + }); + assert_eq!(txid, tx_id); + assert_eq!(is_success, false); - assert_eq!(initial_balance, runtime.block_on(alice_oms.get_balance()).unwrap()); + let tx_id2 = runtime + .block_on(alice_ts.send_transaction( + dave_node_identity.public_key().clone(), + value_a_to_c_1, + MicroTari::from(20), + "Discovery Tx2!".to_string(), + )) + .unwrap(); - let tx_id2 = match runtime.block_on(alice_ts.send_transaction( - dave_node_identity.public_key().clone(), - value_a_to_c_1, - MicroTari::from(20), - "Discovery Tx2!".to_string(), - )) { - Err(TransactionServiceError::OutboundSendDiscoveryInProgress(tx_id)) => tx_id, - _ => { - assert!(false, "Send should not succeed as Peer is not known"); - 0u64 - }, - }; + let mut success_result = false; + let mut success_tx_id = 0u64; + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut success_count = 0; - let event = runtime.block_on(alice_event_stream.next()).unwrap(); - unpack_enum!(TransactionEvent::TransactionSendResult(txid, is_success) = &*event); - assert_eq!(txid, &tx_id2); - assert!(is_success); + loop { + futures::select! { + event = alice_event_stream.select_next_some() => { + if let TransactionEvent::TransactionDirectSendResult(tx_id, success) = &*event.unwrap() { + success_count+=1; + success_result = success.clone(); + success_tx_id = *tx_id; + if success_count >= 1 { + break; + } + } + }, + () = delay => { + break; + }, + } + } + assert!(success_count >= 1); + }); - let event = runtime.block_on(alice_event_stream.next()).unwrap(); - unpack_enum!(TransactionEvent::ReceivedTransactionReply(txid) = &*event); - assert_eq!(txid, &tx_id2); + assert_eq!(success_tx_id, tx_id2); + assert!(success_result); + + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut tx_reply = 0; + loop { + futures::select! { + event = alice_event_stream.select_next_some() => { + if let TransactionEvent::ReceivedTransactionReply(tx_id) = &*event.unwrap() { + if tx_id == &tx_id2 { + tx_reply +=1; + break; + } + } + }, + () = delay => { + break; + }, + } + } + assert!(tx_reply >= 1); + }); runtime.block_on(async move { alice_comms.shutdown().await; @@ -1285,7 +1226,6 @@ fn test_coinbase(backend: T) { _, _, _, - _, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), backend, None); let balance = runtime.block_on(alice_output_manager.get_balance()).unwrap(); @@ -1412,34 +1352,25 @@ fn transaction_mempool_broadcast() { mut alice_tx_ack_sender, _, mut alice_mempool_response_sender, - _, - _, + mut alice_base_node_response_sender, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), TransactionMemoryDatabase::new(), None); + let mut alice_event_stream = alice_ts.get_event_stream_fused(); runtime .block_on(alice_ts.set_base_node_public_key(base_node_identity.public_key().clone())) .unwrap(); - let ( - mut bob_ts, - _bob_output_manager, - bob_outbound_service, - mut bob_tx_sender, - _, - mut bob_tx_finalized_sender, - _, - _, - _, - ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), TransactionMemoryDatabase::new(), None); - - runtime - .block_on(bob_ts.set_base_node_public_key(base_node_identity.public_key().clone())) - .unwrap(); + let (_bob_ts, _bob_output_manager, bob_outbound_service, mut bob_tx_sender, _, _, _, _) = + setup_transaction_service_no_comms(&mut runtime, factories.clone(), TransactionMemoryDatabase::new(), None); let (_utxo, uo) = make_input(&mut OsRng, MicroTari(250000), &factories.commitment); runtime.block_on(alice_output_manager.add_output(uo)).unwrap(); - runtime + let (_utxo, uo2) = make_input(&mut OsRng, MicroTari(250000), &factories.commitment); + runtime.block_on(alice_output_manager.add_output(uo2)).unwrap(); + + // Send Tx1 + let tx_id1 = runtime .block_on(alice_ts.send_transaction( bob_node_identity.public_key().clone(), 10000 * uT, @@ -1447,20 +1378,23 @@ fn transaction_mempool_broadcast() { "Testing Message".to_string(), )) .unwrap(); + alice_outbound_service + .wait_call_count(2, Duration::from_secs(60)) + .expect("Alice call wait 1"); + let (_, body) = alice_outbound_service.pop_call().unwrap(); + let _ = alice_outbound_service.pop_call().unwrap(); // Burn the SAF version of the message - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_sender_msg: TransactionSenderMessage = envelope_body .decode_part::(1) .unwrap() .unwrap() .try_into() .unwrap(); - let tx_id = match tx_sender_msg.clone() { - TransactionSenderMessage::Single(s) => s.tx_id, + match tx_sender_msg.clone() { + TransactionSenderMessage::Single(_) => (), _ => { assert!(false, "Transaction is the not a single rounder sender variant"); - 0 }, }; @@ -1470,93 +1404,181 @@ fn transaction_mempool_broadcast() { alice_node_identity.public_key(), ))) .unwrap(); + bob_outbound_service + .wait_call_count(2, Duration::from_secs(60)) + .expect("bob call wait 1"); - let _result_stream = runtime.block_on(async { - collect_stream!( - bob_ts.get_event_stream_fused(), - take = 1, - timeout = Duration::from_secs(20) - ) - }); let call = bob_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); - let tx_reply_msg: RecipientSignedMessage = envelope_body + let _ = bob_outbound_service.pop_call().unwrap(); // Burn the SAF version of the message + let envelope_body = EnvelopeBody::decode(&mut call.1.to_vec().as_slice()).unwrap(); + let bob_tx_reply_msg1: RecipientSignedMessage = envelope_body + .decode_part::(1) + .unwrap() + .unwrap() + .try_into() + .unwrap(); + + // Send Tx2 + let tx_id2 = runtime + .block_on(alice_ts.send_transaction( + bob_node_identity.public_key().clone(), + 10001 * uT, + 100 * uT, + "Testing Message2".to_string(), + )) + .unwrap(); + alice_outbound_service + .wait_call_count(2, Duration::from_secs(60)) + .expect("Alice call wait 2"); + + let call = alice_outbound_service.pop_call().unwrap(); + let _ = alice_outbound_service.pop_call().unwrap(); // Burn the SAF version of the message + let tx_sender_msg = try_decode_sender_message(call.1.to_vec().clone()).unwrap(); + + match tx_sender_msg.clone() { + TransactionSenderMessage::Single(_) => (), + _ => { + assert!(false, "Transaction is the not a single rounder sender variant"); + }, + }; + + runtime + .block_on(bob_tx_sender.send(create_dummy_message( + tx_sender_msg.into(), + alice_node_identity.public_key(), + ))) + .unwrap(); + bob_outbound_service + .wait_call_count(2, Duration::from_secs(60)) + .expect("Bob call wait 2"); + + let (_, body) = bob_outbound_service.pop_call().unwrap(); + let _ = bob_outbound_service.pop_call().unwrap(); // Burn the SAF version of the message + + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bob_tx_reply_msg2: RecipientSignedMessage = envelope_body .decode_part::(1) .unwrap() .unwrap() .try_into() .unwrap(); + // Give Alice both of Bobs replies runtime .block_on(alice_tx_ack_sender.send(create_dummy_message( - tx_reply_msg.into(), + bob_tx_reply_msg1.into(), + bob_node_identity.public_key(), + ))) + .unwrap(); + + runtime + .block_on(alice_tx_ack_sender.send(create_dummy_message( + bob_tx_reply_msg2.into(), bob_node_identity.public_key(), ))) .unwrap(); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); runtime.block_on(async { let mut delay = delay_for(Duration::from_secs(60)).fuse(); - let mut broadcast_timeout_count = 0; + let mut tx1_timeout = false; + let mut tx2_timeout = false; loop { futures::select! { event = alice_event_stream.select_next_some() => { - if let TransactionEvent::MempoolBroadcastTimedOut(_) = &*event{ - broadcast_timeout_count +=1; - if broadcast_timeout_count >= 2 { + if let TransactionEvent::MempoolBroadcastTimedOut(tx_id) = &*event.unwrap(){ + if tx_id == &tx_id1 { + tx1_timeout = true; + } + if tx_id == &tx_id2 { + tx2_timeout = true; + } + if tx1_timeout && tx2_timeout { break; } - } }, () = delay => { - log::error!("This select loop timed out"); break; }, } } - assert!(broadcast_timeout_count >= 2); + assert!(tx1_timeout && tx2_timeout); }); - let alice_completed_tx = runtime + let alice_completed_tx1 = runtime .block_on(alice_ts.get_completed_transactions()) .unwrap() - .remove(&tx_id) + .remove(&tx_id1) .expect("Transaction must be in collection"); - assert_eq!(alice_completed_tx.status, TransactionStatus::Completed); - - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); - let msr = envelope_body - .decode_part::(1) + let alice_completed_tx2 = runtime + .block_on(alice_ts.get_completed_transactions()) .unwrap() - .unwrap(); - - let mempool_service_request = MempoolServiceRequest::try_from(msr.clone()).unwrap(); - - let _ = alice_outbound_service.pop_call().unwrap(); // burn a mempool request - let _ = alice_outbound_service.pop_call().unwrap(); // burn a mempool request - let call = alice_outbound_service.pop_call().unwrap(); // this should be the sending of the finalized tx to the receiver + .remove(&tx_id2) + .expect("Transaction must be in collection"); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); - let tx_finalized = envelope_body - .decode_part::(1) - .unwrap() - .unwrap(); + assert_eq!(alice_completed_tx1.status, TransactionStatus::Completed); + assert_eq!(alice_completed_tx2.status, TransactionStatus::Completed); - runtime - .block_on(bob_tx_finalized_sender.send(create_dummy_message(tx_finalized, alice_node_identity.public_key()))) - .unwrap(); + alice_outbound_service + .wait_call_count(4, Duration::from_secs(60)) + .expect("Alice call wait 3"); + + let mut msr_tx1_found = false; + let mut bsr_tx1_found = false; + let mut msr_tx2_found = false; + let mut bsr_tx2_found = false; + log::info!("Starting to look for MSR and BSR requests"); + for _ in 0..4 { + let call = alice_outbound_service.pop_call().unwrap(); + match try_decode_mempool_request(call.1.to_vec().clone()) { + Some(m) => { + if m.request_key == tx_id1 { + msr_tx1_found = true; + } + if m.request_key == tx_id2 { + msr_tx2_found = true; + } + match m.request { + MempoolRequest::GetStats => assert!(false, "Invalid Mempool Service Request variant"), + MempoolRequest::GetState => assert!(false, "Invalid Mempool Service Request variant"), + MempoolRequest::GetTxStateWithExcessSig(_) => { + assert!(false, "Invalid Mempool Service Request variant") + }, + MempoolRequest::SubmitTransaction(t) => { + if m.request_key == tx_id1 { + assert_eq!(t, alice_completed_tx1.transaction); + } + if m.request_key == tx_id2 { + assert_eq!(t, alice_completed_tx2.transaction); + } + }, + } + }, + None => { + if let Some(bsr) = try_decode_base_node_request(call.1.to_vec().clone()) { + if bsr.request_key == tx_id1 { + bsr_tx1_found = true; + } + if bsr.request_key == tx_id2 { + bsr_tx2_found = true; + } + } + }, + } + } + assert!(msr_tx1_found); + assert!(msr_tx2_found); + assert!(bsr_tx1_found); + assert!(bsr_tx2_found); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); runtime.block_on(async { let mut delay = delay_for(Duration::from_secs(60)).fuse(); let mut broadcast_timeout_count = 0; loop { futures::select! { event = alice_event_stream.select_next_some() => { - if let TransactionEvent::MempoolBroadcastTimedOut(_) = &*event{ + if let TransactionEvent::MempoolBroadcastTimedOut(_) = &*event.unwrap(){ broadcast_timeout_count +=1; if broadcast_timeout_count >= 1 { break; @@ -1565,7 +1587,6 @@ fn transaction_mempool_broadcast() { } }, () = delay => { - log::error!("This select loop timed out"); break; }, } @@ -1573,40 +1594,52 @@ fn transaction_mempool_broadcast() { assert!(broadcast_timeout_count >= 1); }); - assert_eq!(mempool_service_request.request_key, tx_id); - - match mempool_service_request.request { - MempoolRequest::GetStats => assert!(false, "Invalid Mempool Service Request variant"), - MempoolRequest::GetTxStateWithExcessSig(_) => assert!(false, "Invalid Mempool Service Request variant"), - MempoolRequest::SubmitTransaction(tx) => assert_eq!(tx, alice_completed_tx.transaction), - } - let mempool_response = MempoolProto::MempoolServiceResponse { - request_key: tx_id, + request_key: tx_id1, response: Some(MempoolResponse::TxStorage(TxStorageResponse::UnconfirmedPool).into()), }; - runtime .block_on( alice_mempool_response_sender.send(create_dummy_message(mempool_response, base_node_identity.public_key())), ) .unwrap(); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let completed_tx_outputs: Vec = alice_completed_tx2 + .transaction + .body + .outputs() + .iter() + .map(|o| TransactionOutputProto::from(o.clone())) + .collect(); + + let base_node_response = BaseNodeProto::BaseNodeServiceResponse { + request_key: tx_id2.clone(), + response: Some(BaseNodeResponseProto::TransactionOutputs( + BaseNodeProto::TransactionOutputs { + outputs: completed_tx_outputs.into(), + }, + )), + }; + + runtime + .block_on(alice_base_node_response_sender.send(create_dummy_message( + base_node_response, + base_node_identity.public_key(), + ))) + .unwrap(); + runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(10)).fuse(); + let mut delay = delay_for(Duration::from_secs(30)).fuse(); let mut broadcast = false; loop { futures::select! { event = alice_event_stream.select_next_some() => { - if let TransactionEvent::TransactionBroadcast(_) = &*event{ - broadcast = true; + if let TransactionEvent::TransactionBroadcast(id) = &*event.unwrap(){ + broadcast = &tx_id1 == id; break; - } }, () = delay => { - log::error!("This select loop timed out"); break; }, } @@ -1614,13 +1647,83 @@ fn transaction_mempool_broadcast() { assert!(broadcast); }); + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(30)).fuse(); + let mut mined = false; + loop { + futures::select! { + event = alice_event_stream.select_next_some() => { + if let TransactionEvent::TransactionMined(id) = &*event.unwrap(){ + mined = &tx_id2 == id; + break; + } + }, + () = delay => { + break; + }, + } + } + assert!(mined); + }); + let alice_completed_tx = runtime .block_on(alice_ts.get_completed_transactions()) .unwrap() - .remove(&tx_id) + .remove(&tx_id1) .expect("Transaction must be in collection"); assert_eq!(alice_completed_tx.status, TransactionStatus::Broadcast); + + let alice_completed_tx = runtime + .block_on(alice_ts.get_completed_transactions()) + .unwrap() + .remove(&tx_id2) + .expect("Transaction must be in collection"); + + assert_eq!(alice_completed_tx.status, TransactionStatus::Mined); +} + +fn try_decode_mempool_request(bytes: Vec) -> Option { + let envelope_body = EnvelopeBody::decode(&mut bytes.as_slice()).unwrap(); + let msr = match envelope_body.decode_part::(1) { + Err(_) => return None, + Ok(d) => match d { + None => return None, + Some(r) => r, + }, + }; + + match MempoolServiceRequest::try_from(msr) { + Ok(msr) => Some(msr), + Err(_) => None, + } +} + +fn try_decode_sender_message(bytes: Vec) -> Option { + let envelope_body = EnvelopeBody::decode(&mut bytes.as_slice()).unwrap(); + let tx_sender_msg = match envelope_body.decode_part::(1) { + Err(_) => return None, + Ok(d) => match d { + None => return None, + Some(r) => r, + }, + }; + + match TransactionSenderMessage::try_from(tx_sender_msg) { + Ok(msr) => Some(msr), + Err(_) => None, + } +} + +fn try_decode_base_node_request(bytes: Vec) -> Option { + let envelope_body = EnvelopeBody::decode(&mut bytes.as_slice()).unwrap(); + match envelope_body.decode_part::(1) { + Err(_) => return None, + Ok(d) => match d { + None => return None, + Some(r) => return Some(r), + }, + }; } #[test] @@ -1663,24 +1766,24 @@ fn broadcast_all_completed_transactions_on_startup() { }; db.write(WriteOperation::Insert(DbKeyValuePair::CompletedTransaction( - completed_tx1.tx_id.clone(), + completed_tx1.tx_id, Box::new(completed_tx1.clone()), ))) .unwrap(); db.write(WriteOperation::Insert(DbKeyValuePair::CompletedTransaction( - completed_tx2.tx_id.clone(), + completed_tx2.tx_id, Box::new(completed_tx2.clone()), ))) .unwrap(); db.write(WriteOperation::Insert(DbKeyValuePair::CompletedTransaction( - completed_tx3.tx_id.clone(), + completed_tx3.tx_id, Box::new(completed_tx3.clone()), ))) .unwrap(); - let (mut alice_ts, _, _, _, _, _, _, _, _) = + let (mut alice_ts, _, _, _, _, _, _, _) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), db, None); runtime @@ -1695,7 +1798,7 @@ fn broadcast_all_completed_transactions_on_startup() { loop { futures::select! { event = event_stream.select_next_some() => { - if let TransactionEvent::MempoolBroadcastTimedOut(tx_id) = (*event).clone() { + if let TransactionEvent::MempoolBroadcastTimedOut(tx_id) = (*event.unwrap()).clone() { if tx_id == 1u64 { found1 = true } @@ -1741,15 +1844,12 @@ fn transaction_base_node_monitoring() { _, mut alice_mempool_response_sender, mut alice_base_node_response_sender, - _, ) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), TransactionMemoryDatabase::new(), None); - let (mut bob_ts, _, bob_outbound_service, mut bob_tx_sender, _, _, _, _, _) = - setup_transaction_service_no_comms(&mut runtime, factories.clone(), TransactionMemoryDatabase::new(), None); + let mut alice_event_stream = alice_ts.get_event_stream_fused(); - runtime - .block_on(bob_ts.set_base_node_public_key(base_node_identity.public_key().clone())) - .unwrap(); + let (_, _, bob_outbound_service, mut bob_tx_sender, _, _, _, _) = + setup_transaction_service_no_comms(&mut runtime, factories.clone(), TransactionMemoryDatabase::new(), None); let mut alice_total_available = 250000 * uT; let (_utxo, uo) = make_input(&mut OsRng, alice_total_available, &factories.commitment); @@ -1762,7 +1862,7 @@ fn transaction_base_node_monitoring() { let amount_sent = 10000 * uT; - runtime + let tx_id = runtime .block_on(alice_ts.send_transaction( bob_node_identity.public_key().clone(), amount_sent, @@ -1771,15 +1871,21 @@ fn transaction_base_node_monitoring() { )) .unwrap(); - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + alice_outbound_service + .wait_call_count(2, Duration::from_secs(60)) + .unwrap(); + + let (_, body) = alice_outbound_service.pop_call().unwrap(); + let _ = alice_outbound_service.pop_call().unwrap(); // burn SAF message + + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_sender_msg: TransactionSenderMessage = envelope_body .decode_part::(1) .unwrap() .unwrap() .try_into() .unwrap(); - let tx_id = match tx_sender_msg.clone() { + match tx_sender_msg.clone() { TransactionSenderMessage::Single(s) => s.tx_id, _ => { assert!(false, "Transaction is the not a single rounder sender variant"); @@ -1794,48 +1900,24 @@ fn transaction_base_node_monitoring() { ))) .unwrap(); - let _result_stream = runtime.block_on(async { - collect_stream!( - bob_ts.get_event_stream_fused(), - take = 1, - timeout = Duration::from_secs(20) - ) - }); - let call = bob_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); - let tx_reply_msg: RecipientSignedMessage = envelope_body + bob_outbound_service + .wait_call_count(2, Duration::from_secs(60)) + .unwrap(); + let (_, body) = bob_outbound_service.pop_call().unwrap(); + let _ = bob_outbound_service.pop_call().unwrap(); // burn SAF message + + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bob_tx_reply_msg1: RecipientSignedMessage = envelope_body .decode_part::(1) .unwrap() .unwrap() .try_into() .unwrap(); - runtime - .block_on(alice_tx_ack_sender.send(create_dummy_message( - tx_reply_msg.into(), - bob_node_identity.public_key(), - ))) - .unwrap(); - - let _result_stream = runtime.block_on(async { - collect_stream!( - alice_ts.get_event_stream_fused().map(|i| (*i).clone()), - take = 1, - timeout = Duration::from_secs(60) - ) - }); - let alice_completed_tx = runtime - .block_on(alice_ts.get_completed_transactions()) - .unwrap() - .remove(&tx_id) - .expect("Transaction must be in collection"); - - assert_eq!(alice_completed_tx.status, TransactionStatus::Completed); - // Send another transaction let amount_sent2 = 20000 * uT; - runtime + let tx_id2 = runtime .block_on(alice_ts.send_transaction( bob_node_identity.public_key().clone(), amount_sent2, @@ -1844,21 +1926,19 @@ fn transaction_base_node_monitoring() { )) .unwrap(); - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + alice_outbound_service + .wait_call_count(2, Duration::from_secs(60)) + .unwrap(); + let (_, body) = alice_outbound_service.pop_call().unwrap(); + let _ = alice_outbound_service.pop_call().unwrap(); // burn SAF message + + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_sender_msg: TransactionSenderMessage = envelope_body .decode_part::(1) .unwrap() .unwrap() .try_into() .unwrap(); - let tx_id2 = match tx_sender_msg.clone() { - TransactionSenderMessage::Single(s) => s.tx_id, - _ => { - assert!(false, "Transaction is the not a single rounder sender variant"); - 0 - }, - }; runtime .block_on(bob_tx_sender.send(create_dummy_message( @@ -1866,17 +1946,14 @@ fn transaction_base_node_monitoring() { alice_node_identity.public_key(), ))) .unwrap(); + bob_outbound_service + .wait_call_count(2, Duration::from_secs(60)) + .unwrap(); + let (_, body) = bob_outbound_service.pop_call().unwrap(); + let _ = bob_outbound_service.pop_call().unwrap(); // burn SAF message - let _result_stream = runtime.block_on(async { - collect_stream!( - bob_ts.get_event_stream_fused(), - take = 2, - timeout = Duration::from_secs(20) - ) - }); - let call = bob_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); - let tx_reply_msg: RecipientSignedMessage = envelope_body + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bob_tx_reply_msg2: RecipientSignedMessage = envelope_body .decode_part::(1) .unwrap() .unwrap() @@ -1885,23 +1962,53 @@ fn transaction_base_node_monitoring() { runtime .block_on(alice_tx_ack_sender.send(create_dummy_message( - tx_reply_msg.into(), + bob_tx_reply_msg1.into(), + bob_node_identity.public_key(), + ))) + .unwrap(); + runtime + .block_on(alice_tx_ack_sender.send(create_dummy_message( + bob_tx_reply_msg2.into(), bob_node_identity.public_key(), ))) .unwrap(); - let _result_stream = runtime.block_on(async { - collect_stream!( - alice_ts.get_event_stream_fused().map(|i| (*i).clone()), - take = 2, - timeout = Duration::from_secs(60) - ) - }); + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut reply_count = 0; + loop { + futures::select! { + event = alice_event_stream.select_next_some() => { + match &*event.unwrap() { + TransactionEvent::ReceivedTransactionReply(_) => { + reply_count+=1; + if reply_count >= 2 { + break; + } + }, + _ => (), + } + }, + () = delay => { + break; + }, + } + } + }); + + let alice_completed_tx = runtime + .block_on(alice_ts.get_completed_transactions()) + .unwrap() + .remove(&tx_id) + .expect("Transaction must be in collection"); + + assert_eq!(alice_completed_tx.status, TransactionStatus::Completed); + let alice_completed_tx2 = runtime .block_on(alice_ts.get_completed_transactions()) .unwrap() .remove(&tx_id2) - .expect("Transaction must be in collection"); + .expect("Transaction2 must be in collection"); assert_eq!(alice_completed_tx2.status, TransactionStatus::Completed); @@ -1909,26 +2016,20 @@ fn transaction_base_node_monitoring() { .block_on(alice_ts.set_base_node_public_key(base_node_identity.public_key().clone())) .unwrap(); - let _ = alice_outbound_service.pop_call().unwrap(); // burn a base node request - - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); - let msr = envelope_body - .decode_part::(1) - .unwrap() - .unwrap(); - - let mempool_service_request = MempoolServiceRequest::try_from(msr.clone()).unwrap(); + let _ = alice_outbound_service.wait_call_count(6, Duration::from_secs(60)); + for _ in 0..6 { + let _ = alice_outbound_service.pop_call().unwrap(); // burn SAF message + } - let broadcast_tx_id = mempool_service_request.request_key; - let completed_tx_id = if tx_id == broadcast_tx_id { tx_id2 } else { tx_id }; + let broadcast_tx_id = tx_id; + let completed_tx_id = tx_id2; let broadcast_tx = runtime .block_on(alice_ts.get_completed_transactions()) .unwrap() .remove(&broadcast_tx_id) - .expect("Transaction must be in collection"); - let tx_outputs: Vec = broadcast_tx + .expect("Broadcast Transaction must be in collection"); + let broadcast_tx_outputs: Vec = broadcast_tx .transaction .body .outputs() @@ -1936,10 +2037,18 @@ fn transaction_base_node_monitoring() { .map(|o| TransactionOutputProto::from(o.clone())) .collect(); - match mempool_service_request.request { - MempoolRequest::GetStats => assert!(false, "Invalid Mempool Service Request variant"), - _ => (), - } + let completed_tx = runtime + .block_on(alice_ts.get_completed_transactions()) + .unwrap() + .remove(&completed_tx_id) + .expect("Completed Transaction must be in collection"); + let completed_tx_outputs: Vec = completed_tx + .transaction + .body + .outputs() + .iter() + .map(|o| TransactionOutputProto::from(o.clone())) + .collect(); let mempool_response = MempoolProto::MempoolServiceResponse { request_key: broadcast_tx_id, @@ -1952,27 +2061,33 @@ fn transaction_base_node_monitoring() { ) .unwrap(); - let result_stream = runtime.block_on(async { - collect_stream!( - alice_ts.get_event_stream_fused().map(|i| (*i).clone()), - take = 6, - timeout = Duration::from_secs(60) - ) - }); - assert!( - result_stream.iter().fold(0, |acc, item| { - if let TransactionEvent::TransactionMinedRequestTimedOut(_) = item { - acc + 1 - } else { - acc + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut mined_request_timeout_count = 0; + loop { + futures::select! { + event = alice_event_stream.select_next_some() => { + if let TransactionEvent::TransactionMinedRequestTimedOut(_) = &*event.unwrap(){ + mined_request_timeout_count +=1; + if mined_request_timeout_count >= 2 { + break; + } + + } + }, + () = delay => { + break; + }, } - }) >= 2 - ); + } + assert!(mined_request_timeout_count >= 2); + }); - let wrong_outputs = vec![tx_outputs[0].clone(), TransactionOutput::default().into()]; + // Test that receiving a base node response with the wrong outputs does not result in a TX being mined + let wrong_outputs = vec![completed_tx_outputs[0].clone(), TransactionOutput::default().into()]; let base_node_response = BaseNodeProto::BaseNodeServiceResponse { - request_key: tx_id.clone(), + request_key: completed_tx_id, response: Some(BaseNodeResponseProto::TransactionOutputs( BaseNodeProto::TransactionOutputs { outputs: wrong_outputs.into(), @@ -1987,52 +2102,65 @@ fn transaction_base_node_monitoring() { ))) .unwrap(); - let result_stream = runtime.block_on(async { - collect_stream!( - alice_ts.get_event_stream_fused().map(|i| (*i).clone()), - take = 10, - timeout = Duration::from_secs(60) - ) - }); + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut mined_request_timeout_count = 0; + loop { + futures::select! { + event = alice_event_stream.select_next_some() => { + if let TransactionEvent::TransactionMinedRequestTimedOut(_) = &*event.unwrap(){ + mined_request_timeout_count +=1; + if mined_request_timeout_count >= 2 { + break; + } - assert!( - result_stream.iter().fold(0, |acc, item| { - if let TransactionEvent::TransactionMinedRequestTimedOut(_) = item { - acc + 1 - } else { - acc + } + }, + () = delay => { + break; + }, } - }) >= 3 - ); + } + assert!(mined_request_timeout_count >= 2); + }); let broadcast_tx = runtime .block_on(alice_ts.get_completed_transactions()) .unwrap() .remove(&broadcast_tx_id) - .expect("Transaction must be in collection"); + .expect("Broadcast Transaction2 must be in collection"); let completed_tx = runtime .block_on(alice_ts.get_completed_transactions()) .unwrap() .remove(&completed_tx_id) - .expect("Transaction must be in collection"); + .expect("Completed Transaction must be in collection"); assert_eq!(broadcast_tx.status, TransactionStatus::Broadcast); assert_eq!(completed_tx.status, TransactionStatus::Completed); - let tx_outputs2: Vec = completed_tx - .transaction - .body - .outputs() - .iter() - .map(|o| TransactionOutputProto::from(o.clone())) - .collect(); + let mut chain_monitoring_id = 0u64; + // We need to get the Protocol ID that is not the completed_tx_id so we might need to pop one or pop up to 3 + for _ in 0..4 { + let call = alice_outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(&mut call.1.to_vec().as_slice()).unwrap(); + let msr = envelope_body + .clone() + .decode_part::(1) + .unwrap() + .unwrap(); + + chain_monitoring_id = msr.request_key; + if chain_monitoring_id != completed_tx_id { + break; + } + } let base_node_response = BaseNodeProto::BaseNodeServiceResponse { - request_key: broadcast_tx_id.clone(), + request_key: chain_monitoring_id, response: Some(BaseNodeResponseProto::TransactionOutputs( BaseNodeProto::TransactionOutputs { - outputs: tx_outputs.into(), + outputs: broadcast_tx_outputs.into(), }, )), }; @@ -2045,10 +2173,10 @@ fn transaction_base_node_monitoring() { .unwrap(); let base_node_response2 = BaseNodeProto::BaseNodeServiceResponse { - request_key: completed_tx_id.clone(), + request_key: completed_tx_id, response: Some(BaseNodeResponseProto::TransactionOutputs( BaseNodeProto::TransactionOutputs { - outputs: tx_outputs2.into(), + outputs: completed_tx_outputs.into(), }, )), }; @@ -2060,15 +2188,13 @@ fn transaction_base_node_monitoring() { ))) .unwrap(); - let mut event_stream = alice_ts.get_event_stream_fused(); - runtime.block_on(async { let mut delay = delay_for(Duration::from_secs(60)).fuse(); let mut acc = 0; loop { futures::select! { - event = event_stream.select_next_some() => { - if let TransactionEvent::TransactionMined(_) = &*event { + event = alice_event_stream.select_next_some() => { + if let TransactionEvent::TransactionMined(_) = &*event.unwrap() { acc += 1; if acc >= 2 { break; @@ -2087,7 +2213,7 @@ fn transaction_base_node_monitoring() { .block_on(alice_ts.get_completed_transactions()) .unwrap() .remove(&tx_id) - .expect("Transaction must be in collection"); + .expect("Completed Transaction3 must be in collection"); assert_eq!(alice_completed_tx.status, TransactionStatus::Mined); @@ -2095,7 +2221,7 @@ fn transaction_base_node_monitoring() { .block_on(alice_ts.get_completed_transactions()) .unwrap() .remove(&tx_id2) - .expect("Transaction must be in collection"); + .expect("Completed Transaction4 must be in collection"); assert_eq!(alice_completed_tx2.status, TransactionStatus::Mined); @@ -2147,52 +2273,57 @@ fn query_all_completed_transactions_on_startup() { }; db.write(WriteOperation::Insert(DbKeyValuePair::CompletedTransaction( - completed_tx1.tx_id.clone(), + completed_tx1.tx_id, Box::new(completed_tx1.clone()), ))) .unwrap(); db.write(WriteOperation::Insert(DbKeyValuePair::CompletedTransaction( - completed_tx2.tx_id.clone(), + completed_tx2.tx_id, Box::new(completed_tx2.clone()), ))) .unwrap(); db.write(WriteOperation::Insert(DbKeyValuePair::CompletedTransaction( - completed_tx3.tx_id.clone(), + completed_tx3.tx_id, Box::new(completed_tx3.clone()), ))) .unwrap(); - let (mut alice_ts, _, _, _, _, _, _, _, _) = + let (mut alice_ts, _, _, _, _, _, _, _) = setup_transaction_service_no_comms(&mut runtime, factories.clone(), db, None); + let mut alice_event_stream = alice_ts.get_event_stream_fused(); runtime .block_on(alice_ts.set_base_node_public_key(PublicKey::default())) .unwrap(); - let result_stream = runtime.block_on(async { - collect_stream!( - alice_ts.get_event_stream_fused().map(|i| (*i).clone()), - take = 2, - timeout = Duration::from_secs(20) - ) + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut found_tx_mined_1 = false; + let mut found_tx_mined_2 = false; + loop { + futures::select! { + event = alice_event_stream.select_next_some() => { + if let TransactionEvent::TransactionMinedRequestTimedOut(tx_id) = &*event.unwrap(){ + match tx_id { + 1u64 => found_tx_mined_1 = true, + 2u64 => found_tx_mined_2 = true, + _ => assert!(false, "Should be no other transactions being broadcast!"), + } + if found_tx_mined_1 && found_tx_mined_2 { + break; + } + } + }, + () = delay => { + break; + }, + } + } + assert!(found_tx_mined_1); + assert!(found_tx_mined_2); }); - - assert!(result_stream - .iter() - .find(|v| match v { - TransactionEvent::TransactionMinedRequestTimedOut(tx_id) => *tx_id == 1u64, - _ => false, - }) - .is_some()); - assert!(result_stream - .iter() - .find(|v| match v { - TransactionEvent::TransactionMinedRequestTimedOut(tx_id) => *tx_id == 2u64, - _ => false, - }) - .is_some()); } #[test] @@ -2283,14 +2414,13 @@ fn test_failed_tx_send_timeout() { loop { futures::select! { event = alice_event_stream.select_next_some() => { - if let TransactionEvent::TransactionSendResult(_, success) = (*event).clone() { + if let TransactionEvent::TransactionDirectSendResult(_, success) = (*event.unwrap()).clone() { returned = true; result = success; break; } }, () = delay => { - log::error!("This select loop timed out"); break; }, } @@ -2305,7 +2435,6 @@ fn test_failed_tx_send_timeout() { #[test] fn transaction_cancellation_when_not_in_mempool() { - let _ = env_logger::try_init(); let factories = CryptoFactories::default(); let mut runtime = Runtime::new().unwrap(); @@ -2327,21 +2456,19 @@ fn transaction_cancellation_when_not_in_mempool() { _, mut alice_mempool_response_sender, mut alice_base_node_response_sender, - _, ) = setup_transaction_service_no_comms( &mut runtime, factories.clone(), TransactionMemoryDatabase::new(), - Some(Duration::from_secs(20)), + Some(Duration::from_secs(5)), ); - - let (mut bob_ts, _, bob_outbound_service, mut bob_tx_sender, _, _, _, _, _) = setup_transaction_service_no_comms( + let mut alice_event_stream = alice_ts.get_event_stream_fused(); + let (mut bob_ts, _, bob_outbound_service, mut bob_tx_sender, _, _, _, _) = setup_transaction_service_no_comms( &mut runtime, factories.clone(), TransactionMemoryDatabase::new(), Some(Duration::from_secs(20)), ); - runtime .block_on(bob_ts.set_base_node_public_key(base_node_identity.public_key().clone())) .unwrap(); @@ -2360,9 +2487,13 @@ fn transaction_cancellation_when_not_in_mempool() { "Testing Message".to_string(), )) .unwrap(); + alice_outbound_service + .wait_call_count(2, Duration::from_secs(60)) + .unwrap(); + let (_, body) = alice_outbound_service.pop_call().unwrap(); + let _ = alice_outbound_service.pop_call().unwrap(); // burn SAF message - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_sender_msg: TransactionSenderMessage = envelope_body .decode_part::(1) .unwrap() @@ -2383,16 +2514,13 @@ fn transaction_cancellation_when_not_in_mempool() { alice_node_identity.public_key(), ))) .unwrap(); + bob_outbound_service + .wait_call_count(2, Duration::from_secs(60)) + .unwrap(); + let (_, body) = bob_outbound_service.pop_call().unwrap(); + let _ = bob_outbound_service.pop_call().unwrap(); // burn SAF message - let _result_stream = runtime.block_on(async { - collect_stream!( - bob_ts.get_event_stream_fused(), - take = 1, - timeout = Duration::from_secs(30) - ) - }); - let call = bob_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let tx_reply_msg: RecipientSignedMessage = envelope_body .decode_part::(1) .unwrap() @@ -2407,12 +2535,25 @@ fn transaction_cancellation_when_not_in_mempool() { ))) .unwrap(); - let _result_stream = runtime.block_on(async { - collect_stream!( - alice_ts.get_event_stream_fused().map(|i| (*i).clone()), - take = 1, - timeout = Duration::from_secs(60) - ) + let _ = alice_outbound_service.wait_call_count(2, Duration::from_secs(60)); + let _ = alice_outbound_service.pop_call().unwrap(); // Burn finalize message + let _ = alice_outbound_service.pop_call().unwrap(); // burn SAF message + + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + loop { + futures::select! { + event = alice_event_stream.select_next_some() => { + match &*event.unwrap() { + TransactionEvent::ReceivedTransactionReply(_) => break, + _ => (), + } + }, + () = delay => { + break; + }, + } + } }); let alice_completed_tx = runtime .block_on(alice_ts.get_completed_transactions()) @@ -2426,22 +2567,6 @@ fn transaction_cancellation_when_not_in_mempool() { .block_on(alice_ts.set_base_node_public_key(base_node_identity.public_key().clone())) .unwrap(); - let _ = alice_outbound_service.pop_call().unwrap(); // burn a base node request - - let call = alice_outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(&mut call.1.as_slice()).unwrap(); - let msr = envelope_body - .decode_part::(1) - .unwrap() - .unwrap(); - - let mempool_service_request = MempoolServiceRequest::try_from(msr.clone()).unwrap(); - - match mempool_service_request.request { - MempoolRequest::GetStats => assert!(false, "Invalid Mempool Service Request variant"), - _ => (), - } - let mempool_response = MempoolProto::MempoolServiceResponse { request_key: tx_id, response: Some(MempoolResponse::TxStorage(TxStorageResponse::UnconfirmedPool).into()), @@ -2453,14 +2578,13 @@ fn transaction_cancellation_when_not_in_mempool() { ) .unwrap(); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); runtime.block_on(async { let mut delay = delay_for(Duration::from_secs(60)).fuse(); let mut timeouts = 0; loop { futures::select! { event = alice_event_stream.select_next_some() => { - if let TransactionEvent::TransactionMinedRequestTimedOut(_e) = &*event { + if let TransactionEvent::TransactionMinedRequestTimedOut(_e) = &*event.unwrap() { timeouts+=1; if timeouts >= 1 { break; @@ -2475,38 +2599,44 @@ fn transaction_cancellation_when_not_in_mempool() { assert!(timeouts >= 1); }); - let base_node_response = BaseNodeProto::BaseNodeServiceResponse { - request_key: tx_id.clone(), - response: Some(BaseNodeResponseProto::TransactionOutputs( - BaseNodeProto::TransactionOutputs { outputs: vec![] }, - )), - }; - runtime - .block_on(alice_base_node_response_sender.send(create_dummy_message( - base_node_response, - base_node_identity.public_key(), - ))) + let alice_completed_tx = runtime + .block_on(alice_ts.get_completed_transactions()) + .unwrap() + .remove(&tx_id) + .expect("Transaction must be in collection"); + + assert_eq!(alice_completed_tx.status, TransactionStatus::Broadcast); + + let _ = alice_outbound_service.wait_call_count(2, Duration::from_secs(60)); + let call = alice_outbound_service.pop_call().unwrap(); + let _ = alice_outbound_service.pop_call().unwrap(); // burn SAF message + + let envelope_body = EnvelopeBody::decode(&mut call.1.to_vec().as_slice()).unwrap(); + let msr = envelope_body + .decode_part::(1) + .unwrap() .unwrap(); + let chain_monitoring_id = msr.request_key; let mempool_response = MempoolProto::MempoolServiceResponse { - request_key: tx_id, + request_key: chain_monitoring_id, response: Some(MempoolResponse::TxStorage(TxStorageResponse::NotStored).into()), }; + let base_node_response = BaseNodeProto::BaseNodeServiceResponse { - request_key: tx_id.clone(), + request_key: chain_monitoring_id, response: Some(BaseNodeResponseProto::TransactionOutputs( BaseNodeProto::TransactionOutputs { outputs: vec![] }, )), }; - let mut alice_event_stream = alice_ts.get_event_stream_fused(); runtime.block_on(async { let mut delay = delay_for(Duration::from_secs(60)).fuse(); let mut timeouts = 0; loop { futures::select! { event = alice_event_stream.select_next_some() => { - if let TransactionEvent::TransactionMinedRequestTimedOut(_e) = &*event { + if let TransactionEvent::TransactionMinedRequestTimedOut(_e) = &*event.unwrap() { timeouts+=1; if timeouts >= 1 { break; @@ -2529,6 +2659,7 @@ fn transaction_cancellation_when_not_in_mempool() { alice_mempool_response_sender.send(create_dummy_message(mempool_response, base_node_identity.public_key())), ) .unwrap(); + runtime .block_on(alice_base_node_response_sender.send(create_dummy_message( base_node_response, @@ -2536,17 +2667,14 @@ fn transaction_cancellation_when_not_in_mempool() { ))) .unwrap(); - let mut alice_event_stream = alice_ts.get_event_stream_fused(); runtime.block_on(async { - let mut delay = delay_for(Duration::from_secs(20)).fuse(); - let mut returned = false; - let mut result = true; + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut cancelled = false; loop { futures::select! { event = alice_event_stream.select_next_some() => { - if let TransactionEvent::TransactionSendDiscoveryComplete(_, success) = &*event { - returned = true; - result = success.clone(); + if let TransactionEvent::TransactionCancelled(_) = &*event.unwrap() { + cancelled = true; } }, () = delay => { @@ -2554,8 +2682,7 @@ fn transaction_cancellation_when_not_in_mempool() { }, } } - assert!(returned, "Event should have occured"); - assert!(!result, "Transaction should be cancelled"); + assert!(cancelled, "Tx should have been cancelled"); }); let alice_completed_tx = runtime @@ -2568,3 +2695,132 @@ fn transaction_cancellation_when_not_in_mempool() { assert_eq!(balance.available_balance, alice_total_available); } + +fn test_transaction_cancellation(backend: T) { + let factories = CryptoFactories::default(); + let mut runtime = Runtime::new().unwrap(); + + let bob_node_identity = + NodeIdentity::random(&mut OsRng, get_next_memory_address(), PeerFeatures::COMMUNICATION_NODE).unwrap(); + + let (mut alice_ts, mut alice_output_manager, _alice_outbound_service, mut alice_tx_sender, _, _, _, _) = + setup_transaction_service_no_comms(&mut runtime, factories.clone(), backend, Some(Duration::from_secs(20))); + let mut alice_event_stream = alice_ts.get_event_stream_fused(); + + let alice_total_available = 250000 * uT; + let (_utxo, uo) = make_input(&mut OsRng, alice_total_available, &factories.commitment); + runtime.block_on(alice_output_manager.add_output(uo)).unwrap(); + + let amount_sent = 10000 * uT; + + let tx_id = runtime + .block_on(alice_ts.send_transaction( + bob_node_identity.public_key().clone(), + amount_sent, + 100 * uT, + "Testing Message".to_string(), + )) + .unwrap(); + + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + loop { + futures::select! { + event = alice_event_stream.select_next_some() => { + if let TransactionEvent::TransactionStoreForwardSendResult(_,_) = &*event.unwrap() { + break; + } + }, + () = delay => { + break; + }, + } + } + }); + + runtime + .block_on(alice_ts.get_pending_outbound_transactions()) + .unwrap() + .remove(&tx_id) + .expect("Pending Transaction should be in list"); + + runtime.block_on(alice_ts.cancel_transaction(tx_id)).unwrap(); + + assert!(runtime + .block_on(alice_ts.get_pending_outbound_transactions()) + .unwrap() + .remove(&tx_id) + .is_none()); + + let mut builder = SenderTransactionProtocol::builder(1); + let amount = MicroTari::from(10_000); + let input = UnblindedOutput::new(MicroTari::from(100_000), PrivateKey::random(&mut OsRng), None); + builder + .with_lock_height(0) + .with_fee_per_gram(MicroTari::from(177)) + .with_offset(PrivateKey::random(&mut OsRng)) + .with_private_nonce(PrivateKey::random(&mut OsRng)) + .with_amount(0, amount) + .with_message("Yo!".to_string()) + .with_input( + input.as_transaction_input(&factories.commitment, OutputFeatures::default()), + input.clone(), + ) + .with_change_secret(PrivateKey::random(&mut OsRng)); + + let mut stp = builder.build::(&factories).unwrap(); + let tx_sender_msg = stp.build_single_round_message().unwrap(); + let tx_id2 = tx_sender_msg.tx_id; + let proto_message = proto::TransactionSenderMessage::single(tx_sender_msg.into()); + runtime + .block_on(alice_tx_sender.send(create_dummy_message( + proto_message, + &PublicKey::from_secret_key(&PrivateKey::random(&mut OsRng)), + ))) + .unwrap(); + + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + loop { + futures::select! { + event = alice_event_stream.select_next_some() => { + if let TransactionEvent::ReceivedTransaction(_) = &*event.unwrap() { + break; + } + }, + () = delay => { + break; + }, + } + } + }); + + runtime + .block_on(alice_ts.get_pending_inbound_transactions()) + .unwrap() + .remove(&tx_id2) + .expect("Pending Transaction should be in list"); + + runtime.block_on(alice_ts.cancel_transaction(tx_id2)).unwrap(); + + assert!(runtime + .block_on(alice_ts.get_pending_inbound_transactions()) + .unwrap() + .remove(&tx_id2) + .is_none()); +} + +#[test] +fn test_transaction_cancellation_memory_db() { + test_transaction_cancellation(TransactionMemoryDatabase::new()); +} + +#[test] +fn test_transaction_cancellation_sqlite_db() { + let db_name = format!("{}.sqlite3", random_string(8).as_str()); + let temp_dir = TempDir::new(random_string(8).as_str()).unwrap(); + let db_folder = temp_dir.path().to_str().unwrap().to_string(); + let connection = run_migration_and_create_sqlite_connection(&format!("{}/{}", db_folder, db_name)).unwrap(); + + test_transaction_cancellation(TransactionServiceSqliteDatabase::new(connection)); +} diff --git a/base_layer/wallet/tests/transaction_service/storage.rs b/base_layer/wallet/tests/transaction_service/storage.rs index 4f3a1743d5..a6f6b08187 100644 --- a/base_layer/wallet/tests/transaction_service/storage.rs +++ b/base_layer/wallet/tests/transaction_service/storage.rs @@ -270,7 +270,7 @@ pub fn test_db_backend(backend: T) { ); #[cfg(feature = "test_harness")] runtime - .block_on(db.broadcast_completed_transaction(completed_txs[0].tx_id.clone())) + .block_on(db.broadcast_completed_transaction(completed_txs[0].tx_id)) .unwrap(); let retrieved_completed_txs = runtime.block_on(db.get_completed_transactions()).unwrap(); @@ -282,7 +282,7 @@ pub fn test_db_backend(backend: T) { #[cfg(feature = "test_harness")] runtime - .block_on(db.mine_completed_transaction(completed_txs[0].tx_id.clone())) + .block_on(db.mine_completed_transaction(completed_txs[0].tx_id)) .unwrap(); let retrieved_completed_txs = runtime.block_on(db.get_completed_transactions()).unwrap(); @@ -296,12 +296,19 @@ pub fn test_db_backend(backend: T) { let completed_txs = runtime.block_on(db.get_completed_transactions()).unwrap(); let num_completed_txs = completed_txs.len(); + let cancelled_tx_id = completed_txs[&1].tx_id; + assert!(runtime.block_on(db.get_completed_transaction(cancelled_tx_id)).is_ok()); runtime - .block_on(db.cancel_completed_transaction(completed_txs[&1].tx_id)) + .block_on(db.cancel_completed_transaction(cancelled_tx_id)) .unwrap(); - let completed_txs = runtime.block_on(db.get_completed_transactions()).unwrap(); assert_eq!(completed_txs.len(), num_completed_txs - 1); + + assert!(runtime.block_on(db.get_completed_transaction(cancelled_tx_id)).is_err()); + + assert!(runtime + .block_on(db.get_completed_transaction(completed_txs[&0].tx_id)) + .is_ok()); } #[test] diff --git a/base_layer/wallet/tests/wallet/mod.rs b/base_layer/wallet/tests/wallet/mod.rs index cee1c3ed28..37ff8e6e60 100644 --- a/base_layer/wallet/tests/wallet/mod.rs +++ b/base_layer/wallet/tests/wallet/mod.rs @@ -28,14 +28,15 @@ use tari_comms::{ peer_manager::{NodeId, NodeIdentity, Peer, PeerFeatures, PeerFlags}, types::CommsPublicKey, }; -#[cfg(feature = "test_harness")] use tari_comms_dht::DhtConfig; use tari_core::transactions::{tari_amount::MicroTari, types::CryptoFactories}; use tari_crypto::keys::PublicKey; use tari_p2p::initialization::CommsConfig; -use tari_test_utils::{collect_stream, paths::with_temp_dir}; +use tari_test_utils::paths::with_temp_dir; use crate::support::comms_and_services::get_next_memory_address; +use futures::{FutureExt, StreamExt}; +use std::path::Path; use tari_core::transactions::{tari_amount::uT, transaction::UnblindedOutput, types::PrivateKey}; use tari_p2p::transport::TransportType; use tari_wallet::{ @@ -47,7 +48,7 @@ use tari_wallet::{ Wallet, }; use tempdir::TempDir; -use tokio::runtime::Runtime; +use tokio::{runtime::Runtime, time::delay_for}; fn create_peer(public_key: CommsPublicKey, net_address: Multiaddr) -> Peer { Peer::new( @@ -60,91 +61,62 @@ fn create_peer(public_key: CommsPublicKey, net_address: Multiaddr) -> Peer { ) } +fn create_wallet( + node_identity: NodeIdentity, + data_path: &Path, + factories: CryptoFactories, +) -> Wallet +{ + let comms_config = CommsConfig { + node_identity: Arc::new(node_identity.clone()), + transport_type: TransportType::Memory { + listener_address: node_identity.public_address(), + }, + datastore_path: data_path.to_path_buf(), + peer_database_name: random_string(8), + max_concurrent_inbound_tasks: 100, + outbound_buffer_size: 100, + dht: DhtConfig { + discovery_request_timeout: Duration::from_secs(1), + ..Default::default() + }, + allow_test_addresses: true, + listener_liveness_whitelist_cidrs: Vec::new(), + listener_liveness_max_sessions: 0, + }; + let config = WalletConfig { + comms_config, + factories, + transaction_service_config: None, + }; + let runtime_node = Runtime::new().unwrap(); + let wallet = Wallet::new( + config, + runtime_node, + WalletMemoryDatabase::new(), + TransactionMemoryDatabase::new(), + OutputManagerMemoryDatabase::new(), + ContactsServiceMemoryDatabase::new(), + ) + .unwrap(); + wallet +} + #[test] fn test_wallet() { with_temp_dir(|dir_path| { let mut runtime = Runtime::new().unwrap(); let factories = CryptoFactories::default(); - let alice_identity = NodeIdentity::random( - &mut OsRng, - "/ip4/127.0.0.1/tcp/22523".parse().unwrap(), - PeerFeatures::COMMUNICATION_NODE, - ) - .unwrap(); - let bob_identity = NodeIdentity::random( - &mut OsRng, - "/ip4/127.0.0.1/tcp/22145".parse().unwrap(), - PeerFeatures::COMMUNICATION_NODE, - ) - .unwrap(); + let alice_identity = + NodeIdentity::random(&mut OsRng, get_next_memory_address(), PeerFeatures::COMMUNICATION_NODE).unwrap(); + let bob_identity = + NodeIdentity::random(&mut OsRng, get_next_memory_address(), PeerFeatures::COMMUNICATION_NODE).unwrap(); - let base_node_identity = NodeIdentity::random( - &mut OsRng, - "/ip4/127.0.0.1/tcp/54225".parse().unwrap(), - PeerFeatures::COMMUNICATION_NODE, - ) - .unwrap(); + let base_node_identity = + NodeIdentity::random(&mut OsRng, get_next_memory_address(), PeerFeatures::COMMUNICATION_NODE).unwrap(); - let comms_config1 = CommsConfig { - node_identity: Arc::new(alice_identity.clone()), - transport_type: TransportType::Tcp { - listener_address: alice_identity.public_address(), - tor_socks_config: None, - }, - datastore_path: dir_path.to_path_buf(), - peer_database_name: random_string(8), - max_concurrent_inbound_tasks: 100, - outbound_buffer_size: 100, - dht: Default::default(), - allow_test_addresses: true, - listener_liveness_whitelist_cidrs: Vec::new(), - listener_liveness_max_sessions: 0, - }; - let comms_config2 = CommsConfig { - node_identity: Arc::new(bob_identity.clone()), - transport_type: TransportType::Tcp { - listener_address: bob_identity.public_address(), - tor_socks_config: None, - }, - datastore_path: dir_path.to_path_buf(), - peer_database_name: random_string(8), - max_concurrent_inbound_tasks: 100, - outbound_buffer_size: 100, - dht: Default::default(), - allow_test_addresses: true, - listener_liveness_whitelist_cidrs: Vec::new(), - listener_liveness_max_sessions: 0, - }; - let config1 = WalletConfig { - comms_config: comms_config1, - factories: factories.clone(), - transaction_service_config: None, - }; - let config2 = WalletConfig { - comms_config: comms_config2, - factories: factories.clone(), - transaction_service_config: None, - }; - let runtime_node1 = Runtime::new().unwrap(); - let runtime_node2 = Runtime::new().unwrap(); - let mut alice_wallet = Wallet::new( - config1, - runtime_node1, - WalletMemoryDatabase::new(), - TransactionMemoryDatabase::new(), - OutputManagerMemoryDatabase::new(), - ContactsServiceMemoryDatabase::new(), - ) - .unwrap(); - let mut bob_wallet = Wallet::new( - config2, - runtime_node2, - WalletMemoryDatabase::new(), - TransactionMemoryDatabase::new(), - OutputManagerMemoryDatabase::new(), - ContactsServiceMemoryDatabase::new(), - ) - .unwrap(); + let mut alice_wallet = create_wallet(alice_identity.clone(), dir_path, factories.clone()); + let mut bob_wallet = create_wallet(bob_identity.clone(), dir_path, factories.clone()); alice_wallet .runtime @@ -169,7 +141,7 @@ fn test_wallet() { ) .unwrap(); - let alice_event_stream = alice_wallet.transaction_service.get_event_stream_fused(); + let mut alice_event_stream = alice_wallet.transaction_service.get_event_stream_fused(); let value = MicroTari::from(1000); let (_utxo, uo1) = make_input(&mut OsRng, MicroTari(2500), &factories.commitment); @@ -187,18 +159,27 @@ fn test_wallet() { )) .unwrap(); - let result_stream = runtime - .block_on(async { collect_stream!(alice_event_stream, take = 2, timeout = Duration::from_secs(10)) }); - let received_transaction_reply_count = result_stream.iter().fold(0, |acc, x| match &**x { - TransactionEvent::ReceivedTransactionReply(_) => acc + 1, - _ => acc, + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut reply_count = false; + loop { + futures::select! { + event = alice_event_stream.select_next_some() => match &*event.unwrap() { + TransactionEvent::ReceivedTransactionReply(_) => { + reply_count = true; + break; + }, + _ => (), + }, + + () = delay => { + break; + }, + } + } + assert!(reply_count); }); - assert_eq!( - received_transaction_reply_count, 1, - "Did not received correct numebr of replies" - ); - let mut contacts = Vec::new(); for i in 0..2 { let (_secret_key, public_key) = PublicKey::random_keypair(&mut OsRng); @@ -218,6 +199,106 @@ fn test_wallet() { }); } +#[test] +fn test_store_and_forward_send_tx() { + let factories = CryptoFactories::default(); + let db_tempdir = TempDir::new(random_string(8).as_str()).unwrap(); + + let alice_identity = + NodeIdentity::random(&mut OsRng, get_next_memory_address(), PeerFeatures::COMMUNICATION_NODE).unwrap(); + let bob_identity = + NodeIdentity::random(&mut OsRng, get_next_memory_address(), PeerFeatures::COMMUNICATION_NODE).unwrap(); + let carol_identity = + NodeIdentity::random(&mut OsRng, get_next_memory_address(), PeerFeatures::COMMUNICATION_NODE).unwrap(); + + let mut alice_wallet = create_wallet(alice_identity.clone(), &db_tempdir.path(), factories.clone()); + let mut bob_wallet = create_wallet(bob_identity.clone(), &db_tempdir.path(), factories.clone()); + let mut alice_event_stream = alice_wallet.transaction_service.get_event_stream_fused(); + + alice_wallet + .runtime + .block_on(alice_wallet.comms.peer_manager().add_peer(create_peer( + bob_identity.public_key().clone(), + bob_identity.public_address(), + ))) + .unwrap(); + + bob_wallet + .runtime + .block_on(bob_wallet.comms.peer_manager().add_peer(create_peer( + alice_identity.public_key().clone(), + alice_identity.public_address(), + ))) + .unwrap(); + + bob_wallet + .runtime + .block_on(bob_wallet.comms.peer_manager().add_peer(create_peer( + carol_identity.public_key().clone(), + carol_identity.public_address(), + ))) + .unwrap(); + + let value = MicroTari::from(1000); + let (_utxo, uo1) = make_input(&mut OsRng, MicroTari(2500), &factories.commitment); + + alice_wallet + .runtime + .block_on(alice_wallet.output_manager_service.add_output(uo1)) + .unwrap(); + + alice_wallet + .runtime + .block_on(alice_wallet.transaction_service.send_transaction( + carol_identity.public_key().clone(), + value, + MicroTari::from(20), + "Store and Forward!".to_string(), + )) + .unwrap(); + + // Waiting here for a while to make sure the discovery retry is over + alice_wallet + .runtime + .block_on(async { delay_for(Duration::from_secs(10)).await }); + + let mut carol_wallet = create_wallet(carol_identity.clone(), &db_tempdir.path(), factories.clone()); + + carol_wallet + .runtime + .block_on(carol_wallet.comms.peer_manager().add_peer(create_peer( + bob_identity.public_key().clone(), + bob_identity.public_address(), + ))) + .unwrap(); + + alice_wallet.runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut tx_reply = 0; + loop { + futures::select! { + event = alice_event_stream.select_next_some() => { + match &*event.unwrap() { + TransactionEvent::ReceivedTransactionReply(_) => tx_reply+=1, + _ => (), + } + if tx_reply == 1 { + break; + } + }, + () = delay => { + break; + }, + } + } + assert_eq!(tx_reply, 1, "Must have received a reply from Carol"); + }); + + alice_wallet.shutdown(); + bob_wallet.shutdown(); + carol_wallet.shutdown(); +} + #[test] fn test_import_utxo() { let factories = CryptoFactories::default(); diff --git a/base_layer/wallet_ffi/Cargo.toml b/base_layer/wallet_ffi/Cargo.toml index f06e999579..bc0c0eaaaf 100644 --- a/base_layer/wallet_ffi/Cargo.toml +++ b/base_layer/wallet_ffi/Cargo.toml @@ -3,15 +3,15 @@ name = "tari_wallet_ffi" authors = ["The Tari Development Community"] description = "Tari cryptocurrency wallet C FFI bindings" license = "BSD-3-Clause" -version = "0.0.10" +version = "0.1.0" edition = "2018" [dependencies] -tari_comms = { path = "../../comms", version = "^0.0"} -tari_comms_dht = { path = "../../comms/dht", version = "^0.0"} +tari_comms = { path = "../../comms", version = "^0.1"} +tari_comms_dht = { path = "../../comms/dht", version = "^0.1"} tari_crypto = { version = "^0.3" } -tari_p2p = {path = "../p2p", version = "^0.0"} -tari_wallet = { path = "../wallet", version = "^0.0", features = ["test_harness", "c_integration"]} +tari_p2p = {path = "../p2p", version = "^0.1"} +tari_wallet = { path = "../wallet", version = "^0.1", features = ["test_harness", "c_integration"]} tari_shutdown = { path = "../../infrastructure/shutdown", version = "^0.0"} tari_utilities = "^0.1" @@ -27,7 +27,7 @@ log4rs = {version = "0.8.3", features = ["console_appender", "file_appender", "f [dependencies.tari_core] path = "../../base_layer/core" -version = "^0.0" +version = "^0.1" default-features = false features = ["transactions"] diff --git a/base_layer/wallet_ffi/README.md b/base_layer/wallet_ffi/README.md index 82a18d1a7d..6d67503759 100644 --- a/base_layer/wallet_ffi/README.md +++ b/base_layer/wallet_ffi/README.md @@ -35,11 +35,36 @@ For macOS Mojave additional headers need to be installed, run ```Shell Script open /Library/Developer/CommandLineTools/Packages/macOS_SDK_headers_for_macOS_10.14.pkg ``` -and follow the prompts +and follow the prompts. + +For Catalina, if you get compilation errors such as these: + + xcrun: error: SDK "iphoneos" cannot be located + xcrun: error: unable to lookup item 'Path' in SDK 'iphoneos' + +Switch the XCode app defaults with: + + sudo xcode-select --switch /Applications/Xcode.app + +**Note:** If this command fails, XCode was not found and needs to be installed/re-installed. ## Android Dependencies -Download the [Android NDK Bundle](https://developer.android.com/ndk/downloads) +Install [Android Studio](https://developer.android.com/studio) and then use the SDK Manager to install the Android NDK +along with the SDK of your choice (Android Q is recommended). Not all of these tools are required, but will come in +handy during Rust / Android development: + +* LLDB +* NDK (Side by side) +* Android SDK Command-line Tools (latest) +* Android SDK Platform Tools +* Android SDK Tools +* CMake + +When setting up an AVD (Android Virtual Device) please note that a 64-bit image (x86_64) needs to be used and not a +32-bit image (x86). This is to run the application on the simulator with these libraries. + +Alternatively, download the [Android NDK Bundle](https://developer.android.com/ndk/downloads) directly. ## Enable Hidden Files @@ -67,10 +92,12 @@ Install [Rust](https://www.rust-lang.org/tools/install) Install the following tools and system images ```Shell Script -rustup toolchain add nightly-2019-10-04 -rustup default nightly-2019-10-04 +rustup toolchain add nightly-2020-01-08 +rustup default nightly-2020-01-08 rustup component add rustfmt --toolchain nightly rustup component add clippy +rustup target add x86_64-apple-ios aarch64-apple-ios # iPhone and emulator cross compiling +rustup target add x86_64-linux-android aarch64-linux-android armv7-linux-androideabi # Android device cross compiling ``` ## Build Configuration diff --git a/base_layer/wallet_ffi/src/callback_handler.rs b/base_layer/wallet_ffi/src/callback_handler.rs index c704bf48ec..8dccd34d8b 100644 --- a/base_layer/wallet_ffi/src/callback_handler.rs +++ b/base_layer/wallet_ffi/src/callback_handler.rs @@ -55,7 +55,7 @@ use tari_shutdown::ShutdownSignal; use tari_wallet::{ output_manager_service::{handle::OutputManagerEvent, TxId}, transaction_service::{ - handle::TransactionEvent, + handle::{TransactionEvent, TransactionEventReceiver}, storage::database::{CompletedTransaction, InboundTransaction, TransactionBackend, TransactionDatabase}, }, }; @@ -70,10 +70,12 @@ where TBackend: TransactionBackend + 'static callback_received_finalized_transaction: unsafe extern "C" fn(*mut CompletedTransaction), callback_transaction_broadcast: unsafe extern "C" fn(*mut CompletedTransaction), callback_transaction_mined: unsafe extern "C" fn(*mut CompletedTransaction), - callback_discovery_process_complete: unsafe extern "C" fn(TxId, bool), + callback_direct_send_result: unsafe extern "C" fn(TxId, bool), + callback_store_and_forward_send_result: unsafe extern "C" fn(TxId, bool), + callback_transaction_cancellation: unsafe extern "C" fn(TxId), callback_base_node_sync_complete: unsafe extern "C" fn(TxId, bool), db: TransactionDatabase, - transaction_service_event_stream: Fuse>, + transaction_service_event_stream: Fuse, output_manager_service_event_stream: Fuse>, shutdown_signal: Option, } @@ -84,7 +86,7 @@ where TBackend: TransactionBackend + 'static { pub fn new( db: TransactionDatabase, - transaction_service_event_stream: Fuse>, + transaction_service_event_stream: Fuse, output_manager_service_event_stream: Fuse>, shutdown_signal: ShutdownSignal, callback_received_transaction: unsafe extern "C" fn(*mut InboundTransaction), @@ -92,7 +94,9 @@ where TBackend: TransactionBackend + 'static callback_received_finalized_transaction: unsafe extern "C" fn(*mut CompletedTransaction), callback_transaction_broadcast: unsafe extern "C" fn(*mut CompletedTransaction), callback_transaction_mined: unsafe extern "C" fn(*mut CompletedTransaction), - callback_discovery_process_complete: unsafe extern "C" fn(TxId, bool), + callback_direct_send_result: unsafe extern "C" fn(TxId, bool), + callback_store_and_forward_send_result: unsafe extern "C" fn(TxId, bool), + callback_transaction_cancellation: unsafe extern "C" fn(TxId), callback_base_node_sync_complete: unsafe extern "C" fn(u64, bool), ) -> Self { @@ -118,7 +122,15 @@ where TBackend: TransactionBackend + 'static ); info!( target: LOG_TARGET, - "DiscoveryProcessCompleteCallback -> Assigning Fn: {:?}", callback_discovery_process_complete + "DirectSendResultCallback -> Assigning Fn: {:?}", callback_direct_send_result + ); + info!( + target: LOG_TARGET, + "StoreAndForwardSendResultCallback -> Assigning Fn: {:?}", callback_store_and_forward_send_result + ); + info!( + target: LOG_TARGET, + "TransactionCancellationCallback -> Assigning Fn: {:?}", callback_transaction_cancellation ); info!( target: LOG_TARGET, @@ -131,7 +143,9 @@ where TBackend: TransactionBackend + 'static callback_received_finalized_transaction, callback_transaction_broadcast, callback_transaction_mined, - callback_discovery_process_complete, + callback_direct_send_result, + callback_store_and_forward_send_result, + callback_transaction_cancellation, callback_base_node_sync_complete, db, transaction_service_event_stream, @@ -150,38 +164,40 @@ where TBackend: TransactionBackend + 'static loop { futures::select! { - msg = self.transaction_service_event_stream.select_next_some() => { - trace!(target: LOG_TARGET, "Transaction Service Callback Handler event {:?}", msg); - match (*msg).clone() { - TransactionEvent::ReceivedTransaction(tx_id) => { - self.receive_transaction_event(tx_id).await; - }, - TransactionEvent::ReceivedTransactionReply(tx_id) => { - self.receive_transaction_reply_event(tx_id).await; - }, - TransactionEvent::ReceivedFinalizedTransaction(tx_id) => { - self.receive_finalized_transaction_event(tx_id).await; - }, - TransactionEvent::TransactionSendDiscoveryComplete(tx_id, result) => { - // If this event result is false we will return that result via the callback as - // no further action will be taken on this send attempt. However if it is true - // then we must wait for a `TransactionSendResult` which - // will tell us the final result of the send - if !result { - self.receive_discovery_process_result(tx_id, result); + result = self.transaction_service_event_stream.select_next_some() => { + match result { + Ok(msg) => { + trace!(target: LOG_TARGET, "Transaction Service Callback Handler event {:?}", msg); + match (*msg).clone() { + TransactionEvent::ReceivedTransaction(tx_id) => { + self.receive_transaction_event(tx_id).await; + }, + TransactionEvent::ReceivedTransactionReply(tx_id) => { + self.receive_transaction_reply_event(tx_id).await; + }, + TransactionEvent::ReceivedFinalizedTransaction(tx_id) => { + self.receive_finalized_transaction_event(tx_id).await; + }, + TransactionEvent::TransactionDirectSendResult(tx_id, result) => { + self.receive_direct_send_result(tx_id, result); + }, + TransactionEvent::TransactionStoreForwardSendResult(tx_id, result) => { + self.receive_store_and_forward_send_result(tx_id, result); + }, + TransactionEvent::TransactionCancelled(tx_id) => { + self.receive_transaction_cancellation(tx_id); + }, + TransactionEvent::TransactionBroadcast(tx_id) => { + self.receive_transaction_broadcast_event(tx_id).await; + }, + TransactionEvent::TransactionMined(tx_id) => { + self.receive_transaction_mined_event(tx_id).await; + }, + /// Only the above variants are mapped to callbacks + _ => (), } }, - TransactionEvent::TransactionBroadcast(tx_id) => { - self.receive_transaction_broadcast_event(tx_id).await; - }, - TransactionEvent::TransactionMined(tx_id) => { - self.receive_transaction_mined_event(tx_id).await; - }, - TransactionEvent::TransactionSendResult(tx_id, result) => { - self.receive_discovery_process_result(tx_id, result); - }, - /// Only the above variants are mapped to callbacks - _ => (), + Err(e) => error!(target: LOG_TARGET, "Error reading from Transaction Service event broadcast channel"), } }, msg = self.output_manager_service_event_stream.select_next_some() => { @@ -260,13 +276,33 @@ where TBackend: TransactionBackend + 'static } } - fn receive_discovery_process_result(&mut self, tx_id: TxId, result: bool) { + fn receive_direct_send_result(&mut self, tx_id: TxId, result: bool) { + debug!( + target: LOG_TARGET, + "Calling Direct Send Result callback function for TxId: {} with result {}", tx_id, result + ); + unsafe { + (self.callback_direct_send_result)(tx_id, result); + } + } + + fn receive_store_and_forward_send_result(&mut self, tx_id: TxId, result: bool) { + debug!( + target: LOG_TARGET, + "Calling Store and Forward Send Result callback function for TxId: {} with result {}", tx_id, result + ); + unsafe { + (self.callback_store_and_forward_send_result)(tx_id, result); + } + } + + fn receive_transaction_cancellation(&mut self, tx_id: TxId) { debug!( target: LOG_TARGET, - "Calling Discovery Process Completed callback function for TxId: {} with result {}", tx_id, result + "Calling Transaction Cancellation callback function for TxId: {}", tx_id ); unsafe { - (self.callback_discovery_process_complete)(tx_id, result); + (self.callback_transaction_cancellation)(tx_id); } } diff --git a/base_layer/wallet_ffi/src/error.rs b/base_layer/wallet_ffi/src/error.rs index b7ef6a24cf..2291b07c1d 100644 --- a/base_layer/wallet_ffi/src/error.rs +++ b/base_layer/wallet_ffi/src/error.rs @@ -208,6 +208,10 @@ impl From for LibWalletError { code: 301, message: format!("{:?}", w), }, + WalletError::StoreAndForwardError(_) => Self { + code: 302, + message: format!("{:?}", w), + }, WalletError::ContactsServiceError(ContactsServiceError::ContactNotFound) => Self { code: 401, message: format!("{:?}", w), diff --git a/base_layer/wallet_ffi/src/lib.rs b/base_layer/wallet_ffi/src/lib.rs index b0410a4291..7e2e11a77f 100644 --- a/base_layer/wallet_ffi/src/lib.rs +++ b/base_layer/wallet_ffi/src/lib.rs @@ -142,7 +142,7 @@ use tari_comms::{ socks, tor, }; -use tari_comms_dht::DhtConfig; +use tari_comms_dht::{DbConnectionUrl, DhtConfig}; use tari_core::transactions::{tari_amount::MicroTari, types::CryptoFactories}; use tari_crypto::{ keys::{PublicKey, SecretKey}, @@ -215,6 +215,7 @@ pub struct ByteVector(Vec); // declared like this so that it can be exp /// `()` - Does not return a value, equivalent to void in C. /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn string_destroy(ptr: *mut c_char) { if !ptr.is_null() { @@ -240,6 +241,7 @@ pub unsafe extern "C" fn string_destroy(ptr: *mut c_char) { /// element_count when it is created /// /// # Safety +/// The ```byte_vector_destroy``` function must be called when finished with a ByteVector to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn byte_vector_create( byte_array: *const c_uchar, @@ -275,6 +277,7 @@ pub unsafe extern "C" fn byte_vector_create( /// `()` - Does not return a value, equivalent to void in C /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn byte_vector_destroy(bytes: *mut ByteVector) { if !bytes.is_null() { @@ -295,6 +298,7 @@ pub unsafe extern "C" fn byte_vector_destroy(bytes: *mut ByteVector) { /// is null or if the position is invalid /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn byte_vector_get_at(ptr: *mut ByteVector, position: c_uint, error_out: *mut c_int) -> c_uchar { let mut error = 0; @@ -325,6 +329,7 @@ pub unsafe extern "C" fn byte_vector_get_at(ptr: *mut ByteVector, position: c_ui /// if ptr is null /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn byte_vector_get_length(vec: *const ByteVector, error_out: *mut c_int) -> c_uint { let mut error = 0; @@ -353,6 +358,7 @@ pub unsafe extern "C" fn byte_vector_get_length(vec: *const ByteVector, error_ou /// if there was an error with the contents of bytes /// /// # Safety +/// The ```public_key_destroy``` function must be called when finished with a TariPublicKey to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn public_key_create(bytes: *mut ByteVector, error_out: *mut c_int) -> *mut TariPublicKey { let mut error = 0; @@ -385,6 +391,7 @@ pub unsafe extern "C" fn public_key_create(bytes: *mut ByteVector, error_out: *m /// `()` - Does not return a value, equivalent to void in C /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn public_key_destroy(pk: *mut TariPublicKey) { if !pk.is_null() { @@ -403,6 +410,7 @@ pub unsafe extern "C" fn public_key_destroy(pk: *mut TariPublicKey) { /// `*mut ByteVector` - Returns a pointer to a ByteVector. Note that it returns ptr::null_mut() if pk is null /// /// # Safety +/// The ```byte_vector_destroy``` function must be called when finished with the ByteVector to prevent a memory leak. #[no_mangle] pub unsafe extern "C" fn public_key_get_bytes(pk: *mut TariPublicKey, error_out: *mut c_int) -> *mut ByteVector { let mut error = 0; @@ -429,6 +437,7 @@ pub unsafe extern "C" fn public_key_get_bytes(pk: *mut TariPublicKey, error_out: /// `*mut TariPublicKey` - Returns a pointer to a TariPublicKey /// /// # Safety +/// The ```private_key_destroy``` method must be called when finished with a private key to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn public_key_from_private_key( secret_key: *mut TariPrivateKey, @@ -458,6 +467,7 @@ pub unsafe extern "C" fn public_key_from_private_key( /// if key is null or if there was an error creating the TariPublicKey from key /// /// # Safety +/// The ```public_key_destroy``` method must be called when finished with a TariPublicKey to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn public_key_from_hex(key: *const c_char, error_out: *mut c_int) -> *mut TariPublicKey { let mut error = 0; @@ -494,6 +504,7 @@ pub unsafe extern "C" fn public_key_from_hex(key: *const c_char, error_out: *mut /// if emoji is null or if there was an error creating the emoji string from TariPublicKey /// /// # Safety +/// The ```string_destroy``` method must be called when finished with a string from rust to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn public_key_to_emoji_id(pk: *mut TariPublicKey, error_out: *mut c_int) -> *mut c_char { let mut error = 0; @@ -521,6 +532,7 @@ pub unsafe extern "C" fn public_key_to_emoji_id(pk: *mut TariPublicKey, error_ou /// `*mut c_char` - Returns a pointer to a TariPublicKey. Note that it returns null on error. /// /// # Safety +/// The ```public_key_destroy``` method must be called when finished with a TariPublicKey to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn emoji_id_to_public_key(emoji: *const c_char, error_out: *mut c_int) -> *mut TariPublicKey { let mut error = 0; @@ -561,6 +573,7 @@ pub unsafe extern "C" fn emoji_id_to_public_key(emoji: *const c_char, error_out: /// if bytes is null or if there was an error creating the TariPrivateKey from bytes /// /// # Safety +/// The ```private_key_destroy``` method must be called when finished with a TariPrivateKey to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn private_key_create(bytes: *mut ByteVector, error_out: *mut c_int) -> *mut TariPrivateKey { let mut error = 0; @@ -593,6 +606,7 @@ pub unsafe extern "C" fn private_key_create(bytes: *mut ByteVector, error_out: * /// `()` - Does not return a value, equivalent to void in C /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn private_key_destroy(pk: *mut TariPrivateKey) { if !pk.is_null() { @@ -612,6 +626,7 @@ pub unsafe extern "C" fn private_key_destroy(pk: *mut TariPrivateKey) { /// if pk is null /// /// # Safety +/// The ```byte_vector_destroy``` must be called when finished with a ByteVector to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn private_key_get_bytes(pk: *mut TariPrivateKey, error_out: *mut c_int) -> *mut ByteVector { let mut error = 0; @@ -636,6 +651,7 @@ pub unsafe extern "C" fn private_key_get_bytes(pk: *mut TariPrivateKey, error_ou /// `*mut TariPrivateKey` - Returns a pointer to a TariPrivateKey /// /// # Safety +/// The ```private_key_destroy``` method must be called when finished with a TariPrivateKey to prevent a memory leak. #[no_mangle] pub unsafe extern "C" fn private_key_generate() -> *mut TariPrivateKey { let secret_key = TariPrivateKey::random(&mut OsRng); @@ -654,6 +670,7 @@ pub unsafe extern "C" fn private_key_generate() -> *mut TariPrivateKey { /// if key is null or if there was an error creating the TariPrivateKey from key /// /// # Safety +/// The ```private_key_destroy``` method must be called when finished with a TariPrivateKey to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn private_key_from_hex(key: *const c_char, error_out: *mut c_int) -> *mut TariPrivateKey { let mut error = 0; @@ -696,6 +713,7 @@ pub unsafe extern "C" fn private_key_from_hex(key: *const c_char, error_out: *mu /// if alias is null or if pk is null /// /// # Safety +/// The ```contact_destroy``` method must be called when finished with a TariContact #[no_mangle] pub unsafe extern "C" fn contact_create( alias: *const c_char, @@ -739,6 +757,7 @@ pub unsafe extern "C" fn contact_create( /// contact is null /// /// # Safety +/// The ```string_destroy``` method must be called when finished with a string from rust to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn contact_get_alias(contact: *mut TariContact, error_out: *mut c_int) -> *mut c_char { let mut error = 0; @@ -765,6 +784,7 @@ pub unsafe extern "C" fn contact_get_alias(contact: *mut TariContact, error_out: /// ptr::null_mut() if contact is null /// /// # Safety +/// The ```public_key_destroy``` method must be called when finished with a TariPublicKey to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn contact_get_public_key( contact: *mut TariContact, @@ -790,6 +810,7 @@ pub unsafe extern "C" fn contact_get_public_key( /// `()` - Does not return a value, equivalent to void in C /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn contact_destroy(contact: *mut TariContact) { if !contact.is_null() { @@ -810,6 +831,7 @@ pub unsafe extern "C" fn contact_destroy(contact: *mut TariContact) { /// `c_uint` - Returns number of elements in , zero if contacts is null /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn contacts_get_length(contacts: *mut TariContacts, error_out: *mut c_int) -> c_uint { let mut error = 0; @@ -837,6 +859,7 @@ pub unsafe extern "C" fn contacts_get_length(contacts: *mut TariContacts, error_ /// null or position is invalid /// /// # Safety +/// The ```contact_destroy``` method must be called when finished with a TariContact to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn contacts_get_at( contacts: *mut TariContacts, @@ -869,6 +892,7 @@ pub unsafe extern "C" fn contacts_get_at( /// `()` - Does not return a value, equivalent to void in C /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn contacts_destroy(contacts: *mut TariContacts) { if !contacts.is_null() { @@ -892,6 +916,7 @@ pub unsafe extern "C" fn contacts_destroy(contacts: *mut TariContacts) { /// zero if transactions is null /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn completed_transactions_get_length( transactions: *mut TariCompletedTransactions, @@ -923,6 +948,8 @@ pub unsafe extern "C" fn completed_transactions_get_length( /// note that ptr::null_mut() is returned if transactions is null or position is invalid /// /// # Safety +/// The ```completed_transaction_destroy``` method must be called when finished with a TariCompletedTransaction to +/// prevent a memory leak #[no_mangle] pub unsafe extern "C" fn completed_transactions_get_at( transactions: *mut TariCompletedTransactions, @@ -955,6 +982,7 @@ pub unsafe extern "C" fn completed_transactions_get_at( /// `()` - Does not return a value, equivalent to void in C /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn completed_transactions_destroy(transactions: *mut TariCompletedTransactions) { if !transactions.is_null() { @@ -978,6 +1006,7 @@ pub unsafe extern "C" fn completed_transactions_destroy(transactions: *mut TariC /// zero if transactions is null /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn pending_outbound_transactions_get_length( transactions: *mut TariPendingOutboundTransactions, @@ -1010,6 +1039,8 @@ pub unsafe extern "C" fn pending_outbound_transactions_get_length( /// note that ptr::null_mut() is returned if transactions is null or position is invalid /// /// # Safety +/// The ```pending_outbound_transaction_destroy``` method must be called when finished with a +/// TariPendingOutboundTransaction to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn pending_outbound_transactions_get_at( transactions: *mut TariPendingOutboundTransactions, @@ -1042,6 +1073,7 @@ pub unsafe extern "C" fn pending_outbound_transactions_get_at( /// `()` - Does not return a value, equivalent to void in C /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn pending_outbound_transactions_destroy(transactions: *mut TariPendingOutboundTransactions) { if !transactions.is_null() { @@ -1065,6 +1097,7 @@ pub unsafe extern "C" fn pending_outbound_transactions_destroy(transactions: *mu /// it will be zero if transactions is null /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn pending_inbound_transactions_get_length( transactions: *mut TariPendingInboundTransactions, @@ -1096,6 +1129,8 @@ pub unsafe extern "C" fn pending_inbound_transactions_get_length( /// note that ptr::null_mut() is returned if transactions is null or position is invalid /// /// # Safety +/// The ```pending_inbound_transaction_destroy``` method must be called when finished with a +/// TariPendingOutboundTransaction to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn pending_inbound_transactions_get_at( transactions: *mut TariPendingInboundTransactions, @@ -1128,6 +1163,7 @@ pub unsafe extern "C" fn pending_inbound_transactions_get_at( /// `()` - Does not return a value, equivalent to void in C /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn pending_inbound_transactions_destroy(transactions: *mut TariPendingInboundTransactions) { if !transactions.is_null() { @@ -1150,6 +1186,7 @@ pub unsafe extern "C" fn pending_inbound_transactions_destroy(transactions: *mut /// `c_ulonglong` - Returns the TransactionID, note that it will be zero if transaction is null /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn completed_transaction_get_transaction_id( transaction: *mut TariCompletedTransaction, @@ -1178,6 +1215,7 @@ pub unsafe extern "C" fn completed_transaction_get_transaction_id( /// ptr::null_mut() if transaction is null /// /// # Safety +/// The ```public_key_destroy``` method must be called when finished with a TariPublicKey to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn completed_transaction_get_destination_public_key( transaction: *mut TariCompletedTransaction, @@ -1207,6 +1245,7 @@ pub unsafe extern "C" fn completed_transaction_get_destination_public_key( /// ptr::null_mut() if transaction is null /// /// # Safety +/// The ```public_key_destroy``` method must be called when finished with a TariPublicKey to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn completed_transaction_get_source_public_key( transaction: *mut TariCompletedTransaction, @@ -1243,6 +1282,7 @@ pub unsafe extern "C" fn completed_transaction_get_source_public_key( /// | 4 | Pending | /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn completed_transaction_get_status( transaction: *mut TariCompletedTransaction, @@ -1271,6 +1311,7 @@ pub unsafe extern "C" fn completed_transaction_get_status( /// `c_ulonglong` - Returns the amount, note that it will be zero if transaction is null /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn completed_transaction_get_amount( transaction: *mut TariCompletedTransaction, @@ -1298,6 +1339,7 @@ pub unsafe extern "C" fn completed_transaction_get_amount( /// `c_ulonglong` - Returns the fee, note that it will be zero if transaction is null /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn completed_transaction_get_fee( transaction: *mut TariCompletedTransaction, @@ -1325,6 +1367,7 @@ pub unsafe extern "C" fn completed_transaction_get_fee( /// `c_ulonglong` - Returns the timestamp, note that it will be zero if transaction is null /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn completed_transaction_get_timestamp( transaction: *mut TariCompletedTransaction, @@ -1353,6 +1396,7 @@ pub unsafe extern "C" fn completed_transaction_get_timestamp( /// to an empty char array if transaction is null /// /// # Safety +/// The ```string_destroy``` method must be called when finished with string coming from rust to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn completed_transaction_get_message( transaction: *mut TariCompletedTransaction, @@ -1382,6 +1426,7 @@ pub unsafe extern "C" fn completed_transaction_get_message( /// `()` - Does not return a value, equivalent to void in C /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn completed_transaction_destroy(transaction: *mut TariCompletedTransaction) { if !transaction.is_null() { @@ -1404,6 +1449,7 @@ pub unsafe extern "C" fn completed_transaction_destroy(transaction: *mut TariCom /// `c_ulonglong` - Returns the TransactionID, note that it will be zero if transaction is null /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn pending_outbound_transaction_get_transaction_id( transaction: *mut TariPendingOutboundTransaction, @@ -1432,6 +1478,7 @@ pub unsafe extern "C" fn pending_outbound_transaction_get_transaction_id( /// ptr::null_mut() if transaction is null /// /// # Safety +/// The ```public_key_destroy``` method must be called when finished with a TariPublicKey to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn pending_outbound_transaction_get_destination_public_key( transaction: *mut TariPendingOutboundTransaction, @@ -1460,6 +1507,7 @@ pub unsafe extern "C" fn pending_outbound_transaction_get_destination_public_key /// `c_ulonglong` - Returns the amount, note that it will be zero if transaction is null /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn pending_outbound_transaction_get_amount( transaction: *mut TariPendingOutboundTransaction, @@ -1487,6 +1535,7 @@ pub unsafe extern "C" fn pending_outbound_transaction_get_amount( /// `c_ulonglong` - Returns the fee, note that it will be zero if transaction is null /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn pending_outbound_transaction_get_fee( transaction: *mut TariPendingOutboundTransaction, @@ -1514,6 +1563,7 @@ pub unsafe extern "C" fn pending_outbound_transaction_get_fee( /// `c_ulonglong` - Returns the timestamp, note that it will be zero if transaction is null /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn pending_outbound_transaction_get_timestamp( transaction: *mut TariPendingOutboundTransaction, @@ -1542,6 +1592,8 @@ pub unsafe extern "C" fn pending_outbound_transaction_get_timestamp( /// to an empty char array if transaction is null /// /// # Safety +/// The ```string_destroy``` method must be called when finished with a string coming from rust to prevent a memory +/// leak #[no_mangle] pub unsafe extern "C" fn pending_outbound_transaction_get_message( transaction: *mut TariPendingOutboundTransaction, @@ -1581,6 +1633,7 @@ pub unsafe extern "C" fn pending_outbound_transaction_get_message( /// | 4 | Pending | /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn pending_outbound_transaction_get_status( transaction: *mut TariPendingOutboundTransaction, @@ -1607,6 +1660,7 @@ pub unsafe extern "C" fn pending_outbound_transaction_get_status( /// `()` - Does not return a value, equivalent to void in C /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn pending_outbound_transaction_destroy(transaction: *mut TariPendingOutboundTransaction) { if !transaction.is_null() { @@ -1629,6 +1683,7 @@ pub unsafe extern "C" fn pending_outbound_transaction_destroy(transaction: *mut /// `c_ulonglong` - Returns the TransactonId, note that it will be zero if transaction is null /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn pending_inbound_transaction_get_transaction_id( transaction: *mut TariPendingInboundTransaction, @@ -1657,6 +1712,7 @@ pub unsafe extern "C" fn pending_inbound_transaction_get_transaction_id( /// ptr::null_mut() if transaction is null /// /// # Safety +/// The ```public_key_destroy``` method must be called when finished with a TariPublicKey to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn pending_inbound_transaction_get_source_public_key( transaction: *mut TariPendingInboundTransaction, @@ -1685,6 +1741,7 @@ pub unsafe extern "C" fn pending_inbound_transaction_get_source_public_key( /// `c_ulonglong` - Returns the amount, note that it will be zero if transaction is null /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn pending_inbound_transaction_get_amount( transaction: *mut TariPendingInboundTransaction, @@ -1712,6 +1769,7 @@ pub unsafe extern "C" fn pending_inbound_transaction_get_amount( /// `c_ulonglong` - Returns the timestamp, note that it will be zero if transaction is null /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn pending_inbound_transaction_get_timestamp( transaction: *mut TariPendingInboundTransaction, @@ -1740,6 +1798,8 @@ pub unsafe extern "C" fn pending_inbound_transaction_get_timestamp( /// to an empty char array if transaction is null /// /// # Safety +/// The ```string_destroy``` method must be called when finished with a string coming from rust to prevent a memory +/// leak #[no_mangle] pub unsafe extern "C" fn pending_inbound_transaction_get_message( transaction: *mut TariPendingInboundTransaction, @@ -1779,6 +1839,7 @@ pub unsafe extern "C" fn pending_inbound_transaction_get_message( /// | 4 | Pending | /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn pending_inbound_transaction_get_status( transaction: *mut TariPendingInboundTransaction, @@ -1805,6 +1866,7 @@ pub unsafe extern "C" fn pending_inbound_transaction_get_status( /// `()` - Does not return a value, equivalent to void in C /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn pending_inbound_transaction_destroy(transaction: *mut TariPendingInboundTransaction) { if !transaction.is_null() { @@ -1825,6 +1887,8 @@ pub unsafe extern "C" fn pending_inbound_transaction_destroy(transaction: *mut T /// `*mut TariTransportType` - Returns a pointer to a memory TariTransportType /// /// # Safety +/// The ```transport_type_destroy``` method must be called when finished with a TariTransportType to prevent a memory +/// leak #[no_mangle] pub unsafe extern "C" fn transport_memory_create() -> *mut TariTransportType { let transport = TariTransportType::Memory { @@ -1844,6 +1908,8 @@ pub unsafe extern "C" fn transport_memory_create() -> *mut TariTransportType { /// `*mut TariTransportType` - Returns a pointer to a tcp TariTransportType, null on error. /// /// # Safety +/// The ```transport_type_destroy``` method must be called when finished with a TariTransportType to prevent a memory +/// leak #[no_mangle] pub unsafe extern "C" fn transport_tcp_create( listener_address: *const c_char, @@ -1884,6 +1950,8 @@ pub unsafe extern "C" fn transport_tcp_create( /// `*mut TariTransportType` - Returns a pointer to a tor TariTransportType, null on error. /// /// # Safety +/// The ```transport_type_destroy``` method must be called when finished with a TariTransportType to prevent a memory +/// leak #[no_mangle] pub unsafe extern "C" fn transport_tor_create( control_server_address: *const c_char, @@ -1967,6 +2035,7 @@ pub unsafe extern "C" fn transport_tor_create( /// `*mut c_char` - Returns the address as a pointer to a char array, array will be empty on error /// /// # Safety +/// Can only be used with a memory transport type, will crash otherwise #[no_mangle] pub unsafe extern "C" fn transport_memory_get_address( transport: *const TariTransportType, @@ -2006,6 +2075,7 @@ pub unsafe extern "C" fn transport_memory_get_address( /// be empty on error. /// /// # Safety +/// Can only be used with a tor transport type, will crash otherwise #[no_mangle] pub unsafe extern "C" fn wallet_get_tor_identity(wallet: *const TariWallet, error_out: *mut c_int) -> *mut ByteVector { let mut error = 0; @@ -2073,6 +2143,7 @@ pub unsafe extern "C" fn transport_type_destroy(transport: *mut TariTransportTyp /// null or a problem is encountered when constructing the NetAddress a ptr::null_mut() is returned /// /// # Safety +/// The ```comms_config_destroy``` method must be called when finished with a TariCommsConfig to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn comms_config_create( public_address: *const c_char, @@ -2080,6 +2151,7 @@ pub unsafe extern "C" fn comms_config_create( database_name: *const c_char, datastore_path: *const c_char, secret_key: *mut TariPrivateKey, + discovery_timeout_in_secs: c_ulonglong, error_out: *mut c_int, ) -> *mut TariCommsConfig { @@ -2111,7 +2183,9 @@ pub unsafe extern "C" fn comms_config_create( ptr::swap(error_out, &mut error as *mut c_int); return ptr::null_mut(); } + let datastore_path = PathBuf::from(datastore_path_string); + let dht_database_path = datastore_path.join("dht.db"); let public_address = public_address_str.parse::(); match public_address { @@ -2126,12 +2200,13 @@ pub unsafe extern "C" fn comms_config_create( let config = TariCommsConfig { node_identity: Arc::new(ni), transport_type: (*transport_type).clone(), - datastore_path: PathBuf::from(datastore_path_string), + datastore_path, peer_database_name: database_name_string, max_concurrent_inbound_tasks: 100, outbound_buffer_size: 100, dht: DhtConfig { - discovery_request_timeout: Duration::from_secs(30), + discovery_request_timeout: Duration::from_secs(discovery_timeout_in_secs), + database_url: DbConnectionUrl::File(dht_database_path), ..Default::default() }, // TODO: This should be set to false for non-test wallets. See the `allow_test_addresses` field @@ -2167,6 +2242,7 @@ pub unsafe extern "C" fn comms_config_create( /// `()` - Does not return a value, equivalent to void in C /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn comms_config_destroy(wc: *mut TariCommsConfig) { if !wc.is_null() { @@ -2208,6 +2284,7 @@ pub unsafe extern "C" fn comms_config_destroy(wc: *mut TariCommsConfig) { /// if config is null, a wallet error was encountered or if the runtime could not be created /// /// # Safety +/// The ```wallet_destroy``` method must be called when finished with a TariWallet to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn wallet_create( config: *mut TariCommsConfig, @@ -2217,7 +2294,9 @@ pub unsafe extern "C" fn wallet_create( callback_received_finalized_transaction: unsafe extern "C" fn(*mut TariCompletedTransaction), callback_transaction_broadcast: unsafe extern "C" fn(*mut TariCompletedTransaction), callback_transaction_mined: unsafe extern "C" fn(*mut TariCompletedTransaction), - callback_discovery_process_complete: unsafe extern "C" fn(c_ulonglong, bool), + callback_direct_send_result: unsafe extern "C" fn(c_ulonglong, bool), + callback_store_and_forward_send_result: unsafe extern "C" fn(c_ulonglong, bool), + callback_transaction_cancellation: unsafe extern "C" fn(c_ulonglong), callback_base_node_sync_complete: unsafe extern "C" fn(u64, bool), error_out: *mut c_int, ) -> *mut TariWallet @@ -2230,17 +2309,8 @@ pub unsafe extern "C" fn wallet_create( return ptr::null_mut(); } - let logging_path_string = if !log_path.is_null() { - Some(CStr::from_ptr(log_path).to_str().unwrap().to_owned()) - } else { - None - }; - - let runtime = Runtime::new(); - let factories = CryptoFactories::default(); - let w; - - if let Some(path) = logging_path_string { + if !log_path.is_null() { + let path = CStr::from_ptr(log_path).to_str().unwrap().to_owned(); let logfile = FileAppender::builder() .encoder(Box::new(PatternEncoder::new( "{d(%Y-%m-%d %H:%M:%S.%f)} [{t}] {l:5} {m}{n}", @@ -2258,6 +2328,10 @@ pub unsafe extern "C" fn wallet_create( debug!(target: LOG_TARGET, "Logging started"); } + let runtime = Runtime::new(); + let factories = CryptoFactories::default(); + let w; + match runtime { Ok(runtime) => { let sql_database_path = (*config) @@ -2305,7 +2379,9 @@ pub unsafe extern "C" fn wallet_create( callback_received_finalized_transaction, callback_transaction_broadcast, callback_transaction_mined, - callback_discovery_process_complete, + callback_direct_send_result, + callback_store_and_forward_send_result, + callback_transaction_cancellation, callback_base_node_sync_complete, ); @@ -2340,6 +2416,7 @@ pub unsafe extern "C" fn wallet_create( /// public nonce, seperated by a pipe character. Empty if an error occured. /// /// # Safety +/// The ```string_destroy``` method must be called when finished with a string coming from rust to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn wallet_sign_message( wallet: *mut TariWallet, @@ -2398,6 +2475,7 @@ pub unsafe extern "C" fn wallet_sign_message( /// `bool` - Returns if the signature is valid or not, will be false if an error occurs. /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn wallet_verify_message_signature( wallet: *mut TariWallet, @@ -2479,6 +2557,7 @@ pub unsafe extern "C" fn wallet_verify_message_signature( /// `bool` - Returns if successful or not /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn wallet_test_generate_data( wallet: *mut TariWallet, @@ -2528,6 +2607,7 @@ pub unsafe extern "C" fn wallet_test_generate_data( /// `bool` - Returns if successful or not /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn wallet_test_receive_transaction(wallet: *mut TariWallet, error_out: *mut c_int) -> bool { let mut error = 0; @@ -2561,6 +2641,7 @@ pub unsafe extern "C" fn wallet_test_receive_transaction(wallet: *mut TariWallet /// `bool` - Returns if successful or not /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn wallet_test_complete_sent_transaction( wallet: *mut TariWallet, @@ -2602,6 +2683,7 @@ pub unsafe extern "C" fn wallet_test_complete_sent_transaction( /// `bool` - Returns if the transaction was originally sent from the wallet /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn wallet_is_completed_transaction_outbound( wallet: *mut TariWallet, @@ -2642,6 +2724,7 @@ pub unsafe extern "C" fn wallet_is_completed_transaction_outbound( /// `bool` - Returns if successful or not /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn wallet_test_finalize_received_transaction( wallet: *mut TariWallet, @@ -2680,6 +2763,7 @@ pub unsafe extern "C" fn wallet_test_finalize_received_transaction( /// `bool` - Returns if successful or not /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn wallet_test_broadcast_transaction( wallet: *mut TariWallet, @@ -2719,6 +2803,7 @@ pub unsafe extern "C" fn wallet_test_broadcast_transaction( /// `bool` - Returns if successful or not /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn wallet_test_mine_transaction( wallet: *mut TariWallet, @@ -2756,6 +2841,7 @@ pub unsafe extern "C" fn wallet_test_mine_transaction( /// `bool` - Returns if successful or not /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn wallet_add_base_node_peer( wallet: *mut TariWallet, @@ -2810,6 +2896,7 @@ pub unsafe extern "C" fn wallet_add_base_node_peer( /// `bool` - Returns if successful or not /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn wallet_upsert_contact( wallet: *mut TariWallet, @@ -2855,6 +2942,7 @@ pub unsafe extern "C" fn wallet_upsert_contact( /// `bool` - Returns if successful or not /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn wallet_remove_contact( wallet: *mut TariWallet, @@ -2899,6 +2987,7 @@ pub unsafe extern "C" fn wallet_remove_contact( /// `c_ulonglong` - The available balance, 0 if wallet is null /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn wallet_get_available_balance(wallet: *mut TariWallet, error_out: *mut c_int) -> c_ulonglong { let mut error = 0; @@ -2934,6 +3023,7 @@ pub unsafe extern "C" fn wallet_get_available_balance(wallet: *mut TariWallet, e /// `c_ulonglong` - The incoming balance, 0 if wallet is null /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn wallet_get_pending_incoming_balance( wallet: *mut TariWallet, @@ -2973,6 +3063,7 @@ pub unsafe extern "C" fn wallet_get_pending_incoming_balance( /// `c_ulonglong` - The outgoing balance, 0 if wallet is null /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn wallet_get_pending_outgoing_balance( wallet: *mut TariWallet, @@ -3012,9 +3103,10 @@ pub unsafe extern "C" fn wallet_get_pending_outgoing_balance( /// as an out parameter. /// /// ## Returns -/// `bool` - Returns if successful or not +/// `unsigned long long` - Returns 0 if unsuccessful or the TxId of the sent transaction if successful /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn wallet_send_transaction( wallet: *mut TariWallet, @@ -3023,20 +3115,20 @@ pub unsafe extern "C" fn wallet_send_transaction( fee_per_gram: c_ulonglong, message: *const c_char, error_out: *mut c_int, -) -> bool +) -> c_ulonglong { let mut error = 0; ptr::swap(error_out, &mut error as *mut c_int); if wallet.is_null() { error = LibWalletError::from(InterfaceError::NullError("wallet".to_string())).code; ptr::swap(error_out, &mut error as *mut c_int); - return false; + return 0; } if dest_public_key.is_null() { error = LibWalletError::from(InterfaceError::NullError("dest_public_key".to_string())).code; ptr::swap(error_out, &mut error as *mut c_int); - return false; + return 0; } let message_string = if !message.is_null() { @@ -3055,11 +3147,11 @@ pub unsafe extern "C" fn wallet_send_transaction( MicroTari::from(fee_per_gram), message_string, )) { - Ok(_) => true, + Ok(tx_id) => tx_id, Err(e) => { error = LibWalletError::from(WalletError::TransactionServiceError(e)).code; ptr::swap(error_out, &mut error as *mut c_int); - false + 0 }, } } @@ -3076,6 +3168,7 @@ pub unsafe extern "C" fn wallet_send_transaction( /// wallet is null /// /// # Safety +/// The ```contacts_destroy``` method must be called when finished with a TariContacts to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn wallet_get_contacts(wallet: *mut TariWallet, error_out: *mut c_int) -> *mut TariContacts { let mut error = 0; @@ -3113,6 +3206,8 @@ pub unsafe extern "C" fn wallet_get_contacts(wallet: *mut TariWallet, error_out: /// wallet is null or an error is encountered /// /// # Safety +/// The ```completed_transactions_destroy``` method must be called when finished with a TariCompletedTransactions to +/// prevent a memory leak #[no_mangle] pub unsafe extern "C" fn wallet_get_completed_transactions( wallet: *mut TariWallet, @@ -3168,6 +3263,8 @@ pub unsafe extern "C" fn wallet_get_completed_transactions( /// wallet is null or and error is encountered /// /// # Safety +/// The ```pending_inbound_transactions_destroy``` method must be called when finished with a +/// TariPendingInboundTransactions to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn wallet_get_pending_inbound_transactions( wallet: *mut TariWallet, @@ -3235,6 +3332,8 @@ pub unsafe extern "C" fn wallet_get_pending_inbound_transactions( /// wallet is null or and error is encountered /// /// # Safety +/// The ```pending_outbound_transactions_destroy``` method must be called when finished with a +/// TariPendingOutboundTransactions to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn wallet_get_pending_outbound_transactions( wallet: *mut TariWallet, @@ -3298,6 +3397,8 @@ pub unsafe extern "C" fn wallet_get_pending_outbound_transactions( /// wallet is null, an error is encountered or if the transaction is not found /// /// # Safety +/// The ```completed_transaction_destroy``` method must be called when finished with a TariCompletedTransaction to +/// prevent a memory leak #[no_mangle] pub unsafe extern "C" fn wallet_get_completed_transaction_by_id( wallet: *mut TariWallet, @@ -3350,6 +3451,8 @@ pub unsafe extern "C" fn wallet_get_completed_transaction_by_id( /// wallet is null, an error is encountered or if the transaction is not found /// /// # Safety +/// The ```pending_inbound_transaction_destroy``` method must be called when finished with a +/// TariPendingInboundTransaction to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn wallet_get_pending_inbound_transaction_by_id( wallet: *mut TariWallet, @@ -3420,6 +3523,8 @@ pub unsafe extern "C" fn wallet_get_pending_inbound_transaction_by_id( /// wallet is null, an error is encountered or if the transaction is not found /// /// # Safety +/// The ```pending_outbound_transaction_destroy``` method must be called when finished with a +/// TariPendingOutboundtransaction to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn wallet_get_pending_outbound_transaction_by_id( wallet: *mut TariWallet, @@ -3489,6 +3594,7 @@ pub unsafe extern "C" fn wallet_get_pending_outbound_transaction_by_id( /// if wc is null /// /// # Safety +/// The ```public_key_destroy``` method must be called when finished with a TariPublicKey to prevent a memory leak #[no_mangle] pub unsafe extern "C" fn wallet_get_public_key(wallet: *mut TariWallet, error_out: *mut c_int) -> *mut TariPublicKey { let mut error = 0; @@ -3519,6 +3625,7 @@ pub unsafe extern "C" fn wallet_get_public_key(wallet: *mut TariWallet, error_ou /// null /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn wallet_import_utxo( wallet: *mut TariWallet, @@ -3572,8 +3679,50 @@ pub unsafe extern "C" fn wallet_import_utxo( } } +/// Cancel a Pending Outbound Transaction +/// +/// ## Arguments +/// `wallet` - The TariWallet pointer +/// `transaction_id` - The TransactionId +/// `error_out` - Pointer to an int which will be modified to an error code should one occur, may not be null. Functions +/// as an out parameter. +/// +/// ## Returns +/// `bool` - returns whether the transaction could be cancelled +/// +/// # Safety +/// None +#[no_mangle] +pub unsafe extern "C" fn wallet_cancel_pending_transaction( + wallet: *mut TariWallet, + transaction_id: c_ulonglong, + error_out: *mut c_int, +) -> bool +{ + let mut error = 0; + ptr::swap(error_out, &mut error as *mut c_int); + if wallet.is_null() { + error = LibWalletError::from(InterfaceError::NullError("wallet".to_string())).code; + ptr::swap(error_out, &mut error as *mut c_int); + return false; + } + + match (*wallet) + .runtime + .block_on((*wallet).transaction_service.cancel_transaction(transaction_id)) + { + Ok(_) => true, + Err(e) => { + error = LibWalletError::from(WalletError::TransactionServiceError(e)).code; + ptr::swap(error_out, &mut error as *mut c_int); + false + }, + } +} + /// This function will tell the wallet to query the set base node to confirm the status of wallet data. For example this -/// will check that Unspent Outputs stored in the wallet are still available as UTXO's on the blockchain +/// will check that Unspent Outputs stored in the wallet are still available as UTXO's on the blockchain. This will also +/// trigger a request for outstanding SAF messages to you neighbours /// /// ## Arguments /// `wallet` - The TariWallet pointer @@ -3585,6 +3734,7 @@ pub unsafe extern "C" fn wallet_import_utxo( /// request. Note the result will be 0 if there was an error /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn wallet_sync_with_base_node(wallet: *mut TariWallet, error_out: *mut c_int) -> c_ulonglong { let mut error = 0; @@ -3614,6 +3764,7 @@ pub unsafe extern "C" fn wallet_sync_with_base_node(wallet: *mut TariWallet, err /// `()` - Does not return a value, equivalent to void in C /// /// # Safety +/// None #[no_mangle] pub unsafe extern "C" fn wallet_destroy(wallet: *mut TariWallet) { if !wallet.is_null() { @@ -3622,6 +3773,22 @@ pub unsafe extern "C" fn wallet_destroy(wallet: *mut TariWallet) { } } +/// This function will log the provided string at debug level. To be used to have a client log messages to the LibWallet +/// logs. +/// +/// ## Arguments +/// `msg` - A string that will be logged at the debug level. If msg is null nothing will be done. +/// +/// # Safety +/// None +#[no_mangle] +pub unsafe extern "C" fn log_debug_message(msg: *const c_char) { + if !msg.is_null() { + let message = CStr::from_ptr(msg).to_str().unwrap().to_owned(); + debug!(target: LOG_TARGET, "{}", message); + } +} + #[cfg(test)] mod test { extern crate libc; @@ -3644,7 +3811,9 @@ mod test { pub received_finalized_tx_callback_called: bool, pub broadcast_tx_callback_called: bool, pub mined_tx_callback_called: bool, - pub discovery_send_callback_called: bool, + pub direct_send_callback_called: bool, + pub store_and_forward_send_callback_called: bool, + pub tx_cancellation_callback_called: bool, pub base_node_sync_callback_called: bool, } @@ -3656,8 +3825,10 @@ mod test { received_finalized_tx_callback_called: false, broadcast_tx_callback_called: false, mined_tx_callback_called: false, - discovery_send_callback_called: false, + direct_send_callback_called: false, + store_and_forward_send_callback_called: false, base_node_sync_callback_called: false, + tx_cancellation_callback_called: false, } } @@ -3667,7 +3838,9 @@ mod test { self.received_finalized_tx_callback_called = false; self.broadcast_tx_callback_called = false; self.mined_tx_callback_called = false; - self.discovery_send_callback_called = false; + self.direct_send_callback_called = false; + self.store_and_forward_send_callback_called = false; + self.tx_cancellation_callback_called = false; self.base_node_sync_callback_called = false; } } @@ -3743,7 +3916,15 @@ mod test { completed_transaction_destroy(tx); } - unsafe extern "C" fn discovery_process_complete_callback(_tx_id: c_ulonglong, _result: bool) { + unsafe extern "C" fn direct_send_callback(_tx_id: c_ulonglong, _result: bool) { + assert!(true); + } + + unsafe extern "C" fn store_and_forward_send_callback(_tx_id: c_ulonglong, _result: bool) { + assert!(true); + } + + unsafe extern "C" fn tx_cancellation_callback(_tx_id: c_ulonglong) { assert!(true); } @@ -3800,7 +3981,15 @@ mod test { completed_transaction_destroy(tx); } - unsafe extern "C" fn discovery_process_complete_callback_bob(_tx_id: c_ulonglong, _result: bool) { + unsafe extern "C" fn direct_send_callback_bob(_tx_id: c_ulonglong, _result: bool) { + assert!(true); + } + + unsafe extern "C" fn store_and_forward_send_callback_bob(_tx_id: c_ulonglong, _result: bool) { + assert!(true); + } + + unsafe extern "C" fn tx_cancellation_callback_bob(_tx_id: c_ulonglong) { assert!(true); } @@ -4046,7 +4235,6 @@ mod test { let mut lock = CALLBACK_STATE_FFI.lock().unwrap(); lock.reset(); } - let mut error = 0; let error_ptr = &mut error as *mut c_int; let secret_key_alice = private_key_generate(); @@ -4067,9 +4255,10 @@ mod test { db_name_alice_str, db_path_alice_str, secret_key_alice, + 20, error_ptr, ); - (*alice_config).allow_test_addresses = true; + let alice_wallet = wallet_create( alice_config, ptr::null(), @@ -4078,7 +4267,9 @@ mod test { received_tx_finalized_callback, broadcast_callback, mined_callback, - discovery_process_complete_callback, + direct_send_callback, + store_and_forward_send_callback, + tx_cancellation_callback, base_node_sync_process_complete_callback, error_ptr, ); @@ -4099,6 +4290,7 @@ mod test { db_name_bob_str, db_path_bob_str, secret_key_bob, + 20, error_ptr, ); let bob_wallet = wallet_create( @@ -4109,7 +4301,9 @@ mod test { received_tx_finalized_callback_bob, broadcast_callback_bob, mined_callback_bob, - discovery_process_complete_callback_bob, + direct_send_callback_bob, + store_and_forward_send_callback_bob, + tx_cancellation_callback_bob, base_node_sync_process_complete_callback_bob, error_ptr, ); diff --git a/base_layer/wallet_ffi/wallet.h b/base_layer/wallet_ffi/wallet.h index 5601326dd4..27d5833194 100644 --- a/base_layer/wallet_ffi/wallet.h +++ b/base_layer/wallet_ffi/wallet.h @@ -319,11 +319,12 @@ void pending_inbound_transactions_destroy(struct TariPendingInboundTransactions /// -------------------------------- TariCommsConfig ----------------------------------------------- /// // Creates a TariCommsConfig -struct TariCommsConfig *comms_config_create(char *public_address, +struct TariCommsConfig *comms_config_create(const char *public_address, struct TariTransportType *transport, - char *database_name, - char *datastore_path, + const char *database_name, + const char *datastore_path, struct TariPrivateKey *secret_key, + unsigned long long discovery_timeout_in_secs, int* error_out); // Frees memory for a TariCommsConfig @@ -333,13 +334,15 @@ void comms_config_destroy(struct TariCommsConfig *wc); // Creates a TariWallet struct TariWallet *wallet_create(struct TariWalletConfig *config, - char *log_path, + const char *log_path, void (*callback_received_transaction)(struct TariPendingInboundTransaction*), void (*callback_received_transaction_reply)(struct TariCompletedTransaction*), void (*callback_received_finalized_transaction)(struct TariCompletedTransaction*), void (*callback_transaction_broadcast)(struct TariCompletedTransaction*), void (*callback_transaction_mined)(struct TariCompletedTransaction*), - void (*callback_discovery_process_complete)(unsigned long long, bool), + void (*callback_direct_send_result)(unsigned long long, bool), + void (*callback_store_and_forward_send_result)(unsigned long long, bool), + void (*callback_transaction_cancellation)(unsigned long long), void (*callback_base_node_sync_complete)(unsigned long long, bool), int* error_out); @@ -350,7 +353,7 @@ char* wallet_sign_message(struct TariWallet *wallet, const char* msg, int* error bool wallet_verify_message_signature(struct TariWallet *wallet, struct TariPublicKey *public_key, const char* hex_sig_nonce, const char* msg, int* error_out); /// Generates test data -bool wallet_test_generate_data(struct TariWallet *wallet, char *datastore_path,int* error_out); +bool wallet_test_generate_data(struct TariWallet *wallet, const char *datastore_path,int* error_out); // Adds a base node peer to the TariWallet bool wallet_add_base_node_peer(struct TariWallet *wallet, struct TariPublicKey *public_key, const char *address,int* error_out); @@ -371,7 +374,7 @@ unsigned long long wallet_get_pending_incoming_balance(struct TariWallet *wallet unsigned long long wallet_get_pending_outgoing_balance(struct TariWallet *wallet,int* error_out); // Sends a TariPendingOutboundTransaction -bool wallet_send_transaction(struct TariWallet *wallet, struct TariPublicKey *destination, unsigned long long amount, unsigned long long fee_per_gram,const char *message,int* error_out); +unsigned long long wallet_send_transaction(struct TariWallet *wallet, struct TariPublicKey *destination, unsigned long long amount, unsigned long long fee_per_gram,const char *message,int* error_out); // Get the TariContacts from a TariWallet struct TariContacts *wallet_get_contacts(struct TariWallet *wallet,int* error_out); @@ -423,9 +426,15 @@ bool wallet_test_mine_transaction(struct TariWallet *wallet, unsigned long long // Simulates a TariPendingInboundtransaction being received bool wallet_test_receive_transaction(struct TariWallet *wallet,int* error_out); +/// Cancel a Pending Outbound Transaction +bool wallet_cancel_pending_transaction(struct TariWallet *wallet, unsigned long long transaction_id, int* error_out); + // Frees memory for a TariWallet void wallet_destroy(struct TariWallet *wallet); +/// This function will log the provided string at debug level. To be used to have a client log messages to the LibWallet +void log_debug_message(const char* msg); + #ifdef __cplusplus } #endif diff --git a/buildtools/base_node.Dockerfile b/buildtools/base_node.Dockerfile index 6f27246149..13a421d8f9 100644 --- a/buildtools/base_node.Dockerfile +++ b/buildtools/base_node.Dockerfile @@ -26,7 +26,7 @@ FROM base COPY --from=builder /tari_base_node/target/release/tari_base_node /usr/bin/ #COPY --from=builder /basenode/target/release/tari_base_node.d /etc/tari_base_node.d RUN mkdir /etc/tari_base_node.d -COPY --from=builder /tari_base_node/config/tari_config_sample.toml /etc/tari_base_node.d/tari_config_sample.toml +COPY --from=builder /tari_base_node/common/config/tari_config_sample.toml /etc/tari_base_node.d/tari_config_sample.toml COPY --from=builder /tari_base_node/common/logging/log4rs-sample.yml /root/.tari/log4rs.yml CMD ["tari_base_node"] diff --git a/common/Cargo.toml b/common/Cargo.toml index 71b2b8efac..277b040ea7 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -6,13 +6,15 @@ repository = "https://github.com/tari-project/tari" homepage = "https://tari.com" readme = "README.md" license = "BSD-3-Clause" -version = "0.0.10" +version = "0.1.0" edition = "2018" [dependencies] -clap = "2.33.0" -config = { version = "0.9.3" } +structopt = { version = "0.3.13", default_features = false } +config = { version = "0.9.3", default_features = false, features = ["toml"] } +serde = { version = "1.0.106", default_features = false } +serde_json = "1.0.51" dirs = "2.0" get_if_addrs = "0.5.3" log = "0.4.8" @@ -24,3 +26,5 @@ sha2 = "0.8.0" [dev-dependencies] tempdir = "0.3.7" tari_test_utils = { version = "^0.0", path = "../infrastructure/test_utils"} +serde = { version = "1.0.106", features = ["derive"] } +anyhow = "1.0" diff --git a/config/README.md b/common/config/README.md similarity index 100% rename from config/README.md rename to common/config/README.md diff --git a/common/config/presets/rincewind-simple.toml b/common/config/presets/rincewind-simple.toml new file mode 100644 index 0000000000..c96a37c0fd --- /dev/null +++ b/common/config/presets/rincewind-simple.toml @@ -0,0 +1,33 @@ +# A simple set of sane defaults for connecting to the Rincewind testnet +[common] +#peer_database = "~/.tari/peers" + +[base_node] +network = "rincewind" + +[base_node.rincewind] +db_type = "lmdb" +transport = "tor" +peer_seeds = [ + #t-tbn-nvir + "06e98e9c5eb52bd504836edec1878eccf12eb9f26a5fe5ec0e279423156e657a::/onion3/bsmuof2cn4y2ysz253gzsvg3s72fcgh4f3qcm3hdlxdtcwe6al2dicyd:18141", + + #t-tbn-ncal + "3a5081a0c9ff72b2d5cf52f8d78cc5a206d643259cdeb7d934512f519e090e6c::/onion3/gfynjxfm7rcxcekhu6jwuhvyyfdjvmruvvehurmlte565v74oinv2lad:18141", + + #t-tbn-oregon + "e6f3c83dc592af45ede5424974f52c776e9e6859e530066e57c1b0dd59d9b61c::/onion3/ixihtmcbvr2q53bj23bsla5hi7ek37odnsxzkapr7znyfneqxzhq7zad:18141", + + #t-tbn-london + "ce2254825d0e0294d31a86c6aac18f83c9a7b3d01d9cdb6866b4b2af8fd3fd17::/onion3/gm7kxmr4cyjg5fhcw4onav2ofa3flscrocfxiiohdcys3kshdhcjeuyd:18141", + + #t-tbn-stockholm + "461d4d7be657521969896f90e3f611f0c4e902ca33d3b808c03357ad86fd7801::/onion3/4me2aw6auq2ucql34uuvsegtjxmvsmcyk55qprtrpqxvu3whxajvb5ad:18141", + + #t-tbn-seoul + "d440b328e69b20dd8ee6c4a61aeb18888939f0f67cf96668840b7f72055d834c::/onion3/j5x7xkcxnrich5lcwibwszd5kylclbf6a5unert5sy6ykio2kphnopad:18141", + + #t-tbn-sydney + "b81b4071f72418cc410166d9baf0c6ef7a8c309e64671fafbbed88f7e1ee7709::/onion3/lwwcv4nq7epgem5vdcawom4mquqsw2odbwfcjzv3j6sksx4gr24e52ad:18141", +] +enable_mining = false diff --git a/config/tari.config.json b/common/config/tari.config.json similarity index 100% rename from config/tari.config.json rename to common/config/tari.config.json diff --git a/config/tari_config_sample.toml b/common/config/tari_config_sample.toml similarity index 100% rename from config/tari_config_sample.toml rename to common/config/tari_config_sample.toml diff --git a/common/examples/base_node_init.rs b/common/examples/base_node_init.rs new file mode 100644 index 0000000000..b743bcd517 --- /dev/null +++ b/common/examples/base_node_init.rs @@ -0,0 +1,77 @@ +use serde::{Deserialize, Serialize}; +use structopt::StructOpt; +use tari_common::{ConfigBootstrap, DefaultConfigLoader, NetworkConfigPath}; + +#[derive(StructOpt, Debug)] +/// The reference Tari cryptocurrency base node implementation +struct Arguments { + /// Custom application parameters might eb specified as usual + #[structopt(long, default_value = "any structopt options allowed")] + my_param: String, + #[structopt(flatten)] + bootstrap: ConfigBootstrap, +} + +// Following config does not require any keys customization +// and might be deserialized just as +// `let my_config: BasicConfig = config.try_into()?` +#[derive(Deserialize, Debug)] +struct BasicConfig { + #[serde(default = "welcome")] + welcome_message: String, +} +fn welcome() -> String { + "welcome from tari_common".into() +} + +// Following config is loading from key `my_node.{network}` where +// `{network} = my_node.use_network` parameter. +// This achieved with DefaultConfigLoader trait, which inhertis default impl +// when struct implements Serialize, Deserialize, Default and NetworkConfigPath. +// ```ignore +// let my_config = MyNodeConfig::try_from(&config)? +// ``` +#[derive(Serialize, Deserialize, Debug)] +struct MyNodeConfig { + welcome_message: String, + goodbye_message: String, +} +impl Default for MyNodeConfig { + fn default() -> Self { + Self { + welcome_message: welcome(), + goodbye_message: "bye bye".into(), + } + } +} +impl NetworkConfigPath for MyNodeConfig { + fn main_key_prefix() -> &'static str { + "my_node" + } +} + +fn main() -> anyhow::Result<()> { + Arguments::clap().print_help()?; + let mut args = Arguments::from_args(); + args.bootstrap.init_dirs()?; + println!("CLI arguments:\n"); + dbg!(&args); + + let mut config = args.bootstrap.load_configuration()?; + + // load basic config directly via Deserialize trait: + let basic_config: BasicConfig = config.clone().try_into()?; + assert_eq!(basic_config.welcome_message, welcome()); + + let my_config: MyNodeConfig = MyNodeConfig::load_from(&config)?; + assert_eq!(my_config.welcome_message, welcome()); + + config.set("my_node.use_network", "mainnet")?; + config.set("my_node.mainnet.welcome_message", "welcome from mainnet")?; + + let my_config = MyNodeConfig::load_from(&config)?; + assert_eq!(my_config.welcome_message, "welcome from mainnet".to_string()); + assert_eq!(my_config.goodbye_message, "bye bye".to_string()); + + Ok(()) +} diff --git a/common/examples/mempool_config.rs b/common/examples/mempool_config.rs new file mode 100644 index 0000000000..c88ef8c574 --- /dev/null +++ b/common/examples/mempool_config.rs @@ -0,0 +1,153 @@ +use serde::{Deserialize, Serialize}; +use std::time::Duration; +use tari_common::{ConfigurationError, DefaultConfigLoader, NetworkConfigPath}; + +const UNCONFIRMED_STORAGE_CAPACITY: usize = 1024; +const ORPHAN_STORAGE_CAPACITY: usize = 2048; +const PENDING_STORAGE_CAPACITY: usize = 4096; +const REORG_STORAGE_CAPACITY: usize = 512; +const UNCONFIRMED_TX_SKIP: usize = 2; +const ORPHAN_TX_TTL: Duration = Duration::from_secs(2); +const REORG_TX_TTL: Duration = Duration::from_secs(10); + +#[derive(Clone, Copy, Serialize, Deserialize)] +pub struct UnconfirmedPoolConfig { + /// The maximum number of transactions that can be stored in the Unconfirmed Transaction pool + #[serde(rename = "unconfirmed_pool_storage_capacity")] + pub storage_capacity: usize, + /// The maximum number of transactions that can be skipped when compiling a set of highest priority transactions, + /// skipping over large transactions are performed in an attempt to fit more transactions into the remaining space. + #[serde(rename = "weight_tx_skip_count")] + pub weight_tx_skip_count: usize, +} +impl Default for UnconfirmedPoolConfig { + fn default() -> Self { + Self { + storage_capacity: UNCONFIRMED_STORAGE_CAPACITY, + weight_tx_skip_count: UNCONFIRMED_TX_SKIP, + } + } +} +/// Configuration for the OrphanPool +#[derive(Clone, Copy, Serialize, Deserialize)] +pub struct OrphanPoolConfig { + /// The maximum number of transactions that can be stored in the Orphan pool + #[serde(rename = "orphan_pool_storage_capacity")] + pub storage_capacity: usize, + /// The Time-to-live for each stored transaction + #[serde(rename = "orphan_tx_ttl", with = "seconds")] + pub tx_ttl: Duration, +} +impl Default for OrphanPoolConfig { + fn default() -> Self { + Self { + storage_capacity: ORPHAN_STORAGE_CAPACITY, + tx_ttl: ORPHAN_TX_TTL, + } + } +} +/// Configuration for the PendingPool. +#[derive(Clone, Copy, Serialize, Deserialize)] +pub struct PendingPoolConfig { + /// The maximum number of transactions that can be stored in the Pending pool. + #[serde(rename = "pending_pool_storage_capacity")] + pub storage_capacity: usize, +} +impl Default for PendingPoolConfig { + fn default() -> Self { + Self { + storage_capacity: PENDING_STORAGE_CAPACITY, + } + } +} + +/// Configuration for the ReorgPool +#[derive(Clone, Copy, Serialize, Deserialize)] +pub struct ReorgPoolConfig { + /// The maximum number of transactions that can be stored in the ReorgPool + #[serde(rename = "reorg_pool_storage_capacity")] + pub storage_capacity: usize, + /// The Time-to-live for each stored transaction + #[serde(rename = "reorg_tx_ttl", with = "seconds")] + pub tx_ttl: Duration, +} +impl Default for ReorgPoolConfig { + fn default() -> Self { + Self { + storage_capacity: REORG_STORAGE_CAPACITY, + tx_ttl: REORG_TX_TTL, + } + } +} + +/// Configuration for the Mempool. +#[derive(Clone, Copy, Default, Serialize, Deserialize)] +pub struct MempoolConfig { + #[serde(flatten)] + pub unconfirmed_pool_config: UnconfirmedPoolConfig, + #[serde(flatten)] + pub orphan_pool_config: OrphanPoolConfig, + #[serde(flatten)] + pub pending_pool_config: PendingPoolConfig, + #[serde(flatten)] + pub reorg_pool_config: ReorgPoolConfig, +} +impl NetworkConfigPath for MempoolConfig { + fn main_key_prefix() -> &'static str { + "mempool" + } +} + +fn main() -> Result<(), ConfigurationError> { + let mut config = config::Config::new(); + + config.set("mempool.orphan_tx_ttl", 70)?; + config.set("mempool.unconfirmed_pool_storage_capacity", 3)?; + config.set("mempool.mainnet.pending_pool_storage_capacity", 100)?; + config.set("mempool.mainnet.orphan_tx_ttl", 99)?; + let my_config = MempoolConfig::load_from(&config)?; + // no use_network value + // [X] mempool.mainnet, [ ] mempool, [X] Default = 4096 + assert_eq!(my_config.pending_pool_config.storage_capacity, PENDING_STORAGE_CAPACITY); + // [X] mempool.mainnet, [X] mempool = 70s, [X] Default + assert_eq!(my_config.orphan_pool_config.tx_ttl, Duration::from_secs(70)); + // [ ] mempool.mainnet, [X] mempool = 3, [X] Default + assert_eq!(my_config.unconfirmed_pool_config.storage_capacity, 3); + // [ ] mempool.mainnet, [ ] mempool, [X] Default = 512 + assert_eq!(my_config.reorg_pool_config.storage_capacity, REORG_STORAGE_CAPACITY); + // [ ] mempool.mainnet, [ ] mempool, [X] Default = 10s + assert_eq!(my_config.reorg_pool_config.tx_ttl, REORG_TX_TTL); + + config.set("mempool.use_network", "mainnet")?; + // use_network = mainnet + let my_config = MempoolConfig::load_from(&config)?; + // [X] mempool.mainnet = 100, [ ] mempool, [X] Default + assert_eq!(my_config.pending_pool_config.storage_capacity, 100); + // [X] mempool.mainnet = 99s, [X] mempool, [X] Default + assert_eq!(my_config.orphan_pool_config.tx_ttl, Duration::from_secs(99)); + // [ ] mempool.mainnet, [X] mempool = 3, [X] Default + assert_eq!(my_config.unconfirmed_pool_config.storage_capacity, 3); + // [ ] mempool.mainnet, [ ] mempool, [X] Default = 512 + assert_eq!(my_config.reorg_pool_config.storage_capacity, REORG_STORAGE_CAPACITY); + // [ ] mempool.mainnet, [ ] mempool, [X] Default = 10s + assert_eq!(my_config.reorg_pool_config.tx_ttl, REORG_TX_TTL); + + config.set("mempool.use_network", "wrong_network")?; + assert!(MempoolConfig::load_from(&config).is_err()); + + Ok(()) +} + +mod seconds { + use serde::{Deserialize, Deserializer, Serializer}; + use std::time::Duration; + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where D: Deserializer<'de> { + Ok(Duration::from_secs(u64::deserialize(deserializer)?)) + } + pub fn serialize(duration: &Duration, s: S) -> Result + where S: Serializer { + s.serialize_u64(duration.as_secs()) + } +} diff --git a/common/presets b/common/presets deleted file mode 120000 index f8cdeb96cd..0000000000 --- a/common/presets +++ /dev/null @@ -1 +0,0 @@ -../config/presets \ No newline at end of file diff --git a/common/src/configuration/bootstrap.rs b/common/src/configuration/bootstrap.rs new file mode 100644 index 0000000000..797eb4af3a --- /dev/null +++ b/common/src/configuration/bootstrap.rs @@ -0,0 +1,424 @@ +//! # Building tari-based applications CLI +//! +//! To help with building tari-enabled CLI from scratch as easy as possible this crate exposes +//! [`ConfigBootstrap`] struct. ConfigBootstrap implements [`structopt::StructOpt`] trait, all CLI options +//! required for initializing configs can be embedded in any StructOpt derived struct. +//! +//! After loading ConfigBootstrap parameters it is necessary to call [`ConfigBootstrap::init_dirs()`] call +//! which would create necessary configuration files based on input parameters. This usually followed by: +//! - [`ConfigBootstrap::initialize_logging()`] would initialize log4rs logging. +//! - [`ConfigBootstrap::load_configuration()`] which would load [config::Config] from .tari config file. +//! +//! ## Example - CLI which is loading and deserializing the global config file +//! +//! ```ignore +//! # use tempdir::TempDir; +//! use tari_common::ConfigBootstrap; +//! use structopt::StructOpt; +//! +//! #[derive(StructOpt)] +//! /// The reference Tari cryptocurrency base node implementation +//! struct Arguments { +//! /// Create and save new node identity if one doesn't exist +//! #[structopt(long)] +//! id: bool, +//! #[structopt(flatten)] +//! bootstrap: ConfigBootstrap, +//! } +//! +//! let mut args = Arguments::from_args(); +//! # let temp_dir = TempDir::new(string(8).as_str()).unwrap(); +//! # args.bootstrap.base_path = temp_dir.path().to_path_buf(); +//! # args.bootstrap.init = true; +//! args.bootstrap.init_dirs(); +//! args.bootstrap.initialize_logging(); +//! let config = args.bootstrap.load_configuration(); +//! assert_eq!(config.network, Network::MainNet); +//! assert_eq!(config.blocking_threads, 4); +//! # std::fs::remove_dir_all(&dir_utils::default_subdir("", Some(dir))).unwrap(); +//! ``` +//! +//! ```shell +//! > main -h +//! main 0.0.0 +//! The reference Tari cryptocurrency base node implementation +//! +//! USAGE: +//! main [FLAGS] [OPTIONS] +//! +//! FLAGS: +//! -h, --help Prints help information +//! --create Create and save new node identity if one doesn't exist +//! --init Create a default configuration file if it doesn't exist +//! -V, --version Prints version information +//! +//! OPTIONS: +//! --base-path A path to a directory to store your files +//! --config A path to the configuration file to use (config.toml) +//! --log-config The path to the log configuration file. It is set using the following precedence +//! set: [env: TARI_LOG_CONFIGURATION=] +//! ``` + +use super::{ + error::ConfigError, + utils::{install_default_config_file, load_configuration}, +}; +use crate::{dir_utils, initialize_logging, logging, DEFAULT_CONFIG, DEFAULT_LOG_CONFIG}; +use std::{ + io, + path::{Path, PathBuf}, +}; +use structopt::{clap::ArgMatches, StructOpt}; + +#[derive(StructOpt, Debug)] +pub struct ConfigBootstrap { + /// A path to a directory to store your files + #[structopt(short, long, alias("base_dir"), hide_default_value(true), default_value = "")] + pub base_path: PathBuf, + /// A path to the configuration file to use (config.toml) + #[structopt(short, long, hide_default_value(true), default_value = "")] + pub config: PathBuf, + /// The path to the log configuration file. It is set using the following precedence set + #[structopt( + short, + long, + alias("log_config"), + env = "TARI_LOG_CONFIGURATION", + hide_default_value(true), + default_value = "" + )] + pub log_config: PathBuf, + /// Create a default configuration file if it doesn't exist + #[structopt(long)] + pub init: bool, +} + +impl Default for ConfigBootstrap { + fn default() -> Self { + ConfigBootstrap { + base_path: dir_utils::default_path("", None), + config: dir_utils::default_path(DEFAULT_CONFIG, None), + log_config: dir_utils::default_path(DEFAULT_LOG_CONFIG, None), + init: false, + } + } +} + +impl ConfigBootstrap { + const ARGS: &'static [&'static str] = &["base-path", "base_dir", "config", "init", "log-config", "log_config"]; + + /// Initialize configuration and directories based on ConfigBootstrap options. + /// + /// If not present it will create base directory (default ~/.tari/, depending on OS). + /// Log and tari configs will be initialized in the base directory too. + /// + /// Without `--init` flag provided configuration and directories will be created only + /// after user's confirmation. + pub fn init_dirs(&mut self) -> Result<(), ConfigError> { + if self.base_path.to_str() == Some("") { + self.base_path = dir_utils::default_path("", None); + } + + // Create the tari data directory + dir_utils::create_data_directory(Some(&self.base_path)).map_err(|err| { + ConfigError::new( + "We couldn't create a default Tari data directory and have to quit now. This makes us sad :(", + Some(err.to_string()), + ) + })?; + + if self.config.to_str() == Some("") { + self.config = dir_utils::default_path(DEFAULT_CONFIG, Some(&self.base_path)); + } + + let log_config = if self.log_config.to_str() == Some("") { + None + } else { + Some(self.log_config.clone()) + }; + self.log_config = logging::get_log_configuration_path(log_config); + + if !self.config.exists() { + let install = if !self.init { + prompt("Config file does not exist. We can create a default one for you now, or you can say 'no' here, \ + and generate a customised one at https://config.tari.com.\n\ + Would you like to try the default configuration (Y/n)?") + } else { + true + }; + + if install { + println!( + "Installing new config file at {}", + self.config.to_str().unwrap_or("[??]") + ); + install_configuration(&self.config, install_default_config_file); + } + } + + if !self.log_config.exists() { + let install = if !self.init { + prompt("Logging configuration file does not exist. Would you like to create a new one (Y/n)?") + } else { + true + }; + if install { + println!( + "Installing new logfile configuration at {}", + self.log_config.to_str().unwrap_or("[??]") + ); + install_configuration(&self.log_config, logging::install_default_logfile_config); + } + }; + Ok(()) + } + + /// Fill in ConfigBootstrap from clap ArgMatches. + /// + /// ## Example: + /// ```edition2018 + /// # use structopt::clap::clap_app; + /// # use tari_common::*; + /// let matches = clap_app!(myapp => + /// (@arg base_path: -b --("base-path") +takes_value "A path to a directory to store your files") + /// (@arg config: -c --config +takes_value "A path to the configuration file to use (config.toml)") + /// (@arg log_config: -l --("log-config") +takes_value "A path to the logfile configuration (log4rs.yml))") + /// (@arg init: -i --init "Create a default configuration file if it doesn't exist") + /// ).get_matches(); + /// let bootstrap = ConfigBootstrap::from_matches(&matches); + /// ``` + pub fn from_matches(matches: &ArgMatches) -> Result { + let iter = matches + .args + .keys() + .flat_map(|arg| match Self::ARGS.binary_search(arg) { + Ok(_) => vec![ + Some(std::ffi::OsString::from(format!("--{}", arg))), + matches.value_of_os(arg).map(|s| s.to_os_string()), + ], + _ => vec![], + }) + .filter_map(|arg| arg); + + let mut vals: Vec = iter.collect(); + vals.insert(0, "".into()); + Ok(ConfigBootstrap::from_iter_safe(vals.iter())?) + } + + /// Set up application-level logging using the Log4rs configuration file + /// based on supplied CLI arguments + pub fn initialize_logging(&self) -> Result<(), ConfigError> { + match initialize_logging(&self.log_config) { + true => Ok(()), + false => Err(ConfigError::new("failed to initalize logging", None)), + } + } + + /// Load configuration from files located based on supplied CLI arguments + pub fn load_configuration(&self) -> Result { + load_configuration(self).map_err(|source| ConfigError::new("failed to load configuration", Some(source))) + } +} + +/// Fill in ConfigBootstrap from clap ArgMatches +/// +/// ```rust +/// # use structopt::clap::clap_app; +/// # use tari_common::*; +/// let matches = clap_app!(myapp => +/// (version: "0.0.10") +/// (author: "The Tari Community") +/// (about: "The reference Tari cryptocurrency base node implementation") +/// (@arg base_path: -b --("base-path") +takes_value "A path to a directory to store your files") +/// (@arg config: -c --config +takes_value "A path to the configuration file to use (config.toml)") +/// (@arg log_config: -l --("log-config") +takes_value "A path to the logfile configuration (log4rs.yml))") +/// (@arg init: -i --init "Create a default configuration file if it doesn't exist") +/// (@arg create_id: --("create-id") "Create and save new node identity if one doesn't exist ") +/// ).get_matches(); +/// let bootstrap = bootstrap_config_from_cli(&matches); +/// ``` +/// ## Caveats +/// It will exit with code 1 if no base dir and fails to create one +pub fn bootstrap_config_from_cli(matches: &ArgMatches) -> ConfigBootstrap { + let mut bootstrap = ConfigBootstrap::from_matches(matches).expect("failed to extract matches"); + match bootstrap.init_dirs() { + Err(err) => { + println!("{}", err); + std::process::exit(1); + }, + Ok(_) => bootstrap, + } +} + +fn prompt(question: &str) -> bool { + println!("{}", question); + let mut input = "".to_string(); + io::stdin().read_line(&mut input).unwrap(); + let input = input.trim().to_lowercase(); + input == "y" || input.is_empty() +} + +pub fn install_configuration(path: &Path, installer: F) +where F: Fn(&Path) -> Result<(), std::io::Error> { + if let Err(e) = installer(path) { + println!( + "We could not install a new configuration file in {}: {}", + path.to_str().unwrap_or("?"), + e.to_string() + ) + } +} + +#[cfg(test)] +mod test { + use super::ConfigBootstrap; + use crate::{bootstrap_config_from_cli, dir_utils, dir_utils::default_subdir, load_configuration}; + use structopt::{clap::clap_app, StructOpt}; + use tari_test_utils::random::string; + use tempdir::TempDir; + + #[test] + fn test_bootstrap_from_matches() { + // Create command line test data + let app = clap_app!(myapp => + (@arg base_dir: -b --base_dir +takes_value "A path to a directory to store your files") + (@arg config: -c --config +takes_value "A path to the configuration file to use (config.toml)") + (@arg log_config: -l--log_config +takes_value "A path to the logfile configuration (log4rs.yml))") + (@arg init: --init "Create a default configuration file if it doesn't exist") + ); + let matches = app.clone().get_matches_from(vec![ + "", + "--log_config", + "no-file-created", + "--config", + "no-file-created", + "--base_dir", + "no-dir-created", + "--init", + ]); + let bootstrap = ConfigBootstrap::from_matches(&matches).expect("failed to extract matches"); + assert!(bootstrap.init); + assert_eq!(bootstrap.base_path.to_str(), Some("no-dir-created")); + assert_eq!(bootstrap.log_config.to_str(), Some("no-file-created")); + assert_eq!(bootstrap.config.to_str(), Some("no-file-created")); + + // Check aliases too + let app = clap_app!(myapp => + (@arg ("base-path"): -b --("base-path") +takes_value "A path to a directory to store your files") + (@arg config: -c --config +takes_value "A path to the configuration file to use (config.toml)") + (@arg ("log-config"): -l --("log-config") +takes_value "A path to the logfile configuration (log4rs.yml))") + (@arg init: --init "Create a default configuration file if it doesn't exist") + ); + let matches = app.get_matches_from(vec![ + "", + "--log-config", + "no-file-created", + "--base-path", + "no-dir-created", + ]); + let bootstrap = ConfigBootstrap::from_matches(&matches).expect("failed to extract matches"); + assert!(!bootstrap.init); + assert_eq!(bootstrap.base_path.to_str(), Some("no-dir-created")); + assert_eq!(bootstrap.log_config.to_str(), Some("no-file-created")); + assert_eq!(bootstrap.config.to_str(), Some("")); + } + + #[test] + fn test_bootstrap_config_from_cli_and_load_configuration() { + let temp_dir = TempDir::new(string(8).as_str()).unwrap(); + let dir = &temp_dir.path().to_path_buf(); + // Create test folder + dir_utils::create_data_directory(Some(dir)).unwrap(); + + // Create command line test data + let matches = clap_app!(myapp => + (version: "0.0.10") + (author: "The Tari Community") + (about: "The reference Tari cryptocurrency base node implementation") + (@arg base_dir: -b --base_dir +takes_value "A path to a directory to store your files") + (@arg config: -c --config +takes_value "A path to the configuration file to use (config.toml)") + (@arg log_config: -l --log_config +takes_value "A path to the logfile configuration (log4rs.yml))") + (@arg init: --init "Create a default configuration file if it doesn't exist") + (@arg create_id: --create_id "Create and save new node identity if one doesn't exist ") + ) + .get_matches_from(vec![ + "", + "--base_dir", + default_subdir("", Some(dir)).as_str(), + "--init", + "--create_id", + ]); + + let bootstrap = ConfigBootstrap::from_matches(&matches).expect("failed to extract matches"); + assert!(bootstrap.init); + assert_eq!(&bootstrap.base_path, dir); + + // Load bootstrap via former API + let bootstrap = bootstrap_config_from_cli(&matches); + let config_exists = std::path::Path::new(&bootstrap.config).exists(); + let log_config_exists = std::path::Path::new(&bootstrap.log_config).exists(); + // Load and apply configuration file + let cfg = load_configuration(&bootstrap); + + // Cleanup test data + if std::path::Path::new(&dir_utils::default_subdir("", Some(dir))).exists() { + std::fs::remove_dir_all(&dir_utils::default_subdir("", Some(dir))).expect("failed to cleanup dirs"); + } + + // Assert results + assert!(config_exists); + assert!(log_config_exists); + assert!(&cfg.is_ok()); + } + + #[test] + fn test_bootstrap_config_from_structopt_derive() { + let temp_dir = TempDir::new(string(8).as_str()).unwrap(); + let dir = &temp_dir.path().to_path_buf(); + // Create test folder + dir_utils::create_data_directory(Some(dir)).unwrap(); + + #[derive(StructOpt)] + /// The reference Tari cryptocurrency base node implementation + struct Arguments { + /// Create and save new node identity if one doesn't exist + #[structopt(long = "create_id")] + create_id: bool, + #[structopt(flatten)] + bootstrap: super::ConfigBootstrap, + } + + // Create command line test data + let mut args = Arguments::from_iter_safe(vec![ + "", + "--base_dir", + default_subdir("", Some(dir)).as_str(), + "--init", + "--create_id", + ]) + .expect("failed to process arguments"); + // Init bootstrap dirs + args.bootstrap.init_dirs().expect("failed to initialize dirs"); + // Load and apply configuration file + let cfg = load_configuration(&args.bootstrap); + + // Cleanup test data + if std::path::Path::new(&dir_utils::default_subdir("", Some(dir))).exists() { + std::fs::remove_dir_all(&dir_utils::default_subdir("", Some(dir))).unwrap(); + } + + // Assert results + assert!(args.bootstrap.init); + assert!(args.create_id); + assert!(&cfg.is_ok()); + } + + #[test] + fn check_homedir_is_used_by_default() { + dir_utils::create_data_directory(None).unwrap(); + assert_eq!( + dirs::home_dir().unwrap().join(".tari"), + dir_utils::default_path("", None) + ); + } +} diff --git a/common/src/configuration/error.rs b/common/src/configuration/error.rs new file mode 100644 index 0000000000..6a7e3a9537 --- /dev/null +++ b/common/src/configuration/error.rs @@ -0,0 +1,35 @@ +use std::fmt; +use structopt::clap::Error as ClapError; + +#[derive(Debug)] +pub struct ConfigError { + pub(crate) cause: &'static str, + pub(crate) source: Option, +} + +impl ConfigError { + pub(crate) fn new(cause: &'static str, source: Option) -> Self { + Self { cause, source } + } +} + +impl std::error::Error for ConfigError {} +impl fmt::Display for ConfigError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.cause)?; + if let Some(ref source) = self.source { + write!(f, ":\n{}", source) + } else { + Ok(()) + } + } +} + +impl From for ConfigError { + fn from(e: ClapError) -> Self { + Self { + cause: "Failed to process commandline parameters", + source: Some(e.to_string()), + } + } +} diff --git a/common/src/configuration.rs b/common/src/configuration/global.rs similarity index 57% rename from common/src/configuration.rs rename to common/src/configuration/global.rs index b0468acd14..1553a3eedc 100644 --- a/common/src/configuration.rs +++ b/common/src/configuration/global.rs @@ -20,234 +20,19 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // +//! # Global configuration of tari base layer system -use crate::{dir_utils::default_subdir, ConfigBootstrap}; +use super::ConfigurationError; use config::{Config, Environment}; -use log::*; -use multiaddr::{Multiaddr, Protocol}; +use multiaddr::Multiaddr; use std::{ convert::TryInto, - error::Error, fmt::{Display, Formatter, Result as FormatResult}, - fs, - net::IpAddr, num::{NonZeroU16, TryFromIntError}, - path::{Path, PathBuf}, + path::PathBuf, str::FromStr, }; -const LOG_TARGET: &str = "common::config"; - -//------------------------------------- Main API functions --------------------------------------// - -pub fn load_configuration(bootstrap: &ConfigBootstrap) -> Result { - debug!( - target: LOG_TARGET, - "Loading configuration file from {}", - bootstrap.config.to_str().unwrap_or("[??]") - ); - let mut cfg = default_config(bootstrap); - // Load the configuration file - let filename = bootstrap - .config - .to_str() - .ok_or_else(|| "Invalid config file path".to_string())?; - let config_file = config::File::with_name(filename); - match cfg.merge(config_file) { - Ok(_) => { - info!(target: LOG_TARGET, "Configuration file loaded."); - Ok(cfg) - }, - Err(e) => Err(format!( - "There was an error loading the configuration file. {}", - e.to_string() - )), - } -} - -/// Installs a new configuration file template, copied from `rincewind-simple.toml` to the given path. -pub fn install_default_config_file(path: &Path) -> Result<(), std::io::Error> { - let source = include_str!("../presets/rincewind-simple.toml"); - fs::write(path, source) -} - -//--------------------------------------------- Network type ------------------------------------------// -#[derive(Clone, Debug, PartialEq)] -pub enum Network { - MainNet, - Rincewind, -} - -impl FromStr for Network { - type Err = ConfigurationError; - - fn from_str(value: &str) -> Result { - match value.to_lowercase().as_str() { - "rincewind" => Ok(Self::Rincewind), - "mainnet" => Ok(Self::MainNet), - invalid => Err(ConfigurationError::new( - "network", - &format!("Invalid network option: {}", invalid), - )), - } - } -} - -impl Display for Network { - fn fmt(&self, f: &mut Formatter) -> FormatResult { - let msg = match self { - Self::MainNet => "mainnet", - Self::Rincewind => "rincewind", - }; - f.write_str(msg) - } -} - -//------------------------------------------- ConfigExtractor trait ------------------------------------------// -/// Extract parts of the global Config file into custom configuration objects that are more specific and localised. -/// The expected use case for this is to use `load_configuration` to load the global configuration file into a Config -/// object. This is then used to generate other, localised configuration objects, for example, `MempoolConfig` etc. -/// -/// # Example -/// -/// ```edition2018 -/// # use tari_common::*; -/// # use config::Config; -/// struct MyConf { -/// foo: usize, -/// } -/// -/// impl ConfigExtractor for MyConf { -/// fn set_default(cfg: &mut Config) { -/// cfg.set_default("main.foo", 5); -/// cfg.set_default("test.foo", 6); -/// } -/// -/// fn extract_configuration(cfg: &Config, network: Network) -> Result { -/// let key = match network { -/// Network::MainNet => "main.foo", -/// Network::Rincewind => "test.foo", -/// }; -/// let foo = cfg.get_int(key).map_err(|e| ConfigurationError::new(&key, &e.to_string()))? as usize; -/// Ok(MyConf { foo }) -/// } -/// } -/// ``` -pub trait ConfigExtractor { - /// Provides the default values for the Config object. This is used before `load_configuration` and ensures that - /// all config parameters have at least the default value set. - fn set_default(cfg: &mut Config); - /// After `load_configuration` has been called, you can construct a specific configuration object by calling - /// `extract_configuration` and it will create the object using values from the config file / environment variables - fn extract_configuration(cfg: &Config, network: Network) -> Result - where Self: Sized; -} -//--------------------------------------------- Database type ------------------------------------------// -#[derive(Debug)] -pub enum DatabaseType { - LMDB(PathBuf), - Memory, -} - -//--------------------------------------------- Network Transport ------------------------------------------// -#[derive(Debug, Clone)] -pub enum TorControlAuthentication { - None, - Password(String), -} - -fn parse_key_value(s: &str, split_chr: char) -> (String, Option<&str>) { - let mut parts = s.splitn(2, split_chr); - ( - parts - .next() - .expect("splitn always emits at least one part") - .to_lowercase(), - parts.next(), - ) -} - -impl FromStr for TorControlAuthentication { - type Err = String; - - fn from_str(s: &str) -> Result { - let (auth_type, maybe_value) = parse_key_value(s, '='); - match auth_type.as_str() { - "none" => Ok(TorControlAuthentication::None), - "password" => { - let password = maybe_value.ok_or_else(|| { - "Invalid format for 'password' tor authentication type. It should be in the format \ - 'password=xxxxxx'." - .to_string() - })?; - Ok(TorControlAuthentication::Password(password.to_string())) - }, - s => Err(format!("Invalid tor auth type '{}'", s)), - } - } -} - -#[derive(Debug, Clone)] -pub enum SocksAuthentication { - None, - UsernamePassword(String, String), -} - -impl FromStr for SocksAuthentication { - type Err = String; - - fn from_str(s: &str) -> Result { - let (auth_type, maybe_value) = parse_key_value(s, '='); - match auth_type.as_str() { - "none" => Ok(SocksAuthentication::None), - "username_password" => { - let (username, password) = maybe_value - .and_then(|value| { - let (un, pwd) = parse_key_value(value, ':'); - // If pwd is None, return None - pwd.map(|p| (un, p)) - }) - .ok_or_else(|| { - "Invalid format for 'username-password' socks authentication type. It should be in the format \ - 'username_password=my_username:xxxxxx'." - .to_string() - })?; - Ok(SocksAuthentication::UsernamePassword(username, password.to_string())) - }, - s => Err(format!("Invalid tor auth type '{}'", s)), - } - } -} - -#[derive(Debug, Clone)] -pub enum CommsTransport { - /// Use TCP to join the Tari network. This transport can only communicate with TCP/IP addresses, so peers with - /// e.g. tor onion addresses will not be contactable. - Tcp { - listener_address: Multiaddr, - tor_socks_address: Option, - tor_socks_auth: Option, - }, - /// Configures the node to run over a tor hidden service using the Tor proxy. This transport recognises ip/tcp, - /// onion v2, onion v3 and dns addresses. - TorHiddenService { - /// The address of the control server - control_server_address: Multiaddr, - socks_address_override: Option, - /// The address used to receive proxied traffic from the tor proxy to the Tari node. This port must be - /// available - forward_address: Multiaddr, - auth: TorControlAuthentication, - onion_port: NonZeroU16, - }, - /// Use a SOCKS5 proxy transport. This transport recognises any addresses supported by the proxy. - Socks5 { - proxy_address: Multiaddr, - auth: SocksAuthentication, - listener_address: Multiaddr, - }, -} - //------------------------------------- Main Configuration Struct --------------------------------------// #[derive(Debug)] @@ -408,9 +193,8 @@ fn convert_node_config(network: Network, cfg: Config) -> Result String { format!("base_node.{}.{}", network, key) } -//------------------------------------- Configuration file defaults --------------------------------------// - -/// Generate the global Tari configuration instance. -/// -/// The `Config` object that is returned holds _all_ the default values possible in the `~/.tari.config.toml` file. -/// These will typically be overridden by userland settings in envars, the config file, or the command line. -pub fn default_config(bootstrap: &ConfigBootstrap) -> Config { - let mut cfg = Config::new(); - let local_ip_addr = get_local_ip().unwrap_or_else(|| "/ip4/1.2.3.4".parse().unwrap()); - - // Common settings - cfg.set_default("common.message_cache_size", 10).unwrap(); - cfg.set_default("common.message_cache_ttl", 1440).unwrap(); - cfg.set_default("common.peer_whitelist", Vec::::new()).unwrap(); - cfg.set_default("common.liveness_max_sessions", 0).unwrap(); - cfg.set_default( - "common.peer_database ", - default_subdir("peers", Some(&bootstrap.base_path)), - ) - .unwrap(); - cfg.set_default("common.blacklist_ban_period ", 1440).unwrap(); - - // Wallet settings - cfg.set_default("wallet.grpc_enabled", false).unwrap(); - cfg.set_default("wallet.grpc_address", "tcp://127.0.0.1:18040").unwrap(); - cfg.set_default( - "wallet.wallet_file", - default_subdir("wallet/wallet.dat", Some(&bootstrap.base_path)), - ) - .unwrap(); - - //---------------------------------- Mainnet Defaults --------------------------------------------// - - cfg.set_default("base_node.network", "mainnet").unwrap(); - - // Mainnet base node defaults - cfg.set_default("base_node.mainnet.db_type", "lmdb").unwrap(); - cfg.set_default("base_node.mainnet.peer_seeds", Vec::::new()) - .unwrap(); - cfg.set_default("base_node.mainnet.block_sync_strategy", "ViaBestChainMetadata") - .unwrap(); - cfg.set_default("base_node.mainnet.blocking_threads", 4).unwrap(); - cfg.set_default("base_node.mainnet.core_threads", 6).unwrap(); - cfg.set_default( - "base_node.mainnet.data_dir", - default_subdir("mainnet/", Some(&bootstrap.base_path)), - ) - .unwrap(); - cfg.set_default( - "base_node.mainnet.identity_file", - default_subdir("mainnet/node_id.json", Some(&bootstrap.base_path)), - ) - .unwrap(); - cfg.set_default( - "base_node.mainnet.tor_identity_file", - default_subdir("mainnet/tor.json", Some(&bootstrap.base_path)), - ) - .unwrap(); - cfg.set_default( - "base_node.mainnet.wallet_identity_file", - default_subdir("mainnet/wallet-identity.json", Some(&bootstrap.base_path)), - ) - .unwrap(); - cfg.set_default( - "base_node.mainnet.wallet_tor_identity_file", - default_subdir("mainnet/wallet-tor.json", Some(&bootstrap.base_path)), - ) - .unwrap(); - cfg.set_default( - "base_node.mainnet.public_address", - format!("{}/tcp/18041", local_ip_addr), - ) - .unwrap(); - cfg.set_default("base_node.mainnet.grpc_enabled", false).unwrap(); - cfg.set_default("base_node.mainnet.grpc_address", "tcp://127.0.0.1:18041") - .unwrap(); - cfg.set_default("base_node.mainnet.enable_mining", false).unwrap(); - cfg.set_default("base_node.mainnet.num_mining_threads", 1).unwrap(); - - //---------------------------------- Rincewind Defaults --------------------------------------------// - - cfg.set_default("base_node.rincewind.db_type", "lmdb").unwrap(); - cfg.set_default("base_node.rincewind.peer_seeds", Vec::::new()) - .unwrap(); - cfg.set_default("base_node.rincewind.block_sync_strategy", "ViaBestChainMetadata") - .unwrap(); - cfg.set_default("base_node.rincewind.blocking_threads", 4).unwrap(); - cfg.set_default("base_node.rincewind.core_threads", 4).unwrap(); - cfg.set_default( - "base_node.rincewind.data_dir", - default_subdir("rincewind/", Some(&bootstrap.base_path)), - ) - .unwrap(); - cfg.set_default( - "base_node.rincewind.tor_identity_file", - default_subdir("rincewind/tor.json", Some(&bootstrap.base_path)), - ) - .unwrap(); - cfg.set_default( - "base_node.rincewind.wallet_identity_file", - default_subdir("rincewind/wallet-identity.json", Some(&bootstrap.base_path)), - ) - .unwrap(); - cfg.set_default( - "base_node.rincewind.wallet_tor_identity_file", - default_subdir("rincewind/wallet-tor.json", Some(&bootstrap.base_path)), - ) - .unwrap(); - cfg.set_default( - "base_node.rincewind.identity_file", - default_subdir("rincewind/node_id.json", Some(&bootstrap.base_path)), - ) - .unwrap(); - cfg.set_default( - "base_node.rincewind.public_address", - format!("{}/tcp/18141", local_ip_addr), - ) - .unwrap(); - cfg.set_default("base_node.rincewind.grpc_enabled", false).unwrap(); - cfg.set_default("base_node.rincewind.grpc_address", "tcp://127.0.0.1:18141") - .unwrap(); - cfg.set_default("base_node.rincewind.enable_mining", false).unwrap(); - cfg.set_default("base_node.rincewind.num_mining_threads", 1).unwrap(); +//--------------------------------------------- Network type ------------------------------------------// +#[derive(Clone, Debug, PartialEq)] +pub enum Network { + MainNet, + Rincewind, +} - set_transport_defaults(&mut cfg); +impl FromStr for Network { + type Err = ConfigurationError; - cfg + fn from_str(value: &str) -> Result { + match value.to_lowercase().as_str() { + "rincewind" => Ok(Self::Rincewind), + "mainnet" => Ok(Self::MainNet), + invalid => Err(ConfigurationError::new( + "network", + &format!("Invalid network option: {}", invalid), + )), + } + } } -fn set_transport_defaults(cfg: &mut Config) { - // Mainnet - // Default transport for mainnet is tcp - cfg.set_default("base_node.mainnet.transport", "tcp").unwrap(); - cfg.set_default("base_node.mainnet.tcp_listener_address", "/ip4/0.0.0.0/tcp/18089") - .unwrap(); - - cfg.set_default("base_node.mainnet.tor_control_address", "/ip4/127.0.0.1/tcp/9051") - .unwrap(); - cfg.set_default("base_node.mainnet.tor_control_auth", "none").unwrap(); - cfg.set_default("base_node.mainnet.tor_forward_address", "/ip4/127.0.0.1/tcp/18141") - .unwrap(); - cfg.set_default("base_node.mainnet.tor_onion_port", "18141").unwrap(); - - cfg.set_default("base_node.mainnet.socks5_proxy_address", "/ip4/0.0.0.0/tcp/9050") - .unwrap(); - cfg.set_default("base_node.mainnet.socks5_listener_address", "/ip4/0.0.0.0/tcp/18099") - .unwrap(); - cfg.set_default("base_node.mainnet.socks5_auth", "none").unwrap(); - - // rincewind - // Default transport for rincewind is tcp - cfg.set_default("base_node.rincewind.transport", "tcp").unwrap(); - cfg.set_default("base_node.rincewind.tcp_listener_address", "/ip4/0.0.0.0/tcp/18189") - .unwrap(); - - cfg.set_default("base_node.rincewind.tor_control_address", "/ip4/127.0.0.1/tcp/9051") - .unwrap(); - cfg.set_default("base_node.rincewind.tor_control_auth", "none").unwrap(); - cfg.set_default("base_node.rincewind.tor_forward_address", "/ip4/127.0.0.1/tcp/18041") - .unwrap(); - cfg.set_default("base_node.rincewind.tor_onion_port", "18141").unwrap(); - - cfg.set_default("base_node.rincewind.socks5_proxy_address", "/ip4/0.0.0.0/tcp/9150") - .unwrap(); - cfg.set_default("base_node.rincewind.socks5_listener_address", "/ip4/0.0.0.0/tcp/18199") - .unwrap(); - cfg.set_default("base_node.rincewind.socks5_auth", "none").unwrap(); +impl Display for Network { + fn fmt(&self, f: &mut Formatter) -> FormatResult { + let msg = match self { + Self::MainNet => "mainnet", + Self::Rincewind => "rincewind", + }; + f.write_str(msg) + } } -fn get_local_ip() -> Option { - get_if_addrs::get_if_addrs().ok().and_then(|if_addrs| { - if_addrs - .into_iter() - .find(|if_addr| !if_addr.is_loopback()) - .map(|if_addr| { - let mut addr = Multiaddr::empty(); - match if_addr.ip() { - IpAddr::V4(ip) => { - addr.push(Protocol::Ip4(ip)); - }, - IpAddr::V6(ip) => { - addr.push(Protocol::Ip6(ip)); - }, - } - addr - }) - }) +//--------------------------------------------- Database type ------------------------------------------// +#[derive(Debug)] +pub enum DatabaseType { + LMDB(PathBuf), + Memory, } -//------------------------------------- Configuration errors --------------------------------------// +//--------------------------------------------- Network Transport ------------------------------------------// +#[derive(Debug, Clone)] +pub enum TorControlAuthentication { + None, + Password(String), +} -#[derive(Debug)] -pub struct ConfigurationError { - field: String, - message: String, +fn parse_key_value(s: &str, split_chr: char) -> (String, Option<&str>) { + let mut parts = s.splitn(2, split_chr); + ( + parts + .next() + .expect("splitn always emits at least one part") + .to_lowercase(), + parts.next(), + ) } -impl ConfigurationError { - pub fn new(field: &str, msg: &str) -> Self { - ConfigurationError { - field: String::from(field), - message: String::from(msg), +impl FromStr for TorControlAuthentication { + type Err = String; + + fn from_str(s: &str) -> Result { + let (auth_type, maybe_value) = parse_key_value(s, '='); + match auth_type.as_str() { + "none" => Ok(TorControlAuthentication::None), + "password" => { + let password = maybe_value.ok_or_else(|| { + "Invalid format for 'password' tor authentication type. It should be in the format \ + 'password=xxxxxx'." + .to_string() + })?; + Ok(TorControlAuthentication::Password(password.to_string())) + }, + s => Err(format!("Invalid tor auth type '{}'", s)), } } } -impl Display for ConfigurationError { - fn fmt(&self, f: &mut Formatter) -> Result<(), std::fmt::Error> { - f.write_str(&format!("Invalid value for {}: {}", self.field, self.message)) - } +#[derive(Debug, Clone)] +pub enum SocksAuthentication { + None, + UsernamePassword(String, String), } -impl Error for ConfigurationError {} - -#[cfg(test)] -mod test { - use crate::ConfigurationError; +impl FromStr for SocksAuthentication { + type Err = String; - #[test] - fn configuration_error() { - let e = ConfigurationError::new("test", "is a string"); - assert_eq!(e.to_string(), "Invalid value for test: is a string"); + fn from_str(s: &str) -> Result { + let (auth_type, maybe_value) = parse_key_value(s, '='); + match auth_type.as_str() { + "none" => Ok(SocksAuthentication::None), + "username_password" => { + let (username, password) = maybe_value + .and_then(|value| { + let (un, pwd) = parse_key_value(value, ':'); + // If pwd is None, return None + pwd.map(|p| (un, p)) + }) + .ok_or_else(|| { + "Invalid format for 'username-password' socks authentication type. It should be in the format \ + 'username_password=my_username:xxxxxx'." + .to_string() + })?; + Ok(SocksAuthentication::UsernamePassword(username, password.to_string())) + }, + s => Err(format!("Invalid tor auth type '{}'", s)), + } } } + +#[derive(Debug, Clone)] +pub enum CommsTransport { + /// Use TCP to join the Tari network. This transport can only communicate with TCP/IP addresses, so peers with + /// e.g. tor onion addresses will not be contactable. + Tcp { + listener_address: Multiaddr, + tor_socks_address: Option, + tor_socks_auth: Option, + }, + /// Configures the node to run over a tor hidden service using the Tor proxy. This transport recognises ip/tcp, + /// onion v2, onion v3 and dns addresses. + TorHiddenService { + /// The address of the control server + control_server_address: Multiaddr, + socks_address_override: Option, + /// The address used to receive proxied traffic from the tor proxy to the Tari node. This port must be + /// available + forward_address: Multiaddr, + auth: TorControlAuthentication, + onion_port: NonZeroU16, + }, + /// Use a SOCKS5 proxy transport. This transport recognises any addresses supported by the proxy. + Socks5 { + proxy_address: Multiaddr, + auth: SocksAuthentication, + listener_address: Multiaddr, + }, +} diff --git a/common/src/configuration/loader.rs b/common/src/configuration/loader.rs new file mode 100644 index 0000000000..6814ecef5d --- /dev/null +++ b/common/src/configuration/loader.rs @@ -0,0 +1,505 @@ +// Copyright 2019. The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +//! # Application configuration +//! +//! Tari is using config crate which allows to extend config file with application level configs. +//! To allow deriving configuration from a Config via [`ConfigLoader`] trait application configuration +//! struct should implements [`Deserialize`][serde::Deserialize] and [`NetworkConfigPath`] traits. +//! +//! [`ConfigLoader::load_from`] logic will include automated overloading of parameters from [application.{network}] +//! subsection, where network is specified in `application.use_network` parameter. +//! +//! [`ConfigPath`] allows to customize overloading logic event further and [`DefaultConfigLoader`] trait accounts +//! for struct [`Default`]s when loading values. +//! +//! ## Example +//! +//! ``` +//! # use config::Config; +//! # use serde::{Deserialize}; +//! # use tari_common::{NetworkConfigPath, ConfigLoader}; +//! #[derive(Deserialize)] +//! struct MyNodeConfig { +//! welcome_message: String, +//! } +//! impl NetworkConfigPath for MyNodeConfig { +//! fn main_key_prefix() -> &'static str { +//! "my_node" +//! } +//! } +//! +//! # let mut config = Config::new(); +//! config.set("my_node.use_network", "rincewind"); +//! config.set("my_node.rincewind.welcome_message", "nice to see you at unseen"); +//! let my_config = ::load_from(&config).unwrap(); +//! assert_eq!(my_config.welcome_message, "nice to see you at unseen"); +//! ``` + +use super::Network; +use config::Config; +use std::{ + error::Error, + fmt::{Display, Formatter}, +}; + +//------------------------------------------- ConfigExtractor trait ------------------------------------------// +/// Extract parts of the global Config file into custom configuration objects that are more specific and localised. +/// The expected use case for this is to use `load_configuration` to load the global configuration file into a Config +/// object. This is then used to generate other, localised configuration objects, for example, `MempoolConfig` etc. +/// +/// # Example +/// +/// ```edition2018 +/// # use tari_common::*; +/// # use config::Config; +/// struct MyConf { +/// foo: usize, +/// } +/// +/// impl ConfigExtractor for MyConf { +/// fn set_default(cfg: &mut Config) { +/// cfg.set_default("main.foo", 5); +/// cfg.set_default("test.foo", 6); +/// } +/// +/// fn extract_configuration(cfg: &Config, network: Network) -> Result { +/// let key = match network { +/// Network::MainNet => "main.foo", +/// Network::Rincewind => "test.foo", +/// }; +/// let foo = cfg.get_int(key).map_err(|e| ConfigurationError::new(&key, &e.to_string()))? as usize; +/// Ok(MyConf { foo }) +/// } +/// } +/// ``` +#[deprecated(since = "0.0.10", note = "Please use ConfigPath and ConfigLoader traits instead")] +pub trait ConfigExtractor { + /// Provides the default values for the Config object. This is used before `load_configuration` and ensures that + /// all config parameters have at least the default value set. + fn set_default(cfg: &mut Config); + /// After `load_configuration` has been called, you can construct a specific configuration object by calling + /// `extract_configuration` and it will create the object using values from the config file / environment variables + fn extract_configuration(cfg: &Config, network: Network) -> Result + where Self: Sized; +} + +//------------------------------------------- ConfigLoader trait ------------------------------------------// + +/// Load struct from config's main section and subsecttion override +/// +/// Implementation of this trait along with Deserialize grants +/// ConfigLoader implementation +pub trait ConfigPath { + /// Main configuration section + fn main_key_prefix() -> &'static str; + /// Overload values from a key prefix based on some configuration value. + /// + /// Should return a path to configuration table with overloading values. + /// Returns `ConfigurationError` if key_prefix field has wrong value. + /// Returns Ok(None) if no overload is required + fn overload_key_prefix(config: &Config) -> Result, ConfigurationError>; + /// Merge and produce sub-config from overload_key_prefix to main_key_prefix, + /// which can be used to deserialize Self struct + /// If overload key is not present in config it won't make effect + fn merge_subconfig(config: &Config) -> Result { + use config::Value; + match Self::overload_key_prefix(config)? { + Some(key) => { + let overload: Value = config.get(key.as_str()).unwrap_or_default(); + let base: Value = config.get(Self::main_key_prefix()).unwrap_or_default(); + let mut base_config = Config::new(); + base_config.set(Self::main_key_prefix(), base)?; + let mut config = Config::new(); + // Some magic is required to make them correctly merge + config.merge(base_config)?; + config.set(Self::main_key_prefix(), overload)?; + Ok(config) + }, + None => Ok(config.clone()), + } + } +} + +/// Load struct from config's main section and network subsection override +/// +/// Network subsection will be choosen based on `use_network` key value +/// from the main section defined in this trait. +/// +/// Wrong network value will result in Error +pub trait NetworkConfigPath { + /// Main configuration section + fn main_key_prefix() -> &'static str; +} +impl ConfigPath for C { + fn main_key_prefix() -> &'static str { + ::main_key_prefix() + } + + fn overload_key_prefix(config: &Config) -> Result, ConfigurationError> { + let main = ::main_key_prefix(); + let network_key = format!("{}.use_network", main); + let network_val: Option = config.get_str(network_key.as_str()).ok(); + if let Some(s) = network_val { + let network: Network = s.parse()?; + Ok(Some(format!("{}.{}", main, network))) + } else { + Ok(None) + } + } +} + +/// Configuration loader based on ConfigPath selectors +/// +/// ``` +/// # use config::Config; +/// # use serde::{Deserialize}; +/// use tari_common::{ConfigLoader, NetworkConfigPath}; +/// +/// #[derive(Deserialize)] +/// struct MyNodeConfig { +/// #[serde(default = "welcome")] +/// welcome_message: String, +/// #[serde(default = "bye")] +/// goodbye_message: String, +/// } +/// fn welcome() -> String { +/// "welcome to tari".into() +/// } +/// fn bye() -> String { +/// "bye bye".into() +/// } +/// impl NetworkConfigPath for MyNodeConfig { +/// fn main_key_prefix() -> &'static str { +/// "my_node" +/// } +/// } +/// // Loading preset and serde default value +/// let mut config = Config::new(); +/// config.set("my_node.goodbye_message", "see you later"); +/// config.set("my_node.mainnet.goodbye_message", "see you soon"); +/// let my_config = ::load_from(&config).unwrap(); +/// assert_eq!(my_config.goodbye_message, "see you later".to_string()); +/// assert_eq!(my_config.welcome_message, welcome()); +/// // Overloading from network subsection as we use NetworkConfigPath +/// config.set("my_node.use_network", "mainnet"); +/// let my_config = ::load_from(&config).unwrap(); +/// assert_eq!(my_config.goodbye_message, "see you soon".to_string()); +/// ``` +pub trait ConfigLoader: ConfigPath + for<'de> serde::de::Deserialize<'de> { + /// Try to load configuration from supplied Config by `main_key_prefix()` + /// with values overloaded from `overload_key_prefix()`. + /// + /// Default values will be taken from + /// - `#[serde(default="value")]` field attribute + /// - value defined in Config::set_default() + /// For automated inheritance of Default values use DefaultConfigLoader. + fn load_from(config: &Config) -> Result { + let merger = Self::merge_subconfig(config)?; + Ok(merger.get(Self::main_key_prefix())?) + } +} +impl ConfigLoader for C where C: ConfigPath + for<'de> serde::de::Deserialize<'de> {} + +/// Configuration loader based on ConfigPath selectors with Defaults +/// +/// ``` +/// use config::Config; +/// use serde::{Deserialize, Serialize}; +/// use tari_common::{DefaultConfigLoader, NetworkConfigPath}; +/// +/// #[derive(Serialize, Deserialize)] +/// struct MyNodeConfig { +/// welcome_message: String, +/// goodbye_message: String, +/// } +/// impl Default for MyNodeConfig { +/// fn default() -> Self { +/// Self { +/// welcome_message: "welcome from tari".into(), +/// goodbye_message: "bye bye".into(), +/// } +/// } +/// } +/// impl NetworkConfigPath for MyNodeConfig { +/// fn main_key_prefix() -> &'static str { +/// "my_node" +/// } +/// } +/// let mut config = Config::new(); +/// config.set("my_node.goodbye_message", "see you later"); +/// let my_config = ::load_from(&config).unwrap(); +/// assert_eq!(my_config.goodbye_message, "see you later".to_string()); +/// assert_eq!(my_config.welcome_message, MyNodeConfig::default().welcome_message); +/// ``` +pub trait DefaultConfigLoader: + ConfigPath + Default + serde::ser::Serialize + for<'de> serde::de::Deserialize<'de> +{ + /// Try to load configuration from supplied Config by `main_key_prefix()` + /// with values overloaded from `overload_key_prefix()`. + /// + /// Default values will be taken from Default impl for struct + fn load_from(config: &Config) -> Result { + let default = ::default(); + let buf = serde_json::to_string(&default)?; + let value: config::Value = serde_json::from_str(buf.as_str())?; + let mut merger = Self::merge_subconfig(config)?; + merger.set_default(Self::main_key_prefix(), value)?; + Ok(merger.get(Self::main_key_prefix())?) + } +} +impl DefaultConfigLoader for C where C: ConfigPath + Default + serde::ser::Serialize + for<'de> serde::de::Deserialize<'de> +{} + +//------------------------------------- Configuration errors --------------------------------------// + +#[derive(Debug)] +pub struct ConfigurationError { + field: String, + message: String, +} + +impl ConfigurationError { + pub fn new(field: &str, msg: &str) -> Self { + ConfigurationError { + field: String::from(field), + message: String::from(msg), + } + } +} + +impl Display for ConfigurationError { + fn fmt(&self, f: &mut Formatter) -> Result<(), std::fmt::Error> { + f.write_str(&format!("Invalid value for {}: {}", self.field, self.message)) + } +} + +impl Error for ConfigurationError {} +impl From for ConfigurationError { + fn from(err: config::ConfigError) -> Self { + use config::ConfigError; + match err { + ConfigError::FileParse { uri, cause } if uri.is_some() => Self { + field: uri.unwrap(), + message: cause.to_string(), + }, + ConfigError::Type { ref key, .. } => Self { + field: format!("{:?}", key), + message: err.to_string(), + }, + ConfigError::NotFound(key) => Self { + field: key, + message: "required key not found".to_string(), + }, + x => Self::new("", x.to_string().as_str()), + } + } +} +impl From for ConfigurationError { + fn from(err: serde_json::error::Error) -> Self { + Self { + field: "".to_string(), + message: err.to_string(), + } + } +} + +#[cfg(test)] +mod test { + use crate::ConfigurationError; + + #[test] + fn configuration_error() { + let e = ConfigurationError::new("test", "is a string"); + assert_eq!(e.to_string(), "Invalid value for test: is a string"); + } + + use super::*; + use serde::{Deserialize, Serialize}; + + // test NetworkConfigPath both with Default and withou Default + #[derive(Serialize, Deserialize)] + struct SubTari { + monero: String, + } + impl Default for SubTari { + fn default() -> Self { + Self { + monero: "isprivate".into(), + } + } + } + #[derive(Default, Serialize, Deserialize)] + struct SuperTari { + #[serde(flatten)] + pub within: SubTari, + pub over: SubTari, + #[serde(default = "serde_default_string")] + bitcoin: String, + } + fn serde_default_string() -> String { + "ispublic".into() + } + impl NetworkConfigPath for SuperTari { + fn main_key_prefix() -> &'static str { + "crypto" + } + } + + #[test] + fn default_network_config_loader() -> anyhow::Result<()> { + let mut config = Config::new(); + + config.set("crypto.monero", "isnottari")?; + config.set("crypto.mainnet.monero", "isnottaritoo")?; + config.set("crypto.mainnet.bitcoin", "isnottaritoo")?; + let crypto = ::load_from(&config)?; + // no use_network value + // [X] crypto.mainnet, [X] crypto = "isnottari", [X] Default + assert_eq!(crypto.within.monero, "isnottari"); + // [ ] crypto.mainnet, [ ] crypto, [X] Default = "isprivate" + assert_eq!(crypto.over.monero, "isprivate"); + // [X] crypto.mainnet, [ ] crypto, [X] Default = "", [X] serde(default) + assert_eq!(crypto.bitcoin, ""); + + config.set("crypto.over.monero", "istari")?; + let crypto = ::load_from(&config)?; + // [ ] crypto.mainnet, [X] crypto = "istari", [X] Default + assert_eq!(crypto.over.monero, "istari"); + + config.set("crypto.use_network", "mainnet")?; + // use_network = mainnet + let crypto = ::load_from(&config)?; + // [X] crypto.mainnet = "isnottaritoo", [X] crypto, [X] Default + assert_eq!(crypto.within.monero, "isnottaritoo"); + // [X] crypto.mainnet = "isnottaritoo", [ ] crypto, [X] serde(default), [X] Default + assert_eq!(crypto.bitcoin, "isnottaritoo"); + // [ ] crypto.mainnet, [X] crypto = "istari", [X] Default + assert_eq!(crypto.over.monero, "istari"); + + config.set("crypto.use_network", "wrong_network")?; + assert!(::load_from(&config).is_err()); + + Ok(()) + } + + #[test] + fn network_config_loader() -> anyhow::Result<()> { + let mut config = Config::new(); + + // no use_network value + config.set("crypto.monero", "isnottari")?; + config.set("crypto.mainnet.bitcoin", "isnottaritoo")?; + // [X] crypto.monero [X] crypto.bitcoin(serde) [ ] crypto.over.monero + assert!(::load_from(&config).is_err()); + + // [X] crypto.monero [X] crypto.bitcoin(serde) [ ] crypto.over.monero [X] mainnet.* + config.set("crypto.mainnet.monero", "isnottaritoo")?; + config.set("crypto.mainnet.over.monero", "istari")?; + assert!(::load_from(&config).is_err()); + + // use_network = mainnet + config.set("crypto.use_network", "mainnet")?; + let crypto = ::load_from(&config)?; + // [X] crypto.mainnet = "isnottaritoo", [X] crypto, [X] Default + assert_eq!(crypto.within.monero, "isnottaritoo"); + // [X] crypto.mainnet = "isnottaritoo", [ ] crypto, [X] serde(default), [X] Default + assert_eq!(crypto.bitcoin, "isnottaritoo"); + // [X] crypto.mainnet = "istari", [ ] crypto, [X] Default + assert_eq!(crypto.over.monero, "istari"); + + let mut config = Config::new(); + // no use_network value + config.set("crypto.monero", "isnottari")?; + config.set("crypto.over.monero", "istari")?; + let crypto = ::load_from(&config)?; + // [ ] crypto.mainnet, [X] crypto = "isnottari" + assert_eq!(crypto.within.monero, "isnottari"); + // [ ] crypto.mainnet, [ ] crypto, [X] serde(default) = "ispublic" + assert_eq!(crypto.bitcoin, "ispublic"); + // [ ] crypto.mainnet, [X] crypto = "istari" + assert_eq!(crypto.over.monero, "istari"); + + config.set("crypto.bitcoin", "isnottaritoo")?; + let crypto = ::load_from(&config)?; + // [ ] crypto.mainnet, [X] crypto = "isnottaritoo", [X] serde(default) + assert_eq!(crypto.bitcoin, "isnottaritoo"); + + Ok(()) + } + + // test ConfigPath reading only from main section + #[derive(Serialize, Deserialize)] + struct OneConfig { + param1: String, + #[serde(default = "param2_serde_default")] + param2: String, + } + impl Default for OneConfig { + fn default() -> Self { + Self { + param1: "param1".into(), + param2: "param2".into(), + } + } + } + fn param2_serde_default() -> String { + "alwaysset".into() + } + impl ConfigPath for OneConfig { + fn main_key_prefix() -> &'static str { + "one" + } + + fn overload_key_prefix(_: &Config) -> Result, ConfigurationError> { + Ok(None) + } + } + + #[test] + fn config_loaders() -> anyhow::Result<()> { + let mut config = Config::new(); + + // no use_network value + // [ ] one.param1(default) [X] one.param1(default) [ ] one.param2 [X] one.param2(serde) + assert!(::load_from(&config).is_err()); + // [ ] one.param1(default) [X] one.param1(default) [ ] one.param2 [X] one.param2(default) + let one = ::load_from(&config)?; + assert_eq!(one.param1, OneConfig::default().param1); + assert_eq!(one.param2, OneConfig::default().param2); + + config.set("one.param1", "can load from main section")?; + let one = ::load_from(&config)?; + assert_eq!(one.param1, "can load from main section"); + assert_eq!(one.param2, "param2"); + + let one = ::load_from(&config)?; + assert_eq!(one.param1, "can load from main section"); + assert_eq!(one.param2, param2_serde_default()); + + config.set("one.param2", "specific param overloads serde")?; + let one = ::load_from(&config)?; + assert_eq!(one.param2, "specific param overloads serde"); + + Ok(()) + } +} diff --git a/common/src/configuration/mod.rs b/common/src/configuration/mod.rs new file mode 100644 index 0000000000..8e9e356d63 --- /dev/null +++ b/common/src/configuration/mod.rs @@ -0,0 +1,46 @@ +//! # Configuration of tari applications +//! +//! Tari application consist of `common`, `base_node`, `wallet` and `application` configuration sections. +//! All tari apps follow traits implemented in this crate for ease and automation, for instance managing config files, +//! defaults configuration, overloading settings from subsections. +//! +//! ## Submodules +//! +//! - [bootstrap] - build CLI and manage/load configuration with [ConfigBootsrap] struct +//! - [global] - load GlobalConfig for Tari +//! - [loader] - build and load configuration modules in a tari-way +//! - [utils] - utilities for working with configuration +//! +//! ## Configuration file +//! +//! The tari configuration file (config.yml) is intended to be a single config file for all Tari desktop apps to use +//! to pull configuration variables, whether it's a testnet base node; wallet; validator node etc. +//! +//! The file lives in ~/.tari by default and has sections which will allow a specific app to determine +//! the config values it needs, e.g. +//! +//! ```toml +//! [common] +//! # Globally common variables +//! ... +//! [base_node] +//! # common vars for all base_node instances +//! [base_node.rincewind] +//! # overrides for rincewnd testnet +//! [base_node.mainnet] +//! # overrides for mainnet +//! [wallet] +//! [wallet.rincewind] +//! # etc.. +//! ``` + +pub mod bootstrap; +pub mod error; +pub mod global; +pub mod loader; +pub mod utils; + +pub use bootstrap::ConfigBootstrap; +pub use global::{CommsTransport, DatabaseType, GlobalConfig, Network, SocksAuthentication, TorControlAuthentication}; +pub use loader::ConfigurationError; +pub use utils::{default_config, install_default_config_file, load_configuration}; diff --git a/common/src/configuration/utils.rs b/common/src/configuration/utils.rs new file mode 100644 index 0000000000..1ccb521cd1 --- /dev/null +++ b/common/src/configuration/utils.rs @@ -0,0 +1,229 @@ +use crate::{dir_utils::default_subdir, ConfigBootstrap, LOG_TARGET}; +use config::Config; +use log::{debug, info}; +use multiaddr::{Multiaddr, Protocol}; +use std::{fs, path::Path}; + +//------------------------------------- Main API functions --------------------------------------// + +pub fn load_configuration(bootstrap: &ConfigBootstrap) -> Result { + debug!( + target: LOG_TARGET, + "Loading configuration file from {}", + bootstrap.config.to_str().unwrap_or("[??]") + ); + let mut cfg = default_config(bootstrap); + // Load the configuration file + let filename = bootstrap + .config + .to_str() + .ok_or_else(|| "Invalid config file path".to_string())?; + let config_file = config::File::with_name(filename); + match cfg.merge(config_file) { + Ok(_) => { + info!(target: LOG_TARGET, "Configuration file loaded."); + Ok(cfg) + }, + Err(e) => Err(format!( + "There was an error loading the configuration file. {}", + e.to_string() + )), + } +} + +/// Installs a new configuration file template, copied from `rincewind-simple.toml` to the given path. +pub fn install_default_config_file(path: &Path) -> Result<(), std::io::Error> { + let source = include_str!("../../config/presets/rincewind-simple.toml"); + fs::write(path, source) +} + +//------------------------------------- Configuration file defaults --------------------------------------// + +/// Generate the global Tari configuration instance. +/// +/// The `Config` object that is returned holds _all_ the default values possible in the `~/.tari.config.toml` file. +/// These will typically be overridden by userland settings in envars, the config file, or the command line. +pub fn default_config(bootstrap: &ConfigBootstrap) -> Config { + let mut cfg = Config::new(); + let local_ip_addr = get_local_ip().unwrap_or_else(|| "/ip4/1.2.3.4".parse().unwrap()); + + // Common settings + cfg.set_default("common.message_cache_size", 10).unwrap(); + cfg.set_default("common.message_cache_ttl", 1440).unwrap(); + cfg.set_default("common.peer_whitelist", Vec::::new()).unwrap(); + cfg.set_default("common.liveness_max_sessions", 0).unwrap(); + cfg.set_default( + "common.peer_database ", + default_subdir("peers", Some(&bootstrap.base_path)), + ) + .unwrap(); + cfg.set_default("common.blacklist_ban_period ", 1440).unwrap(); + + // Wallet settings + cfg.set_default("wallet.grpc_enabled", false).unwrap(); + cfg.set_default("wallet.grpc_address", "tcp://127.0.0.1:18040").unwrap(); + cfg.set_default( + "wallet.wallet_file", + default_subdir("wallet/wallet.dat", Some(&bootstrap.base_path)), + ) + .unwrap(); + + //---------------------------------- Mainnet Defaults --------------------------------------------// + + cfg.set_default("base_node.network", "mainnet").unwrap(); + + // Mainnet base node defaults + cfg.set_default("base_node.mainnet.db_type", "lmdb").unwrap(); + cfg.set_default("base_node.mainnet.peer_seeds", Vec::::new()) + .unwrap(); + cfg.set_default("base_node.mainnet.block_sync_strategy", "ViaBestChainMetadata") + .unwrap(); + cfg.set_default("base_node.mainnet.blocking_threads", 4).unwrap(); + cfg.set_default("base_node.mainnet.core_threads", 6).unwrap(); + cfg.set_default( + "base_node.mainnet.data_dir", + default_subdir("mainnet/", Some(&bootstrap.base_path)), + ) + .unwrap(); + cfg.set_default( + "base_node.mainnet.identity_file", + default_subdir("mainnet/node_id.json", Some(&bootstrap.base_path)), + ) + .unwrap(); + cfg.set_default( + "base_node.mainnet.tor_identity_file", + default_subdir("mainnet/tor.json", Some(&bootstrap.base_path)), + ) + .unwrap(); + cfg.set_default( + "base_node.mainnet.wallet_identity_file", + default_subdir("mainnet/wallet-identity.json", Some(&bootstrap.base_path)), + ) + .unwrap(); + cfg.set_default( + "base_node.mainnet.wallet_tor_identity_file", + default_subdir("mainnet/wallet-tor.json", Some(&bootstrap.base_path)), + ) + .unwrap(); + cfg.set_default( + "base_node.mainnet.public_address", + format!("{}/tcp/18041", local_ip_addr), + ) + .unwrap(); + cfg.set_default("base_node.mainnet.grpc_enabled", false).unwrap(); + cfg.set_default("base_node.mainnet.grpc_address", "tcp://127.0.0.1:18041") + .unwrap(); + cfg.set_default("base_node.mainnet.enable_mining", false).unwrap(); + cfg.set_default("base_node.mainnet.num_mining_threads", 1).unwrap(); + + //---------------------------------- Rincewind Defaults --------------------------------------------// + + cfg.set_default("base_node.rincewind.db_type", "lmdb").unwrap(); + cfg.set_default("base_node.rincewind.peer_seeds", Vec::::new()) + .unwrap(); + cfg.set_default("base_node.rincewind.block_sync_strategy", "ViaBestChainMetadata") + .unwrap(); + cfg.set_default("base_node.rincewind.blocking_threads", 4).unwrap(); + cfg.set_default("base_node.rincewind.core_threads", 4).unwrap(); + cfg.set_default( + "base_node.rincewind.data_dir", + default_subdir("rincewind/", Some(&bootstrap.base_path)), + ) + .unwrap(); + cfg.set_default( + "base_node.rincewind.tor_identity_file", + default_subdir("rincewind/tor.json", Some(&bootstrap.base_path)), + ) + .unwrap(); + cfg.set_default( + "base_node.rincewind.wallet_identity_file", + default_subdir("rincewind/wallet-identity.json", Some(&bootstrap.base_path)), + ) + .unwrap(); + cfg.set_default( + "base_node.rincewind.wallet_tor_identity_file", + default_subdir("rincewind/wallet-tor.json", Some(&bootstrap.base_path)), + ) + .unwrap(); + cfg.set_default( + "base_node.rincewind.identity_file", + default_subdir("rincewind/node_id.json", Some(&bootstrap.base_path)), + ) + .unwrap(); + cfg.set_default( + "base_node.rincewind.public_address", + format!("{}/tcp/18141", local_ip_addr), + ) + .unwrap(); + cfg.set_default("base_node.rincewind.grpc_enabled", false).unwrap(); + cfg.set_default("base_node.rincewind.grpc_address", "tcp://127.0.0.1:18141") + .unwrap(); + cfg.set_default("base_node.rincewind.enable_mining", false).unwrap(); + cfg.set_default("base_node.rincewind.num_mining_threads", 1).unwrap(); + + set_transport_defaults(&mut cfg); + + cfg +} + +fn set_transport_defaults(cfg: &mut Config) { + // Mainnet + // Default transport for mainnet is tcp + cfg.set_default("base_node.mainnet.transport", "tcp").unwrap(); + cfg.set_default("base_node.mainnet.tcp_listener_address", "/ip4/0.0.0.0/tcp/18089") + .unwrap(); + + cfg.set_default("base_node.mainnet.tor_control_address", "/ip4/127.0.0.1/tcp/9051") + .unwrap(); + cfg.set_default("base_node.mainnet.tor_control_auth", "none").unwrap(); + cfg.set_default("base_node.mainnet.tor_forward_address", "/ip4/127.0.0.1/tcp/18141") + .unwrap(); + cfg.set_default("base_node.mainnet.tor_onion_port", "18141").unwrap(); + + cfg.set_default("base_node.mainnet.socks5_proxy_address", "/ip4/0.0.0.0/tcp/9050") + .unwrap(); + cfg.set_default("base_node.mainnet.socks5_listener_address", "/ip4/0.0.0.0/tcp/18099") + .unwrap(); + cfg.set_default("base_node.mainnet.socks5_auth", "none").unwrap(); + + // rincewind + // Default transport for rincewind is tcp + cfg.set_default("base_node.rincewind.transport", "tcp").unwrap(); + cfg.set_default("base_node.rincewind.tcp_listener_address", "/ip4/0.0.0.0/tcp/18189") + .unwrap(); + + cfg.set_default("base_node.rincewind.tor_control_address", "/ip4/127.0.0.1/tcp/9051") + .unwrap(); + cfg.set_default("base_node.rincewind.tor_control_auth", "none").unwrap(); + cfg.set_default("base_node.rincewind.tor_forward_address", "/ip4/127.0.0.1/tcp/18041") + .unwrap(); + cfg.set_default("base_node.rincewind.tor_onion_port", "18141").unwrap(); + + cfg.set_default("base_node.rincewind.socks5_proxy_address", "/ip4/0.0.0.0/tcp/9150") + .unwrap(); + cfg.set_default("base_node.rincewind.socks5_listener_address", "/ip4/0.0.0.0/tcp/18199") + .unwrap(); + cfg.set_default("base_node.rincewind.socks5_auth", "none").unwrap(); +} + +fn get_local_ip() -> Option { + use std::net::IpAddr; + + get_if_addrs::get_if_addrs().ok().and_then(|if_addrs| { + if_addrs + .into_iter() + .find(|if_addr| !if_addr.is_loopback()) + .map(|if_addr| { + let mut addr = Multiaddr::empty(); + match if_addr.ip() { + IpAddr::V4(ip) => { + addr.push(Protocol::Ip4(ip)); + }, + IpAddr::V6(ip) => { + addr.push(Protocol::Ip6(ip)); + }, + } + addr + }) + }) +} diff --git a/common/src/dir_utils.rs b/common/src/dir_utils.rs index 5c6e346b6a..abf7e4346a 100644 --- a/common/src/dir_utils.rs +++ b/common/src/dir_utils.rs @@ -45,11 +45,11 @@ pub fn default_subdir(path: &str, base_dir: Option<&PathBuf>) -> String { } pub fn default_path(filename: &str, base_path: Option<&PathBuf>) -> PathBuf { - let mut home = base_path.map(|base_path| base_path.clone()).unwrap_or_else(|| { + let mut home = base_path.cloned().unwrap_or_else(|| { let mut home = dirs::home_dir().unwrap_or_else(|| PathBuf::from(".")); home.push(".tari"); home }); home.push(filename); - home.into() + home } diff --git a/common/src/lib.rs b/common/src/lib.rs index 9317d5b5aa..6d14e06664 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -31,210 +31,61 @@ //! //! 1. Command-line argument //! 2. Environment variable -//! 3. `config.toml` file value +//! 3. `config.toml` file value (see details: [configuration]) //! 4. Configuration default //! //! The utilities exposed in this crate are opinionated, but flexible. In general, all data is stored in a `.tari` //! folder under your home folder. //! -//! ### Example - Loading and deserializing the global config file +//! ## Custom application configuration +//! +//! Tari configuration file allows adding custom application specific sections. Tari is using [config] crate +//! to load configurations and gives access to [`config::Config`] struct so that apps might be flexible. +//! Though as tari apps follow certain configurability assumptions, tari_common provides helper traits +//! which automate those with minimal code. +//! +//! ## CLI helpers +//! +//! Bootstrapping tari configuration files might be customized via CLI or env settings. To help with building +//! tari-enabled CLI from scratch as easy as possible this crate exposes [ConfigBootstrap] struct which +//! implements [structopt::StructOpt] trait and can be easily reused in any CLI. +//! +//! ## Example - CLI which is loading and deserializing the global config file //! //! ```edition2018 //! # use tari_common::*; -//! let bootstrap: ConfigBootstrap = ConfigBootstrap::default(); -//! let config = default_config(&bootstrap); -//! let config = GlobalConfig::convert_from(config).unwrap(); -//! assert_eq!(config.network, Network::MainNet); -//! assert_eq!(config.blocking_threads, 4); +//! # use tari_test_utils::random::string; +//! # use tempdir::TempDir; +//! # use structopt::StructOpt; +//! let mut args = ConfigBootstrap::from_args(); +//! # let temp_dir = TempDir::new(string(8).as_str()).unwrap(); +//! # args.base_path = temp_dir.path().to_path_buf(); +//! # args.init = true; +//! args.init_dirs(); +//! let config = args.load_configuration().unwrap(); +//! let global = GlobalConfig::convert_from(config).unwrap(); +//! assert_eq!(global.network, Network::Rincewind); +//! assert_eq!(global.blocking_threads, 4); +//! # std::fs::remove_dir_all(temp_dir).unwrap(); //! ``` -use clap::ArgMatches; -use std::path::{Path, PathBuf}; - -mod configuration; +pub mod configuration; #[macro_use] mod logging; pub mod protobuf_build; +pub use configuration::error::ConfigError; pub mod dir_utils; pub use configuration::{ - default_config, - install_default_config_file, - load_configuration, - CommsTransport, - ConfigExtractor, - ConfigurationError, - DatabaseType, - GlobalConfig, - Network, - SocksAuthentication, - TorControlAuthentication, + bootstrap::{bootstrap_config_from_cli, install_configuration, ConfigBootstrap}, + global::{CommsTransport, DatabaseType, GlobalConfig, Network, SocksAuthentication, TorControlAuthentication}, + loader::{ConfigExtractor, ConfigLoader, ConfigPath, ConfigurationError, DefaultConfigLoader, NetworkConfigPath}, + utils::{default_config, install_default_config_file, load_configuration}, }; pub use logging::initialize_logging; -use std::io; + pub const DEFAULT_CONFIG: &str = "config.toml"; pub const DEFAULT_LOG_CONFIG: &str = "log4rs.yml"; -/// A minimal parsed configuration object that's used to bootstrap the main Configuration. -pub struct ConfigBootstrap { - pub base_path: PathBuf, - pub config: PathBuf, - /// The path to the log configuration file. It is set using the following precedence set: - /// 1. from the command-line parameter, - /// 2. from the `TARI_LOG_CONFIGURATION` environment variable, - /// 3. from a default value, usually `~/.tari/log4rs.yml` (or OS equivalent). - pub log_config: PathBuf, -} - -impl Default for ConfigBootstrap { - fn default() -> Self { - ConfigBootstrap { - base_path: dir_utils::default_path("", None), - config: dir_utils::default_path(DEFAULT_CONFIG, None), - log_config: dir_utils::default_path(DEFAULT_LOG_CONFIG, None), - } - } -} - -pub fn bootstrap_config_from_cli(matches: &ArgMatches) -> ConfigBootstrap { - let base_path = matches - .value_of("base_dir") - .map(PathBuf::from) - .unwrap_or_else(|| dir_utils::default_path("", None)); - - // Create the tari data directory - if let Err(e) = dir_utils::create_data_directory(Some(&base_path)) { - println!( - "We couldn't create a default Tari data directory and have to quit now. This makes us sad :(\n {}", - e.to_string() - ); - std::process::exit(1); - } - - let config = matches - .value_of("config") - .map(PathBuf::from) - .unwrap_or_else(|| dir_utils::default_path(DEFAULT_CONFIG, Some(&base_path))); - - let log_config = matches - .value_of("log_config") - .map(PathBuf::from) - .or_else(|| Some(base_path.clone().join(DEFAULT_LOG_CONFIG))); - let log_config = logging::get_log_configuration_path(log_config); - - if !config.exists() { - let install = if !matches.is_present("init") { - prompt("Config file does not exist. We can create a default one for you now, or you can say 'no' here, \ - and generate a customised one at https://config.tari.com.\n\ - Would you like to try the default configuration (Y/n)?") - } else { - true - }; - - if install { - println!("Installing new config file at {}", config.to_str().unwrap_or("[??]")); - install_configuration(&config, configuration::install_default_config_file); - } - } - - if !log_config.exists() { - let install = if !matches.is_present("init") { - prompt("Logging configuration file does not exist. Would you like to create a new one (Y/n)?") - } else { - true - }; - if install { - println!( - "Installing new logfile configuration at {}", - log_config.to_str().unwrap_or("[??]") - ); - install_configuration(&log_config, logging::install_default_logfile_config); - } - } - ConfigBootstrap { - base_path, - config, - log_config, - } -} - -fn prompt(question: &str) -> bool { - println!("{}", question); - let mut input = "".to_string(); - io::stdin().read_line(&mut input).unwrap(); - let input = input.trim().to_lowercase(); - input == "y" || input.is_empty() -} - -pub fn install_configuration(path: &Path, installer: F) -where F: Fn(&Path) -> Result<(), std::io::Error> { - if let Err(e) = installer(path) { - println!( - "We could not install a new configuration file in {}: {}", - path.to_str().unwrap_or("?"), - e.to_string() - ) - } -} - -#[cfg(test)] -mod test { - use crate::{bootstrap_config_from_cli, dir_utils, dir_utils::default_subdir, load_configuration}; - use clap::clap_app; - use tari_test_utils::random::string; - use tempdir::TempDir; - - #[test] - fn test_bootstrap_config_from_cli_and_load_configuration() { - let temp_dir = TempDir::new(string(8).as_str()).unwrap(); - let dir = &temp_dir.path().to_path_buf(); - // Create test folder - dir_utils::create_data_directory(Some(dir)).unwrap(); - - // Create command line test data - let matches = clap_app!(myapp => - (version: "0.0.10") - (author: "The Tari Community") - (about: "The reference Tari cryptocurrency base node implementation") - (@arg base_dir: -b --base_dir +takes_value "A path to a directory to store your files") - (@arg config: -c --config +takes_value "A path to the configuration file to use (config.toml)") - (@arg log_config: -l --log_config +takes_value "A path to the logfile configuration (log4rs.yml))") - (@arg init: --init "Create a default configuration file if it doesn't exist") - (@arg create_id: --create_id "Create and save new node identity if one doesn't exist ") - ) - .get_matches_from(vec![ - "", - "--base_dir", - default_subdir("", Some(dir)).as_str(), - "--init", - "--create_id", - ]); - - // Load bootstrap - let bootstrap = bootstrap_config_from_cli(&matches); - let config_exists = std::path::Path::new(&bootstrap.config).exists(); - let log_config_exists = std::path::Path::new(&bootstrap.log_config).exists(); - // Load and apply configuration file - let cfg = load_configuration(&bootstrap); - - // Cleanup test data - if std::path::Path::new(&dir_utils::default_subdir("", Some(dir))).exists() { - std::fs::remove_dir_all(&dir_utils::default_subdir("", Some(dir))).unwrap(); - } - - // Assert results - assert!(config_exists); - assert!(log_config_exists); - assert!(&cfg.is_ok()); - } - - #[test] - fn check_homedir_is_used_by_default() { - dir_utils::create_data_directory(None).unwrap(); - assert_eq!( - dirs::home_dir().unwrap().join(".tari"), - dir_utils::default_path("", None) - ); - } -} +pub(crate) const LOG_TARGET: &str = "common::config"; diff --git a/comms/Cargo.toml b/comms/Cargo.toml index 2875b8e53e..37a13794a9 100644 --- a/comms/Cargo.toml +++ b/comms/Cargo.toml @@ -6,12 +6,12 @@ repository = "https://github.com/tari-project/tari" homepage = "https://tari.com" readme = "README.md" license = "BSD-3-Clause" -version = "0.0.10" +version = "0.1.0" edition = "2018" [dependencies] tari_crypto = { version = "^0.3" } -tari_storage = { version="^0.0", path = "../infrastructure/storage" } +tari_storage = { version="^0.1", path = "../infrastructure/storage" } tari_shutdown = { version="^0.0", path = "../infrastructure/shutdown" } bitflags = "1.0.4" @@ -48,4 +48,4 @@ tokio-macros = "0.2.3" tempdir = "0.3.7" [build-dependencies] -tari_common = { version = "^0.0", path="../common"} +tari_common = { version = "^0.1", path="../common"} diff --git a/comms/dht/Cargo.toml b/comms/dht/Cargo.toml index 2f2679fd99..c618d95eaf 100644 --- a/comms/dht/Cargo.toml +++ b/comms/dht/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tari_comms_dht" -version = "0.0.10" +version = "0.1.0" authors = ["The Tari Development Community"] description = "Tari comms DHT module" repository = "https://github.com/tari-project/tari" @@ -13,15 +13,18 @@ edition = "2018" test-mocks = [] [dependencies] -tari_comms = { version = "^0.0", path = "../"} +tari_comms = { version = "^0.1", path = "../"} tari_crypto = { version = "^0.3" } +tari_utilities = { version = "^0.1" } tari_shutdown = { version = "^0.0", path = "../../infrastructure/shutdown"} -tari_storage = { version = "^0.0", path = "../../infrastructure/storage"} +tari_storage = { version = "^0.1", path = "../../infrastructure/storage"} bitflags = "1.2.0" bytes = "0.4.12" chrono = "0.4.9" derive-error = "0.0.4" +diesel = {version="1.4", features = ["sqlite", "serde_json", "chrono"]} +diesel_migrations = "1.4" digest = "0.8.1" futures= {version= "^0.3.1"} log = "0.4.8" @@ -34,6 +37,7 @@ serde_repr = "0.1.5" tokio = {version="0.2.10", features=["rt-threaded", "blocking"]} tower= "0.3.0" ttl_cache = "0.5.1" + # tower-filter dependencies pin-project = "0.4" @@ -55,4 +59,4 @@ futures-util = "^0.3.1" lazy_static = "1.4.0" [build-dependencies] -tari_common = { version = "^0.0", path="../../common"} +tari_common = { version = "^0.1", path="../../common"} diff --git a/comms/dht/diesel.toml b/comms/dht/diesel.toml new file mode 100644 index 0000000000..92267c829f --- /dev/null +++ b/comms/dht/diesel.toml @@ -0,0 +1,5 @@ +# For documentation on how to configure this file, +# see diesel.rs/guides/configuring-diesel-cli + +[print_schema] +file = "src/schema.rs" diff --git a/comms/dht/examples/memorynet.rs b/comms/dht/examples/memorynet.rs index 94c1d63b4b..77b63952ad 100644 --- a/comms/dht/examples/memorynet.rs +++ b/comms/dht/examples/memorynet.rs @@ -36,9 +36,10 @@ //! `RUST_BACKTRACE=1 RUST_LOG=trace cargo run --example memorynet 2> /tmp/debug.log` // Size of network -const NUM_NODES: usize = 40; +const NUM_NODES: usize = 39; // Must be at least 2 const NUM_WALLETS: usize = 8; +const QUIET_MODE: bool = true; mod memory_net; @@ -69,7 +70,6 @@ use tari_comms::{ PeerConnection, }; use tari_comms_dht::{envelope::NodeDestination, inbound::DecryptedDhtMessage, Dht, DhtBuilder}; -use tari_crypto::tari_utilities::ByteArray; use tari_storage::{lmdb_store::LMDBBuilder, LMDBWrapper}; use tari_test_utils::{paths::create_temporary_data_path, random}; use tokio::{runtime, time}; @@ -177,14 +177,12 @@ async fn main() { node, seed_node ); node.dht.dht_requester().send_join().await.unwrap(); - seed_node.expect_peer_connection(&node.get_node_id()).await.unwrap(); - println!(); } take_a_break().await; - peer_list_summary(&nodes).await; + // peer_list_summary(&nodes).await; banner!( "Now, {} wallets are going to join from a random base node.", @@ -206,27 +204,32 @@ async fn main() { .expect_peer_connection(&wallet.get_node_id()) .await .unwrap(); - println!(); } - drain_messaging_events(&mut messaging_events_rx, false).await; + let mut total_messages = 0; + total_messages += drain_messaging_events(&mut messaging_events_rx, false).await; take_a_break().await; - drain_messaging_events(&mut messaging_events_rx, false).await; - - peer_list_summary(&wallets).await; + total_messages += drain_messaging_events(&mut messaging_events_rx, false).await; - discovery(&wallets, &mut messaging_events_rx, false, true).await; + network_peer_list_stats(&nodes, &wallets).await; + network_connectivity_stats(&nodes, &wallets).await; - take_a_break().await; - drain_messaging_events(&mut messaging_events_rx, false).await; + { + let all_known_peers = seed_node.comms.peer_manager().all().await.unwrap(); + println!("Seed node knows {} peers", all_known_peers.len()); + } + // peer_list_summary(&wallets).await; - discovery(&wallets, &mut messaging_events_rx, true, false).await; + total_messages += discovery(&wallets, &mut messaging_events_rx).await; take_a_break().await; - drain_messaging_events(&mut messaging_events_rx, false).await; + total_messages += drain_messaging_events(&mut messaging_events_rx, false).await; - discovery(&wallets, &mut messaging_events_rx, false, false).await; + let random_wallet = wallets.remove(OsRng.gen_range(0, wallets.len() - 1)); + total_messages += + do_store_and_forward_discovery(random_wallet, &wallets, messaging_events_tx, &mut messaging_events_rx).await; + println!("{} messages sent in total across the network", total_messages); banner!("That's it folks! Network is shutting down..."); shutdown_all(nodes).await; @@ -238,36 +241,18 @@ async fn shutdown_all(nodes: Vec) { future::join_all(tasks).await; } -async fn discovery( - wallets: &[TestNode], - messaging_events_rx: &mut MessagingEventRx, - use_network_region: bool, - use_destination_node_id: bool, -) -{ +async fn discovery(wallets: &[TestNode], messaging_events_rx: &mut MessagingEventRx) -> usize { let mut successes = 0; + let mut total_messages = 0; let mut total_time = Duration::from_secs(0); for i in 0..wallets.len() - 1 { let wallet1 = wallets.get(i).unwrap(); let wallet2 = wallets.get(i + 1).unwrap(); - banner!("'{}' is going to try discover '{}'.", wallet1, wallet2); - - peer_list_summary(&[wallet1, wallet2]).await; + banner!("🌎 '{}' is going to try discover '{}'.", wallet1, wallet2); - let mut destination = NodeDestination::Unknown; - if use_network_region { - let mut new_node_id = [0; 13]; - let node_id = wallet2.get_node_id(); - let buf = &mut new_node_id[..10]; - buf.copy_from_slice(&node_id.as_bytes()[..10]); - let regional_node_id = NodeId::from_bytes(&new_node_id).unwrap(); - destination = NodeDestination::NodeId(Box::new(regional_node_id)); - } - - let mut node_id_dest = None; - if use_destination_node_id { - node_id_dest = Some(wallet2.get_node_id()); + if !QUIET_MODE { + peer_list_summary(&[wallet1, wallet2]).await; } let start = Instant::now(); @@ -276,51 +261,48 @@ async fn discovery( .discovery_service_requester() .discover_peer( Box::new(wallet2.node_identity().public_key().clone()), - node_id_dest, - destination, + wallet2.node_identity().node_id().clone().into(), ) .await; - let end = Instant::now(); - banner!("Discovery is done."); - match discovery_result { Ok(peer) => { successes += 1; - total_time += end - start; - println!( + total_time += start.elapsed(); + banner!( "⚡️🎉😎 '{}' discovered peer '{}' ({}) in {}ms", wallet1, get_name(&peer.node_id), peer, - (end - start).as_millis() + start.elapsed().as_millis() ); - println!(); time::delay_for(Duration::from_secs(5)).await; - drain_messaging_events(messaging_events_rx, false).await; + total_messages += drain_messaging_events(messaging_events_rx, false).await; }, Err(err) => { - println!( + banner!( "💩 '{}' failed to discover '{}' after {}ms because '{:?}'", wallet1, wallet2, - (end - start).as_millis(), + start.elapsed().as_millis(), err ); - println!(); time::delay_for(Duration::from_secs(5)).await; - drain_messaging_events(messaging_events_rx, true).await; + total_messages += drain_messaging_events(messaging_events_rx, true).await; }, } } banner!( - "✨ The set of discoveries succeeded {}% of the time and took a total of {:.1}s.", - (successes as f32 / (wallets.len() - 1) as f32) * 100.0, - total_time.as_secs_f32() + "✨ The set of discoveries succeeded {} out of {} times and took a total of {:.1}s with {} messages sent.", + successes, + wallets.len() - 1, + total_time.as_secs_f32(), + total_messages ); + total_messages } async fn peer_list_summary<'a, I: IntoIterator, T: AsRef>(network: I) { @@ -330,7 +312,7 @@ async fn peer_list_summary<'a, I: IntoIterator, T: AsRef>(ne .as_ref() .comms .peer_manager() - .closest_peers(node_identity.node_id(), 10, &[]) + .closest_peers(node_identity.node_id(), 10, &[], None) .await .unwrap(); let mut table = Table::new(); @@ -356,7 +338,130 @@ async fn peer_list_summary<'a, I: IntoIterator, T: AsRef>(ne } } -async fn drain_messaging_events(messaging_rx: &mut MessagingEventRx, show_logs: bool) { +async fn network_peer_list_stats(nodes: &[TestNode], wallets: &[TestNode]) { + let mut stats = HashMap::::with_capacity(wallets.len()); + for wallet in wallets { + let mut num_known = 0; + for node in nodes { + if node + .comms + .peer_manager() + .exists(wallet.node_identity().public_key()) + .await + { + num_known += 1; + } + } + stats.insert(get_name(wallet.node_identity().node_id()), num_known); + } + + let mut avg = Vec::with_capacity(wallets.len()); + for (n, v) in stats { + let perc = v as f32 / nodes.len() as f32; + avg.push(perc); + println!( + "{} is known by {} out of {} nodes ({:.2}%)", + n, + v, + nodes.len(), + perc * 100.0 + ); + } + println!( + "Average {}%", + avg.into_iter().sum::() / wallets.len() as f32 * 100.0 + ); +} + +async fn network_connectivity_stats(nodes: &[TestNode], wallets: &[TestNode]) { + async fn display(nodes: &[TestNode]) -> (usize, usize) { + let mut total = 0; + let mut avg = Vec::new(); + for node in nodes { + let conns = node.comms.connection_manager().get_active_connections().await.unwrap(); + total += conns.len(); + avg.push(conns.len()); + + if !QUIET_MODE { + println!("{} connected to {} nodes", node, conns.len()); + for c in conns { + println!(" {} ({})", get_name(c.peer_node_id()), c.direction()); + } + } + } + (total, avg.into_iter().sum()) + } + let (mut total, mut avg) = display(nodes).await; + let (t, a) = display(wallets).await; + total += t; + avg += a; + println!("{} total connections on the network. ({} average)", total, avg); +} + +async fn do_store_and_forward_discovery( + wallet: TestNode, + wallets: &[TestNode], + messaging_tx: MessagingEventTx, + messaging_rx: &mut MessagingEventRx, +) -> usize +{ + println!("{} chosen at random to be discovered using store and forward", wallet); + let all_peers = wallet.comms.peer_manager().all().await.unwrap(); + let node_identity = wallet.comms.node_identity().clone(); + + banner!("😴 {} is going offline", wallet); + wallet.comms.shutdown().await; + + banner!( + "🌎 {} ({}) is going to attempt to discover {} ({})", + wallets[0], + wallets[0].comms.node_identity().public_key(), + get_name(node_identity.node_id()), + node_identity.public_key(), + ); + let mut first_wallet_discovery_req = wallets[0].dht.discovery_service_requester(); + + let start = Instant::now(); + let discovery_task = runtime::Handle::current().spawn({ + let node_identity = node_identity.clone(); + let dest_public_key = Box::new(node_identity.public_key().clone()); + async move { + first_wallet_discovery_req + .discover_peer(dest_public_key.clone(), NodeDestination::PublicKey(dest_public_key)) + .await + } + }); + + println!("Waiting a few seconds for discovery to propagate around the network..."); + time::delay_for(Duration::from_secs(5)).await; + + let mut total_messages = drain_messaging_events(messaging_rx, false).await; + + banner!("🤓 {} is coming back online", get_name(node_identity.node_id())); + let (tx, ims_rx) = mpsc::channel(1); + let (comms, dht) = setup_comms_dht(node_identity, create_peer_storage(all_peers), tx).await; + let wallet = TestNode::new(comms, dht, None, ims_rx, messaging_tx); + wallet.dht.dht_requester().send_join().await.unwrap(); + + total_messages += match discovery_task.await.unwrap() { + Ok(peer) => { + banner!("🎉 Discovered peer {} in {}ms", peer, start.elapsed().as_millis()); + drain_messaging_events(messaging_rx, false).await + }, + Err(err) => { + banner!( + "💩 Failed to discovery peer after {}ms using store and forward '{:?}'", + start.elapsed().as_millis(), + err + ); + drain_messaging_events(messaging_rx, true).await + }, + }; + + total_messages +} + +async fn drain_messaging_events(messaging_rx: &mut MessagingEventRx, show_logs: bool) -> usize { let drain_fut = DrainBurst::new(messaging_rx); if show_logs { let messages = drain_fut.await; @@ -383,9 +488,11 @@ async fn drain_messaging_events(messaging_rx: &mut MessagingEventRx, show_logs: } } println!("{} messages sent between nodes", num_messages); + num_messages } else { let len = drain_fut.await.len(); println!("📨 {} messages exchanged", len); + len } } @@ -394,6 +501,9 @@ fn connection_manager_logger( ) -> impl FnMut(Arc) -> Arc { let node_name = get_name(&node_id); move |event| { + if QUIET_MODE { + return event; + } use ConnectionManagerEvent::*; print!("EVENT: "); match &*event { @@ -415,8 +525,8 @@ fn connection_manager_logger( PeerConnectFailed(node_id, err) => { println!( "'{}' failed to connect to '{}' because '{:?}'", - get_name(node_id), node_name, + get_name(node_id), err ); }, @@ -630,9 +740,11 @@ async fn setup_comms_dht( comms.shutdown_signal(), ) .local_test() - .with_discovery_timeout(Duration::from_secs(60)) + .with_discovery_timeout(Duration::from_secs(15)) .with_num_neighbouring_nodes(8) - .finish(); + .finish() + .await + .unwrap(); let dht_outbound_layer = dht.outbound_middleware_layer(); @@ -660,5 +772,5 @@ async fn setup_comms_dht( async fn take_a_break() { banner!("Taking a break for a few seconds to let things settle..."); - time::delay_for(Duration::from_millis(NUM_NODES as u64 * 500)).await; + time::delay_for(Duration::from_millis(NUM_NODES as u64 * 300)).await; } diff --git a/comms/dht/migrations/.gitkeep b/comms/dht/migrations/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/comms/dht/migrations/2020-04-01-095825_initial/down.sql b/comms/dht/migrations/2020-04-01-095825_initial/down.sql new file mode 100644 index 0000000000..6e0a2cbbd7 --- /dev/null +++ b/comms/dht/migrations/2020-04-01-095825_initial/down.sql @@ -0,0 +1,2 @@ +DROP TABLE IF EXISTS stored_messages; +DROP TABLE IF EXISTS dht_settings; diff --git a/comms/dht/migrations/2020-04-01-095825_initial/up.sql b/comms/dht/migrations/2020-04-01-095825_initial/up.sql new file mode 100644 index 0000000000..3805ecaa31 --- /dev/null +++ b/comms/dht/migrations/2020-04-01-095825_initial/up.sql @@ -0,0 +1,27 @@ +CREATE TABLE stored_messages ( + id INTEGER NOT NULL PRIMARY KEY, + version INT NOT NULL, + origin_pubkey TEXT NOT NULL, + origin_signature TEXT NOT NULL, + message_type INT NOT NULL, + destination_pubkey TEXT, + destination_node_id TEXT, + header BLOB NOT NULL, + body BLOB NOT NULL, + is_encrypted BOOLEAN NOT NULL CHECK (is_encrypted IN (0,1)), + priority INT NOT NULL, + stored_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX idx_stored_messages_destination_pubkey ON stored_messages (destination_pubkey); +CREATE INDEX idx_stored_messages_destination_node_id ON stored_messages (destination_node_id); +CREATE INDEX idx_stored_messages_stored_at ON stored_messages (stored_at); +CREATE INDEX idx_stored_messages_priority ON stored_messages (priority); + +CREATE TABLE dht_settings ( + id INTEGER PRIMARY KEY NOT NULL, + key TEXT NOT NULL, + value BLOB NOT NULL +); + +CREATE UNIQUE INDEX idx_dht_settings_key ON dht_settings (key); diff --git a/comms/dht/migrations/2020-04-07-161148_remove_origin_signature/down.sql b/comms/dht/migrations/2020-04-07-161148_remove_origin_signature/down.sql new file mode 100644 index 0000000000..a4b1198712 --- /dev/null +++ b/comms/dht/migrations/2020-04-07-161148_remove_origin_signature/down.sql @@ -0,0 +1 @@ +-- No going back diff --git a/comms/dht/migrations/2020-04-07-161148_remove_origin_signature/up.sql b/comms/dht/migrations/2020-04-07-161148_remove_origin_signature/up.sql new file mode 100644 index 0000000000..d672ff7ac5 --- /dev/null +++ b/comms/dht/migrations/2020-04-07-161148_remove_origin_signature/up.sql @@ -0,0 +1,20 @@ +DROP TABLE stored_messages; + +CREATE TABLE stored_messages( + id INTEGER NOT NULL PRIMARY KEY, + version INT NOT NULL, + origin_pubkey TEXT, + message_type INT NOT NULL, + destination_pubkey TEXT, + destination_node_id TEXT, + header BLOB NOT NULL, + body BLOB NOT NULL, + is_encrypted BOOLEAN NOT NULL CHECK (is_encrypted IN (0,1)), + priority INT NOT NULL, + stored_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX idx_stored_messages_destination_pubkey ON stored_messages (destination_pubkey); +CREATE INDEX idx_stored_messages_destination_node_id ON stored_messages (destination_node_id); +CREATE INDEX idx_stored_messages_stored_at ON stored_messages (stored_at); +CREATE INDEX idx_stored_messages_priority ON stored_messages (priority); diff --git a/comms/dht/migrations/2020-04-16-165626_clear_stored_messages/down.sql b/comms/dht/migrations/2020-04-16-165626_clear_stored_messages/down.sql new file mode 100644 index 0000000000..291a97c5ce --- /dev/null +++ b/comms/dht/migrations/2020-04-16-165626_clear_stored_messages/down.sql @@ -0,0 +1 @@ +-- This file should undo anything in `up.sql` \ No newline at end of file diff --git a/comms/dht/migrations/2020-04-16-165626_clear_stored_messages/up.sql b/comms/dht/migrations/2020-04-16-165626_clear_stored_messages/up.sql new file mode 100644 index 0000000000..72cbfbeec5 --- /dev/null +++ b/comms/dht/migrations/2020-04-16-165626_clear_stored_messages/up.sql @@ -0,0 +1 @@ +DELETE FROM stored_messages; \ No newline at end of file diff --git a/comms/dht/migrations/2020-04-20-082924_rename_settings_to_metadata/down.sql b/comms/dht/migrations/2020-04-20-082924_rename_settings_to_metadata/down.sql new file mode 100644 index 0000000000..44efc5f9df --- /dev/null +++ b/comms/dht/migrations/2020-04-20-082924_rename_settings_to_metadata/down.sql @@ -0,0 +1 @@ +ALTER TABLE dht_metadata RENAME TO dht_settings; diff --git a/comms/dht/migrations/2020-04-20-082924_rename_settings_to_metadata/up.sql b/comms/dht/migrations/2020-04-20-082924_rename_settings_to_metadata/up.sql new file mode 100644 index 0000000000..57074157cc --- /dev/null +++ b/comms/dht/migrations/2020-04-20-082924_rename_settings_to_metadata/up.sql @@ -0,0 +1 @@ +ALTER TABLE dht_settings RENAME TO dht_metadata; \ No newline at end of file diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index 36bb2f9c25..ce8ede1ce6 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -31,7 +31,8 @@ use crate::{ broadcast_strategy::BroadcastStrategy, discovery::DhtDiscoveryError, outbound::{OutboundMessageRequester, SendMessageParams}, - proto::{dht::JoinMessage, envelope::DhtMessageType, store_forward::StoredMessagesRequest}, + proto::{dht::JoinMessage, envelope::DhtMessageType}, + storage::{DbConnection, DhtDatabase, DhtMetadataKey, StorageError}, DhtConfig, }; use chrono::{DateTime, Utc}; @@ -41,7 +42,6 @@ use futures::{ future, future::BoxFuture, stream::{Fuse, FuturesUnordered}, - FutureExt, SinkExt, StreamExt, }; @@ -49,6 +49,7 @@ use log::*; use std::{fmt, fmt::Display, sync::Arc}; use tari_comms::{ peer_manager::{ + node_id::NodeDistance, NodeId, NodeIdentity, Peer, @@ -60,9 +61,12 @@ use tari_comms::{ }, types::CommsPublicKey, }; -use tari_crypto::tari_utilities::ByteArray; use tari_shutdown::ShutdownSignal; -use tari_storage::IterationResult; +use tari_utilities::{ + message_format::{MessageFormat, MessageFormatError}, + ByteArray, +}; +use tokio::task; use ttl_cache::TtlCache; const LOG_TARGET: &str = "comms::dht::actor"; @@ -80,6 +84,11 @@ pub enum DhtActorError { SendFailed(String), DiscoveryError(DhtDiscoveryError), BlockingJoinError(tokio::task::JoinError), + StorageError(StorageError), + #[error(no_from)] + StoredValueFailedToDeserialize(MessageFormatError), + #[error(no_from)] + FailedToSerializeValue(MessageFormatError), } impl From for DhtActorError { @@ -98,23 +107,24 @@ impl From for DhtActorError { pub enum DhtRequest { /// Send a Join request to the network SendJoin, - /// Send a request for stored messages, optionally specifying a date time that the foreign node should - /// use to filter the returned messages. - SendRequestStoredMessages(Option>), /// Inserts a message signature to the msg hash cache. This operation replies with a boolean /// which is true if the signature already exists in the cache, otherwise false MsgHashCacheInsert(Vec, oneshot::Sender), /// Fetch selected peers according to the broadcast strategy SelectPeers(BroadcastStrategy, oneshot::Sender>), + GetMetadata(DhtMetadataKey, oneshot::Sender>, DhtActorError>>), + SetMetadata(DhtMetadataKey, Vec), } impl Display for DhtRequest { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use DhtRequest::*; match self { - DhtRequest::SendJoin => f.write_str("SendJoin"), - DhtRequest::SendRequestStoredMessages(d) => f.write_str(&format!("SendRequestStoredMessages ({:?})", d)), - DhtRequest::MsgHashCacheInsert(_, _) => f.write_str("MsgHashCacheInsert"), - DhtRequest::SelectPeers(s, _) => f.write_str(&format!("SelectPeers (Strategy={})", s)), + SendJoin => f.write_str("SendJoin"), + MsgHashCacheInsert(_, _) => f.write_str("MsgHashCacheInsert"), + SelectPeers(s, _) => f.write_str(&format!("SelectPeers (Strategy={})", s)), + GetMetadata(key, _) => f.write_str(&format!("GetSetting (key={})", key)), + SetMetadata(key, value) => f.write_str(&format!("SetSetting (key={}, value={} bytes)", key, value.len())), } } } @@ -150,17 +160,28 @@ impl DhtRequester { reply_rx.await.map_err(|_| DhtActorError::ReplyCanceled) } - pub async fn send_request_stored_messages(&mut self) -> Result<(), DhtActorError> { - self.sender - .send(DhtRequest::SendRequestStoredMessages(None)) - .await - .map_err(Into::into) + pub async fn get_metadata(&mut self, key: DhtMetadataKey) -> Result, DhtActorError> { + let (reply_tx, reply_rx) = oneshot::channel(); + self.sender.send(DhtRequest::GetMetadata(key, reply_tx)).await?; + match reply_rx.await.map_err(|_| DhtActorError::ReplyCanceled)?? { + Some(bytes) => T::from_binary(&bytes) + .map(Some) + .map_err(DhtActorError::StoredValueFailedToDeserialize), + None => Ok(None), + } + } + + pub async fn set_metadata(&mut self, key: DhtMetadataKey, value: T) -> Result<(), DhtActorError> { + let bytes = value.to_binary().map_err(DhtActorError::FailedToSerializeValue)?; + self.sender.send(DhtRequest::SetMetadata(key, bytes)).await?; + Ok(()) } } pub struct DhtActor<'a> { node_identity: Arc, peer_manager: Arc, + database: DhtDatabase, outbound_requester: OutboundMessageRequester, config: DhtConfig, shutdown_signal: Option, @@ -169,9 +190,17 @@ pub struct DhtActor<'a> { pending_jobs: FuturesUnordered>>, } +impl DhtActor<'static> { + pub async fn spawn(self) -> Result<(), DhtActorError> { + task::spawn(Self::run(self)); + Ok(()) + } +} + impl<'a> DhtActor<'a> { pub fn new( config: DhtConfig, + conn: DbConnection, node_identity: Arc, peer_manager: Arc, outbound_requester: OutboundMessageRequester, @@ -182,6 +211,7 @@ impl<'a> DhtActor<'a> { Self { msg_hash_cache: TtlCache::new(config.msg_hash_cache_capacity), config, + database: DhtDatabase::new(conn), outbound_requester, peer_manager, node_identity, @@ -191,12 +221,25 @@ impl<'a> DhtActor<'a> { } } - pub async fn run(mut self) { + async fn run(mut self) { + let offline_ts = self + .database + .get_metadata_value::>(DhtMetadataKey::OfflineTimestamp) + .await + .ok() + .flatten(); + info!( + target: LOG_TARGET, + "DhtActor started. {}", + offline_ts + .map(|dt| format!("Dht has been offline since '{}'", dt)) + .unwrap_or_else(String::new) + ); + let mut shutdown_signal = self .shutdown_signal .take() - .expect("DhtActor initialized without shutdown_signal") - .fuse(); + .expect("DhtActor initialized without shutdown_signal"); loop { futures::select! { @@ -219,16 +262,23 @@ impl<'a> DhtActor<'a> { _ = shutdown_signal => { info!(target: LOG_TARGET, "DhtActor is shutting down because it received a shutdown signal."); + // Called with reference to database otherwise DhtActor is not Send + Self::mark_shutdown_time(&self.database).await; break; }, - complete => { - info!(target: LOG_TARGET, "DhtActor is shutting down because the request stream ended."); - break; - } } } } + async fn mark_shutdown_time(db: &DhtDatabase) { + if let Err(err) = db + .set_metadata_value(DhtMetadataKey::OfflineTimestamp, Utc::now()) + .await + { + error!(target: LOG_TARGET, "Failed to mark offline time: {:?}", err); + } + } + fn request_handler(&mut self, request: DhtRequest) -> BoxFuture<'a, Result<(), DhtActorError>> { use DhtRequest::*; match request { @@ -265,15 +315,26 @@ impl<'a> DhtActor<'a> { } }) }, - SendRequestStoredMessages(maybe_since) => { - let node_identity = Arc::clone(&self.node_identity); - let outbound_requester = self.outbound_requester.clone(); - Box::pin(Self::request_stored_messages( - node_identity, - outbound_requester, - self.config.num_neighbouring_nodes, - maybe_since, - )) + GetMetadata(key, reply_tx) => { + let db = self.database.clone(); + Box::pin(async move { + let _ = reply_tx.send(db.get_metadata_value_bytes(key).await.map_err(Into::into)); + Ok(()) + }) + }, + SetMetadata(key, value) => { + let db = self.database.clone(); + Box::pin(async move { + match db.set_metadata_value_bytes(key, value).await { + Ok(_) => { + info!(target: LOG_TARGET, "Dht setting '{}' set", key); + }, + Err(err) => { + error!(target: LOG_TARGET, "set_setting failed because {:?}", err); + }, + } + Ok(()) + }) }, } } @@ -310,32 +371,6 @@ impl<'a> DhtActor<'a> { Ok(()) } - async fn request_stored_messages( - node_identity: Arc, - mut outbound_requester: OutboundMessageRequester, - num_neighbouring_nodes: usize, - maybe_since: Option>, - ) -> Result<(), DhtActorError> - { - outbound_requester - .send_message_no_header( - SendMessageParams::new() - .closest( - node_identity.node_id().clone(), - num_neighbouring_nodes, - Vec::new(), - PeerFeatures::DHT_STORE_FORWARD, - ) - .with_dht_message_type(DhtMessageType::SafRequestMessages) - .finish(), - maybe_since.map(StoredMessagesRequest::since).unwrap_or_default(), - ) - .await - .map_err(|err| DhtActorError::SendFailed(format!("Failed to send request for stored messages: {}", err)))?; - - Ok(()) - } - async fn select_peers( config: DhtConfig, node_identity: Arc, @@ -372,12 +407,13 @@ impl<'a> DhtActor<'a> { &closest_request.node_id, closest_request.n, &closest_request.excluded_peers, + closest_request.peer_features, ) .await }, - Random(n) => { + Random(n, excluded) => { // Send to a random set of peers of size n that are Communication Nodes - peer_manager.random_peers(n).await.map_err(Into::into) + peer_manager.random_peers(n, excluded).await.map_err(Into::into) }, // TODO: This is a common and expensive search - values here should be cached Neighbours(exclude, include_all_communication_clients) => { @@ -388,11 +424,26 @@ impl<'a> DhtActor<'a> { node_identity.node_id(), config.num_neighbouring_nodes, &exclude, + PeerFeatures::MESSAGE_PROPAGATION, ) .await?; if include_all_communication_clients { - Self::add_all_communication_client_nodes(&peer_manager, &exclude, &mut candidates).await?; + let region_dist = peer_manager + .calc_region_threshold( + node_identity.node_id(), + config.num_neighbouring_nodes, + PeerFeatures::COMMUNICATION_CLIENT, + ) + .await?; + Self::add_communication_client_nodes_within_region( + &peer_manager, + node_identity.node_id(), + region_dist, + &exclude, + &mut candidates, + ) + .await?; } Ok(candidates) @@ -400,31 +451,39 @@ impl<'a> DhtActor<'a> { } } - async fn add_all_communication_client_nodes( + async fn add_communication_client_nodes_within_region( peer_manager: &PeerManager, + ref_node_id: &NodeId, + threshold_dist: NodeDistance, excluded_peers: &[CommsPublicKey], list: &mut Vec, ) -> Result<(), DhtActorError> { - peer_manager - .for_each(|peer| { + let query = PeerQuery::new() + .select_where(|peer| { if peer.features != PeerFeatures::COMMUNICATION_CLIENT { - return IterationResult::Continue; + return false; } if peer.is_banned() || peer.is_offline() { - return IterationResult::Continue; + return false; } if excluded_peers.contains(&peer.public_key) { - return IterationResult::Continue; + return false; } - list.push(peer); + let dist = ref_node_id.distance(&peer.node_id); + if dist > threshold_dist { + return false; + } - IterationResult::Continue + true }) - .await?; + .sort_by(PeerQuerySortBy::DistanceFrom(ref_node_id)); + + let peers = peer_manager.perform_query(query).await?; + list.extend(peers); Ok(()) } @@ -442,6 +501,7 @@ impl<'a> DhtActor<'a> { node_id: &NodeId, n: usize, excluded_peers: &[CommsPublicKey], + features: PeerFeatures, ) -> Result, DhtActorError> { // TODO: This query is expensive. We can probably cache a list of neighbouring peers which are online @@ -458,17 +518,19 @@ impl<'a> DhtActor<'a> { let mut filtered_out_node_count = 0; let query = PeerQuery::new() .select_where(|peer| { - trace!(target: LOG_TARGET, "Considering peer for broadcast: {}", peer.node_id); - - let is_banned = peer.is_banned(); - trace!(target: LOG_TARGET, "[{}] is banned: {}", peer.node_id, is_banned); - if is_banned { + if peer.is_banned() { + trace!(target: LOG_TARGET, "[{}] is banned", peer.node_id); banned_count += 1; return false; } - if !peer.features.contains(PeerFeatures::MESSAGE_PROPAGATION) { - trace!(target: LOG_TARGET, "[{}] is not a propagation node", peer.node_id); + if !peer.features.contains(features) { + trace!( + target: LOG_TARGET, + "[{}] is does not have the required features {:?}", + peer.node_id, + features + ); filtered_out_node_count += 1; return false; } @@ -534,12 +596,19 @@ mod test { broadcast_strategy::BroadcastClosestRequest, test_utils::{make_node_identity, make_peer_manager}, }; + use chrono::{DateTime, Utc}; use tari_comms::{ net_address::MultiaddressesWithStats, peer_manager::{PeerFeatures, PeerFlags}, }; use tari_shutdown::Shutdown; - use tokio::runtime; + use tari_test_utils::random; + + async fn db_connection() -> DbConnection { + let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); + conn.migrate().await.unwrap(); + conn + } #[tokio_macros::test_basic] async fn send_join_request() { @@ -552,6 +621,7 @@ mod test { let shutdown = Shutdown::new(); let actor = DhtActor::new( Default::default(), + db_connection().await, node_identity, peer_manager, outbound_requester, @@ -559,7 +629,7 @@ mod test { shutdown.to_signal(), ); - runtime::Handle::current().spawn(actor.run()); + actor.spawn().await.unwrap(); requester.send_join().await.unwrap(); let (params, _) = unwrap_oms_send_msg!(out_rx.next().await.unwrap()); @@ -577,6 +647,7 @@ mod test { let shutdown = Shutdown::new(); let actor = DhtActor::new( Default::default(), + db_connection().await, node_identity, peer_manager, outbound_requester, @@ -584,7 +655,7 @@ mod test { shutdown.to_signal(), ); - runtime::Handle::current().spawn(actor.run()); + actor.spawn().await.unwrap(); let signature = vec![1u8, 2, 3]; let is_dup = requester.insert_message_hash(signature.clone()).await.unwrap(); @@ -632,6 +703,7 @@ mod test { let shutdown = Shutdown::new(); let actor = DhtActor::new( Default::default(), + db_connection().await, Arc::clone(&node_identity), peer_manager, outbound_requester, @@ -639,7 +711,7 @@ mod test { shutdown.to_signal(), ); - runtime::Handle::current().spawn(actor.run()); + actor.spawn().await.unwrap(); let peers = requester .select_peers(BroadcastStrategy::Neighbours(Vec::new(), false)) @@ -675,4 +747,58 @@ mod test { assert_eq!(peers.len(), 1); } + + #[tokio_macros::test_basic] + async fn get_and_set_metadata() { + let node_identity = make_node_identity(); + let peer_manager = make_peer_manager(); + let (out_tx, _out_rx) = mpsc::channel(1); + let (actor_tx, actor_rx) = mpsc::channel(1); + let mut requester = DhtRequester::new(actor_tx); + let outbound_requester = OutboundMessageRequester::new(out_tx); + let shutdown = Shutdown::new(); + let actor = DhtActor::new( + Default::default(), + db_connection().await, + node_identity, + peer_manager, + outbound_requester, + actor_rx, + shutdown.to_signal(), + ); + + actor.spawn().await.unwrap(); + + assert!(requester + .get_metadata::>(DhtMetadataKey::OfflineTimestamp,) + .await + .unwrap() + .is_none()); + let ts = Utc::now(); + requester + .set_metadata(DhtMetadataKey::OfflineTimestamp, ts) + .await + .unwrap(); + + let got_ts = requester + .get_metadata::>(DhtMetadataKey::OfflineTimestamp) + .await + .unwrap() + .unwrap(); + assert_eq!(got_ts, ts); + + // Check upsert + let ts = Utc::now().checked_add_signed(chrono::Duration::seconds(123)).unwrap(); + requester + .set_metadata(DhtMetadataKey::OfflineTimestamp, ts) + .await + .unwrap(); + + let got_ts = requester + .get_metadata::>(DhtMetadataKey::OfflineTimestamp) + .await + .unwrap() + .unwrap(); + assert_eq!(got_ts, ts); + } } diff --git a/comms/dht/src/broadcast_strategy.rs b/comms/dht/src/broadcast_strategy.rs index 79c6637813..922aa52105 100644 --- a/comms/dht/src/broadcast_strategy.rs +++ b/comms/dht/src/broadcast_strategy.rs @@ -42,13 +42,13 @@ pub enum BroadcastStrategy { DirectPublicKey(Box), /// Send to all known peers Flood, - /// Send to a random set of peers of size n that are Communication Nodes - Random(usize), + /// Send to a random set of peers of size n that are Communication Nodes, excluding the given node IDs + Random(usize, Vec), /// Send to all n nearest Communication Nodes according to the given BroadcastClosestRequest Closest(Box), /// A convenient strategy which behaves the same as the `Closest` strategy with the `NodeId` set - /// to this node and a pre-configured number of neighbours that have all the matching PeerFeatures flags. - /// This strategy excludes the given public keys. + /// to this node. Element 0 in the tuple is a public key exclusion list. If element 1 is set to true, all + /// neighbouring client peers are also included in addition to node peers. Neighbours(Vec, bool), } @@ -60,7 +60,7 @@ impl fmt::Display for BroadcastStrategy { DirectNodeId(node_id) => write!(f, "DirectNodeId({})", node_id), Flood => write!(f, "Flood"), Closest(request) => write!(f, "Closest({})", request.n), - Random(n) => write!(f, "Random({})", n), + Random(n, excluded) => write!(f, "Random({}, {} excluded)", n, excluded.len()), Neighbours(excluded, include_clients) => write!( f, "Neighbours({} excluded{})", @@ -72,6 +72,14 @@ impl fmt::Display for BroadcastStrategy { } impl BroadcastStrategy { + pub fn is_broadcast(&self) -> bool { + use BroadcastStrategy::*; + match self { + Closest(_) | Flood | Neighbours(_, _) | Random(_, _) => true, + _ => false, + } + } + pub fn is_direct(&self) -> bool { use BroadcastStrategy::*; match self { @@ -128,7 +136,7 @@ mod test { .is_direct(), false ); - assert_eq!(BroadcastStrategy::Random(0).is_direct(), false); + assert_eq!(BroadcastStrategy::Random(0, vec![]).is_direct(), false); } #[test] @@ -151,7 +159,10 @@ mod test { })) .direct_public_key() .is_none(),); - assert!(BroadcastStrategy::Random(0).direct_public_key().is_none(), false); + assert!( + BroadcastStrategy::Random(0, vec![]).direct_public_key().is_none(), + false + ); } #[test] @@ -174,6 +185,6 @@ mod test { })) .direct_node_id() .is_none(),); - assert!(BroadcastStrategy::Random(0).direct_node_id().is_none(), false); + assert!(BroadcastStrategy::Random(0, vec![]).direct_node_id().is_none(), false); } } diff --git a/comms/dht/src/builder.rs b/comms/dht/src/builder.rs index ad6f175dc4..f897043739 100644 --- a/comms/dht/src/builder.rs +++ b/comms/dht/src/builder.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{outbound::DhtOutboundRequest, Dht, DhtConfig}; +use crate::{dht::DhtInitializationError, outbound::DhtOutboundRequest, DbConnectionUrl, Dht, DhtConfig}; use futures::channel::mpsc; use std::{sync::Arc, time::Duration}; use tari_comms::{ @@ -70,6 +70,11 @@ impl DhtBuilder { self } + pub fn disable_auto_store_and_forward_requests(mut self) -> Self { + self.config.saf_auto_request = false; + self + } + pub fn testnet(mut self) -> Self { self.config = DhtConfig::default_testnet(); self @@ -80,6 +85,11 @@ impl DhtBuilder { self } + pub fn with_database_url(mut self, database_url: DbConnectionUrl) -> Self { + self.config.database_url = database_url; + self + } + pub fn with_signature_cache_ttl(mut self, ttl: Duration) -> Self { self.config.msg_hash_cache_ttl = ttl; self @@ -100,11 +110,11 @@ impl DhtBuilder { self } - /// Build a Dht object. + /// Build and initialize a Dht object. /// - /// Will panic if an executor is not given AND not in a tokio runtime context - pub fn finish(self) -> Dht { - Dht::new( + /// Will panic not in a tokio runtime context + pub async fn finish(self) -> Result { + Dht::initialize( self.config, self.node_identity, self.peer_manager, @@ -112,5 +122,6 @@ impl DhtBuilder { self.connection_manager, self.shutdown_signal, ) + .await } } diff --git a/comms/dht/src/config.rs b/comms/dht/src/config.rs index 0b3cf13bf8..c642910684 100644 --- a/comms/dht/src/config.rs +++ b/comms/dht/src/config.rs @@ -20,20 +20,22 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::envelope::Network; +use crate::{envelope::Network, storage::DbConnectionUrl}; use std::time::Duration; /// The default maximum number of messages that can be stored using the Store-and-forward middleware pub const SAF_MSG_CACHE_STORAGE_CAPACITY: usize = 10_000; /// The default time-to-live duration used for storage of low priority messages by the Store-and-forward middleware -pub const SAF_LOW_PRIORITY_MSG_STORAGE_TTL: Duration = Duration::from_secs(6 * 60 * 60); +pub const SAF_LOW_PRIORITY_MSG_STORAGE_TTL: Duration = Duration::from_secs(6 * 60 * 60); // 6 hours /// The default time-to-live duration used for storage of high priority messages by the Store-and-forward middleware -pub const SAF_HIGH_PRIORITY_MSG_STORAGE_TTL: Duration = Duration::from_secs(24 * 60 * 60); +pub const SAF_HIGH_PRIORITY_MSG_STORAGE_TTL: Duration = Duration::from_secs(3 * 24 * 60 * 60); // 3 days /// The default number of peer nodes that a message has to be closer to, to be considered a neighbour pub const DEFAULT_NUM_NEIGHBOURING_NODES: usize = 10; #[derive(Debug, Clone)] pub struct DhtConfig { + /// The `DbConnectionUrl` for the Dht database. Default: In-memory database + pub database_url: DbConnectionUrl, /// The size of the buffer (channel) which holds pending outbound message requests. /// Default: 20 pub outbound_buffer_size: usize, @@ -53,13 +55,17 @@ pub struct DhtConfig { /// Default: 6 hours pub saf_low_priority_msg_storage_ttl: Duration, /// The time-to-live duration used for storage of high priority messages by the Store-and-forward middleware. - /// Default: 24 hours + /// Default: 3 days pub saf_high_priority_msg_storage_ttl: Duration, + /// The limit on the message size to store in SAF storage in bytes. Default 500 KiB + pub saf_max_message_size: usize, + /// When true, store and forward messages are requested from peers on connect (Default: true) + pub saf_auto_request: bool, /// The max capacity of the message hash cache - /// Default: 1000 + /// Default: 10000 pub msg_hash_cache_capacity: usize, /// The time-to-live for items in the message hash cache - /// Default: 300s + /// Default: 300s (5 mins) pub msg_hash_cache_ttl: Duration, /// Sets the number of failed attempts in-a-row to tolerate before temporarily excluding this peer from broadcast /// messages. @@ -92,6 +98,8 @@ impl DhtConfig { pub fn default_local_test() -> Self { Self { network: Network::LocalTest, + database_url: DbConnectionUrl::Memory, + saf_auto_request: false, ..Default::default() } } @@ -102,14 +110,17 @@ impl Default for DhtConfig { Self { num_neighbouring_nodes: DEFAULT_NUM_NEIGHBOURING_NODES, saf_num_closest_nodes: 10, - saf_max_returned_messages: 100, + saf_max_returned_messages: 50, outbound_buffer_size: 20, saf_msg_cache_storage_capacity: SAF_MSG_CACHE_STORAGE_CAPACITY, saf_low_priority_msg_storage_ttl: SAF_LOW_PRIORITY_MSG_STORAGE_TTL, saf_high_priority_msg_storage_ttl: SAF_HIGH_PRIORITY_MSG_STORAGE_TTL, - msg_hash_cache_capacity: 1000, - msg_hash_cache_ttl: Duration::from_secs(300), + saf_auto_request: true, + saf_max_message_size: 512 * 1024, // 500 KiB + msg_hash_cache_capacity: 10_000, + msg_hash_cache_ttl: Duration::from_secs(5 * 60), broadcast_cooldown_max_attempts: 3, + database_url: DbConnectionUrl::Memory, broadcast_cooldown_period: Duration::from_secs(60 * 30), discovery_request_timeout: Duration::from_secs(2 * 60), network: Network::TestNet, diff --git a/comms/dht/src/crypt.rs b/comms/dht/src/crypt.rs index 4f12b1ff81..2f07d5a097 100644 --- a/comms/dht/src/crypt.rs +++ b/comms/dht/src/crypt.rs @@ -38,17 +38,17 @@ where PK: PublicKey + DiffieHellmanSharedSecret { } pub fn decrypt(cipher_key: &CommsPublicKey, cipher_text: &[u8]) -> Result, CipherError> { - ChaCha20::open_with_integral_nonce(cipher_text, cipher_key.as_bytes()) + ChaCha20::open_with_integral_nonce(&cipher_text.to_vec(), cipher_key.as_bytes()) } -pub fn encrypt(cipher_key: &CommsPublicKey, plain_text: &Vec) -> Result, CipherError> { - ChaCha20::seal_with_integral_nonce(plain_text, &cipher_key.to_vec()) +pub fn encrypt(cipher_key: &CommsPublicKey, plain_text: &[u8]) -> Result, CipherError> { + ChaCha20::seal_with_integral_nonce(&plain_text.to_vec(), cipher_key.as_bytes()) } #[cfg(test)] mod test { use super::*; - use tari_crypto::tari_utilities::hex::from_hex; + use tari_utilities::hex::from_hex; #[test] fn encrypt_decrypt() { diff --git a/comms/dht/src/dedup.rs b/comms/dht/src/dedup.rs new file mode 100644 index 0000000000..25fd0effcb --- /dev/null +++ b/comms/dht/src/dedup.rs @@ -0,0 +1,232 @@ +// Copyright 2020, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use crate::{actor::DhtRequester, inbound::DhtInboundMessage, outbound::message::DhtOutboundMessage}; +use digest::Input; +use futures::{task::Context, Future}; +use log::*; +use std::task::Poll; +use tari_comms::{pipeline::PipelineError, types::Challenge}; +use tari_utilities::hex::Hex; +use tower::{layer::Layer, Service, ServiceExt}; + +const LOG_TARGET: &str = "comms::dht::dedup"; + +fn hash_inbound_message(message: &DhtInboundMessage) -> Vec { + Challenge::new().chain(&message.body).result().to_vec() +} + +fn hash_outbound_message(message: &DhtOutboundMessage) -> Vec { + Challenge::new().chain(&message.body.to_vec()).result().to_vec() +} + +/// # DHT Deduplication middleware +/// +/// Takes in a `DhtInboundMessage` and checks the message signature cache for duplicates. +/// If a duplicate message is detected, it is discarded. +#[derive(Clone)] +pub struct DedupMiddleware { + next_service: S, + dht_requester: DhtRequester, +} + +impl DedupMiddleware { + pub fn new(service: S, dht_requester: DhtRequester) -> Self { + Self { + next_service: service, + dht_requester, + } + } +} + +impl Service for DedupMiddleware +where S: Service + Clone +{ + type Error = PipelineError; + type Response = (); + + type Future = impl Future>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, message: DhtInboundMessage) -> Self::Future { + let next_service = self.next_service.clone(); + let mut dht_requester = self.dht_requester.clone(); + async move { + let hash = hash_inbound_message(&message); + trace!( + target: LOG_TARGET, + "Inserting message hash {} for message {}", + hash.to_hex(), + message.tag + ); + if dht_requester + .insert_message_hash(hash) + .await + .map_err(PipelineError::from_debug)? + { + info!( + target: LOG_TARGET, + "Received duplicate message {} from peer '{}'. Message discarded.", + message.tag, + message.source_peer.node_id.short_str(), + ); + return Ok(()); + } + + debug!(target: LOG_TARGET, "Passing message {} onto next service", message.tag); + next_service.oneshot(message).await + } + } +} + +impl Service for DedupMiddleware +where S: Service + Clone +{ + type Error = PipelineError; + type Response = (); + + type Future = impl Future>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, message: DhtOutboundMessage) -> Self::Future { + let next_service = self.next_service.clone(); + let mut dht_requester = self.dht_requester.clone(); + async move { + if message.is_broadcast { + let hash = hash_outbound_message(&message); + debug!( + target: LOG_TARGET, + "Dedup added message hash {} to cache for message {}", + hash.to_hex(), + message.tag + ); + if dht_requester + .insert_message_hash(hash) + .await + .map_err(PipelineError::from_debug)? + { + info!( + target: LOG_TARGET, + "Outgoing message is already in the cache ({}, next peer = {})", + message.tag, + message.destination_peer.node_id.short_str() + ); + } + } + + trace!(target: LOG_TARGET, "Passing message onto next service"); + next_service.oneshot(message).await + } + } +} + +pub struct DedupLayer { + dht_requester: DhtRequester, +} + +impl DedupLayer { + pub fn new(dht_requester: DhtRequester) -> Self { + Self { dht_requester } + } +} + +impl Layer for DedupLayer { + type Service = DedupMiddleware; + + fn layer(&self, service: S) -> Self::Service { + DedupMiddleware::new(service, self.dht_requester.clone()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{ + envelope::DhtMessageFlags, + test_utils::{ + create_dht_actor_mock, + create_outbound_message, + make_dht_inbound_message, + make_node_identity, + service_spy, + DhtMockState, + }, + }; + use tari_test_utils::panic_context; + use tokio::runtime::Runtime; + + #[test] + fn process_message() { + let mut rt = Runtime::new().unwrap(); + let spy = service_spy(); + + let (dht_requester, mut mock) = create_dht_actor_mock(1); + let mock_state = DhtMockState::new(); + mock_state.set_signature_cache_insert(false); + mock.set_shared_state(mock_state.clone()); + rt.spawn(mock.run()); + + let mut dedup = DedupLayer::new(dht_requester).layer(spy.to_service::()); + + panic_context!(cx); + + assert!(dedup.poll_ready(&mut cx).is_ready()); + let node_identity = make_node_identity(); + let msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty(), false); + + rt.block_on(dedup.call(msg.clone())).unwrap(); + assert_eq!(spy.call_count(), 1); + + mock_state.set_signature_cache_insert(true); + rt.block_on(dedup.call(msg)).unwrap(); + assert_eq!(spy.call_count(), 1); + // Drop dedup so that the DhtMock will stop running + drop(dedup); + } + + #[test] + fn deterministic_hash() { + const TEST_MSG: &[u8] = b"test123"; + const EXPECTED_HASH: &str = "90cccd774db0ac8c6ea2deff0e26fc52768a827c91c737a2e050668d8c39c224"; + let node_identity = make_node_identity(); + let msg = make_dht_inbound_message(&node_identity, TEST_MSG.to_vec(), DhtMessageFlags::empty(), false); + let hash1 = hash_inbound_message(&msg); + let msg = create_outbound_message(&TEST_MSG); + let hash_out1 = hash_outbound_message(&msg); + + let node_identity = make_node_identity(); + let msg = make_dht_inbound_message(&node_identity, TEST_MSG.to_vec(), DhtMessageFlags::empty(), false); + let hash2 = hash_inbound_message(&msg); + let msg = create_outbound_message(&TEST_MSG); + let hash_out2 = hash_outbound_message(&msg); + + assert_eq!(hash1, hash2); + let subjects = &[hash1, hash_out1, hash2, hash_out2]; + assert!(subjects.into_iter().all(|h| h.to_hex() == EXPECTED_HASH)); + } +} diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index cf8adbcb9e..fcfaa99768 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -30,11 +30,15 @@ use crate::{ outbound, outbound::DhtOutboundRequest, proto::envelope::DhtMessageType, + storage::{DbConnection, StorageError}, store_forward, + store_forward::{StoreAndForwardError, StoreAndForwardRequest, StoreAndForwardRequester, StoreAndForwardService}, tower_filter, - tower_filter::error::Error as FilterError, + DedupLayer, + DhtActorError, DhtConfig, }; +use derive_error::Error; use futures::{channel::mpsc, future, Future}; use log::*; use std::sync::Arc; @@ -45,9 +49,22 @@ use tari_comms::{ pipeline::PipelineError, }; use tari_shutdown::ShutdownSignal; -use tokio::task; use tower::{layer::Layer, Service, ServiceBuilder}; +const LOG_TARGET: &str = "comms::dht"; + +const DHT_ACTOR_CHANNEL_SIZE: usize = 100; +const DHT_DISCOVERY_CHANNEL_SIZE: usize = 100; +const DHT_SAF_SERVICE_CHANNEL_SIZE: usize = 100; + +#[derive(Debug, Error)] +pub enum DhtInitializationError { + /// Database initialization failed + DatabaseMigrationFailed(StorageError), + StoreAndForwardInitializationError(StoreAndForwardError), + DhtActorInitializationError(DhtActorError), +} + /// Responsible for starting the DHT actor, building the DHT middleware stack and as a factory /// for producing DHT requesters. pub struct Dht { @@ -61,24 +78,27 @@ pub struct Dht { outbound_tx: mpsc::Sender, /// Sender for DHT requests dht_sender: mpsc::Sender, - /// Sender for DHT requests + /// Sender for SAF requests + saf_sender: mpsc::Sender, + /// Sender for DHT discovery requests discovery_sender: mpsc::Sender, /// Connection manager actor requester connection_manager: ConnectionManagerRequester, } impl Dht { - pub fn new( + pub async fn initialize( config: DhtConfig, node_identity: Arc, peer_manager: Arc, outbound_tx: mpsc::Sender, connection_manager: ConnectionManagerRequester, shutdown_signal: ShutdownSignal, - ) -> Self + ) -> Result { - let (dht_sender, dht_receiver) = mpsc::channel(20); - let (discovery_sender, discovery_receiver) = mpsc::channel(20); + let (dht_sender, dht_receiver) = mpsc::channel(DHT_ACTOR_CHANNEL_SIZE); + let (discovery_sender, discovery_receiver) = mpsc::channel(DHT_DISCOVERY_CHANNEL_SIZE); + let (saf_sender, saf_receiver) = mpsc::channel(DHT_SAF_SERVICE_CHANNEL_SIZE); let dht = Self { node_identity, @@ -86,25 +106,37 @@ impl Dht { config, outbound_tx, dht_sender, + saf_sender, connection_manager, discovery_sender, }; - task::spawn(dht.actor(dht_receiver, shutdown_signal.clone()).run()); - task::spawn(dht.discovery_service(discovery_receiver, shutdown_signal).run()); + let conn = DbConnection::connect_and_migrate(dht.config.database_url.clone()) + .await + .map_err(DhtInitializationError::DatabaseMigrationFailed)?; - dht + dht.store_and_forward_service(conn.clone(), saf_receiver, shutdown_signal.clone()) + .spawn() + .await?; + dht.actor(conn, dht_receiver, shutdown_signal.clone()).spawn().await?; + dht.discovery_service(discovery_receiver, shutdown_signal).spawn(); + + info!(target: LOG_TARGET, "Dht initialization complete."); + + Ok(dht) } /// Create a DHT actor fn actor( &self, + conn: DbConnection, request_receiver: mpsc::Receiver, shutdown_signal: ShutdownSignal, ) -> DhtActor<'static> { DhtActor::new( self.config.clone(), + conn, Arc::clone(&self.node_identity), Arc::clone(&self.peer_manager), self.outbound_requester(), @@ -131,6 +163,26 @@ impl Dht { ) } + fn store_and_forward_service( + &self, + conn: DbConnection, + request_rx: mpsc::Receiver, + shutdown_signal: ShutdownSignal, + ) -> StoreAndForwardService + { + StoreAndForwardService::new( + self.config.clone(), + conn, + self.node_identity.clone(), + self.peer_manager.clone(), + self.dht_requester(), + self.connection_manager.clone(), + self.outbound_requester(), + request_rx, + shutdown_signal, + ) + } + /// Return a new OutboundMessageRequester connected to the receiver pub fn outbound_requester(&self) -> OutboundMessageRequester { OutboundMessageRequester::new(self.outbound_tx.clone()) @@ -146,6 +198,11 @@ impl Dht { DhtDiscoveryRequester::new(self.discovery_sender.clone(), self.config.discovery_request_timeout) } + /// Returns a requester for the StoreAndForwardService associated with this instance + pub fn store_and_forward_requester(&self) -> StoreAndForwardRequester { + StoreAndForwardRequester::new(self.saf_sender.clone()) + } + /// Returns an the full DHT stack as a `tower::layer::Layer`. This can be composed with /// other inbound middleware services which expect an DecryptedDhtMessage pub fn inbound_middleware_layer( @@ -161,44 +218,35 @@ impl Dht { + Send, > where - S: Service + Clone + Send + Sync + 'static, + S: Service + Clone + Send + Sync + 'static, S::Future: Send, - S::Error: std::error::Error + Send + Sync + 'static, { - let saf_storage = Arc::new(store_forward::SafStorage::new( - self.config.saf_msg_cache_storage_capacity, - )); - let builder = ServiceBuilder::new() - .layer(inbound::DeserializeLayer::new()) - .layer(inbound::ValidateLayer::new( - self.config.network, - self.outbound_requester(), - )) - .layer(inbound::DedupLayer::new(self.dht_requester())); - - // FIXME: There is an unresolved stack overflow issue on windows. Seems that we've reached the limit on stack - // page size. These layers are removed from windows builds for now as they are not critical to - // the functioning of the node. (issue #1416) - #[cfg(not(target_os = "windows"))] - let builder = builder + // FIXME: There is an unresolved stack overflow issue on windows in debug mode during runtime, but not in + // release mode, related to the amount of layers. (issue #1416) + ServiceBuilder::new() + .layer(inbound::DeserializeLayer) + .layer(inbound::ValidateLayer::new(self.config.network)) + .layer(DedupLayer::new(self.dht_requester())) .layer(tower_filter::FilterLayer::new(self.unsupported_saf_messages_filter())) - .layer(MessageLoggingLayer::new("Inbound message: ")); - - builder + .layer(MessageLoggingLayer::new(format!( + "Inbound [{}]", + self.node_identity.node_id().short_str() + ))) .layer(inbound::DecryptionLayer::new(Arc::clone(&self.node_identity))) .layer(store_forward::ForwardLayer::new( Arc::clone(&self.peer_manager), self.outbound_requester(), + self.node_identity.features().contains(PeerFeatures::DHT_STORE_FORWARD), )) .layer(store_forward::StoreLayer::new( self.config.clone(), Arc::clone(&self.peer_manager), Arc::clone(&self.node_identity), - Arc::clone(&saf_storage), + self.store_and_forward_requester(), )) .layer(store_forward::MessageHandlerLayer::new( self.config.clone(), - saf_storage, + self.store_and_forward_requester(), self.dht_requester(), Arc::clone(&self.node_identity), Arc::clone(&self.peer_manager), @@ -239,9 +287,12 @@ impl Dht { self.discovery_service_requester(), self.config.network, )) - .layer(MessageLoggingLayer::new("Outbound message: ")) - .layer(outbound::EncryptionLayer::new(Arc::clone(&self.node_identity))) - .layer(outbound::SerializeLayer::new(Arc::clone(&self.node_identity))) + .layer(DedupLayer::new(self.dht_requester())) + .layer(MessageLoggingLayer::new(format!( + "Outbound [{}]", + self.node_identity.node_id().short_str() + ))) + .layer(outbound::SerializeLayer) .into_inner() } @@ -249,7 +300,7 @@ impl Dht { /// supported by the node. fn unsupported_saf_messages_filter( &self, - ) -> impl tower_filter::Predicate>> + Clone + Send + ) -> impl tower_filter::Predicate>> + Clone + Send { let node_identity = Arc::clone(&self.node_identity); move |msg: &DhtInboundMessage| { @@ -265,7 +316,9 @@ impl Dht { supported by this node. Discarding message.", msg.source_peer.public_key ); - future::ready(Err(FilterError::rejected())) + future::ready(Err(PipelineError::from_debug( + "Message filtered out because store and forward is not supported by this node", + ))) }, _ => future::ready(Ok(())), } @@ -292,7 +345,7 @@ mod test { use futures::{channel::mpsc, StreamExt}; use std::{sync::Arc, time::Duration}; use tari_comms::{ - message::{MessageExt, MessageFlags}, + message::MessageExt, pipeline::SinkService, test_utils::mocks::create_connection_manager_mock, wrap_in_envelope_body, @@ -319,23 +372,17 @@ mod test { shutdown.to_signal(), ) .local_test() - .finish(); + .finish() + .await + .unwrap(); let (out_tx, mut out_rx) = mpsc::channel(10); let mut service = dht.inbound_middleware_layer().layer(SinkService::new(out_tx)); - let msg = wrap_in_envelope_body!(b"secret".to_vec()).unwrap(); - let dht_envelope = make_dht_envelope( - &node_identity, - msg.to_encoded_bytes().unwrap(), - DhtMessageFlags::empty(), - ); - let inbound_message = make_comms_inbound_message( - &node_identity, - dht_envelope.to_encoded_bytes().unwrap().into(), - MessageFlags::empty(), - ); + let msg = wrap_in_envelope_body!(b"secret".to_vec()); + let dht_envelope = make_dht_envelope(&node_identity, msg.to_encoded_bytes(), DhtMessageFlags::empty(), false); + let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into()); let msg = { service.call(inbound_message).await.unwrap(); @@ -356,7 +403,7 @@ mod test { let (connection_manager, _) = create_connection_manager_mock(1); // Dummy out channel, we are not testing outbound here. - let (out_tx, _) = mpsc::channel(10); + let (out_tx, _out_rx) = mpsc::channel(10); let shutdown = Shutdown::new(); let dht = DhtBuilder::new( @@ -366,22 +413,18 @@ mod test { connection_manager, shutdown.to_signal(), ) - .finish(); + .finish() + .await + .unwrap(); let (out_tx, mut out_rx) = mpsc::channel(10); let mut service = dht.inbound_middleware_layer().layer(SinkService::new(out_tx)); - let msg = wrap_in_envelope_body!(b"secret".to_vec()).unwrap(); + let msg = wrap_in_envelope_body!(b"secret".to_vec()); // Encrypt for self - let ecdh_key = crypt::generate_ecdh_secret(node_identity.secret_key(), node_identity.public_key()); - let encrypted_bytes = crypt::encrypt(&ecdh_key, &msg.to_encoded_bytes().unwrap()).unwrap(); - let dht_envelope = make_dht_envelope(&node_identity, encrypted_bytes, DhtMessageFlags::ENCRYPTED); - let inbound_message = make_comms_inbound_message( - &node_identity, - dht_envelope.to_encoded_bytes().unwrap().into(), - MessageFlags::empty(), - ); + let dht_envelope = make_dht_envelope(&node_identity, msg.to_encoded_bytes(), DhtMessageFlags::ENCRYPTED, true); + let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into()); let msg = { service.call(inbound_message).await.unwrap(); @@ -414,34 +457,25 @@ mod test { connection_manager, shutdown.to_signal(), ) - .finish(); + .finish() + .await + .unwrap(); let oms_mock_state = oms_mock.get_state(); task::spawn(oms_mock.run()); let mut service = dht.inbound_middleware_layer().layer(SinkService::new(next_service_tx)); - let msg = wrap_in_envelope_body!(b"unencrypteable".to_vec()).unwrap(); + let msg = wrap_in_envelope_body!(b"unencrypteable".to_vec()); // Encrypt for someone else let node_identity2 = make_node_identity(); let ecdh_key = crypt::generate_ecdh_secret(node_identity2.secret_key(), node_identity2.public_key()); - let encrypted_bytes = crypt::encrypt(&ecdh_key, &msg.to_encoded_bytes().unwrap()).unwrap(); - let dht_envelope = make_dht_envelope(&node_identity, encrypted_bytes, DhtMessageFlags::ENCRYPTED); - - let origin_sig = dht_envelope - .header - .as_ref() - .unwrap() - .origin - .as_ref() - .unwrap() - .signature - .clone(); - let inbound_message = make_comms_inbound_message( - &node_identity, - dht_envelope.to_encoded_bytes().unwrap().into(), - MessageFlags::empty(), - ); + let encrypted_bytes = crypt::encrypt(&ecdh_key, &msg.to_encoded_bytes()).unwrap(); + let dht_envelope = make_dht_envelope(&node_identity, encrypted_bytes, DhtMessageFlags::ENCRYPTED, true); + + let origin_mac = dht_envelope.header.as_ref().unwrap().origin_mac.clone(); + assert_eq!(origin_mac.is_empty(), false); + let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into()); service.call(inbound_message).await.unwrap(); @@ -449,14 +483,12 @@ mod test { let (params, _) = oms_mock_state.pop_call().unwrap(); // Check that OMS got a request to forward with the original Dht Header - assert_eq!(params.dht_header.unwrap().origin.unwrap().signature, origin_sig); + assert_eq!(params.dht_header.unwrap().origin_mac, origin_mac); // Check the next service was not called assert!(next_service_rx.try_next().is_err()); } - // FIXME: This test is excluded for Windows builds due to an unresolved stack overflow issue (#1416) - #[cfg(not(target_os = "windows"))] #[tokio_macros::test_basic] async fn stack_filter_saf_message() { let node_identity = make_client_identity(); @@ -474,27 +506,22 @@ mod test { connection_manager, shutdown.to_signal(), ) - .finish(); + .finish() + .await + .unwrap(); let (next_service_tx, mut next_service_rx) = mpsc::channel(10); let mut service = dht.inbound_middleware_layer().layer(SinkService::new(next_service_tx)); - let msg = wrap_in_envelope_body!(b"secret".to_vec()).unwrap(); - let mut dht_envelope = make_dht_envelope( - &node_identity, - msg.to_encoded_bytes().unwrap(), - DhtMessageFlags::empty(), - ); + let msg = wrap_in_envelope_body!(b"secret".to_vec()); + let mut dht_envelope = + make_dht_envelope(&node_identity, msg.to_encoded_bytes(), DhtMessageFlags::empty(), false); dht_envelope.header.as_mut().and_then(|header| { header.message_type = DhtMessageType::SafStoredMessages as i32; Some(header) }); - let inbound_message = make_comms_inbound_message( - &node_identity, - dht_envelope.to_encoded_bytes().unwrap().into(), - MessageFlags::empty(), - ); + let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into()); service.call(inbound_message).await.unwrap_err(); // This seems like the best way to tell that an open channel is empty without the test blocking indefinitely diff --git a/comms/dht/src/discovery/error.rs b/comms/dht/src/discovery/error.rs index 6bcefc1f92..1f3d3e8fdd 100644 --- a/comms/dht/src/discovery/error.rs +++ b/comms/dht/src/discovery/error.rs @@ -23,15 +23,13 @@ use crate::outbound::DhtOutboundError; use derive_error::Error; use futures::channel::mpsc::SendError; -use tari_comms::peer_manager::PeerManagerError; +use tari_comms::{connection_manager::ConnectionManagerError, peer_manager::PeerManagerError}; #[derive(Debug, Error)] pub enum DhtDiscoveryError { /// The reply channel was canceled ReplyCanceled, DhtOutboundError(DhtOutboundError), - /// Received a discovery response which did not match an inflight discovery request - InflightDiscoveryRequestNotFound, /// Received public key in peer discovery response which does not match the requested public key DiscoveredPeerMismatch, /// Received an invalid `NodeId` @@ -42,9 +40,12 @@ pub enum DhtDiscoveryError { SendBufferFull, /// The discovery request timed out DiscoveryTimeout, + /// Failed to send discovery message + DiscoverySendFailed, PeerManagerError(PeerManagerError), #[error(msg_embedded, non_std, no_from)] InvalidPeerMultiaddr(String), + ConnectionManagerError(ConnectionManagerError), } impl DhtDiscoveryError { diff --git a/comms/dht/src/discovery/requester.rs b/comms/dht/src/discovery/requester.rs index 0166fab720..3c9787464d 100644 --- a/comms/dht/src/discovery/requester.rs +++ b/comms/dht/src/discovery/requester.rs @@ -29,51 +29,16 @@ use std::{ fmt::{Display, Error, Formatter}, time::Duration, }; -use tari_comms::{ - peer_manager::{NodeId, Peer}, - types::CommsPublicKey, -}; +use tari_comms::{peer_manager::Peer, types::CommsPublicKey}; use tokio::time; -#[derive(Debug)] -pub struct DiscoverPeerRequest { - /// The public key of the peer to be discovered. The message will be encrypted with a DH shared - /// secret using this public key. - pub dest_public_key: Box, - /// The node id of the peer to be discovered, if it is known. Providing the `NodeId` allows - /// discovery requests to reach their destination more quickly. - pub dest_node_id: Option, - /// The destination to include in the comms header. - /// `Undisclosed` will require nodes to propagate the message across the network, presumably eventually - /// reaching the destination node (the node which can decrypt the message). This will happen without - /// any intermediary nodes knowing who is being searched for. - /// `NodeId` will direct the discovery request closer to the destination or network region. - /// `PublicKey` will be propagated across the network. If any node knows the peer, the request can be - /// forwarded to them immediately. However, more nodes will know that this node is being searched for - /// which may slightly compromise privacy. - pub destination: NodeDestination, -} - -impl Display for DiscoverPeerRequest { - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { - f.debug_struct("DiscoverPeerRequest") - .field("dest_public_key", &format!("{}", self.dest_public_key)) - .field( - "dest_node_id", - &self - .dest_node_id - .as_ref() - .map(|node_id| format!("{}", node_id)) - .unwrap_or_else(|| "None".to_string()), - ) - .field("destination", &format!("{}", self.destination)) - .finish() - } -} - #[derive(Debug)] pub enum DhtDiscoveryRequest { - DiscoverPeer(Box<(DiscoverPeerRequest, oneshot::Sender>)>), + DiscoverPeer( + Box, + NodeDestination, + oneshot::Sender>, + ), NotifyDiscoveryResponseReceived(Box), } @@ -81,8 +46,10 @@ impl Display for DhtDiscoveryRequest { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { use DhtDiscoveryRequest::*; match self { - DiscoverPeer(boxed) => write!(f, "DiscoverPeer({})", boxed.0), - NotifyDiscoveryResponseReceived(boxed) => write!(f, "NotifyDiscoveryResponseReceived({:#?})", *boxed), + DiscoverPeer(public_key, dest, _) => write!(f, "DiscoverPeer({}, {})", public_key, dest), + NotifyDiscoveryResponseReceived(discovery_resp) => { + write!(f, "NotifyDiscoveryResponseReceived({:#?})", discovery_resp) + }, } } } @@ -101,21 +68,34 @@ impl DhtDiscoveryRequester { } } + /// Initiate a peer discovery + /// + /// ## Arguments + /// - `dest_public_key` - The public key of he recipient used to create a shared ECDH key which in turn is used to + /// encrypt the discovery message + /// - `destination` - The `NodeDestination` to use in the DhtHeader when sending a discovery message. + /// - `Unknown` destination will maintain complete privacy, the trade off is that discovery needs to propagate + /// the entire network to reach the destination and so may take longer + /// - `NodeId` Instruct propagation nodes to direct the message to peers closer to the given NodeId. The `NodeId` + /// may be directed to a region close to the real destination (somewhat private) or directed at a particular + /// node (not private) + /// - `PublicKey` if any node on the network knows this public key, the message will be directed to that node. + /// This sacrifices privacy for more efficient discovery in terms of network bandwidth and may result in + /// quicker discovery times. pub async fn discover_peer( &mut self, dest_public_key: Box, - dest_node_id: Option, destination: NodeDestination, ) -> Result { let (reply_tx, reply_rx) = oneshot::channel(); - let request = DiscoverPeerRequest { - dest_public_key, - dest_node_id, - destination, - }; + self.sender - .send(DhtDiscoveryRequest::DiscoverPeer(Box::new((request, reply_tx)))) + .send(DhtDiscoveryRequest::DiscoverPeer( + dest_public_key, + destination, + reply_tx, + )) .await?; time::timeout( diff --git a/comms/dht/src/discovery/service.rs b/comms/dht/src/discovery/service.rs index 7b9c58c9fd..f603123b7f 100644 --- a/comms/dht/src/discovery/service.rs +++ b/comms/dht/src/discovery/service.rs @@ -21,10 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - discovery::{ - requester::{DhtDiscoveryRequest, DiscoverPeerRequest}, - DhtDiscoveryError, - }, + discovery::{requester::DhtDiscoveryRequest, DhtDiscoveryError}, envelope::{DhtMessageType, NodeDestination}, outbound::{OutboundEncryption, OutboundMessageRequester, SendMessageParams}, proto::dht::{DiscoveryMessage, DiscoveryResponseMessage}, @@ -33,11 +30,16 @@ use crate::{ use futures::{ channel::{mpsc, oneshot}, future::FutureExt, + stream::Fuse, StreamExt, }; use log::*; use rand::{rngs::OsRng, RngCore}; -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::HashMap, + sync::Arc, + time::{Duration, Instant}, +}; use tari_comms::{ connection_manager::{ConnectionManagerError, ConnectionManagerRequester}, log_if_error, @@ -47,8 +49,9 @@ use tari_comms::{ validate_peer_addresses, ConnectionManagerEvent, }; -use tari_crypto::tari_utilities::{hex::Hex, ByteArray}; use tari_shutdown::ShutdownSignal; +use tari_utilities::{hex::Hex, ByteArray}; +use tokio::{sync::broadcast, task, time}; const LOG_TARGET: &str = "comms::dht::discovery_service"; @@ -59,6 +62,17 @@ const MAX_FAILED_ATTEMPTS_MARK_PEER_OFFLINE: usize = 10; struct DiscoveryRequestState { reply_tx: oneshot::Sender>, public_key: Box, + start_ts: Instant, +} + +impl DiscoveryRequestState { + pub fn new(public_key: Box, reply_tx: oneshot::Sender>) -> Self { + Self { + public_key, + reply_tx, + start_ts: Instant::now(), + } + } } pub struct DhtDiscoveryService { @@ -95,7 +109,14 @@ impl DhtDiscoveryService { } } - pub async fn run(mut self) { + pub fn spawn(self) { + let connection_events = self.connection_manager.get_event_subscription().fuse(); + info!(target: LOG_TARGET, "Discovery service started"); + task::spawn(async move { self.run(connection_events).await }); + } + + pub async fn run(mut self, mut connection_events: Fuse>>) { + info!(target: LOG_TARGET, "Dht discovery service started"); let mut shutdown_signal = self .shutdown_signal .take() @@ -108,8 +129,6 @@ impl DhtDiscoveryService { .expect("DiscoveryService initialized without request_rx") .fuse(); - let mut connection_events = self.connection_manager.get_event_subscription().fuse(); - loop { futures::select! { request = request_rx.select_next_some() => { @@ -120,7 +139,9 @@ impl DhtDiscoveryService { event = connection_events.select_next_some() => { if let Ok(event) = event { trace!(target: LOG_TARGET, "Received connection manager event '{}'", event); - self.handle_connection_manager_event(&event).await; + if let Err(err) = self.handle_connection_manager_event(&event).await { + error!(target: LOG_TARGET, "Error handling connection manager event: {:?}", err); + } } }, @@ -135,65 +156,96 @@ impl DhtDiscoveryService { async fn handle_request(&mut self, request: DhtDiscoveryRequest) { use DhtDiscoveryRequest::*; match request { - DiscoverPeer(boxed) => { - let (request, reply_tx) = *boxed; + DiscoverPeer(dest_pubkey, destination, reply_tx) => { log_if_error!( target: LOG_TARGET, - self.initiate_peer_discovery(request, reply_tx).await, + self.initiate_peer_discovery(dest_pubkey, destination, reply_tx).await, "Failed to initiate a discovery request because '{error}'", ); }, - NotifyDiscoveryResponseReceived(discovery_msg) => self.handle_discovery_response(*discovery_msg).await, + NotifyDiscoveryResponseReceived(discovery_msg) => self.handle_discovery_response(discovery_msg).await, } } - async fn handle_connection_manager_event(&mut self, event: &ConnectionManagerEvent) { + async fn handle_connection_manager_event( + &mut self, + event: &ConnectionManagerEvent, + ) -> Result<(), DhtDiscoveryError> + { use ConnectionManagerEvent::*; + // The connection manager could not dial the peer on any address match event { - // The connection manager could not dial the peer on any address - PeerConnectFailed(node_id, ConnectionManagerError::DialConnectFailedAllAddresses) => { - // Send out a discovery for that peer without keeping track of it as an inflight discovery - match self.peer_manager.find_by_node_id(node_id).await { - Ok(peer) => { - if peer.connection_stats.failed_attempts() > MAX_FAILED_ATTEMPTS_MARK_PEER_OFFLINE { - debug!( - target: LOG_TARGET, - "Deleting stale peer '{}' because this node failed to connect to them {} times", - peer.node_id.short_str(), - MAX_FAILED_ATTEMPTS_MARK_PEER_OFFLINE - ); - if let Err(err) = self.peer_manager.set_offline(&peer.public_key, true).await { - error!(target: LOG_TARGET, "Failed to mark peer as offline because '{:?}'", err); - } - } else { - debug!( - target: LOG_TARGET, - "Attempting to discover peer '{}' because we failed to connect on all addresses for \ - the peer", - peer.node_id.short_str() - ); - // Attempt to discover them - let request = DiscoverPeerRequest { - dest_public_key: Box::new(peer.public_key), - // TODO: This should be the node region, not the node id - dest_node_id: Some(peer.node_id), - destination: Default::default(), - }; - // Don't need to be notified for this discovery - let (reply_tx, _) = oneshot::channel(); - if let Err(err) = self.initiate_peer_discovery(request, reply_tx).await { - error!(target: LOG_TARGET, "Error sending discovery message: {:?}", err); - } - } - }, - Err(err) => error!(target: LOG_TARGET, "{:?}", err), + PeerConnectFailed(node_id, ConnectionManagerError::ConnectFailedMaximumAttemptsReached) => { + if self.connection_manager.get_num_active_connections().await? == 0 { + info!( + target: LOG_TARGET, + "Unsure if we're online because we have no connections. Ignoring connection failed event for \ + peer '{}'.", + node_id + ); + return Ok(()); + } + let peer = self.peer_manager.find_by_node_id(node_id).await?; + if peer.connection_stats.failed_attempts() > MAX_FAILED_ATTEMPTS_MARK_PEER_OFFLINE { + debug!( + target: LOG_TARGET, + "Marking peer '{}' as offline because this node failed to connect to them {} times", + peer.node_id.short_str(), + MAX_FAILED_ATTEMPTS_MARK_PEER_OFFLINE + ); + let neighbourhood_stats = self + .peer_manager + .get_region_stats( + self.node_identity.node_id(), + self.config.num_neighbouring_nodes, + PeerFeatures::COMMUNICATION_NODE, + ) + .await?; + // If the node_id is not neighbouring or else if it is, the ratio of offline neighbouring peers + // is below 30%, mark the peer as offline + if !neighbourhood_stats.in_region(node_id) || neighbourhood_stats.offline_ratio() <= 0.3 { + self.peer_manager.set_offline(&peer.public_key, true).await?; + } else { + debug!( + target: LOG_TARGET, + "Not marking neighbouring peer '{}' as offline ({})", node_id, neighbourhood_stats + ); + } + } else { + // if !self.has_inflight_discovery(&peer.public_key) { + // debug!( + // target: LOG_TARGET, + // "Attempting to discover peer '{}' because we failed to connect on all addresses for the + // peer", + // peer.node_id.short_str() + // ); + // + // // Don't need to be notified for this discovery + // let (reply_tx, _) = oneshot::channel(); + // // Send out a discovery for that peer without keeping track of it as an inflight discovery + // let dest_pubkey = Box::new(peer.public_key); + // self.initiate_peer_discovery( + // dest_pubkey.clone(), + // NodeDestination::PublicKey(dest_pubkey), + // reply_tx, + // ) + // .await?; + // } } }, _ => {}, } + + Ok(()) } + // fn has_inflight_discovery(&self, public_key: &CommsPublicKey) -> bool { + // self.inflight_discoveries + // .values() + // .all(|state| &*state.public_key != public_key) + // } + fn collect_all_discovery_requests(&mut self, public_key: &CommsPublicKey) -> Vec { let mut requests = Vec::new(); let mut remaining_requests = HashMap::new(); @@ -217,40 +269,73 @@ impl DhtDiscoveryService { requests } - async fn handle_discovery_response(&mut self, discovery_msg: DiscoveryResponseMessage) { + async fn handle_discovery_response(&mut self, discovery_msg: Box) { trace!( target: LOG_TARGET, "Received discovery response message from {}", discovery_msg.node_id.to_hex() ); - if let Some(request) = log_if_error!( - target: LOG_TARGET, - self.inflight_discoveries - .remove(&discovery_msg.nonce) - .ok_or_else(|| DhtDiscoveryError::InflightDiscoveryRequestNotFound), - "{error}", - ) { - let DiscoveryRequestState { public_key, reply_tx } = request; - - let result = self.validate_then_add_peer(&public_key, discovery_msg).await; - - // Resolve any other pending discover requests if the peer was found - if let Ok(peer) = &result { - for request in self.collect_all_discovery_requests(&public_key) { - let _ = request.reply_tx.send(Ok(peer.clone())); - } - } - trace!(target: LOG_TARGET, "Discovery request is recognised and valid"); + match self.inflight_discoveries.remove(&discovery_msg.nonce) { + Some(request) => { + let DiscoveryRequestState { + public_key, + reply_tx, + start_ts, + } = request; + + let result = self.validate_then_add_peer(&public_key, discovery_msg).await; - let _ = reply_tx.send(result); + // Resolve any other pending discover requests if the peer was found + match &result { + Ok(peer) => { + info!( + target: LOG_TARGET, + "Received discovery response from peer {}. Discovery completed in {}s", + peer.node_id, + (Instant::now() - start_ts).as_secs_f32() + ); + + for request in self.collect_all_discovery_requests(&public_key) { + if !reply_tx.is_canceled() { + let _ = request.reply_tx.send(Ok(peer.clone())); + } + } + + debug!( + target: LOG_TARGET, + "Discovery request for Node Id {} completed successfully", + peer.node_id.to_hex(), + ); + }, + Err(err) => { + info!( + target: LOG_TARGET, + "Failed to validate and add peer from discovery response from peer. {:?} Discovery \ + completed in {}s", + err, + (Instant::now() - start_ts).as_secs_f32() + ); + }, + } + + let _ = reply_tx.send(result); + }, + None => { + info!( + target: LOG_TARGET, + "Received a discovery response from peer '{}' that this node did not expect. It may have been \ + cancelled earlier.", + discovery_msg.node_id.to_hex() + ); + }, } } async fn validate_then_add_peer( &mut self, public_key: &CommsPublicKey, - discovery_msg: DiscoveryResponseMessage, + discovery_msg: Box, ) -> Result { let node_id = self.validate_raw_node_id(&public_key, &discovery_msg.node_id)?; @@ -311,6 +396,8 @@ impl DhtDiscoveryService { Some(node_id), Some(net_addresses), None, + None, + Some(false), Some(peer_features), None, None, @@ -339,13 +426,13 @@ impl DhtDiscoveryService { async fn initiate_peer_discovery( &mut self, - discovery_request: DiscoverPeerRequest, + dest_pubkey: Box, + destination: NodeDestination, reply_tx: oneshot::Sender>, ) -> Result<(), DhtDiscoveryError> { let nonce = OsRng.next_u64(); - let public_key = discovery_request.dest_public_key.clone(); - self.send_discover(nonce, discovery_request).await?; + self.send_discover(nonce, destination, dest_pubkey.clone()).await?; let inflight_count = self.inflight_discoveries.len(); @@ -363,12 +450,8 @@ impl DhtDiscoveryService { ); // Add the new inflight request. - let key_exists = self - .inflight_discoveries - .insert(nonce, DiscoveryRequestState { reply_tx, public_key }) - .is_some(); - // The nonce should never be chosen more than once - debug_assert!(!key_exists); + self.inflight_discoveries + .insert(nonce, DiscoveryRequestState::new(dest_pubkey, reply_tx)); trace!( target: LOG_TARGET, @@ -382,52 +465,62 @@ impl DhtDiscoveryService { async fn send_discover( &mut self, nonce: u64, - discovery_request: DiscoverPeerRequest, + destination: NodeDestination, + dest_public_key: Box, ) -> Result<(), DhtDiscoveryError> { - let DiscoverPeerRequest { - dest_node_id, - dest_public_key, - destination, - } = discovery_request; - - // If the destination node is is known, send to the closest peers we know. Otherwise... - let network_location_node_id = dest_node_id - .or_else(|| match &destination { - // ... if the destination is undisclosed or a public key, send discover to our closest peers - NodeDestination::Unknown | NodeDestination::PublicKey(_) => Some(self.node_identity.node_id().clone()), - // otherwise, send it to the closest peers to the given NodeId destination we know - NodeDestination::NodeId(node_id) => Some(*node_id.clone()), - }) - .expect("cannot fail"); - let discover_msg = DiscoveryMessage { node_id: self.node_identity.node_id().to_vec(), addresses: vec![self.node_identity.public_address().to_string()], peer_features: self.node_identity.features().bits(), nonce, }; - debug!( + info!( target: LOG_TARGET, - "Sending Discover message to (at most) {} closest peers", self.config.num_neighbouring_nodes + "Sending Discovery message for peer public key '{}' with destination {}", dest_public_key, destination ); - self.outbound_requester + let send_states = self + .outbound_requester .send_message_no_header( SendMessageParams::new() - .closest( - network_location_node_id, - self.config.num_neighbouring_nodes, - Vec::new(), - PeerFeatures::empty(), - ) + .neighbours_include_clients(Vec::new()) .with_destination(destination) .with_encryption(OutboundEncryption::EncryptFor(dest_public_key)) .with_dht_message_type(DhtMessageType::Discovery) .finish(), discover_msg, ) - .await?; + .await? + .resolve_ok() + .await + .ok_or_else(|| DhtDiscoveryError::DiscoverySendFailed)?; + + // Spawn a task to log how the sending of discovery went + task::spawn(async move { + info!( + target: LOG_TARGET, + "Discovery sent to {} peer(s). Waiting to see how many got through.", + send_states.len() + ); + let result = time::timeout(Duration::from_secs(10), send_states.wait_percentage_success(0.51)).await; + match result { + Ok((succeeded, failed)) => { + let num_succeeded = succeeded.len(); + let num_failed = failed.len(); + + info!( + target: LOG_TARGET, + "Discovery sent to a majority of neighbouring peers ({} succeeded, {} failed)", + num_succeeded, + num_failed + ); + }, + Err(_) => { + warn!(target: LOG_TARGET, "Failed to send discovery to a majority of peers"); + }, + } + }); Ok(()) } @@ -444,46 +537,45 @@ mod test { use std::time::Duration; use tari_comms::test_utils::mocks::create_connection_manager_mock; use tari_shutdown::Shutdown; - use tari_test_utils::runtime; - - #[test] - fn send_discovery() { - runtime::test_async(|rt| { - let node_identity = make_node_identity(); - let peer_manager = make_peer_manager(); - let (outbound_requester, outbound_mock) = create_outbound_service_mock(10); - let oms_mock_state = outbound_mock.get_state(); - rt.spawn(outbound_mock.run()); - - let (connection_manager, _) = create_connection_manager_mock(1); - let (sender, receiver) = mpsc::channel(10); - // Requester which timeout instantly - let mut requester = DhtDiscoveryRequester::new(sender, Duration::from_millis(1)); - let mut shutdown = Shutdown::new(); - - let service = DhtDiscoveryService::new( - DhtConfig::default(), - node_identity, - peer_manager, - outbound_requester, - connection_manager, - receiver, - shutdown.to_signal(), - ); - - rt.spawn(service.run()); - let dest_public_key = Box::new(CommsPublicKey::default()); - let result = rt.block_on(requester.discover_peer(dest_public_key.clone(), None, NodeDestination::Unknown)); - - assert!(result.unwrap_err().is_timeout()); + #[tokio_macros::test_basic] + async fn send_discovery() { + let node_identity = make_node_identity(); + let peer_manager = make_peer_manager(); + let (outbound_requester, outbound_mock) = create_outbound_service_mock(10); + let oms_mock_state = outbound_mock.get_state(); + task::spawn(outbound_mock.run()); + + let (connection_manager, _) = create_connection_manager_mock(1); + let (sender, receiver) = mpsc::channel(10); + // Requester which timeout instantly + let mut requester = DhtDiscoveryRequester::new(sender, Duration::from_millis(1)); + let shutdown = Shutdown::new(); + + DhtDiscoveryService::new( + DhtConfig::default(), + node_identity, + peer_manager, + outbound_requester, + connection_manager, + receiver, + shutdown.to_signal(), + ) + .spawn(); + + let dest_public_key = Box::new(CommsPublicKey::default()); + let result = requester + .discover_peer( + dest_public_key.clone(), + NodeDestination::PublicKey(dest_public_key.clone()), + ) + .await; - oms_mock_state.wait_call_count(1, Duration::from_secs(5)).unwrap(); - let (params, _) = oms_mock_state.pop_call().unwrap(); - assert_eq!(params.dht_message_type, DhtMessageType::Discovery); - assert_eq!(params.encryption, OutboundEncryption::EncryptFor(dest_public_key)); + assert!(result.unwrap_err().is_timeout()); - shutdown.trigger().unwrap(); - }) + oms_mock_state.wait_call_count(1, Duration::from_secs(5)).unwrap(); + let (params, _) = oms_mock_state.pop_call().unwrap(); + assert_eq!(params.dht_message_type, DhtMessageType::Discovery); + assert_eq!(params.encryption, OutboundEncryption::EncryptFor(dest_public_key)); } } diff --git a/comms/dht/src/envelope.rs b/comms/dht/src/envelope.rs index 5393f36797..56aa246d1e 100644 --- a/comms/dht/src/envelope.rs +++ b/comms/dht/src/envelope.rs @@ -20,7 +20,6 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{consts::DHT_ENVELOPE_HEADER_VERSION, proto::envelope::DhtOrigin}; use bitflags::bitflags; use derive_error::Error; use serde::{Deserialize, Serialize}; @@ -29,11 +28,12 @@ use std::{ fmt, fmt::Display, }; -use tari_comms::{peer_manager::NodeId, types::CommsPublicKey, utils::signature}; -use tari_crypto::tari_utilities::{hex::Hex, ByteArray, ByteArrayError}; +use tari_comms::{peer_manager::NodeId, types::CommsPublicKey}; +use tari_utilities::{ByteArray, ByteArrayError}; // Re-export applicable protos pub use crate::proto::envelope::{dht_header::Destination, DhtEnvelope, DhtHeader, DhtMessageType, Network}; +use bytes::Bytes; #[derive(Debug, Error)] pub enum DhtMessageError { @@ -47,6 +47,8 @@ pub enum DhtMessageError { InvalidNetwork, /// Invalid or unrecognised DHT message flags InvalidMessageFlags, + /// Invalid ephemeral public key + InvalidEphemeralPublicKey, /// Header was omitted from the message HeaderOmitted, } @@ -69,46 +71,36 @@ bitflags! { } } -impl DhtMessageType { - pub fn is_dht_message(self) -> bool { - match self { - DhtMessageType::None => false, - _ => true, - } +impl DhtMessageFlags { + pub fn is_encrypted(self) -> bool { + self.contains(Self::ENCRYPTED) } } -#[derive(Clone, PartialEq, Eq)] -pub struct DhtMessageOrigin { - pub public_key: CommsPublicKey, - pub signature: Vec, -} - -impl fmt::Debug for DhtMessageOrigin { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("DhtMessageOrigin") - .field("public_key", &self.public_key.to_hex()) - .field("signature", &self.signature.to_hex()) - .finish() +impl DhtMessageType { + pub fn is_dht_message(self) -> bool { + self.is_dht_discovery() || self.is_dht_join() } -} -impl TryFrom for DhtMessageOrigin { - type Error = DhtMessageError; + pub fn is_dht_discovery(self) -> bool { + match self { + DhtMessageType::Discovery => true, + _ => false, + } + } - fn try_from(value: DhtOrigin) -> Result { - Ok(Self { - public_key: CommsPublicKey::from_bytes(&value.public_key).map_err(|_| DhtMessageError::InvalidOrigin)?, - signature: value.signature, - }) + pub fn is_dht_join(self) -> bool { + match self { + DhtMessageType::Join => true, + _ => false, + } } -} -impl From for DhtOrigin { - fn from(value: DhtMessageOrigin) -> Self { - Self { - public_key: value.public_key.to_vec(), - signature: value.signature, + pub fn is_saf_message(self) -> bool { + use DhtMessageType::*; + match self { + SafRequestMessages | SafStoredMessages => true, + _ => false, } } } @@ -119,30 +111,21 @@ impl From for DhtOrigin { pub struct DhtMessageHeader { pub version: u32, pub destination: NodeDestination, - /// Origin of the message. This can refer to the same peer that sent the message - /// or another peer if the message should be forwarded. - pub origin: Option, + /// Encoded DhtOrigin. This can refer to the same peer that sent the message + /// or another peer if the message is being propagated. + pub origin_mac: Vec, + pub ephemeral_public_key: Option, pub message_type: DhtMessageType, pub network: Network, pub flags: DhtMessageFlags, } impl DhtMessageHeader { - pub fn new( - destination: NodeDestination, - message_type: DhtMessageType, - origin: Option, - network: Network, - flags: DhtMessageFlags, - ) -> Self - { - Self { - version: DHT_ENVELOPE_HEADER_VERSION, - destination, - origin, - message_type, - network, - flags, + pub fn is_valid(&self) -> bool { + if self.flags.contains(DhtMessageFlags::ENCRYPTED) { + !self.origin_mac.is_empty() && self.ephemeral_public_key.is_some() + } else { + true } } } @@ -151,8 +134,8 @@ impl Display for DhtMessageHeader { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { write!( f, - "DhtMessageHeader (Dest:{}, Origin:{:?}, Type:{:?}, Network:{:?}, Flags:{:?})", - self.destination, self.origin, self.message_type, self.network, self.flags + "DhtMessageHeader (Dest:{}, Type:{:?}, Network:{:?}, Flags:{:?})", + self.destination, self.message_type, self.network, self.flags ) } } @@ -168,15 +151,20 @@ impl TryFrom for DhtMessageHeader { .map(Option::unwrap) .ok_or_else(|| DhtMessageError::InvalidDestination)?; - let origin = match header.origin { - Some(origin) => Some(origin.try_into()?), - None => None, + let ephemeral_public_key = if header.ephemeral_public_key.is_empty() { + None + } else { + Some( + CommsPublicKey::from_bytes(&header.ephemeral_public_key) + .map_err(|_| DhtMessageError::InvalidEphemeralPublicKey)?, + ) }; Ok(Self { version: header.version, destination, - origin, + origin_mac: header.origin_mac, + ephemeral_public_key, message_type: DhtMessageType::from_i32(header.message_type) .ok_or_else(|| DhtMessageError::InvalidMessageType)?, network: Network::from_i32(header.network).ok_or_else(|| DhtMessageError::InvalidNetwork)?, @@ -200,7 +188,12 @@ impl From for DhtHeader { fn from(header: DhtMessageHeader) -> Self { Self { version: header.version, - origin: header.origin.map(Into::into), + ephemeral_public_key: header + .ephemeral_public_key + .as_ref() + .map(ByteArray::to_vec) + .unwrap_or_else(Vec::new), + origin_mac: header.origin_mac, destination: Some(header.destination.into()), message_type: header.message_type as i32, network: header.network as i32, @@ -210,44 +203,12 @@ impl From for DhtHeader { } impl DhtEnvelope { - pub fn new(header: DhtHeader, body: Vec) -> Self { + pub fn new(header: DhtHeader, body: Bytes) -> Self { Self { header: Some(header), - body, + body: body.to_vec(), } } - - /// Returns true if the header and origin are present, otherwise false - pub fn has_origin(&self) -> bool { - self.header.as_ref().map(|h| h.origin.is_some()).unwrap_or(false) - } - - /// Verifies the origin signature and returns true if it is valid. - /// - /// This method panics if called on an envelope without an origin. This should be checked before calling this - /// function by using the `DhtEnvelope::has_origin` method - pub fn is_origin_signature_valid(&self) -> bool { - self.header - .as_ref() - .and_then(|header| { - let origin = header - .origin - .as_ref() - .expect("call is_origin_signature_valid on envelope without origin"); - - CommsPublicKey::from_bytes(&origin.public_key) - .map(|pk| (pk, &origin.signature)) - .ok() - }) - .map(|(origin_public_key, origin_signature)| { - match signature::verify(&origin_public_key, origin_signature, &self.body) { - Ok(is_valid) => is_valid, - // error means that the signature could not deserialize, so is invalid - Err(_) => false, - } - }) - .unwrap_or(false) - } } /// Represents the ways a destination node can be represented. @@ -270,6 +231,41 @@ impl NodeDestination { NodeDestination::NodeId(node_id) => node_id.to_vec(), } } + + pub fn public_key(&self) -> Option<&CommsPublicKey> { + match self { + NodeDestination::Unknown => None, + NodeDestination::PublicKey(pk) => Some(pk), + NodeDestination::NodeId(_) => None, + } + } + + pub fn node_id(&self) -> Option<&NodeId> { + match self { + NodeDestination::Unknown => None, + NodeDestination::PublicKey(_) => None, + NodeDestination::NodeId(node_id) => Some(node_id), + } + } + + pub fn is_unknown(&self) -> bool { + match self { + NodeDestination::Unknown => true, + _ => false, + } + } +} + +impl PartialEq<&CommsPublicKey> for NodeDestination { + fn eq(&self, other: &&CommsPublicKey) -> bool { + self.public_key().map(|pk| pk == *other).unwrap_or(false) + } +} + +impl PartialEq<&NodeId> for NodeDestination { + fn eq(&self, other: &&NodeId) -> bool { + self.node_id().map(|node_id| node_id == *other).unwrap_or(false) + } } impl Display for NodeDestination { @@ -304,6 +300,18 @@ impl TryFrom for NodeDestination { } } +impl From for NodeDestination { + fn from(pk: CommsPublicKey) -> Self { + NodeDestination::PublicKey(Box::new(pk)) + } +} + +impl From for NodeDestination { + fn from(node_id: NodeId) -> Self { + NodeDestination::NodeId(Box::new(node_id)) + } +} + impl From for Destination { fn from(destination: NodeDestination) -> Self { use NodeDestination::*; diff --git a/comms/dht/src/inbound/decryption.rs b/comms/dht/src/inbound/decryption.rs index 8c79b4d191..3f55c0c9c3 100644 --- a/comms/dht/src/inbound/decryption.rs +++ b/comms/dht/src/inbound/decryption.rs @@ -22,17 +22,42 @@ use crate::{ crypt, - envelope::DhtMessageFlags, + envelope::{DhtMessageFlags, DhtMessageHeader}, inbound::message::{DecryptedDhtMessage, DhtInboundMessage}, + proto::envelope::OriginMac, }; +use derive_error::Error; use futures::{task::Context, Future}; use log::*; use prost::Message; use std::{sync::Arc, task::Poll}; -use tari_comms::{message::EnvelopeBody, peer_manager::NodeIdentity, pipeline::PipelineError}; +use tari_comms::{ + message::EnvelopeBody, + peer_manager::NodeIdentity, + pipeline::PipelineError, + types::CommsPublicKey, + utils::signature, +}; +use tari_utilities::ByteArray; use tower::{layer::Layer, Service, ServiceExt}; -const LOG_TARGET: &str = "comms::middleware::encryption"; +const LOG_TARGET: &str = "comms::middleware::decryption"; + +#[derive(Error, Debug)] +enum DecryptionError { + /// Failed to validate origin MAC signature + OriginMacInvalidSignature, + /// Origin MAC contained an invalid public key + OriginMacInvalidPublicKey, + /// Origin MAC not provided for encrypted message + OriginMacNotProvided, + /// Failed to decrypt origin MAC + OriginMacDecryptedFailed, + /// Failed to decode clear-text origin MAC + OriginMacClearTextDecodeFailed, + /// Failed to decrypt message body + MessageBodyDecryptionFailed, +} /// This layer is responsible for attempting to decrypt inbound messages. pub struct DecryptionLayer { @@ -70,9 +95,7 @@ impl DecryptionService { } impl Service for DecryptionService -where - S: Service + Clone, - S::Error: std::error::Error + Send + Sync + 'static, +where S: Service + Clone { type Error = PipelineError; type Response = (); @@ -89,9 +112,7 @@ where } impl DecryptionService -where - S: Service, - S::Error: std::error::Error + Send + Sync + 'static, +where S: Service { async fn handle_message( next_service: S, @@ -100,78 +121,140 @@ where ) -> Result<(), PipelineError> { let dht_header = &message.dht_header; + if !dht_header.flags.contains(DhtMessageFlags::ENCRYPTED) { return Self::success_not_encrypted(next_service, message).await; } - let origin = dht_header - .origin + let e_pk = dht_header + .ephemeral_public_key .as_ref() - // TODO: #banheuristics - this should not have been sent/propagated - .ok_or_else(|| "Message origin field is required for encrypted messages")?; + // TODO: #banheuristic - encrypted message sent without ephemeral public key + .ok_or("Ephemeral public key not provided for encrypted message")?; - debug!(target: LOG_TARGET, "Attempting to decrypt message"); - let shared_secret = crypt::generate_ecdh_secret(node_identity.secret_key(), &origin.public_key); - match crypt::decrypt(&shared_secret, &message.body) { - Ok(decrypted) => Self::decryption_succeeded(next_service, message, &decrypted).await, + let shared_secret = crypt::generate_ecdh_secret(node_identity.secret_key(), e_pk); + + // Decrypt and verify the origin + let authenticated_origin = match Self::attempt_decrypt_origin_mac(&shared_secret, dht_header) { + Ok((public_key, signature)) => { + // If this fails, discard the message because we decrypted and deserialized the message with our shared + // ECDH secret but the message could not be authenticated + Self::authenticate_origin_mac(&public_key, &signature, &message.body) + .map_err(PipelineError::from_debug)?; + public_key + }, + Err(err) => { + debug!(target: LOG_TARGET, "Unable to decrypt message origin: {}", err); + return Self::decryption_failed(next_service, &node_identity, message).await; + }, + }; + + debug!( + target: LOG_TARGET, + "Attempting to decrypt message body from origin public key '{}'", authenticated_origin + ); + match Self::attempt_decrypt_message_body(&shared_secret, &message.body) { + Ok(message_body) => { + debug!(target: LOG_TARGET, "Message successfully decrypted"); + let msg = DecryptedDhtMessage::succeeded(message_body, Some(authenticated_origin), message); + next_service.oneshot(msg).await + }, Err(err) => { debug!(target: LOG_TARGET, "Unable to decrypt message: {}", err); - Self::decryption_failed(next_service, message).await + Self::decryption_failed(next_service, &node_identity, message).await }, } } - async fn decryption_succeeded( - next_service: S, - message: DhtInboundMessage, - decrypted: &[u8], - ) -> Result<(), PipelineError> + fn attempt_decrypt_origin_mac( + shared_secret: &CommsPublicKey, + dht_header: &DhtMessageHeader, + ) -> Result<(CommsPublicKey, Vec), DecryptionError> + { + let encrypted_origin_mac = Some(&dht_header.origin_mac) + .filter(|b| !b.is_empty()) + // TODO: #banheuristic - this should not have been sent/propagated + .ok_or_else(|| DecryptionError::OriginMacNotProvided)?; + let decrypted_bytes = crypt::decrypt(shared_secret, encrypted_origin_mac) + .map_err(|_| DecryptionError::OriginMacDecryptedFailed)?; + let origin_mac = + OriginMac::decode(decrypted_bytes.as_slice()).map_err(|_| DecryptionError::OriginMacDecryptedFailed)?; + // Check the public key here, because it is possible (rare but possible) for an failed decrypted message to pass + // protobuf decoding of the relatively simple OriginMac struct but with invalid data + let public_key = CommsPublicKey::from_bytes(&origin_mac.public_key) + .map_err(|_| DecryptionError::OriginMacInvalidPublicKey)?; + Ok((public_key, origin_mac.signature)) + } + + fn authenticate_origin_mac( + public_key: &CommsPublicKey, + signature: &[u8], + body: &[u8], + ) -> Result<(), DecryptionError> + { + if signature::verify(public_key, signature, body).unwrap_or(false) { + Ok(()) + } else { + Err(DecryptionError::OriginMacInvalidSignature) + } + } + + fn attempt_decrypt_message_body( + shared_secret: &CommsPublicKey, + message_body: &[u8], + ) -> Result { + let decrypted = + crypt::decrypt(shared_secret, message_body).map_err(|_| DecryptionError::MessageBodyDecryptionFailed)?; // Deserialization into an EnvelopeBody is done here to determine if the // decryption produced valid bytes or not. - let result = EnvelopeBody::decode(decrypted).and_then(|body| { - // Check if we received a body length of zero - // - // In addition to a peer sending a zero-length EnvelopeBody, decoding can erroneously succeed - // if the decrypted bytes happen to be valid protobuf encoding. This is very possible and - // the decrypt_inbound_fail test below _will_ sporadically fail without the following check. - // This is because proto3 will set fields to their default value if they don't exist in a valid encoding. - // - // For the parts of EnvelopeBody to be erroneously populated with bytes, all of these - // conditions would have to be true: - // 1. field type == 2 (length-delimited) - // 2. field number == 1 - // 3. the subsequent byte(s) would have to be varint-encoded length which does not overflow - // 4. the rest of the bytes would have to be valid protobuf encoding - // - // The chance of this happening is extremely negligible. - if body.is_empty() { - return Err(prost::DecodeError::new("EnvelopeBody has no parts")); - } - Ok(body) - }); - match result { - Ok(deserialized) => { - debug!(target: LOG_TARGET, "Message successfully decrypted"); - let msg = DecryptedDhtMessage::succeeded(deserialized, message); - next_service.oneshot(msg).await.map_err(PipelineError::from_debug) - }, - Err(err) => { - debug!(target: LOG_TARGET, "Unable to deserialize message: {}", err); - Self::decryption_failed(next_service, message).await - }, - } + EnvelopeBody::decode(decrypted.as_slice()) + .and_then(|body| { + // Check if we received a body length of zero + // + // In addition to a peer sending a zero-length EnvelopeBody, decoding can erroneously succeed + // if the decrypted bytes happen to be valid protobuf encoding. This is very possible and + // the decrypt_inbound_fail test below _will_ sporadically fail without the following check. + // This is because proto3 will set fields to their default value if they don't exist in a valid + // encoding. + // + // For the parts of EnvelopeBody to be erroneously populated with bytes, all of these + // conditions would have to be true: + // 1. field type == 2 (length-delimited) + // 2. field number == 1 + // 3. the subsequent byte(s) would have to be varint-encoded length which does not overflow + // 4. the rest of the bytes would have to be valid protobuf encoding + // + // The chance of this happening is extremely negligible. + if body.is_empty() { + return Err(prost::DecodeError::new("EnvelopeBody has no parts")); + } + Ok(body) + }) + .map_err(|_| DecryptionError::MessageBodyDecryptionFailed) } async fn success_not_encrypted(next_service: S, message: DhtInboundMessage) -> Result<(), PipelineError> { + let authenticated_pk = if message.dht_header.origin_mac.is_empty() { + None + } else { + let origin_mac = OriginMac::decode(message.dht_header.origin_mac.as_slice()) + .map_err(|_| PipelineError::from_debug(DecryptionError::OriginMacClearTextDecodeFailed))?; + let public_key = CommsPublicKey::from_bytes(&origin_mac.public_key) + .map_err(|_| PipelineError::from_debug(DecryptionError::OriginMacInvalidPublicKey))?; + Self::authenticate_origin_mac(&public_key, &origin_mac.signature, &message.body) + .map_err(PipelineError::from_debug)?; + Some(public_key) + }; + match EnvelopeBody::decode(message.body.as_slice()) { Ok(deserialized) => { debug!( target: LOG_TARGET, "Message is not encrypted. Passing onto next service" ); - let msg = DecryptedDhtMessage::succeeded(deserialized, message); - next_service.oneshot(msg).await.map_err(PipelineError::from_debug) + let msg = DecryptedDhtMessage::succeeded(deserialized, authenticated_pk, message); + next_service.oneshot(msg).await }, Err(err) => { // Message was not encrypted but failed to deserialize - immediately discard @@ -185,9 +268,28 @@ where } } - async fn decryption_failed(next_service: S, message: DhtInboundMessage) -> Result<(), PipelineError> { + async fn decryption_failed( + next_service: S, + node_identity: &NodeIdentity, + message: DhtInboundMessage, + ) -> Result<(), PipelineError> + { + if message.dht_header.destination == node_identity.node_id() || + message.dht_header.destination == node_identity.public_key() + { + // TODO: #banheuristic - the origin of this message sent this node a message we could not decrypt + warn!( + target: LOG_TARGET, + "Received message from peer '{}' that is destined for this node that could not be decrypted. \ + Discarding message", + message.source_peer.node_id + ); + return Err( + "Message rejected because this node could not decrypt a message that was addressed to it".into(), + ); + } let msg = DecryptedDhtMessage::failed(message); - next_service.oneshot(msg).await.map_err(PipelineError::from_debug) + next_service.oneshot(msg).await } } @@ -226,10 +328,13 @@ mod test { let node_identity = make_node_identity(); let mut service = DecryptionService::new(inner, Arc::clone(&node_identity)); - let plain_text_msg = wrap_in_envelope_body!(Vec::new()).unwrap(); - let secret_key = crypt::generate_ecdh_secret(node_identity.secret_key(), node_identity.public_key()); - let encrypted = crypt::encrypt(&secret_key, &plain_text_msg.to_encoded_bytes().unwrap()).unwrap(); - let inbound_msg = make_dht_inbound_message(&node_identity, encrypted, DhtMessageFlags::ENCRYPTED); + let plain_text_msg = wrap_in_envelope_body!(b"Secret plans".to_vec()); + let inbound_msg = make_dht_inbound_message( + &node_identity, + plain_text_msg.to_encoded_bytes(), + DhtMessageFlags::ENCRYPTED, + true, + ); block_on(service.call(inbound_msg)).unwrap(); let decrypted = result.lock().unwrap().take().unwrap(); @@ -247,13 +352,35 @@ mod test { let node_identity = make_node_identity(); let mut service = DecryptionService::new(inner, Arc::clone(&node_identity)); - let nonsense = "Cannot Decrypt this".as_bytes().to_vec(); - let inbound_msg = make_dht_inbound_message(&node_identity, nonsense.clone(), DhtMessageFlags::ENCRYPTED); + let some_secret = "Super secret message".as_bytes().to_vec(); + let some_other_node_identity = make_node_identity(); + let inbound_msg = + make_dht_inbound_message(&some_other_node_identity, some_secret, DhtMessageFlags::ENCRYPTED, true); - block_on(service.call(inbound_msg)).unwrap(); + block_on(service.call(inbound_msg.clone())).unwrap(); let decrypted = result.lock().unwrap().take().unwrap(); assert_eq!(decrypted.decryption_succeeded(), false); - assert_eq!(decrypted.decryption_result.unwrap_err(), nonsense); + assert_eq!(decrypted.decryption_result.unwrap_err(), inbound_msg.body); + } + + #[test] + fn decrypt_inbound_fail_destination() { + let result = Mutex::new(None); + let inner = service_fn(|msg: DecryptedDhtMessage| { + *result.lock().unwrap() = Some(msg); + future::ready(Result::<(), PipelineError>::Ok(())) + }); + let node_identity = make_node_identity(); + let mut service = DecryptionService::new(inner, Arc::clone(&node_identity)); + + let nonsense = "Cannot Decrypt this".as_bytes().to_vec(); + let mut inbound_msg = + make_dht_inbound_message(&node_identity, nonsense.clone(), DhtMessageFlags::ENCRYPTED, true); + inbound_msg.dht_header.destination = node_identity.public_key().clone().into(); + + let err = block_on(service.call(inbound_msg)).unwrap_err(); + assert!(err.to_string().starts_with("Message rejected"),); + assert!(result.lock().unwrap().is_none()); } } diff --git a/comms/dht/src/inbound/dedup.rs b/comms/dht/src/inbound/dedup.rs deleted file mode 100644 index b3671b27e3..0000000000 --- a/comms/dht/src/inbound/dedup.rs +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright 2019, The Tari Project -// -// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the -// following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following -// disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the -// following disclaimer in the documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote -// products derived from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, -// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -use crate::{actor::DhtRequester, inbound::DhtInboundMessage}; -use digest::Input; -use futures::{task::Context, Future}; -use log::*; -use std::task::Poll; -use tari_comms::{pipeline::PipelineError, types::Challenge}; -use tari_crypto::tari_utilities::hex::Hex; -use tower::{layer::Layer, Service, ServiceExt}; - -const LOG_TARGET: &str = "comms::dht::dedup"; - -/// # DHT Deduplication middleware -/// -/// Takes in a `DhtInboundMessage` and checks the message signature cache for duplicates. -/// If a duplicate message is detected, it is discarded. -#[derive(Clone)] -pub struct DedupMiddleware { - next_service: S, - dht_requester: DhtRequester, -} - -impl DedupMiddleware { - pub fn new(service: S, dht_requester: DhtRequester) -> Self { - Self { - next_service: service, - dht_requester, - } - } -} - -impl Service for DedupMiddleware -where - S: Service + Clone + 'static, - S::Error: std::error::Error + Send + Sync + 'static, -{ - type Error = PipelineError; - type Response = (); - - type Future = impl Future>; - - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, msg: DhtInboundMessage) -> Self::Future { - Self::process_message(self.next_service.clone(), self.dht_requester.clone(), msg) - } -} - -impl DedupMiddleware -where - S: Service, - S::Error: std::error::Error + Send + Sync + 'static, -{ - pub async fn process_message( - next_service: S, - mut dht_requester: DhtRequester, - message: DhtInboundMessage, - ) -> Result<(), PipelineError> - { - trace!(target: LOG_TARGET, "Checking inbound message cache for duplicates"); - let hash = Self::hash_message(&message); - if dht_requester - .insert_message_hash(hash) - .await - .map_err(PipelineError::from_debug)? - { - warn!( - target: LOG_TARGET, - "Received duplicate message from peer {} (origin={:?}). Message discarded.", - message.source_peer.node_id, - message - .dht_header - .origin - .map(|o| o.public_key.to_hex()) - .unwrap_or_else(|| "".to_string()), - ); - return Ok(()); - } - next_service.oneshot(message).await.map_err(PipelineError::from_debug) - } - - fn hash_message(message: &DhtInboundMessage) -> Vec { - Challenge::new().chain(&message.body).result().to_vec() - } -} - -pub struct DedupLayer { - dht_requester: DhtRequester, -} - -impl DedupLayer { - pub fn new(dht_requester: DhtRequester) -> Self { - Self { dht_requester } - } -} - -impl Layer for DedupLayer { - type Service = DedupMiddleware; - - fn layer(&self, service: S) -> Self::Service { - DedupMiddleware::new(service, self.dht_requester.clone()) - } -} - -#[cfg(test)] -mod test { - use super::*; - use crate::{ - envelope::DhtMessageFlags, - test_utils::{create_dht_actor_mock, make_dht_inbound_message, make_node_identity, service_spy, DhtMockState}, - }; - use tari_test_utils::panic_context; - use tokio::runtime::Runtime; - - #[test] - fn process_message() { - let mut rt = Runtime::new().unwrap(); - let spy = service_spy(); - - let (dht_requester, mut mock) = create_dht_actor_mock(1); - let mock_state = DhtMockState::new(); - mock_state.set_signature_cache_insert(false); - mock.set_shared_state(mock_state.clone()); - rt.spawn(mock.run()); - - let mut dedup = DedupLayer::new(dht_requester).layer(spy.to_service::()); - - panic_context!(cx); - - assert!(dedup.poll_ready(&mut cx).is_ready()); - let node_identity = make_node_identity(); - let msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty()); - - rt.block_on(dedup.call(msg.clone())).unwrap(); - assert_eq!(spy.call_count(), 1); - - mock_state.set_signature_cache_insert(true); - rt.block_on(dedup.call(msg)).unwrap(); - assert_eq!(spy.call_count(), 1); - // Drop dedup so that the DhtMock will stop running - drop(dedup); - } -} diff --git a/comms/dht/src/inbound/deserialize.rs b/comms/dht/src/inbound/deserialize.rs index 18ea93004a..67390a9c01 100644 --- a/comms/dht/src/inbound/deserialize.rs +++ b/comms/dht/src/inbound/deserialize.rs @@ -47,9 +47,7 @@ impl DhtDeserializeMiddleware { } impl Service for DhtDeserializeMiddleware -where - S: Service + Clone + 'static, - S::Error: std::error::Error + Send + Sync + 'static, +where S: Service + Clone + 'static { type Error = PipelineError; type Response = (); @@ -60,55 +58,45 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, msg: InboundMessage) -> Self::Future { - Self::deserialize(self.next_service.clone(), msg) - } -} - -impl DhtDeserializeMiddleware -where - S: Service, - S::Error: std::error::Error + Send + Sync + 'static, -{ - pub async fn deserialize(next_service: S, message: InboundMessage) -> Result<(), PipelineError> { - trace!(target: LOG_TARGET, "Deserializing InboundMessage"); - - let InboundMessage { - source_peer, mut body, .. - } = message; - - match DhtEnvelope::decode(&mut body) { - Ok(dht_envelope) => { - trace!(target: LOG_TARGET, "Deserialization succeeded. Checking signatures"); - if dht_envelope.has_origin() { - if dht_envelope.is_origin_signature_valid() { - trace!(target: LOG_TARGET, "Origin signature validation passed."); - } else { - // The origin signature is not valid, this message should never have been sent - warn!( - target: LOG_TARGET, - "SECURITY: Origin signature verification failed. Discarding message from NodeId {}", - source_peer.node_id - ); - return Ok(()); - } - } - - let inbound_msg = DhtInboundMessage::new( - dht_envelope.header.try_into().map_err(PipelineError::from_debug)?, - source_peer, - dht_envelope.body, - ); - - next_service - .oneshot(inbound_msg) - .await - .map_err(PipelineError::from_debug) - }, - Err(err) => { - error!(target: LOG_TARGET, "DHT deserialization failed: {}", err); - Err(PipelineError::from_debug(err)) - }, + fn call(&mut self, message: InboundMessage) -> Self::Future { + let next_service = self.next_service.clone(); + async move { + trace!(target: LOG_TARGET, "Deserializing InboundMessage"); + + let InboundMessage { + source_peer, + mut body, + tag, + .. + } = message; + + if body.is_empty() { + return Err(format!("Received empty message from peer '{}'", source_peer) + .as_str() + .into()); + } + + match DhtEnvelope::decode(&mut body) { + Ok(dht_envelope) => { + debug!( + target: LOG_TARGET, + "Deserialization succeeded. Passing message {} onto next service", tag + ); + + let inbound_msg = DhtInboundMessage::new( + tag, + dht_envelope.header.try_into().map_err(PipelineError::from_debug)?, + source_peer, + dht_envelope.body, + ); + + next_service.oneshot(inbound_msg).await + }, + Err(err) => { + error!(target: LOG_TARGET, "DHT deserialization failed: {}", err); + Err(PipelineError::from_debug(err)) + }, + } } } } @@ -138,7 +126,7 @@ mod test { test_utils::{make_comms_inbound_message, make_dht_envelope, make_node_identity, service_spy}, }; use futures::executor::block_on; - use tari_comms::message::{MessageExt, MessageFlags}; + use tari_comms::message::MessageExt; use tari_test_utils::panic_context; #[test] @@ -150,11 +138,10 @@ mod test { assert!(deserialize.poll_ready(&mut cx).is_ready()); let node_identity = make_node_identity(); - let dht_envelope = make_dht_envelope(&node_identity, b"A".to_vec(), DhtMessageFlags::empty()); + let dht_envelope = make_dht_envelope(&node_identity, b"A".to_vec(), DhtMessageFlags::empty(), false); block_on(deserialize.call(make_comms_inbound_message( &node_identity, - dht_envelope.to_encoded_bytes().unwrap().into(), - MessageFlags::empty(), + dht_envelope.to_encoded_bytes().into(), ))) .unwrap(); diff --git a/comms/dht/src/inbound/dht_handler/middleware.rs b/comms/dht/src/inbound/dht_handler/middleware.rs index 4f89c7eb5a..5e0a6c3567 100644 --- a/comms/dht/src/inbound/dht_handler/middleware.rs +++ b/comms/dht/src/inbound/dht_handler/middleware.rs @@ -68,9 +68,7 @@ impl DhtHandlerMiddleware { } impl Service for DhtHandlerMiddleware -where - S: Service + Clone, - S::Error: std::error::Error + Send + Sync + 'static, +where S: Service + Clone { type Error = PipelineError; type Response = (); @@ -78,7 +76,7 @@ where type Future = impl Future>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.next_service.poll_ready(cx).map_err(PipelineError::from_debug) + self.next_service.poll_ready(cx) } fn call(&mut self, message: DecryptedDhtMessage) -> Self::Future { diff --git a/comms/dht/src/inbound/dht_handler/task.rs b/comms/dht/src/inbound/dht_handler/task.rs index b9ced8ac84..92c2aa52df 100644 --- a/comms/dht/src/inbound/dht_handler/task.rs +++ b/comms/dht/src/inbound/dht_handler/task.rs @@ -40,7 +40,7 @@ use tari_comms::{ pipeline::PipelineError, types::CommsPublicKey, }; -use tari_crypto::tari_utilities::{hex::Hex, ByteArray}; +use tari_utilities::{hex::Hex, ByteArray}; use tower::{Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::dht_handler"; @@ -56,9 +56,7 @@ pub struct ProcessDhtMessage { } impl ProcessDhtMessage -where - S: Service, - S::Error: std::error::Error + Send + Sync + 'static, +where S: Service { pub fn new( config: DhtConfig, @@ -87,12 +85,12 @@ where .take() .expect("ProcessDhtMessage initialized without message"); - // If this message failed to decrypt, this middleware is not interested in it + // If this message failed to decrypt, we stop it going further at this layer if message.decryption_failed() { - self.next_service - .oneshot(message) - .await - .map_err(PipelineError::from_debug)?; + debug!( + target: LOG_TARGET, + "Message that failed to decrypt will be discarded here. DhtHeader={}", message.dht_header + ); return Ok(()); } @@ -110,10 +108,7 @@ where // Not a DHT message, call downstream middleware _ => { trace!(target: LOG_TARGET, "Passing message onto next service"); - self.next_service - .oneshot(message) - .await - .map_err(PipelineError::from_debug)? + self.next_service.oneshot(message).await?; }, } @@ -137,12 +132,13 @@ where Some(node_id), Some(net_addresses), None, + None, + Some(false), Some(peer_features), None, None, ) .await?; - peer_manager.set_offline(&pubkey, false).await?; } else { peer_manager .add_peer(Peer::new( @@ -180,25 +176,29 @@ where decryption_result, dht_header, source_peer, + authenticated_origin, + is_saf_message, .. } = message; - let origin = dht_header - .origin - .as_ref() - .ok_or_else(|| DhtInboundError::OriginRequired("Origin is required for this message type".to_string()))?; + let authenticated_pk = authenticated_origin.ok_or_else(|| { + DhtInboundError::OriginRequired("Authenticated origin is required for this message type".to_string()) + })?; - if &origin.public_key == self.node_identity.public_key() { - trace!(target: LOG_TARGET, "Received our own join message. Discarding it."); + if &authenticated_pk == self.node_identity.public_key() { + warn!(target: LOG_TARGET, "Received our own join message. Discarding it."); return Ok(()); } - trace!(target: LOG_TARGET, "Received Join Message from {}", origin.public_key); - let body = decryption_result.expect("already checked that this message decrypted successfully"); let join_msg = body .decode_part::(0)? - .ok_or_else(|| DhtInboundError::InvalidJoinNetAddresses)?; + .ok_or_else(|| DhtInboundError::InvalidMessageBody)?; + + info!( + target: LOG_TARGET, + "Received join Message from '{}' {}", authenticated_pk, join_msg + ); let addresses = join_msg .addresses @@ -210,11 +210,11 @@ where return Err(DhtInboundError::InvalidAddresses); } - let node_id = self.validate_raw_node_id(&origin.public_key, &join_msg.node_id)?; + let node_id = self.validate_raw_node_id(&authenticated_pk, &join_msg.node_id)?; let origin_peer = self .add_or_update_peer( - &origin.public_key, + &authenticated_pk, node_id, addresses, PeerFeatures::from_bits_truncate(join_msg.peer_features), @@ -250,13 +250,21 @@ where "Sending Join to joining peer with public key '{}'", origin_peer.public_key ); + self.send_join_direct(origin_peer.public_key).await?; } - trace!( + if is_saf_message { + debug!( + target: LOG_TARGET, + "Not re-propagating join message received from store and forward" + ); + return Ok(()); + } + + debug!( target: LOG_TARGET, - "Propagating join message to at most {} peer(s)", - self.config.num_neighbouring_nodes + "Propagating join message to at most {} peer(s)", self.config.num_neighbouring_nodes ); // Propagate message to closer peers @@ -266,12 +274,12 @@ where .closest( origin_peer.node_id, self.config.num_neighbouring_nodes, - vec![origin.public_key.clone(), source_peer.public_key.clone()], + vec![authenticated_pk, source_peer.public_key.clone()], PeerFeatures::MESSAGE_PROPAGATION, ) .with_dht_header(dht_header) .finish(), - body.to_encoded_bytes()?, + body.to_encoded_bytes(), ) .await?; @@ -304,10 +312,9 @@ where target: LOG_TARGET, "Received Discover Response Message from {}", message - .dht_header - .origin + .authenticated_origin .as_ref() - .map(|o| o.public_key.to_hex()) + .map(|pk| pk.to_hex()) .unwrap_or_else(|| "".to_string()) ); @@ -335,13 +342,13 @@ where .decode_part::(0)? .ok_or_else(|| DhtInboundError::InvalidMessageBody)?; - let origin = message.dht_header.origin.ok_or_else(|| { + let authenticated_pk = message.authenticated_origin.ok_or_else(|| { DhtInboundError::OriginRequired("Origin header required for Discovery message".to_string()) })?; info!( target: LOG_TARGET, - "Received discovery message from '{}'", origin.public_key, + "Received discovery message from '{}', forwarded by {}", authenticated_pk, message.source_peer ); let addresses = discover_msg @@ -354,10 +361,10 @@ where return Err(DhtInboundError::InvalidAddresses); } - let node_id = self.validate_raw_node_id(&origin.public_key, &discover_msg.node_id)?; + let node_id = self.validate_raw_node_id(&authenticated_pk, &discover_msg.node_id)?; let origin_peer = self .add_or_update_peer( - &origin.public_key, + &authenticated_pk, node_id, addresses, PeerFeatures::from_bits_truncate(discover_msg.peer_features), @@ -368,13 +375,13 @@ where if origin_peer.is_banned() { warn!( target: LOG_TARGET, - "Received Discovery request for banned peer. This request will be ignored." + "Received Discovery request for banned peer '{}'. This request will be ignored.", authenticated_pk ); return Ok(()); } // Send the origin the current nodes latest contact info - self.send_discovery_response(origin.public_key, discover_msg.nonce) + self.send_discovery_response(origin_peer.public_key, discover_msg.nonce) .await?; Ok(()) diff --git a/comms/dht/src/inbound/error.rs b/comms/dht/src/inbound/error.rs index 43e548f4a6..d383728263 100644 --- a/comms/dht/src/inbound/error.rs +++ b/comms/dht/src/inbound/error.rs @@ -28,7 +28,6 @@ use tari_comms::{message::MessageError, peer_manager::PeerManagerError}; #[derive(Debug, Error)] pub enum DhtInboundError { MessageError(MessageError), - // MessageFormatError(MessageFormatError), PeerManagerError(PeerManagerError), DhtOutboundError(DhtOutboundError), /// Failed to decode message @@ -39,8 +38,6 @@ pub enum DhtInboundError { InvalidNodeId, /// All given addresses were invalid InvalidAddresses, - /// One or more NetAddress in the join message were invalid - InvalidJoinNetAddresses, DhtDiscoveryError(DhtDiscoveryError), #[error(msg_embedded, no_from, non_std)] OriginRequired(String), diff --git a/comms/dht/src/inbound/message.rs b/comms/dht/src/inbound/message.rs index 4f5bd96a35..6559a716b8 100644 --- a/comms/dht/src/inbound/message.rs +++ b/comms/dht/src/inbound/message.rs @@ -20,27 +20,38 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{consts::DHT_ENVELOPE_HEADER_VERSION, envelope::DhtMessageHeader}; +use crate::{ + consts::DHT_ENVELOPE_HEADER_VERSION, + envelope::{DhtMessageFlags, DhtMessageHeader}, +}; use std::{ fmt::{Display, Error, Formatter}, sync::Arc, }; -use tari_comms::{message::EnvelopeBody, peer_manager::Peer, types::CommsPublicKey}; -use tari_crypto::tari_utilities::hex::Hex; +use tari_comms::{ + message::{EnvelopeBody, MessageTag}, + peer_manager::Peer, + types::CommsPublicKey, +}; #[derive(Debug, Clone)] pub struct DhtInboundMessage { + pub tag: MessageTag, pub version: u32, pub source_peer: Arc, pub dht_header: DhtMessageHeader, + /// True if forwarded via store and forward, otherwise false + pub is_saf_message: bool, pub body: Vec, } impl DhtInboundMessage { - pub fn new(dht_header: DhtMessageHeader, source_peer: Arc, body: Vec) -> Self { + pub fn new(tag: MessageTag, dht_header: DhtMessageHeader, source_peer: Arc, body: Vec) -> Self { Self { + tag, version: DHT_ENVELOPE_HEADER_VERSION, dht_header, source_peer, + is_saf_message: false, body, } } @@ -50,11 +61,12 @@ impl Display for DhtInboundMessage { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { write!( f, - "DhtInboundMessage (v{}, Peer:{}, Header:{}, Body:{})", - self.version, + "\n---- Inbound Message ---- \nSize: {} byte(s)\nType: {}\nPeer: {}\nHeader: {}\n{}\n----", + self.body.len(), + self.dht_header.message_type, self.source_peer, self.dht_header, - self.body.to_hex() + self.tag, ) } } @@ -62,29 +74,43 @@ impl Display for DhtInboundMessage { /// Represents a decrypted InboundMessage. #[derive(Debug, Clone)] pub struct DecryptedDhtMessage { + pub tag: MessageTag, pub version: u32, /// The _connected_ peer which sent or forwarded this message. This may not be the peer /// which created this message. pub source_peer: Arc, + pub authenticated_origin: Option, pub dht_header: DhtMessageHeader, + pub is_saf_message: bool, pub decryption_result: Result>, } impl DecryptedDhtMessage { - pub fn succeeded(decrypted_message: EnvelopeBody, message: DhtInboundMessage) -> Self { + pub fn succeeded( + message_body: EnvelopeBody, + authenticated_origin: Option, + message: DhtInboundMessage, + ) -> Self + { Self { + tag: message.tag, version: message.version, source_peer: message.source_peer, + authenticated_origin, dht_header: message.dht_header, - decryption_result: Ok(decrypted_message), + is_saf_message: message.is_saf_message, + decryption_result: Ok(message_body), } } pub fn failed(message: DhtInboundMessage) -> Self { Self { + tag: message.tag, version: message.version, source_peer: message.source_peer, + authenticated_origin: None, dht_header: message.dht_header, + is_saf_message: message.is_saf_message, decryption_result: Err(message.body), } } @@ -113,11 +139,23 @@ impl DecryptedDhtMessage { self.decryption_result.is_err() } - pub fn origin_public_key(&self) -> &CommsPublicKey { - self.dht_header - .origin - .as_ref() - .map(|o| &o.public_key) - .unwrap_or(&self.source_peer.public_key) + pub fn authenticated_origin(&self) -> Option<&CommsPublicKey> { + self.authenticated_origin.as_ref() + } + + /// Returns true if the message is or was encrypted by + pub fn is_encrypted(&self) -> bool { + self.dht_header.flags.contains(DhtMessageFlags::ENCRYPTED) + } + + pub fn has_origin_mac(&self) -> bool { + !self.dht_header.origin_mac.is_empty() + } + + pub fn body_len(&self) -> usize { + match self.decryption_result.as_ref() { + Ok(b) => b.total_size(), + Err(b) => b.len(), + } } } diff --git a/comms/dht/src/inbound/mod.rs b/comms/dht/src/inbound/mod.rs index f5373f41bf..7b9c137794 100644 --- a/comms/dht/src/inbound/mod.rs +++ b/comms/dht/src/inbound/mod.rs @@ -21,7 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mod decryption; -mod dedup; mod deserialize; mod dht_handler; mod error; @@ -30,7 +29,6 @@ mod validate; pub use self::{ decryption::DecryptionLayer, - dedup::DedupLayer, deserialize::DeserializeLayer, dht_handler::DhtHandlerLayer, message::{DecryptedDhtMessage, DhtInboundMessage}, diff --git a/comms/dht/src/inbound/validate.rs b/comms/dht/src/inbound/validate.rs index 2b59d4195c..e2b2ea089a 100644 --- a/comms/dht/src/inbound/validate.rs +++ b/comms/dht/src/inbound/validate.rs @@ -20,19 +20,11 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - inbound::DhtInboundMessage, - outbound::{OutboundMessageRequester, SendMessageParams}, - proto::{ - dht::{RejectMessage, RejectMessageReason}, - envelope::{DhtMessageType, Network}, - }, -}; +use crate::{inbound::DhtInboundMessage, proto::envelope::Network}; use futures::{task::Context, Future}; use log::*; use std::task::Poll; -use tari_comms::{message::MessageExt, pipeline::PipelineError}; -use tari_crypto::tari_utilities::ByteArray; +use tari_comms::pipeline::PipelineError; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::validate"; @@ -45,23 +37,19 @@ const LOG_TARGET: &str = "comms::dht::validate"; pub struct ValidateMiddleware { next_service: S, target_network: Network, - outbound_requester: OutboundMessageRequester, } impl ValidateMiddleware { - pub fn new(service: S, target_network: Network, outbound_requester: OutboundMessageRequester) -> Self { + pub fn new(service: S, target_network: Network) -> Self { Self { next_service: service, target_network, - outbound_requester, } } } impl Service for ValidateMiddleware -where - S: Service + Clone + 'static, - S::Error: std::error::Error + Send + Sync + 'static, +where S: Service + Clone + 'static { type Error = PipelineError; type Response = (); @@ -72,78 +60,35 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, msg: DhtInboundMessage) -> Self::Future { - Self::process_message( - self.next_service.clone(), - self.target_network, - self.outbound_requester.clone(), - msg, - ) - } -} - -impl ValidateMiddleware -where - S: Service, - S::Error: std::error::Error + Send + Sync + 'static, -{ - pub async fn process_message( - next_service: S, - target_network: Network, - mut outbound_requester: OutboundMessageRequester, - message: DhtInboundMessage, - ) -> Result<(), PipelineError> - { - trace!( - target: LOG_TARGET, - "Checking the message target network is '{:?}'", - target_network - ); - if message.dht_header.network == target_network { - next_service.oneshot(message).await.map_err(PipelineError::from_debug)?; - } else { - debug!( - target: LOG_TARGET, - "Message is for another network (want = {:?} got = {:?}). Explicitly rejecting the message.", - target_network, - message.dht_header.network - ); - outbound_requester - .send_raw( - SendMessageParams::new() - .direct_public_key(message.source_peer.public_key.clone()) - .with_dht_message_type(DhtMessageType::RejectMsg) - .finish(), - RejectMessage { - signature: message - .dht_header - .origin - .map(|o| o.public_key.to_vec()) - .unwrap_or_default(), - reason: RejectMessageReason::UnsupportedNetwork as i32, - } - .to_encoded_bytes() - .map_err(PipelineError::from_debug)?, - ) - .await - .map_err(PipelineError::from_debug)?; + fn call(&mut self, message: DhtInboundMessage) -> Self::Future { + let next_service = self.next_service.clone(); + let target_network = self.target_network; + async move { + if message.dht_header.network == target_network && message.dht_header.is_valid() { + debug!(target: LOG_TARGET, "Passing message {} to next service", message.tag); + next_service.oneshot(message).await?; + } else { + warn!( + target: LOG_TARGET, + "Message is for another network (want = {:?} got = {:?}) or message header is invalid. Discarding \ + the message.", + target_network, + message.dht_header.network + ); + } + + Ok(()) } - - Ok(()) } } pub struct ValidateLayer { target_network: Network, - outbound_requester: OutboundMessageRequester, } impl ValidateLayer { - pub fn new(target_network: Network, outbound_requester: OutboundMessageRequester) -> Self { - Self { - target_network, - outbound_requester, - } + pub fn new(target_network: Network) -> Self { + Self { target_network } } } @@ -151,7 +96,7 @@ impl Layer for ValidateLayer { type Service = ValidateMiddleware; fn layer(&self, service: S) -> Self::Service { - ValidateMiddleware::new(service, self.target_network, self.outbound_requester.clone()) + ValidateMiddleware::new(service, self.target_network) } } @@ -159,8 +104,7 @@ impl Layer for ValidateLayer { mod test { use super::*; use crate::{ - envelope::{DhtMessageFlags, DhtMessageType}, - outbound::mock::create_outbound_service_mock, + envelope::DhtMessageFlags, test_utils::{make_dht_inbound_message, make_node_identity, service_spy}, }; use tari_test_utils::panic_context; @@ -171,18 +115,13 @@ mod test { let mut rt = Runtime::new().unwrap(); let spy = service_spy(); - let (out_requester, mock) = create_outbound_service_mock(1); - let mock_state = mock.get_state(); - rt.spawn(mock.run()); - - let mut validate = - ValidateLayer::new(Network::LocalTest, out_requester).layer(spy.to_service::()); + let mut validate = ValidateLayer::new(Network::LocalTest).layer(spy.to_service::()); panic_context!(cx); assert!(validate.poll_ready(&mut cx).is_ready()); let node_identity = make_node_identity(); - let mut msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty()); + let mut msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty(), false); msg.dht_header.network = Network::MainNet; rt.block_on(validate.call(msg.clone())).unwrap(); @@ -192,17 +131,5 @@ mod test { rt.block_on(validate.call(msg.clone())).unwrap(); assert_eq!(spy.call_count(), 1); - - let calls = mock_state.take_calls(); - assert_eq!(calls.len(), 1); - let params = calls[0].0.clone(); - assert_eq!(params.dht_message_type, DhtMessageType::RejectMsg); - assert_eq!( - params.broadcast_strategy.direct_public_key().unwrap(), - node_identity.public_key() - ); - - // Drop validate so that the mock will stop running - drop(validate); } } diff --git a/comms/dht/src/lib.rs b/comms/dht/src/lib.rs index 37f4e18686..eadbf96db9 100644 --- a/comms/dht/src/lib.rs +++ b/comms/dht/src/lib.rs @@ -107,6 +107,11 @@ // Details: https://doc.rust-lang.org/beta/unstable-book/language-features/type-alias-impl-trait.html #![feature(type_alias_impl_trait)] +#[macro_use] +extern crate diesel; +#[macro_use] +extern crate diesel_migrations; + #[macro_use] mod macros; @@ -127,14 +132,23 @@ mod consts; mod crypt; mod dht; -pub use dht::Dht; +pub use dht::{Dht, DhtInitializationError}; mod discovery; pub use discovery::DhtDiscoveryRequester; +mod storage; +pub use storage::DbConnectionUrl; + +mod dedup; +pub use dedup::DedupLayer; + mod logging_middleware; mod proto; mod tower_filter; +mod utils; + +mod schema; pub mod broadcast_strategy; pub mod domain_message; diff --git a/comms/dht/src/logging_middleware.rs b/comms/dht/src/logging_middleware.rs index e6b906b1d8..6201853a4f 100644 --- a/comms/dht/src/logging_middleware.rs +++ b/comms/dht/src/logging_middleware.rs @@ -22,47 +22,47 @@ use futures::{task::Context, Future, TryFutureExt}; use log::*; -use std::{fmt::Display, marker::PhantomData, task::Poll}; +use std::{borrow::Cow, fmt::Display, marker::PhantomData, task::Poll}; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::middleware::message_logging"; /// This layer is responsible for logging messages for debugging. -pub struct MessageLoggingLayer { - prefix_msg: &'static str, +pub struct MessageLoggingLayer<'a, R> { + prefix_msg: Cow<'a, str>, _r: PhantomData, } -impl MessageLoggingLayer { - pub fn new(prefix_msg: &'static str) -> Self { +impl<'a, R> MessageLoggingLayer<'a, R> { + pub fn new>>(prefix_msg: T) -> Self { Self { - prefix_msg, + prefix_msg: prefix_msg.into(), _r: PhantomData, } } } -impl Layer for MessageLoggingLayer +impl<'a, S, R> Layer for MessageLoggingLayer<'a, R> where S: Service, S::Error: std::error::Error + Send + Sync + 'static, R: Display, { - type Service = MessageLoggingService; + type Service = MessageLoggingService<'a, S>; fn layer(&self, service: S) -> Self::Service { - MessageLoggingService::new(self.prefix_msg, service) + MessageLoggingService::new(self.prefix_msg.clone(), service) } } #[derive(Clone)] -pub struct MessageLoggingService { - prefix_msg: &'static str, +pub struct MessageLoggingService<'a, S> { + prefix_msg: Cow<'a, str>, inner: S, } -impl MessageLoggingService { - pub fn new(prefix_msg: &'static str, service: S) -> Self { +impl<'a, S> MessageLoggingService<'a, S> { + pub fn new(prefix_msg: Cow<'a, str>, service: S) -> Self { Self { inner: service, prefix_msg, @@ -70,7 +70,7 @@ impl MessageLoggingService { } } -impl Service for MessageLoggingService +impl Service for MessageLoggingService<'_, S> where S: Service + Clone, S::Error: std::error::Error + Send + Sync + 'static, diff --git a/comms/dht/src/macros.rs b/comms/dht/src/macros.rs index ce4cb640fb..3911e96e99 100644 --- a/comms/dht/src/macros.rs +++ b/comms/dht/src/macros.rs @@ -28,7 +28,7 @@ macro_rules! acquire_lock { match $e.$m() { Ok(lock) => lock, Err(poisoned) => { - log::warn!(target: "dht", "Lock has been POISONED and will be silently recovered"); + log::warn!(target: "comms::dht", "Lock has been POISONED and will be silently recovered"); poisoned.into_inner() }, } @@ -37,9 +37,3 @@ macro_rules! acquire_lock { acquire_lock!($e, lock) }; } - -macro_rules! acquire_write_lock { - ($e:expr) => { - acquire_lock!($e, write) - }; -} diff --git a/comms/dht/src/outbound/broadcast.rs b/comms/dht/src/outbound/broadcast.rs index 6289fe32d9..b0acd199ab 100644 --- a/comms/dht/src/outbound/broadcast.rs +++ b/comms/dht/src/outbound/broadcast.rs @@ -24,15 +24,18 @@ use super::{error::DhtOutboundError, message::DhtOutboundRequest}; use crate::{ actor::DhtRequester, broadcast_strategy::BroadcastStrategy, + crypt, discovery::DhtDiscoveryRequester, - envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageOrigin, NodeDestination}, + envelope::{DhtMessageFlags, DhtMessageHeader, NodeDestination}, outbound::{ message::{DhtOutboundMessage, OutboundEncryption}, message_params::FinalSendMessageParams, + message_send_state::MessageSendState, SendMessageResponse, }, - proto::envelope::{DhtMessageType, Network}, + proto::envelope::{DhtMessageType, Network, OriginMac}, }; +use bytes::Bytes; use futures::{ channel::oneshot, future, @@ -41,12 +44,18 @@ use futures::{ Future, }; use log::*; +use rand::rngs::OsRng; use std::{sync::Arc, task::Poll}; use tari_comms::{ - message::MessageFlags, - peer_manager::{NodeId, NodeIdentity, Peer}, + message::{MessageExt, MessageTag}, + peer_manager::{NodeIdentity, Peer}, pipeline::PipelineError, types::CommsPublicKey, + utils::signature, +}; +use tari_crypto::{ + keys::PublicKey, + tari_utilities::{message_format::MessageFormat, ByteArray}, }; use tower::{layer::Layer, Service, ServiceExt}; @@ -121,9 +130,7 @@ impl BroadcastMiddleware { } impl Service for BroadcastMiddleware -where - S: Service + Clone, - S::Error: std::error::Error + Send + Sync + 'static, +where S: Service + Clone { type Error = PipelineError; type Response = (); @@ -157,9 +164,7 @@ struct BroadcastTask { } impl BroadcastTask -where - S: Service, - S::Error: std::error::Error + Send + Sync + 'static, +where S: Service { pub fn new( service: S, @@ -221,10 +226,11 @@ where async fn handle_send_message( &mut self, params: FinalSendMessageParams, - body: Vec, + body: Bytes, reply_tx: oneshot::Sender, ) -> Result, DhtOutboundError> { + trace!(target: LOG_TARGET, "Send params: {:?}", params); if params .broadcast_strategy .direct_public_key() @@ -262,6 +268,8 @@ where is_discovery_enabled, ); + let is_broadcast = broadcast_strategy.is_broadcast(); + // Discovery is required if: // - Discovery is enabled for this request // - There where no peers returned @@ -283,7 +291,7 @@ where }, Ok(None) => { // Message sent to 0 peers - let _ = discovery_reply_tx.send(SendMessageResponse::Queued(vec![])); + let _ = discovery_reply_tx.send(SendMessageResponse::Queued(vec![].into())); return Ok(Vec::new()); }, Err(err) => { @@ -302,20 +310,21 @@ where dht_header, dht_message_flags, force_origin, + is_broadcast, body, ) .await { - Ok(msgs) => { - // Reply with the number of messages to be sent + Ok((msgs, send_states)) => { + // Reply with the `MessageTag`s for each message let _ = reply_tx .take() .expect("cannot fail") - .send(SendMessageResponse::Queued(msgs.iter().map(|m| m.tag).collect())); + .send(SendMessageResponse::Queued(send_states.into())); + Ok(msgs) }, Err(err) => { - // Reply 0 messages sent let _ = reply_tx.take().expect("cannot fail").send(SendMessageResponse::Failed); Err(err) }, @@ -349,22 +358,10 @@ where dest_public_key ); - // TODO: This works because we know that all non-DAN node IDs are/should be derived from the public key. - // Once the DAN launches, this may not be the case and we'll need to query the blockchain for the node id - let derived_node_id = NodeId::from_key(&*dest_public_key).ok(); - - // TODO: Target a general region instead of the actual destination node id - let regional_destination = derived_node_id - .as_ref() - .map(Clone::clone) - .map(Box::new) - .map(NodeDestination::NodeId) - .unwrap_or_else(|| NodeDestination::Unknown); - // Peer not found, let's try and discover it match self .dht_discovery_requester - .discover_peer(dest_public_key, derived_node_id, regional_destination) + .discover_peer(dest_public_key.clone(), NodeDestination::PublicKey(dest_public_key)) .await { // Peer found! @@ -387,7 +384,7 @@ where // Error during discovery Err(err) => { debug!(target: LOG_TARGET, "Peer discovery failed because '{}'.", err); - Ok(None) + Err(DhtOutboundError::DiscoveryFailed) }, } } @@ -402,54 +399,92 @@ where custom_header: Option, extra_flags: DhtMessageFlags, force_origin: bool, - body: Vec, - ) -> Result, DhtOutboundError> + is_broadcast: bool, + body: Bytes, + ) -> Result<(Vec, Vec), DhtOutboundError> { let dht_flags = encryption.flags() | extra_flags; - // Create a DHT header - let dht_header = custom_header - .or_else(|| { - // The origin is specified if encryption is turned on, otherwise it is not - let origin = if force_origin || encryption.is_encrypt() { - Some(DhtMessageOrigin { - // Origin public key used to identify the origin and verify the signature - public_key: self.node_identity.public_key().clone(), - // Signing will happen later in the pipeline (SerializeMiddleware), left empty to prevent double - // work - signature: Vec::new(), - }) - } else { - None - }; - - Some(DhtMessageHeader::new( - // Final destination for this message - destination, - dht_message_type, - origin, - self.target_network, - dht_flags, - )) - }) - .expect("always Some"); + let (ephemeral_public_key, origin_mac, body) = self.process_encryption(&encryption, force_origin, body)?; - // Construct a MessageEnvelope for each recipient + // Construct a DhtOutboundMessage for each recipient let messages = selected_peers .into_iter() .map(|peer| { - DhtOutboundMessage::new( - peer, - dht_header.clone(), - encryption.clone(), - MessageFlags::NONE, - body.clone(), + let (reply_tx, reply_rx) = oneshot::channel(); + let tag = MessageTag::new(); + let send_state = MessageSendState::new(tag, reply_rx); + ( + DhtOutboundMessage { + tag, + destination_peer: peer, + destination: destination.clone(), + dht_message_type, + network: self.target_network, + dht_flags, + custom_header: custom_header.clone(), + body: body.clone(), + reply_tx: reply_tx.into(), + ephemeral_public_key: ephemeral_public_key.clone(), + origin_mac: origin_mac.clone(), + is_broadcast, + }, + send_state, ) }) .collect::>(); - Ok(messages) + Ok(messages.into_iter().unzip()) } + + fn process_encryption( + &self, + encryption: &OutboundEncryption, + include_origin: bool, + body: Bytes, + ) -> Result<(Option>, Option, Bytes), DhtOutboundError> + { + match encryption { + OutboundEncryption::EncryptFor(public_key) => { + debug!(target: LOG_TARGET, "Encrypting message for {}", public_key); + // Generate ephemeral public/private key pair and ECDH shared secret + let (e_sk, e_pk) = CommsPublicKey::random_keypair(&mut OsRng); + let shared_ephemeral_secret = crypt::generate_ecdh_secret(&e_sk, &**public_key); + // Encrypt the message with the body + let encrypted_body = crypt::encrypt(&shared_ephemeral_secret, &body)?; + + // Sign the encrypted message + let origin_mac = create_origin_mac(&self.node_identity, &encrypted_body)?; + // Encrypt and set the origin field + let encrypted_origin_mac = crypt::encrypt(&shared_ephemeral_secret, &origin_mac)?; + Ok(( + Some(Arc::new(e_pk)), + Some(encrypted_origin_mac.into()), + encrypted_body.into(), + )) + }, + OutboundEncryption::None => { + debug!(target: LOG_TARGET, "Encryption not requested for message"); + + if include_origin { + let origin_mac = create_origin_mac(&self.node_identity, &body)?; + Ok((None, Some(origin_mac.into()), body)) + } else { + Ok((None, None, body)) + } + }, + } + } +} + +fn create_origin_mac(node_identity: &NodeIdentity, body: &[u8]) -> Result, DhtOutboundError> { + let signature = signature::sign(&mut OsRng, node_identity.secret_key().clone(), body)?; + + let mac = OriginMac { + public_key: node_identity.public_key().to_vec(), + signature: signature.to_binary()?, + }; + Ok(mac.to_encoded_bytes()) } #[cfg(test)] @@ -531,7 +566,7 @@ mod test { rt.block_on(service.call(DhtOutboundRequest::SendMessage( Box::new(SendMessageParams::new().flood().finish()), - "custom_msg".as_bytes().to_vec(), + "custom_msg".as_bytes().into(), reply_tx, ))) .unwrap(); @@ -581,7 +616,7 @@ mod test { .with_discovery(false) .finish(), ), - "custom_msg".as_bytes().to_vec(), + Bytes::from_static(b"custom_msg"), reply_tx, )), ) @@ -632,7 +667,7 @@ mod test { .direct_public_key(peer_to_discover.public_key.clone()) .finish(), ), - "custom_msg".as_bytes().to_vec(), + "custom_msg".as_bytes().into(), reply_tx, )), ) diff --git a/comms/dht/src/outbound/encryption.rs b/comms/dht/src/outbound/encryption.rs deleted file mode 100644 index 9756eafe11..0000000000 --- a/comms/dht/src/outbound/encryption.rs +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright 2019, The Tari Project -// -// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the -// following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following -// disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the -// following disclaimer in the documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote -// products derived from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, -// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -use crate::{ - crypt, - outbound::message::{DhtOutboundMessage, OutboundEncryption}, -}; -use futures::{task::Context, Future}; -use log::*; -use std::{sync::Arc, task::Poll}; -use tari_comms::{peer_manager::NodeIdentity, pipeline::PipelineError}; -use tower::{layer::Layer, Service, ServiceExt}; - -const LOG_TARGET: &str = "comms::middleware::encryption"; - -/// This layer is responsible for attempting to decrypt inbound messages. -pub struct EncryptionLayer { - node_identity: Arc, -} - -impl EncryptionLayer { - pub fn new(node_identity: Arc) -> Self { - Self { node_identity } - } -} - -impl Layer for EncryptionLayer { - type Service = EncryptionService; - - fn layer(&self, service: S) -> Self::Service { - EncryptionService::new(service, Arc::clone(&self.node_identity)) - } -} - -/// Responsible for decrypting InboundMessages and passing a DecryptedInboundMessage to the given service -#[derive(Clone)] -pub struct EncryptionService { - node_identity: Arc, - inner: S, -} - -impl EncryptionService { - pub fn new(service: S, node_identity: Arc) -> Self { - Self { - inner: service, - node_identity, - } - } -} - -impl Service for EncryptionService -where - S: Service + Clone, - S::Error: std::error::Error + Send + Sync + 'static, -{ - type Error = PipelineError; - type Response = (); - - type Future = impl Future>; - - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, msg: DhtOutboundMessage) -> Self::Future { - Self::handle_message(self.inner.clone(), Arc::clone(&self.node_identity), msg) - } -} - -impl EncryptionService -where - S: Service, - S::Error: std::error::Error + Send + Sync + 'static, -{ - async fn handle_message( - next_service: S, - node_identity: Arc, - mut message: DhtOutboundMessage, - ) -> Result<(), PipelineError> - { - trace!(target: LOG_TARGET, "DHT Message flags: {:?}", message.dht_header.flags); - match &message.encryption { - OutboundEncryption::EncryptFor(public_key) => { - debug!(target: LOG_TARGET, "Encrypting message for {}", public_key); - let shared_secret = crypt::generate_ecdh_secret(node_identity.secret_key(), &**public_key); - message.body = crypt::encrypt(&shared_secret, &message.body).map_err(PipelineError::from_debug)?; - }, - OutboundEncryption::EncryptForPeer => { - debug!( - target: LOG_TARGET, - "Encrypting message for peer with public key {}", message.destination_peer.public_key - ); - let shared_secret = - crypt::generate_ecdh_secret(node_identity.secret_key(), &message.destination_peer.public_key); - message.body = crypt::encrypt(&shared_secret, &message.body).map_err(PipelineError::from_debug)? - }, - OutboundEncryption::None => { - debug!(target: LOG_TARGET, "Encryption not requested for message"); - }, - }; - - next_service.oneshot(message).await.map_err(PipelineError::from_debug) - } -} - -#[cfg(test)] -mod test { - use super::*; - use crate::{ - envelope::DhtMessageFlags, - test_utils::{make_dht_header, make_node_identity, service_spy}, - }; - use futures::executor::block_on; - use tari_comms::{ - message::MessageFlags, - net_address::MultiaddressesWithStats, - peer_manager::{NodeId, Peer, PeerFeatures, PeerFlags}, - types::CommsPublicKey, - }; - use tari_test_utils::panic_context; - - #[test] - fn no_encryption() { - let spy = service_spy(); - let node_identity = make_node_identity(); - let mut encryption = EncryptionLayer::new(Arc::clone(&node_identity)).layer(spy.to_service::()); - - panic_context!(cx); - assert!(encryption.poll_ready(&mut cx).is_ready()); - - let body = b"A".to_vec(); - let msg = DhtOutboundMessage::new( - Peer::new( - CommsPublicKey::default(), - NodeId::default(), - MultiaddressesWithStats::new(vec![]), - PeerFlags::empty(), - PeerFeatures::COMMUNICATION_NODE, - &[], - ), - make_dht_header(&node_identity, &body, DhtMessageFlags::empty()), - OutboundEncryption::None, - MessageFlags::empty(), - body.clone(), - ); - block_on(encryption.call(msg)).unwrap(); - - let msg = spy.pop_request().unwrap(); - assert_eq!(msg.body, body); - assert_eq!(msg.destination_peer.node_id, NodeId::default()); - } - - #[test] - fn encryption() { - let spy = service_spy(); - let node_identity = make_node_identity(); - let mut encryption = EncryptionLayer::new(Arc::clone(&node_identity)).layer(spy.to_service::()); - - panic_context!(cx); - assert!(encryption.poll_ready(&mut cx).is_ready()); - - let body = b"A".to_vec(); - let msg = DhtOutboundMessage::new( - Peer::new( - CommsPublicKey::default(), - NodeId::default(), - MultiaddressesWithStats::new(vec![]), - PeerFlags::empty(), - PeerFeatures::COMMUNICATION_NODE, - &[], - ), - make_dht_header(&node_identity, &body, DhtMessageFlags::ENCRYPTED), - OutboundEncryption::EncryptForPeer, - MessageFlags::empty(), - body.clone(), - ); - block_on(encryption.call(msg)).unwrap(); - - let msg = spy.pop_request().unwrap(); - assert_ne!(msg.body, body); - assert_eq!(msg.destination_peer.node_id, NodeId::default()); - } -} diff --git a/comms/dht/src/outbound/error.rs b/comms/dht/src/outbound/error.rs index 4af39d565c..76b58ad6f4 100644 --- a/comms/dht/src/outbound/error.rs +++ b/comms/dht/src/outbound/error.rs @@ -23,7 +23,10 @@ use derive_error::Error; use futures::channel::mpsc::SendError; use tari_comms::message::MessageError; -use tari_crypto::{signatures::SchnorrSignatureError, tari_utilities::message_format::MessageFormatError}; +use tari_crypto::{ + signatures::SchnorrSignatureError, + tari_utilities::{ciphers::cipher::CipherError, message_format::MessageFormatError}, +}; #[derive(Debug, Error)] pub enum DhtOutboundError { @@ -31,6 +34,7 @@ pub enum DhtOutboundError { MessageSerializationError(MessageError), MessageFormatError(MessageFormatError), SignatureError(SchnorrSignatureError), + CipherError(CipherError), /// Requester reply channel closed before response was received RequesterReplyChannelClosed, /// Peer selection failed @@ -41,4 +45,6 @@ pub enum DhtOutboundError { ReplyChannelCanceled, /// Attempted to send a message to ourselves SendToOurselves, + /// Discovery process failed + DiscoveryFailed, } diff --git a/comms/dht/src/outbound/message.rs b/comms/dht/src/outbound/message.rs index ebccbe9185..abaac4b113 100644 --- a/comms/dht/src/outbound/message.rs +++ b/comms/dht/src/outbound/message.rs @@ -21,17 +21,18 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - envelope::{DhtMessageFlags, DhtMessageHeader}, - outbound::message_params::FinalSendMessageParams, + envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageType, Network, NodeDestination}, + outbound::{message_params::FinalSendMessageParams, message_send_state::MessageSendStates}, }; +use bytes::Bytes; use futures::channel::oneshot; -use std::{fmt, fmt::Display}; +use std::{fmt, fmt::Display, sync::Arc}; use tari_comms::{ - message::{MessageFlags, MessageTag}, + message::{MessageTag, MessagingReplyTx}, peer_manager::Peer, types::CommsPublicKey, }; -use tari_crypto::tari_utilities::hex::Hex; +use tari_utilities::hex::Hex; /// Determines if an outbound message should be Encrypted and, if so, for which public key #[derive(Debug, Clone, PartialEq, Eq)] @@ -40,18 +41,13 @@ pub enum OutboundEncryption { None, /// Message should be encrypted using a shared secret derived from the given public key EncryptFor(Box), - // TODO: Remove this option as it is redundant (message encryption only needed for forwarded private messages) - /// Message should be encrypted using a shared secret derived from the destination peer's - /// public key. Each message sent according to the broadcast strategy will be encrypted for - /// the destination peer. - EncryptForPeer, } impl OutboundEncryption { /// Return the correct DHT flags for the encryption setting pub fn flags(&self) -> DhtMessageFlags { match self { - OutboundEncryption::EncryptFor(_) | OutboundEncryption::EncryptForPeer => DhtMessageFlags::ENCRYPTED, + OutboundEncryption::EncryptFor(_) => DhtMessageFlags::ENCRYPTED, _ => DhtMessageFlags::NONE, } } @@ -61,7 +57,7 @@ impl OutboundEncryption { use OutboundEncryption::*; match self { None => false, - EncryptFor(_) | EncryptForPeer => true, + EncryptFor(_) => true, } } } @@ -71,7 +67,6 @@ impl Display for OutboundEncryption { match self { OutboundEncryption::None => write!(f, "None"), OutboundEncryption::EncryptFor(ref key) => write!(f, "EncryptFor:{}", key.to_hex()), - OutboundEncryption::EncryptForPeer => write!(f, "EncryptForPeer"), } } } @@ -86,7 +81,7 @@ impl Default for OutboundEncryption { pub enum SendMessageResponse { /// Returns the message tags which are queued for sending. These tags will be used in a subsequent OutboundEvent to /// indicate if the message succeeded/failed to send - Queued(Vec), + Queued(MessageSendStates), /// A failure occurred when sending Failed, /// DHT Discovery has been initiated. The caller may wait on the receiver @@ -101,19 +96,19 @@ impl SendMessageResponse { /// A `SendMessageResponse::Failed` will resolve immediately returning a `None`. /// If DHT discovery is initiated, this will resolve once discovery has completed, either /// succeeding (`Some(n)`) or failing (`None`). - pub async fn resolve_ok(self) -> Option> { + pub async fn resolve_ok(self) -> Option { use SendMessageResponse::*; match self { - Queued(tags) => Some(tags), + Queued(send_states) => Some(send_states), Failed => None, PendingDiscovery(rx) => rx.await.ok()?.queued_or_failed(), } } - fn queued_or_failed(self) -> Option> { + fn queued_or_failed(self) -> Option { use SendMessageResponse::*; match self { - Queued(tags) => Some(tags), + Queued(send_states) => Some(send_states), Failed => None, PendingDiscovery(_) => panic!("ok_or_failed() called on PendingDiscovery"), } @@ -124,11 +119,7 @@ impl SendMessageResponse { #[derive(Debug)] pub enum DhtOutboundRequest { /// Send a message using the given broadcast strategy - SendMessage( - Box, - Vec, - oneshot::Sender, - ), + SendMessage(Box, Bytes, oneshot::Sender), } impl fmt::Display for DhtOutboundRequest { @@ -141,52 +132,75 @@ impl fmt::Display for DhtOutboundRequest { } } +/// Wrapper struct for a oneshot reply sender. When this struct is dropped, an automatic fail is sent on the oneshot if +/// a response has not already been sent. +#[derive(Debug)] +pub struct WrappedReplyTx(Option); + +impl WrappedReplyTx { + pub fn into_inner(mut self) -> Option { + self.0.take() + } + + #[cfg(test)] + pub(crate) fn none() -> Self { + Self(None) + } +} + +impl From for WrappedReplyTx { + fn from(inner: MessagingReplyTx) -> Self { + Self(Some(inner)) + } +} + +impl Drop for WrappedReplyTx { + fn drop(&mut self) { + // If this is dropped and the reply tx has not been used already, send an error reply + if let Some(reply_tx) = self.0.take() { + let _ = reply_tx.send(Err(())); + } + } +} + /// DhtOutboundMessage consists of the DHT and comms information required to /// send a message -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct DhtOutboundMessage { pub tag: MessageTag, pub destination_peer: Peer, - pub dht_header: DhtMessageHeader, - pub comms_flags: MessageFlags, - pub encryption: OutboundEncryption, - pub body: Vec, -} - -impl DhtOutboundMessage { - /// Create a new DhtOutboundMessage - pub fn new( - destination_peer: Peer, - dht_header: DhtMessageHeader, - encryption: OutboundEncryption, - comms_flags: MessageFlags, - body: Vec, - ) -> Self - { - Self { - tag: MessageTag::new(), - destination_peer, - dht_header, - encryption, - comms_flags, - body, - } - } + pub custom_header: Option, + pub body: Bytes, + pub ephemeral_public_key: Option>, + pub origin_mac: Option, + pub destination: NodeDestination, + pub dht_message_type: DhtMessageType, + pub reply_tx: WrappedReplyTx, + pub network: Network, + pub dht_flags: DhtMessageFlags, + pub is_broadcast: bool, } impl fmt::Display for DhtOutboundMessage { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + let header_str = self + .custom_header + .as_ref() + .map(|h| format!("{} (Propagated)", h)) + .unwrap_or_else(|| { + format!( + "Network: {:?}, Flags: {:?}, Destination: {}", + self.network, self.dht_flags, self.destination + ) + }); write!( f, - "\n---- DhtOutboundMessage ---- \nSize: {} byte(s)\nType: {}\nPeer: {}\nHeader: {} \nFlags: \ - {:?}\nEncryption: {}\n{}\n----", + "\n---- Outgoing message ---- \nSize: {} byte(s)\nType: {}\nPeer: {}\nHeader: {}\n{}\n----", self.body.len(), - self.dht_header.message_type, + self.dht_message_type, self.destination_peer, - self.dht_header, - self.comms_flags, - self.encryption, - self.tag + header_str, + self.tag, ) } } diff --git a/comms/dht/src/outbound/message_params.rs b/comms/dht/src/outbound/message_params.rs index dbe61ac262..10d528c1a6 100644 --- a/comms/dht/src/outbound/message_params.rs +++ b/comms/dht/src/outbound/message_params.rs @@ -36,11 +36,14 @@ use tari_comms::{ /// /// ```edition2018 /// # use tari_comms_dht::outbound::{SendMessageParams, OutboundEncryption}; +/// use tari_comms::types::CommsPublicKey; /// -/// // These params represent sending to 5 random peers, each encrypted for that peer +/// // These params represent sending to 5 random peers. The message will be able to be decrypted by +/// // the peer with the corresponding secret key of `dest_public_key`. +/// let dest_public_key = CommsPublicKey::default(); /// let params = SendMessageParams::new() /// .random(5) -/// .with_encryption(OutboundEncryption::EncryptForPeer) +/// .with_encryption(OutboundEncryption::EncryptFor(Box::new(dest_public_key))) /// .finish(); /// ``` #[derive(Debug, Clone)] @@ -147,9 +150,9 @@ impl SendMessageParams { self } - /// Set broadcast_strategy to Random + /// Set broadcast_strategy to Random. pub fn random(&mut self, n: usize) -> &mut Self { - self.params_mut().broadcast_strategy = BroadcastStrategy::Random(n); + self.params_mut().broadcast_strategy = BroadcastStrategy::Random(n, vec![]); self } diff --git a/comms/dht/src/outbound/message_send_state.rs b/comms/dht/src/outbound/message_send_state.rs new file mode 100644 index 0000000000..b481529d2f --- /dev/null +++ b/comms/dht/src/outbound/message_send_state.rs @@ -0,0 +1,230 @@ +// Copyright 2020, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use futures::{stream::FuturesUnordered, Future, StreamExt}; +use std::ops::Index; +use tari_comms::message::{MessageTag, MessagingReplyRx}; + +#[derive(Debug)] +pub struct MessageSendState { + pub tag: MessageTag, + reply_rx: MessagingReplyRx, +} +impl MessageSendState { + pub fn new(tag: MessageTag, reply_rx: MessagingReplyRx) -> Self { + Self { tag, reply_rx } + } +} + +#[derive(Debug)] +pub struct MessageSendStates { + inner: Vec, +} + +impl MessageSendStates { + /// The number of `MessageSendState`s held in this container + #[inline] + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Returns true if there are no send states held in this container, otherwise false + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Wait for all send results to return. The return value contains the successful messages sent and the failed + /// messages respectively + pub async fn wait_all(self) -> (Vec, Vec) { + let mut succeeded = Vec::new(); + let mut failed = Vec::new(); + let mut unordered = self.into_futures_unordered(); + while let Some((tag, result)) = unordered.next().await { + match result { + Ok(_) => { + succeeded.push(tag); + }, + Err(_) => { + failed.push(tag); + }, + } + } + + (succeeded, failed) + } + + /// Wait for a certain percentage of successful sends + pub async fn wait_percentage_success(self, threshold_perc: f32) -> (Vec, Vec) { + if self.is_empty() { + return (Vec::new(), Vec::new()); + } + let total = self.len(); + let mut count = 0; + + let mut unordered = self.into_futures_unordered(); + let mut succeeded = Vec::new(); + let mut failed = Vec::new(); + loop { + match unordered.next().await { + Some((tag, result)) => { + match result { + Ok(_) => { + count += 1; + succeeded.push(tag); + }, + Err(_) => { + failed.push(tag); + }, + } + if (count as f32) / (total as f32) >= threshold_perc { + break; + } + }, + None => { + break; + }, + } + } + + (succeeded, failed) + } + + /// Wait for the result of a single send. This should not be used when this container contains multiple send states. + /// + /// ## Panics + /// + /// This function expects there to be exactly one MessageSendState contained in this object and will + /// panic in debug mode if this expectation is not met. It will panic for release builds if called + /// when empty. + pub async fn wait_single(mut self) -> bool { + let state = self + .inner + .pop() + .expect("wait_single called when MessageSendStates::len() is 0"); + + debug_assert!( + self.is_empty(), + "MessageSendStates::wait_single called with multiple message send states" + ); + + state + .reply_rx + .await + .expect("oneshot should never be canceled before sending") + .is_ok() + } + + pub fn into_futures_unordered(self) -> FuturesUnordered)>> { + let unordered = FuturesUnordered::new(); + self.inner.into_iter().for_each(|state| { + unordered.push(async move { + match state.reply_rx.await { + Ok(result) => (state.tag, result), + // Somewhere the reply sender was dropped without first sending a reply + // This should never happen because we if the wrapped oneshot is dropped it sends an Err(()) + Err(_) => unreachable!(), + } + }); + }); + + unordered + } +} + +impl From> for MessageSendStates { + fn from(inner: Vec) -> Self { + Self { inner } + } +} + +impl Index for MessageSendStates { + type Output = MessageSendState; + + fn index(&self, index: usize) -> &Self::Output { + &self.inner[index] + } +} + +#[cfg(test)] +mod test { + use super::*; + use bitflags::_core::iter::repeat_with; + use futures::channel::oneshot; + use tari_comms::message::MessagingReplyTx; + + fn create_send_state() -> (MessageSendState, MessagingReplyTx) { + let (reply_tx, reply_rx) = oneshot::channel(); + let state = MessageSendState::new(MessageTag::new(), reply_rx); + (state, reply_tx) + } + + #[test] + fn is_empty() { + let states = MessageSendStates::from(vec![]); + assert!(states.is_empty()); + let (state, _) = create_send_state(); + let states = MessageSendStates::from(vec![state]); + assert_eq!(states.is_empty(), false); + } + + #[tokio_macros::test_basic] + async fn wait_single() { + let (state, reply_tx) = create_send_state(); + let states = MessageSendStates::from(vec![state]); + reply_tx.send(Ok(())).unwrap(); + assert_eq!(states.len(), 1); + assert_eq!(states.wait_single().await, true); + + let (state, reply_tx) = create_send_state(); + let states = MessageSendStates::from(vec![state]); + reply_tx.send(Err(())).unwrap(); + assert_eq!(states.len(), 1); + assert_eq!(states.wait_single().await, false); + } + + #[tokio_macros::test_basic] + async fn wait_percentage_success() { + let states = repeat_with(|| create_send_state()).take(10).collect::>(); + let (states, mut reply_txs) = states.into_iter().unzip::<_, _, Vec<_>, Vec<_>>(); + let states = MessageSendStates::from(states); + reply_txs.drain(..4).for_each(|tx| tx.send(Err(())).unwrap()); + reply_txs.drain(..).for_each(|tx| tx.send(Ok(())).unwrap()); + + let (success, failed) = states.wait_percentage_success(0.3).await; + assert_eq!(success.len(), 3); + assert_eq!(failed.len(), 4); + } + + #[tokio_macros::test_basic] + async fn wait_all() { + let states = repeat_with(|| create_send_state()).take(10).collect::>(); + let (states, mut reply_txs) = states.into_iter().unzip::<_, _, Vec<_>, Vec<_>>(); + let states = MessageSendStates::from(states); + reply_txs.drain(..4).for_each(|tx| tx.send(Err(())).unwrap()); + reply_txs.drain(..).for_each(|tx| tx.send(Ok(())).unwrap()); + + let (success, failed) = states.wait_all().await; + assert_eq!(success.len(), 6); + assert_eq!(failed.len(), 4); + } +} diff --git a/comms/dht/src/outbound/mock.rs b/comms/dht/src/outbound/mock.rs index 89cc7c8e65..056d4d9eff 100644 --- a/comms/dht/src/outbound/mock.rs +++ b/comms/dht/src/outbound/mock.rs @@ -23,10 +23,16 @@ use crate::outbound::{ message::SendMessageResponse, message_params::FinalSendMessageParams, + message_send_state::MessageSendState, DhtOutboundRequest, OutboundMessageRequester, }; -use futures::{channel::mpsc, stream::Fuse, StreamExt}; +use bytes::Bytes; +use futures::{ + channel::{mpsc, oneshot}, + stream::Fuse, + StreamExt, +}; use std::{ sync::{Arc, Condvar, Mutex, RwLock}, time::Duration, @@ -44,7 +50,7 @@ pub fn create_outbound_service_mock(size: usize) -> (OutboundMessageRequester, O #[derive(Clone, Default)] pub struct OutboundServiceMockState { #[allow(clippy::type_complexity)] - calls: Arc)>>>, + calls: Arc>>, next_response: Arc>>, call_count_cond_var: Arc, } @@ -88,7 +94,7 @@ impl OutboundServiceMockState { /// Wait for a call to be added or timeout. /// /// An error will be returned if the timeout expires. - pub fn wait_pop_call(&self, timeout: Duration) -> Result<(FinalSendMessageParams, Vec), String> { + pub fn wait_pop_call(&self, timeout: Duration) -> Result<(FinalSendMessageParams, Bytes), String> { let call_guard = acquire_lock!(self.calls); let (mut call_guard, timeout) = self .call_count_cond_var @@ -103,19 +109,19 @@ impl OutboundServiceMockState { } pub fn take_next_response(&self) -> Option { - acquire_write_lock!(self.next_response).take() + self.next_response.write().unwrap().take() } - pub fn add_call(&self, req: (FinalSendMessageParams, Vec)) { + pub fn add_call(&self, req: (FinalSendMessageParams, Bytes)) { acquire_lock!(self.calls).push(req); self.call_count_cond_var.notify_all(); } - pub fn take_calls(&self) -> Vec<(FinalSendMessageParams, Vec)> { + pub fn take_calls(&self) -> Vec<(FinalSendMessageParams, Bytes)> { acquire_lock!(self.calls).drain(..).collect() } - pub fn pop_call(&self) -> Option<(FinalSendMessageParams, Vec)> { + pub fn pop_call(&self) -> Option<(FinalSendMessageParams, Bytes)> { acquire_lock!(self.calls).pop() } } @@ -142,13 +148,19 @@ impl OutboundServiceMock { match req { DhtOutboundRequest::SendMessage(params, body, reply_tx) => { self.mock_state.add_call((*params, body)); + let (inner_reply_tx, inner_reply_rx) = oneshot::channel(); let response = self .mock_state .take_next_response() - .or_else(|| Some(SendMessageResponse::Queued(vec![MessageTag::new()]))) + .or_else(|| { + Some(SendMessageResponse::Queued( + vec![MessageSendState::new(MessageTag::new(), inner_reply_rx)].into(), + )) + }) .expect("never none"); reply_tx.send(response).expect("Reply channel cancelled"); + let _ = inner_reply_tx.send(Ok(())); }, } } diff --git a/comms/dht/src/outbound/mod.rs b/comms/dht/src/outbound/mod.rs index 354910eb81..fc9ade3a2e 100644 --- a/comms/dht/src/outbound/mod.rs +++ b/comms/dht/src/outbound/mod.rs @@ -21,22 +21,24 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mod broadcast; -mod encryption; +pub use broadcast::BroadcastLayer; + mod error; +pub use error::DhtOutboundError; + pub(crate) mod message; +pub use message::{DhtOutboundRequest, OutboundEncryption, SendMessageResponse}; + mod message_params; +pub use message_params::SendMessageParams; + +mod message_send_state; + mod requester; +pub use requester::OutboundMessageRequester; + mod serialize; +pub use serialize::SerializeLayer; #[cfg(any(test, feature = "test-mocks"))] pub mod mock; - -pub use self::{ - broadcast::BroadcastLayer, - encryption::EncryptionLayer, - error::DhtOutboundError, - message::{DhtOutboundRequest, OutboundEncryption, SendMessageResponse}, - message_params::SendMessageParams, - requester::OutboundMessageRequester, - serialize::SerializeLayer, -}; diff --git a/comms/dht/src/outbound/requester.rs b/comms/dht/src/outbound/requester.rs index 9eb6eb0025..64e0793a01 100644 --- a/comms/dht/src/outbound/requester.rs +++ b/comms/dht/src/outbound/requester.rs @@ -195,7 +195,7 @@ impl OutboundMessageRequester { message ); } - let body = wrap_in_envelope_body!(message.to_header(), message.into_inner())?.to_encoded_bytes()?; + let body = wrap_in_envelope_body!(message.to_header(), message.into_inner()).to_encoded_bytes(); self.send_raw(params, body).await } @@ -211,7 +211,7 @@ impl OutboundMessageRequester { if cfg!(debug_assertions) { trace!(target: LOG_TARGET, "Send Message: {} {:?}", params, message); } - let body = wrap_in_envelope_body!(message)?.to_encoded_bytes()?; + let body = wrap_in_envelope_body!(message).to_encoded_bytes(); self.send_raw(params, body).await } @@ -224,7 +224,7 @@ impl OutboundMessageRequester { { let (reply_tx, reply_rx) = oneshot::channel(); self.sender - .send(DhtOutboundRequest::SendMessage(Box::new(params), body, reply_tx)) + .send(DhtOutboundRequest::SendMessage(Box::new(params), body.into(), reply_tx)) .await?; reply_rx diff --git a/comms/dht/src/outbound/serialize.rs b/comms/dht/src/outbound/serialize.rs index fb7694df62..58ce96dced 100644 --- a/comms/dht/src/outbound/serialize.rs +++ b/comms/dht/src/outbound/serialize.rs @@ -20,19 +20,20 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{outbound::message::DhtOutboundMessage, proto::envelope::DhtEnvelope}; +use crate::{ + consts::DHT_ENVELOPE_HEADER_VERSION, + outbound::message::DhtOutboundMessage, + proto::envelope::{DhtEnvelope, DhtHeader}, +}; use futures::{task::Context, Future}; use log::*; -use rand::rngs::OsRng; -use std::{sync::Arc, task::Poll}; +use std::task::Poll; use tari_comms::{ message::{MessageExt, OutboundMessage}, - peer_manager::NodeIdentity, pipeline::PipelineError, - utils::signature, Bytes, }; -use tari_crypto::tari_utilities::{hex::Hex, message_format::MessageFormat}; +use tari_utilities::ByteArray; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::serialize"; @@ -40,22 +41,16 @@ const LOG_TARGET: &str = "comms::dht::serialize"; #[derive(Clone)] pub struct SerializeMiddleware { inner: S, - node_identity: Arc, } impl SerializeMiddleware { - pub fn new(service: S, node_identity: Arc) -> Self { - Self { - inner: service, - node_identity, - } + pub fn new(service: S) -> Self { + Self { inner: service } } } impl Service for SerializeMiddleware -where - S: Service + Clone + 'static, - S::Error: std::error::Error + Send + Sync + 'static, +where S: Service + Clone + 'static { type Error = PipelineError; type Response = (); @@ -66,84 +61,56 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, msg: DhtOutboundMessage) -> Self::Future { - Self::serialize(self.inner.clone(), Arc::clone(&self.node_identity), msg) - } -} + fn call(&mut self, message: DhtOutboundMessage) -> Self::Future { + let next_service = self.inner.clone(); + async move { + debug!(target: LOG_TARGET, "Serializing outbound message {:?}", message.tag); -impl SerializeMiddleware -where - S: Service, - S::Error: std::error::Error + Send + Sync + 'static, -{ - pub async fn serialize( - next_service: S, - node_identity: Arc, - message: DhtOutboundMessage, - ) -> Result<(), PipelineError> - { - debug!(target: LOG_TARGET, "Serializing outbound message {:?}", message.tag); - - let DhtOutboundMessage { - mut dht_header, - body, - destination_peer, - comms_flags, - .. - } = message; - - // The message is being forwarded if the origin public_key is specified and it is not this node - let is_forwarded = dht_header - .origin - .as_ref() - .map(|o| &o.public_key != node_identity.public_key()) - .unwrap_or(false); - - // If forwarding the message, the DhtHeader already has a signature that should not change - if is_forwarded { - trace!( - target: LOG_TARGET, - "Forwarded message {:?}. Message will not be signed", - message.tag - ); - } else { - // Sign the body if the origin public key was previously specified. - if let Some(origin) = dht_header.origin.as_mut() { - let signature = signature::sign(&mut OsRng, node_identity.secret_key().clone(), &body) - .map_err(PipelineError::from_debug)?; - origin.signature = signature.to_binary().map_err(PipelineError::from_debug)?; - trace!( - target: LOG_TARGET, - "Signed message {:?}: {}", - message.tag, - origin.signature.to_hex() - ); - } - } - - let envelope = DhtEnvelope::new(dht_header.into(), body); - - let body = Bytes::from(envelope.to_encoded_bytes().map_err(PipelineError::from_debug)?); - - next_service - .oneshot(OutboundMessage::with_tag( - message.tag, - destination_peer.node_id, - comms_flags, + let DhtOutboundMessage { + tag, + destination_peer, + custom_header, body, - )) - .await - .map_err(PipelineError::from_debug) + ephemeral_public_key, + destination, + dht_message_type, + network, + dht_flags, + origin_mac, + reply_tx, + .. + } = message; + + let dht_header = custom_header.map(DhtHeader::from).unwrap_or_else(|| DhtHeader { + version: DHT_ENVELOPE_HEADER_VERSION, + origin_mac: origin_mac.map(|b| b.to_vec()).unwrap_or_else(Vec::new), + ephemeral_public_key: ephemeral_public_key.map(|e| e.to_vec()).unwrap_or_else(Vec::new), + message_type: dht_message_type as i32, + network: network as i32, + flags: dht_flags.bits(), + destination: Some(destination.into()), + }); + let envelope = DhtEnvelope::new(dht_header, body); + + let body = Bytes::from(envelope.to_encoded_bytes()); + + next_service + .oneshot(OutboundMessage { + tag, + peer_node_id: destination_peer.node_id, + reply_tx: reply_tx.into_inner(), + body, + }) + .await + } } } -pub struct SerializeLayer { - node_identity: Arc, -} +pub struct SerializeLayer; impl SerializeLayer { - pub fn new(node_identity: Arc) -> Self { - Self { node_identity } + pub fn new() -> Self { + Self } } @@ -151,52 +118,29 @@ impl Layer for SerializeLayer { type Service = SerializeMiddleware; fn layer(&self, service: S) -> Self::Service { - SerializeMiddleware::new(service, Arc::clone(&self.node_identity)) + SerializeMiddleware::new(service) } } #[cfg(test)] mod test { use super::*; - use crate::{ - envelope::DhtMessageFlags, - outbound::OutboundEncryption, - test_utils::{make_dht_header, make_node_identity, service_spy}, - }; + use crate::test_utils::{create_outbound_message, service_spy}; use futures::executor::block_on; use prost::Message; - use tari_comms::{ - message::MessageFlags, - net_address::MultiaddressesWithStats, - peer_manager::{NodeId, Peer, PeerFeatures, PeerFlags}, - types::CommsPublicKey, - }; + use tari_comms::peer_manager::NodeId; use tari_test_utils::panic_context; #[test] fn serialize() { let spy = service_spy(); - let node_identity = make_node_identity(); - let mut serialize = SerializeLayer::new(Arc::clone(&node_identity)).layer(spy.to_service::()); + let mut serialize = SerializeLayer.layer(spy.to_service::()); panic_context!(cx); assert!(serialize.poll_ready(&mut cx).is_ready()); - let body = b"A".to_vec(); - let msg = DhtOutboundMessage::new( - Peer::new( - CommsPublicKey::default(), - NodeId::default(), - MultiaddressesWithStats::new(vec![]), - PeerFlags::empty(), - PeerFeatures::COMMUNICATION_NODE, - &[], - ), - make_dht_header(&node_identity, &body, DhtMessageFlags::empty()), - OutboundEncryption::None, - MessageFlags::empty(), - body, - ); + let body = b"A"; + let msg = create_outbound_message(body); block_on(serialize.call(msg)).unwrap(); let mut msg = spy.pop_request().unwrap(); diff --git a/comms/dht/src/proto/envelope.proto b/comms/dht/src/proto/envelope.proto index 43f1815cee..d2a16d77dc 100644 --- a/comms/dht/src/proto/envelope.proto +++ b/comms/dht/src/proto/envelope.proto @@ -22,8 +22,7 @@ enum DhtMessageType { message DhtHeader { uint32 version = 1; oneof destination { - // The sender has chosen not to disclose the message destination, or the destination is - // the peer being sent to. + // The sender has chosen not to disclose the message destination bool unknown = 2; // Destined for a particular public key bytes public_key = 3; @@ -32,14 +31,17 @@ message DhtHeader { } // Origin public key of the message. This can be the same peer that sent the message - // or another peer if the message should be forwarded. This is optional but must be specified + // or another peer if the message should be forwarded. This is optional but MUST be specified // if the ENCRYPTED flag is set. - DhtOrigin origin = 5; + // If an ephemeral_public_key is specified, this MUST be encrypted using a derived ECDH shared key + bytes origin_mac = 5; + // Ephemeral public key component of the ECDH shared key. MUST be specified if the ENCRYPTED flag is set. + bytes ephemeral_public_key = 6; // The type of message - DhtMessageType message_type = 6; + DhtMessageType message_type = 7; // The network for which this message is intended (e.g. TestNet, MainNet etc.) - Network network = 7; - uint32 flags = 8; + Network network = 8; + uint32 flags = 9; } enum Network { @@ -56,7 +58,8 @@ message DhtEnvelope { bytes body = 2; } -message DhtOrigin { +// The Message Authentication Code (MAC) message format of the decrypted `DhtHeader::origin_mac` field +message OriginMac { bytes public_key = 1; bytes signature = 2; } \ No newline at end of file diff --git a/comms/dht/src/proto/mod.rs b/comms/dht/src/proto/mod.rs index ac940a0019..906664d1f7 100644 --- a/comms/dht/src/proto/mod.rs +++ b/comms/dht/src/proto/mod.rs @@ -22,6 +22,7 @@ use crate::proto::envelope::Network; use std::fmt; +use tari_utilities::hex::Hex; #[path = "tari.dht.envelope.rs"] pub mod envelope; @@ -67,3 +68,17 @@ impl fmt::Display for dht::RejectMessage { write!(f, "RejectMessage(Reason = {})", self.reason) } } + +//---------------------------------- JoinMessage --------------------------------------------// + +impl fmt::Display for dht::JoinMessage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "JoinMessage(NodeId = {}, Addresses = {:?}, Features = {:?})", + self.node_id.to_hex(), + self.addresses, + self.peer_features + ) + } +} diff --git a/comms/dht/src/proto/store_forward.proto b/comms/dht/src/proto/store_forward.proto index 702aee4225..896728fb8c 100644 --- a/comms/dht/src/proto/store_forward.proto +++ b/comms/dht/src/proto/store_forward.proto @@ -11,6 +11,8 @@ package tari.dht.store_forward; // will be sent. message StoredMessagesRequest { google.protobuf.Timestamp since = 1; + uint32 request_id = 2; + bytes dist_threshold = 3; } // Storage for a single message envelope, including the date and time when the element was stored @@ -18,10 +20,24 @@ message StoredMessage { google.protobuf.Timestamp stored_at = 1; uint32 version = 2; tari.dht.envelope.DhtHeader dht_header = 3; - bytes encrypted_body = 4; + bytes body = 4; } // The StoredMessages contains the set of applicable messages retrieved from a neighbouring peer node. message StoredMessagesResponse { repeated StoredMessage messages = 1; + uint32 request_id = 2; + enum SafResponseType { + // Messages for the requested public key or node ID + ForMe = 0; + // Discovery messages that could be for the requester + Discovery = 1; + // Join messages that the requester could be interested in + Join = 2; + // Messages without an explicit destination and with an unidentified encrypted source + Anonymous = 3; + // Messages within the requesting node's region + InRegion = 4; + } + SafResponseType response_type = 3; } diff --git a/comms/dht/src/proto/tari.dht.envelope.rs b/comms/dht/src/proto/tari.dht.envelope.rs index f7e24f31b2..b9dda194e5 100644 --- a/comms/dht/src/proto/tari.dht.envelope.rs +++ b/comms/dht/src/proto/tari.dht.envelope.rs @@ -3,17 +3,21 @@ pub struct DhtHeader { #[prost(uint32, tag = "1")] pub version: u32, /// Origin public key of the message. This can be the same peer that sent the message - /// or another peer if the message should be forwarded. This is optional but must be specified + /// or another peer if the message should be forwarded. This is optional but MUST be specified /// if the ENCRYPTED flag is set. - #[prost(message, optional, tag = "5")] - pub origin: ::std::option::Option, + /// If an ephemeral_public_key is specified, this MUST be encrypted using a derived ECDH shared key + #[prost(bytes, tag = "5")] + pub origin_mac: std::vec::Vec, + /// Ephemeral public key component of the ECDH shared key. MUST be specified if the ENCRYPTED flag is set. + #[prost(bytes, tag = "6")] + pub ephemeral_public_key: std::vec::Vec, /// The type of message - #[prost(enumeration = "DhtMessageType", tag = "6")] + #[prost(enumeration = "DhtMessageType", tag = "7")] pub message_type: i32, /// The network for which this message is intended (e.g. TestNet, MainNet etc.) - #[prost(enumeration = "Network", tag = "7")] + #[prost(enumeration = "Network", tag = "8")] pub network: i32, - #[prost(uint32, tag = "8")] + #[prost(uint32, tag = "9")] pub flags: u32, #[prost(oneof = "dht_header::Destination", tags = "2, 3, 4")] pub destination: ::std::option::Option, @@ -21,8 +25,7 @@ pub struct DhtHeader { pub mod dht_header { #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum Destination { - /// The sender has chosen not to disclose the message destination, or the destination is - /// the peer being sent to. + /// The sender has chosen not to disclose the message destination #[prost(bool, tag = "2")] Unknown(bool), /// Destined for a particular public key @@ -40,8 +43,9 @@ pub struct DhtEnvelope { #[prost(bytes, tag = "2")] pub body: std::vec::Vec, } +/// The Message Authentication Code (MAC) message format of the decrypted `DhtHeader::origin_mac` field #[derive(Clone, PartialEq, ::prost::Message)] -pub struct DhtOrigin { +pub struct OriginMac { #[prost(bytes, tag = "1")] pub public_key: std::vec::Vec, #[prost(bytes, tag = "2")] diff --git a/comms/dht/src/proto/tari.dht.store_forward.rs b/comms/dht/src/proto/tari.dht.store_forward.rs index f1f881adaf..af958f75cc 100644 --- a/comms/dht/src/proto/tari.dht.store_forward.rs +++ b/comms/dht/src/proto/tari.dht.store_forward.rs @@ -5,6 +5,10 @@ pub struct StoredMessagesRequest { #[prost(message, optional, tag = "1")] pub since: ::std::option::Option<::prost_types::Timestamp>, + #[prost(uint32, tag = "2")] + pub request_id: u32, + #[prost(bytes, tag = "3")] + pub dist_threshold: std::vec::Vec, } /// Storage for a single message envelope, including the date and time when the element was stored #[derive(Clone, PartialEq, ::prost::Message)] @@ -16,11 +20,31 @@ pub struct StoredMessage { #[prost(message, optional, tag = "3")] pub dht_header: ::std::option::Option, #[prost(bytes, tag = "4")] - pub encrypted_body: std::vec::Vec, + pub body: std::vec::Vec, } /// The StoredMessages contains the set of applicable messages retrieved from a neighbouring peer node. #[derive(Clone, PartialEq, ::prost::Message)] pub struct StoredMessagesResponse { #[prost(message, repeated, tag = "1")] pub messages: ::std::vec::Vec, + #[prost(uint32, tag = "2")] + pub request_id: u32, + #[prost(enumeration = "stored_messages_response::SafResponseType", tag = "3")] + pub response_type: i32, +} +pub mod stored_messages_response { + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] + #[repr(i32)] + pub enum SafResponseType { + /// Messages for the requested public key or node ID + ForMe = 0, + /// Discovery messages that could be for the requester + Discovery = 1, + /// Join messages that the requester could be interested in + Join = 2, + /// Messages without an explicit destination and with an unidentified encrypted source + Anonymous = 3, + /// Messages within the requesting node's region + InRegion = 4, + } } diff --git a/comms/dht/src/schema.rs b/comms/dht/src/schema.rs new file mode 100644 index 0000000000..f56a16c168 --- /dev/null +++ b/comms/dht/src/schema.rs @@ -0,0 +1,25 @@ +table! { + dht_metadata (id) { + id -> Integer, + key -> Text, + value -> Binary, + } +} + +table! { + stored_messages (id) { + id -> Integer, + version -> Integer, + origin_pubkey -> Nullable, + message_type -> Integer, + destination_pubkey -> Nullable, + destination_node_id -> Nullable, + header -> Binary, + body -> Binary, + is_encrypted -> Bool, + priority -> Integer, + stored_at -> Timestamp, + } +} + +allow_tables_to_appear_in_same_query!(dht_metadata, stored_messages,); diff --git a/comms/dht/src/storage/connection.rs b/comms/dht/src/storage/connection.rs new file mode 100644 index 0000000000..856a94315a --- /dev/null +++ b/comms/dht/src/storage/connection.rs @@ -0,0 +1,151 @@ +// Copyright 2020. The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use crate::storage::error::StorageError; +use diesel::{Connection, SqliteConnection}; +use log::*; +use std::{ + io, + path::PathBuf, + sync::{Arc, Mutex}, +}; +use tokio::task; + +const LOG_TARGET: &str = "comms::dht::storage::connection"; + +#[derive(Clone, Debug)] +pub enum DbConnectionUrl { + /// In-memory database. Each connection has it's own database + Memory, + /// In-memory database shared with more than one in-process connection according to the given identifier + MemoryShared(String), + /// Database persisted on disk + File(PathBuf), +} + +impl DbConnectionUrl { + pub fn to_url_string(&self) -> String { + use DbConnectionUrl::*; + match self { + Memory => ":memory:".to_owned(), + MemoryShared(identifier) => format!("file:{}?mode=memory&cache=shared", identifier), + File(path) => path + .to_str() + .expect("Invalid non-UTF8 character in database path") + .to_owned(), + } + } +} + +#[derive(Clone)] +pub struct DbConnection { + inner: Arc>, +} + +impl DbConnection { + #[cfg(test)] + pub async fn connect_memory(name: String) -> Result { + Self::connect_url(DbConnectionUrl::MemoryShared(name)).await + } + + pub async fn connect_url(db_url: DbConnectionUrl) -> Result { + debug!(target: LOG_TARGET, "Connecting to database using '{:?}'", db_url); + let conn = task::spawn_blocking(move || { + let conn = SqliteConnection::establish(&db_url.to_url_string())?; + conn.execute("PRAGMA foreign_keys = ON; PRAGMA busy_timeout = 60000;")?; + Result::<_, StorageError>::Ok(conn) + }) + .await??; + + Ok(Self::new(conn)) + } + + pub async fn connect_and_migrate(db_url: DbConnectionUrl) -> Result { + let conn = Self::connect_url(db_url).await?; + let output = conn.migrate().await?; + info!(target: LOG_TARGET, "DHT database migration: {}", output.trim()); + Ok(conn) + } + + fn new(conn: SqliteConnection) -> Self { + Self { + inner: Arc::new(Mutex::new(conn)), + } + } + + pub async fn migrate(&self) -> Result { + embed_migrations!("./migrations"); + + self.with_connection_async(|conn| { + let mut buf = io::Cursor::new(Vec::new()); + embedded_migrations::run_with_output(conn, &mut buf) + .map_err(|err| StorageError::DatabaseMigrationFailed(format!("Database migration failed {}", err)))?; + Ok(String::from_utf8_lossy(&buf.into_inner()).to_string()) + }) + .await + } + + pub async fn with_connection_async(&self, f: F) -> Result + where + F: FnOnce(&SqliteConnection) -> Result + Send + 'static, + R: Send + 'static, + { + let conn_mutex = self.inner.clone(); + let ret = task::spawn_blocking(move || { + let lock = acquire_lock!(conn_mutex); + f(&*lock) + }) + .await??; + Ok(ret) + } +} + +#[cfg(test)] +mod test { + use super::*; + use diesel::{expression::sql_literal::sql, sql_types::Integer, RunQueryDsl}; + use tari_test_utils::random; + + #[tokio_macros::test_basic] + async fn connect_and_migrate() { + let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); + let output = conn.migrate().await.unwrap(); + assert!(output.starts_with("Running migration")); + } + + #[tokio_macros::test_basic] + async fn memory_connections() { + let id = random::string(8); + let conn = DbConnection::connect_memory(id.clone()).await.unwrap(); + conn.migrate().await.unwrap(); + let conn = DbConnection::connect_memory(id).await.unwrap(); + let count: i32 = conn + .with_connection_async(|c| { + sql::("SELECT COUNT(*) FROM stored_messages") + .get_result(c) + .map_err(Into::into) + }) + .await + .unwrap(); + assert_eq!(count, 0); + } +} diff --git a/comms/dht/src/storage/database.rs b/comms/dht/src/storage/database.rs new file mode 100644 index 0000000000..eb922906f1 --- /dev/null +++ b/comms/dht/src/storage/database.rs @@ -0,0 +1,87 @@ +// Copyright 2020, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use super::{dht_setting_entry::DhtMetadataEntry, DbConnection, StorageError}; +use crate::{ + schema::dht_metadata, + storage::{dht_setting_entry::NewDhtMetadataEntry, DhtMetadataKey}, +}; +use diesel::{ExpressionMethods, QueryDsl, RunQueryDsl}; +use tari_utilities::message_format::MessageFormat; + +#[derive(Clone)] +pub struct DhtDatabase { + connection: DbConnection, +} + +impl DhtDatabase { + pub fn new(connection: DbConnection) -> Self { + Self { connection } + } + + pub async fn get_metadata_value(&self, key: DhtMetadataKey) -> Result, StorageError> { + match self.get_metadata_value_bytes(key).await? { + Some(bytes) => T::from_binary(&bytes).map(Some).map_err(Into::into), + None => Ok(None), + } + } + + pub async fn get_metadata_value_bytes(&self, key: DhtMetadataKey) -> Result>, StorageError> { + self.connection + .with_connection_async(move |conn| { + dht_metadata::table + .filter(dht_metadata::key.eq(key.to_string())) + .first(conn) + .map(|rec: DhtMetadataEntry| Some(rec.value)) + .or_else(|err| match err { + diesel::result::Error::NotFound => Ok(None), + err => Err(err.into()), + }) + }) + .await + } + + pub async fn set_metadata_value( + &self, + key: DhtMetadataKey, + value: T, + ) -> Result<(), StorageError> + { + let bytes = value.to_binary()?; + self.set_metadata_value_bytes(key, bytes).await + } + + pub async fn set_metadata_value_bytes(&self, key: DhtMetadataKey, value: Vec) -> Result<(), StorageError> { + self.connection + .with_connection_async(move |conn| { + diesel::replace_into(dht_metadata::table) + .values(NewDhtMetadataEntry { + key: key.to_string(), + value, + }) + .execute(conn) + .map(|_| ()) + .map_err(Into::into) + }) + .await + } +} diff --git a/comms/dht/src/storage/dht_setting_entry.rs b/comms/dht/src/storage/dht_setting_entry.rs new file mode 100644 index 0000000000..73cb39fe69 --- /dev/null +++ b/comms/dht/src/storage/dht_setting_entry.rs @@ -0,0 +1,51 @@ +// Copyright 2020, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use crate::schema::dht_metadata; +use std::fmt; + +#[derive(Debug, Clone, Copy)] +pub enum DhtMetadataKey { + /// Timestamp each time the DHT is shut down + OfflineTimestamp, +} + +impl fmt::Display for DhtMetadataKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self) + } +} + +#[derive(Clone, Debug, Insertable)] +#[table_name = "dht_metadata"] +pub struct NewDhtMetadataEntry { + pub key: String, + pub value: Vec, +} + +#[derive(Clone, Debug, Queryable, Identifiable)] +#[table_name = "dht_metadata"] +pub struct DhtMetadataEntry { + pub id: i32, + pub key: String, + pub value: Vec, +} diff --git a/comms/dht/src/storage/error.rs b/comms/dht/src/storage/error.rs new file mode 100644 index 0000000000..de94d7557c --- /dev/null +++ b/comms/dht/src/storage/error.rs @@ -0,0 +1,37 @@ +// Copyright 2020, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use derive_error::Error; +use tari_utilities::message_format::MessageFormatError; +use tokio::task; + +#[derive(Debug, Error)] +pub enum StorageError { + /// Database path contained non-UTF8 characters that are not supported by the host OS + InvalidUnicodePath, + JoinError(task::JoinError), + ConnectionError(diesel::ConnectionError), + #[error(msg_embedded, no_from, non_std)] + DatabaseMigrationFailed(String), + ResultError(diesel::result::Error), + MessageFormatError(MessageFormatError), +} diff --git a/base_layer/core/src/proof_of_work/diff_adj_manager/error.rs b/comms/dht/src/storage/mod.rs similarity index 81% rename from base_layer/core/src/proof_of_work/diff_adj_manager/error.rs rename to comms/dht/src/storage/mod.rs index 8e2e7f5830..3b65b199f5 100644 --- a/base_layer/core/src/proof_of_work/diff_adj_manager/error.rs +++ b/comms/dht/src/storage/mod.rs @@ -20,13 +20,14 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{chain_storage::ChainStorageError, proof_of_work::error::DifficultyAdjustmentError}; -use derive_error::Error; +mod connection; +pub use connection::{DbConnection, DbConnectionUrl}; -#[derive(Debug, Error, Clone, PartialEq)] -pub enum DiffAdjManagerError { - DifficultyAdjustmentError(DifficultyAdjustmentError), - ChainStorageError(ChainStorageError), - EmptyBlockchain, - PoisonedAccess, -} +mod error; +pub use error::StorageError; + +mod dht_setting_entry; +pub use dht_setting_entry::{DhtMetadataEntry, DhtMetadataKey}; + +mod database; +pub use database::DhtDatabase; diff --git a/comms/dht/src/store_forward/database/mod.rs b/comms/dht/src/store_forward/database/mod.rs new file mode 100644 index 0000000000..945797895b --- /dev/null +++ b/comms/dht/src/store_forward/database/mod.rs @@ -0,0 +1,278 @@ +// Copyright 2020, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +mod stored_message; +pub use stored_message::{NewStoredMessage, StoredMessage}; + +use crate::{ + envelope::DhtMessageType, + schema::stored_messages, + storage::{DbConnection, StorageError}, + store_forward::message::StoredMessagePriority, +}; +use chrono::{DateTime, NaiveDateTime, Utc}; +use diesel::{BoolExpressionMethods, ExpressionMethods, QueryDsl, RunQueryDsl}; +use tari_comms::{ + peer_manager::{node_id::NodeDistance, NodeId}, + types::CommsPublicKey, +}; +use tari_utilities::hex::Hex; + +pub struct StoreAndForwardDatabase { + connection: DbConnection, +} + +impl StoreAndForwardDatabase { + pub fn new(connection: DbConnection) -> Self { + Self { connection } + } + + pub async fn insert_message(&self, message: NewStoredMessage) -> Result<(), StorageError> { + self.connection + .with_connection_async(|conn| { + diesel::insert_into(stored_messages::table) + .values(message) + .execute(conn)?; + Ok(()) + }) + .await + } + + pub async fn find_messages_for_peer( + &self, + public_key: &CommsPublicKey, + node_id: &NodeId, + since: Option>, + limit: i64, + ) -> Result, StorageError> + { + let pk_hex = public_key.to_hex(); + let node_id_hex = node_id.to_hex(); + self.connection + .with_connection_async::<_, Vec>(move |conn| { + let mut query = stored_messages::table + .select(stored_messages::all_columns) + .filter( + stored_messages::destination_pubkey + .eq(pk_hex) + .or(stored_messages::destination_node_id.eq(node_id_hex)), + ) + .filter(stored_messages::message_type.eq(DhtMessageType::None as i32)) + .into_boxed(); + + if let Some(since) = since { + query = query.filter(stored_messages::stored_at.gt(since.naive_utc())); + } + + query + .order_by(stored_messages::stored_at.desc()) + .limit(limit) + .get_results(conn) + .map_err(Into::into) + }) + .await + } + + pub async fn find_regional_messages( + &self, + node_id: &NodeId, + dist_threshold: Option>, + since: Option>, + limit: i64, + ) -> Result, StorageError> + { + let node_id_hex = node_id.to_hex(); + let results = self + .connection + .with_connection_async::<_, Vec>(move |conn| { + let mut query = stored_messages::table + .select(stored_messages::all_columns) + .filter(stored_messages::destination_node_id.ne(node_id_hex)) + .filter(stored_messages::destination_node_id.is_not_null()) + .filter(stored_messages::message_type.eq(DhtMessageType::None as i32)) + .into_boxed(); + + if let Some(since) = since { + query = query.filter(stored_messages::stored_at.gt(since.naive_utc())); + } + + query + .order_by(stored_messages::stored_at.desc()) + .limit(limit) + .get_results(conn) + .map_err(Into::into) + }) + .await?; + + match dist_threshold { + Some(dist_threshold) => { + // Filter node ids that are within the distance threshold from the source node id + let results = results + .into_iter() + // TODO: Investigate if we could do this in sqlite using XOR (^) + .filter(|message| match message.destination_node_id { + Some(ref dest_node_id) => match NodeId::from_hex(dest_node_id).ok() { + Some(dest_node_id) => { + &dest_node_id == node_id || &dest_node_id.distance(node_id) <= &*dist_threshold + }, + None => false, + }, + None => true, + }) + .collect(); + Ok(results) + }, + None => Ok(results), + } + } + + pub async fn find_anonymous_messages( + &self, + since: Option>, + limit: i64, + ) -> Result, StorageError> + { + self.connection + .with_connection_async(move |conn| { + let mut query = stored_messages::table + .select(stored_messages::all_columns) + .filter(stored_messages::origin_pubkey.is_null()) + .filter(stored_messages::destination_pubkey.is_null()) + .filter(stored_messages::is_encrypted.eq(true)) + .filter(stored_messages::message_type.eq(DhtMessageType::None as i32)) + .into_boxed(); + + if let Some(since) = since { + query = query.filter(stored_messages::stored_at.gt(since.naive_utc())); + } + + query + .order_by(stored_messages::stored_at.desc()) + .limit(limit) + .get_results(conn) + .map_err(Into::into) + }) + .await + } + + pub async fn find_join_messages( + &self, + since: Option>, + limit: i64, + ) -> Result, StorageError> + { + self.connection + .with_connection_async(move |conn| { + let mut query = stored_messages::table + .select(stored_messages::all_columns) + .filter(stored_messages::message_type.eq(DhtMessageType::Join as i32)) + .into_boxed(); + + if let Some(since) = since { + query = query.filter(stored_messages::stored_at.gt(since.naive_utc())); + } + + query + .order_by(stored_messages::stored_at.desc()) + .limit(limit) + .get_results(conn) + .map_err(Into::into) + }) + .await + } + + pub async fn find_messages_of_type_for_pubkey( + &self, + public_key: &CommsPublicKey, + message_type: DhtMessageType, + since: Option>, + limit: i64, + ) -> Result, StorageError> + { + let pk_hex = public_key.to_hex(); + self.connection + .with_connection_async(move |conn| { + let mut query = stored_messages::table + .select(stored_messages::all_columns) + .filter(stored_messages::destination_pubkey.eq(pk_hex)) + .filter(stored_messages::message_type.eq(message_type as i32)) + .into_boxed(); + + if let Some(since) = since { + query = query.filter(stored_messages::stored_at.gt(since.naive_utc())); + } + + query + .order_by(stored_messages::stored_at.desc()) + .limit(limit) + .get_results(conn) + .map_err(Into::into) + }) + .await + } + + #[cfg(test)] + pub(crate) async fn get_all_messages(&self) -> Result, StorageError> { + self.connection + .with_connection_async(|conn| { + stored_messages::table + .select(stored_messages::all_columns) + .get_results(conn) + .map_err(Into::into) + }) + .await + } + + pub(crate) async fn delete_messages_with_priority_older_than( + &self, + priority: StoredMessagePriority, + since: NaiveDateTime, + ) -> Result + { + self.connection + .with_connection_async(move |conn| { + diesel::delete(stored_messages::table) + .filter(stored_messages::stored_at.lt(since)) + .filter(stored_messages::priority.eq(priority as i32)) + .execute(conn) + .map_err(Into::into) + }) + .await + } +} + +#[cfg(test)] +mod test { + use super::*; + use tari_test_utils::random; + + #[tokio_macros::test_basic] + async fn insert_messages() { + let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); + // let conn = DbConnection::connect_path("/tmp/tmp.db").await.unwrap(); + conn.migrate().await.unwrap(); + let db = StoreAndForwardDatabase::new(conn); + db.insert_message(Default::default()).await.unwrap(); + let messages = db.get_all_messages().await.unwrap(); + assert_eq!(messages.len(), 1); + } +} diff --git a/comms/dht/src/store_forward/database/stored_message.rs b/comms/dht/src/store_forward/database/stored_message.rs new file mode 100644 index 0000000000..a3c8039e1b --- /dev/null +++ b/comms/dht/src/store_forward/database/stored_message.rs @@ -0,0 +1,93 @@ +// Copyright 2020, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use crate::{ + inbound::DecryptedDhtMessage, + proto::envelope::DhtHeader, + schema::stored_messages, + store_forward::message::StoredMessagePriority, +}; +use chrono::NaiveDateTime; +use std::convert::TryInto; +use tari_comms::message::MessageExt; +use tari_utilities::hex::Hex; + +#[derive(Clone, Debug, Insertable, Default)] +#[table_name = "stored_messages"] +pub struct NewStoredMessage { + pub version: i32, + pub origin_pubkey: Option, + pub message_type: i32, + pub destination_pubkey: Option, + pub destination_node_id: Option, + pub header: Vec, + pub body: Vec, + pub is_encrypted: bool, + pub priority: i32, +} + +impl NewStoredMessage { + pub fn try_construct(message: DecryptedDhtMessage, priority: StoredMessagePriority) -> Option { + let DecryptedDhtMessage { + version, + authenticated_origin, + decryption_result, + dht_header, + .. + } = message; + + let body = match decryption_result { + Ok(envelope_body) => envelope_body.to_encoded_bytes(), + Err(encrypted_body) => encrypted_body, + }; + + Some(Self { + version: version.try_into().ok()?, + origin_pubkey: authenticated_origin.as_ref().map(|pk| pk.to_hex()), + message_type: dht_header.message_type as i32, + destination_pubkey: dht_header.destination.public_key().map(|pk| pk.to_hex()), + destination_node_id: dht_header.destination.node_id().map(|node_id| node_id.to_hex()), + is_encrypted: dht_header.flags.is_encrypted(), + priority: priority as i32, + header: { + let dht_header: DhtHeader = dht_header.into(); + dht_header.to_encoded_bytes() + }, + body, + }) + } +} + +#[derive(Clone, Debug, Queryable, Identifiable)] +pub struct StoredMessage { + pub id: i32, + pub version: i32, + pub origin_pubkey: Option, + pub message_type: i32, + pub destination_pubkey: Option, + pub destination_node_id: Option, + pub header: Vec, + pub body: Vec, + pub is_encrypted: bool, + pub priority: i32, + pub stored_at: NaiveDateTime, +} diff --git a/comms/dht/src/store_forward/error.rs b/comms/dht/src/store_forward/error.rs index dc0f084949..ac4e70ff74 100644 --- a/comms/dht/src/store_forward/error.rs +++ b/comms/dht/src/store_forward/error.rs @@ -20,12 +20,12 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{actor::DhtActorError, envelope::DhtMessageError, outbound::DhtOutboundError}; +use crate::{actor::DhtActorError, envelope::DhtMessageError, outbound::DhtOutboundError, storage::StorageError}; use derive_error::Error; use prost::DecodeError; use std::io; use tari_comms::{message::MessageError, peer_manager::PeerManagerError}; -use tari_crypto::tari_utilities::ciphers::cipher::CipherError; +use tari_utilities::{byte_array::ByteArrayError, ciphers::cipher::CipherError}; #[derive(Debug, Error)] pub enum StoreAndForwardError { @@ -36,9 +36,11 @@ pub enum StoreAndForwardError { /// Received stored message has an invalid destination InvalidDestination, /// Received stored message has an invalid origin signature - InvalidSignature, + InvalidOriginMac, /// Invalid envelope body InvalidEnvelopeBody, + /// DHT header is invalid + InvalidDhtHeader, /// Received stored message which is not encrypted StoredMessageNotEncrypted, /// Unable to decrypt received stored message @@ -58,4 +60,22 @@ pub enum StoreAndForwardError { MessageOriginRequired, /// The message was malformed MalformedMessage, + + StorageError(StorageError), + /// The store and forward service requester channel closed + RequesterChannelClosed, + /// The request was cancelled by the store and forward service + RequestCancelled, + /// The message was not valid for store and forward + InvalidStoreMessage, + /// The envelope version is invalid + InvalidEnvelopeVersion, + MalformedNodeId(ByteArrayError), + /// NodeDistance threshold was invalid + InvalidNodeDistanceThreshold, + /// DHT message type should not have been forwarded + InvalidDhtMessageType, + /// Failed to send request for store and forward messages + #[error(no_from)] + RequestMessagesFailed(DhtOutboundError), } diff --git a/comms/dht/src/store_forward/forward.rs b/comms/dht/src/store_forward/forward.rs index 254e1accbd..8d10c9f668 100644 --- a/comms/dht/src/store_forward/forward.rs +++ b/comms/dht/src/store_forward/forward.rs @@ -30,22 +30,28 @@ use crate::{ use futures::{task::Context, Future}; use log::*; use std::{sync::Arc, task::Poll}; -use tari_comms::{peer_manager::PeerManager, pipeline::PipelineError, types::CommsPublicKey}; +use tari_comms::{ + peer_manager::{Peer, PeerManager}, + pipeline::PipelineError, + types::CommsPublicKey, +}; use tower::{layer::Layer, Service, ServiceExt}; -const LOG_TARGET: &str = "comms::store_forward::forward"; +const LOG_TARGET: &str = "comms::dht::storeforward::forward"; /// This layer is responsible for forwarding messages which have failed to decrypt pub struct ForwardLayer { peer_manager: Arc, outbound_service: OutboundMessageRequester, + is_enabled: bool, } impl ForwardLayer { - pub fn new(peer_manager: Arc, outbound_service: OutboundMessageRequester) -> Self { + pub fn new(peer_manager: Arc, outbound_service: OutboundMessageRequester, is_enabled: bool) -> Self { Self { peer_manager, outbound_service, + is_enabled, } } } @@ -59,6 +65,7 @@ impl Layer for ForwardLayer { // Pass in just the config item needed by the middleware for almost free copies Arc::clone(&self.peer_manager), self.outbound_service.clone(), + self.is_enabled, ) } } @@ -71,22 +78,28 @@ pub struct ForwardMiddleware { next_service: S, peer_manager: Arc, outbound_service: OutboundMessageRequester, + is_enabled: bool, } impl ForwardMiddleware { - pub fn new(service: S, peer_manager: Arc, outbound_service: OutboundMessageRequester) -> Self { + pub fn new( + service: S, + peer_manager: Arc, + outbound_service: OutboundMessageRequester, + is_enabled: bool, + ) -> Self + { Self { next_service: service, peer_manager, outbound_service, + is_enabled, } } } impl Service for ForwardMiddleware -where - S: Service + Clone + 'static, - S::Error: std::error::Error + Send + Sync + 'static, +where S: Service + Clone + 'static { type Error = PipelineError; type Response = (); @@ -97,13 +110,20 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, msg: DecryptedDhtMessage) -> Self::Future { - Forwarder::new( - self.next_service.clone(), - Arc::clone(&self.peer_manager), - self.outbound_service.clone(), - ) - .handle(msg) + fn call(&mut self, message: DecryptedDhtMessage) -> Self::Future { + let next_service = self.next_service.clone(); + let peer_manager = Arc::clone(&self.peer_manager); + let outbound_service = self.outbound_service.clone(); + let is_enabled = self.is_enabled; + async move { + if !is_enabled { + trace!(target: LOG_TARGET, "Passing message to next service (Not enabled)"); + return next_service.oneshot(message).await; + } + + let forwarder = Forwarder::new(next_service, peer_manager, outbound_service); + forwarder.handle(message).await + } } } @@ -126,9 +146,7 @@ impl Forwarder { } impl Forwarder -where - S: Service, - S::Error: std::error::Error + Send + Sync + 'static, +where S: Service { async fn handle(mut self, message: DecryptedDhtMessage) -> Result<(), PipelineError> { if message.decryption_failed() { @@ -138,10 +156,7 @@ where // The message has been forwarded, but other middleware may be interested (i.e. StoreMiddleware) trace!(target: LOG_TARGET, "Passing message to next service"); - self.next_service - .oneshot(message) - .await - .map_err(PipelineError::from_debug)?; + self.next_service.oneshot(message).await?; Ok(()) } @@ -150,18 +165,36 @@ where source_peer, decryption_result, dht_header, + authenticated_origin, .. } = message; + if self.destination_matches_source(&dht_header.destination, &source_peer) { + // TODO: #banheuristic - the origin of this message was the destination. Two things are wrong here: + // 1. The origin/destination should not have forwarded this (the destination node didnt do + // this destination_matches_source check) + // 1. The source sent a message that the destination could not decrypt + // The authenticated source should be banned (malicious), and origin should be temporarily banned + // (bug?) + warn!( + target: LOG_TARGET, + "Received message from peer '{}' that is destined for that peer. Discarding message", + source_peer.node_id.short_str() + ); + return Ok(()); + } + let body = decryption_result .clone() .err() .expect("previous check that decryption failed"); - let mut message_params = self - .get_send_params(&dht_header, vec![source_peer.public_key.clone()]) - .await?; + let mut excluded_peers = vec![source_peer.public_key.clone()]; + if let Some(pk) = authenticated_origin.as_ref() { + excluded_peers.push(pk.clone()); + } + let mut message_params = self.get_send_params(&dht_header, excluded_peers).await?; message_params.with_dht_header(dht_header.clone()); self.outbound_service.send_raw(message_params.finish(), body).await?; @@ -223,6 +256,18 @@ where Ok(params) } + + fn destination_matches_source(&self, destination: &NodeDestination, source: &Peer) -> bool { + if let Some(pk) = destination.public_key() { + return pk == &source.public_key; + } + + if let Some(node_id) = destination.node_id() { + return node_id == &source.node_id; + } + + false + } } #[cfg(test)] @@ -243,10 +288,15 @@ mod test { let peer_manager = make_peer_manager(); let (oms_tx, mut oms_rx) = mpsc::channel(1); let oms = OutboundMessageRequester::new(oms_tx); - let mut service = ForwardLayer::new(peer_manager, oms).layer(spy.to_service::()); - - let inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty()); - let msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(Vec::new()).unwrap(), inbound_msg); + let mut service = ForwardLayer::new(peer_manager, oms, true).layer(spy.to_service::()); + + let node_identity = make_node_identity(); + let inbound_msg = make_dht_inbound_message(&node_identity, b"".to_vec(), DhtMessageFlags::empty(), false); + let msg = DecryptedDhtMessage::succeeded( + wrap_in_envelope_body!(Vec::new()), + Some(node_identity.public_key().clone()), + inbound_msg, + ); block_on(service.call(msg)).unwrap(); assert!(spy.is_called()); assert!(oms_rx.try_next().is_err()); @@ -261,16 +311,25 @@ mod test { let oms_mock_state = oms_mock.get_state(); rt.spawn(oms_mock.run()); - let mut service = ForwardLayer::new(peer_manager, oms_requester).layer(spy.to_service::()); + let mut service = ForwardLayer::new(peer_manager, oms_requester, true).layer(spy.to_service::()); - let inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty()); + let sample_body = b"Lorem ipsum"; + let inbound_msg = make_dht_inbound_message( + &make_node_identity(), + sample_body.to_vec(), + DhtMessageFlags::empty(), + false, + ); + let header = inbound_msg.dht_header.clone(); let msg = DecryptedDhtMessage::failed(inbound_msg); rt.block_on(service.call(msg)).unwrap(); assert!(spy.is_called()); assert_eq!(oms_mock_state.call_count(), 1); - let (params, _) = oms_mock_state.pop_call().unwrap(); + let (params, body) = oms_mock_state.pop_call().unwrap(); - assert!(params.dht_header.is_some()); + // Header and body are preserved when forwarding + assert_eq!(&body.to_vec(), &sample_body); + assert_eq!(params.dht_header.unwrap(), header); } } diff --git a/comms/dht/src/store_forward/message.rs b/comms/dht/src/store_forward/message.rs index 3c98383b82..a9d141140e 100644 --- a/comms/dht/src/store_forward/message.rs +++ b/comms/dht/src/store_forward/message.rs @@ -21,40 +21,80 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - envelope::DhtMessageHeader, - proto::store_forward::{StoredMessage, StoredMessagesRequest, StoredMessagesResponse}, + proto::{ + envelope::DhtHeader, + store_forward::{StoredMessage, StoredMessagesRequest, StoredMessagesResponse}, + }, + store_forward::{database, StoreAndForwardError}, }; -use chrono::{DateTime, Utc}; +use chrono::{DateTime, NaiveDateTime, Utc}; +use prost::Message; use prost_types::Timestamp; +use rand::{rngs::OsRng, RngCore}; +use std::{ + cmp, + convert::{TryFrom, TryInto}, +}; -/// Utility function that converts a `chrono::DateTime` to a `prost::Timestamp` +/// Utility function that converts a `chrono::DateTime` to a `prost::Timestamp` pub(crate) fn datetime_to_timestamp(datetime: DateTime) -> Timestamp { Timestamp { seconds: datetime.timestamp(), - nanos: datetime.timestamp_subsec_nanos() as i32, + nanos: datetime.timestamp_subsec_nanos().try_into().unwrap_or(std::i32::MAX), } } +/// Utility function that converts a `prost::Timestamp` to a `chrono::DateTime` +pub(crate) fn timestamp_to_datetime(timestamp: Timestamp) -> DateTime { + let naive = NaiveDateTime::from_timestamp(timestamp.seconds, cmp::max(0, timestamp.nanos) as u32); + DateTime::from_utc(naive, Utc) +} + impl StoredMessagesRequest { + pub fn new() -> Self { + Self { + since: None, + request_id: OsRng.next_u32(), + dist_threshold: Vec::new(), + } + } + + #[allow(unused)] pub fn since(since: DateTime) -> Self { Self { since: Some(datetime_to_timestamp(since)), + request_id: OsRng.next_u32(), + dist_threshold: Vec::new(), } } } +#[cfg(test)] impl StoredMessage { - pub fn new(version: u32, dht_header: DhtMessageHeader, encrypted_body: Vec) -> Self { + pub fn new(version: u32, dht_header: crate::envelope::DhtMessageHeader, body: Vec) -> Self { Self { version, dht_header: Some(dht_header.into()), - encrypted_body, + body, stored_at: Some(datetime_to_timestamp(Utc::now())), } } +} - pub fn has_required_fields(&self) -> bool { - self.dht_header.is_some() +impl TryFrom for StoredMessage { + type Error = StoreAndForwardError; + + fn try_from(message: database::StoredMessage) -> Result { + let dht_header = DhtHeader::decode(message.header.as_slice())?; + Ok(Self { + stored_at: Some(datetime_to_timestamp(DateTime::from_utc(message.stored_at, Utc))), + version: message + .version + .try_into() + .map_err(|_| StoreAndForwardError::InvalidEnvelopeVersion)?, + body: message.body, + dht_header: Some(dht_header), + }) } } @@ -64,8 +104,8 @@ impl StoredMessagesResponse { } } -impl From> for StoredMessagesResponse { - fn from(messages: Vec) -> Self { - Self { messages } - } +#[derive(Debug, Copy, Clone)] +pub enum StoredMessagePriority { + Low = 1, + High = 10, } diff --git a/comms/dht/src/store_forward/mod.rs b/comms/dht/src/store_forward/mod.rs index 9b9e08b282..aa8d9a91e9 100644 --- a/comms/dht/src/store_forward/mod.rs +++ b/comms/dht/src/store_forward/mod.rs @@ -20,17 +20,24 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +type SafResult = Result; + +mod service; +pub use service::{StoreAndForwardRequest, StoreAndForwardRequester, StoreAndForwardService}; + +mod database; +pub use database::StoredMessage; + mod error; +pub use error::StoreAndForwardError; + mod forward; +pub use forward::ForwardLayer; + mod message; + mod saf_handler; -mod state; -mod store; +pub use saf_handler::MessageHandlerLayer; -pub use self::{ - error::StoreAndForwardError, - forward::ForwardLayer, - saf_handler::MessageHandlerLayer, - state::SafStorage, - store::StoreLayer, -}; +mod store; +pub use store::StoreLayer; diff --git a/comms/dht/src/store_forward/saf_handler/layer.rs b/comms/dht/src/store_forward/saf_handler/layer.rs index 6622a3b713..d829d46838 100644 --- a/comms/dht/src/store_forward/saf_handler/layer.rs +++ b/comms/dht/src/store_forward/saf_handler/layer.rs @@ -21,14 +21,19 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::middleware::MessageHandlerMiddleware; -use crate::{actor::DhtRequester, config::DhtConfig, outbound::OutboundMessageRequester, store_forward::SafStorage}; +use crate::{ + actor::DhtRequester, + config::DhtConfig, + outbound::OutboundMessageRequester, + store_forward::StoreAndForwardRequester, +}; use std::sync::Arc; use tari_comms::peer_manager::{NodeIdentity, PeerManager}; use tower::layer::Layer; pub struct MessageHandlerLayer { config: DhtConfig, - store: Arc, + saf_requester: StoreAndForwardRequester, dht_requester: DhtRequester, peer_manager: Arc, node_identity: Arc, @@ -38,7 +43,7 @@ pub struct MessageHandlerLayer { impl MessageHandlerLayer { pub fn new( config: DhtConfig, - store: Arc, + saf_requester: StoreAndForwardRequester, dht_requester: DhtRequester, node_identity: Arc, peer_manager: Arc, @@ -47,7 +52,7 @@ impl MessageHandlerLayer { { Self { config, - store, + saf_requester, dht_requester, node_identity, peer_manager, @@ -63,7 +68,7 @@ impl Layer for MessageHandlerLayer { MessageHandlerMiddleware::new( self.config.clone(), service, - Arc::clone(&self.store), + self.saf_requester.clone(), self.dht_requester.clone(), Arc::clone(&self.node_identity), Arc::clone(&self.peer_manager), diff --git a/comms/dht/src/store_forward/saf_handler/middleware.rs b/comms/dht/src/store_forward/saf_handler/middleware.rs index 94736306b0..2a01ac713f 100644 --- a/comms/dht/src/store_forward/saf_handler/middleware.rs +++ b/comms/dht/src/store_forward/saf_handler/middleware.rs @@ -26,7 +26,7 @@ use crate::{ config::DhtConfig, inbound::DecryptedDhtMessage, outbound::OutboundMessageRequester, - store_forward::SafStorage, + store_forward::StoreAndForwardRequester, }; use futures::{task::Context, Future}; use std::{sync::Arc, task::Poll}; @@ -40,7 +40,7 @@ use tower::Service; pub struct MessageHandlerMiddleware { config: DhtConfig, next_service: S, - store: Arc, + saf_requester: StoreAndForwardRequester, dht_requester: DhtRequester, peer_manager: Arc, node_identity: Arc, @@ -51,7 +51,7 @@ impl MessageHandlerMiddleware { pub fn new( config: DhtConfig, next_service: S, - store: Arc, + saf_requester: StoreAndForwardRequester, dht_requester: DhtRequester, node_identity: Arc, peer_manager: Arc, @@ -60,7 +60,7 @@ impl MessageHandlerMiddleware { { Self { config, - store, + saf_requester, dht_requester, next_service, node_identity, @@ -86,7 +86,7 @@ where S: Service + Cl MessageHandlerTask::new( self.config.clone(), self.next_service.clone(), - Arc::clone(&self.store), + self.saf_requester.clone(), self.dht_requester.clone(), Arc::clone(&self.peer_manager), self.outbound_service.clone(), diff --git a/comms/dht/src/store_forward/saf_handler/task.rs b/comms/dht/src/store_forward/saf_handler/task.rs index 0d18a7b7ce..de9f1abfee 100644 --- a/comms/dht/src/store_forward/saf_handler/task.rs +++ b/comms/dht/src/store_forward/saf_handler/task.rs @@ -24,14 +24,25 @@ use crate::{ actor::DhtRequester, config::DhtConfig, crypt, - envelope::{Destination, DhtMessageFlags, DhtMessageHeader, DhtMessageOrigin, NodeDestination}, + envelope::{DhtMessageFlags, DhtMessageHeader, NodeDestination}, inbound::{DecryptedDhtMessage, DhtInboundMessage}, outbound::{OutboundMessageRequester, SendMessageParams}, proto::{ - envelope::DhtMessageType, - store_forward::{StoredMessage, StoredMessagesRequest, StoredMessagesResponse}, + envelope::{DhtMessageType, OriginMac}, + store_forward::{ + stored_messages_response::SafResponseType, + StoredMessage as ProtoStoredMessage, + StoredMessagesRequest, + StoredMessagesResponse, + }, + }, + store_forward::{ + error::StoreAndForwardError, + message::timestamp_to_datetime, + service::FetchStoredMessageQuery, + StoreAndForwardRequester, }, - store_forward::{error::StoreAndForwardError, SafStorage}, + utils::try_convert_all, }; use digest::Digest; use futures::{future, stream, Future, StreamExt}; @@ -39,16 +50,16 @@ use log::*; use prost::Message; use std::{convert::TryInto, sync::Arc}; use tari_comms::{ - message::EnvelopeBody, - peer_manager::{NodeIdentity, Peer, PeerManager, PeerManagerError}, + message::{EnvelopeBody, MessageTag}, + peer_manager::{node_id::NodeDistance, NodeIdentity, Peer, PeerFeatures, PeerManager, PeerManagerError}, pipeline::PipelineError, - types::Challenge, + types::{Challenge, CommsPublicKey}, utils::signature, }; -use tari_crypto::tari_utilities::ByteArray; +use tari_utilities::ByteArray; use tower::{Service, ServiceExt}; -const LOG_TARGET: &str = "comms::dht::store_forward"; +const LOG_TARGET: &str = "comms::dht::storeforward::handler"; pub struct MessageHandlerTask { config: DhtConfig, @@ -58,7 +69,7 @@ pub struct MessageHandlerTask { outbound_service: OutboundMessageRequester, node_identity: Arc, message: Option, - store: Arc, + saf_requester: StoreAndForwardRequester, } impl MessageHandlerTask @@ -68,7 +79,7 @@ where S: Service pub fn new( config: DhtConfig, next_service: S, - store: Arc, + saf_requester: StoreAndForwardRequester, dht_requester: DhtRequester, peer_manager: Arc, outbound_service: OutboundMessageRequester, @@ -78,7 +89,7 @@ where S: Service { Self { config, - store, + saf_requester, dht_requester, next_service, peer_manager, @@ -94,20 +105,32 @@ where S: Service .take() .expect("DhtInboundMessageTask initialized without message"); - if message.dht_header.message_type.is_dht_message() && message.decryption_failed() { + if message.dht_header.message_type.is_saf_message() && message.decryption_failed() { debug!( target: LOG_TARGET, - "Received SAFRetrieveMessages message which could not decrypt from NodeId={}. Discarding message.", + "Received store and forward message which could not decrypt from NodeId={}. Discarding message.", message.source_peer.node_id ); return Ok(()); } match message.dht_header.message_type { - DhtMessageType::SafRequestMessages => self - .handle_stored_messages_request(message) - .await - .map_err(PipelineError::from_debug)?, + DhtMessageType::SafRequestMessages => { + if self.node_identity.has_peer_features(PeerFeatures::DHT_STORE_FORWARD) { + self.handle_stored_messages_request(message) + .await + .map_err(PipelineError::from_debug)? + } else { + // TODO: #banheuristics - requester should not have requested store and forward messages from this + // node + info!( + target: LOG_TARGET, + "Received store and forward request from peer '{}' however, this node is not a store and \ + forward node. Request ignored.", + message.source_peer.node_id.short_str() + ); + } + }, DhtMessageType::SafStoredMessages => self .handle_stored_messages(message) @@ -116,9 +139,7 @@ where S: Service // Not a SAF message, call downstream middleware _ => { trace!(target: LOG_TARGET, "Passing message onto next service"); - if let Err(err) = self.next_service.oneshot(message).await { - return Err(PipelineError::from_debug(err)); - } + self.next_service.oneshot(message).await?; }, } @@ -143,70 +164,68 @@ where S: Service .decode_part::(0)? .ok_or_else(|| StoreAndForwardError::InvalidEnvelopeBody)?; - if !self - .peer_manager - .in_network_region( - &message.source_peer.node_id, - self.node_identity.node_id(), - self.config.saf_num_closest_nodes, - ) - .await? - { + let source_pubkey = Box::new(message.source_peer.public_key.clone()); + let source_node_id = Box::new(message.source_peer.node_id.clone()); + + // Compile a set of stored messages for the requesting peer + let mut query = FetchStoredMessageQuery::new(source_pubkey, source_node_id.clone()); + + if let Some(since) = retrieve_msgs.since.map(timestamp_to_datetime) { debug!( target: LOG_TARGET, - "Received store and forward message requests from node outside of this nodes network region" + "Peer '{}' requested all messages since '{}'", + source_node_id.short_str(), + since ); - return Ok(()); + query.since(since); } - // Compile a set of stored messages for the requesting peer - let messages = self.store.with_lock(|mut store| { - store - .iter() - // All messages within start_time (if specified) - .filter(|(_, msg)| { - retrieve_msgs.since.as_ref().map(|since| msg.stored_at.as_ref().map(|s| since.seconds <= s.seconds).unwrap_or( false)).unwrap_or( true) - }) - .filter(|(_, msg)|{ - if msg.dht_header.is_none() { - warn!(target: LOG_TARGET, "Message was stored without a header. This should never happen!"); - return false; - } - let dht_header = msg.dht_header.as_ref().expect("previously checked"); - - match &dht_header.destination { - None=> false, - // The stored message was sent with an undisclosed recipient. Perhaps this node - // is interested in it - Some(Destination::Unknown(_)) => true, - // Was the stored message sent for the requesting node public key? - Some(Destination::PublicKey(pk)) => pk.as_slice() == message.source_peer.public_key.as_bytes(), - // Was the stored message sent for the requesting node node id? - Some( Destination::NodeId(node_id)) => node_id.as_slice() == message.source_peer.node_id.as_bytes(), - } - }) - .take(self.config.saf_max_returned_messages) - .map(|(_, msg)| msg) - .cloned() - .collect::>() - }); - - let stored_messages: StoredMessagesResponse = messages.into(); + if !retrieve_msgs.dist_threshold.is_empty() { + let dist_threshold = Box::new( + NodeDistance::from_bytes(&retrieve_msgs.dist_threshold) + .map_err(|_| StoreAndForwardError::InvalidNodeDistanceThreshold)?, + ); + query.with_dist_threshold(dist_threshold); + } - trace!( - target: LOG_TARGET, - "Responding to received message retrieval request with {} message(s)", - stored_messages.messages().len() - ); - self.outbound_service - .send_message_no_header( - SendMessageParams::new() - .direct_public_key(message.source_peer.public_key.clone()) - .with_dht_message_type(DhtMessageType::SafStoredMessages) - .finish(), - stored_messages, - ) - .await?; + let response_types = vec![SafResponseType::ForMe]; + + for resp_type in response_types { + query.with_response_type(resp_type); + let messages = self.saf_requester.fetch_messages(query.clone()).await?; + + if messages.is_empty() { + info!( + target: LOG_TARGET, + "No {:?} stored messages for peer '{}'", + resp_type, + message.source_peer.node_id.short_str() + ); + continue; + } + + let stored_messages = StoredMessagesResponse { + messages: try_convert_all(messages)?, + request_id: retrieve_msgs.request_id, + response_type: resp_type as i32, + }; + + info!( + target: LOG_TARGET, + "Responding to received message retrieval request with {} {:?} message(s)", + stored_messages.messages().len(), + resp_type + ); + self.outbound_service + .send_message_no_header( + SendMessageParams::new() + .direct_public_key(message.source_peer.public_key.clone()) + .with_dht_message_type(DhtMessageType::SafStoredMessages) + .finish(), + stored_messages, + ) + .await?; + } Ok(()) } @@ -226,10 +245,14 @@ where S: Service .ok_or_else(|| StoreAndForwardError::InvalidEnvelopeBody)?; let source_peer = Arc::new(message.source_peer); - debug!( + info!( target: LOG_TARGET, - "Received {} stored messages from peer", - response.messages().len() + "Received {} stored messages of type {} from peer", + response.messages().len(), + SafResponseType::from_i32(response.response_type) + .as_ref() + .map(|t| format!("{:?}", t)) + .unwrap_or("".to_string()), ); let tasks = response @@ -311,7 +334,7 @@ where S: Service fn process_incoming_stored_message( &self, source_peer: Arc, - message: StoredMessage, + message: ProtoStoredMessage, ) -> impl Future> { let node_identity = Arc::clone(&self.node_identity); @@ -330,29 +353,49 @@ where S: Service .try_into() .map_err(StoreAndForwardError::DhtMessageError)?; - let dht_flags = dht_header.flags; - - let origin = dht_header - .origin - .as_ref() - .ok_or_else(|| StoreAndForwardError::MessageOriginRequired)?; + if !dht_header.is_valid() { + return Err(StoreAndForwardError::InvalidDhtHeader); + } + let message_type = dht_header.message_type; + + if message_type.is_dht_message() { + if !message_type.is_dht_discovery() { + warn!( + target: LOG_TARGET, + "Discarding {} message from peer '{}'", + message_type, + source_peer.node_id.short_str() + ); + return Err(StoreAndForwardError::InvalidDhtMessageType); + } + if dht_header.destination.is_unknown() { + warn!( + target: LOG_TARGET, + "Discarding anonymous discovery message from peer '{}'", + source_peer.node_id.short_str() + ); + return Err(StoreAndForwardError::InvalidDhtMessageType); + } + } - // Check that the destination is either undisclosed + // Check that the destination is either undisclosed, for us or for our network region Self::check_destination(&config, &peer_manager, &node_identity, &dht_header).await?; - // Verify the signature - Self::check_signature(origin, &message.encrypted_body)?; // Check that the message has not already been received. - // The current thread runtime is used because calls to the DHT actor are async - // let mut rt = runtime::Builder::new().basic_scheduler().build()?; - Self::check_duplicate(&mut dht_requester, &message.encrypted_body).await?; + Self::check_duplicate(&mut dht_requester, &message.body).await?; // Attempt to decrypt the message (if applicable), and deserialize it - let decrypted_body = - Self::maybe_decrypt_and_deserialize(&node_identity, origin, dht_flags, &message.encrypted_body)?; + let (authenticated_pk, decrypted_body) = + Self::authenticate_and_decrypt_if_required(&node_identity, &dht_header, &message.body)?; - let inbound_msg = DhtInboundMessage::new(dht_header, Arc::clone(&source_peer), message.encrypted_body); + let mut inbound_msg = + DhtInboundMessage::new(MessageTag::new(), dht_header, Arc::clone(&source_peer), message.body); + inbound_msg.is_saf_message = true; - Ok(DecryptedDhtMessage::succeeded(decrypted_body, inbound_msg)) + Ok(DecryptedDhtMessage::succeeded( + decrypted_body, + authenticated_pk, + inbound_msg, + )) } } @@ -390,34 +433,65 @@ where S: Service } } - fn check_signature(origin: &DhtMessageOrigin, body: &[u8]) -> Result<(), StoreAndForwardError> { - signature::verify(&origin.public_key, &origin.signature, body) - .map_err(|_| StoreAndForwardError::InvalidSignature) - .and_then(|is_valid| { - if is_valid { - Ok(()) - } else { - Err(StoreAndForwardError::InvalidSignature) - } - }) - } - - fn maybe_decrypt_and_deserialize( + fn authenticate_and_decrypt_if_required( node_identity: &NodeIdentity, - origin: &DhtMessageOrigin, - flags: DhtMessageFlags, + header: &DhtMessageHeader, body: &[u8], - ) -> Result + ) -> Result<(Option, EnvelopeBody), StoreAndForwardError> { - if flags.contains(DhtMessageFlags::ENCRYPTED) { - let shared_secret = crypt::generate_ecdh_secret(node_identity.secret_key(), &origin.public_key); + if header.flags.contains(DhtMessageFlags::ENCRYPTED) { + let ephemeral_public_key = header.ephemeral_public_key.as_ref().expect( + "[store and forward] DHT header is invalid after validity check because it did not contain an \ + ephemeral_public_key", + ); + + trace!( + target: LOG_TARGET, + "Attempting to decrypt origin mac ({} byte(s))", + header.origin_mac.len() + ); + let shared_secret = crypt::generate_ecdh_secret(node_identity.secret_key(), ephemeral_public_key); + let decrypted = crypt::decrypt(&shared_secret, &header.origin_mac)?; + let authenticated_pk = Self::authenticate_message(&decrypted, body)?; + + trace!( + target: LOG_TARGET, + "Attempting to decrypt message body ({} byte(s))", + body.len() + ); let decrypted_bytes = crypt::decrypt(&shared_secret, body)?; - EnvelopeBody::decode(decrypted_bytes.as_slice()).map_err(|_| StoreAndForwardError::DecryptionFailed) + let envelope_body = + EnvelopeBody::decode(decrypted_bytes.as_slice()).map_err(|_| StoreAndForwardError::DecryptionFailed)?; + if envelope_body.is_empty() { + return Err(StoreAndForwardError::InvalidEnvelopeBody); + } + Ok((Some(authenticated_pk), envelope_body)) } else { - // Malformed cleartext messages should never have been forwarded by the peer - EnvelopeBody::decode(body).map_err(|_| StoreAndForwardError::MalformedMessage) + let authenticated_pk = if !header.origin_mac.is_empty() { + Some(Self::authenticate_message(&header.origin_mac, body)?) + } else { + None + }; + let envelope_body = EnvelopeBody::decode(body).map_err(|_| StoreAndForwardError::MalformedMessage)?; + Ok((authenticated_pk, envelope_body)) } } + + fn authenticate_message(origin_mac_body: &[u8], body: &[u8]) -> Result { + let origin_mac = OriginMac::decode(origin_mac_body)?; + let public_key = + CommsPublicKey::from_bytes(&origin_mac.public_key).map_err(|_| StoreAndForwardError::InvalidOriginMac)?; + signature::verify(&public_key, &origin_mac.signature, body) + .map_err(|_| StoreAndForwardError::InvalidOriginMac) + .and_then(|is_valid| { + if is_valid { + Ok(()) + } else { + Err(StoreAndForwardError::InvalidOriginMac) + } + })?; + Ok(public_key) + } } #[cfg(test)] @@ -425,10 +499,14 @@ mod test { use super::*; use crate::{ envelope::DhtMessageFlags, - store_forward::message::datetime_to_timestamp, + proto::envelope::DhtHeader, + store_forward::{message::StoredMessagePriority, StoredMessage}, test_utils::{ create_dht_actor_mock, + create_store_and_forward_mock, + make_dht_header, make_dht_inbound_message, + make_keypair, make_node_identity, make_peer_manager, service_spy, @@ -438,17 +516,33 @@ mod test { use chrono::Utc; use futures::channel::mpsc; use prost::Message; - use std::time::Duration; use tari_comms::{message::MessageExt, wrap_in_envelope_body}; + use tari_utilities::hex::Hex; use tokio::runtime::Handle; // TODO: unit tests for static functions (check_signature, etc) + fn make_stored_message(node_identity: &NodeIdentity, dht_header: DhtMessageHeader) -> StoredMessage { + StoredMessage { + id: 1, + version: 0, + origin_pubkey: Some(node_identity.public_key().to_hex()), + message_type: DhtMessageType::None as i32, + destination_pubkey: None, + destination_node_id: None, + header: DhtHeader::from(dht_header).to_encoded_bytes(), + body: b"A".to_vec(), + is_encrypted: false, + priority: StoredMessagePriority::High as i32, + stored_at: Utc::now().naive_utc(), + } + } + #[tokio_macros::test_basic] async fn request_stored_messages() { let rt_handle = Handle::current(); let spy = service_spy(); - let storage = Arc::new(SafStorage::new(10)); + let (requester, mock_state) = create_store_and_forward_mock(); let peer_manager = make_peer_manager(); let (oms_tx, mut oms_rx) = mpsc::channel(1); @@ -456,34 +550,22 @@ mod test { let node_identity = make_node_identity(); // Recent message - let inbound_msg = make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::empty()); - storage.insert( - vec![0], - StoredMessage::new(0, inbound_msg.dht_header, b"A".to_vec()), - Duration::from_secs(60), - ); - - // Expired message - let inbound_msg = make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::empty()); - storage.insert( - vec![1], - StoredMessage::new(0, inbound_msg.dht_header, vec![]), - Duration::from_secs(0), - ); - - // Out of time range - let inbound_msg = make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::empty()); - let mut msg = StoredMessage::new(0, inbound_msg.dht_header, vec![]); - msg.stored_at = Some(datetime_to_timestamp( - Utc::now().checked_sub_signed(chrono::Duration::days(1)).unwrap(), - )); + let (e_sk, e_pk) = make_keypair(); + let dht_header = make_dht_header(&node_identity, &e_pk, &e_sk, &[], DhtMessageFlags::empty(), false); + mock_state + .add_message(make_stored_message(&node_identity, dht_header)) + .await; + let since = Utc::now().checked_sub_signed(chrono::Duration::seconds(60)).unwrap(); let mut message = DecryptedDhtMessage::succeeded( - wrap_in_envelope_body!(StoredMessagesRequest::since( - Utc::now().checked_sub_signed(chrono::Duration::seconds(60)).unwrap() - )) - .unwrap(), - make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::ENCRYPTED), + wrap_in_envelope_body!(StoredMessagesRequest::since(since)), + None, + make_dht_inbound_message( + &node_identity, + b"Keep this for others please".to_vec(), + DhtMessageFlags::ENCRYPTED, + true, + ), ); message.dht_header.message_type = DhtMessageType::SafRequestMessages; @@ -493,83 +575,79 @@ mod test { let task = MessageHandlerTask::new( Default::default(), spy.to_service::(), - storage, + requester, dht_requester, peer_manager, OutboundMessageRequester::new(oms_tx), - node_identity, + node_identity.clone(), message, ); rt_handle.spawn(task.run()); let (_, body) = unwrap_oms_send_msg!(oms_rx.next().await.unwrap()); + let body = body.to_vec(); let body = EnvelopeBody::decode(body.as_slice()).unwrap(); let msg = body.decode_part::(0).unwrap().unwrap(); assert_eq!(msg.messages().len(), 1); - assert_eq!(msg.messages()[0].encrypted_body, b"A"); + assert_eq!(msg.messages()[0].body, b"A"); assert!(!spy.is_called()); + + assert_eq!(mock_state.call_count(), 1); + let calls = mock_state.take_calls().await; + assert!(calls[0].contains("FetchMessages")); + assert!(calls[0].contains(node_identity.public_key().to_hex().as_str())); + assert!(calls[0].contains(format!("{:?}", since).as_str())); } #[tokio_macros::test_basic] async fn receive_stored_messages() { let rt_handle = Handle::current(); let spy = service_spy(); - let storage = Arc::new(SafStorage::new(10)); + let (requester, _) = create_store_and_forward_mock(); let peer_manager = make_peer_manager(); let (oms_tx, _) = mpsc::channel(1); let node_identity = make_node_identity(); - let shared_key = crypt::generate_ecdh_secret(node_identity.secret_key(), node_identity.public_key()); - let msg_a = crypt::encrypt( - &shared_key, - &wrap_in_envelope_body!(&b"A".to_vec()) - .unwrap() - .to_encoded_bytes() - .unwrap(), - ) - .unwrap(); - - let inbound_msg_a = make_dht_inbound_message(&node_identity, msg_a.clone(), DhtMessageFlags::ENCRYPTED); + let msg_a = wrap_in_envelope_body!(&b"A".to_vec()).to_encoded_bytes(); + + let inbound_msg_a = make_dht_inbound_message(&node_identity, msg_a.clone(), DhtMessageFlags::ENCRYPTED, true); // Need to know the peer to process a stored message peer_manager .add_peer(Clone::clone(&*inbound_msg_a.source_peer)) .await .unwrap(); - let msg_b = crypt::encrypt( - &shared_key, - &wrap_in_envelope_body!(b"B".to_vec()) - .unwrap() - .to_encoded_bytes() - .unwrap(), - ) - .unwrap(); - - let inbound_msg_b = make_dht_inbound_message(&node_identity, msg_b.clone(), DhtMessageFlags::ENCRYPTED); + + let msg_b = &wrap_in_envelope_body!(b"B".to_vec()).to_encoded_bytes(); + let inbound_msg_b = make_dht_inbound_message(&node_identity, msg_b.clone(), DhtMessageFlags::ENCRYPTED, true); // Need to know the peer to process a stored message peer_manager .add_peer(Clone::clone(&*inbound_msg_b.source_peer)) .await .unwrap(); - let msg1 = StoredMessage::new(0, inbound_msg_a.dht_header.clone(), msg_a); - let msg2 = StoredMessage::new(0, inbound_msg_b.dht_header, msg_b); + let msg1 = ProtoStoredMessage::new(0, inbound_msg_a.dht_header.clone(), inbound_msg_a.body); + let msg2 = ProtoStoredMessage::new(0, inbound_msg_b.dht_header, inbound_msg_b.body); // Cleartext message - let clear_msg = wrap_in_envelope_body!(b"Clear".to_vec()) - .unwrap() - .to_encoded_bytes() - .unwrap(); + let clear_msg = wrap_in_envelope_body!(b"Clear".to_vec()).to_encoded_bytes(); let clear_header = - make_dht_inbound_message(&node_identity, clear_msg.clone(), DhtMessageFlags::empty()).dht_header; - let msg_clear = StoredMessage::new(0, clear_header, clear_msg); + make_dht_inbound_message(&node_identity, clear_msg.clone(), DhtMessageFlags::empty(), false).dht_header; + let msg_clear = ProtoStoredMessage::new(0, clear_header, clear_msg); let mut message = DecryptedDhtMessage::succeeded( wrap_in_envelope_body!(StoredMessagesResponse { messages: vec![msg1.clone(), msg2, msg_clear], - }) - .unwrap(), - make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::ENCRYPTED), + request_id: 123, + response_type: 0 + }), + None, + make_dht_inbound_message( + &node_identity, + b"Stored message".to_vec(), + DhtMessageFlags::ENCRYPTED, + true, + ), ); message.dht_header.message_type = DhtMessageType::SafStoredMessages; @@ -581,7 +659,7 @@ mod test { let task = MessageHandlerTask::new( Default::default(), spy.to_service::(), - storage, + requester, dht_requester, peer_manager, OutboundMessageRequester::new(oms_tx), @@ -596,11 +674,10 @@ mod test { // Deserialize each request into the message (a vec of a single byte in this case) let msgs = requests .into_iter() - .map(|req| req.success().unwrap().decode_part::>(0).unwrap().unwrap()) + .map(|req| req.success().unwrap().decode_part::>(0).unwrap().unwrap()) .collect::>>(); assert!(msgs.contains(&b"A".to_vec())); assert!(msgs.contains(&b"B".to_vec())); assert!(msgs.contains(&b"Clear".to_vec())); - assert_eq!(mock_state.call_count(), msgs.len()); } } diff --git a/comms/dht/src/store_forward/service.rs b/comms/dht/src/store_forward/service.rs new file mode 100644 index 0000000000..41b474c0e2 --- /dev/null +++ b/comms/dht/src/store_forward/service.rs @@ -0,0 +1,430 @@ +// Copyright 2020, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use super::{ + database::{NewStoredMessage, StoreAndForwardDatabase, StoredMessage}, + message::StoredMessagePriority, + SafResult, + StoreAndForwardError, +}; +use crate::{ + envelope::DhtMessageType, + outbound::{OutboundMessageRequester, SendMessageParams}, + proto::store_forward::{stored_messages_response::SafResponseType, StoredMessagesRequest}, + storage::{DbConnection, DhtMetadataKey}, + DhtConfig, + DhtRequester, +}; +use chrono::{DateTime, NaiveDateTime, Utc}; +use futures::{ + channel::{mpsc, oneshot}, + stream::Fuse, + SinkExt, + StreamExt, +}; +use log::*; +use std::{convert::TryFrom, sync::Arc, time::Duration}; +use tari_comms::{ + connection_manager::ConnectionManagerRequester, + peer_manager::{node_id::NodeDistance, NodeId, PeerFeatures}, + types::CommsPublicKey, + ConnectionManagerEvent, + NodeIdentity, + PeerManager, +}; +use tari_shutdown::ShutdownSignal; +use tari_utilities::ByteArray; +use tokio::{sync::broadcast, task, time}; + +const LOG_TARGET: &str = "comms::dht::storeforward::actor"; +/// The interval to initiate a database cleanup. +/// This involves cleaning up messages which have been stored too long according to their priority +const CLEANUP_INTERVAL: Duration = Duration::from_secs(10 * 60); // 10 mins + +#[derive(Debug, Clone)] +pub struct FetchStoredMessageQuery { + public_key: Box, + node_id: Box, + since: Option>, + dist_threshold: Option>, + response_type: SafResponseType, +} + +impl FetchStoredMessageQuery { + pub fn new(public_key: Box, node_id: Box) -> Self { + Self { + public_key, + node_id, + since: None, + response_type: SafResponseType::Anonymous, + dist_threshold: None, + } + } + + pub fn since(&mut self, since: DateTime) -> &mut Self { + self.since = Some(since); + self + } + + pub fn with_response_type(&mut self, response_type: SafResponseType) -> &mut Self { + self.response_type = response_type; + self + } + + pub fn with_dist_threshold(&mut self, dist_threshold: Box) -> &mut Self { + self.dist_threshold = Some(dist_threshold); + self + } +} + +#[derive(Debug)] +pub enum StoreAndForwardRequest { + FetchMessages(FetchStoredMessageQuery, oneshot::Sender>>), + InsertMessage(NewStoredMessage), + SendStoreForwardRequestToPeer(Box), + SendStoreForwardRequestNeighbours, +} + +#[derive(Clone)] +pub struct StoreAndForwardRequester { + sender: mpsc::Sender, +} + +impl StoreAndForwardRequester { + pub fn new(sender: mpsc::Sender) -> Self { + Self { sender } + } + + pub async fn fetch_messages(&mut self, request: FetchStoredMessageQuery) -> SafResult> { + let (reply_tx, reply_rx) = oneshot::channel(); + self.sender + .send(StoreAndForwardRequest::FetchMessages(request, reply_tx)) + .await + .map_err(|_| StoreAndForwardError::RequesterChannelClosed)?; + reply_rx.await.map_err(|_| StoreAndForwardError::RequestCancelled)? + } + + pub async fn insert_message(&mut self, message: NewStoredMessage) -> SafResult<()> { + self.sender + .send(StoreAndForwardRequest::InsertMessage(message)) + .await + .map_err(|_| StoreAndForwardError::RequesterChannelClosed)?; + Ok(()) + } + + pub async fn request_saf_messages_from_peer(&mut self, node_id: NodeId) -> SafResult<()> { + self.sender + .send(StoreAndForwardRequest::SendStoreForwardRequestToPeer(Box::new(node_id))) + .await + .map_err(|_| StoreAndForwardError::RequesterChannelClosed)?; + Ok(()) + } + + pub async fn request_saf_messages_from_neighbours(&mut self) -> SafResult<()> { + self.sender + .send(StoreAndForwardRequest::SendStoreForwardRequestNeighbours) + .await + .map_err(|_| StoreAndForwardError::RequesterChannelClosed)?; + Ok(()) + } +} + +pub struct StoreAndForwardService { + config: DhtConfig, + node_identity: Arc, + dht_requester: DhtRequester, + database: StoreAndForwardDatabase, + peer_manager: Arc, + connection_events: Fuse>>, + outbound_requester: OutboundMessageRequester, + request_rx: Fuse>, + shutdown_signal: Option, +} + +impl StoreAndForwardService { + pub fn new( + config: DhtConfig, + conn: DbConnection, + node_identity: Arc, + peer_manager: Arc, + dht_requester: DhtRequester, + connection_manager: ConnectionManagerRequester, + outbound_requester: OutboundMessageRequester, + request_rx: mpsc::Receiver, + shutdown_signal: ShutdownSignal, + ) -> Self + { + Self { + config, + database: StoreAndForwardDatabase::new(conn), + node_identity, + peer_manager, + dht_requester, + request_rx: request_rx.fuse(), + connection_events: connection_manager.get_event_subscription().fuse(), + outbound_requester, + shutdown_signal: Some(shutdown_signal), + } + } + + pub async fn spawn(self) -> SafResult<()> { + info!(target: LOG_TARGET, "Store and forward service started"); + task::spawn(Self::run(self)); + Ok(()) + } + + async fn run(mut self) { + let mut shutdown_signal = self + .shutdown_signal + .take() + .expect("StoreAndForwardActor initialized without shutdown_signal"); + + let mut cleanup_ticker = time::interval(CLEANUP_INTERVAL).fuse(); + + loop { + futures::select! { + request = self.request_rx.select_next_some() => { + self.handle_request(request).await; + }, + + event = self.connection_events.select_next_some() => { + if let Ok(event) = event { + if let Err(err) = self.handle_connection_manager_event(&event).await { + error!(target: LOG_TARGET, "Error handling connection manager event: {:?}", err); + } + } + }, + + _ = cleanup_ticker.select_next_some() => { + if let Err(err) = self.cleanup().await { + error!(target: LOG_TARGET, "Error when performing store and forward cleanup: {:?}", err); + } + }, + + _ = shutdown_signal => { + info!(target: LOG_TARGET, "StoreAndForwardActor is shutting down because the shutdown signal was triggered"); + break; + } + } + } + } + + async fn handle_request(&mut self, request: StoreAndForwardRequest) { + use StoreAndForwardRequest::*; + trace!(target: LOG_TARGET, "Request: {:?}", request); + match request { + FetchMessages(query, reply_tx) => match self.handle_fetch_message_query(query).await { + Ok(messages) => { + let _ = reply_tx.send(Ok(messages)); + }, + Err(err) => { + error!( + target: LOG_TARGET, + "Failed to fetch stored messages because '{:?}'", err + ); + let _ = reply_tx.send(Err(err)); + }, + }, + InsertMessage(msg) => { + let public_key = msg.destination_pubkey.clone(); + let node_id = msg.destination_node_id.clone(); + match self.database.insert_message(msg).await { + Ok(_) => info!( + target: LOG_TARGET, + "Stored message for {}", + public_key + .map(|p| format!("public key '{}'", p)) + .or_else(|| node_id.map(|n| format!("node id '{}'", n))) + .unwrap_or_else(|| "".to_string()) + ), + Err(err) => { + error!(target: LOG_TARGET, "InsertMessage failed because '{:?}'", err); + }, + } + }, + SendStoreForwardRequestToPeer(node_id) => { + if let Err(err) = self.request_stored_messages_from_peer(&node_id).await { + error!(target: LOG_TARGET, "Error sending store and forward request: {:?}", err); + } + }, + SendStoreForwardRequestNeighbours => { + if let Err(err) = self.request_stored_messages_neighbours().await { + error!( + target: LOG_TARGET, + "Error sending store and forward request to neighbours: {:?}", err + ); + } + }, + } + } + + async fn handle_connection_manager_event(&mut self, event: &ConnectionManagerEvent) -> SafResult<()> { + use ConnectionManagerEvent::*; + if !self.config.saf_auto_request { + debug!( + target: LOG_TARGET, + "Auto store and forward request disabled. Ignoring connection manager event" + ); + return Ok(()); + } + + match event { + PeerConnected(conn) => { + // Whenever we connect to a peer, request SAF messages + let features = self.peer_manager.get_peer_features(conn.peer_node_id()).await?; + if features.contains(PeerFeatures::DHT_STORE_FORWARD) { + info!( + target: LOG_TARGET, + "Connected peer '{}' is a SAF node. Requesting stored messages.", + conn.peer_node_id().short_str() + ); + self.request_stored_messages_from_peer(conn.peer_node_id()).await?; + } + }, + _ => {}, + } + + Ok(()) + } + + async fn request_stored_messages_from_peer(&mut self, node_id: &NodeId) -> SafResult<()> { + let request = self.get_saf_request().await?; + info!( + target: LOG_TARGET, + "Sending store and forward request to peer '{}' (Since = {:?})", node_id, request.since + ); + + self.outbound_requester + .send_message_no_header( + SendMessageParams::new() + .direct_node_id(node_id.clone()) + .with_dht_message_type(DhtMessageType::SafRequestMessages) + .finish(), + request, + ) + .await + .map_err(StoreAndForwardError::RequestMessagesFailed)?; + + Ok(()) + } + + async fn request_stored_messages_neighbours(&mut self) -> SafResult<()> { + let request = self.get_saf_request().await?; + info!( + target: LOG_TARGET, + "Sending store and forward request to neighbours (Since = {:?})", request.since + ); + self.outbound_requester + .send_message_no_header( + SendMessageParams::new() + .neighbours(vec![]) + .with_dht_message_type(DhtMessageType::SafRequestMessages) + .finish(), + request, + ) + .await + .map_err(StoreAndForwardError::RequestMessagesFailed)?; + + Ok(()) + } + + async fn get_saf_request(&mut self) -> SafResult { + let mut request = self + .dht_requester + .get_metadata(DhtMetadataKey::OfflineTimestamp) + .await? + .map(StoredMessagesRequest::since) + .unwrap_or_else(StoredMessagesRequest::new); + + // Calculate the network region threshold for our node id. + // i.e. "Give me all messages that are this close to my node ID" + let threshold = self + .peer_manager + .calc_region_threshold( + self.node_identity.node_id(), + self.config.num_neighbouring_nodes, + PeerFeatures::DHT_STORE_FORWARD, + ) + .await?; + + request.dist_threshold = threshold.to_vec(); + + Ok(request) + } + + async fn handle_fetch_message_query(&self, query: FetchStoredMessageQuery) -> SafResult> { + use SafResponseType::*; + let limit = i64::try_from(self.config.saf_max_returned_messages) + .ok() + .or(Some(std::i64::MAX)) + .unwrap(); + let db = &self.database; + let messages = match query.response_type { + ForMe => { + db.find_messages_for_peer(&query.public_key, &query.node_id, query.since, limit) + .await? + }, + Join => db.find_join_messages(query.since, limit).await?, + Discovery => { + db.find_messages_of_type_for_pubkey(&query.public_key, DhtMessageType::Discovery, query.since, limit) + .await? + }, + Anonymous => db.find_anonymous_messages(query.since, limit).await?, + InRegion => { + db.find_regional_messages(&query.node_id, query.dist_threshold, query.since, limit) + .await? + }, + }; + + Ok(messages) + } + + async fn cleanup(&self) -> SafResult<()> { + let num_removed = self + .database + .delete_messages_with_priority_older_than( + StoredMessagePriority::Low, + since(self.config.saf_low_priority_msg_storage_ttl), + ) + .await?; + info!(target: LOG_TARGET, "Cleaned {} old low priority messages", num_removed); + + let num_removed = self + .database + .delete_messages_with_priority_older_than( + StoredMessagePriority::High, + since(self.config.saf_high_priority_msg_storage_ttl), + ) + .await?; + info!(target: LOG_TARGET, "Cleaned {} old high priority messages", num_removed); + Ok(()) + } +} + +fn since(period: Duration) -> NaiveDateTime { + use chrono::Duration as OldDuration; + let period = OldDuration::from_std(period).expect("period was out of range for chrono::Duration"); + Utc::now() + .naive_utc() + .checked_sub_signed(period) + .expect("period overflowed when used with checked_sub_signed") +} diff --git a/comms/dht/src/store_forward/store.rs b/comms/dht/src/store_forward/store.rs index da37f18871..98ef4f15b8 100644 --- a/comms/dht/src/store_forward/store.rs +++ b/comms/dht/src/store_forward/store.rs @@ -20,31 +20,35 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use super::StoreAndForwardRequester; use crate::{ - envelope::{DhtMessageFlags, NodeDestination}, + envelope::NodeDestination, inbound::DecryptedDhtMessage, - proto::store_forward::StoredMessage, - store_forward::{error::StoreAndForwardError, state::SafStorage}, + store_forward::{ + database::NewStoredMessage, + error::StoreAndForwardError, + message::StoredMessagePriority, + SafResult, + }, DhtConfig, }; use futures::{task::Context, Future}; use log::*; use std::{sync::Arc, task::Poll}; use tari_comms::{ - message::MessageExt, - peer_manager::{NodeIdentity, PeerManager}, + peer_manager::{NodeIdentity, PeerFeatures, PeerManager}, pipeline::PipelineError, }; use tower::{layer::Layer, Service, ServiceExt}; -const LOG_TARGET: &str = "comms::middleware::forward"; +const LOG_TARGET: &str = "comms::dht::storeforward::store"; /// This layer is responsible for storing messages which have failed to decrypt pub struct StoreLayer { peer_manager: Arc, config: DhtConfig, node_identity: Arc, - storage: Arc, + saf_requester: StoreAndForwardRequester, } impl StoreLayer { @@ -52,14 +56,14 @@ impl StoreLayer { config: DhtConfig, peer_manager: Arc, node_identity: Arc, - storage: Arc, + saf_requester: StoreAndForwardRequester, ) -> Self { Self { peer_manager, config, node_identity, - storage, + saf_requester, } } } @@ -73,7 +77,7 @@ impl Layer for StoreLayer { self.config.clone(), Arc::clone(&self.peer_manager), Arc::clone(&self.node_identity), - Arc::clone(&self.storage), + self.saf_requester.clone(), ) } } @@ -84,8 +88,7 @@ pub struct StoreMiddleware { config: DhtConfig, peer_manager: Arc, node_identity: Arc, - - storage: Arc, + saf_requester: StoreAndForwardRequester, } impl StoreMiddleware { @@ -94,7 +97,7 @@ impl StoreMiddleware { config: DhtConfig, peer_manager: Arc, node_identity: Arc, - storage: Arc, + saf_requester: StoreAndForwardRequester, ) -> Self { Self { @@ -102,15 +105,13 @@ impl StoreMiddleware { config, peer_manager, node_identity, - storage, + saf_requester, } } } impl Service for StoreMiddleware -where - S: Service + Clone + 'static, - S::Error: std::error::Error + Send + Sync + 'static, +where S: Service + Clone + 'static { type Error = PipelineError; type Response = (); @@ -127,7 +128,7 @@ where self.config.clone(), Arc::clone(&self.peer_manager), Arc::clone(&self.node_identity), - Arc::clone(&self.storage), + self.saf_requester.clone(), ) .handle(msg) } @@ -137,7 +138,10 @@ where /// to the next service. struct StoreTask { next_service: S, - storage: Option, + peer_manager: Arc, + config: DhtConfig, + node_identity: Arc, + saf_requester: StoreAndForwardRequester, } impl StoreTask { @@ -146,116 +150,230 @@ impl StoreTask { config: DhtConfig, peer_manager: Arc, node_identity: Arc, - storage: Arc, + saf_requester: StoreAndForwardRequester, ) -> Self { Self { - storage: Some(InnerStorage { - config, - peer_manager, - node_identity, - storage, - }), + config, + peer_manager, + node_identity, + saf_requester, next_service, } } } impl StoreTask -where - S: Service, - S::Error: std::error::Error + Send + Sync + 'static, +where S: Service { + /// Determine if this is a message we should store for our peers and, if so, store it. + /// + /// The criteria for storing a message is: + /// 1. Messages MUST have a message origin set and be encrypted (Join messages are the exception) + /// 1. Unencrypted Join messages - this increases the knowledge the network has of peers (Low priority) + /// 1. Encrypted Discovery messages - so that nodes are aware of other nodes that are looking for them (High + /// priority) 1. Encrypted messages addressed to the neighbourhood - some node in the neighbourhood may be + /// interested in this message (High priority) 1. Encrypted messages addressed to a particular public key or + /// node id that this node knows about async fn handle(mut self, message: DecryptedDhtMessage) -> Result<(), PipelineError> { + if !self.node_identity.features().contains(PeerFeatures::DHT_STORE_FORWARD) { + trace!(target: LOG_TARGET, "Passing message to next service (Not a SAF node)"); + self.next_service.oneshot(message).await?; + return Ok(()); + } + + if let Some(priority) = self + .get_storage_priority(&message) + .await + .map_err(PipelineError::from_debug)? + { + self.store(priority, message.clone()) + .await + .map_err(PipelineError::from_debug)?; + } + + debug!(target: LOG_TARGET, "Passing message {} to next service", message.tag); + self.next_service.oneshot(message).await?; + + Ok(()) + } + + async fn get_storage_priority(&self, message: &DecryptedDhtMessage) -> SafResult> { + let log_not_eligible = |reason: &str| { + debug!( + target: LOG_TARGET, + "Message from peer '{}' not eligible for SAF storage because {}", + message.source_peer.node_id.short_str(), + reason + ); + }; + + if message.body_len() > self.config.saf_max_message_size { + log_not_eligible(&format!( + "the message body exceeded the maximum storage size (body size={}, max={})", + message.body_len(), + self.config.saf_max_message_size + )); + return Ok(None); + } + + if message.dht_header.message_type.is_saf_message() { + log_not_eligible("it is a SAF message"); + return Ok(None); + } + + if message.dht_header.message_type.is_dht_join() { + log_not_eligible("it is a join message"); + return Ok(None); + } + + if message + .authenticated_origin() + .map(|pk| pk == self.node_identity.public_key()) + .unwrap_or(false) + { + log_not_eligible("this message originates from this node"); + return Ok(None); + } + match message.success() { + // The message decryption was successful, or the message was not encrypted Some(_) => { - // If message was not originally encrypted and has an origin we want to store a copy for others - if message.dht_header.origin.is_some() && !message.dht_header.flags.contains(DhtMessageFlags::ENCRYPTED) - { - debug!( - target: LOG_TARGET, - "Cleartext message sent from origin {}. Adding to SAF storage.", - message.origin_public_key() - ); - let mut storage = self.storage.take().expect("StoreTask intialized without storage"); - let msg_clone = message.clone(); - storage.store(msg_clone).await.map_err(PipelineError::from_debug)?; + // If the message doesnt have an origin we wont store it + if !message.has_origin_mac() { + log_not_eligible("it is a cleartext message and does not have an origin MAC"); + return Ok(None); } - trace!(target: LOG_TARGET, "Passing message to next service"); - self.next_service - .oneshot(message) - .await - .map_err(PipelineError::from_debug)?; + // If this node decrypted the message (message.success() above), no need to store it + if message.is_encrypted() { + log_not_eligible("the message was encrypted for this node"); + return Ok(None); + } + + // If this is a join message, we may want to store it if it's for our neighbourhood + // if message.dht_header.message_type.is_dht_join() { + // return match self.get_priority_for_dht_join(message).await? { + // Some(priority) => Ok(Some(priority)), + // None => { + // log_not_eligible("the join message was not considered in this node's neighbourhood"); + // Ok(None) + // }, + // }; + // } + + log_not_eligible("it is not an eligible DhtMessageType"); + // Otherwise, don't store + Ok(None) }, + // This node could not decrypt the message None => { - if message.dht_header.origin.is_none() { - // TODO: #banheuristic + if !message.has_origin_mac() { + // TODO: #banheuristic - the source peer should not have propagated this message warn!( target: LOG_TARGET, - "Store task received an encrypted message with no source. This message is invalid and should \ - not be stored or propagated. Dropping message. Sent by node '{}'", + "Store task received an encrypted message with no origin MAC. This message is invalid and \ + should not be stored or propagated. Dropping message. Sent by node '{}'", message.source_peer.node_id.short_str() ); - return Ok(()); + return Ok(None); } - debug!( - target: LOG_TARGET, - "Decryption failed for message. Adding to SAF storage." - ); - let mut storage = self.storage.take().expect("StoreTask intialized without storage"); - storage.store(message).await.map_err(PipelineError::from_debug)?; + + // The destination of the message will determine if we store it + self.get_priority_by_destination(message).await }, } - - Ok(()) } -} - -struct InnerStorage { - peer_manager: Arc, - config: DhtConfig, - node_identity: Arc, - storage: Arc, -} - -impl InnerStorage { - async fn store(&mut self, message: DecryptedDhtMessage) -> Result<(), StoreAndForwardError> { - let DecryptedDhtMessage { - version, - decryption_result, - dht_header, - .. - } = message; - let origin = dht_header.origin.as_ref().expect("already checked"); - - let body = match decryption_result { - Ok(body) => body.to_encoded_bytes()?, - Err(encrypted_body) => encrypted_body, + // async fn get_priority_for_dht_join( + // &self, + // message: &DecryptedDhtMessage, + // ) -> SafResult> + // { + // debug_assert!(message.dht_header.message_type.is_dht_join() && !message.is_encrypted()); + // + // let body = message + // .decryption_result + // .as_ref() + // .expect("already checked that this message is not encrypted"); + // let join_msg = body + // .decode_part::(0)? + // .ok_or_else(|| StoreAndForwardError::InvalidEnvelopeBody)?; + // let node_id = NodeId::from_bytes(&join_msg.node_id).map_err(StoreAndForwardError::MalformedNodeId)?; + // + // // If this join request is for a peer that we'd consider to be a neighbour, store it for other neighbours + // if self + // .peer_manager + // .in_network_region( + // &node_id, + // self.node_identity.node_id(), + // self.config.num_neighbouring_nodes, + // ) + // .await? + // { + // if self.saf_requester.query_messages( + // DhtMessageType::Join, + // ) + // return Ok(Some(StoredMessagePriority::Low)); + // } + // + // Ok(None) + // } + + async fn get_priority_by_destination( + &self, + message: &DecryptedDhtMessage, + ) -> SafResult> + { + let log_not_eligible = |reason: &str| { + debug!( + target: LOG_TARGET, + "Message from peer '{}' not eligible for SAF storage because {}", + message.source_peer.node_id.short_str(), + reason + ); }; let peer_manager = &self.peer_manager; let node_identity = &self.node_identity; - match &dht_header.destination { - NodeDestination::Unknown => { - self.storage.insert( - origin.signature.clone(), - StoredMessage::new(version, dht_header, body), - self.config.saf_low_priority_msg_storage_ttl, - ); + if message.dht_header.destination == node_identity.public_key() || + message.dht_header.destination == node_identity.node_id() + { + log_not_eligible("the message is destined for this node"); + return Ok(None); + } + + use NodeDestination::*; + match &message.dht_header.destination { + Unknown => { + // No destination provided, + if message.dht_header.message_type.is_dht_discovery() { + log_not_eligible("it is an anonymous discovery message"); + Ok(None) + } else { + Ok(Some(StoredMessagePriority::Low)) + } }, - NodeDestination::PublicKey(dest_public_key) => { - if peer_manager.exists(&dest_public_key).await { - self.storage.insert( - origin.signature.clone(), - StoredMessage::new(version, dht_header, body), - self.config.saf_high_priority_msg_storage_ttl, - ); + PublicKey(dest_public_key) => { + // If we know the destination peer, keep the message for them + match peer_manager.find_by_public_key(&dest_public_key).await { + Ok(peer) => { + if peer.is_banned() { + log_not_eligible( + "origin peer is banned. ** This should not happen because it should have been checked \ + earlier in the pipeline **", + ); + Ok(None) + } else { + Ok(Some(StoredMessagePriority::High)) + } + }, + Err(err) if err.is_peer_not_found() => Ok(Some(StoredMessagePriority::Low)), + Err(err) => Err(err.into()), } }, - NodeDestination::NodeId(dest_node_id) => { + NodeId(dest_node_id) => { if peer_manager.exists_node_id(&dest_node_id).await || peer_manager .in_network_region( @@ -265,14 +383,30 @@ impl InnerStorage { ) .await? { - self.storage.insert( - origin.signature.clone(), - StoredMessage::new(version, dht_header, body), - self.config.saf_high_priority_msg_storage_ttl, - ); + Ok(Some(StoredMessagePriority::High)) + } else { + log_not_eligible(&format!( + "this node does not know the destination node id '{}' or does not consider it a neighbouring \ + node id", + dest_node_id + )); + Ok(None) } }, - }; + } + } + + async fn store(&mut self, priority: StoredMessagePriority, message: DecryptedDhtMessage) -> SafResult<()> { + debug!( + target: LOG_TARGET, + "Storing message from peer '{}' ({} bytes)", + message.source_peer.node_id.short_str(), + message.body_len(), + ); + + let stored_message = NewStoredMessage::try_construct(message, priority) + .ok_or_else(|| StoreAndForwardError::InvalidStoreMessage)?; + self.saf_requester.insert_message(stored_message).await?; Ok(()) } @@ -283,87 +417,145 @@ mod test { use super::*; use crate::{ envelope::DhtMessageFlags, - test_utils::{make_dht_inbound_message, make_node_identity, make_peer_manager, service_spy}, + proto::{dht::JoinMessage, envelope::DhtMessageType}, + test_utils::{ + create_store_and_forward_mock, + make_dht_inbound_message, + make_node_identity, + make_peer_manager, + service_spy, + }, }; - use chrono::{DateTime, Utc}; - use std::time::{Duration, UNIX_EPOCH}; - use tari_comms::wrap_in_envelope_body; + use chrono::Utc; + use std::time::Duration; + use tari_comms::{message::MessageExt, wrap_in_envelope_body}; + use tari_test_utils::async_assert_eventually; + use tari_utilities::{hex::Hex, ByteArray}; #[tokio_macros::test_basic] async fn cleartext_message_no_origin() { - let storage = Arc::new(SafStorage::new(1)); + let (requester, mock_state) = create_store_and_forward_mock(); let spy = service_spy(); let peer_manager = make_peer_manager(); let node_identity = make_node_identity(); - let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, storage.clone()) + let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, requester) .layer(spy.to_service::()); - let mut inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty()); - inbound_msg.dht_header.origin = None; - let msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(Vec::new()).unwrap(), inbound_msg); + let inbound_msg = + make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty(), false); + let msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(Vec::new()), None, inbound_msg); service.call(msg).await.unwrap(); assert!(spy.is_called()); - storage.with_lock(|mut lock| { - assert_eq!(lock.iter().count(), 0); - }); + let messages = mock_state.get_messages().await; + assert_eq!(messages.len(), 0); } + #[ignore] #[tokio_macros::test_basic] - async fn cleartext_message_with_origin() { - let storage = Arc::new(SafStorage::new(1)); + async fn cleartext_join_message() { + let (requester, mock_state) = create_store_and_forward_mock(); let spy = service_spy(); let peer_manager = make_peer_manager(); let node_identity = make_node_identity(); - let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, storage.clone()) - .layer(spy.to_service::()); + let join_msg_bytes = JoinMessage { + node_id: node_identity.node_id().to_vec(), + addresses: vec![], + peer_features: 0, + } + .to_encoded_bytes(); - let inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty()); - let msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(Vec::new()).unwrap(), inbound_msg); + let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, requester) + .layer(spy.to_service::()); + let sender_identity = make_node_identity(); + let inbound_msg = make_dht_inbound_message(&sender_identity, b"".to_vec(), DhtMessageFlags::empty(), true); + + let mut msg = DecryptedDhtMessage::succeeded( + wrap_in_envelope_body!(join_msg_bytes), + Some(sender_identity.public_key().clone()), + inbound_msg, + ); + msg.dht_header.message_type = DhtMessageType::Join; service.call(msg).await.unwrap(); assert!(spy.is_called()); - storage.with_lock(|mut lock| { - assert_eq!(lock.iter().count(), 1); - }); + + // Because we dont wait for the message to reach the mock/service before continuing (for efficiency and it's not + // necessary) we need to wait for the call to happen eventually - it should be almost instant + async_assert_eventually!( + mock_state.call_count(), + expect = 1, + max_attempts = 10, + interval = Duration::from_millis(10), + ); + let messages = mock_state.get_messages().await; + assert_eq!(messages[0].message_type, DhtMessageType::Join as i32); } #[tokio_macros::test_basic] async fn decryption_succeeded_no_store() { - let storage = Arc::new(SafStorage::new(1)); + let (requester, mock_state) = create_store_and_forward_mock(); let spy = service_spy(); let peer_manager = make_peer_manager(); let node_identity = make_node_identity(); - let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, storage.clone()) + let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, requester) .layer(spy.to_service::()); - let inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::ENCRYPTED); - let msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(b"secret".to_vec()).unwrap(), inbound_msg); + let msg_node_identity = make_node_identity(); + let inbound_msg = make_dht_inbound_message( + &msg_node_identity, + b"This shouldnt be stored".to_vec(), + DhtMessageFlags::ENCRYPTED, + true, + ); + let msg = DecryptedDhtMessage::succeeded( + wrap_in_envelope_body!(b"secret".to_vec()), + Some(msg_node_identity.public_key().clone()), + inbound_msg, + ); service.call(msg).await.unwrap(); assert!(spy.is_called()); - storage.with_lock(|mut lock| { - assert_eq!(lock.iter().count(), 0); - }); + + assert_eq!(mock_state.call_count(), 0); } #[tokio_macros::test_basic] async fn decryption_failed_should_store() { - let storage = Arc::new(SafStorage::new(1)); + let (requester, mock_state) = create_store_and_forward_mock(); let spy = service_spy(); let peer_manager = make_peer_manager(); + let origin_node_identity = make_node_identity(); + peer_manager.add_peer(origin_node_identity.to_peer()).await.unwrap(); let node_identity = make_node_identity(); - let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, Arc::clone(&storage)) + let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, requester) .layer(spy.to_service::()); - let inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty()); + let mut inbound_msg = make_dht_inbound_message( + &origin_node_identity, + b"Will you keep this for me?".to_vec(), + DhtMessageFlags::ENCRYPTED, + true, + ); + inbound_msg.dht_header.destination = + NodeDestination::PublicKey(Box::new(origin_node_identity.public_key().clone())); let msg = DecryptedDhtMessage::failed(inbound_msg.clone()); service.call(msg).await.unwrap(); - assert_eq!(spy.is_called(), false); - let msg = storage - .remove(&inbound_msg.dht_header.origin.unwrap().signature) - .unwrap(); - let timestamp: DateTime = (UNIX_EPOCH + Duration::from_secs(msg.stored_at.unwrap().seconds as u64)).into(); - assert!((Utc::now() - timestamp).num_seconds() <= 5); + assert_eq!(spy.is_called(), true); + + async_assert_eventually!( + mock_state.call_count(), + expect = 1, + max_attempts = 10, + interval = Duration::from_millis(10), + ); + + let message = mock_state.get_messages().await.remove(0); + assert_eq!( + message.destination_pubkey.unwrap(), + origin_node_identity.public_key().to_hex() + ); + let duration = Utc::now().naive_utc().signed_duration_since(message.stored_at); + assert!(duration.num_seconds() <= 5); } } diff --git a/comms/dht/src/test_utils/dht_actor_mock.rs b/comms/dht/src/test_utils/dht_actor_mock.rs index a30445ce5b..ba6435f2b3 100644 --- a/comms/dht/src/test_utils/dht_actor_mock.rs +++ b/comms/dht/src/test_utils/dht_actor_mock.rs @@ -19,13 +19,20 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#![allow(dead_code)] -use crate::actor::{DhtRequest, DhtRequester}; +use crate::{ + actor::{DhtRequest, DhtRequester}, + storage::DhtMetadataKey, +}; use futures::{channel::mpsc, stream::Fuse, StreamExt}; -use std::sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering}, - Arc, - RwLock, +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicBool, AtomicUsize, Ordering}, + Arc, + RwLock, + }, }; use tari_comms::peer_manager::Peer; @@ -39,6 +46,7 @@ pub struct DhtMockState { signature_cache_insert: Arc, call_count: Arc, select_peers: Arc>>, + settings: Arc>>>, } impl DhtMockState { @@ -47,6 +55,7 @@ impl DhtMockState { signature_cache_insert: Arc::new(AtomicBool::new(false)), call_count: Arc::new(AtomicUsize::new(0)), select_peers: Arc::new(RwLock::new(Vec::new())), + settings: Arc::new(RwLock::new(HashMap::new())), } } @@ -56,7 +65,7 @@ impl DhtMockState { } pub fn set_select_peers_response(&self, peers: Vec) -> &Self { - *acquire_write_lock!(self.select_peers) = peers; + *self.select_peers.write().unwrap() = peers; self } @@ -64,8 +73,8 @@ impl DhtMockState { self.call_count.fetch_add(1, Ordering::SeqCst); } - pub fn call_count(&self) -> usize { - self.call_count.load(Ordering::SeqCst) + pub fn get_setting(&self, key: &DhtMetadataKey) -> Option> { + self.settings.read().unwrap().get(&key.to_string()).map(Clone::clone) } } @@ -105,7 +114,18 @@ impl DhtActorMock { let lock = self.state.select_peers.read().unwrap(); reply_tx.send(lock.clone()).unwrap(); }, - SendRequestStoredMessages(_) => {}, + GetMetadata(key, reply_tx) => { + let _ = reply_tx.send(Ok(self + .state + .settings + .read() + .unwrap() + .get(&key.to_string()) + .map(Clone::clone))); + }, + SetMetadata(key, value) => { + self.state.settings.write().unwrap().insert(key.to_string(), value); + }, } } } diff --git a/comms/dht/src/test_utils/dht_discovery_mock.rs b/comms/dht/src/test_utils/dht_discovery_mock.rs index ae07a83cd0..70575e2ae0 100644 --- a/comms/dht/src/test_utils/dht_discovery_mock.rs +++ b/comms/dht/src/test_utils/dht_discovery_mock.rs @@ -61,7 +61,7 @@ impl DhtDiscoveryMockState { } pub fn set_discover_peer_response(&self, peer: Peer) -> &Self { - *acquire_write_lock!(self.discover_peer) = peer; + *self.discover_peer.write().unwrap() = peer; self } @@ -102,8 +102,7 @@ impl DhtDiscoveryMock { trace!(target: LOG_TARGET, "DhtDiscoveryMock received request {:?}", req); self.state.inc_call_count(); match req { - DiscoverPeer(boxed) => { - let (_, reply_tx) = *boxed; + DiscoverPeer(_, _, reply_tx) => { let lock = self.state.discover_peer.read().unwrap(); reply_tx.send(Ok(lock.clone())).unwrap(); }, diff --git a/comms/dht/src/test_utils/makers.rs b/comms/dht/src/test_utils/makers.rs index 1a473628c8..e34fadfe06 100644 --- a/comms/dht/src/test_utils/makers.rs +++ b/comms/dht/src/test_utils/makers.rs @@ -20,21 +20,27 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageOrigin, NodeDestination}, + crypt, + envelope::{DhtMessageFlags, DhtMessageHeader, NodeDestination}, inbound::DhtInboundMessage, - proto::envelope::{DhtEnvelope, DhtMessageType, Network}, + outbound::message::{DhtOutboundMessage, WrappedReplyTx}, + proto::envelope::{DhtEnvelope, DhtMessageType, Network, OriginMac}, }; use rand::rngs::OsRng; -use std::sync::Arc; +use std::{convert::TryInto, sync::Arc}; use tari_comms::{ - message::{InboundMessage, MessageEnvelopeHeader, MessageFlags}, + message::{InboundMessage, MessageExt, MessageTag}, multiaddr::Multiaddr, - peer_manager::{NodeIdentity, Peer, PeerFeatures, PeerFlags, PeerManager}, - types::CommsDatabase, + net_address::MultiaddressesWithStats, + peer_manager::{NodeId, NodeIdentity, Peer, PeerFeatures, PeerFlags, PeerManager}, + types::{CommsDatabase, CommsPublicKey, CommsSecretKey}, utils::signature, Bytes, }; -use tari_crypto::tari_utilities::message_format::MessageFormat; +use tari_crypto::{ + keys::PublicKey, + tari_utilities::{message_format::MessageFormat, ByteArray}, +}; use tari_storage::lmdb_store::LMDBBuilder; use tari_test_utils::{paths::create_temporary_data_path, random}; @@ -72,7 +78,7 @@ pub fn make_client_identity() -> Arc { ) } -pub fn make_comms_inbound_message(node_identity: &NodeIdentity, message: Bytes, flags: MessageFlags) -> InboundMessage { +pub fn make_comms_inbound_message(node_identity: &NodeIdentity, message: Bytes) -> InboundMessage { InboundMessage::new( Arc::new(Peer::new( node_identity.public_key().clone(), @@ -82,40 +88,68 @@ pub fn make_comms_inbound_message(node_identity: &NodeIdentity, message: Bytes, PeerFeatures::COMMUNICATION_NODE, &[], )), - MessageEnvelopeHeader { - public_key: node_identity.public_key().clone(), - signature: Bytes::new(), - flags, - }, message, ) } -pub fn make_dht_header(node_identity: &NodeIdentity, message: &Vec, flags: DhtMessageFlags) -> DhtMessageHeader { +pub fn make_dht_header( + node_identity: &NodeIdentity, + e_pk: &CommsPublicKey, + e_sk: &CommsSecretKey, + message: &[u8], + flags: DhtMessageFlags, + include_origin: bool, +) -> DhtMessageHeader +{ DhtMessageHeader { version: 0, destination: NodeDestination::Unknown, - origin: Some(DhtMessageOrigin { - public_key: node_identity.public_key().clone(), - signature: signature::sign(&mut OsRng, node_identity.secret_key().clone(), message) - .unwrap() - .to_binary() - .unwrap(), - }), + ephemeral_public_key: if flags.is_encrypted() { Some(e_pk.clone()) } else { None }, + origin_mac: if include_origin { + make_valid_origin_mac(node_identity, &e_sk, message, flags) + } else { + Vec::new() + }, message_type: DhtMessageType::None, network: Network::LocalTest, flags, } } +pub fn make_valid_origin_mac( + node_identity: &NodeIdentity, + e_sk: &CommsSecretKey, + body: &[u8], + flags: DhtMessageFlags, +) -> Vec +{ + let mac = OriginMac { + public_key: node_identity.public_key().to_vec(), + signature: signature::sign(&mut OsRng, node_identity.secret_key().clone(), body) + .unwrap() + .to_binary() + .unwrap(), + }; + let body = mac.to_encoded_bytes(); + if flags.is_encrypted() { + let shared_secret = crypt::generate_ecdh_secret(e_sk, node_identity.public_key()); + crypt::encrypt(&shared_secret, &body).unwrap() + } else { + body + } +} + pub fn make_dht_inbound_message( node_identity: &NodeIdentity, body: Vec, flags: DhtMessageFlags, + include_origin: bool, ) -> DhtInboundMessage { + let envelope = make_dht_envelope(node_identity, body, flags, include_origin); DhtInboundMessage::new( - make_dht_header(node_identity, &body, flags), + MessageTag::new(), + envelope.header.unwrap().try_into().unwrap(), Arc::new(Peer::new( node_identity.public_key().clone(), node_identity.node_id().clone(), @@ -124,12 +158,28 @@ pub fn make_dht_inbound_message( PeerFeatures::COMMUNICATION_NODE, &[], )), - body, + envelope.body, ) } -pub fn make_dht_envelope(node_identity: &NodeIdentity, message: Vec, flags: DhtMessageFlags) -> DhtEnvelope { - DhtEnvelope::new(make_dht_header(node_identity, &message, flags).into(), message) +pub fn make_keypair() -> (CommsSecretKey, CommsPublicKey) { + CommsPublicKey::random_keypair(&mut OsRng) +} + +pub fn make_dht_envelope( + node_identity: &NodeIdentity, + mut message: Vec, + flags: DhtMessageFlags, + include_origin: bool, +) -> DhtEnvelope +{ + let (e_sk, e_pk) = make_keypair(); + if flags.is_encrypted() { + let shared_secret = crypt::generate_ecdh_secret(&e_sk, node_identity.public_key()); + message = crypt::encrypt(&shared_secret, &message).unwrap(); + } + let header = make_dht_header(node_identity, &e_pk, &e_sk, &message, flags, include_origin).into(); + DhtEnvelope::new(header, message.into()) } pub fn make_peer_manager() -> Arc { @@ -137,7 +187,7 @@ pub fn make_peer_manager() -> Arc { let path = create_temporary_data_path(); let datastore = LMDBBuilder::new() .set_path(path.to_str().unwrap()) - .set_environment_size(10) + .set_environment_size(50) .set_max_number_of_databases(1) .add_database(&database_name, lmdb_zero::db::CREATE) .build() @@ -149,3 +199,27 @@ pub fn make_peer_manager() -> Arc { .map(Arc::new) .unwrap() } + +pub fn create_outbound_message(body: &[u8]) -> DhtOutboundMessage { + DhtOutboundMessage { + tag: MessageTag::new(), + destination_peer: Peer::new( + CommsPublicKey::default(), + NodeId::default(), + MultiaddressesWithStats::new(vec![]), + PeerFlags::empty(), + PeerFeatures::COMMUNICATION_NODE, + &[], + ), + destination: Default::default(), + dht_message_type: Default::default(), + network: Network::LocalTest, + dht_flags: Default::default(), + custom_header: None, + body: body.to_vec().into(), + ephemeral_public_key: None, + reply_tx: WrappedReplyTx::none(), + origin_mac: None, + is_broadcast: false, + } +} diff --git a/comms/dht/src/test_utils/mod.rs b/comms/dht/src/test_utils/mod.rs index 90f9358d43..56cfab4713 100644 --- a/comms/dht/src/test_utils/mod.rs +++ b/comms/dht/src/test_utils/mod.rs @@ -32,17 +32,22 @@ macro_rules! unwrap_oms_send_msg { ($var:expr) => { unwrap_oms_send_msg!( $var, - reply_value = $crate::outbound::SendMessageResponse::Queued(vec![]) + reply_value = $crate::outbound::SendMessageResponse::Queued(vec![].into()) ); }; } mod dht_actor_mock; -mod dht_discovery_mock; -mod makers; -mod service; - pub use dht_actor_mock::{create_dht_actor_mock, DhtMockState}; + +mod dht_discovery_mock; pub use dht_discovery_mock::{create_dht_discovery_mock, DhtDiscoveryMockState}; + +mod makers; pub use makers::*; + +mod service; pub use service::{service_fn, service_spy}; + +mod store_and_forward_mock; +pub use store_and_forward_mock::{create_store_and_forward_mock, StoreAndForwardMockState}; diff --git a/comms/dht/src/test_utils/store_and_forward_mock.rs b/comms/dht/src/test_utils/store_and_forward_mock.rs new file mode 100644 index 0000000000..0f3bfaae08 --- /dev/null +++ b/comms/dht/src/test_utils/store_and_forward_mock.rs @@ -0,0 +1,136 @@ +// Copyright 2019, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use crate::store_forward::{StoreAndForwardRequest, StoreAndForwardRequester, StoredMessage}; +use chrono::Utc; +use futures::{channel::mpsc, stream::Fuse, StreamExt}; +use log::*; +use rand::{rngs::OsRng, RngCore}; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use tokio::{runtime, sync::RwLock}; + +const LOG_TARGET: &str = "comms::dht::discovery_mock"; + +pub fn create_store_and_forward_mock() -> (StoreAndForwardRequester, StoreAndForwardMockState) { + let (tx, rx) = mpsc::channel(10); + + let mock = StoreAndForwardMock::new(rx.fuse()); + let state = mock.get_shared_state(); + runtime::Handle::current().spawn(mock.run()); + (StoreAndForwardRequester::new(tx), state) +} + +#[derive(Debug, Clone)] +pub struct StoreAndForwardMockState { + call_count: Arc, + stored_messages: Arc>>, + calls: Arc>>, +} + +impl StoreAndForwardMockState { + pub fn new() -> Self { + Self { + call_count: Arc::new(AtomicUsize::new(0)), + stored_messages: Arc::new(RwLock::new(Vec::new())), + calls: Arc::new(RwLock::new(Vec::new())), + } + } + + pub fn inc_call_count(&self) { + self.call_count.fetch_add(1, Ordering::SeqCst); + } + + pub async fn add_call(&self, call: &StoreAndForwardRequest) { + self.inc_call_count(); + self.calls.write().await.push(format!("{:?}", call)); + } + + pub fn call_count(&self) -> usize { + self.call_count.load(Ordering::SeqCst) + } + + pub async fn get_messages(&self) -> Vec { + self.stored_messages.read().await.clone() + } + + pub async fn add_message(&self, message: StoredMessage) { + self.stored_messages.write().await.push(message) + } + + pub async fn take_calls(&self) -> Vec { + self.calls.write().await.drain(..).collect() + } +} + +pub struct StoreAndForwardMock { + receiver: Fuse>, + state: StoreAndForwardMockState, +} + +impl StoreAndForwardMock { + pub fn new(receiver: Fuse>) -> Self { + Self { + receiver, + state: StoreAndForwardMockState::new(), + } + } + + pub fn get_shared_state(&self) -> StoreAndForwardMockState { + self.state.clone() + } + + pub async fn run(mut self) { + while let Some(req) = self.receiver.next().await { + self.handle_request(req).await; + } + } + + async fn handle_request(&self, req: StoreAndForwardRequest) { + use StoreAndForwardRequest::*; + trace!(target: LOG_TARGET, "StoreAndForwardMock received request {:?}", req); + self.state.add_call(&req).await; + match req { + FetchMessages(_, reply_tx) => { + let msgs = self.state.stored_messages.read().await; + let _ = reply_tx.send(Ok(msgs.clone())); + }, + InsertMessage(msg) => self.state.stored_messages.write().await.push(StoredMessage { + id: OsRng.next_u32() as i32, + version: msg.version, + origin_pubkey: msg.origin_pubkey, + message_type: msg.message_type, + destination_pubkey: msg.destination_pubkey, + destination_node_id: msg.destination_node_id, + header: msg.header, + body: msg.body, + is_encrypted: msg.is_encrypted, + priority: msg.priority, + stored_at: Utc::now().naive_utc(), + }), + SendStoreForwardRequestToPeer(_) => {}, + SendStoreForwardRequestNeighbours => {}, + } + } +} diff --git a/comms/dht/src/tower_filter/error.rs b/comms/dht/src/tower_filter/error.rs deleted file mode 100644 index b2643f77f4..0000000000 --- a/comms/dht/src/tower_filter/error.rs +++ /dev/null @@ -1,46 +0,0 @@ -//! Error types - -use std::{error, fmt}; - -/// Error produced by `Filter` -#[derive(Debug)] -pub struct Error { - source: Option, -} - -pub(crate) type Source = Box; - -impl Error { - /// Create a new `Error` representing a rejected request. - pub fn rejected() -> Error { - Error { source: None } - } - - /// Create a new `Error` representing an inner service error. - pub fn inner(source: E) -> Error - where E: Into { - Error { - source: Some(source.into()), - } - } -} - -impl fmt::Display for Error { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - if self.source.is_some() { - write!(fmt, "inner service errored") - } else { - write!(fmt, "rejected") - } - } -} - -impl error::Error for Error { - fn source(&self) -> Option<&(dyn error::Error + 'static)> { - if let Some(ref err) = self.source { - Some(&**err) - } else { - None - } - } -} diff --git a/comms/dht/src/tower_filter/future.rs b/comms/dht/src/tower_filter/future.rs index 0faa301245..25c6357c80 100644 --- a/comms/dht/src/tower_filter/future.rs +++ b/comms/dht/src/tower_filter/future.rs @@ -1,6 +1,5 @@ //! Future types -use super::error::{self, Error}; use futures::ready; use pin_project::{pin_project, project}; use std::{ @@ -8,6 +7,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; +use tari_comms::pipeline::PipelineError; use tower::Service; /// Filtered response future @@ -37,9 +37,8 @@ enum State { impl ResponseFuture where - F: Future>, - S: Service, - S::Error: Into, + F: Future>, + S: Service, { pub(crate) fn new(request: Request, check: F, service: S) -> Self { ResponseFuture { @@ -52,11 +51,10 @@ where impl Future for ResponseFuture where - F: Future>, - S: Service, - S::Error: Into, + F: Future>, + S: Service, { - type Output = Result; + type Output = Result; #[project] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -83,7 +81,7 @@ where } }, State::WaitResponse(response) => { - return Poll::Ready(ready!(response.poll(cx)).map_err(Error::inner)); + return Poll::Ready(ready!(response.poll(cx))); }, } } diff --git a/comms/dht/src/tower_filter/mod.rs b/comms/dht/src/tower_filter/mod.rs index 673a855393..92c2c85f93 100644 --- a/comms/dht/src/tower_filter/mod.rs +++ b/comms/dht/src/tower_filter/mod.rs @@ -4,7 +4,6 @@ //! Conditionally dispatch requests to the inner service based on the result of //! a predicate. -pub mod error; pub mod future; mod layer; mod predicate; @@ -12,15 +11,12 @@ mod predicate; pub use layer::FilterLayer; pub use predicate::Predicate; -use error::Error; use future::ResponseFuture; use futures::ready; use std::task::{Context, Poll}; +use tari_comms::pipeline::PipelineError; use tower::Service; -#[cfg(test)] -mod test; - /// Conditionally dispatch requests to the inner service based on a predicate. #[derive(Clone, Debug)] pub struct Filter { @@ -37,16 +33,15 @@ impl Filter { impl Service for Filter where - T: Service + Clone, - T::Error: Into, + T: Service + Clone, U: Predicate, { - type Error = Error; + type Error = PipelineError; type Future = ResponseFuture; type Response = T::Response; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - Poll::Ready(ready!(self.inner.poll_ready(cx)).map_err(error::Error::inner)) + Poll::Ready(ready!(self.inner.poll_ready(cx)).map_err(PipelineError::from_debug)) } fn call(&mut self, request: Request) -> Self::Future { diff --git a/comms/dht/src/tower_filter/predicate.rs b/comms/dht/src/tower_filter/predicate.rs index 52b3936aa6..f86b9cc406 100644 --- a/comms/dht/src/tower_filter/predicate.rs +++ b/comms/dht/src/tower_filter/predicate.rs @@ -1,10 +1,10 @@ -use super::error::Error; use std::future::Future; +use tari_comms::pipeline::PipelineError; /// Checks a request pub trait Predicate { /// The future returned by `check`. - type Future: Future>; + type Future: Future>; /// Check whether the given request should be forwarded. /// @@ -15,7 +15,7 @@ pub trait Predicate { impl Predicate for F where F: Fn(&T) -> U, - U: Future>, + U: Future>, { type Future = U; diff --git a/comms/dht/src/tower_filter/test.rs b/comms/dht/src/tower_filter/test.rs deleted file mode 100644 index 0990ecc11d..0000000000 --- a/comms/dht/src/tower_filter/test.rs +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2020, The Tari Project -// -// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the -// following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following -// disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the -// following disclaimer in the documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote -// products derived from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, -// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -use super::{error::Error, Filter}; -use futures_util::{future::poll_fn, pin_mut}; -use std::future::Future; -use tokio::runtime::Handle; -use tokio_test::task; -use tower::Service; -use tower_test::{assert_request_eq, mock}; - -#[tokio_macros::test] -async fn passthrough_sync() { - let (mut service, handle) = new_service(|_| async { Ok(()) }); - - let handle = Handle::current().spawn(async move { - // Receive the requests and respond - pin_mut!(handle); - for i in 0..10 { - assert_request_eq!(handle, format!("ping-{}", i)).send_response(format!("pong-{}", i)); - } - }); - - let mut responses = vec![]; - - for i in 0usize..10 { - let request = format!("ping-{}", i); - poll_fn(|cx| service.poll_ready(cx)).await.unwrap(); - let exchange = service.call(request); - let exchange = async move { - let response = exchange.await.unwrap(); - let expect = format!("pong-{}", i); - assert_eq!(response.as_str(), expect.as_str()); - }; - - responses.push(exchange); - } - - futures_util::future::join_all(responses).await; - handle.await.unwrap(); -} - -#[test] -fn rejected_sync() { - task::spawn(async { - let (mut service, _handle) = new_service(|_| async { Err(Error::rejected()) }); - service.call("hello".into()).await.unwrap_err(); - }); -} - -type Mock = mock::Mock; -type MockHandle = mock::Handle; - -fn new_service(f: F) -> (Filter, MockHandle) -where - F: Fn(&String) -> U, - U: Future>, -{ - let (service, handle) = mock::pair(); - let service = Filter::new(service, f); - (service, handle) -} diff --git a/comms/dht/src/store_forward/state.rs b/comms/dht/src/utils.rs similarity index 61% rename from comms/dht/src/store_forward/state.rs rename to comms/dht/src/utils.rs index 40a1bf513b..b99e334815 100644 --- a/comms/dht/src/store_forward/state.rs +++ b/comms/dht/src/utils.rs @@ -20,37 +20,18 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::proto::store_forward::StoredMessage; -use std::{ - sync::{RwLock, RwLockWriteGuard}, - time::Duration, -}; -use ttl_cache::TtlCache; +use std::convert::TryInto; -pub type SignatureBytes = Vec; - -pub struct SafStorage { - message_cache: RwLock>, -} - -impl SafStorage { - pub fn new(cache_capacity: usize) -> Self { - Self { - message_cache: RwLock::new(TtlCache::new(cache_capacity)), - } - } - - pub fn insert(&self, key: SignatureBytes, message: StoredMessage, ttl: Duration) -> Option { - acquire_write_lock!(self.message_cache).insert(key, message, ttl) - } - - pub fn with_lock(&self, f: F) -> T - where F: FnOnce(RwLockWriteGuard>) -> T { - f(acquire_write_lock!(self.message_cache)) - } - - #[cfg(test)] - pub fn remove(&self, key: &SignatureBytes) -> Option { - acquire_write_lock!(self.message_cache).remove(key) +/// Tries to convert a series of `T`s to `U`s, returning an error at the first failure +pub fn try_convert_all(into_iter: I) -> Result, T::Error> +where + I: IntoIterator, + T: TryInto, +{ + let iter = into_iter.into_iter(); + let mut result = Vec::with_capacity(iter.size_hint().0); + for item in iter { + result.push(item.try_into()?); } + Ok(result) } diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index 7118df8f94..9cc33b59b3 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -20,28 +20,46 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::channel::mpsc; +use futures::{channel::mpsc, StreamExt}; use rand::rngs::OsRng; use std::{sync::Arc, time::Duration}; use tari_comms::{ backoff::ConstantBackoff, - peer_manager::{NodeIdentity, Peer, PeerFeatures, PeerStorage}, + message::MessageExt, + peer_manager::{NodeId, NodeIdentity, Peer, PeerFeatures, PeerStorage}, pipeline, pipeline::SinkService, + protocol::messaging::MessagingEvent, transports::MemoryTransport, types::CommsDatabase, + wrap_in_envelope_body, CommsBuilder, CommsNode, }; -use tari_comms_dht::{envelope::NodeDestination, inbound::DecryptedDhtMessage, Dht, DhtBuilder}; +use tari_comms_dht::{ + domain_message::OutboundDomainMessage, + envelope::NodeDestination, + inbound::DecryptedDhtMessage, + outbound::{OutboundEncryption, SendMessageParams}, + DbConnectionUrl, + Dht, + DhtBuilder, +}; use tari_storage::{lmdb_store::LMDBBuilder, LMDBWrapper}; -use tari_test_utils::{async_assert_eventually, paths::create_temporary_data_path, random}; +use tari_test_utils::{ + async_assert_eventually, + collect_stream, + paths::create_temporary_data_path, + random, + unpack_enum, +}; +use tokio::time; use tower::ServiceBuilder; struct TestNode { comms: CommsNode, dht: Dht, - _ims_rx: mpsc::Receiver, + ims_rx: mpsc::Receiver, } impl TestNode { @@ -52,6 +70,10 @@ impl TestNode { pub fn to_peer(&self) -> Peer { self.comms.node_identity().to_peer() } + + pub async fn next_inbound_message(&mut self, timeout: Duration) -> Option { + time::timeout(timeout, self.ims_rx.next()).await.ok()? + } } fn make_node_identity(features: PeerFeatures) -> Arc { @@ -63,7 +85,7 @@ fn create_peer_storage(peers: Vec) -> CommsDatabase { let database_name = random::string(8); let datastore = LMDBBuilder::new() .set_path(create_temporary_data_path().to_str().unwrap()) - .set_environment_size(10) + .set_environment_size(50) .set_max_number_of_databases(1) .add_database(&database_name, lmdb_zero::db::CREATE) .build() @@ -81,15 +103,14 @@ fn create_peer_storage(peers: Vec) -> CommsDatabase { async fn make_node(features: PeerFeatures, seed_peer: Option) -> TestNode { let node_identity = make_node_identity(features); + make_node_with_node_identity(node_identity, seed_peer).await +} - let (tx, ims_rx) = mpsc::channel(1); +async fn make_node_with_node_identity(node_identity: Arc, seed_peer: Option) -> TestNode { + let (tx, ims_rx) = mpsc::channel(10); let (comms, dht) = setup_comms_dht(node_identity, create_peer_storage(seed_peer.into_iter().collect()), tx).await; - TestNode { - comms, - dht, - _ims_rx: ims_rx, - } + TestNode { comms, dht, ims_rx } } async fn setup_comms_dht( @@ -120,9 +141,13 @@ async fn setup_comms_dht( comms.shutdown_signal(), ) .local_test() + .disable_auto_store_and_forward_requests() + .with_database_url(DbConnectionUrl::MemoryShared(random::string(8))) .with_discovery_timeout(Duration::from_secs(60)) .with_num_neighbouring_nodes(8) - .finish(); + .finish() + .await + .unwrap(); let dht_outbound_layer = dht.outbound_middleware_layer(); @@ -215,9 +240,7 @@ async fn dht_discover_propagation() { .discovery_service_requester() .discover_peer( Box::new(node_D.node_identity().public_key().clone()), - None, - // Sending to a nonsense NodeId, this should still propagate towards D in a network of 4 - NodeDestination::NodeId(Box::new(Default::default())), + NodeDestination::Unknown, ) .await .unwrap(); @@ -239,3 +262,217 @@ async fn dht_discover_propagation() { node_C.comms.shutdown().await; node_D.comms.shutdown().await; } + +#[tokio_macros::test] +#[allow(non_snake_case)] +async fn dht_store_forward() { + let node_C_node_identity = make_node_identity(PeerFeatures::COMMUNICATION_NODE); + // Node B knows about Node C + let node_B = make_node(PeerFeatures::COMMUNICATION_NODE, None).await; + // Node A knows about Node B + let node_A = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_B.to_peer())).await; + log::info!( + "NodeA = {}, NodeB = {}, Node C = {}", + node_A.node_identity().node_id().short_str(), + node_B.node_identity().node_id().short_str(), + node_C_node_identity.node_id().short_str(), + ); + + let dest_public_key = Box::new(node_C_node_identity.public_key().clone()); + let params = SendMessageParams::new() + .neighbours(vec![]) + .with_encryption(OutboundEncryption::EncryptFor(dest_public_key)) + .with_destination(NodeDestination::NodeId(Box::new( + node_C_node_identity.node_id().clone(), + ))) + .finish(); + + let secret_msg1 = b"NCZW VUSX PNYM INHZ XMQX SFWX WLKJ AHSH"; + let secret_msg2 = b"NMCO CCAK UQPM KCSM HKSE INJU SBLK"; + + let node_B_msg_events = node_B.comms.subscribe_messaging_events(); + node_A + .dht + .outbound_requester() + .send_raw( + params.clone(), + wrap_in_envelope_body!(secret_msg1.to_vec()).to_encoded_bytes(), + ) + .await + .unwrap(); + node_A + .dht + .outbound_requester() + .send_raw(params, wrap_in_envelope_body!(secret_msg2.to_vec()).to_encoded_bytes()) + .await + .unwrap(); + + // Wait for node B to receive 2 propagation messages + collect_stream!(node_B_msg_events, take = 2, timeout = Duration::from_secs(20)); + + let mut node_C = make_node_with_node_identity(node_C_node_identity, Some(node_B.to_peer())).await; + let node_C_msg_events = node_C.comms.subscribe_messaging_events(); + // Ask node B for messages + node_C + .dht + .store_and_forward_requester() + .request_saf_messages_from_peer(node_B.node_identity().node_id().clone()) + .await + .unwrap(); + // Wait for node C to send 1 SAF request, and receive a response + collect_stream!(node_C_msg_events, take = 2, timeout = Duration::from_secs(20)); + + let msg = node_C.next_inbound_message(Duration::from_secs(5)).await.unwrap(); + assert_eq!( + msg.authenticated_origin.as_ref().unwrap(), + node_A.comms.node_identity().public_key() + ); + let secret = msg.success().unwrap().decode_part::>(0).unwrap().unwrap(); + assert_eq!(secret, secret_msg1.to_vec()); + let msg = node_C.next_inbound_message(Duration::from_secs(5)).await.unwrap(); + assert_eq!( + msg.authenticated_origin.as_ref().unwrap(), + node_A.comms.node_identity().public_key() + ); + let secret = msg.success().unwrap().decode_part::>(0).unwrap().unwrap(); + assert_eq!(secret, secret_msg2.to_vec()); + + node_A.comms.shutdown().await; + node_B.comms.shutdown().await; + node_C.comms.shutdown().await; +} + +#[tokio_macros::test] +#[allow(non_snake_case)] +async fn dht_propagate_dedup() { + env_logger::init(); + // Node D knows no one + let mut node_D = make_node(PeerFeatures::COMMUNICATION_NODE, None).await; + // Node C knows about Node D + let mut node_C = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_D.to_peer())).await; + // Node B knows about Node C + let mut node_B = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_C.to_peer())).await; + // Node A knows about Node B and C + let mut node_A = make_node(PeerFeatures::COMMUNICATION_NODE, Some(node_B.to_peer())).await; + node_A.comms.peer_manager().add_peer(node_C.to_peer()).await.unwrap(); + log::info!( + "NodeA = {}, NodeB = {}, Node C = {}, Node D = {}", + node_A.node_identity().node_id().short_str(), + node_B.node_identity().node_id().short_str(), + node_C.node_identity().node_id().short_str(), + node_D.node_identity().node_id().short_str(), + ); + + // Connect the peers that should be connected + async fn connect_nodes(node1: &mut TestNode, node2: &mut TestNode) { + node1 + .comms + .connection_manager() + .dial_peer(node2.node_identity().node_id().clone()) + .await + .unwrap(); + } + // Pre-connect nodes, this helps message passing be more deterministic + connect_nodes(&mut node_A, &mut node_B).await; + connect_nodes(&mut node_A, &mut node_C).await; + connect_nodes(&mut node_B, &mut node_C).await; + connect_nodes(&mut node_C, &mut node_D).await; + + let mut node_A_messaging = node_A.comms.subscribe_messaging_events(); + let mut node_B_messaging = node_B.comms.subscribe_messaging_events(); + let mut node_C_messaging = node_C.comms.subscribe_messaging_events(); + let mut node_D_messaging = node_D.comms.subscribe_messaging_events(); + + #[derive(Clone, PartialEq, ::prost::Message)] + struct Person { + #[prost(string, tag = "1")] + name: String, + #[prost(uint32, tag = "2")] + age: u32, + } + + let out_msg = OutboundDomainMessage::new(123, Person { + name: "John Conway".into(), + age: 82, + }); + node_A + .dht + .outbound_requester() + .propagate( + // Node D is a client node, so an destination is required for domain messages + NodeDestination::Unknown, // NodeId(Box::new(node_D.node_identity().node_id().clone())), + OutboundEncryption::EncryptFor(Box::new(node_D.node_identity().public_key().clone())), + vec![], + out_msg, + ) + .await + .unwrap(); + + let msg = node_D + .next_inbound_message(Duration::from_secs(10)) + .await + .expect("Node D expected an inbound message but it never arrived"); + assert!(msg.decryption_succeeded()); + let person = msg + .decryption_result + .unwrap() + .decode_part::(1) + .unwrap() + .unwrap(); + assert_eq!(person.name, "John Conway"); + + let node_A_id = node_A.node_identity().node_id().clone(); + let node_B_id = node_B.node_identity().node_id().clone(); + let node_C_id = node_C.node_identity().node_id().clone(); + let node_D_id = node_D.node_identity().node_id().clone(); + + node_A.comms.shutdown().await; + node_B.comms.shutdown().await; + node_C.comms.shutdown().await; + node_D.comms.shutdown().await; + + // Check the message flow BEFORE deduping + let (sent, received) = partition_events(collect_stream!(node_A_messaging, timeout = Duration::from_secs(20))); + assert_eq!(sent.len(), 2); + // Expected race condition: If A->(B|C)->(C|B) before A->(C|B) then (C|B)->A + if received.len() > 0 { + assert_eq!(count_messages_received(&received, &[&node_B_id, &node_C_id]), 1); + } + + let (sent, received) = partition_events(collect_stream!(node_B_messaging, timeout = Duration::from_secs(20))); + assert_eq!(sent.len(), 1); + let recv_count = count_messages_received(&received, &[&node_A_id, &node_C_id]); + // Expected race condition: If A->B->C before A->C then C->B does not happen + assert!(recv_count >= 1 && recv_count <= 2); + + let (sent, received) = partition_events(collect_stream!(node_C_messaging, timeout = Duration::from_secs(20))); + let recv_count = count_messages_received(&received, &[&node_A_id, &node_B_id]); + assert_eq!(recv_count, 2); + assert_eq!(sent.len(), 2); + assert_eq!(count_messages_received(&received, &[&node_D_id]), 0); + + let (sent, received) = partition_events(collect_stream!(node_D_messaging, timeout = Duration::from_secs(20))); + assert_eq!(sent.len(), 0); + assert_eq!(received.len(), 1); + assert_eq!(count_messages_received(&received, &[&node_C_id]), 1); +} + +fn partition_events( + events: Vec, tokio::sync::broadcast::RecvError>>, +) -> (Vec>, Vec>) { + events.into_iter().map(Result::unwrap).partition(|e| match &**e { + MessagingEvent::MessageReceived(_, _) => false, + MessagingEvent::MessageSent(_) => true, + _ => unreachable!(), + }) +} + +fn count_messages_received(events: &[Arc], node_ids: &[&NodeId]) -> usize { + events + .into_iter() + .filter(|event| { + unpack_enum!(MessagingEvent::MessageReceived(recv_node_id, _tag) = &***event); + node_ids.into_iter().any(|n| &**recv_node_id == *n) + }) + .count() +} diff --git a/comms/examples/tor.rs b/comms/examples/tor.rs index e9ca5beac4..3845b49ed5 100644 --- a/comms/examples/tor.rs +++ b/comms/examples/tor.rs @@ -112,7 +112,6 @@ async fn run() -> Result<(), Error> { outbound_tx1 .send(OutboundMessage::new( comms_node2.node_identity().node_id().clone(), - Default::default(), Bytes::from_static(b"START"), )) .await?; @@ -152,7 +151,7 @@ async fn setup_node_with_tor>( { let datastore = LMDBBuilder::new() .set_path(database_path.to_str().unwrap()) - .set_environment_size(10) + .set_environment_size(50) .set_max_number_of_databases(1) .add_database("peerdb", lmdb_zero::db::CREATE) .build() @@ -272,5 +271,5 @@ async fn start_ping_ponger( fn make_msg(node_id: &NodeId, msg: String) -> OutboundMessage { let msg = Bytes::copy_from_slice(msg.as_bytes()); - OutboundMessage::new(node_id.clone(), Default::default(), msg) + OutboundMessage::new(node_id.clone(), msg) } diff --git a/comms/src/builder/comms_node.rs b/comms/src/builder/comms_node.rs index b930076cdd..ba258d85d6 100644 --- a/comms/src/builder/comms_node.rs +++ b/comms/src/builder/comms_node.rs @@ -275,6 +275,11 @@ impl CommsNode { self.messaging_event_tx.subscribe() } + /// Return a clone of the of the messaging event Sender to allow for other services to create subscriptions + pub fn message_event_sender(&self) -> messaging::MessagingEventSender { + self.messaging_event_tx.clone() + } + /// Return an owned copy of a ConnectionManagerRequester. Used to initiate connections to peers. pub fn connection_manager(&self) -> ConnectionManagerRequester { self.connection_manager_requester.clone() diff --git a/comms/src/builder/tests.rs b/comms/src/builder/tests.rs index 4778a6d677..65576e4a40 100644 --- a/comms/src/builder/tests.rs +++ b/comms/src/builder/tests.rs @@ -206,7 +206,6 @@ async fn peer_to_peer_messaging() { for i in 0..NUM_MSGS { let outbound_msg = OutboundMessage::new( node_identity2.node_id().clone(), - Default::default(), format!("#{:0>3} - comms messaging is so hot right now!", i).into(), ); outbound_tx1.send(outbound_msg).await.unwrap(); @@ -227,7 +226,6 @@ async fn peer_to_peer_messaging() { for i in 0..NUM_MSGS { let outbound_msg = OutboundMessage::new( node_identity1.node_id().clone(), - Default::default(), format!("#{:0>3} - comms messaging is so hot right now!", i).into(), ); outbound_tx2.send(outbound_msg).await.unwrap(); @@ -295,7 +293,6 @@ async fn peer_to_peer_messaging_simultaneous() { for i in 0..NUM_MSGS { let outbound_msg = OutboundMessage::new( node_identity2.node_id().clone(), - Default::default(), format!("#{:0>3} - comms messaging is so hot right now!", i).into(), ); outbound_tx1.send(outbound_msg).await.unwrap(); @@ -306,7 +303,6 @@ async fn peer_to_peer_messaging_simultaneous() { for i in 0..NUM_MSGS { let outbound_msg = OutboundMessage::new( node_identity1.node_id().clone(), - Default::default(), format!("#{:0>3} - comms messaging is so hot right now!", i).into(), ); outbound_tx2.send(outbound_msg).await.unwrap(); diff --git a/comms/src/connection_manager/common.rs b/comms/src/connection_manager/common.rs index ef06c4a1c7..acf94d0208 100644 --- a/comms/src/connection_manager/common.rs +++ b/comms/src/connection_manager/common.rs @@ -139,6 +139,8 @@ pub async fn validate_and_add_peer_from_peer_identity( Some(peer_node_id.clone()), Some(addresses), None, + None, + Some(false), Some(PeerFeatures::from_bits_truncate(peer_identity.features)), Some(conn_stats), Some(supported_protocols), diff --git a/comms/src/connection_manager/dialer.rs b/comms/src/connection_manager/dialer.rs index 594baba131..ac5595963f 100644 --- a/comms/src/connection_manager/dialer.rs +++ b/comms/src/connection_manager/dialer.rs @@ -441,7 +441,7 @@ where // Inflight dial was cancelled (state, Err(ConnectionManagerError::DialCancelled)) => break (state, Err(ConnectionManagerError::DialCancelled)), (mut state, Err(err)) => { - if state.num_attempts() > max_attempts { + if state.num_attempts() >= max_attempts { break (state, Err(ConnectionManagerError::ConnectFailedMaximumAttemptsReached)); } diff --git a/comms/src/connection_manager/error.rs b/comms/src/connection_manager/error.rs index c7ab43da91..49b1c2205b 100644 --- a/comms/src/connection_manager/error.rs +++ b/comms/src/connection_manager/error.rs @@ -74,8 +74,6 @@ pub enum ConnectionManagerError { IdentityProtocolError(IdentityProtocolError), /// The dial was cancelled DialCancelled, - /// The peer is offline and will not be dialed - PeerOffline, #[error(msg_embedded, no_from, non_std)] InvalidMultiaddr(String), /// Failed to send wire format byte diff --git a/comms/src/connection_manager/manager.rs b/comms/src/connection_manager/manager.rs index dc5c54afe0..6b42e6ac45 100644 --- a/comms/src/connection_manager/manager.rs +++ b/comms/src/connection_manager/manager.rs @@ -315,7 +315,7 @@ where use ConnectionManagerRequest::*; trace!(target: LOG_TARGET, "Connection manager got request: {:?}", request); match request { - DialPeer(node_id, is_forced, reply_tx) => match self.get_active_connection(&node_id) { + DialPeer(node_id, reply_tx) => match self.get_active_connection(&node_id) { Some(conn) => { debug!(target: LOG_TARGET, "[{}] Found existing active connection", conn); log_if_error_fmt!( @@ -333,7 +333,7 @@ where self.node_identity.node_id().short_str(), node_id.short_str() ); - self.dial_peer(node_id, reply_tx, is_forced).await + self.dial_peer(node_id, reply_tx).await }, }, NotifyListening(reply_tx) => match self.listener_address.as_ref() { @@ -348,7 +348,21 @@ where let _ = reply_tx.send(self.active_connections.get(&node_id).map(Clone::clone)); }, GetActiveConnections(reply_tx) => { - let _ = reply_tx.send(self.active_connections.values().cloned().collect()); + let _ = reply_tx.send( + self.active_connections + .values() + .filter(|conn| conn.is_connected()) + .cloned() + .collect(), + ); + }, + GetNumActiveConnections(reply_tx) => { + let _ = reply_tx.send( + self.active_connections + .values() + .filter(|conn| conn.is_connected()) + .count(), + ); }, DisconnectPeer(node_id, reply_tx) => match self.active_connections.remove(&node_id) { Some(mut conn) => { @@ -412,7 +426,7 @@ where self.send_dialer_request(DialerRequest::CancelPendingDial(node_id.clone())) .await; - match self.active_connections.remove(&node_id) { + match self.active_connections.get(&node_id) { Some(existing_conn) => { debug!( target: LOG_TARGET, @@ -421,7 +435,7 @@ where existing_conn.peer_node_id() ); - if self.tie_break_existing_connection(&existing_conn, &new_conn) { + if self.tie_break_existing_connection(existing_conn, &new_conn) { debug!( target: LOG_TARGET, "Disconnecting existing {} connection to peer '{}' because of simultaneous dial", @@ -434,8 +448,14 @@ where Box::new(existing_conn.peer_node_id().clone()), existing_conn.direction(), )); + + // Replace existing connection with new one + let existing_conn = self + .active_connections + .insert(node_id, new_conn.clone()) + .expect("Already checked"); + self.delayed_disconnect(existing_conn); - self.active_connections.insert(node_id, new_conn.clone()); self.publish_event(PeerConnected(new_conn)); } else { debug!( @@ -447,7 +467,6 @@ where ); self.delayed_disconnect(new_conn); - self.active_connections.insert(node_id, existing_conn); } }, None => { @@ -567,38 +586,12 @@ where &mut self, node_id: NodeId, reply_tx: oneshot::Sender>, - force_dial: bool, ) { match self.peer_manager.find_by_node_id(&node_id).await { Ok(peer) => { - if !force_dial && peer.is_recently_offline() { - debug!( - target: LOG_TARGET, - "Peer '{}' is offline (i.e. we failed to connect to them recently).", - peer.node_id.short_str() - ); - let _ = reply_tx.send(Err(ConnectionManagerError::PeerOffline)); - self.publish_event(ConnectionManagerEvent::PeerConnectFailed( - Box::new(peer.node_id), - ConnectionManagerError::PeerOffline, - )); - return; - } - - if let Err(err) = self.dialer_tx.try_send(DialerRequest::Dial(Box::new(peer), reply_tx)) { + if let Err(err) = self.dialer_tx.send(DialerRequest::Dial(Box::new(peer), reply_tx)).await { error!(target: LOG_TARGET, "Failed to send request to dialer because '{}'", err); - // TODO: If the channel is full - we'll fail to dial. This function should block until the dial - // request channel has cleared - - if let DialerRequest::Dial(_, reply_tx) = err.into_inner() { - log_if_error_fmt!( - target: LOG_TARGET, - reply_tx.send(Err(ConnectionManagerError::EstablisherChannelError)), - "Failed to send dial peer result for peer '{}'", - node_id.short_str() - ); - } } }, Err(err) => { diff --git a/comms/src/connection_manager/peer_connection.rs b/comms/src/connection_manager/peer_connection.rs index 5daa7f933c..86fd477949 100644 --- a/comms/src/connection_manager/peer_connection.rs +++ b/comms/src/connection_manager/peer_connection.rs @@ -46,6 +46,7 @@ use std::{ atomic::{AtomicUsize, Ordering}, Arc, }, + time::{Duration, Instant}, }; use tari_shutdown::Shutdown; @@ -107,6 +108,7 @@ pub struct PeerConnection { request_tx: mpsc::Sender, address: Multiaddr, direction: ConnectionDirection, + started_at: Instant, } impl PeerConnection { @@ -124,6 +126,7 @@ impl PeerConnection { peer_node_id: Arc::new(peer_node_id), address, direction, + started_at: Instant::now(), } } @@ -135,6 +138,10 @@ impl PeerConnection { self.direction } + pub fn address(&self) -> &Multiaddr { + &self.address + } + pub fn id(&self) -> ConnId { self.id } @@ -143,6 +150,10 @@ impl PeerConnection { !self.request_tx.is_closed() } + pub fn connected_since(&self) -> Duration { + self.started_at.elapsed() + } + pub fn reference_count(&self) -> usize { Arc::strong_count(&self.peer_node_id) } diff --git a/comms/src/connection_manager/requester.rs b/comms/src/connection_manager/requester.rs index 0284bb0177..b2d33b315a 100644 --- a/comms/src/connection_manager/requester.rs +++ b/comms/src/connection_manager/requester.rs @@ -35,18 +35,15 @@ pub enum ConnectionManagerRequest { /// Dial a given peer by node id. /// Parameters: /// 1. Node Id to dial - /// 1. If true, attempt to dial the peer even if we recently failed to dial them and they are considered offline - DialPeer( - NodeId, - bool, - oneshot::Sender>, - ), + DialPeer(NodeId, oneshot::Sender>), /// Register a oneshot to get triggered when the node is listening, or has failed to listen NotifyListening(oneshot::Sender), /// Retrieve an active connection for a given node id if one exists. GetActiveConnection(NodeId, oneshot::Sender>), /// Retrieve all active connections GetActiveConnections(oneshot::Sender>), + /// Retrieve the number of active connections + GetNumActiveConnections(oneshot::Sender), /// Disconnect a peer DisconnectPeer(NodeId, oneshot::Sender>), } @@ -100,6 +97,8 @@ macro_rules! request_fn { impl ConnectionManagerRequester { request_fn!(get_active_connections() -> Vec, request = ConnectionManagerRequest::GetActiveConnections); + request_fn!(get_num_active_connections() -> usize, request = ConnectionManagerRequest::GetNumActiveConnections); + request_fn!(get_active_connection(node_id: NodeId) -> Option, request = ConnectionManagerRequest::GetActiveConnection); request_fn!(disconnect_peer(node_id: NodeId) -> Result<(), ConnectionManagerError>, request = ConnectionManagerRequest::DisconnectPeer); @@ -111,23 +110,9 @@ impl ConnectionManagerRequester { /// Attempt to connect to a remote peer pub async fn dial_peer(&mut self, node_id: NodeId) -> Result { - self.send_dial_peer(node_id, false).await - } - - /// Attempt to connect to a remote peer, even if we failed to contact the peer recently - pub async fn dial_peer_forced(&mut self, node_id: NodeId) -> Result { - self.send_dial_peer(node_id, true).await - } - - async fn send_dial_peer( - &mut self, - node_id: NodeId, - is_forced: bool, - ) -> Result - { let (reply_tx, reply_rx) = oneshot::channel(); self.sender - .send(ConnectionManagerRequest::DialPeer(node_id, is_forced, reply_tx)) + .send(ConnectionManagerRequest::DialPeer(node_id, reply_tx)) .await .map_err(|_| ConnectionManagerError::SendToActorFailed)?; reply_rx diff --git a/comms/src/connection_manager/tests/manager.rs b/comms/src/connection_manager/tests/manager.rs index 0fabbf3e0d..e7b7a0386c 100644 --- a/comms/src/connection_manager/tests/manager.rs +++ b/comms/src/connection_manager/tests/manager.rs @@ -43,9 +43,8 @@ use std::time::Duration; use tari_shutdown::Shutdown; use tari_test_utils::{collect_stream, unpack_enum}; use tokio::{runtime::Handle, sync::broadcast}; -use tokio_macros as r#async; -#[r#async::test_basic] +#[tokio_macros::test_basic] async fn connect_to_nonexistent_peer() { let rt_handle = Handle::current(); let node_identity = build_node_identity(PeerFeatures::empty()); @@ -85,7 +84,7 @@ async fn connect_to_nonexistent_peer() { shutdown.trigger().unwrap(); } -#[r#async::test_basic] +#[tokio_macros::test_basic] async fn dial_success() { const TEST_PROTO: ProtocolId = ProtocolId::from_static(b"/test/valid"); let shutdown = Shutdown::new(); @@ -184,57 +183,11 @@ where .count() } -#[r#async::test_basic] -async fn dial_offline_peer() { - let shutdown = Shutdown::new(); - - let node_identity = build_node_identity(PeerFeatures::empty()); - - let peer_manager = build_peer_manager(); - let mut conn_man = build_connection_manager( - TestNodeConfig { - node_identity: node_identity.clone(), - ..Default::default() - }, - peer_manager.clone(), - Protocols::new(), - shutdown.to_signal(), - ); - - let public_address = conn_man.wait_until_listening().await.unwrap(); - let mut subscription = conn_man.get_event_subscription(); - - let mut peer = Peer::new( - node_identity.public_key().clone(), - node_identity.node_id().clone(), - vec![public_address].into(), - PeerFlags::empty(), - PeerFeatures::COMMUNICATION_CLIENT, - &[], - ); - - peer.connection_stats.set_connection_failed(); - assert_eq!(peer.is_recently_offline(), false); - peer.connection_stats.set_connection_failed(); - assert_eq!(peer.is_recently_offline(), true); - - peer_manager.add_peer(peer).await.unwrap(); - - let err = conn_man.dial_peer(node_identity.node_id().clone()).await.unwrap_err(); - unpack_enum!(ConnectionManagerError::PeerOffline = err); - - let event = subscription.next().await.unwrap().unwrap(); - - unpack_enum!(ConnectionManagerEvent::PeerConnectFailed(node_id, err) = &*event); - assert_eq!(&**node_id, node_identity.node_id()); - unpack_enum!(ConnectionManagerError::PeerOffline = err); -} - -#[r#async::test_basic] +#[tokio_macros::test_basic] async fn simultaneous_dial_events() { let mut shutdown = Shutdown::new(); - let node_identities = ordered_node_identities(2); + let node_identities = ordered_node_identities(2, Default::default()); // Setup connection manager 1 let peer_manager1 = build_peer_manager(); diff --git a/comms/src/connection_manager/wire_mode.rs b/comms/src/connection_manager/wire_mode.rs index 7dd0e9a820..b2771cd254 100644 --- a/comms/src/connection_manager/wire_mode.rs +++ b/comms/src/connection_manager/wire_mode.rs @@ -23,7 +23,7 @@ use std::convert::TryFrom; pub enum WireMode { - Comms = 0x01, + Comms = 0x02, Liveness = 0x45, // E } @@ -32,7 +32,7 @@ impl TryFrom for WireMode { fn try_from(value: u8) -> Result { match value { - 0x01 => Ok(WireMode::Comms), + 0x02 => Ok(WireMode::Comms), 0x45 => Ok(WireMode::Liveness), _ => Err(()), } diff --git a/comms/src/consts.rs b/comms/src/consts.rs index 0af8fdc2cc..3773cc15c0 100644 --- a/comms/src/consts.rs +++ b/comms/src/consts.rs @@ -28,7 +28,3 @@ pub const PEER_MANAGER_MAX_FLOOD_PEERS: usize = 1000; /// The amount of time to consider a peer to be offline (i.e. dial to peer will fail without trying) after a failed /// connection attempt pub const PEER_OFFLINE_COOLDOWN_PERIOD: Duration = Duration::from_secs(60); - -/// The envelope version. This should be increased any time a change is made to the -/// envelope proto files. -pub const ENVELOPE_VERSION: u32 = 0; diff --git a/comms/src/message/envelope.rs b/comms/src/message/envelope.rs index fcdf975c57..61a2c0d749 100644 --- a/comms/src/message/envelope.rs +++ b/comms/src/message/envelope.rs @@ -20,123 +20,21 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use super::{MessageError, MessageFlags}; -use crate::{ - consts::ENVELOPE_VERSION, - types::{CommsPublicKey, CommsSecretKey}, - utils::signature, -}; -use bytes::Bytes; -use rand::rngs::OsRng; -use std::convert::TryInto; -use tari_crypto::tari_utilities::{message_format::MessageFormat, ByteArray}; +use super::MessageError; // Re-export protos pub use crate::proto::envelope::*; -/// Represents data that every message contains. -/// As described in [RFC-0172](https://rfc.tari.com/RFC-0172_PeerToPeerMessagingProtocol.html#messaging-structure) -#[derive(Clone, Debug, PartialEq)] -pub struct MessageEnvelopeHeader { - pub public_key: CommsPublicKey, - pub signature: Bytes, - pub flags: MessageFlags, -} - -impl Envelope { - /// Sign a message, construct an Envelope with a Header - pub fn construct_signed( - secret_key: &CommsSecretKey, - public_key: &CommsPublicKey, - body: Bytes, - flags: MessageFlags, - ) -> Result - { - // Sign this body - let header_signature = { - let sig = - signature::sign(&mut OsRng, secret_key.clone(), &body).map_err(MessageError::SchnorrSignatureError)?; - sig.to_binary().map_err(MessageError::MessageFormatError) - }?; - - Ok(Envelope { - version: ENVELOPE_VERSION, - header: Some(EnvelopeHeader { - public_key: public_key.to_vec(), - signature: header_signature, - flags: flags.bits(), - }), - body: body.to_vec(), - }) - } - - /// Verify that the signature provided is valid for the given body - pub fn verify_signature(&self) -> Result { - match self - .header - .as_ref() - .map(|header| (header, header.get_comms_public_key())) - { - Some((header, Some(public_key))) => signature::verify(&public_key, &header.signature, &self.body), - _ => Ok(false), - } - } - - /// Returns true if the message contains a valid public key in the header, otherwise - /// false - pub fn is_valid(&self) -> bool { - self.get_public_key().is_some() - } - - /// Returns a valid public key from the header of this envelope, or None if the - /// public key is invalid - pub fn get_public_key(&self) -> Option { - self.header.as_ref().and_then(|header| header.get_comms_public_key()) - } -} - -impl EnvelopeHeader { - pub fn get_comms_public_key(&self) -> Option { - CommsPublicKey::from_bytes(&self.public_key).ok() - } -} - -impl TryInto for EnvelopeHeader { - type Error = MessageError; - - fn try_into(self) -> Result { - Ok(MessageEnvelopeHeader { - public_key: self - .get_comms_public_key() - .ok_or_else(|| MessageError::InvalidHeaderPublicKey)?, - signature: self.signature.into(), - flags: MessageFlags::from_bits_truncate(self.flags), - }) - } -} - /// Wraps a number of `prost::Message`s in a EnvelopeBody #[macro_export] macro_rules! wrap_in_envelope_body { ($($e:expr),+) => {{ use $crate::message::MessageExt; let mut envelope_body = $crate::message::EnvelopeBody::new(); - let mut error = None; $( - match $e.to_encoded_bytes() { - Ok(bytes) => envelope_body.push_part(bytes), - Err(err) => { - if error.is_none() { - error = Some(err); - } - } - } + envelope_body.push_part($e.to_encoded_bytes()); )* - - match error { - Some(err) => Err(err), - None => Ok(envelope_body), - } + envelope_body }} } @@ -151,6 +49,10 @@ impl EnvelopeBody { self.parts.len() } + pub fn total_size(&self) -> usize { + self.parts.iter().fold(0, |acc, b| acc + b.len()) + } + pub fn is_empty(&self) -> bool { self.parts.is_empty() } @@ -181,42 +83,3 @@ impl EnvelopeBody { } } } - -#[cfg(test)] -mod test { - use super::*; - use crate::message::MessageFlags; - use rand::rngs::OsRng; - use tari_crypto::keys::PublicKey; - - #[test] - fn construct_signed() { - let (sk, pk) = CommsPublicKey::random_keypair(&mut OsRng); - let envelope = Envelope::construct_signed(&sk, &pk, Bytes::new(), MessageFlags::all()).unwrap(); - assert_eq!(envelope.get_public_key().unwrap(), pk); - assert!(envelope.verify_signature().unwrap()); - } - - #[test] - fn header_try_into() { - let header = EnvelopeHeader { - public_key: CommsPublicKey::default().to_vec(), - flags: MessageFlags::all().bits(), - signature: vec![1, 2, 3], - }; - - let msg_header: MessageEnvelopeHeader = header.try_into().unwrap(); - assert_eq!(msg_header.public_key, CommsPublicKey::default()); - assert_eq!(msg_header.flags, MessageFlags::all()); - assert_eq!(msg_header.signature, vec![1, 2, 3]); - } - - #[test] - fn is_valid() { - let (sk, pk) = CommsPublicKey::random_keypair(&mut OsRng); - let mut envelope = Envelope::construct_signed(&sk, &pk, Bytes::new(), MessageFlags::all()).unwrap(); - assert_eq!(envelope.is_valid(), true); - envelope.header = None; - assert_eq!(envelope.is_valid(), false); - } -} diff --git a/comms/src/message/error.rs b/comms/src/message/error.rs index 76b23f2761..f56c4cdb4d 100644 --- a/comms/src/message/error.rs +++ b/comms/src/message/error.rs @@ -22,7 +22,7 @@ use crate::peer_manager::node_id::NodeIdError; use derive_error::Error; -use prost::{DecodeError, EncodeError}; +use prost::DecodeError; use tari_crypto::{ signatures::SchnorrSignatureError, tari_utilities::{ciphers::cipher::CipherError, message_format::MessageFormatError}, @@ -53,8 +53,6 @@ pub enum MessageError { InvalidHeaderPublicKey, /// Failed to decode protobuf message DecodeError(DecodeError), - /// Failed to encode protobuf message - EncodeError(EncodeError), /// Failed to decode message part of envelope body EnvelopeBodyDecodeFailed, } diff --git a/comms/src/message/inbound.rs b/comms/src/message/inbound.rs index 59216fe62c..e4c76f0041 100644 --- a/comms/src/message/inbound.rs +++ b/comms/src/message/inbound.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use super::{MessageEnvelopeHeader, MessageTag}; +use super::MessageTag; use crate::peer_manager::Peer; use bytes::Bytes; use std::sync::Arc; @@ -29,8 +29,6 @@ use std::sync::Arc; #[derive(Clone, Debug)] pub struct InboundMessage { pub tag: MessageTag, - /// The deserialized message envelope header - pub envelope_header: MessageEnvelopeHeader, /// The connected peer which sent this message pub source_peer: Arc, /// The raw message envelope @@ -39,11 +37,10 @@ pub struct InboundMessage { impl InboundMessage { /// Construct a new InboundMessage - pub fn new(source_peer: Arc, envelope_header: MessageEnvelopeHeader, body: Bytes) -> Self { + pub fn new(source_peer: Arc, body: Bytes) -> Self { Self { tag: MessageTag::new(), source_peer, - envelope_header, body, } } diff --git a/comms/src/message/mod.rs b/comms/src/message/mod.rs index d56383b369..0590a36298 100644 --- a/comms/src/message/mod.rs +++ b/comms/src/message/mod.rs @@ -59,12 +59,10 @@ //! [MessageHeader]: ./message/struct.MessageHeader.html //! [MessageData]: ./message/struct.MessageData.html //! [DomainConnector]: ../domain_connector/struct.DomainConnector.html -use bitflags::*; -use serde::{Deserialize, Serialize}; #[macro_use] mod envelope; -pub use envelope::{Envelope, EnvelopeBody, EnvelopeHeader, MessageEnvelopeHeader}; +pub use envelope::EnvelopeBody; mod error; pub use error::MessageError; @@ -73,28 +71,21 @@ mod inbound; pub use inbound::InboundMessage; mod outbound; -pub use outbound::OutboundMessage; +pub use outbound::{MessagingReplyRx, MessagingReplyTx, OutboundMessage}; mod tag; pub use tag::MessageTag; pub trait MessageExt: prost::Message { /// Encodes a message, allocating the buffer on the heap as necessary - fn to_encoded_bytes(&self) -> Result, MessageError> + fn to_encoded_bytes(&self) -> Vec where Self: Sized { let mut buf = Vec::with_capacity(self.encoded_len()); - self.encode(&mut buf)?; - Ok(buf) + self.encode(&mut buf).expect( + "prost::Message::encode documentation says it is infallible unless the buffer has insufficient capacity. \ + This buffer's capacity was set with encoded_len", + ); + buf } } impl MessageExt for T {} - -bitflags! { - /// Used to indicate characteristics of the incoming or outgoing message, such - /// as whether the message is encrypted. - #[derive(Default, Deserialize, Serialize)] - pub struct MessageFlags: u32 { - const NONE = 0b0000_0000; - const ENCRYPTED = 0b0000_0001; - } -} diff --git a/comms/src/message/outbound.rs b/comms/src/message/outbound.rs index 0b74e6f864..26287d7769 100644 --- a/comms/src/message/outbound.rs +++ b/comms/src/message/outbound.rs @@ -20,43 +20,59 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{ - message::{MessageFlags, MessageTag}, - peer_manager::NodeId, -}; +use crate::{message::MessageTag, peer_manager::NodeId}; use bytes::Bytes; +use futures::channel::oneshot; use std::{ fmt, fmt::{Error, Formatter}, }; +pub type MessagingReplyTx = oneshot::Sender>; +pub type MessagingReplyRx = oneshot::Receiver>; + /// Contains details required to build a message envelope and send a message to a peer. OutboundMessage will not copy /// the body bytes when cloned and is 'cheap to clone(tm)'. -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Debug)] pub struct OutboundMessage { pub tag: MessageTag, pub peer_node_id: NodeId, - pub flags: MessageFlags, pub body: Bytes, + pub reply_tx: Option, } impl OutboundMessage { - /// Create a new OutboundMessage - pub fn new(peer_node_id: NodeId, flags: MessageFlags, body: Bytes) -> OutboundMessage { - Self::with_tag(MessageTag::new(), peer_node_id, flags, body) - } - - /// Create a new OutboundMessage with the specified MessageTag - pub fn with_tag(tag: MessageTag, peer_node_id: NodeId, flags: MessageFlags, body: Bytes) -> OutboundMessage { - OutboundMessage { - tag, + pub fn new(peer_node_id: NodeId, body: Bytes) -> Self { + Self { + tag: MessageTag::new(), peer_node_id, - flags, body, + reply_tx: None, + } + } + + pub fn reply_fail(&mut self) { + self.oneshot_reply(Err(())); + } + + pub fn reply_success(&mut self) { + self.oneshot_reply(Ok(())); + } + + #[inline] + fn oneshot_reply(&mut self, result: Result<(), ()>) { + if let Some(reply_tx) = self.reply_tx.take() { + let _ = reply_tx.send(result); } } } +impl Drop for OutboundMessage { + fn drop(&mut self) { + self.reply_fail(); + } +} + impl fmt::Display for OutboundMessage { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { write!( @@ -78,7 +94,12 @@ mod test { static TEST_MSG: Bytes = Bytes::from_static(b"The ghost brigades"); let node_id = NodeId::new(); let tag = MessageTag::new(); - let subject = OutboundMessage::with_tag(tag, node_id.clone(), MessageFlags::empty(), TEST_MSG.clone()); + let subject = OutboundMessage { + tag, + peer_node_id: node_id.clone(), + reply_tx: None, + body: TEST_MSG.clone(), + }; assert_eq!(tag, subject.tag); assert_eq!(subject.body, TEST_MSG); assert_eq!(subject.peer_node_id, node_id); diff --git a/comms/src/peer_manager/connection_stats.rs b/comms/src/peer_manager/connection_stats.rs index 88a934dcc5..74b396093f 100644 --- a/comms/src/peer_manager/connection_stats.rs +++ b/comms/src/peer_manager/connection_stats.rs @@ -89,17 +89,17 @@ impl PeerConnectionStats { impl fmt::Display for PeerConnectionStats { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self.last_connected_at.as_ref() { - Some(dt) => { - write!(f, "Last connected at '{}'.", dt)?; - - if self.last_failed_at().is_some() { - write!(f, " {}", self.last_connection_attempt)?; - } + match self.last_failed_at() { + Some(_) => { + write!(f, "{}", self.last_connection_attempt)?; }, - None => { - write!(f, "Never connected to this peer.")?; - write!(f, " {}", self.last_connection_attempt)?; + None => match self.last_connected_at.as_ref() { + Some(dt) => { + write!(f, "Last connected at {}", dt.format("%Y-%m-%d %H:%M:%S"))?; + }, + None => { + write!(f, "{}", self.last_connection_attempt)?; + }, }, } @@ -145,8 +145,9 @@ impl Display for LastConnectionAttempt { num_attempts, } => write!( f, - "Connection failed at {} after {} attempt(s)", - failed_at, num_attempts + "Connection failed at {} ({} attempt(s))", + failed_at.format("%Y-%m-%d %H:%M:%S"), + num_attempts ), } } diff --git a/comms/src/peer_manager/manager.rs b/comms/src/peer_manager/manager.rs index b19b5fdde6..2ef42f2ba0 100644 --- a/comms/src/peer_manager/manager.rs +++ b/comms/src/peer_manager/manager.rs @@ -23,10 +23,10 @@ use crate::{ peer_manager::{ connection_stats::PeerConnectionStats, - node_id::NodeId, + node_id::{NodeDistance, NodeId}, peer::{Peer, PeerFlags}, peer_id::PeerId, - peer_storage::PeerStorage, + peer_storage::{PeerStorage, RegionStats}, PeerFeatures, PeerManagerError, PeerQuery, @@ -35,6 +35,7 @@ use crate::{ types::{CommsDatabase, CommsPublicKey}, }; use multiaddr::Multiaddr; +use std::time::Duration; use tari_storage::IterationResult; use tokio::sync::RwLock; @@ -68,6 +69,8 @@ impl PeerManager { node_id: Option, net_addresses: Option>, flags: Option, + #[allow(clippy::option_option)] banned_until: Option>, + #[allow(clippy::option_option)] is_offline: Option, peer_features: Option, connection_stats: Option, supported_protocols: Option>, @@ -78,6 +81,8 @@ impl PeerManager { node_id, net_addresses, flags, + banned_until, + is_offline, peer_features, connection_stats, supported_protocols, @@ -89,13 +94,14 @@ impl PeerManager { let mut storage = self.peer_storage.write().await; let mut peer = storage.find_by_node_id(node_id)?; peer.connection_stats.set_connection_success(); - peer.flags.remove(PeerFlags::OFFLINE); storage.update_peer( &peer.public_key, None, None, None, None, + Some(false), + None, Some(peer.connection_stats), None, ) @@ -112,6 +118,8 @@ impl PeerManager { None, None, None, + None, + None, Some(peer.connection_stats), None, ) @@ -186,21 +194,26 @@ impl PeerManager { self.peer_storage.read().await.for_each(f) } - /// Fetch n nearest neighbour Communication Nodes + /// Fetch n nearest neighbours. If features are supplied, the function will return the closest peers matching that + /// feature pub async fn closest_peers( &self, node_id: &NodeId, n: usize, excluded_peers: &[CommsPublicKey], + features: Option, ) -> Result, PeerManagerError> { - self.peer_storage.read().await.closest_peers(node_id, n, excluded_peers) + self.peer_storage + .read() + .await + .closest_peers(node_id, n, excluded_peers, features) } /// Fetch n random peers - pub async fn random_peers(&self, n: usize) -> Result, PeerManagerError> { + pub async fn random_peers(&self, n: usize, excluded: Vec) -> Result, PeerManagerError> { // Send to a random set of peers of size n that are Communication Nodes - self.peer_storage.read().await.random_peers(n) + self.peer_storage.read().await.random_peers(n, excluded) } /// Check if a specific node_id is in the network region of the N nearest neighbours of the region specified by @@ -218,9 +231,27 @@ impl PeerManager { .in_network_region(node_id, region_node_id, n) } - /// Changes the ban flag bit of the peer - pub async fn set_banned(&self, public_key: &CommsPublicKey, ban_flag: bool) -> Result { - self.peer_storage.write().await.set_banned(public_key, ban_flag) + pub async fn calc_region_threshold( + &self, + region_node_id: &NodeId, + n: usize, + features: PeerFeatures, + ) -> Result + { + self.peer_storage + .read() + .await + .calc_region_threshold(region_node_id, n, features) + } + + /// Unbans the peer if it is banned. This function is idempotent. + pub async fn unban(&self, public_key: &CommsPublicKey) -> Result { + self.peer_storage.write().await.unban(public_key) + } + + /// Ban the peer for a length of time specified by the duration + pub async fn ban_for(&self, public_key: &CommsPublicKey, duration: Duration) -> Result { + self.peer_storage.write().await.ban_for(public_key, duration) } /// Changes the offline flag bit of the peer @@ -232,6 +263,45 @@ impl PeerManager { pub async fn add_net_address(&self, node_id: &NodeId, net_address: &Multiaddr) -> Result<(), PeerManagerError> { self.peer_storage.write().await.add_net_address(node_id, net_address) } + + pub async fn update_each(&self, mut f: F) -> Result + where F: FnMut(Peer) -> Option { + let mut lock = self.peer_storage.write().await; + let mut peers_to_update = Vec::new(); + lock.for_each(|peer| { + if let Some(peer) = (f)(peer) { + peers_to_update.push(peer); + } + IterationResult::Continue + })?; + + let updated_count = peers_to_update.len(); + for p in peers_to_update { + lock.add_peer(p)?; + } + + Ok(updated_count) + } + + /// Return some basic stats about the region around region_node_id + pub async fn get_region_stats<'a>( + &self, + region_node_id: &'a NodeId, + n: usize, + features: PeerFeatures, + ) -> Result, PeerManagerError> + { + self.peer_storage + .read() + .await + .get_region_stats(region_node_id, n, features) + } + + pub async fn get_peer_features(&self, node_id: &NodeId) -> Result { + // TODO: #sqliterefactor fetch the features with a sql query + let peer = self.find_by_node_id(node_id).await?; + Ok(peer.features) + } } #[cfg(test)] @@ -249,19 +319,14 @@ mod test { use tari_crypto::{keys::PublicKey, ristretto::RistrettoPublicKey}; use tari_storage::HashmapDatabase; - fn create_test_peer(ban_flag: bool) -> Peer { + fn create_test_peer(ban_flag: bool, features: PeerFeatures) -> Peer { let (_sk, pk) = RistrettoPublicKey::random_keypair(&mut OsRng); let node_id = NodeId::from_key(&pk).unwrap(); let net_addresses = MultiaddressesWithStats::from("/ip4/1.2.3.4/tcp/8000".parse::().unwrap()); - let mut peer = Peer::new( - pk, - node_id, - net_addresses, - PeerFlags::default(), - PeerFeatures::MESSAGE_PROPAGATION, - &[], - ); - peer.set_banned(ban_flag); + let mut peer = Peer::new(pk, node_id, net_addresses, PeerFlags::default(), features, &[]); + if ban_flag { + peer.ban_for(Duration::from_secs(1000)); + } peer } @@ -271,19 +336,19 @@ mod test { let peer_manager = PeerManager::new(HashmapDatabase::new()).unwrap(); let mut test_peers = Vec::new(); // Create 20 peers were the 1st and last one is bad - test_peers.push(create_test_peer(true)); + test_peers.push(create_test_peer(true, PeerFeatures::COMMUNICATION_NODE)); assert!(peer_manager .add_peer(test_peers[test_peers.len() - 1].clone()) .await .is_ok()); for _i in 0..18 { - test_peers.push(create_test_peer(false)); + test_peers.push(create_test_peer(false, PeerFeatures::COMMUNICATION_NODE)); assert!(peer_manager .add_peer(test_peers[test_peers.len() - 1].clone()) .await .is_ok()); } - test_peers.push(create_test_peer(true)); + test_peers.push(create_test_peer(true, PeerFeatures::COMMUNICATION_NODE)); assert!(peer_manager .add_peer(test_peers[test_peers.len() - 1].clone()) .await @@ -298,7 +363,7 @@ mod test { assert_eq!(selected_peers.node_id, test_peers[2].node_id); assert_eq!(selected_peers.public_key, test_peers[2].public_key); // Test Invalid Direct - let unmanaged_peer = create_test_peer(false); + let unmanaged_peer = create_test_peer(false, PeerFeatures::COMMUNICATION_NODE); assert!(peer_manager .direct_identity_node_id(&unmanaged_peer.node_id) .await @@ -321,7 +386,7 @@ mod test { // Test Closest - No exclusions let selected_peers = peer_manager - .closest_peers(&unmanaged_peer.node_id, 3, &Vec::new()) + .closest_peers(&unmanaged_peer.node_id, 3, &[], None) .await .unwrap(); assert_eq!(selected_peers.len(), 3); @@ -349,7 +414,7 @@ mod test { selected_peers[0].public_key.clone(), // ,selected_peers[1].public_key.clone() ]; let selected_peers = peer_manager - .closest_peers(&unmanaged_peer.node_id, 3, &excluded_peers) + .closest_peers(&unmanaged_peer.node_id, 3, &excluded_peers, None) .await .unwrap(); assert_eq!(selected_peers.len(), 3); @@ -373,51 +438,110 @@ mod test { } // Test Random - let identities1 = peer_manager.random_peers(10).await.unwrap(); - let identities2 = peer_manager.random_peers(10).await.unwrap(); + let identities1 = peer_manager.random_peers(10, vec![]).await.unwrap(); + let identities2 = peer_manager.random_peers(10, vec![]).await.unwrap(); assert_ne!(identities1, identities2); } #[tokio_macros::test_basic] - async fn test_in_network_region() { - let _rng = rand::rngs::OsRng; + async fn calc_region_threshold() { + let n = 5; // Create peer manager with random peers let peer_manager = PeerManager::new(HashmapDatabase::new()).unwrap(); - let network_region_node_id = create_test_peer(false).node_id; - // Create peers - let mut test_peers: Vec = Vec::new(); - for _ in 0..10 { - test_peers.push(create_test_peer(false)); - assert!(peer_manager - .add_peer(test_peers[test_peers.len() - 1].clone()) - .await - .is_ok()); + let network_region_node_id = create_test_peer(false, Default::default()).node_id; + let mut test_peers = (0..10) + .map(|_| create_test_peer(false, PeerFeatures::COMMUNICATION_NODE)) + .chain((0..10).map(|_| create_test_peer(false, PeerFeatures::COMMUNICATION_CLIENT))) + .collect::>(); + + for p in &test_peers { + peer_manager.add_peer(p.clone()).await.unwrap(); } - test_peers[0].set_banned(true); - test_peers[1].set_banned(true); - // Get nearest neighbours - let n = 5; - let nearest_identities = peer_manager - .closest_peers(&network_region_node_id, n, &Vec::new()) + test_peers.sort_by(|a, b| { + let a_dist = network_region_node_id.distance(&a.node_id); + let b_dist = network_region_node_id.distance(&b.node_id); + a_dist.partial_cmp(&b_dist).unwrap() + }); + + let node_region_threshold = peer_manager + .calc_region_threshold(&network_region_node_id, n, PeerFeatures::COMMUNICATION_NODE) .await .unwrap(); - for peer in &test_peers { - if nearest_identities + // First 5 base nodes should be within the region + for peer in test_peers + .iter() + .filter(|p| p.features == PeerFeatures::COMMUNICATION_NODE) + .take(n) + { + assert!(peer.node_id.distance(&network_region_node_id) <= node_region_threshold); + } + + // Next 5 should not be in the region + for peer in test_peers + .iter() + .filter(|p| p.features == PeerFeatures::COMMUNICATION_NODE) + .skip(n) + { + assert!(peer.node_id.distance(&network_region_node_id) > node_region_threshold); + } + + let node_region_threshold = peer_manager + .calc_region_threshold(&network_region_node_id, n, PeerFeatures::COMMUNICATION_CLIENT) + .await + .unwrap(); + + // First 5 clients should be in region + for peer in test_peers + .iter() + .filter(|p| p.features == PeerFeatures::COMMUNICATION_CLIENT) + .take(5) + { + assert!(peer.node_id.distance(&network_region_node_id) <= node_region_threshold); + } + + // Next 5 should not be in the region + for peer in test_peers + .iter() + .filter(|p| p.features == PeerFeatures::COMMUNICATION_CLIENT) + .skip(5) + { + assert!(peer.node_id.distance(&network_region_node_id) > node_region_threshold); + } + } + + #[tokio_macros::test_basic] + async fn closest_peers() { + let n = 5; + // Create peer manager with random peers + let peer_manager = PeerManager::new(HashmapDatabase::new()).unwrap(); + let network_region_node_id = create_test_peer(false, Default::default()).node_id; + let test_peers = (0..10) + .map(|_| create_test_peer(false, PeerFeatures::COMMUNICATION_NODE)) + .chain((0..10).map(|_| create_test_peer(false, PeerFeatures::COMMUNICATION_CLIENT))) + .collect::>(); + + for p in &test_peers { + peer_manager.add_peer(p.clone()).await.unwrap(); + } + + for features in &[PeerFeatures::COMMUNICATION_NODE, PeerFeatures::COMMUNICATION_CLIENT] { + let node_threshold = peer_manager + .peer_storage + .read() + .await + .calc_region_threshold(&network_region_node_id, n, *features) + .unwrap(); + + let closest = peer_manager + .closest_peers(&network_region_node_id, n, &[], Some(*features)) + .await + .unwrap(); + + assert!(closest .iter() - .any(|peer_identity| peer.node_id == peer_identity.node_id) - { - assert!(peer_manager - .in_network_region(&peer.node_id, &network_region_node_id, n) - .await - .unwrap()); - } else { - assert!(!peer_manager - .in_network_region(&peer.node_id, &network_region_node_id, n) - .await - .unwrap()); - } + .all(|p| network_region_node_id.distance(&p.node_id) <= node_threshold)); } } } diff --git a/comms/src/peer_manager/node_id.rs b/comms/src/peer_manager/node_id.rs index c7b1552cfc..52c00c53a2 100644 --- a/comms/src/peer_manager/node_id.rs +++ b/comms/src/peer_manager/node_id.rs @@ -68,7 +68,7 @@ impl NodeDistance { nd } - pub fn max_distance() -> NodeDistance { + pub const fn max_distance() -> NodeDistance { NodeDistance([255; NODE_ID_ARRAY_SIZE]) } } @@ -94,6 +94,21 @@ impl TryFrom<&[u8]> for NodeDistance { } } +impl ByteArray for NodeDistance { + /// Try and convert the given byte array to a NodeDistance. Any failures (incorrect array length, + /// implementation-specific checks, etc) return a [ByteArrayError](enum.ByteArrayError.html). + fn from_bytes(bytes: &[u8]) -> Result { + bytes + .try_into() + .map_err(|err| ByteArrayError::ConversionError(format!("{:?}", err))) + } + + /// Return the NodeDistance as a byte array + fn as_bytes(&self) -> &[u8] { + self.0.as_ref() + } +} + impl fmt::Display for NodeDistance { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", to_hex(&self.0)) @@ -101,7 +116,7 @@ impl fmt::Display for NodeDistance { } /// A Node Identity is used as a unique identifier for a node in the Tari communications network. -#[derive(Clone, Debug, Eq, Deserialize, Serialize, Default)] +#[derive(Clone, Eq, Deserialize, Serialize, Default)] pub struct NodeId(NodeIdArray); impl NodeId { @@ -254,6 +269,12 @@ impl fmt::Display for NodeId { } } +impl fmt::Debug for NodeId { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "NodeId({})", to_hex(&self.0)) + } +} + pub fn deserialize_node_id_from_hex<'de, D>(des: D) -> Result where D: Deserializer<'de> { struct KeyStringVisitor { @@ -355,6 +376,12 @@ mod test { assert!(n1_to_n2_dist < n1_to_n3_dist); assert_eq!(n1_to_n2_dist, desired_n1_to_n2_dist); assert_eq!(n1_to_n3_dist, desired_n1_to_n3_dist); + + // Commutative + let n1_to_n2_dist = node_id1.distance(&node_id2); + let n2_to_n1_dist = node_id2.distance(&node_id1); + + assert_eq!(n1_to_n2_dist, n2_to_n1_dist); } #[test] diff --git a/comms/src/peer_manager/peer.rs b/comms/src/peer_manager/peer.rs index 3bd6be2893..a887609eff 100644 --- a/comms/src/peer_manager/peer.rs +++ b/comms/src/peer_manager/peer.rs @@ -31,19 +31,19 @@ use crate::{ net_address::MultiaddressesWithStats, protocol::ProtocolId, types::CommsPublicKey, + utils::datetime::safe_future_datetime_from_duration, }; use bitflags::bitflags; use chrono::{DateTime, NaiveDateTime, Utc}; use multiaddr::Multiaddr; use serde::{Deserialize, Serialize}; -use std::fmt::Display; +use std::{fmt::Display, time::Duration}; use tari_crypto::tari_utilities::hex::serialize_to_hex; bitflags! { #[derive(Default, Deserialize, Serialize)] pub struct PeerFlags: u8 { - const BANNED = 0x01; - const OFFLINE = 0x02; + const NONE = 0x00; } } @@ -68,8 +68,10 @@ pub struct Peer { pub node_id: NodeId, /// Peer's addresses pub addresses: MultiaddressesWithStats, - /// Flags for the peer. Indicates if the peer is banned. + /// Flags for the peer. pub flags: PeerFlags, + pub banned_until: Option, + pub offline_at: Option, /// Features supported by the peer pub features: PeerFeatures, /// Connection statics for the peer @@ -99,6 +101,8 @@ impl Peer { addresses, flags, features, + banned_until: None, + offline_at: None, connection_stats: Default::default(), added_at: Utc::now().naive_utc(), supported_protocols: supported_protocols.into_iter().cloned().collect(), @@ -137,7 +141,7 @@ impl Peer { /// Returns true if the peer is marked as offline pub fn is_offline(&self) -> bool { - self.flags.contains(PeerFlags::OFFLINE) + self.offline_at.is_some() } /// TODO: Remove once we don't have to sync wallet and base node db @@ -146,7 +150,6 @@ impl Peer { } pub(super) fn set_id(&mut self, id: PeerId) { - debug_assert!(self.id.is_none()); self.id = Some(id); } @@ -160,6 +163,8 @@ impl Peer { node_id: Option, net_addresses: Option>, flags: Option, + #[allow(clippy::option_option)] banned_until: Option>, + #[allow(clippy::option_option)] is_offline: Option, features: Option, connection_stats: Option, supported_protocols: Option>, @@ -174,6 +179,14 @@ impl Peer { if let Some(new_flags) = flags { self.flags = new_flags } + if let Some(banned_until) = banned_until { + self.banned_until = banned_until + .map(safe_future_datetime_from_duration) + .map(|dt| dt.naive_utc()); + } + if let Some(is_offline) = is_offline { + self.set_offline(is_offline); + } if let Some(new_features) = features { self.features = new_features; } @@ -197,33 +210,66 @@ impl Peer { /// Returns the ban status of the peer pub fn is_banned(&self) -> bool { - self.flags.contains(PeerFlags::BANNED) + self.banned_until().is_some() } - /// Changes the BANNED flag bit of the peer - pub fn set_banned(&mut self, ban_flag: bool) { - self.flags.set(PeerFlags::BANNED, ban_flag); + /// Bans the peer for a specified duration + pub fn ban_for(&mut self, duration: Duration) { + let dt = safe_future_datetime_from_duration(duration); + self.banned_until = Some(dt.naive_utc()); } - /// Changes the OFFLINE flag bit of the peer + /// Unban the peer + pub fn unban(&mut self) { + self.banned_until = None; + } + + pub fn banned_until(&self) -> Option<&NaiveDateTime> { + self.banned_until.as_ref().filter(|dt| *dt > &Utc::now().naive_utc()) + } + + /// Marks the peer as offline pub fn set_offline(&mut self, is_offline: bool) { - self.flags.set(PeerFlags::OFFLINE, is_offline); + if is_offline { + self.offline_at = Some(Utc::now().naive_utc()); + } else { + self.offline_at = None; + } } } /// Display Peer as `[peer_id]: ` impl Display for Peer { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let flags_str = if self.flags == PeerFlags::empty() { + "".to_string() + } else { + format!("{:?}", self.flags) + }; + + let status_str = { + let mut s = Vec::new(); + if let Some(offline_at) = self.offline_at.as_ref() { + s.push(format!("OFFLINE since {}", offline_at)); + } + + if let Some(dt) = self.banned_until() { + s.push(format!("BANNED until {}", dt)); + } + s.join(", ") + }; f.write_str(&format!( - "{}[{}] PK={} {} {:?} {}", - if self.is_banned() { "BANNED " } else { "" }, + "{}[{}] PK={} ({}) {} {:?} {}", + flags_str, self.node_id.short_str(), self.public_key, self.addresses - .address_iter() - .next() + .addresses + .iter() .map(ToString::to_string) - .unwrap_or_else(|| "".to_string()), + .collect::>() + .join(","), + status_str, match self.features { PeerFeatures::COMMUNICATION_NODE => "BASE_NODE".to_string(), PeerFeatures::COMMUNICATION_CLIENT => "WALLET".to_string(), @@ -246,16 +292,16 @@ mod test { }; #[test] - fn test_is_and_set_banned() { + fn test_is_banned_and_ban_for() { let mut rng = rand::rngs::OsRng; let (_sk, pk) = RistrettoPublicKey::random_keypair(&mut rng); let node_id = NodeId::from_key(&pk).unwrap(); let addresses = MultiaddressesWithStats::from("/ip4/123.0.0.123/tcp/8000".parse::().unwrap()); let mut peer: Peer = Peer::new(pk, node_id, addresses, PeerFlags::default(), PeerFeatures::empty(), &[]); assert_eq!(peer.is_banned(), false); - peer.set_banned(true); + peer.ban_for(Duration::from_millis(std::u64::MAX)); assert_eq!(peer.is_banned(), true); - peer.set_banned(false); + peer.ban_for(Duration::from_millis(0)); assert_eq!(peer.is_banned(), false); } @@ -282,7 +328,9 @@ mod test { peer.update( Some(node_id2.clone()), Some(vec![net_address2.clone(), net_address3.clone()]), - Some(PeerFlags::BANNED), + None, + Some(Some(Duration::from_secs(1000))), + None, Some(PeerFeatures::MESSAGE_PROPAGATION), Some(PeerConnectionStats::new()), Some(vec![protocol::IDENTITY_PROTOCOL.clone()]), @@ -305,7 +353,7 @@ mod test { .addresses .iter() .any(|net_address_with_stats| net_address_with_stats.address == net_address3)); - assert_eq!(peer.flags, PeerFlags::BANNED); + assert!(peer.is_banned()); assert_eq!(peer.has_features(PeerFeatures::MESSAGE_PROPAGATION), true); assert_eq!(peer.supported_protocols, vec![protocol::IDENTITY_PROTOCOL.clone()]); } diff --git a/comms/src/peer_manager/peer_features.rs b/comms/src/peer_manager/peer_features.rs index 402840a320..50b3c332f4 100644 --- a/comms/src/peer_manager/peer_features.rs +++ b/comms/src/peer_manager/peer_features.rs @@ -22,6 +22,7 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; +use std::fmt; bitflags! { #[derive(Serialize, Deserialize)] @@ -40,3 +41,9 @@ impl Default for PeerFeatures { PeerFeatures::NONE } } + +impl fmt::Display for PeerFeatures { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self) + } +} diff --git a/comms/src/peer_manager/peer_query.rs b/comms/src/peer_manager/peer_query.rs index e0776d1f9a..a5da8363ed 100644 --- a/comms/src/peer_manager/peer_query.rs +++ b/comms/src/peer_manager/peer_query.rs @@ -218,7 +218,7 @@ mod test { }; use multiaddr::Multiaddr; use rand::rngs::OsRng; - use std::iter::repeat_with; + use std::{iter::repeat_with, time::Duration}; use tari_crypto::{keys::PublicKey, ristretto::RistrettoPublicKey}; use tari_storage::HashmapDatabase; @@ -234,7 +234,9 @@ mod test { PeerFeatures::MESSAGE_PROPAGATION, &[], ); - peer.set_banned(ban_flag); + if ban_flag { + peer.ban_for(Duration::from_secs(1000)); + } peer } diff --git a/comms/src/peer_manager/peer_storage.rs b/comms/src/peer_manager/peer_storage.rs index 7c488d8d0d..8f7e46d974 100644 --- a/comms/src/peer_manager/peer_storage.rs +++ b/comms/src/peer_manager/peer_storage.rs @@ -37,7 +37,7 @@ use crate::{ use log::*; use multiaddr::Multiaddr; use rand::{rngs::OsRng, Rng}; -use std::{cmp::min, collections::HashMap}; +use std::{cmp, collections::HashMap, fmt, time::Duration}; use tari_storage::{IterationResult, KeyValueStore}; const LOG_TARGET: &str = "comms::peer_manager::peer_storage"; @@ -121,6 +121,8 @@ where DS: KeyValueStore node_id: Option, net_addresses: Option>, flags: Option, + #[allow(clippy::option_option)] banned_until: Option>, + #[allow(clippy::option_option)] is_offline: Option, peer_features: Option, connection_stats: Option, supported_protocols: Option>, @@ -149,6 +151,8 @@ where DS: KeyValueStore node_id, net_addresses, flags, + banned_until, + is_offline, peer_features, connection_stats, supported_protocols, @@ -301,13 +305,18 @@ where DS: KeyValueStore node_id: &NodeId, n: usize, excluded_peers: &[CommsPublicKey], + features: Option, ) -> Result, PeerManagerError> { let mut peer_keys = Vec::new(); let mut dists = Vec::new(); self.peer_db .for_each_ok(|(peer_key, peer)| { - if !peer.is_banned() && !excluded_peers.contains(&peer.public_key) { + if features.map(|f| peer.features == f).unwrap_or(true) && + !peer.is_banned() && + !peer.is_offline() && + !excluded_peers.contains(&peer.public_key) + { peer_keys.push(peer_key); dists.push(node_id.distance(&peer.node_id)); } @@ -315,7 +324,7 @@ where DS: KeyValueStore }) .map_err(PeerManagerError::DatabaseError)?; // Use all available peers up to a maximum of N - let max_available = min(peer_keys.len(), n); + let max_available = cmp::min(peer_keys.len(), n); if max_available == 0 { return Ok(Vec::new()); } @@ -340,17 +349,22 @@ where DS: KeyValueStore Ok(nearest_identities) } - /// Compile a random list of peers of size _n_ - pub fn random_peers(&self, n: usize) -> Result, PeerManagerError> { - // TODO: Send to a random set of Communication Nodes + /// Compile a random list of communication node peers of size _n_ that are not banned or offline + pub fn random_peers(&self, n: usize, exclude_peers: Vec) -> Result, PeerManagerError> { let mut peer_keys = self .peer_db - .filter(|(_, peer)| !peer.is_banned()) + .filter(|(_, peer)| { + !peer.is_recently_offline() && + !peer.is_offline() && + !peer.is_banned() && + peer.features == PeerFeatures::COMMUNICATION_NODE && + !exclude_peers.contains(&peer.node_id) + }) .map(|pairs| pairs.into_iter().map(|(k, _)| k).collect::>()) .map_err(PeerManagerError::DatabaseError)?; // Use all available peers up to a maximum of N - let max_available = min(peer_keys.len(), n); + let max_available = cmp::min(peer_keys.len(), n); if max_available == 0 { return Ok(Vec::new()); } @@ -383,35 +397,52 @@ where DS: KeyValueStore n: usize, ) -> Result { - let region2node_dist = region_node_id.distance(node_id); - let mut dists = vec![NodeDistance::max_distance(); n]; - let last_index = dists.len() - 1; - self.peer_db - .for_each_ok(|(_, peer)| { - if !peer.is_banned() { - let curr_dist = region_node_id.distance(&peer.node_id); - for i in 0..dists.len() { - if dists[i] > curr_dist { - dists.insert(i, curr_dist); - dists.pop(); - break; - } - } + let region_node_distance = region_node_id.distance(node_id); + let node_threshold = self.calc_region_threshold(region_node_id, n, PeerFeatures::COMMUNICATION_NODE)?; + // Is node ID in the base node threshold? + if region_node_distance <= node_threshold { + return Ok(true); + } + let client_threshold = self.calc_region_threshold(region_node_id, n, PeerFeatures::COMMUNICATION_CLIENT)?; + // Is node ID in the base client threshold? + Ok(region_node_distance <= client_threshold) + } - if region2node_dist > dists[last_index] { - return IterationResult::Break; - } - } + pub fn calc_region_threshold( + &self, + region_node_id: &NodeId, + n: usize, + features: PeerFeatures, + ) -> Result + { + self.get_region_stats(region_node_id, n, features) + .map(|stats| stats.distance) + } - IterationResult::Continue - }) - .map_err(PeerManagerError::DatabaseError)?; + /// Unban the peer + pub fn unban(&mut self, public_key: &CommsPublicKey) -> Result { + let peer_key = *self + .public_key_index + .get(&public_key) + .ok_or_else(|| PeerManagerError::PeerNotFoundError)?; + let mut peer = self + .peer_db + .get(&peer_key) + .map_err(PeerManagerError::DatabaseError)? + .ok_or_else(|| PeerManagerError::PeerNotFoundError)?; + let node_id = peer.node_id.clone(); - Ok(region2node_dist <= dists[last_index]) + if peer.banned_until.is_some() { + peer.unban(); + self.peer_db + .insert(peer_key, peer) + .map_err(PeerManagerError::DatabaseError)?; + } + Ok(node_id) } - /// Changes the ban flag bit of the peer - pub fn set_banned(&mut self, public_key: &CommsPublicKey, ban_flag: bool) -> Result { + /// Ban the peer for the given duration + pub fn ban_for(&mut self, public_key: &CommsPublicKey, duration: Duration) -> Result { let peer_key = *self .public_key_index .get(&public_key) @@ -421,7 +452,7 @@ where DS: KeyValueStore .get(&peer_key) .map_err(PeerManagerError::DatabaseError)? .ok_or_else(|| PeerManagerError::PeerNotFoundError)?; - peer.set_banned(ban_flag); + peer.ban_for(duration); let node_id = peer.node_id.clone(); self.peer_db .insert(peer_key, peer) @@ -464,6 +495,66 @@ where DS: KeyValueStore .insert(peer_key, peer) .map_err(PeerManagerError::DatabaseError) } + + /// Return some basic stats for the region surrounding the region_node_id + pub fn get_region_stats<'a>( + &self, + region_node_id: &'a NodeId, + n: usize, + features: PeerFeatures, + ) -> Result, PeerManagerError> + { + let mut dists = vec![NodeDistance::max_distance(); n]; + let last_index = n - 1; + + let mut neighbours = vec![None; n]; + self.peer_db + .for_each_ok(|(_, peer)| { + if peer.features != features { + return IterationResult::Continue; + } + + if peer.is_banned() { + return IterationResult::Continue; + } + if peer.is_offline() { + return IterationResult::Continue; + } + + let curr_dist = region_node_id.distance(&peer.node_id); + for i in 0..dists.len() { + if dists[i] > curr_dist { + dists.insert(i, curr_dist); + dists.pop(); + neighbours.insert(i, Some(peer)); + neighbours.pop(); + break; + } + } + + IterationResult::Continue + }) + .map_err(PeerManagerError::DatabaseError)?; + + let distance = dists.remove(last_index); + let total = neighbours.iter().filter(|p| p.is_some()).count(); + let num_offline = neighbours + .iter() + .filter(|p| p.as_ref().map(|p| p.is_offline()).unwrap_or(false)) + .count(); + let num_banned = neighbours + .iter() + .filter(|p| p.as_ref().map(|p| p.is_banned()).unwrap_or(false)) + .count(); + + Ok(RegionStats { + distance, + ref_node_id: region_node_id, + total, + num_offline, + num_banned, + }) + } } impl Into for PeerStorage { @@ -472,6 +563,38 @@ impl Into for PeerStorage { } } +pub struct RegionStats<'a> { + distance: NodeDistance, + ref_node_id: &'a NodeId, + total: usize, + num_offline: usize, + num_banned: usize, +} + +impl RegionStats<'_> { + pub fn in_region(&self, node_id: &NodeId) -> bool { + node_id.distance(self.ref_node_id) <= self.distance + } + + pub fn offline_ratio(&self) -> f32 { + self.num_offline as f32 / self.total as f32 + } + + pub fn banned_ratio(&self) -> f32 { + self.num_banned as f32 / self.total as f32 + } +} + +impl fmt::Display for RegionStats<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "RegionStats(distance = {}, total = {}, num offline = {}, num banned = {})", + self.distance, self.total, self.num_offline, self.num_banned + ) + } +} + #[cfg(test)] mod test { use super::*; diff --git a/comms/src/pipeline/error.rs b/comms/src/pipeline/error.rs index 9384dab4b1..f225b27760 100644 --- a/comms/src/pipeline/error.rs +++ b/comms/src/pipeline/error.rs @@ -22,7 +22,6 @@ use std::{error, fmt}; -#[derive(Debug)] pub struct PipelineError { err_string: String, } @@ -35,9 +34,18 @@ impl PipelineError { } } +impl fmt::Debug for PipelineError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("PipelineError: ")?; + f.write_str(&self.err_string) + } +} + impl From<&str> for PipelineError { fn from(s: &str) -> Self { - Self::from_debug(s) + Self { + err_string: s.to_owned(), + } } } diff --git a/comms/src/pipeline/outbound.rs b/comms/src/pipeline/outbound.rs index 8dc316771f..14c2785be9 100644 --- a/comms/src/pipeline/outbound.rs +++ b/comms/src/pipeline/outbound.rs @@ -128,13 +128,8 @@ mod test { #[tokio_macros::test_basic] async fn run() { const NUM_ITEMS: usize = 10; - let items = (0..NUM_ITEMS).map(|i| { - OutboundMessage::new( - Default::default(), - Default::default(), - Bytes::copy_from_slice(&i.to_be_bytes()), - ) - }); + let items = + (0..NUM_ITEMS).map(|i| OutboundMessage::new(Default::default(), Bytes::copy_from_slice(&i.to_be_bytes()))); let stream = stream::iter(items).fuse(); let (out_tx, out_rx) = mpsc::channel(NUM_ITEMS); let (msg_tx, msg_rx) = mpsc::channel(NUM_ITEMS); diff --git a/comms/src/proto/control_service/header.proto b/comms/src/proto/control_service/header.proto deleted file mode 100644 index 3ad1c403bd..0000000000 --- a/comms/src/proto/control_service/header.proto +++ /dev/null @@ -1,17 +0,0 @@ -syntax = "proto3"; - -package tari.comms.control_service; - -enum MessageType { - MessageTypeNone = 0; - MessageTypeRequestConnection = 1; - MessageTypePing = 2; - MessageTypeAcceptPeerConnection = 3; - MessageTypeRejectPeerConnection = 4; - MessageTypePong = 5; - MessageTypeConnectRequestOutcome = 6; -} - -message MessageHeader { - MessageType message_type = 1; -} diff --git a/comms/src/proto/control_service/ping.proto b/comms/src/proto/control_service/ping.proto deleted file mode 100644 index dc57a07082..0000000000 --- a/comms/src/proto/control_service/ping.proto +++ /dev/null @@ -1,6 +0,0 @@ -syntax = "proto3"; - -package tari.comms.control_service; - -message PingMessage { } -message PongMessage { } diff --git a/comms/src/proto/control_service/request_connection.proto b/comms/src/proto/control_service/request_connection.proto deleted file mode 100644 index ec50a53d9a..0000000000 --- a/comms/src/proto/control_service/request_connection.proto +++ /dev/null @@ -1,35 +0,0 @@ -syntax = "proto3"; - -package tari.comms.control_service; - -message RequestConnectionMessage { - string control_service_address = 1; - bytes node_id = 2; - uint64 features = 3; -} - -// Represents the reason for a peer connection request being rejected -enum RejectReason { - // No reject reason given - RejectReasonNone = 0; - // Peer already has an existing active peer connection - RejectReasonExistingConnection = 1; - // A connection collision has been detected, foreign node should abandon the connection attempt - RejectReasonCollisionDetected = 2; -} - -// Represents an outcome for the request to establish a new [PeerConnection]. -// -// [PeerConnection]: ../../connection/peer_connection/index.html -message RequestConnectionOutcome { - // True if the connection is accepted, otherwise false - bool accepted = 1; - // The zeroMQ Curve public key to use for the peer connection - bytes curve_public_key = 2; - /// The address of the open port to connect to - string address = 3; - /// If this connection was not accepted, the rejection reason is given - RejectReason reject_reason = 4; - /// The identity to use when connecting - bytes identity = 5; -} \ No newline at end of file diff --git a/comms/src/proto/envelope.proto b/comms/src/proto/envelope.proto index 56d37f31b7..ca8f91867e 100644 --- a/comms/src/proto/envelope.proto +++ b/comms/src/proto/envelope.proto @@ -2,20 +2,6 @@ syntax = "proto3"; package tari.comms.envelope; -/// Represents a message which is about to go on or has just come off the wire. -/// As described in [RFC-0172](https://rfc.tari.com/RFC-0172_PeerToPeerMessagingProtocol.html#messaging-structure) -message Envelope { - uint32 version = 1; - EnvelopeHeader header = 3; - bytes body = 4; -} - -message EnvelopeHeader { - bytes public_key = 1; - bytes signature = 2; - uint32 flags = 3; -} - // Parts contained within an Envelope. This is used to tell if an encrypted // message was successfully decrypted, by decrypting the envelope body and checking // if deserialization succeeds. diff --git a/comms/src/proto/mod.rs b/comms/src/proto/mod.rs index b9a2db02d5..ccdd88dd42 100644 --- a/comms/src/proto/mod.rs +++ b/comms/src/proto/mod.rs @@ -23,8 +23,5 @@ #[path = "tari.comms.envelope.rs"] pub(crate) mod envelope; -#[path = "tari.comms.control_service.rs"] -pub(crate) mod control_service; - #[path = "tari.comms.identity.rs"] pub(crate) mod identity; diff --git a/comms/src/proto/tari.comms.control_service.rs b/comms/src/proto/tari.comms.control_service.rs deleted file mode 100644 index ca33091768..0000000000 --- a/comms/src/proto/tari.comms.control_service.rs +++ /dev/null @@ -1,61 +0,0 @@ -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct MessageHeader { - #[prost(enumeration = "MessageType", tag = "1")] - pub message_type: i32, -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum MessageType { - None = 0, - RequestConnection = 1, - Ping = 2, - AcceptPeerConnection = 3, - RejectPeerConnection = 4, - Pong = 5, - ConnectRequestOutcome = 6, -} -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct RequestConnectionMessage { - #[prost(string, tag = "1")] - pub control_service_address: std::string::String, - #[prost(bytes, tag = "2")] - pub node_id: std::vec::Vec, - #[prost(uint64, tag = "3")] - pub features: u64, -} -/// Represents an outcome for the request to establish a new [PeerConnection]. -/// -/// [PeerConnection]: ../../connection/peer_connection/index.html -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct RequestConnectionOutcome { - /// True if the connection is accepted, otherwise false - #[prost(bool, tag = "1")] - pub accepted: bool, - /// The zeroMQ Curve public key to use for the peer connection - #[prost(bytes, tag = "2")] - pub curve_public_key: std::vec::Vec, - //// The address of the open port to connect to - #[prost(string, tag = "3")] - pub address: std::string::String, - //// If this connection was not accepted, the rejection reason is given - #[prost(enumeration = "RejectReason", tag = "4")] - pub reject_reason: i32, - //// The identity to use when connecting - #[prost(bytes, tag = "5")] - pub identity: std::vec::Vec, -} -/// Represents the reason for a peer connection request being rejected -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum RejectReason { - /// No reject reason given - None = 0, - /// Peer already has an existing active peer connection - ExistingConnection = 1, - /// A connection collision has been detected, foreign node should abandon the connection attempt - CollisionDetected = 2, -} -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PingMessage {} -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PongMessage {} diff --git a/comms/src/proto/tari.comms.envelope.rs b/comms/src/proto/tari.comms.envelope.rs index f445fabea7..e72ec33b17 100644 --- a/comms/src/proto/tari.comms.envelope.rs +++ b/comms/src/proto/tari.comms.envelope.rs @@ -1,23 +1,3 @@ -//// Represents a message which is about to go on or has just come off the wire. -//// As described in [RFC-0172](https://rfc.tari.com/RFC-0172_PeerToPeerMessagingProtocol.html#messaging-structure) -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Envelope { - #[prost(uint32, tag = "1")] - pub version: u32, - #[prost(message, optional, tag = "3")] - pub header: ::std::option::Option, - #[prost(bytes, tag = "4")] - pub body: std::vec::Vec, -} -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct EnvelopeHeader { - #[prost(bytes, tag = "1")] - pub public_key: std::vec::Vec, - #[prost(bytes, tag = "2")] - pub signature: std::vec::Vec, - #[prost(uint32, tag = "3")] - pub flags: u32, -} /// Parts contained within an Envelope. This is used to tell if an encrypted /// message was successfully decrypted, by decrypting the envelope body and checking /// if deserialization succeeds. diff --git a/comms/src/protocol/identity.rs b/comms/src/protocol/identity.rs index 78f734a77c..bd52299620 100644 --- a/comms/src/protocol/identity.rs +++ b/comms/src/protocol/identity.rs @@ -88,8 +88,7 @@ where features: node_identity.features().bits(), supported_protocols, } - .to_encoded_bytes() - .map_err(|_| IdentityProtocolError::ProtobufEncodingError)?; + .to_encoded_bytes(); sink.send(msg_bytes.into()).await?; sink.close().await?; diff --git a/comms/src/protocol/messaging/error.rs b/comms/src/protocol/messaging/error.rs index f8b6d14829..ff595332f8 100644 --- a/comms/src/protocol/messaging/error.rs +++ b/comms/src/protocol/messaging/error.rs @@ -22,7 +22,7 @@ use crate::{ connection_manager::PeerConnectionError, - message::{MessageError, OutboundMessage}, + message::MessageError, peer_manager::PeerManagerError, protocol::ProtocolError, }; @@ -31,12 +31,6 @@ use derive_error::Error; #[derive(Debug, Error)] pub enum InboundMessagingError { PeerManagerError(PeerManagerError), - /// Inbound message signatures are invalid - InvalidMessageSignature, - /// The received envelope is invalid - InvalidEnvelope, - /// The connected peer sent a public key which did not match the public key of the connected peer - PeerPublicKeyMismatch, /// Failed to decode message MessageDecodeError(prost::DecodeError), MessageError(MessageError), @@ -45,7 +39,7 @@ pub enum InboundMessagingError { pub enum MessagingProtocolError { /// Failed to send message #[error(no_from, non_std)] - MessageSendFailed(OutboundMessage), // Msg returned to sender + MessageSendFailed, ProtocolError(ProtocolError), PeerConnectionError(PeerConnectionError), /// Failed to dial peer diff --git a/comms/src/protocol/messaging/inbound.rs b/comms/src/protocol/messaging/inbound.rs deleted file mode 100644 index b66d294ab6..0000000000 --- a/comms/src/protocol/messaging/inbound.rs +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2020, The Tari Project -// -// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the -// following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following -// disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the -// following disclaimer in the documentation and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote -// products derived from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, -// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -use crate::{ - message::{Envelope, InboundMessage}, - peer_manager::Peer, - protocol::messaging::error::InboundMessagingError, -}; -use bytes::Bytes; -use log::*; -use prost::Message; -use std::{convert::TryInto, sync::Arc}; - -const LOG_TARGET: &str = "comms::protocol::messaging::inbound"; - -pub struct InboundMessaging; - -impl InboundMessaging { - /// Process a single received message from its raw serialized form i.e. a FrameSet - pub async fn process_message( - &self, - source_peer: Arc, - msg: &mut Bytes, - ) -> Result - { - let envelope = Envelope::decode(msg)?; - - let public_key = envelope - .get_public_key() - .ok_or_else(|| InboundMessagingError::InvalidEnvelope)?; - - trace!( - target: LOG_TARGET, - "Received message envelope version {} from peer '{}'", - envelope.version, - source_peer.node_id.short_str() - ); - - if source_peer.public_key != public_key { - return Err(InboundMessagingError::PeerPublicKeyMismatch); - } - - if !envelope.verify_signature()? { - return Err(InboundMessagingError::InvalidMessageSignature); - } - - // -- Message is authenticated -- - let Envelope { header, body, .. } = envelope; - let header = header.expect("already checked").try_into().expect("already checked"); - - let inbound_message = InboundMessage::new(source_peer, header, body.into()); - - Ok(inbound_message) - } -} diff --git a/comms/src/protocol/messaging/mod.rs b/comms/src/protocol/messaging/mod.rs index 0cb78ccfc0..af1f6d20a8 100644 --- a/comms/src/protocol/messaging/mod.rs +++ b/comms/src/protocol/messaging/mod.rs @@ -21,10 +21,9 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mod error; -mod inbound; mod outbound; -mod protocol; +mod protocol; pub use protocol::{ MessagingEvent, MessagingEventReceiver, diff --git a/comms/src/protocol/messaging/outbound.rs b/comms/src/protocol/messaging/outbound.rs index c41d4134da..bb1038914d 100644 --- a/comms/src/protocol/messaging/outbound.rs +++ b/comms/src/protocol/messaging/outbound.rs @@ -23,11 +23,10 @@ use super::{error::MessagingProtocolError, MessagingEvent, MessagingProtocol, SendFailReason, MESSAGING_PROTOCOL}; use crate::{ connection_manager::{ConnectionManagerError, ConnectionManagerRequester, NegotiatedSubstream, PeerConnection}, - message::{Envelope, MessageExt, OutboundMessage}, + message::OutboundMessage, peer_manager::{NodeId, NodeIdentity}, types::CommsSubstream, }; -use bytes::Bytes; use futures::{channel::mpsc, SinkExt, StreamExt}; use log::*; use std::sync::Arc; @@ -87,17 +86,6 @@ impl OutboundMessaging { ); continue; }, - Err(err @ ConnectionManagerError::PeerOffline) => { - error!( - target: LOG_TARGET, - "MessagingProtocol failed to dial peer '{}' because '{:?}'", - self.peer_node_id.short_str(), - err - ); - self.flush_all_messages_to_failed_event(SendFailReason::PeerOffline) - .await; - break Err(MessagingProtocolError::PeerDialFailed); - }, Err(err) => { error!( target: LOG_TARGET, @@ -136,36 +124,16 @@ impl OutboundMessaging { async fn start_forwarding_messages(mut self, substream: CommsSubstream) -> Result<(), MessagingProtocolError> { let mut framed = MessagingProtocol::framed(substream); - while let Some(out_msg) = self.request_rx.next().await { - match self.to_envelope_bytes(&out_msg).await { - Ok(body) => { - trace!( - target: LOG_TARGET, - "Sending message ({} bytes) ({:?}) on outbound messaging substream", - body.len(), - out_msg.tag, - ); - if let Err(err) = framed.send(body).await { - debug!( - target: LOG_TARGET, - "[ThisNode={}] OutboundMessaging failed to send message to peer '{}' because '{}'", - self.node_identity.node_id().short_str(), - self.peer_node_id.short_str(), - err - ); - let _ = self - .messaging_events_tx - .send(MessagingEvent::SendMessageFailed( - out_msg, - SendFailReason::SubstreamSendFailed, - )) - .await; - // FATAL: Failed to send on the substream - self.flush_all_messages_to_failed_event(SendFailReason::SubstreamSendFailed) - .await; - return Err(MessagingProtocolError::OutboundSubstreamFailure); - } - + while let Some(mut out_msg) = self.request_rx.next().await { + trace!( + target: LOG_TARGET, + "Sending message ({} bytes) ({:?}) on outbound messaging substream", + out_msg.body.len(), + out_msg.tag, + ); + match framed.send(out_msg.body.clone()).await { + Ok(_) => { + out_msg.reply_success(); let _ = self .messaging_events_tx .send(MessagingEvent::MessageSent(out_msg.tag)) @@ -174,18 +142,23 @@ impl OutboundMessaging { Err(err) => { debug!( target: LOG_TARGET, - "Failed to send message to peer '{}' because '{:?}'", - out_msg.peer_node_id.short_str(), + "[ThisNode={}] OutboundMessaging failed to send message to peer '{}' because '{}'", + self.node_identity.node_id().short_str(), + self.peer_node_id.short_str(), err ); - + out_msg.reply_fail(); let _ = self .messaging_events_tx .send(MessagingEvent::SendMessageFailed( out_msg, - SendFailReason::EnvelopeFailedToSerialize, + SendFailReason::SubstreamSendFailed, )) .await; + // FATAL: Failed to send on the substream + self.flush_all_messages_to_failed_event(SendFailReason::SubstreamSendFailed) + .await; + return Err(MessagingProtocolError::OutboundSubstreamFailure); }, } } @@ -204,31 +177,4 @@ impl OutboundMessaging { .await; } } - - async fn to_envelope_bytes(&self, out_msg: &OutboundMessage) -> Result { - let OutboundMessage { - flags, - body, - peer_node_id, - .. - } = out_msg; - - let envelope = Envelope::construct_signed( - self.node_identity.secret_key(), - self.node_identity.public_key(), - body.clone(), - *flags, - )?; - let body = envelope.to_encoded_bytes()?; - - trace!( - target: LOG_TARGET, - "[Node={}] Sending message ({} bytes) to peer '{}'", - self.node_identity.node_id().short_str(), - body.len(), - peer_node_id.short_str(), - ); - - Ok(body.into()) - } } diff --git a/comms/src/protocol/messaging/protocol.rs b/comms/src/protocol/messaging/protocol.rs index cf461822e4..e8238bddb9 100644 --- a/comms/src/protocol/messaging/protocol.rs +++ b/comms/src/protocol/messaging/protocol.rs @@ -26,11 +26,7 @@ use crate::{ connection_manager::{ConnectionManagerEvent, ConnectionManagerRequester}, message::{InboundMessage, MessageTag, OutboundMessage}, peer_manager::{NodeId, NodeIdentity, Peer, PeerManagerError}, - protocol::{ - messaging::{inbound::InboundMessaging, outbound::OutboundMessaging}, - ProtocolEvent, - ProtocolNotification, - }, + protocol::{messaging::outbound::OutboundMessaging, ProtocolEvent, ProtocolNotification}, runtime::current_executor, types::CommsSubstream, PeerManager, @@ -65,19 +61,15 @@ pub enum MessagingRequest { /// occurred #[derive(Debug, Error, Copy, Clone)] pub enum SendFailReason { - /// Dial was not attempted because the peer is offline - PeerOffline, /// Dial was attempted, but failed PeerDialFailed, - /// Outbound message envelope failed to serialize - EnvelopeFailedToSerialize, /// Failed to open a messaging substream to peer SubstreamOpenFailed, /// Failed to send on substream channel SubstreamSendFailed, } -#[derive(Clone, Debug)] +#[derive(Debug)] pub enum MessagingEvent { MessageReceived(Box, MessageTag), InvalidMessageReceived(Box), @@ -92,7 +84,7 @@ impl fmt::Display for MessagingEvent { MessageReceived(node_id, tag) => write!(f, "MessageReceived({}, {})", node_id.short_str(), tag), InvalidMessageReceived(node_id) => write!(f, "InvalidMessageReceived({})", node_id.short_str()), SendMessageFailed(out_msg, reason) => write!(f, "SendMessageFailed({}, Reason = {})", out_msg, reason), - MessageSent(tag) => write!(f, "SendMessageSucceeded({})", tag), + MessageSent(tag) => write!(f, "MessageSent({})", tag), } } } @@ -305,8 +297,9 @@ impl MessagingProtocol { } async fn send_message(&mut self, out_msg: OutboundMessage) -> Result<(), MessagingProtocolError> { + let peer_node_id = out_msg.peer_node_id.clone(); let sender = loop { - match self.active_queues.entry(Box::new(out_msg.peer_node_id.clone())) { + match self.active_queues.entry(Box::new(peer_node_id.clone())) { Entry::Occupied(entry) => { if entry.get().is_closed() { entry.remove(); @@ -320,7 +313,7 @@ impl MessagingProtocol { self.node_identity.clone(), self.connection_manager_requester.clone(), self.internal_messaging_event_tx.clone(), - out_msg.peer_node_id.clone(), + peer_node_id.clone(), ) .await?; break entry.insert(sender); @@ -328,18 +321,18 @@ impl MessagingProtocol { } }; - match sender.send(out_msg.clone()).await { + match sender.send(out_msg).await { Ok(_) => Ok(()), Err(err) => { debug!( target: LOG_TARGET, "Failed to send message on channel because '{:?}'", err ); - // Lazily remove Senders from the active queue if the MessagingProtocolHandler has shut down + // Lazily remove Senders from the active queue if the `OutboundMessaging` task has shut down if err.is_disconnected() { - self.active_queues.remove(&out_msg.peer_node_id); + self.active_queues.remove(&peer_node_id); } - Err(MessagingProtocolError::MessageSendFailed(out_msg)) + Err(MessagingProtocolError::MessageSendFailed) }, } } @@ -363,7 +356,6 @@ impl MessagingProtocol { let messaging_events_tx = self.messaging_events_tx.clone(); let mut inbound_message_tx = self.inbound_message_tx.clone(); let mut framed_substream = Self::framed(substream); - let inbound = InboundMessaging; self.executor.spawn(async move { while let Some(result) = framed_substream.next().await { @@ -376,42 +368,23 @@ impl MessagingProtocol { raw_msg.len() ); - let mut raw_msg = raw_msg.freeze(); - let (event, in_msg) = match inbound.process_message(Arc::clone(&peer), &mut raw_msg).await { - Ok(inbound_msg) => ( - MessagingEvent::MessageReceived( - Box::new(inbound_msg.source_peer.node_id.clone()), - inbound_msg.tag, - ), - Some(inbound_msg), - ), - Err(err) => { - // TODO: #banheuristic - warn!( - target: LOG_TARGET, - "Received invalid message from peer '{}' ({})", - peer.node_id.short_str(), - err - ); - ( - MessagingEvent::InvalidMessageReceived(Box::new(peer.node_id.clone())), - None, - ) - }, - }; - - if let Some(in_msg) = in_msg { - if let Err(err) = inbound_message_tx.send(in_msg).await { - warn!( - target: LOG_TARGET, - "Failed to send InboundMessage for peer '{}' because '{}'", - peer.node_id.short_str(), - err - ); - - if err.is_disconnected() { - break; - } + let inbound_msg = InboundMessage::new(Arc::clone(&peer), raw_msg.freeze()); + + let event = MessagingEvent::MessageReceived( + Box::new(inbound_msg.source_peer.node_id.clone()), + inbound_msg.tag, + ); + + if let Err(err) = inbound_message_tx.send(inbound_msg).await { + warn!( + target: LOG_TARGET, + "Failed to send InboundMessage for peer '{}' because '{}'", + peer.node_id.short_str(), + err + ); + + if err.is_disconnected() { + break; } } @@ -426,16 +399,19 @@ impl MessagingProtocol { ); } }, - Err(err) => debug!( - target: LOG_TARGET, - "Failed to receive from peer '{}' because '{}'", - peer.node_id.short_str(), - err - ), + Err(err) => { + error!( + target: LOG_TARGET, + "Failed to receive from peer '{}' because '{}'", + peer.node_id.short_str(), + err + ); + break; + }, } } - trace!( + debug!( target: LOG_TARGET, "Inbound messaging handler for peer '{}' has stopped", peer.node_id.short_str() diff --git a/comms/src/protocol/messaging/test.rs b/comms/src/protocol/messaging/test.rs index 6d579c5680..f814fdc347 100644 --- a/comms/src/protocol/messaging/test.rs +++ b/comms/src/protocol/messaging/test.rs @@ -28,10 +28,9 @@ use super::protocol::{ MESSAGING_PROTOCOL, }; use crate::{ - message::{InboundMessage, MessageExt, MessageFlags, MessageTag, OutboundMessage}, + message::{InboundMessage, MessageTag, OutboundMessage}, net_address::MultiaddressesWithStats, peer_manager::{NodeId, NodeIdentity, Peer, PeerFeatures, PeerFlags, PeerManager}, - proto::envelope::Envelope, protocol::{messaging::SendFailReason, ProtocolEvent, ProtocolNotification}, test_utils::{ mocks::{create_connection_manager_mock, create_peer_connection_mock_pair, ConnectionManagerMockState}, @@ -42,8 +41,12 @@ use crate::{ types::{CommsDatabase, CommsPublicKey, CommsSubstream}, }; use bytes::Bytes; -use futures::{channel::mpsc, SinkExt, StreamExt}; -use prost::Message; +use futures::{ + channel::{mpsc, oneshot}, + stream::FuturesUnordered, + SinkExt, + StreamExt, +}; use rand::rngs::OsRng; use std::{sync::Arc, time::Duration}; use tari_crypto::keys::PublicKey; @@ -110,7 +113,7 @@ async fn new_inbound_substream_handling() { spawn_messaging_protocol().await; let expected_node_id = node_id::random(); - let (sk, pk) = CommsPublicKey::random_keypair(&mut OsRng); + let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng); peer_manager .add_peer(Peer::new( pk.clone(), @@ -139,12 +142,7 @@ async fn new_inbound_substream_handling() { let stream_theirs = muxer_theirs.incoming_mut().next().await.unwrap(); let mut framed_theirs = MessagingProtocol::framed(stream_theirs); - let envelope = Envelope::construct_signed(&sk, &pk, TEST_MSG1, MessageFlags::empty()).unwrap(); - - framed_theirs - .send(Bytes::copy_from_slice(&envelope.to_encoded_bytes().unwrap())) - .await - .unwrap(); + framed_theirs.send(TEST_MSG1).await.unwrap(); let in_msg = time::timeout(Duration::from_secs(5), inbound_msg_rx.next()) .await @@ -177,15 +175,14 @@ async fn send_message_request() { conn_man_mock.add_active_connection(peer_node_id.clone(), conn1).await; // Send a message to node - let out_msg = OutboundMessage::new(peer_node_id, MessageFlags::NONE, TEST_MSG1); + let out_msg = OutboundMessage::new(peer_node_id, TEST_MSG1); request_tx.send(MessagingRequest::SendMessage(out_msg)).await.unwrap(); // Check that node got the message let stream = peer_conn_mock2.next_incoming_substream().await.unwrap(); let mut framed = MessagingProtocol::framed(stream); let msg = framed.next().await.unwrap().unwrap(); - let msg = Envelope::decode(msg).unwrap(); - assert_eq!(msg.body, TEST_MSG1); + assert_eq!(msg, TEST_MSG1); // Got the call to create a substream assert_eq!(peer_conn_mock1.call_count(), 1); @@ -196,7 +193,7 @@ async fn send_message_dial_failed() { let (_, _, conn_manager_mock, _, mut request_tx, _, mut event_tx, _shutdown) = spawn_messaging_protocol().await; let node_id = node_id::random(); - let out_msg = OutboundMessage::new(node_id, MessageFlags::NONE, TEST_MSG1); + let out_msg = OutboundMessage::new(node_id, TEST_MSG1); let expected_out_msg_tag = out_msg.tag; // Send a message to node 2 request_tx.send(MessagingRequest::SendMessage(out_msg)).await.unwrap(); @@ -228,7 +225,7 @@ async fn send_message_substream_bulk_failure() { .await; async fn send_msg(request_tx: &mut mpsc::Sender, node_id: NodeId) -> MessageTag { - let out_msg = OutboundMessage::new(node_id, MessageFlags::NONE, TEST_MSG1); + let out_msg = OutboundMessage::new(node_id, TEST_MSG1); let msg_tag = out_msg.tag; // Send a message to node 2 request_tx.send(MessagingRequest::SendMessage(out_msg)).await.unwrap(); @@ -274,9 +271,17 @@ async fn many_concurrent_send_message_requests() { // Send many messages to node let mut msg_tags = Vec::with_capacity(NUM_MSGS); + let mut reply_rxs = Vec::with_capacity(NUM_MSGS); for _ in 0..NUM_MSGS { - let out_msg = OutboundMessage::new(node_id2.clone(), MessageFlags::NONE, TEST_MSG1); + let (reply_tx, reply_rx) = oneshot::channel(); + let out_msg = OutboundMessage { + tag: MessageTag::new(), + reply_tx: Some(reply_tx), + peer_node_id: node_id2.clone(), + body: TEST_MSG1, + }; msg_tags.push(out_msg.tag); + reply_rxs.push(reply_rx); request_tx.send(MessagingRequest::SendMessage(out_msg)).await.unwrap(); } @@ -297,6 +302,11 @@ async fn many_concurrent_send_message_requests() { msg_tags.remove(index); } + let unordered = FuturesUnordered::new(); + reply_rxs.into_iter().for_each(|rx| unordered.push(rx)); + let results = unordered.collect::>().await; + assert_eq!(results.into_iter().map(|r| r.unwrap()).all(|r| r.is_ok()), true); + // Got a single call to create a substream assert_eq!(peer_conn_mock1.call_count(), 1); } @@ -310,9 +320,17 @@ async fn many_concurrent_send_message_requests_that_fail() { // Send many messages to node let mut msg_tags = Vec::with_capacity(NUM_MSGS); + let mut reply_rxs = Vec::with_capacity(NUM_MSGS); for _ in 0..NUM_MSGS { - let out_msg = OutboundMessage::new(node_id2.clone(), MessageFlags::NONE, TEST_MSG1); + let (reply_tx, reply_rx) = oneshot::channel(); + let out_msg = OutboundMessage { + tag: MessageTag::new(), + reply_tx: Some(reply_tx), + peer_node_id: node_id2.clone(), + body: TEST_MSG1, + }; msg_tags.push(out_msg.tag); + reply_rxs.push(reply_rx); request_tx.send(MessagingRequest::SendMessage(out_msg)).await.unwrap(); } @@ -327,5 +345,11 @@ async fn many_concurrent_send_message_requests_that_fail() { let index = msg_tags.iter().position(|t| t == &out_msg.tag).unwrap(); msg_tags.remove(index); } + + let unordered = FuturesUnordered::new(); + reply_rxs.into_iter().for_each(|rx| unordered.push(rx)); + let results = unordered.collect::>().await; + assert_eq!(results.into_iter().map(|r| r.unwrap()).all(|r| r.is_err()), true); + assert_eq!(msg_tags.len(), 0); } diff --git a/comms/src/test_utils/mocks/connection_manager.rs b/comms/src/test_utils/mocks/connection_manager.rs index d6e9f6f949..e4eb6aff06 100644 --- a/comms/src/test_utils/mocks/connection_manager.rs +++ b/comms/src/test_utils/mocks/connection_manager.rs @@ -127,7 +127,7 @@ impl ConnectionManagerMock { self.state.inc_call_count(); self.state.add_call(format!("{:?}", req)).await; match req { - DialPeer(node_id, _, reply_tx) => { + DialPeer(node_id, reply_tx) => { // Send Ok(conn) if we have an active connection, otherwise Err(DialConnectFailedAllAddresses) reply_tx .send( @@ -152,6 +152,9 @@ impl ConnectionManagerMock { .send(self.state.active_conns.lock().await.values().cloned().collect()) .unwrap(); }, + GetNumActiveConnections(reply_tx) => { + reply_tx.send(self.state.active_conns.lock().await.len()).unwrap(); + }, DisconnectPeer(node_id, reply_tx) => { let _ = self.state.active_conns.lock().await.remove(&node_id); reply_tx.send(Ok(())).unwrap(); diff --git a/comms/src/test_utils/mocks/peer_connection.rs b/comms/src/test_utils/mocks/peer_connection.rs index 7753ee4043..c3ba9c3214 100644 --- a/comms/src/test_utils/mocks/peer_connection.rs +++ b/comms/src/test_utils/mocks/peer_connection.rs @@ -142,7 +142,7 @@ impl PeerConnectionMock { reply_tx.send(Ok(negotiated_substream)).unwrap(); }, Err(err) => { - reply_tx.send(Err(err.into())).unwrap(); + reply_tx.send(Err(err)).unwrap(); }, }, Disconnect(_, reply_tx) => { diff --git a/comms/src/test_utils/node_identity.rs b/comms/src/test_utils/node_identity.rs index fd06a40dc2..86df4b4c8b 100644 --- a/comms/src/test_utils/node_identity.rs +++ b/comms/src/test_utils/node_identity.rs @@ -34,10 +34,8 @@ pub fn build_node_identity(features: PeerFeatures) -> Arc { Arc::new(NodeIdentity::random(&mut OsRng, public_addr, features).unwrap()) } -pub fn ordered_node_identities(n: usize) -> Vec> { - let mut ids = (0..n) - .map(|_| build_node_identity(PeerFeatures::default())) - .collect::>(); +pub fn ordered_node_identities(n: usize, features: PeerFeatures) -> Vec> { + let mut ids = (0..n).map(|_| build_node_identity(features)).collect::>(); ids.sort_unstable_by(|a, b| a.node_id().cmp(b.node_id())); ids } diff --git a/comms/src/utils/cidr.rs b/comms/src/utils/cidr.rs index 8da1e7ae02..29b8d4615e 100644 --- a/comms/src/utils/cidr.rs +++ b/comms/src/utils/cidr.rs @@ -28,7 +28,7 @@ pub fn parse_cidrs<'a, I: IntoIterator, T: AsRef>(cidr_strs: I) - .map(|s| ::cidr::AnyIpCidr::from_str(s.as_ref())) .partition::, _>(Result::is_ok); - if failed.len() > 0 { + if !failed.is_empty() { return Err(format!("Invalid CIDR strings: {:?}", failed)); } diff --git a/comms/src/utils/datetime.rs b/comms/src/utils/datetime.rs new file mode 100644 index 0000000000..0da83c8ff7 --- /dev/null +++ b/comms/src/utils/datetime.rs @@ -0,0 +1,33 @@ +// Copyright 2020, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use chrono::{DateTime, NaiveTime, Utc}; +use std::time::Duration; + +pub fn safe_future_datetime_from_duration(duration: Duration) -> DateTime { + let old_duration = chrono::Duration::from_std(duration).unwrap_or_else(|_| chrono::Duration::max_value()); + Utc::now().checked_add_signed(old_duration).unwrap_or_else(|| { + chrono::MAX_DATE + .and_time(NaiveTime::from_hms(0, 0, 0)) + .expect("cannot fail") + }) +} diff --git a/comms/src/utils/mod.rs b/comms/src/utils/mod.rs index 39a2e768f1..697a758c1f 100644 --- a/comms/src/utils/mod.rs +++ b/comms/src/utils/mod.rs @@ -21,5 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. pub mod cidr; +pub mod datetime; pub mod multiaddr; pub mod signature; diff --git a/config/presets/rincewind-simple.toml b/config/presets/rincewind-simple.toml deleted file mode 100644 index 153b19af89..0000000000 --- a/config/presets/rincewind-simple.toml +++ /dev/null @@ -1,16 +0,0 @@ -# A simple set of sane defaults for connecting to the Rincewind testnet -[common] -#peer_database = "~/.tari/peers" - -[base_node] -network = "rincewind" - -[base_node.rincewind] -db_type = "lmdb" -transport = "tor" -peer_seeds = [ - "5edb022af1c21d644dfceeea2fcc7d3fac7a57ab44cf775b9a6f692cb75ed767::/onion3/vjkj44zpriqzrlve2qbiasrluaaxagrb6iuavzaascbujri6gw3rcmyd:18141", - "d44d23b005dcd364776e4cad69ac800b8ab6d6bf12097a5edb8720ce584ed45a::/onion3/3gficjdxzduuxtbyzt3auwwzjv7xlljnonzer5t2aglczrjb54wxadyd:18141", - "2e93c460df49d8cfbbf7a06dd9004c25a84f92584f7d0ac5e30bd8e0beee9a43::/onion3/nuuq3e2olck22rudimovhmrdwkmjncxvwdgbvfxhz6myzcnx2j4rssyd:18141" -] -enable_mining = false \ No newline at end of file diff --git a/infrastructure/storage/Cargo.toml b/infrastructure/storage/Cargo.toml index 9bdb515dc0..d824406659 100644 --- a/infrastructure/storage/Cargo.toml +++ b/infrastructure/storage/Cargo.toml @@ -6,7 +6,7 @@ repository = "https://github.com/tari-project/tari" homepage = "https://tari.com" readme = "README.md" license = "BSD-3-Clause" -version = "0.0.10" +version = "0.1.0" edition = "2018" [dependencies] @@ -14,6 +14,7 @@ bincode = "1.1" derive-error = "0.0.4" log = "0.4.0" lmdb-zero = "0.4.4" +thiserror = "1.0.15" rmp = "0.8.7" rmp-serde = "0.13.7" serde = "1.0.80" diff --git a/infrastructure/storage/src/key_val_store/lmdb_database.rs b/infrastructure/storage/src/key_val_store/lmdb_database.rs index c78d7dca93..7a40110f2d 100644 --- a/infrastructure/storage/src/key_val_store/lmdb_database.rs +++ b/infrastructure/storage/src/key_val_store/lmdb_database.rs @@ -125,7 +125,7 @@ mod test { std::fs::create_dir(&path).unwrap_or_default(); LMDBBuilder::new() .set_path(&path) - .set_environment_size(10) + .set_environment_size(50) .set_max_number_of_databases(2) .add_database(name, lmdb_zero::db::CREATE) .build() diff --git a/infrastructure/storage/src/key_val_store/mod.rs b/infrastructure/storage/src/key_val_store/mod.rs index 0aad92ad66..8bb27f7a69 100644 --- a/infrastructure/storage/src/key_val_store/mod.rs +++ b/infrastructure/storage/src/key_val_store/mod.rs @@ -22,6 +22,7 @@ pub mod error; pub mod hmap_database; +#[allow(clippy::module_inception)] pub mod key_val_store; pub mod lmdb_database; diff --git a/infrastructure/storage/src/lmdb_store/error.rs b/infrastructure/storage/src/lmdb_store/error.rs index 21b2ad3316..311f249919 100644 --- a/infrastructure/storage/src/lmdb_store/error.rs +++ b/infrastructure/storage/src/lmdb_store/error.rs @@ -19,36 +19,21 @@ // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use thiserror::Error; -use derive_error::Error; - -#[derive(Debug, Error)] +#[derive(Error, Debug)] pub enum LMDBError { - /// Cannot create LMDB. The path does not exist + #[error("Cannot create LMDB. The path does not exist")] InvalidPath, - /// An error occurred with the underlying data store implementation - #[error(embedded_msg, no_from, non_std)] - InternalError(String), - /// An error occurred during serialization - #[error(no_from, non_std)] + #[error("An error occurred during serialization:{0}")] SerializationErr(String), - /// An error occurred during deserialization - #[error(no_from, non_std)] - DeserializationErr(String), - /// Occurs when trying to perform an action that requires us to be in a live transaction - TransactionNotLiveError, - /// A transaction or query was attempted while no database was open. - DatabaseNotOpen, - /// A database with the requested name does not exist - UnknownDatabase, - /// An error occurred during a put query - #[error(embedded_msg, no_from, non_std)] - PutError(String), - /// An error occurred during a get query - #[error(embedded_msg, no_from, non_std)] + #[error("An error occurred during a get query:{0}")] GetError(String), - #[error(embedded_msg, no_from, non_std)] + #[error("An error occurred during commit:{0}")] CommitError(String), - /// An LMDB error occurred - DatabaseError(lmdb_zero::error::Error), + #[error("An LMDB error occurred:{source}")] + DatabaseError { + #[from] + source: lmdb_zero::error::Error, + }, } diff --git a/infrastructure/storage/src/lmdb_store/store.rs b/infrastructure/storage/src/lmdb_store/store.rs index ab88c4b79b..1b6c0d4bfd 100644 --- a/infrastructure/storage/src/lmdb_store/store.rs +++ b/infrastructure/storage/src/lmdb_store/store.rs @@ -46,8 +46,9 @@ type DatabaseRef = Arc>; /// ``` /// # use tari_storage::lmdb_store::LMDBBuilder; /// # use lmdb_zero::db; +/// # use std::env; /// let mut store = LMDBBuilder::new() -/// .set_path("/tmp/") +/// .set_path(env::temp_dir()) /// .set_environment_size(500) /// .set_max_number_of_databases(10) /// .add_database("db1", db::CREATE) @@ -377,9 +378,7 @@ impl LMDBDatabase { /// Return statistics about the database, See [Stat](lmdb_zero/struct.Stat.html) for more details. pub fn get_stats(&self) -> Result { let env = &(*self.db.env()); - ReadTransaction::new(env) - .and_then(|txn| txn.db_stat(&self.db)) - .map_err(LMDBError::DatabaseError) + Ok(ReadTransaction::new(env).and_then(|txn| txn.db_stat(&self.db))?) } /// Log some pretty printed stats.See [Stat](lmdb_zero/struct.Stat.html) for more details. @@ -440,10 +439,10 @@ impl LMDBDatabase { { let env = self.env.clone(); let db = self.db.clone(); - let txn = ReadTransaction::new(env).map_err(LMDBError::DatabaseError)?; + let txn = ReadTransaction::new(env)?; let access = txn.access(); - let cursor = txn.cursor(db).map_err(LMDBError::DatabaseError)?; + let cursor = txn.cursor(db)?; let head = |c: &mut Cursor, a: &ConstAccessor| { let (key_bytes, val_bytes) = c.first(a)?; @@ -451,7 +450,7 @@ impl LMDBDatabase { }; let cursor = MaybeOwned::Owned(cursor); - let iter = CursorIter::new(cursor, &access, head, ReadOnlyIterator::next).map_err(LMDBError::DatabaseError)?; + let iter = CursorIter::new(cursor, &access, head, ReadOnlyIterator::next)?; for p in iter { match f(p.map_err(|e| KeyValStoreError::DatabaseError(e.to_string()))) { @@ -602,7 +601,7 @@ impl<'txn, 'db: 'txn> LMDBWriteTransaction<'txn, 'db> { pub fn delete(&mut self, key: &K) -> Result<(), LMDBError> where K: AsLmdbBytes + ?Sized { - self.access.del_key(&self.db, key).map_err(LMDBError::DatabaseError) + Ok(self.access.del_key(&self.db, key)?) } fn convert_value(value: &V, size_estimate: usize) -> Result, LMDBError> @@ -612,3 +611,23 @@ impl<'txn, 'db: 'txn> LMDBWriteTransaction<'txn, 'db> { Ok(buf) } } + +#[cfg(test)] +mod test { + use crate::lmdb_store::LMDBBuilder; + use lmdb_zero::db; + use std::env; + + #[test] + fn test_lmdb_builder() { + let store = LMDBBuilder::new() + .set_path(env::temp_dir()) + .set_environment_size(500) + .set_max_number_of_databases(10) + .add_database("db1", db::CREATE) + .add_database("db2", db::CREATE) + .build() + .unwrap(); + assert!(&store.databases.len() == &2); + } +} diff --git a/scripts/build-dists-tarball.sh b/scripts/build-dists-tarball.sh index 0bec0343fa..39833a8875 100755 --- a/scripts/build-dists-tarball.sh +++ b/scripts/build-dists-tarball.sh @@ -23,6 +23,13 @@ if [ -f "$envFile" ]; then source "$envFile" fi +if [ -f "Cargo.toml" ]; then + cargo build --release +else + echo "Can't find Cargo.toml, exiting" + exit 1 +fi + if [ "$(uname)" == "Darwin" ]; then osname="osx" osversion="catalina" @@ -89,12 +96,20 @@ fi distFullName="$distName-$osname-$osversion-$osarch" echo $distFullName +if [ -f "applications/tari_base_node/src/consts.rs" ];then + rustver=$(grep -i 'VERSION' applications/tari_base_node/src/consts.rs | cut -d "\"" -f 2) +else + rustver="unversion" +fi + # Just a basic clean check if [ -n "$(git status --porcelain)" ]; then echo "There are changes, please clean up before re-running $0"; -# exit 1 + gitclean="uncommited" +# exit 2 else echo "No changes"; + gitclean="gitclean" fi git fetch --all --tags @@ -122,12 +137,9 @@ git rev-parse --symbolic-full-name --abbrev-ref HEAD #git pull shaSumVal="256" -#hashFile="$distFullName-$gitTagVersion.sha" -#hashFile+="$shaSumVal" -#hashFile+="sum" -hashFile="$distFullName-$gitTagVersion.sha${shaSumVal}sum" -archiveFile="$distFullName-$gitTagVersion.zip" +hashFile="$distFullName-$rustver-$gitTagVersion-$gitclean.sha${shaSumVal}sum" +archiveFile="$distFullName-$rustver-$gitTagVersion-$gitclean.zip" echo $hashFile distDir=$(mktemp -d) @@ -135,17 +147,15 @@ if [ -d $distDir ]; then echo "Temporary directory $distDir exists" else echo "Temporary directory $distDir does not exist" - exit 2 + exit 3 fi mkdir $distDir/dist -cargo build --release - COPY_FILES=( "target/release/tari_base_node" - "config/presets/rincewind-simple.toml" - "config/tari_config_sample.toml" + "common/config/presets/rincewind-simple.toml" + "common/config/tari_config_sample.toml" # "log4rs.yml" "common/logging/log4rs-sample.yml" "applications/tari_base_node/README.md" @@ -160,7 +170,10 @@ done pushd $distDir/dist if [ "$osname" == "osx" ] && [ -n "${osxsign}" ]; then echo "Signing OSX Binary ..." - codesign --force --verify --verbose --sign "${osxsign}" "${distDir}/dist/tari_base_node" + codesign --options runtime --force --verify --verbose --sign "${osxsign}" "${distDir}/dist/tari_base_node" + echo "Verify signed OSX Binary ..." + codesign --verify --deep --display --verbose=4 "${distDir}/dist/tari_base_node" + spctl -a -v "${distDir}/dist/tari_base_node" fi shasum -a $shaSumVal * >> "$distDir/$hashFile" #echo "$(cat $distDir/$hashFile)" | shasum -a $shaSumVal --check --status diff --git a/scripts/create_bundle.sh b/scripts/create_bundle.sh index 2c1b741494..7c1a406dae 100755 --- a/scripts/create_bundle.sh +++ b/scripts/create_bundle.sh @@ -19,7 +19,7 @@ fi BUNDLE=' target/release/tari_base_node scripts/install_tor.sh -config/presets/rincewind-simple.toml +common/config/presets/rincewind-simple.toml common/logging/log4rs-sample.yml applications/tari_base_node/install-osx.sh applications/tari_base_node/start_tor.sh diff --git a/scripts/publish_crates.sh b/scripts/publish_crates.sh index a40a84851a..a7d07ccee0 100755 --- a/scripts/publish_crates.sh +++ b/scripts/publish_crates.sh @@ -1,18 +1,23 @@ #!/usr/bin/env bash # NB: The order these are listed in is IMPORTANT! Dependencies must go first +#infrastructure/derive +#infrastructure/shutdown +#infrastructure/storage +#infrastructure/test_utils +#common +#comms +#comms/dht +#base_layer/service_framework +#base_layer/mmr +#base_layer/key_manager +#base_layer/p2p +#base_layer/core +#base_layer/wallet +#base_layer/wallet_ffi +#applications/tari_base_node + packages=${@:-' -infrastructure/derive -infrastructure/shutdown -infrastructure/storage -infrastructure/test_utils -common -comms -comms/dht -base_layer/service_framework -base_layer/mmr -base_layer/key_manager -base_layer/p2p base_layer/core base_layer/wallet base_layer/wallet_ffi diff --git a/scripts/update_crate_metadata.sh b/scripts/update_crate_metadata.sh index 5aef0631ba..c4b199cd68 100755 --- a/scripts/update_crate_metadata.sh +++ b/scripts/update_crate_metadata.sh @@ -6,17 +6,28 @@ if [ "x$VERSION" == "x" ]; then exit 1 fi +# infrastructure/derive +# infrastructure/shutdown +# infrastructure/storage +# infrastructure/test_utils +# base_layer/core +# base_layer/key_manager +# base_layer/mmr +# base_layer/p2p +# base_layer/service_framework +# base_layer/wallet +# base_layer/wallet_ffi +# common +# comms +# comms/dht +# applications/tari_base_node + function update_versions { packages=${@:-' - infrastructure/derive - infrastructure/shutdown infrastructure/storage - infrastructure/test_utils base_layer/core - base_layer/key_manager base_layer/mmr base_layer/p2p - base_layer/service_framework base_layer/wallet base_layer/wallet_ffi common @@ -35,7 +46,8 @@ function update_versions { function update_version { CARGO=$1 VERSION=$2 - SCRIPT='s/^version = "[0-9]\.[0-9]\.[0-9]"$/version = "'"$VERSION"'"/g' + SCRIPT='s/^version = ".+\..+\..+"/version = "'"$VERSION"'"/' + echo "$SCRIPT" "$CARGO" sed -i.bak -e "$SCRIPT" "$CARGO" rm $CARGO.bak }