Skip to content

Commit

Permalink
Merge pull request #77 from piotr-roslaniec/participant-pk
Browse files Browse the repository at this point in the history
  • Loading branch information
piotr-roslaniec authored Jul 8, 2023
2 parents 951187c + 867154b commit 3eebc89
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 47 deletions.
83 changes: 59 additions & 24 deletions contracts/contracts/coordination/Coordinator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ contract Coordinator is AccessControlDefaultAdminRules {
event TimeoutChanged(uint32 oldTimeout, uint32 newTimeout);
event MaxDkgSizeChanged(uint16 oldSize, uint16 newSize);

event ParticipantPublicKeySet(uint32 indexed ritualId, address indexed participant, BLS12381.G2Point publicKey);

enum RitualState {
NON_INITIATED,
AWAITING_TRANSCRIPTS,
Expand Down Expand Up @@ -61,10 +63,17 @@ contract Coordinator is AccessControlDefaultAdminRules {
Participant[] participant;
}

struct ParticipantKey {
uint32 lastRitualId;
BLS12381.G2Point publicKey;
}

using SafeERC20 for IERC20;

bytes32 public constant INITIATOR_ROLE = keccak256("INITIATOR_ROLE");

mapping(address => ParticipantKey[]) keysHistory;

IAccessControlApplication public immutable application;

Ritual[] public rituals;
Expand Down Expand Up @@ -99,13 +108,13 @@ contract Coordinator is AccessControlDefaultAdminRules {
function getRitualState(Ritual storage ritual) internal view returns (RitualState){
uint32 t0 = ritual.initTimestamp;
uint32 deadline = t0 + timeout;
if (t0 == 0){
if (t0 == 0) {
return RitualState.NON_INITIATED;
} else if (ritual.totalAggregations == ritual.dkgSize) {
return RitualState.FINALIZED;
} else if (ritual.aggregationMismatch){
} else if (ritual.aggregationMismatch) {
return RitualState.INVALID;
} else if (block.timestamp > deadline){
} else if (block.timestamp > deadline) {
return RitualState.TIMEOUT;
} else if (ritual.totalTranscripts < ritual.dkgSize) {
return RitualState.AWAITING_TRANSCRIPTS;
Expand All @@ -124,6 +133,28 @@ contract Coordinator is AccessControlDefaultAdminRules {
_setRoleAdmin(INITIATOR_ROLE, bytes32(0));
}

function setProviderPublicKey(BLS12381.G2Point calldata _publicKey) public {
uint32 lastRitualId = uint32(rituals.length);
address provider = application.stakingProviderFromOperator(msg.sender);

ParticipantKey memory newRecord = ParticipantKey(lastRitualId, _publicKey);
keysHistory[provider].push(newRecord);

emit ParticipantPublicKeySet(lastRitualId, provider, _publicKey);
}

function getProviderPublicKey(address _provider, uint _ritualId) external view returns (BLS12381.G2Point memory) {
ParticipantKey[] storage participantHistory = keysHistory[_provider];

for (uint i = participantHistory.length - 1; i >= 0; i--) {
if (participantHistory[i].lastRitualId <= _ritualId) {
return participantHistory[i].publicKey;
}
}

revert("No keys found prior to the provided ritual");
}

function setTimeout(uint32 newTimeout) external onlyRole(DEFAULT_ADMIN_ROLE) {
emit TimeoutChanged(timeout, newTimeout);
timeout = newTimeout;
Expand All @@ -136,19 +167,19 @@ contract Coordinator is AccessControlDefaultAdminRules {

function setReimbursementPool(IReimbursementPool pool) external onlyRole(DEFAULT_ADMIN_ROLE) {
require(
address(pool) == address(0) ||
address(pool) == address(0) ||
pool.isAuthorized(address(this)),
"Invalid ReimbursementPool"
);
reimbursementPool = pool;
// TODO: Events
}

function numberOfRituals() external view returns(uint256) {
function numberOfRituals() external view returns (uint256) {
return rituals.length;
}

function getParticipants(uint32 ritualId) external view returns(Participant[] memory) {
function getParticipants(uint32 ritualId) external view returns (Participant[] memory) {
Ritual storage ritual = rituals[ritualId];
return ritual.participant;
}
Expand Down Expand Up @@ -176,28 +207,32 @@ contract Coordinator is AccessControlDefaultAdminRules {
ritual.endTimestamp = ritual.initTimestamp + duration;

address previous = address(0);
for(uint256 i=0; i < length; i++){
for (uint256 i = 0; i < length; i++) {
Participant storage newParticipant = ritual.participant.push();
address current = providers[i];
// Make sure that current provider has already set their public key
ParticipantKey[] storage participantHistory = keysHistory[current];
require(participantHistory.length > 0, "Provider has not set their public key");

require(previous < current, "Providers must be sorted");
// TODO: Improve check for eligible nodes (staking, etc) - nucypher#3109
// TODO: Change check to isAuthorized(), without amount
require(
application.authorizedStake(current) > 0,
application.authorizedStake(current) > 0,
"Not enough authorization"
);
newParticipant.provider = current;
previous = current;
}

processRitualPayment(id, providers, duration);

// TODO: Include cohort fingerprint in StartRitual event?
emit StartRitual(id, ritual.authority, providers);
return id;
}

function cohortFingerprint(address[] calldata nodes) public pure returns(bytes32) {
function cohortFingerprint(address[] calldata nodes) public pure returns (bytes32) {
return keccak256(abi.encode(nodes));
}

Expand Down Expand Up @@ -231,7 +266,7 @@ contract Coordinator is AccessControlDefaultAdminRules {
ritual.totalTranscripts++;

// end round
if (ritual.totalTranscripts == ritual.dkgSize){
if (ritual.totalTranscripts == ritual.dkgSize) {
emit StartAggregationRound(ritualId);
}
processReimbursement(initialGasLeft);
Expand All @@ -240,7 +275,7 @@ contract Coordinator is AccessControlDefaultAdminRules {
function postAggregation(
uint32 ritualId,
bytes calldata aggregatedTranscript,
BLS12381.G1Point calldata publicKey,
BLS12381.G1Point calldata dkgPublicKey,
bytes calldata decryptionRequestStaticKey
) external {
uint256 initialGasLeft = gasleft();
Expand Down Expand Up @@ -281,21 +316,21 @@ contract Coordinator is AccessControlDefaultAdminRules {

if (ritual.aggregatedTranscript.length == 0) {
ritual.aggregatedTranscript = aggregatedTranscript;
ritual.publicKey = publicKey;
ritual.publicKey = dkgPublicKey;
} else if (
!BLS12381.eqG1Point(ritual.publicKey, publicKey) ||
keccak256(ritual.aggregatedTranscript) != aggregatedTranscriptDigest
){
!BLS12381.eqG1Point(ritual.publicKey, dkgPublicKey) ||
keccak256(ritual.aggregatedTranscript) != aggregatedTranscriptDigest
) {
ritual.aggregationMismatch = true;
emit EndRitual({
ritualId: ritualId,
successful: false
});
}

if(!ritual.aggregationMismatch){
if (!ritual.aggregationMismatch) {
ritual.totalAggregations++;
if (ritual.totalAggregations == ritual.dkgSize){
if (ritual.totalAggregations == ritual.dkgSize) {
processPendingFee(ritualId);
emit EndRitual({
ritualId: ritualId,
Expand All @@ -314,9 +349,9 @@ contract Coordinator is AccessControlDefaultAdminRules {
) internal view returns (Participant storage) {
uint length = ritual.participant.length;
// TODO: Improve with binary search
for(uint i = 0; i < length; i++){
for (uint i = 0; i < length; i++) {
Participant storage participant = ritual.participant[i];
if(participant.provider == provider){
if (participant.provider == provider) {
return participant;
}
}
Expand All @@ -332,7 +367,7 @@ contract Coordinator is AccessControlDefaultAdminRules {

function processRitualPayment(uint256 ritualID, address[] calldata providers, uint32 duration) internal {
uint256 ritualCost = feeModel.getRitualInitiationCost(providers, duration);
if (ritualCost > 0){
if (ritualCost > 0) {
totalPendingFees += ritualCost;
assert(pendingFees[ritualID] == 0); // TODO: This is an invariant, not sure if actually needed
pendingFees[ritualID] += ritualCost;
Expand All @@ -358,7 +393,7 @@ contract Coordinator is AccessControlDefaultAdminRules {
totalPendingFees -= pending;
delete pendingFees[ritualID];
// Transfer fees back to initiator if failed
if(state == RitualState.TIMEOUT || state == RitualState.INVALID){
if (state == RitualState.TIMEOUT || state == RitualState.INVALID) {
// Amount to refund depends on how much work nodes did for the ritual.
// TODO: Validate if this is enough to remove griefing attacks
uint256 executedTransactions = ritual.totalTranscripts + ritual.totalAggregations;
Expand All @@ -369,9 +404,9 @@ contract Coordinator is AccessControlDefaultAdminRules {
currency.transferFrom(address(this), ritual.initiator, refundableFee);
}
}

function processReimbursement(uint256 initialGasLeft) internal {
if(address(reimbursementPool) != address(0)){ // TODO: Consider defining a method
if (address(reimbursementPool) != address(0)) { // TODO: Consider defining a method
uint256 gasUsed = initialGasLeft - gasleft();
try reimbursementPool.refund(gasUsed, msg.sender) {
return;
Expand Down
66 changes: 43 additions & 23 deletions tests/test_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def transcript_size(shares, threshold):
return int(424 + 240 * (shares / 2) + 50 * (threshold))


def gen_public_key():
return (os.urandom(32), os.urandom(32), os.urandom(32))


@pytest.fixture(scope="module")
def nodes(accounts):
return sorted(accounts[:MAX_DKG_SIZE], key=lambda x: x.address.lower())
Expand Down Expand Up @@ -106,6 +110,12 @@ def test_invalid_initiate_ritual(coordinator, nodes, accounts, initiator):
with ape.reverts("Invalid ritual duration"):
coordinator.initiateRitual(nodes, initiator, 0, sender=initiator)

with ape.reverts("Provider has not set their public key"):
coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator)

for node in nodes:
public_key = gen_public_key()
coordinator.setProviderPublicKey(public_key, sender=node)
with ape.reverts("Providers must be sorted"):
coordinator.initiateRitual(nodes[1:] + [nodes[0]], initiator, DURATION, sender=initiator)

Expand All @@ -114,11 +124,18 @@ def test_invalid_initiate_ritual(coordinator, nodes, accounts, initiator):
coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator)


def test_initiate_ritual(coordinator, nodes, initiator, erc20, flat_rate_fee_model):
def initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes):
for node in nodes:
public_key = gen_public_key()
coordinator.setProviderPublicKey(public_key, sender=node)
cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION)
erc20.approve(coordinator.address, cost, sender=initiator)
authority = initiator
tx = coordinator.initiateRitual(nodes, authority, DURATION, sender=initiator)
tx = coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator)
return initiator, tx


def test_initiate_ritual(coordinator, nodes, initiator, erc20, flat_rate_fee_model):
authority, tx = initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes)

events = list(coordinator.StartRitual.from_receipt(tx))
assert len(events) == 1
Expand All @@ -130,10 +147,22 @@ def test_initiate_ritual(coordinator, nodes, initiator, erc20, flat_rate_fee_mod
assert coordinator.getRitualState(0) == RitualState.AWAITING_TRANSCRIPTS


def test_provider_public_key(coordinator, nodes):
selected_provider = nodes[0]
public_key = gen_public_key()
tx = coordinator.setProviderPublicKey(public_key, sender=selected_provider)
ritual_id = coordinator.numberOfRituals()

events = list(coordinator.ParticipantPublicKeySet.from_receipt(tx))
assert len(events) == 1
event = events[0]
assert event["participant"] == selected_provider
assert event["publicKey"] == public_key
assert coordinator.getProviderPublicKey(selected_provider, ritual_id) == public_key


def test_post_transcript(coordinator, nodes, initiator, erc20, flat_rate_fee_model):
cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION)
erc20.approve(coordinator.address, cost, sender=initiator)
coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator)
initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes)
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))

for node in nodes:
Expand All @@ -159,9 +188,7 @@ def test_post_transcript(coordinator, nodes, initiator, erc20, flat_rate_fee_mod
def test_post_transcript_but_not_part_of_ritual(
coordinator, nodes, initiator, erc20, flat_rate_fee_model
):
cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION)
erc20.approve(coordinator.address, cost, sender=initiator)
coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator)
initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes)
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))
with ape.reverts("Participant not part of ritual"):
coordinator.postTranscript(0, transcript, sender=initiator)
Expand All @@ -170,9 +197,7 @@ def test_post_transcript_but_not_part_of_ritual(
def test_post_transcript_but_already_posted_transcript(
coordinator, nodes, initiator, erc20, flat_rate_fee_model
):
cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION)
erc20.approve(coordinator.address, cost, sender=initiator)
coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator)
initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes)
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))
coordinator.postTranscript(0, transcript, sender=nodes[0])
with ape.reverts("Node already posted transcript"):
Expand All @@ -182,9 +207,7 @@ def test_post_transcript_but_already_posted_transcript(
def test_post_transcript_but_not_waiting_for_transcripts(
coordinator, nodes, initiator, erc20, flat_rate_fee_model
):
cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION)
erc20.approve(coordinator.address, cost, sender=initiator)
coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator)
initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes)
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))
for node in nodes:
coordinator.postTranscript(0, transcript, sender=node)
Expand All @@ -194,21 +217,18 @@ def test_post_transcript_but_not_waiting_for_transcripts(


def test_post_aggregation(coordinator, nodes, initiator, erc20, flat_rate_fee_model):
cost = flat_rate_fee_model.getRitualInitiationCost(nodes, DURATION)
erc20.approve(coordinator.address, cost, sender=initiator)
coordinator.initiateRitual(nodes, initiator, DURATION, sender=initiator)
initiate_ritual(coordinator, erc20, flat_rate_fee_model, initiator, nodes)
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))

for node in nodes:
coordinator.postTranscript(0, transcript, sender=node)

aggregated = transcript # has the same size as transcript
decryptionRequestStaticKeys = [os.urandom(42) for node in nodes]
publicKey = (os.urandom(32), os.urandom(16))
decryption_request_static_keys = [os.urandom(42) for _ in nodes]
dkg_public_key = (os.urandom(32), os.urandom(16))
for i, node in enumerate(nodes):
assert coordinator.getRitualState(0) == RitualState.AWAITING_AGGREGATIONS
tx = coordinator.postAggregation(
0, aggregated, publicKey, decryptionRequestStaticKeys[i], sender=node
0, aggregated, dkg_public_key, decryption_request_static_keys[i], sender=node
)

events = list(coordinator.AggregationPosted.from_receipt(tx))
Expand All @@ -221,7 +241,7 @@ def test_post_aggregation(coordinator, nodes, initiator, erc20, flat_rate_fee_mo
participants = coordinator.getParticipants(0)
for i, participant in enumerate(participants):
assert participant.aggregated
assert participant.decryptionRequestStaticKey == decryptionRequestStaticKeys[i]
assert participant.decryptionRequestStaticKey == decryption_request_static_keys[i]

assert coordinator.getRitualState(0) == RitualState.FINALIZED
events = list(coordinator.EndRitual.from_receipt(tx))
Expand Down

0 comments on commit 3eebc89

Please sign in to comment.