diff --git a/contracts/contracts/coordination/Coordinator.sol b/contracts/contracts/coordination/Coordinator.sol index 1856ceec..895eb636 100644 --- a/contracts/contracts/coordination/Coordinator.sol +++ b/contracts/contracts/coordination/Coordinator.sol @@ -30,7 +30,7 @@ contract Coordinator is AccessControlDefaultAdminRules { event TimeoutChanged(uint32 oldTimeout, uint32 newTimeout); event MaxDkgSizeChanged(uint16 oldSize, uint16 newSize); - event ParticipantPublicKeySet(address indexed participant, BLS12381.G1Point publicKey); + event ParticipantPublicKeySet(uint32 indexed ritualId, address indexed participant, BLS12381.G2Point publicKey); enum RitualState { NON_INITIATED, @@ -64,8 +64,8 @@ contract Coordinator is AccessControlDefaultAdminRules { } struct ParticipantKey { - uint32 ritualId; - BLS12381.G1Point publicKey; + uint32 lastRitualId; + BLS12381.G2Point publicKey; } using SafeERC20 for IERC20; @@ -133,21 +133,21 @@ contract Coordinator is AccessControlDefaultAdminRules { _setRoleAdmin(INITIATOR_ROLE, bytes32(0)); } - function setProviderPublicKey(BLS12381.G1Point calldata _publicKey) public { + function setProviderPublicKey(BLS12381.G2Point calldata _publicKey) public { uint32 lastRitualId = uint32(rituals.length); address provider = application.stakingProviderFromOperator(msg.sender); ParticipantKey memory newRecord = ParticipantKey(lastRitualId, _publicKey); keysHistory[provider].push(newRecord); - emit ParticipantPublicKeySet(provider, _publicKey); + emit ParticipantPublicKeySet(lastRitualId, provider, _publicKey); } - function getProviderPublicKey(address _address, uint _ritualId) public view returns (BLS12381.G1Point memory) { - ParticipantKey[] storage participantHistory = keysHistory[_address]; + function getProviderPublicKey(address _provider, uint _ritualId) external view returns (BLS12381.G2Point memory) { + ParticipantKey[] storage participantHistory = keysHistory[_provider]; for (uint i = participantHistory.length - 1; i >= 0; i--) { - if (participantHistory[i].ritualId <= _ritualId) { + if (participantHistory[i].lastRitualId <= _ritualId) { return participantHistory[i].publicKey; } } @@ -275,7 +275,7 @@ contract Coordinator is AccessControlDefaultAdminRules { function postAggregation( uint32 ritualId, bytes calldata aggregatedTranscript, - BLS12381.G1Point calldata publicKey, + BLS12381.G1Point calldata dkgPublicKey, bytes calldata decryptionRequestStaticKey ) external { uint256 initialGasLeft = gasleft(); @@ -316,9 +316,9 @@ contract Coordinator is AccessControlDefaultAdminRules { if (ritual.aggregatedTranscript.length == 0) { ritual.aggregatedTranscript = aggregatedTranscript; - ritual.publicKey = publicKey; + ritual.publicKey = dkgPublicKey; } else if ( - !BLS12381.eqG1Point(ritual.publicKey, publicKey) || + !BLS12381.eqG1Point(ritual.publicKey, dkgPublicKey) || keccak256(ritual.aggregatedTranscript) != aggregatedTranscriptDigest ) { ritual.aggregationMismatch = true; diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py index 3fa4adcd..9ce4cbac 100644 --- a/tests/test_coordinator.py +++ b/tests/test_coordinator.py @@ -31,6 +31,10 @@ def transcript_size(shares, threshold): return int(424 + 240 * (shares / 2) + 50 * (threshold)) +def gen_public_key(): + return (os.urandom(32), os.urandom(32), os.urandom(32)) + + @pytest.fixture(scope="module") def nodes(accounts): return sorted(accounts[:MAX_DKG_SIZE], key=lambda x: x.address.lower()) @@ -110,7 +114,7 @@ def test_invalid_initiate_ritual(coordinator, nodes, accounts, initiator): coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator) for node in nodes: - public_key = (os.urandom(32), os.urandom(16)) + public_key = gen_public_key() coordinator.setProviderPublicKey(public_key, sender=node) with ape.reverts("Providers must be sorted"): coordinator.initiateRitual(nodes[1:] + [nodes[0]], initiator, DURATION, sender=initiator) @@ -122,7 +126,7 @@ def test_invalid_initiate_ritual(coordinator, nodes, accounts, initiator): def initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes): for node in nodes: - public_key = (os.urandom(32), os.urandom(16)) + public_key = gen_public_key() coordinator.setProviderPublicKey(public_key, sender=node) cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION) erc20.approve(coordinator.address, cost, sender=initiator) @@ -143,9 +147,9 @@ def test_initiate_ritual(coordinator, nodes, initiator, erc20, flat_rate_fee_mod assert coordinator.getRitualState(0) == RitualState.AWAITING_TRANSCRIPTS -def test_test_provider_public_key(coordinator, nodes): +def test_provider_public_key(coordinator, nodes): selected_provider = nodes[0] - public_key = (os.urandom(32), os.urandom(16)) + public_key = gen_public_key() tx = coordinator.setProviderPublicKey(public_key, sender=selected_provider) ritual_id = coordinator.numberOfRituals() @@ -220,11 +224,11 @@ def test_post_aggregation(coordinator, nodes, initiator, erc20, flat_rate_fee_mo aggregated = transcript # has the same size as transcript decryption_request_static_keys = [os.urandom(42) for _ in nodes] - public_key = (os.urandom(32), os.urandom(16)) + dkg_public_key = (os.urandom(32), os.urandom(16)) for i, node in enumerate(nodes): assert coordinator.getRitualState(0) == RitualState.AWAITING_AGGREGATIONS tx = coordinator.postAggregation( - 0, aggregated, public_key, decryption_request_static_keys[i], sender=node + 0, aggregated, dkg_public_key, decryption_request_static_keys[i], sender=node ) events = list(coordinator.AggregationPosted.from_receipt(tx))