Skip to content

Commit

Permalink
WIP python: Cache blame() inputs, piggyback them on exception
Browse files Browse the repository at this point in the history
  • Loading branch information
real-or-random committed Nov 13, 2024
1 parent 4c2852f commit c6d7934
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 92 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ of the success of the DKG session by presenting recovery data to us.
#### participant\_blame

```python
def participant_blame(hostseckey: bytes, state1: ParticipantState1, cmsg1: CoordinatorMsg1, cblame: CoordinatorBlameMsg) -> NoReturn
def participant_blame(blame_state: ParticipantBlameState, cblame: CoordinatorBlameMsg) -> NoReturn
```

Perform a participant's blame step of a ChillDKG session. TODO
Expand Down
33 changes: 18 additions & 15 deletions python/chilldkg_ref/chilldkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,10 @@ class ParticipantState2(NamedTuple):
dkg_output: DKGOutput


class ParticipantBlameState(NamedTuple):
enc_blame_state: encpedpop.ParticipantBlameState


def participant_step1(
hostseckey: bytes, params: SessionParams, random: bytes
) -> Tuple[ParticipantState1, ParticipantMsg1]:
Expand Down Expand Up @@ -472,16 +476,21 @@ def participant_step2(
params, idx, enc_state = state1
enc_cmsg, enc_secshares = cmsg1

enc_dkg_output, eq_input = encpedpop.participant_step2(
state=enc_state,
deckey=hostseckey,
cmsg=enc_cmsg,
enc_secshare=enc_secshares[idx],
)
try:
enc_dkg_output, eq_input = encpedpop.participant_step2(
state=enc_state,
deckey=hostseckey,
cmsg=enc_cmsg,
enc_secshare=enc_secshares[idx],
)
except UnknownFaultyPartyError[encpedpop.ParticipantBlameState] as e:
# Convert encpedpop.ParticipanBlameState to chilldkg.ParticipantBlameState
blame_state = ParticipantBlameState(e.blame_state)
raise UnknownFaultyPartyError[ParticipantBlameState](blame_state, e.args) from e
# Include the enc_shares in eq_input to ensure that participants agree on all
# shares, which in turn ensures that they have the right recovery data.
eq_input += b"".join([bytes_from_int(int(share)) for share in enc_secshares])
dkg_output = DKGOutput._make(enc_dkg_output) # Convert to chilldkg.DKGOutput type
dkg_output = DKGOutput._make(enc_dkg_output)
state2 = ParticipantState2(params, eq_input, dkg_output)
sig = certeq_participant_step(hostseckey, idx, eq_input)
pmsg2 = ParticipantMsg2(sig)
Expand Down Expand Up @@ -529,18 +538,12 @@ def participant_finalize(


def participant_blame(
hostseckey: bytes,
state1: ParticipantState1,
cmsg1: CoordinatorMsg1,
blame_state: ParticipantBlameState,
cblame: CoordinatorBlameMsg,
) -> NoReturn:
"""Perform a participant's blame step of a ChillDKG session. TODO"""
_, idx, enc_state = state1
encpedpop.participant_blame(
state=enc_state,
deckey=hostseckey,
cmsg=cmsg1.enc_cmsg,
enc_secshare=cmsg1.enc_secshares[idx],
blame_state=blame_state.enc_blame_state,
cblame=cblame.enc_cblame,
)

Expand Down
39 changes: 20 additions & 19 deletions python/chilldkg_ref/encpedpop.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from . import simplpedpop
from .util import (
UnknownFaultyPartyError,
tagged_hash_bip_dkg,
prf,
FaultyParticipantOrCoordinatorError,
Expand Down Expand Up @@ -162,6 +163,12 @@ class ParticipantState(NamedTuple):
idx: int


class ParticipantBlameState(NamedTuple):
simpl_bstate: simplpedpop.ParticipantBlameState
enc_secshare: Scalar
pads: List[Scalar]


def serialize_enc_context(t: int, enckeys: List[bytes]) -> bytes:
# TODO Consider hashing the result here because the string can be long, and
# we'll feed it into hashes on multiple occasions
Expand Down Expand Up @@ -223,33 +230,27 @@ def participant_step2(
raise FaultyCoordinatorError("Coordinator replied with wrong pubnonce")

enc_context = serialize_enc_context(simpl_state.t, enckeys)
secshare = decrypt_sum(
deckey, enckeys[idx], pubnonces, enc_context, idx, enc_secshare
)
pads = decaps_multi(deckey, enckeys[idx], pubnonces, enc_context, idx)
secshare = enc_secshare - Scalar.sum(*pads)

try:
dkg_output, eq_input = simplpedpop.participant_step2(
simpl_state, simpl_cmsg, secshare
)
except UnknownFaultyPartyError[simplpedpop.ParticipantBlameState] as e:
blame_state = ParticipantBlameState(e.blame_state, enc_secshare, pads)
raise UnknownFaultyPartyError[ParticipantBlameState](blame_state, e.args) from e

dkg_output, eq_input = simplpedpop.participant_step2(
simpl_state, simpl_cmsg, secshare
)
eq_input += b"".join(enckeys) + b"".join(pubnonces)
return dkg_output, eq_input


def participant_blame(
state: ParticipantState,
deckey: bytes,
cmsg: CoordinatorMsg,
enc_secshare: Scalar,
blame_state: ParticipantBlameState,
cblame: CoordinatorBlameMsg,
) -> NoReturn:
simpl_state, _, enckeys, idx = state
_, pubnonces = cmsg
simpl_blame_state, enc_secshare, pads = blame_state
enc_partial_secshares, partial_pubshares = cblame

# Compute the encryption pads once and use them to decrypt both the
# enc_secshare and all enc_partial_secshares
enc_context = serialize_enc_context(simpl_state.t, enckeys)
pads = decaps_multi(deckey, enckeys[idx], pubnonces, enc_context, idx)
secshare = enc_secshare - Scalar.sum(*pads)
partial_secshares = [
enc_partial_secshare - pad
for enc_partial_secshare, pad in zip(enc_partial_secshares, pads, strict=True)
Expand All @@ -258,7 +259,7 @@ def participant_blame(
simpl_cblame = simplpedpop.CoordinatorBlameMsg(partial_pubshares)
try:
simplpedpop.participant_blame(
simpl_state, secshare, partial_secshares, simpl_cblame
simpl_blame_state, simpl_cblame, partial_secshares
)
except simplpedpop.SecshareSumError as e:
# The secshare is not equal to the sum of the partial secshares in the
Expand Down
34 changes: 21 additions & 13 deletions python/chilldkg_ref/simplpedpop.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ class ParticipantState(NamedTuple):
com_to_secret: GE


class ParticipantBlameState(NamedTuple):
n: int
idx: int
secshare: Scalar
pubshare: GE


# To keep the algorithms of SimplPedPop and EncPedPop purely non-interactive
# computations, we omit explicit invocations of an interactive equality check
# protocol. ChillDKG will take care of invoking the equality check protocol.
Expand Down Expand Up @@ -200,7 +207,8 @@ def participant_step2(
pubshare = sum_coms.pubshare(idx)

if not VSSCommitment.verify_secshare(secshare, pubshare):
raise UnknownFaultyPartyError(
raise UnknownFaultyPartyError[ParticipantBlameState](
ParticipantBlameState(n, idx, secshare, pubshare),
"Received invalid secshare, consider blaming to determine faulty party",
)

Expand All @@ -215,14 +223,16 @@ def participant_step2(


def participant_blame(
state: ParticipantState,
secshare: Scalar,
partial_secshares: List[Scalar],
blame_state: ParticipantBlameState,
cblame: CoordinatorBlameMsg,
partial_secshares: List[Scalar],
) -> NoReturn:
_, n, idx, _ = state
n, idx, secshare, pubshare = blame_state
partial_pubshares = cblame.partial_pubshares

if GE.sum(*partial_pubshares) != pubshare:
raise FaultyCoordinatorError("Sum of partial pubshares not equal to pubshare")

if Scalar.sum(*partial_secshares) != secshare:
raise SecshareSumError("Sum of partial secshares not equal to secshare")

Expand All @@ -242,15 +252,13 @@ def participant_blame(

# We now know:
# - The sum of the partial secshares is equal to the secshare.
# - The sum of the partial pubshares is equal to the pubshare.
# - Every partial secshare matches its corresponding partial pubshare.
# - The secshare does not match the pubshare (because the caller shouldn't
# have called us otherwise).
# Therefore, the sum of the partial pubshares is not equal to the pubshare,
# and this is the coordinator's fault.
raise FaultyCoordinatorError(
"Sum of partial pubshares not equal to pubshare (or participant_blame() "
"was called even though participant_step2() was successful)"
)
# Hence, the secshare matches the pubshare.
assert VSSCommitment.verify_secshare(secshare, pubshare)

# This should never happen (unless the caller fiddled with the inputs).
raise RuntimeError("participant_blame() was called, but all inputs are consistent.")


###
Expand Down
11 changes: 8 additions & 3 deletions python/chilldkg_ref/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, TypeVar, Generic

from secp256k1proto.util import tagged_hash

Expand Down Expand Up @@ -55,5 +55,10 @@ class FaultyCoordinatorError(ProtocolError):
"""


class UnknownFaultyPartyError(ProtocolError):
pass
S = TypeVar("S")


class UnknownFaultyPartyError(ProtocolError, Generic[S]):
def __init__(self, blame_state: S, *args: Any):
self.blame_state = blame_state
super().__init__(*args)
91 changes: 50 additions & 41 deletions python/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@

from itertools import combinations
from random import randint
from typing import Tuple, List
from typing import Tuple, List, Optional
from secrets import token_bytes as random_bytes

from secp256k1proto.secp256k1 import GE, G, Scalar
from secp256k1proto.keys import pubkey_gen_plain

from chilldkg_ref.util import prf, FaultyCoordinatorError
from chilldkg_ref.util import (
FaultyParticipantOrCoordinatorError,
FaultyCoordinatorError,
UnknownFaultyPartyError,
prf,
)
from chilldkg_ref.vss import Polynomial, VSS, VSSCommitment
import chilldkg_ref.simplpedpop as simplpedpop
import chilldkg_ref.encpedpop as encpedpop
Expand All @@ -35,7 +40,9 @@ def rand_polynomial(t):
)


def simulate_simplpedpop(seeds, t) -> List[Tuple[simplpedpop.DKGOutput, bytes]]:
def simulate_simplpedpop(
seeds, t, blame: bool
) -> Optional[List[Tuple[simplpedpop.DKGOutput, bytes]]]:
n = len(seeds)
prets = []
for i in range(n):
Expand All @@ -45,27 +52,42 @@ def simulate_simplpedpop(seeds, t) -> List[Tuple[simplpedpop.DKGOutput, bytes]]:
pmsgs = [pmsg for (_, pmsg, _) in prets]

cmsg, cout, ceq = simplpedpop.coordinator_step(pmsgs, t, n)
blame_recs = simplpedpop.coordinator_blame(pmsgs)
pre_finalize_rets = [(cout, ceq)]
for i in range(n):
partial_secshares = [
partial_secshares_for[i] for (_, _, partial_secshares_for) in prets
]
# TODO Test that the protocol fails when wrong shares are sent.
# if i == n - 1:
# partial_secshares[-1] += Scalar(17)
if blame:
# Test that the protocol fails when wrong shares are sent.
if i == n - 1:
partial_secshares[-1] += Scalar(17)
secshare = simplpedpop.participant_step2_prepare_secshare(partial_secshares)
pre_finalize_rets += [simplpedpop.participant_step2(pstates[i], cmsg, secshare)]
# This was a correct run, so blame should fail.
try:
simplpedpop.participant_blame(
pstates[i], secshare, partial_secshares, blame_recs[i]
)
pre_finalize_rets += [
simplpedpop.participant_step2(pstates[i], cmsg, secshare)
]
except UnknownFaultyPartyError as e:
if blame:
simplpedpop_simulate_blame(pmsgs, e.blame_state, partial_secshares)
return None
else:
raise
return pre_finalize_rets


def simplpedpop_simulate_blame(pmsgs, blame_state, partial_secshares):
blame_msgs = simplpedpop.coordinator_blame(pmsgs)
assert len(blame_msgs) == len(pmsgs)
for i in range(n):
try:
simplpedpop.participant_blame(blame_state, blame_msgs[i], partial_secshares)
# Any of the following errors is good; it means we identified a faulty party.
# TODO We should also check if the right expection is raised, but this doesn't
# belong in a "correctness" test.
except FaultyParticipantOrCoordinatorError:
pass
except FaultyCoordinatorError:
pass
else:
assert False
return pre_finalize_rets


def encpedpop_keys(seed: bytes) -> Tuple[bytes, bytes]:
Expand Down Expand Up @@ -93,22 +115,13 @@ def simulate_encpedpop(seeds, t) -> List[Tuple[simplpedpop.DKGOutput, bytes]]:
pstates = [pstate for (pstate, _) in enc_prets1]

cmsg, cout, ceq, enc_secshares = encpedpop.coordinator_step(pmsgs, t, enckeys)
blame_recs = encpedpop.coordinator_blame(pmsgs)
blame_msgs = encpedpop.coordinator_blame(pmsgs) # FIXME
pre_finalize_rets = [(cout, ceq)]
for i in range(n):
deckey = enc_prets0[i][0]
pre_finalize_rets += [
encpedpop.participant_step2(pstates[i], deckey, cmsg, enc_secshares[i])
]
try:
encpedpop.participant_blame(
pstates[i], deckey, cmsg, enc_secshares[i], blame_recs[i]
)
# This was a correct run, so blame should fail.
except FaultyCoordinatorError:
pass
else:
assert False
return pre_finalize_rets


Expand All @@ -131,20 +144,11 @@ def simulate_chilldkg(
pstates1 = [pret[0] for pret in prets1]
pmsgs = [pret[1] for pret in prets1]
cstate, cmsg1 = chilldkg.coordinator_step1(pmsgs, params)
blame_recs = chilldkg.coordinator_blame(pmsgs)
blame_msgs = chilldkg.coordinator_blame(pmsgs) # FIXME

prets2 = []
for i in range(n):
prets2 += [chilldkg.participant_step2(hostseckeys[i], pstates1[i], cmsg1)]
# This was a correct run, so blame should fail.
try:
chilldkg.participant_blame(
hostseckeys[i], pstates1[i], cmsg1, blame_recs[i]
)
except FaultyCoordinatorError:
pass
else:
assert False

cmsg2, cout, crec = chilldkg.coordinator_finalize(
cstate, [pret[1] for pret in prets2]
Expand Down Expand Up @@ -221,14 +225,18 @@ def test_correctness_dkg_output(t, n, dkg_outputs: List[simplpedpop.DKGOutput]):
assert recovered * G == GE.from_bytes_compressed(threshold_pubkey)


def test_correctness(t, n, simulate_dkg, recovery=False):
def test_correctness(t, n, simulate_dkg, recovery=False, blame=False):
seeds = [None] + [random_bytes(32) for _ in range(n)]

rets = simulate_dkg(seeds[1:], t, blame=blame)
if blame:
assert rets is None
# The session has failed correctly, so there's nothing further to check.
return

# rets[0] are the return values from the coordinator
# rets[1 : n + 1] are from the participants
rets = simulate_dkg(seeds[1:], t)
assert len(rets) == n + 1

dkg_outputs = [ret[0] for ret in rets]
test_correctness_dkg_output(t, n, dkg_outputs)

Expand All @@ -250,6 +258,7 @@ def test_correctness(t, n, simulate_dkg, recovery=False):
test_recover_secret()
for t, n in [(1, 1), (1, 2), (2, 2), (2, 3), (2, 5)]:
test_correctness(t, n, simulate_simplpedpop)
test_correctness(t, n, simulate_encpedpop)
test_correctness(t, n, simulate_chilldkg, recovery=True)
test_correctness(t, n, simulate_chilldkg_full, recovery=True)
test_correctness(t, n, simulate_simplpedpop, blame=True)
# test_correctness(t, n, simulate_encpedpop)
# test_correctness(t, n, simulate_chilldkg, recovery=True)
# test_correctness(t, n, simulate_chilldkg_full, recovery=True)

0 comments on commit c6d7934

Please sign in to comment.