Skip to content

Commit

Permalink
Merge pull request #116 from derekpierre/alpha-tweaks
Browse files Browse the repository at this point in the history
Tweaks for Alpha 12
  • Loading branch information
KPrasch authored Sep 13, 2023
2 parents 2685e35 + f6fdc2e commit a96b83b
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 18 deletions.
39 changes: 27 additions & 12 deletions contracts/contracts/coordination/Coordinator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ contract Coordinator is AccessControlDefaultAdminRules, FlatRateFeeModel {
}

struct Ritual {
// NOTE: changing the order here affects nucypher/nucypher: CoordinatorAgent
address initiator;
uint32 initTimestamp;
uint32 endTimestamp;
Expand Down Expand Up @@ -93,7 +94,7 @@ contract Coordinator is AccessControlDefaultAdminRules, FlatRateFeeModel {
mapping(uint256 => uint256) public pendingFees;
IFeeModel internal feeModel; // TODO: Consider making feeModel specific to each ritual
IReimbursementPool internal reimbursementPool;
mapping(address => ParticipantKey[]) internal keysHistory;
mapping(address => ParticipantKey[]) internal participantKeysHistory;
mapping(bytes32 => uint32) internal ritualPublicKeyRegistry;

constructor(
Expand Down Expand Up @@ -152,7 +153,7 @@ contract Coordinator is AccessControlDefaultAdminRules, FlatRateFeeModel {
address provider = application.stakingProviderFromOperator(msg.sender);

ParticipantKey memory newRecord = ParticipantKey(lastRitualId, _publicKey);
keysHistory[provider].push(newRecord);
participantKeysHistory[provider].push(newRecord);

emit ParticipantPublicKeySet(lastRitualId, provider, _publicKey);
}
Expand All @@ -161,17 +162,22 @@ contract Coordinator is AccessControlDefaultAdminRules, FlatRateFeeModel {
address _provider,
uint256 _ritualId
) external view returns (BLS12381.G2Point memory) {
ParticipantKey[] storage participantHistory = keysHistory[_provider];
ParticipantKey[] storage participantHistory = participantKeysHistory[_provider];

for (uint256 i = participantHistory.length - 1; i >= 0; i--) {
if (participantHistory[i].lastRitualId <= _ritualId) {
return participantHistory[i].publicKey;
for (uint256 i = participantHistory.length; i > 0; i--) {
if (participantHistory[i - 1].lastRitualId <= _ritualId) {
return participantHistory[i - 1].publicKey;
}
}

revert("No keys found prior to the provided ritual");
}

function isProviderPublicKeySet(address _provider) external view returns (bool) {
ParticipantKey[] storage participantHistory = participantKeysHistory[_provider];
return participantHistory.length > 0;
}

function setTimeout(uint32 newTimeout) external onlyRole(DEFAULT_ADMIN_ROLE) {
emit TimeoutChanged(timeout, newTimeout);
timeout = newTimeout;
Expand Down Expand Up @@ -243,7 +249,7 @@ contract Coordinator is AccessControlDefaultAdminRules, FlatRateFeeModel {
Participant storage newParticipant = ritual.participant.push();
address current = providers[i];
// Make sure that current provider has already set their public key
ParticipantKey[] storage participantHistory = keysHistory[current];
ParticipantKey[] storage participantHistory = participantKeysHistory[current];
require(participantHistory.length > 0, "Provider has not set their public key");

require(previous < current, "Providers must be sorted");
Expand Down Expand Up @@ -401,17 +407,26 @@ contract Coordinator is AccessControlDefaultAdminRules, FlatRateFeeModel {
return getParticipantFromProvider(rituals[ritualId], provider);
}

function isEncryptionAuthorized(
uint32 ritualId,
bytes memory evidence,
bytes memory ciphertextHeader
) external view returns (bool) {
Ritual storage ritual = rituals[ritualId];
require(getRitualState(ritual) == RitualState.FINALIZED, "Ritual not finalized");
return ritual.accessController.isAuthorized(ritualId, evidence, ciphertextHeader);
}

function processRitualPayment(
uint32 ritualId,
address[] calldata providers,
uint32 duration
) internal {
uint256 ritualCost = getRitualInitiationCost(providers, duration);
if (ritualCost > 0) {
totalPendingFees += ritualCost;
pendingFees[ritualId] = ritualCost;
currency.safeTransferFrom(msg.sender, address(this), ritualCost);
}
require(ritualCost > 0, "Invalid ritual cost");
totalPendingFees += ritualCost;
pendingFees[ritualId] = ritualCost;
currency.safeTransferFrom(msg.sender, address(this), ritualCost);
}

function processPendingFee(uint32 ritualId) public {
Expand Down
1 change: 1 addition & 0 deletions contracts/contracts/coordination/FlatRateFeeModel.sol
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ contract FlatRateFeeModel is IFeeModel {
uint256 public immutable feeRatePerSecond;

constructor(IERC20 _currency, uint256 _feeRatePerSecond) {
require(_feeRatePerSecond > 0, "Invalid fee rate");
currency = _currency;
feeRatePerSecond = _feeRatePerSecond;
}
Expand Down
30 changes: 24 additions & 6 deletions tests/test_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def gen_public_key():


def access_control_error_message(address, role=None):
role = Web3.to_hex(role or b'\x00'*32)
role = Web3.to_hex(role or b"\x00" * 32)
return f"AccessControl: account {address.lower()} is missing role {role}"


Expand Down Expand Up @@ -152,7 +152,10 @@ def test_invalid_initiate_ritual(coordinator, nodes, accounts, initiator, global
def initiate_ritual(coordinator, erc20, allow_logic, authority, nodes):
for node in nodes:
public_key = gen_public_key()
assert not coordinator.isProviderPublicKeySet(node)
coordinator.setProviderPublicKey(public_key, sender=node)
assert coordinator.isProviderPublicKeySet(node)

cost = coordinator.getRitualInitiationCost(nodes, DURATION)
erc20.approve(coordinator.address, cost, sender=authority)
tx = coordinator.initiateRitual(
Expand All @@ -161,7 +164,9 @@ def initiate_ritual(coordinator, erc20, allow_logic, authority, nodes):
return authority, tx


def test_initiate_ritual(coordinator, nodes, initiator, erc20, global_allow_list, deployer, treasury):
def test_initiate_ritual(
coordinator, nodes, initiator, erc20, global_allow_list, deployer, treasury
):
authority, tx = initiate_ritual(
coordinator=coordinator,
erc20=erc20,
Expand All @@ -179,7 +184,7 @@ def test_initiate_ritual(coordinator, nodes, initiator, erc20, global_allow_list
assert event["participants"] == tuple(n.address.lower() for n in nodes)

assert coordinator.getRitualState(0) == RitualState.AWAITING_TRANSCRIPTS

ritual_struct = coordinator.rituals(ritualID)
assert ritual_struct[0] == initiator
init, end = ritual_struct[1], ritual_struct[2]
Expand Down Expand Up @@ -210,7 +215,11 @@ def test_initiate_ritual(coordinator, nodes, initiator, erc20, global_allow_list
def test_provider_public_key(coordinator, nodes):
selected_provider = nodes[0]
public_key = gen_public_key()

assert not coordinator.isProviderPublicKeySet(selected_provider)
tx = coordinator.setProviderPublicKey(public_key, sender=selected_provider)
assert coordinator.isProviderPublicKeySet(selected_provider)

ritual_id = coordinator.numberOfRituals()

events = coordinator.ParticipantPublicKeySet.from_receipt(tx)
Expand Down Expand Up @@ -301,7 +310,9 @@ def test_post_transcript_but_not_waiting_for_transcripts(
coordinator.postTranscript(0, transcript, sender=nodes[1])


def test_post_aggregation(coordinator, nodes, initiator, erc20, global_allow_list, treasury, deployer):
def test_post_aggregation(
coordinator, nodes, initiator, erc20, global_allow_list, treasury, deployer
):
initiate_ritual(
coordinator=coordinator,
erc20=erc20,
Expand All @@ -320,8 +331,7 @@ def test_post_aggregation(coordinator, nodes, initiator, erc20, global_allow_lis
for i, node in enumerate(nodes):
assert coordinator.getRitualState(ritualID) == RitualState.AWAITING_AGGREGATIONS
tx = coordinator.postAggregation(
ritualID, aggregated, dkg_public_key, decryption_request_static_keys[i],
sender=node
ritualID, aggregated, dkg_public_key, decryption_request_static_keys[i], sender=node
)

events = coordinator.AggregationPosted.from_receipt(tx)
Expand Down Expand Up @@ -390,6 +400,9 @@ def test_authorize_using_global_allow_list(
with ape.reverts("Only active rituals can add authorizations"):
global_allow_list.authorize(0, [deployer.address], sender=initiator)

with ape.reverts("Ritual not finalized"):
coordinator.isEncryptionAuthorized(0, bytes(signature), bytes(digest))

# Finalize ritual
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))
for node in nodes:
Expand All @@ -408,15 +421,20 @@ def test_authorize_using_global_allow_list(

# Authorized
assert global_allow_list.isAuthorized(0, bytes(signature), bytes(data))
assert coordinator.isEncryptionAuthorized(0, bytes(signature), bytes(data))

# Deauthorize
global_allow_list.deauthorize(0, [deployer.address], sender=initiator)
assert not global_allow_list.isAuthorized(0, bytes(signature), bytes(data))
assert not coordinator.isEncryptionAuthorized(0, bytes(signature), bytes(data))

# Reauthorize in batch
addresses_to_authorize = [deployer.address, initiator.address]
global_allow_list.authorize(0, addresses_to_authorize, sender=initiator)
signed_digest = w3.eth.account.sign_message(signable_message, private_key=initiator.private_key)
initiator_signature = signed_digest.signature
assert global_allow_list.isAuthorized(0, bytes(initiator_signature), bytes(data))
assert coordinator.isEncryptionAuthorized(0, bytes(initiator_signature), bytes(data))

assert global_allow_list.isAuthorized(0, bytes(signature), bytes(data))
assert coordinator.isEncryptionAuthorized(0, bytes(signature), bytes(data))

0 comments on commit a96b83b

Please sign in to comment.