diff --git a/contracts/contracts/coordination/Coordinator.sol b/contracts/contracts/coordination/Coordinator.sol index 5e0e4f70..895eb636 100644 --- a/contracts/contracts/coordination/Coordinator.sol +++ b/contracts/contracts/coordination/Coordinator.sol @@ -30,6 +30,8 @@ contract Coordinator is AccessControlDefaultAdminRules { event TimeoutChanged(uint32 oldTimeout, uint32 newTimeout); event MaxDkgSizeChanged(uint16 oldSize, uint16 newSize); + event ParticipantPublicKeySet(uint32 indexed ritualId, address indexed participant, BLS12381.G2Point publicKey); + enum RitualState { NON_INITIATED, AWAITING_TRANSCRIPTS, @@ -61,10 +63,17 @@ contract Coordinator is AccessControlDefaultAdminRules { Participant[] participant; } + struct ParticipantKey { + uint32 lastRitualId; + BLS12381.G2Point publicKey; + } + using SafeERC20 for IERC20; bytes32 public constant INITIATOR_ROLE = keccak256("INITIATOR_ROLE"); + mapping(address => ParticipantKey[]) keysHistory; + IAccessControlApplication public immutable application; Ritual[] public rituals; @@ -99,13 +108,13 @@ contract Coordinator is AccessControlDefaultAdminRules { function getRitualState(Ritual storage ritual) internal view returns (RitualState){ uint32 t0 = ritual.initTimestamp; uint32 deadline = t0 + timeout; - if (t0 == 0){ + if (t0 == 0) { return RitualState.NON_INITIATED; } else if (ritual.totalAggregations == ritual.dkgSize) { return RitualState.FINALIZED; - } else if (ritual.aggregationMismatch){ + } else if (ritual.aggregationMismatch) { return RitualState.INVALID; - } else if (block.timestamp > deadline){ + } else if (block.timestamp > deadline) { return RitualState.TIMEOUT; } else if (ritual.totalTranscripts < ritual.dkgSize) { return RitualState.AWAITING_TRANSCRIPTS; @@ -124,6 +133,28 @@ contract Coordinator is AccessControlDefaultAdminRules { _setRoleAdmin(INITIATOR_ROLE, bytes32(0)); } + function setProviderPublicKey(BLS12381.G2Point calldata _publicKey) public { + uint32 lastRitualId = uint32(rituals.length); + address provider = application.stakingProviderFromOperator(msg.sender); + + ParticipantKey memory newRecord = ParticipantKey(lastRitualId, _publicKey); + keysHistory[provider].push(newRecord); + + emit ParticipantPublicKeySet(lastRitualId, provider, _publicKey); + } + + function getProviderPublicKey(address _provider, uint _ritualId) external view returns (BLS12381.G2Point memory) { + ParticipantKey[] storage participantHistory = keysHistory[_provider]; + + for (uint i = participantHistory.length - 1; i >= 0; i--) { + if (participantHistory[i].lastRitualId <= _ritualId) { + return participantHistory[i].publicKey; + } + } + + revert("No keys found prior to the provided ritual"); + } + function setTimeout(uint32 newTimeout) external onlyRole(DEFAULT_ADMIN_ROLE) { emit TimeoutChanged(timeout, newTimeout); timeout = newTimeout; @@ -136,7 +167,7 @@ contract Coordinator is AccessControlDefaultAdminRules { function setReimbursementPool(IReimbursementPool pool) external onlyRole(DEFAULT_ADMIN_ROLE) { require( - address(pool) == address(0) || + address(pool) == address(0) || pool.isAuthorized(address(this)), "Invalid ReimbursementPool" ); @@ -144,11 +175,11 @@ contract Coordinator is AccessControlDefaultAdminRules { // TODO: Events } - function numberOfRituals() external view returns(uint256) { + function numberOfRituals() external view returns (uint256) { return rituals.length; } - function getParticipants(uint32 ritualId) external view returns(Participant[] memory) { + function getParticipants(uint32 ritualId) external view returns (Participant[] memory) { Ritual storage ritual = rituals[ritualId]; return ritual.participant; } @@ -176,14 +207,18 @@ contract Coordinator is AccessControlDefaultAdminRules { ritual.endTimestamp = ritual.initTimestamp + duration; address previous = address(0); - for(uint256 i=0; i < length; i++){ + for (uint256 i = 0; i < length; i++) { Participant storage newParticipant = ritual.participant.push(); address current = providers[i]; + // Make sure that current provider has already set their public key + ParticipantKey[] storage participantHistory = keysHistory[current]; + require(participantHistory.length > 0, "Provider has not set their public key"); + require(previous < current, "Providers must be sorted"); // TODO: Improve check for eligible nodes (staking, etc) - nucypher#3109 // TODO: Change check to isAuthorized(), without amount require( - application.authorizedStake(current) > 0, + application.authorizedStake(current) > 0, "Not enough authorization" ); newParticipant.provider = current; @@ -191,13 +226,13 @@ contract Coordinator is AccessControlDefaultAdminRules { } processRitualPayment(id, providers, duration); - + // TODO: Include cohort fingerprint in StartRitual event? emit StartRitual(id, ritual.authority, providers); return id; } - function cohortFingerprint(address[] calldata nodes) public pure returns(bytes32) { + function cohortFingerprint(address[] calldata nodes) public pure returns (bytes32) { return keccak256(abi.encode(nodes)); } @@ -231,7 +266,7 @@ contract Coordinator is AccessControlDefaultAdminRules { ritual.totalTranscripts++; // end round - if (ritual.totalTranscripts == ritual.dkgSize){ + if (ritual.totalTranscripts == ritual.dkgSize) { emit StartAggregationRound(ritualId); } processReimbursement(initialGasLeft); @@ -240,7 +275,7 @@ contract Coordinator is AccessControlDefaultAdminRules { function postAggregation( uint32 ritualId, bytes calldata aggregatedTranscript, - BLS12381.G1Point calldata publicKey, + BLS12381.G1Point calldata dkgPublicKey, bytes calldata decryptionRequestStaticKey ) external { uint256 initialGasLeft = gasleft(); @@ -281,11 +316,11 @@ contract Coordinator is AccessControlDefaultAdminRules { if (ritual.aggregatedTranscript.length == 0) { ritual.aggregatedTranscript = aggregatedTranscript; - ritual.publicKey = publicKey; + ritual.publicKey = dkgPublicKey; } else if ( - !BLS12381.eqG1Point(ritual.publicKey, publicKey) || - keccak256(ritual.aggregatedTranscript) != aggregatedTranscriptDigest - ){ + !BLS12381.eqG1Point(ritual.publicKey, dkgPublicKey) || + keccak256(ritual.aggregatedTranscript) != aggregatedTranscriptDigest + ) { ritual.aggregationMismatch = true; emit EndRitual({ ritualId: ritualId, @@ -293,9 +328,9 @@ contract Coordinator is AccessControlDefaultAdminRules { }); } - if(!ritual.aggregationMismatch){ + if (!ritual.aggregationMismatch) { ritual.totalAggregations++; - if (ritual.totalAggregations == ritual.dkgSize){ + if (ritual.totalAggregations == ritual.dkgSize) { processPendingFee(ritualId); emit EndRitual({ ritualId: ritualId, @@ -314,9 +349,9 @@ contract Coordinator is AccessControlDefaultAdminRules { ) internal view returns (Participant storage) { uint length = ritual.participant.length; // TODO: Improve with binary search - for(uint i = 0; i < length; i++){ + for (uint i = 0; i < length; i++) { Participant storage participant = ritual.participant[i]; - if(participant.provider == provider){ + if (participant.provider == provider) { return participant; } } @@ -332,7 +367,7 @@ contract Coordinator is AccessControlDefaultAdminRules { function processRitualPayment(uint256 ritualID, address[] calldata providers, uint32 duration) internal { uint256 ritualCost = feeModel.getRitualInitiationCost(providers, duration); - if (ritualCost > 0){ + if (ritualCost > 0) { totalPendingFees += ritualCost; assert(pendingFees[ritualID] == 0); // TODO: This is an invariant, not sure if actually needed pendingFees[ritualID] += ritualCost; @@ -358,7 +393,7 @@ contract Coordinator is AccessControlDefaultAdminRules { totalPendingFees -= pending; delete pendingFees[ritualID]; // Transfer fees back to initiator if failed - if(state == RitualState.TIMEOUT || state == RitualState.INVALID){ + if (state == RitualState.TIMEOUT || state == RitualState.INVALID) { // Amount to refund depends on how much work nodes did for the ritual. // TODO: Validate if this is enough to remove griefing attacks uint256 executedTransactions = ritual.totalTranscripts + ritual.totalAggregations; @@ -369,9 +404,9 @@ contract Coordinator is AccessControlDefaultAdminRules { currency.transferFrom(address(this), ritual.initiator, refundableFee); } } - + function processReimbursement(uint256 initialGasLeft) internal { - if(address(reimbursementPool) != address(0)){ // TODO: Consider defining a method + if (address(reimbursementPool) != address(0)) { // TODO: Consider defining a method uint256 gasUsed = initialGasLeft - gasleft(); try reimbursementPool.refund(gasUsed, msg.sender) { return; diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py index a37d4eea..9ce4cbac 100644 --- a/tests/test_coordinator.py +++ b/tests/test_coordinator.py @@ -31,6 +31,10 @@ def transcript_size(shares, threshold): return int(424 + 240 * (shares / 2) + 50 * (threshold)) +def gen_public_key(): + return (os.urandom(32), os.urandom(32), os.urandom(32)) + + @pytest.fixture(scope="module") def nodes(accounts): return sorted(accounts[:MAX_DKG_SIZE], key=lambda x: x.address.lower()) @@ -106,6 +110,12 @@ def test_invalid_initiate_ritual(coordinator, nodes, accounts, initiator): with ape.reverts("Invalid ritual duration"): coordinator.initiateRitual(nodes, initiator, 0, sender=initiator) + with ape.reverts("Provider has not set their public key"): + coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) + + for node in nodes: + public_key = gen_public_key() + coordinator.setProviderPublicKey(public_key, sender=node) with ape.reverts("Providers must be sorted"): coordinator.initiateRitual(nodes[1:] + [nodes[0]], initiator, DURATION, sender=initiator) @@ -114,11 +124,18 @@ def test_invalid_initiate_ritual(coordinator, nodes, accounts, initiator): coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) -def test_initiate_ritual(coordinator, nodes, initiator, erc20, flat_rate_fee_model): +def initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes): + for node in nodes: + public_key = gen_public_key() + coordinator.setProviderPublicKey(public_key, sender=node) cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) erc20.approve(coordinator.address, cost, sender=initiator) - authority = initiator - tx = coordinator.initiateRitual(nodes, authority, DURATION, sender=initiator) + tx = coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) + return initiator, tx + + +def test_initiate_ritual(coordinator, nodes, initiator, erc20, flat_rate_fee_model): + authority, tx = initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes) events = list(coordinator.StartRitual.from_receipt(tx)) assert len(events) == 1 @@ -130,10 +147,22 @@ def test_initiate_ritual(coordinator, nodes, initiator, erc20, flat_rate_fee_mod assert coordinator.getRitualState(0) == RitualState.AWAITING_TRANSCRIPTS +def test_provider_public_key(coordinator, nodes): + selected_provider = nodes[0] + public_key = gen_public_key() + tx = coordinator.setProviderPublicKey(public_key, sender=selected_provider) + ritual_id = coordinator.numberOfRituals() + + events = list(coordinator.ParticipantPublicKeySet.from_receipt(tx)) + assert len(events) == 1 + event = events[0] + assert event["participant"] == selected_provider + assert event["publicKey"] == public_key + assert coordinator.getProviderPublicKey(selected_provider, ritual_id) == public_key + + def test_post_transcript(coordinator, nodes, initiator, erc20, flat_rate_fee_model): - cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) - erc20.approve(coordinator.address, cost, sender=initiator) - coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) + initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes) transcript = os.urandom(transcript_size(len(nodes), len(nodes))) for node in nodes: @@ -159,9 +188,7 @@ def test_post_transcript(coordinator, nodes, initiator, erc20, flat_rate_fee_mod def test_post_transcript_but_not_part_of_ritual( coordinator, nodes, initiator, erc20, flat_rate_fee_model ): - cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) - erc20.approve(coordinator.address, cost, sender=initiator) - coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) + initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes) transcript = os.urandom(transcript_size(len(nodes), len(nodes))) with ape.reverts("Participant not part of ritual"): coordinator.postTranscript(0, transcript, sender=initiator) @@ -170,9 +197,7 @@ def test_post_transcript_but_not_part_of_ritual( def test_post_transcript_but_already_posted_transcript( coordinator, nodes, initiator, erc20, flat_rate_fee_model ): - cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) - erc20.approve(coordinator.address, cost, sender=initiator) - coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) + initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes) transcript = os.urandom(transcript_size(len(nodes), len(nodes))) coordinator.postTranscript(0, transcript, sender=nodes[0]) with ape.reverts("Node already posted transcript"): @@ -182,9 +207,7 @@ def test_post_transcript_but_already_posted_transcript( def test_post_transcript_but_not_waiting_for_transcripts( coordinator, nodes, initiator, erc20, flat_rate_fee_model ): - cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) - erc20.approve(coordinator.address, cost, sender=initiator) - coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) + initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes) transcript = os.urandom(transcript_size(len(nodes), len(nodes))) for node in nodes: coordinator.postTranscript(0, transcript, sender=node) @@ -194,21 +217,18 @@ def test_post_transcript_but_not_waiting_for_transcripts( def test_post_aggregation(coordinator, nodes, initiator, erc20, flat_rate_fee_model): - cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) - erc20.approve(coordinator.address, cost, sender=initiator) - coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) + initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes) transcript = os.urandom(transcript_size(len(nodes), len(nodes))) - for node in nodes: coordinator.postTranscript(0, transcript, sender=node) aggregated = transcript # has the same size as transcript - decryptionRequestStaticKeys = [os.urandom(42) for node in nodes] - publicKey = (os.urandom(32), os.urandom(16)) + decryption_request_static_keys = [os.urandom(42) for _ in nodes] + dkg_public_key = (os.urandom(32), os.urandom(16)) for i, node in enumerate(nodes): assert coordinator.getRitualState(0) == RitualState.AWAITING_AGGREGATIONS tx = coordinator.postAggregation( - 0, aggregated, publicKey, decryptionRequestStaticKeys[i], sender=node + 0, aggregated, dkg_public_key, decryption_request_static_keys[i], sender=node ) events = list(coordinator.AggregationPosted.from_receipt(tx)) @@ -221,7 +241,7 @@ def test_post_aggregation(coordinator, nodes, initiator, erc20, flat_rate_fee_mo participants = coordinator.getParticipants(0) for i, participant in enumerate(participants): assert participant.aggregated - assert participant.decryptionRequestStaticKey == decryptionRequestStaticKeys[i] + assert participant.decryptionRequestStaticKey == decryption_request_static_keys[i] assert coordinator.getRitualState(0) == RitualState.FINALIZED events = list(coordinator.EndRitual.from_receipt(tx))