diff --git a/contracts/contracts/coordination/Coordinator.sol b/contracts/contracts/coordination/Coordinator.sol index 6bd2bcae..1856ceec 100644 --- a/contracts/contracts/coordination/Coordinator.sol +++ b/contracts/contracts/coordination/Coordinator.sol @@ -30,6 +30,8 @@ contract Coordinator is AccessControlDefaultAdminRules { event TimeoutChanged(uint32 oldTimeout, uint32 newTimeout); event MaxDkgSizeChanged(uint16 oldSize, uint16 newSize); + event ParticipantPublicKeySet(address indexed participant, BLS12381.G1Point publicKey); + enum RitualState { NON_INITIATED, AWAITING_TRANSCRIPTS, @@ -61,11 +63,16 @@ contract Coordinator is AccessControlDefaultAdminRules { Participant[] participant; } + struct ParticipantKey { + uint32 ritualId; + BLS12381.G1Point publicKey; + } + using SafeERC20 for IERC20; bytes32 public constant INITIATOR_ROLE = keccak256("INITIATOR_ROLE"); - mapping(address => bytes) public providerPublicKey; + mapping(address => ParticipantKey[]) keysHistory; IAccessControlApplication public immutable application; @@ -101,13 +108,13 @@ contract Coordinator is AccessControlDefaultAdminRules { function getRitualState(Ritual storage ritual) internal view returns (RitualState){ uint32 t0 = ritual.initTimestamp; uint32 deadline = t0 + timeout; - if (t0 == 0){ + if (t0 == 0) { return RitualState.NON_INITIATED; } else if (ritual.totalAggregations == ritual.dkgSize) { return RitualState.FINALIZED; - } else if (ritual.aggregationMismatch){ + } else if (ritual.aggregationMismatch) { return RitualState.INVALID; - } else if (block.timestamp > deadline){ + } else if (block.timestamp > deadline) { return RitualState.TIMEOUT; } else if (ritual.totalTranscripts < ritual.dkgSize) { return RitualState.AWAITING_TRANSCRIPTS; @@ -126,10 +133,26 @@ contract Coordinator is AccessControlDefaultAdminRules { _setRoleAdmin(INITIATOR_ROLE, bytes32(0)); } - function setProviderPublicKey(bytes calldata publicKey) external { - // TODO: Verify public key length - require(publicKey.length == 48, "Invalid public key length"); - providerPublicKey[msg.sender] = publicKey; + function setProviderPublicKey(BLS12381.G1Point calldata _publicKey) public { + uint32 lastRitualId = uint32(rituals.length); + address provider = application.stakingProviderFromOperator(msg.sender); + + ParticipantKey memory newRecord = ParticipantKey(lastRitualId, _publicKey); + keysHistory[provider].push(newRecord); + + emit ParticipantPublicKeySet(provider, _publicKey); + } + + function getProviderPublicKey(address _address, uint _ritualId) public view returns (BLS12381.G1Point memory) { + ParticipantKey[] storage participantHistory = keysHistory[_address]; + + for (uint i = participantHistory.length - 1; i >= 0; i--) { + if (participantHistory[i].ritualId <= _ritualId) { + return participantHistory[i].publicKey; + } + } + + revert("No keys found prior to the provided ritual"); } function setTimeout(uint32 newTimeout) external onlyRole(DEFAULT_ADMIN_ROLE) { @@ -152,11 +175,11 @@ contract Coordinator is AccessControlDefaultAdminRules { // TODO: Events } - function numberOfRituals() external view returns(uint256) { + function numberOfRituals() external view returns (uint256) { return rituals.length; } - function getParticipants(uint32 ritualId) external view returns(Participant[] memory) { + function getParticipants(uint32 ritualId) external view returns (Participant[] memory) { Ritual storage ritual = rituals[ritualId]; return ritual.participant; } @@ -184,19 +207,18 @@ contract Coordinator is AccessControlDefaultAdminRules { ritual.endTimestamp = ritual.initTimestamp + duration; address previous = address(0); - for(uint256 i=0; i < length; i++){ + for (uint256 i = 0; i < length; i++) { Participant storage newParticipant = ritual.participant.push(); address current = providers[i]; // Make sure that current provider has already set their public key - require( - providerPublicKey[current].length > 0, - "Provider has not set their public key" - ); + ParticipantKey[] storage participantHistory = keysHistory[current]; + require(participantHistory.length > 0, "Provider has not set their public key"); + require(previous < current, "Providers must be sorted"); // TODO: Improve check for eligible nodes (staking, etc) - nucypher#3109 // TODO: Change check to isAuthorized(), without amount require( - application.authorizedStake(current) > 0, + application.authorizedStake(current) > 0, "Not enough authorization" ); newParticipant.provider = current; @@ -210,7 +232,7 @@ contract Coordinator is AccessControlDefaultAdminRules { return id; } - function cohortFingerprint(address[] calldata nodes) public pure returns(bytes32) { + function cohortFingerprint(address[] calldata nodes) public pure returns (bytes32) { return keccak256(abi.encode(nodes)); } @@ -244,7 +266,7 @@ contract Coordinator is AccessControlDefaultAdminRules { ritual.totalTranscripts++; // end round - if (ritual.totalTranscripts == ritual.dkgSize){ + if (ritual.totalTranscripts == ritual.dkgSize) { emit StartAggregationRound(ritualId); } processReimbursement(initialGasLeft); @@ -296,9 +318,9 @@ contract Coordinator is AccessControlDefaultAdminRules { ritual.aggregatedTranscript = aggregatedTranscript; ritual.publicKey = publicKey; } else if ( - !BLS12381.eqG1Point(ritual.publicKey, publicKey) || - keccak256(ritual.aggregatedTranscript) != aggregatedTranscriptDigest - ){ + !BLS12381.eqG1Point(ritual.publicKey, publicKey) || + keccak256(ritual.aggregatedTranscript) != aggregatedTranscriptDigest + ) { ritual.aggregationMismatch = true; emit EndRitual({ ritualId: ritualId, @@ -306,9 +328,9 @@ contract Coordinator is AccessControlDefaultAdminRules { }); } - if(!ritual.aggregationMismatch){ + if (!ritual.aggregationMismatch) { ritual.totalAggregations++; - if (ritual.totalAggregations == ritual.dkgSize){ + if (ritual.totalAggregations == ritual.dkgSize) { processPendingFee(ritualId); emit EndRitual({ ritualId: ritualId, @@ -327,9 +349,9 @@ contract Coordinator is AccessControlDefaultAdminRules { ) internal view returns (Participant storage) { uint length = ritual.participant.length; // TODO: Improve with binary search - for(uint i = 0; i < length; i++){ + for (uint i = 0; i < length; i++) { Participant storage participant = ritual.participant[i]; - if(participant.provider == provider){ + if (participant.provider == provider) { return participant; } } @@ -345,7 +367,7 @@ contract Coordinator is AccessControlDefaultAdminRules { function processRitualPayment(uint256 ritualID, address[] calldata providers, uint32 duration) internal { uint256 ritualCost = feeModel.getRitualInitiationCost(providers, duration); - if (ritualCost > 0){ + if (ritualCost > 0) { totalPendingFees += ritualCost; assert(pendingFees[ritualID] == 0); // TODO: This is an invariant, not sure if actually needed pendingFees[ritualID] += ritualCost; @@ -371,7 +393,7 @@ contract Coordinator is AccessControlDefaultAdminRules { totalPendingFees -= pending; delete pendingFees[ritualID]; // Transfer fees back to initiator if failed - if(state == RitualState.TIMEOUT || state == RitualState.INVALID){ + if (state == RitualState.TIMEOUT || state == RitualState.INVALID) { // Amount to refund depends on how much work nodes did for the ritual. // TODO: Validate if this is enough to remove griefing attacks uint256 executedTransactions = ritual.totalTranscripts + ritual.totalAggregations; @@ -384,7 +406,7 @@ contract Coordinator is AccessControlDefaultAdminRules { } function processReimbursement(uint256 initialGasLeft) internal { - if(address(reimbursementPool) != address(0)){ // TODO: Consider defining a method + if (address(reimbursementPool) != address(0)) { // TODO: Consider defining a method uint256 gasUsed = initialGasLeft - gasleft(); try reimbursementPool.refund(gasUsed, msg.sender) { return; diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py index 223d1fbd..3fa4adcd 100644 --- a/tests/test_coordinator.py +++ b/tests/test_coordinator.py @@ -110,7 +110,8 @@ def test_invalid_initiate_ritual(coordinator, nodes, accounts, initiator): coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) for node in nodes: - coordinator.setProviderPublicKey(os.urandom(48), sender=node) + public_key = (os.urandom(32), os.urandom(16)) + coordinator.setProviderPublicKey(public_key, sender=node) with ape.reverts("Providers must be sorted"): coordinator.initiateRitual(nodes[1:] + [nodes[0]], initiator, DURATION, sender=initiator) @@ -121,7 +122,8 @@ def test_invalid_initiate_ritual(coordinator, nodes, accounts, initiator): def initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes): for node in nodes: - coordinator.setProviderPublicKey(os.urandom(48), sender=node) + public_key = (os.urandom(32), os.urandom(16)) + coordinator.setProviderPublicKey(public_key, sender=node) cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) erc20.approve(coordinator.address, cost, sender=initiator) tx = coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) @@ -141,6 +143,20 @@ def test_initiate_ritual(coordinator, nodes, initiator, erc20, flat_rate_fee_mod assert coordinator.getRitualState(0) == RitualState.AWAITING_TRANSCRIPTS +def test_test_provider_public_key(coordinator, nodes): + selected_provider = nodes[0] + public_key = (os.urandom(32), os.urandom(16)) + tx = coordinator.setProviderPublicKey(public_key, sender=selected_provider) + ritual_id = coordinator.numberOfRituals() + + events = list(coordinator.ParticipantPublicKeySet.from_receipt(tx)) + assert len(events) == 1 + event = events[0] + assert event["participant"] == selected_provider + assert event["publicKey"] == public_key + assert coordinator.getProviderPublicKey(selected_provider, ritual_id) == public_key + + def test_post_transcript(coordinator, nodes, initiator, erc20, flat_rate_fee_model): initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes) transcript = os.urandom(transcript_size(len(nodes), len(nodes))) @@ -203,12 +219,12 @@ def test_post_aggregation(coordinator, nodes, initiator, erc20, flat_rate_fee_mo coordinator.postTranscript(0, transcript, sender=node) aggregated = transcript # has the same size as transcript - decryptionRequestStaticKeys = [os.urandom(42) for node in nodes] - publicKey = (os.urandom(32), os.urandom(16)) + decryption_request_static_keys = [os.urandom(42) for _ in nodes] + public_key = (os.urandom(32), os.urandom(16)) for i, node in enumerate(nodes): assert coordinator.getRitualState(0) == RitualState.AWAITING_AGGREGATIONS tx = coordinator.postAggregation( - 0, aggregated, publicKey, decryptionRequestStaticKeys[i], sender=node + 0, aggregated, public_key, decryption_request_static_keys[i], sender=node ) events = list(coordinator.AggregationPosted.from_receipt(tx)) @@ -221,7 +237,7 @@ def test_post_aggregation(coordinator, nodes, initiator, erc20, flat_rate_fee_mo participants = coordinator.getParticipants(0) for i, participant in enumerate(participants): assert participant.aggregated - assert participant.decryptionRequestStaticKey == decryptionRequestStaticKeys[i] + assert participant.decryptionRequestStaticKey == decryption_request_static_keys[i] assert coordinator.getRitualState(0) == RitualState.FINALIZED events = list(coordinator.EndRitual.from_receipt(tx))