diff --git a/README.md b/README.md index 2d2cc5a..9d34788 100644 --- a/README.md +++ b/README.md @@ -839,7 +839,7 @@ of the success of the DKG session by presenting recovery data to us. #### participant\_blame ```python -def participant_blame(hostseckey: bytes, state1: ParticipantState1, cmsg1: CoordinatorMsg1, cblame: CoordinatorBlameMsg) -> NoReturn +def participant_blame(blame_state: ParticipantBlameState, cblame: CoordinatorBlameMsg) -> NoReturn ``` Perform a participant's blame step of a ChillDKG session. TODO diff --git a/python/chilldkg_ref/chilldkg.py b/python/chilldkg_ref/chilldkg.py index 4f2acbb..4bcb02b 100644 --- a/python/chilldkg_ref/chilldkg.py +++ b/python/chilldkg_ref/chilldkg.py @@ -384,6 +384,10 @@ class ParticipantState2(NamedTuple): dkg_output: DKGOutput +class ParticipantBlameState(NamedTuple): + enc_blame_state: encpedpop.ParticipantBlameState + + def participant_step1( hostseckey: bytes, params: SessionParams, random: bytes ) -> Tuple[ParticipantState1, ParticipantMsg1]: @@ -472,16 +476,24 @@ def participant_step2( params, idx, enc_state = state1 enc_cmsg, enc_secshares = cmsg1 - enc_dkg_output, eq_input = encpedpop.participant_step2( - state=enc_state, - deckey=hostseckey, - cmsg=enc_cmsg, - enc_secshare=enc_secshares[idx], - ) + try: + enc_dkg_output, eq_input = encpedpop.participant_step2( + state=enc_state, + deckey=hostseckey, + cmsg=enc_cmsg, + enc_secshare=enc_secshares[idx], + ) + except UnknownFaultyPartyError as e: + assert isinstance(e.blame_state, encpedpop.ParticipantBlameState) + # Translate encpedpop.UnknownFaultyPartyError into our own + # chilldkg.UnknownFaultyPartyError. + blame_state = ParticipantBlameState(e.blame_state) + raise UnknownFaultyPartyError(blame_state, e.args) from e + # Include the enc_shares in eq_input to ensure that participants agree on all # shares, which in turn ensures that they have the right recovery data. eq_input += b"".join([bytes_from_int(int(share)) for share in enc_secshares]) - dkg_output = DKGOutput._make(enc_dkg_output) # Convert to chilldkg.DKGOutput type + dkg_output = DKGOutput._make(enc_dkg_output) state2 = ParticipantState2(params, eq_input, dkg_output) sig = certeq_participant_step(hostseckey, idx, eq_input) pmsg2 = ParticipantMsg2(sig) @@ -529,18 +541,12 @@ def participant_finalize( def participant_blame( - hostseckey: bytes, - state1: ParticipantState1, - cmsg1: CoordinatorMsg1, + blame_state: ParticipantBlameState, cblame: CoordinatorBlameMsg, ) -> NoReturn: """Perform a participant's blame step of a ChillDKG session. TODO""" - _, idx, enc_state = state1 encpedpop.participant_blame( - state=enc_state, - deckey=hostseckey, - cmsg=cmsg1.enc_cmsg, - enc_secshare=cmsg1.enc_secshares[idx], + blame_state=blame_state.enc_blame_state, cblame=cblame.enc_cblame, ) diff --git a/python/chilldkg_ref/encpedpop.py b/python/chilldkg_ref/encpedpop.py index 0c75a72..92c18e0 100644 --- a/python/chilldkg_ref/encpedpop.py +++ b/python/chilldkg_ref/encpedpop.py @@ -7,6 +7,7 @@ from . import simplpedpop from .util import ( + UnknownFaultyPartyError, tagged_hash_bip_dkg, prf, FaultyParticipantOrCoordinatorError, @@ -162,6 +163,12 @@ class ParticipantState(NamedTuple): idx: int +class ParticipantBlameState(NamedTuple): + simpl_bstate: simplpedpop.ParticipantBlameState + enc_secshare: Scalar + pads: List[Scalar] + + def serialize_enc_context(t: int, enckeys: List[bytes]) -> bytes: # TODO Consider hashing the result here because the string can be long, and # we'll feed it into hashes on multiple occasions @@ -223,33 +230,30 @@ def participant_step2( raise FaultyCoordinatorError("Coordinator replied with wrong pubnonce") enc_context = serialize_enc_context(simpl_state.t, enckeys) - secshare = decrypt_sum( - deckey, enckeys[idx], pubnonces, enc_context, idx, enc_secshare - ) + pads = decaps_multi(deckey, enckeys[idx], pubnonces, enc_context, idx) + secshare = enc_secshare - Scalar.sum(*pads) + + try: + dkg_output, eq_input = simplpedpop.participant_step2( + simpl_state, simpl_cmsg, secshare + ) + except UnknownFaultyPartyError as e: + assert isinstance(e.blame_state, simplpedpop.ParticipantBlameState) + # Translate simplpedpop.ParticipantBlamestate into our own + # encpedpop.ParticipantBlameState. + blame_state = ParticipantBlameState(e.blame_state, enc_secshare, pads) + raise UnknownFaultyPartyError(blame_state, e.args) from e - dkg_output, eq_input = simplpedpop.participant_step2( - simpl_state, simpl_cmsg, secshare - ) eq_input += b"".join(enckeys) + b"".join(pubnonces) return dkg_output, eq_input def participant_blame( - state: ParticipantState, - deckey: bytes, - cmsg: CoordinatorMsg, - enc_secshare: Scalar, + blame_state: ParticipantBlameState, cblame: CoordinatorBlameMsg, ) -> NoReturn: - simpl_state, _, enckeys, idx = state - _, pubnonces = cmsg + simpl_blame_state, enc_secshare, pads = blame_state enc_partial_secshares, partial_pubshares = cblame - - # Compute the encryption pads once and use them to decrypt both the - # enc_secshare and all enc_partial_secshares - enc_context = serialize_enc_context(simpl_state.t, enckeys) - pads = decaps_multi(deckey, enckeys[idx], pubnonces, enc_context, idx) - secshare = enc_secshare - Scalar.sum(*pads) partial_secshares = [ enc_partial_secshare - pad for enc_partial_secshare, pad in zip(enc_partial_secshares, pads, strict=True) @@ -258,7 +262,7 @@ def participant_blame( simpl_cblame = simplpedpop.CoordinatorBlameMsg(partial_pubshares) try: simplpedpop.participant_blame( - simpl_state, secshare, partial_secshares, simpl_cblame + simpl_blame_state, simpl_cblame, partial_secshares ) except simplpedpop.SecshareSumError as e: # The secshare is not equal to the sum of the partial secshares in the diff --git a/python/chilldkg_ref/simplpedpop.py b/python/chilldkg_ref/simplpedpop.py index 4e59940..32bcafa 100644 --- a/python/chilldkg_ref/simplpedpop.py +++ b/python/chilldkg_ref/simplpedpop.py @@ -108,6 +108,13 @@ class ParticipantState(NamedTuple): com_to_secret: GE +class ParticipantBlameState(NamedTuple): + n: int + idx: int + secshare: Scalar + pubshare: GE + + # To keep the algorithms of SimplPedPop and EncPedPop purely non-interactive # computations, we omit explicit invocations of an interactive equality check # protocol. ChillDKG will take care of invoking the equality check protocol. @@ -201,6 +208,7 @@ def participant_step2( if not VSSCommitment.verify_secshare(secshare, pubshare): raise UnknownFaultyPartyError( + ParticipantBlameState(n, idx, secshare, pubshare), "Received invalid secshare, consider blaming to determine faulty party", ) @@ -215,14 +223,16 @@ def participant_step2( def participant_blame( - state: ParticipantState, - secshare: Scalar, - partial_secshares: List[Scalar], + blame_state: ParticipantBlameState, cblame: CoordinatorBlameMsg, + partial_secshares: List[Scalar], ) -> NoReturn: - _, n, idx, _ = state + n, idx, secshare, pubshare = blame_state partial_pubshares = cblame.partial_pubshares + if GE.sum(*partial_pubshares) != pubshare: + raise FaultyCoordinatorError("Sum of partial pubshares not equal to pubshare") + if Scalar.sum(*partial_secshares) != secshare: raise SecshareSumError("Sum of partial secshares not equal to secshare") @@ -242,15 +252,13 @@ def participant_blame( # We now know: # - The sum of the partial secshares is equal to the secshare. + # - The sum of the partial pubshares is equal to the pubshare. # - Every partial secshare matches its corresponding partial pubshare. - # - The secshare does not match the pubshare (because the caller shouldn't - # have called us otherwise). - # Therefore, the sum of the partial pubshares is not equal to the pubshare, - # and this is the coordinator's fault. - raise FaultyCoordinatorError( - "Sum of partial pubshares not equal to pubshare (or participant_blame() " - "was called even though participant_step2() was successful)" - ) + # Hence, the secshare matches the pubshare. + assert VSSCommitment.verify_secshare(secshare, pubshare) + + # This should never happen (unless the caller fiddled with the inputs). + raise RuntimeError("participant_blame() was called, but all inputs are consistent.") ### diff --git a/python/chilldkg_ref/util.py b/python/chilldkg_ref/util.py index 4e42c14..8dbbd50 100644 --- a/python/chilldkg_ref/util.py +++ b/python/chilldkg_ref/util.py @@ -56,4 +56,8 @@ class FaultyCoordinatorError(ProtocolError): class UnknownFaultyPartyError(ProtocolError): - pass + """TODO""" + + def __init__(self, blame_state: Any, *args: Any): + self.blame_state = blame_state + super().__init__(*args) diff --git a/python/example.py b/python/example.py index 59718bb..596ebb0 100755 --- a/python/example.py +++ b/python/example.py @@ -2,9 +2,10 @@ """Example of a full ChillDKG session""" -from typing import Tuple, List +from typing import Tuple, List, Optional import asyncio import pprint +from random import randint from secrets import token_bytes as random_bytes import sys @@ -13,11 +14,14 @@ participant_step1, participant_step2, participant_finalize, + participant_blame, coordinator_step1, coordinator_finalize, + coordinator_blame, SessionParams, DKGOutput, RecoveryData, + UnknownFaultyPartyError, ) # @@ -72,21 +76,38 @@ async def participant( chan: ParticipantChannel, hostseckey: bytes, params: SessionParams, - blame: bool = True, + blame: bool, ) -> Tuple[DKGOutput, RecoveryData]: # TODO Top-level error handling random = random_bytes(32) state1, pmsg1 = participant_step1(hostseckey, params, random) + + # The following code simulate a faulty participant that sends a single + # incorrect share with probability 1/2. Of course, this should not be part + # of a real implementation. + if blame: + # Flip a coin to decide if we send incorrect shares. + faulty = randint(0, 1) + print(faulty) + if faulty: + # Pick a random victim participant. + victim = randint(0, len(pmsg1.enc_pmsg.enc_shares) - 1) + pmsg1.enc_pmsg.enc_shares[victim] += 17 chan.send(pmsg1) cmsg1 = await chan.receive() - # TODO - # if blame: - # blame_rec = await chan.receive() - # else: - # blame_rec = None - - state2, eq_round1 = participant_step2(hostseckey, state1, cmsg1) + try: + state2, eq_round1 = participant_step2(hostseckey, state1, cmsg1) + except UnknownFaultyPartyError as e: + if not blame: + # Not in blame mode, so we give up. + raise + blame_msg = await chan.receive() + participant_blame(e.blame_state, blame_msg) + else: + if blame: + # Ignore the blame message because we don't need it. + _ = await chan.receive() chan.send(eq_round1) cmsg2 = await chan.receive() @@ -95,7 +116,7 @@ async def participant( async def coordinator( - chans: CoordinatorChannels, params: SessionParams + chans: CoordinatorChannels, params: SessionParams, blame: bool ) -> Tuple[DKGOutput, RecoveryData]: (hostpubkeys, t) = params n = len(hostpubkeys) @@ -106,10 +127,10 @@ async def coordinator( state, cmsg1 = coordinator_step1(pmsgs1, params) chans.send_all(cmsg1) - # TODO - # if blame: - # for i in range(n): - # chans.send_to(i, blame_recs[i]) + if blame: + blame_msgs = coordinator_blame(pmsgs1) + for i in range(n): + chans.send_to(i, blame_msgs[i]) sigs = [] for i in range(n): @@ -125,7 +146,9 @@ async def coordinator( # -def simulate_chilldkg_full(hostseckeys, t) -> List[Tuple[DKGOutput, RecoveryData]]: +def simulate_chilldkg_full( + hostseckeys, t, blame: bool = False +) -> Optional[List[Tuple[DKGOutput, RecoveryData]]]: # Generate common inputs for all participants and coordinator n = len(hostseckeys) hostpubkeys = [] @@ -136,7 +159,6 @@ def simulate_chilldkg_full(hostseckeys, t) -> List[Tuple[DKGOutput, RecoveryData params = SessionParams(hostpubkeys, t) async def session(): - # TODO Blame coord_chans = CoordinatorChannels(n) participant_chans = [ ParticipantChannel(coord_chans.queues[i]) for i in range(n) @@ -144,8 +166,9 @@ async def session(): coord_chans.set_participant_queues( [participant_chans[i].queue for i in range(n)] ) - coroutines = [coordinator(coord_chans, params)] + [ - participant(participant_chans[i], hostseckeys[i], params) for i in range(n) + coroutines = [coordinator(coord_chans, params, blame)] + [ + participant(participant_chans[i], hostseckeys[i], params, blame) + for i in range(n) ] return await asyncio.gather(*coroutines) @@ -168,7 +191,8 @@ def main(): print(f"Participant {i}'s hostseckey:", hostseckeys[i].hex()) print() - rets = simulate_chilldkg_full(hostseckeys, t) + # TODO Add cli arguments to enable blame mode + rets = simulate_chilldkg_full(hostseckeys, t, blame=False) assert len(rets) == n + 1 print("=== Coordinator's DKGOutput ===") diff --git a/python/tests.py b/python/tests.py index 1602b90..b348496 100755 --- a/python/tests.py +++ b/python/tests.py @@ -4,13 +4,18 @@ from itertools import combinations from random import randint -from typing import Tuple, List +from typing import Tuple, List, Optional from secrets import token_bytes as random_bytes from secp256k1proto.secp256k1 import GE, G, Scalar from secp256k1proto.keys import pubkey_gen_plain -from chilldkg_ref.util import prf, FaultyCoordinatorError +from chilldkg_ref.util import ( + FaultyParticipantOrCoordinatorError, + FaultyCoordinatorError, + UnknownFaultyPartyError, + prf, +) from chilldkg_ref.vss import Polynomial, VSS, VSSCommitment import chilldkg_ref.simplpedpop as simplpedpop import chilldkg_ref.encpedpop as encpedpop @@ -35,7 +40,9 @@ def rand_polynomial(t): ) -def simulate_simplpedpop(seeds, t) -> List[Tuple[simplpedpop.DKGOutput, bytes]]: +def simulate_simplpedpop( + seeds, t, blame: bool +) -> Optional[List[Tuple[simplpedpop.DKGOutput, bytes]]]: n = len(seeds) prets = [] for i in range(n): @@ -45,26 +52,38 @@ def simulate_simplpedpop(seeds, t) -> List[Tuple[simplpedpop.DKGOutput, bytes]]: pmsgs = [pmsg for (_, pmsg, _) in prets] cmsg, cout, ceq = simplpedpop.coordinator_step(pmsgs, t, n) - blame_recs = simplpedpop.coordinator_blame(pmsgs) pre_finalize_rets = [(cout, ceq)] for i in range(n): partial_secshares = [ partial_secshares_for[i] for (_, _, partial_secshares_for) in prets ] - # TODO Test that the protocol fails when wrong shares are sent. - # if i == n - 1: - # partial_secshares[-1] += Scalar(17) + if blame: + # Let a random participant send incorrect shares to participant i. + faulty_idx = randint(0, n - 1) + partial_secshares[faulty_idx] += Scalar(17) + secshare = simplpedpop.participant_step2_prepare_secshare(partial_secshares) - pre_finalize_rets += [simplpedpop.participant_step2(pstates[i], cmsg, secshare)] - # This was a correct run, so blame should fail. try: - simplpedpop.participant_blame( - pstates[i], secshare, partial_secshares, blame_recs[i] - ) - except FaultyCoordinatorError: - pass - else: - assert False + pre_finalize_rets += [ + simplpedpop.participant_step2(pstates[i], cmsg, secshare) + ] + except UnknownFaultyPartyError as e: + if not blame: + raise + blame_msgs = simplpedpop.coordinator_blame(pmsgs) + assert len(blame_msgs) == len(pmsgs) + try: + simplpedpop.participant_blame( + e.blame_state, blame_msgs[i], partial_secshares + ) + # If we're not faulty, we should blame the faulty party. + except FaultyParticipantOrCoordinatorError as e: + assert i != faulty_idx + assert e.participant == faulty_idx + # If we're faulty, we'll blame the coordinator. + except FaultyCoordinatorError: + assert i == faulty_idx + return None return pre_finalize_rets @@ -74,7 +93,9 @@ def encpedpop_keys(seed: bytes) -> Tuple[bytes, bytes]: return deckey, enckey -def simulate_encpedpop(seeds, t) -> List[Tuple[simplpedpop.DKGOutput, bytes]]: +def simulate_encpedpop( + seeds, t, blame: bool +) -> Optional[List[Tuple[simplpedpop.DKGOutput, bytes]]]: n = len(seeds) enc_prets0 = [] enc_prets1 = [] @@ -89,32 +110,44 @@ def simulate_encpedpop(seeds, t) -> List[Tuple[simplpedpop.DKGOutput, bytes]]: encpedpop.participant_step1(seeds[i], deckey, enckeys, t, i, random) ] - pmsgs = [pmsg for (_, pmsg) in enc_prets1] pstates = [pstate for (pstate, _) in enc_prets1] + pmsgs = [pmsg for (_, pmsg) in enc_prets1] + if blame: + faulty_idx: List[int] = [] + for i in range(n): + # Let a random participant faulty_idx[i] send incorrect shares to i. + faulty_idx[i:] = [randint(0, n - 1)] + pmsgs[faulty_idx[i]].enc_shares[i] += Scalar(17) cmsg, cout, ceq, enc_secshares = encpedpop.coordinator_step(pmsgs, t, enckeys) - blame_recs = encpedpop.coordinator_blame(pmsgs) pre_finalize_rets = [(cout, ceq)] for i in range(n): deckey = enc_prets0[i][0] - pre_finalize_rets += [ - encpedpop.participant_step2(pstates[i], deckey, cmsg, enc_secshares[i]) - ] try: - encpedpop.participant_blame( - pstates[i], deckey, cmsg, enc_secshares[i], blame_recs[i] - ) - # This was a correct run, so blame should fail. - except FaultyCoordinatorError: - pass - else: - assert False + pre_finalize_rets += [ + encpedpop.participant_step2(pstates[i], deckey, cmsg, enc_secshares[i]) + ] + except UnknownFaultyPartyError as e: + if not blame: + raise + blame_msgs = encpedpop.coordinator_blame(pmsgs) + assert len(blame_msgs) == len(pmsgs) + try: + encpedpop.participant_blame(e.blame_state, blame_msgs[i]) + # If we're not faulty, we should blame the faulty party. + except FaultyParticipantOrCoordinatorError as e: + assert i != faulty_idx[i] + assert e.participant == faulty_idx[i] + # If we're faulty, we'll blame the coordinator. + except FaultyCoordinatorError: + assert i == faulty_idx[i] + return None return pre_finalize_rets def simulate_chilldkg( - hostseckeys, t -) -> List[Tuple[chilldkg.DKGOutput, chilldkg.RecoveryData]]: + hostseckeys, t, blame: bool +) -> Optional[List[Tuple[chilldkg.DKGOutput, chilldkg.RecoveryData]]]: n = len(hostseckeys) hostpubkeys = [] @@ -130,21 +163,34 @@ def simulate_chilldkg( pstates1 = [pret[0] for pret in prets1] pmsgs = [pret[1] for pret in prets1] + if blame: + faulty_idx: List[int] = [] + for i in range(n): + # Let a random participant faulty_idx[i] send incorrect shares to i. + faulty_idx[i:] = [randint(0, n - 1)] + pmsgs[faulty_idx[i]].enc_pmsg.enc_shares[i] += Scalar(17) + cstate, cmsg1 = chilldkg.coordinator_step1(pmsgs, params) - blame_recs = chilldkg.coordinator_blame(pmsgs) prets2 = [] for i in range(n): - prets2 += [chilldkg.participant_step2(hostseckeys[i], pstates1[i], cmsg1)] - # This was a correct run, so blame should fail. try: - chilldkg.participant_blame( - hostseckeys[i], pstates1[i], cmsg1, blame_recs[i] - ) - except FaultyCoordinatorError: - pass - else: - assert False + prets2 += [chilldkg.participant_step2(hostseckeys[i], pstates1[i], cmsg1)] + except UnknownFaultyPartyError as e: + if not blame: + raise + blame_msgs = chilldkg.coordinator_blame(pmsgs) + assert len(blame_msgs) == len(pmsgs) + try: + chilldkg.participant_blame(e.blame_state, blame_msgs[i]) + # If we're not faulty, we should blame the faulty party. + except FaultyParticipantOrCoordinatorError as e: + assert i != faulty_idx[i] + assert e.participant == faulty_idx[i] + # If we're faulty, we'll blame the coordinator. + except FaultyCoordinatorError: + assert i == faulty_idx[i] + return None cmsg2, cout, crec = chilldkg.coordinator_finalize( cstate, [pret[1] for pret in prets2] @@ -221,14 +267,18 @@ def test_correctness_dkg_output(t, n, dkg_outputs: List[simplpedpop.DKGOutput]): assert recovered * G == GE.from_bytes_compressed(threshold_pubkey) -def test_correctness(t, n, simulate_dkg, recovery=False): +def test_correctness(t, n, simulate_dkg, recovery=False, blame=False): seeds = [None] + [random_bytes(32) for _ in range(n)] + rets = simulate_dkg(seeds[1:], t, blame=blame) + if blame: + assert rets is None + # The session has failed correctly, so there's nothing further to check. + return + # rets[0] are the return values from the coordinator # rets[1 : n + 1] are from the participants - rets = simulate_dkg(seeds[1:], t) assert len(rets) == n + 1 - dkg_outputs = [ret[0] for ret in rets] test_correctness_dkg_output(t, n, dkg_outputs) @@ -250,6 +300,9 @@ def test_correctness(t, n, simulate_dkg, recovery=False): test_recover_secret() for t, n in [(1, 1), (1, 2), (2, 2), (2, 3), (2, 5)]: test_correctness(t, n, simulate_simplpedpop) + test_correctness(t, n, simulate_simplpedpop, blame=True) test_correctness(t, n, simulate_encpedpop) + test_correctness(t, n, simulate_encpedpop, blame=True) test_correctness(t, n, simulate_chilldkg, recovery=True) + test_correctness(t, n, simulate_chilldkg, recovery=True, blame=True) test_correctness(t, n, simulate_chilldkg_full, recovery=True)