Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Coordinator aware of transcript size #334

Merged
merged 5 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion contracts/contracts/coordination/Coordinator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,13 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable
return keccak256(abi.encode(nodes));
}

function expectedTranscriptSize(
uint16 dkgSize,
uint16 threshold
) public pure returns (uint256) {
return 40 + (dkgSize + 1) * BLS12381.G2_POINT_SIZE + threshold * BLS12381.G1_POINT_SIZE;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a nitpick but 40 is the only unnamed component of this calculation. Perhaps it can be assigned to a constant similar to the BLS point sizes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, but the thing is that I don't understand well enough where that 40 comes from. I mean, it's a fixed overhead that comes from rust serialization but I don't have a full explanation for it. I'll update this when I have more info.

}

function postTranscript(uint32 ritualId, bytes calldata transcript) external {
uint256 initialGasLeft = gasleft();

Expand All @@ -394,6 +401,11 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable
"Not waiting for transcripts"
);

require(
transcript.length == expectedTranscriptSize(ritual.dkgSize, ritual.threshold),
"Invalid transcript size"
);

address provider = application.operatorToStakingProvider(msg.sender);
Participant storage participant = getParticipant(ritual, provider);

Expand Down Expand Up @@ -449,6 +461,11 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable
"Invalid length for decryption request static key"
);

require(
aggregatedTranscript.length == expectedTranscriptSize(ritual.dkgSize, ritual.threshold),
"Invalid transcript size"
);

// nodes commit to their aggregation result
bytes32 aggregatedTranscriptDigest = keccak256(aggregatedTranscript);
participant.aggregated = true;
Expand Down Expand Up @@ -618,7 +635,9 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable

function processReimbursement(uint256 initialGasLeft) internal {
if (address(reimbursementPool) != address(0)) {
uint256 gasUsed = initialGasLeft - gasleft();
// For calldataGasCost calculation, see https://github.com/nucypher/nucypher-contracts/issues/328
uint256 calldataGasCost = (msg.data.length - 128) * 16 + 128 * 4;
uint256 gasUsed = initialGasLeft - gasleft() + calldataGasCost;
try reimbursementPool.refund(gasUsed, msg.sender) {
return;
} catch {
Expand Down
24 changes: 24 additions & 0 deletions deployment/artifacts/lynx.json
Original file line number Diff line number Diff line change
Expand Up @@ -3964,6 +3964,30 @@
}
]
},
{
"type": "function",
"name": "expectedTranscriptSize",
"stateMutability": "pure",
"inputs": [
{
"name": "dkgSize",
"type": "uint16",
"internalType": "uint16"
},
{
"name": "threshold",
"type": "uint16",
"internalType": "uint16"
}
],
"outputs": [
{
"name": "",
"type": "uint256",
"internalType": "uint256"
}
]
},
{
"type": "function",
"name": "extendRitual",
Expand Down
Empty file added tests/__init__.py
Empty file.
49 changes: 41 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,47 @@
import os
import pytest
from ape import convert, project
from ape import project
from enum import IntEnum

# Common constants
G1_SIZE = 48
G2_SIZE = 48 * 2
ONE_DAY = 24 * 60 * 60

RitualState = IntEnum(
"RitualState",
[
"NON_INITIATED",
"DKG_AWAITING_TRANSCRIPTS",
"DKG_AWAITING_AGGREGATIONS",
"DKG_TIMEOUT",
"DKG_INVALID",
"ACTIVE",
"EXPIRED",
],
start=0,
)


# Utility functions
def transcript_size(shares, threshold):
return 40 + (1 + shares) * G2_SIZE + threshold * G1_SIZE


def generate_transcript(shares, threshold):
return os.urandom(transcript_size(shares, threshold))


def gen_public_key():
return (os.urandom(32), os.urandom(32), os.urandom(32))


def access_control_error_message(address, role=None):
role = role or b"\x00" * 32
return f"account={address}, neededRole={role}"


# Fixtures
@pytest.fixture(scope="session")
def oz_dependency():
return project.dependencies["openzeppelin"]["5.0.0"]
Expand All @@ -20,10 +60,3 @@ def account1(accounts):
@pytest.fixture
def account2(accounts):
return accounts[2]


@pytest.fixture
def nu_token(NuCypherToken, creator):
TOTAL_SUPPLY = convert("1_000_000_000 ether", int)
nu_token = creator.deploy(NuCypherToken, TOTAL_SUPPLY)
return nu_token
102 changes: 62 additions & 40 deletions tests/test_coordinator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from enum import IntEnum

import ape
import pytest
Expand All @@ -8,41 +7,13 @@
from hexbytes import HexBytes
from web3 import Web3

from tests.conftest import ONE_DAY, gen_public_key, generate_transcript, RitualState

TIMEOUT = 1000
MAX_DKG_SIZE = 31
FEE_RATE = 42
ERC20_SUPPLY = 10**24
DURATION = 48 * 60 * 60
ONE_DAY = 24 * 60 * 60
Copy link
Member

@derekpierre derekpierre Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this cleanup. We were repeating the RitualState enum and fixtures in a number of places.


RitualState = IntEnum(
"RitualState",
[
"NON_INITIATED",
"DKG_AWAITING_TRANSCRIPTS",
"DKG_AWAITING_AGGREGATIONS",
"DKG_TIMEOUT",
"DKG_INVALID",
"ACTIVE",
"EXPIRED",
],
start=0,
)


# This formula returns an approximated size
# To have a representative size, create transcripts with `nucypher-core`
def transcript_size(shares, threshold):
return int(424 + 240 * (shares / 2) + 50 * (threshold))


def gen_public_key():
return (os.urandom(32), os.urandom(32), os.urandom(32))


def access_control_error_message(address, role=None):
role = role or b"\x00" * 32
return f"account={address}, neededRole={role}"


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -275,7 +246,9 @@ def test_post_transcript(coordinator, nodes, initiator, erc20, fee_model, global
nodes=nodes,
allow_logic=global_allow_list,
)
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))
size = len(nodes)
threshold = coordinator.getThresholdForRitualSize(size)
transcript = generate_transcript(size, threshold)

for node in nodes:
assert coordinator.getRitualState(0) == RitualState.DKG_AWAITING_TRANSCRIPTS
Expand Down Expand Up @@ -309,7 +282,10 @@ def test_post_transcript_but_not_part_of_ritual(
allow_logic=global_allow_list,
)

transcript = os.urandom(transcript_size(len(nodes), len(nodes)))
size = len(nodes)
threshold = coordinator.getThresholdForRitualSize(size)
transcript = generate_transcript(size, threshold)

with ape.reverts("Participant not part of ritual"):
coordinator.postTranscript(0, transcript, sender=initiator)

Expand All @@ -325,12 +301,40 @@ def test_post_transcript_but_already_posted_transcript(
nodes=nodes,
allow_logic=global_allow_list,
)
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))

size = len(nodes)
threshold = coordinator.getThresholdForRitualSize(size)
transcript = generate_transcript(size, threshold)

coordinator.postTranscript(0, transcript, sender=nodes[0])
with ape.reverts("Node already posted transcript"):
coordinator.postTranscript(0, transcript, sender=nodes[0])


def test_post_transcript_but_wrong_size(
coordinator, nodes, initiator, erc20, fee_model, global_allow_list
):
initiate_ritual(
coordinator=coordinator,
fee_model=fee_model,
erc20=erc20,
authority=initiator,
nodes=nodes,
allow_logic=global_allow_list,
)

size = len(nodes)
threshold = coordinator.getThresholdForRitualSize(size)
bad_transcript = generate_transcript(size, threshold + 1)

with ape.reverts("Invalid transcript size"):
coordinator.postTranscript(0, bad_transcript, sender=nodes[0])

bad_transcript = b""
with ape.reverts("Invalid transcript size"):
coordinator.postTranscript(0, bad_transcript, sender=nodes[0])


def test_post_transcript_but_not_waiting_for_transcripts(
coordinator, nodes, initiator, erc20, fee_model, global_allow_list
):
Expand All @@ -342,7 +346,11 @@ def test_post_transcript_but_not_waiting_for_transcripts(
nodes=nodes,
allow_logic=global_allow_list,
)
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))

size = len(nodes)
threshold = coordinator.getThresholdForRitualSize(size)
transcript = generate_transcript(size, threshold)

for node in nodes:
coordinator.postTranscript(0, transcript, sender=node)

Expand All @@ -359,7 +367,10 @@ def test_get_participants(coordinator, nodes, initiator, erc20, fee_model, globa
nodes=nodes,
allow_logic=global_allow_list,
)
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))

size = len(nodes)
threshold = coordinator.getThresholdForRitualSize(size)
transcript = generate_transcript(size, threshold)

for node in nodes:
_ = coordinator.postTranscript(0, transcript, sender=node)
Expand Down Expand Up @@ -413,7 +424,10 @@ def test_get_participant(nodes, coordinator, initiator, erc20, fee_model, global
nodes=nodes,
allow_logic=global_allow_list,
)
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))

size = len(nodes)
threshold = coordinator.getThresholdForRitualSize(size)
transcript = generate_transcript(size, threshold)

for node in nodes:
_ = coordinator.postTranscript(0, transcript, sender=node)
Expand Down Expand Up @@ -462,8 +476,12 @@ def test_post_aggregation(
nodes=nodes,
allow_logic=global_allow_list,
)

ritualID = 0
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))
size = len(nodes)
threshold = coordinator.getThresholdForRitualSize(size)
transcript = generate_transcript(size, threshold)

for node in nodes:
coordinator.postTranscript(ritualID, transcript, sender=node)

Expand Down Expand Up @@ -520,8 +538,12 @@ def test_post_aggregation_fails(
nodes=nodes,
allow_logic=global_allow_list,
)

ritualID = 0
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))
size = len(nodes)
threshold = coordinator.getThresholdForRitualSize(size)
transcript = generate_transcript(size, threshold)

for node in nodes:
coordinator.postTranscript(ritualID, transcript, sender=node)

Expand All @@ -535,7 +557,7 @@ def test_post_aggregation_fails(
)

# Second node screws up everything
bad_aggregated = os.urandom(transcript_size(len(nodes), len(nodes)))
bad_aggregated = generate_transcript(size, threshold)
tx = coordinator.postAggregation(
ritualID, bad_aggregated, dkg_public_key, decryption_request_static_keys[1], sender=nodes[1]
)
Expand Down
36 changes: 5 additions & 31 deletions tests/test_global_allow_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,13 @@
from eth_account.messages import encode_defunct
from web3 import Web3

from tests.conftest import gen_public_key, generate_transcript

TIMEOUT = 1000
MAX_DKG_SIZE = 31
FEE_RATE = 42
ERC20_SUPPLY = 10**24
DURATION = 48 * 60 * 60
ONE_DAY = 24 * 60 * 60

RitualState = IntEnum(
"RitualState",
[
"NON_INITIATED",
"DKG_AWAITING_TRANSCRIPTS",
"DKG_AWAITING_AGGREGATIONS",
"DKG_TIMEOUT",
"DKG_INVALID",
"ACTIVE",
"EXPIRED",
],
start=0,
)


# This formula returns an approximated size
# To have a representative size, create transcripts with `nucypher-core`
def transcript_size(shares, threshold):
return int(424 + 240 * (shares / 2) + 50 * (threshold))


def gen_public_key():
return (os.urandom(32), os.urandom(32), os.urandom(32))


def access_control_error_message(address, role=None):
role = role or b"\x00" * 32
return f"account={address}, neededRole={role}"


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -153,6 +125,7 @@ def test_authorize_using_global_allow_list(
signable_message = encode_defunct(digest)
signed_digest = w3.eth.account.sign_message(signable_message, private_key=deployer.private_key)
signature = signed_digest.signature
size = len(nodes)

# Not authorized
assert not global_allow_list.isAuthorized(0, bytes(signature), bytes(digest))
Expand All @@ -168,7 +141,8 @@ def test_authorize_using_global_allow_list(
coordinator.isEncryptionAuthorized(0, bytes(signature), bytes(digest))

# Finalize ritual
transcript = os.urandom(transcript_size(len(nodes), len(nodes)))
threshold = coordinator.getThresholdForRitualSize(size)
transcript = generate_transcript(size, threshold)
for node in nodes:
coordinator.postTranscript(0, transcript, sender=node)

Expand Down
Loading