diff --git a/contracts/contracts/coordination/Coordinator.sol b/contracts/contracts/coordination/Coordinator.sol index ca1c6483b..cb2500073 100644 --- a/contracts/contracts/coordination/Coordinator.sol +++ b/contracts/contracts/coordination/Coordinator.sol @@ -93,7 +93,7 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable ITACoChildApplication public immutable application; uint96 private immutable minAuthorization; // TODO use child app for checking eligibility - Ritual[] private ritualsStub; // former rituals + Ritual[] internal ritualsStub; // former rituals, "internal" for testing only uint32 public timeout; uint16 public maxDkgSize; bool private stub1; // former isInitiationPublic diff --git a/contracts/test/CoordinatorTestSet.sol b/contracts/test/CoordinatorTestSet.sol index d7664553f..6b777d116 100644 --- a/contracts/test/CoordinatorTestSet.sol +++ b/contracts/test/CoordinatorTestSet.sol @@ -3,6 +3,7 @@ pragma solidity ^0.8.0; import "../threshold/ITACoChildApplication.sol"; +import "../contracts/coordination/Coordinator.sol"; /** * @notice Contract for testing Coordinator contract @@ -33,3 +34,37 @@ contract ChildApplicationForCoordinatorMock is ITACoChildApplication { // solhint-disable-next-line no-empty-blocks function penalize(address _stakingProvider) external {} } + +contract ExtendedCoordinator is Coordinator { + constructor(ITACoChildApplication _application) Coordinator(_application) {} + + function initiateOldRitual( + IFeeModel feeModel, + address[] calldata providers, + address authority, + uint32 duration, + IEncryptionAuthorizer accessController + ) external returns (uint32) { + uint16 length = uint16(providers.length); + + uint32 id = uint32(ritualsStub.length); + Ritual storage ritual = ritualsStub.push(); + ritual.initiator = msg.sender; + ritual.authority = authority; + ritual.dkgSize = length; + ritual.threshold = getThresholdForRitualSize(length); + ritual.initTimestamp = uint32(block.timestamp); + ritual.endTimestamp = ritual.initTimestamp + duration; + ritual.accessController = accessController; + ritual.feeModel = feeModel; + + address previous = address(0); + for (uint256 i = 0; i < length; i++) { + Participant storage newParticipant = ritual.participant.push(); + address current = providers[i]; + newParticipant.provider = current; + previous = current; + } + return id; + } +} diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py index 22b88b092..64e799953 100644 --- a/tests/test_coordinator.py +++ b/tests/test_coordinator.py @@ -3,6 +3,7 @@ import ape import pytest +from ape.utils import ZERO_ADDRESS from eth_account import Account from hexbytes import HexBytes from web3 import Web3 @@ -86,9 +87,9 @@ def erc20(project, initiator): @pytest.fixture() -def coordinator(project, deployer, application, initiator, oz_dependency): +def coordinator(project, deployer, application, oz_dependency): admin = deployer - contract = project.Coordinator.deploy( + contract = project.ExtendedCoordinator.deploy( application.address, sender=deployer, ) @@ -100,7 +101,7 @@ def coordinator(project, deployer, application, initiator, oz_dependency): encoded_initializer_function, sender=deployer, ) - proxy_contract = project.Coordinator.at(proxy.address) + proxy_contract = project.ExtendedCoordinator.at(proxy.address) return proxy_contract @@ -564,3 +565,71 @@ def test_post_aggregation_fails( assert fee_model.totalPendingFees() == 0 assert fee_model.pendingFees(ritualID) == 0 fee_model.withdrawTokens(fee_model_balance_after_refund, sender=deployer) + + +def test_upgrade( + coordinator, nodes, initiator, erc20, fee_model, treasury, deployer, global_allow_list +): + coordinator.initiateOldRitual( + fee_model, nodes, initiator, DURATION, global_allow_list.address, sender=initiator + ) + coordinator.initiateOldRitual( + ZERO_ADDRESS, [nodes[0]], treasury, DURATION // 2, deployer, sender=initiator + ) + assert coordinator.numberOfRituals() == 0 + coordinator.initializeNumberOfRituals(sender=deployer) + assert coordinator.numberOfRituals() == 2 + + initiate_ritual( + coordinator=coordinator, + fee_model=fee_model, + erc20=erc20, + authority=initiator, + nodes=nodes, + allow_logic=global_allow_list, + ) + assert coordinator.numberOfRituals() == 3 + + assert coordinator.getRitualState(0) == RitualState.DKG_AWAITING_TRANSCRIPTS + + ritual_struct = coordinator.rituals(0) + assert ritual_struct[0] == initiator + init, end = ritual_struct[1], ritual_struct[2] + assert end - init == DURATION + total_transcripts, total_aggregations = ritual_struct[3], ritual_struct[4] + assert total_transcripts == total_aggregations == 0 + assert ritual_struct[5] == initiator + assert ritual_struct[6] == len(nodes) + assert ritual_struct[7] == 1 + len(nodes) // 2 # threshold + assert not ritual_struct[8] # aggregationMismatch + assert ritual_struct[9] == global_allow_list.address # accessController + assert ritual_struct[10] == (b"\x00" * 32, b"\x00" * 16) # publicKey + assert not ritual_struct[11] # aggregatedTranscript + + ritual_struct = coordinator.rituals(1) + assert ritual_struct[0] == initiator + init, end = ritual_struct[1], ritual_struct[2] + assert end - init == DURATION // 2 + total_transcripts, total_aggregations = ritual_struct[3], ritual_struct[4] + assert total_transcripts == total_aggregations == 0 + assert ritual_struct[5] == treasury + assert ritual_struct[6] == 1 + assert ritual_struct[7] == 1 # threshold + assert not ritual_struct[8] # aggregationMismatch + assert ritual_struct[9] == deployer # accessController + assert ritual_struct[10] == (b"\x00" * 32, b"\x00" * 16) # publicKey + assert not ritual_struct[11] # aggregatedTranscript + + ritual_struct = coordinator.rituals(2) + assert ritual_struct[0] == initiator + init, end = ritual_struct[1], ritual_struct[2] + assert end - init == DURATION + total_transcripts, total_aggregations = ritual_struct[3], ritual_struct[4] + assert total_transcripts == total_aggregations == 0 + assert ritual_struct[5] == initiator + assert ritual_struct[6] == len(nodes) + assert ritual_struct[7] == 1 + len(nodes) // 2 # threshold + assert not ritual_struct[8] # aggregationMismatch + assert ritual_struct[9] == global_allow_list.address # accessController + assert ritual_struct[10] == (b"\x00" * 32, b"\x00" * 16) # publicKey + assert not ritual_struct[11] # aggregatedTranscript