diff --git a/python/chilldkg_ref/chilldkg.py b/python/chilldkg_ref/chilldkg.py index 6c21034..5d3528f 100644 --- a/python/chilldkg_ref/chilldkg.py +++ b/python/chilldkg_ref/chilldkg.py @@ -16,7 +16,7 @@ from secp256k1proto.keys import pubkey_gen_plain from secp256k1proto.util import int_from_bytes, bytes_from_int -from .vss import VSS, VSSCommitment +from .vss import VSSCommitment from . import encpedpop from .util import ( BIP_TAG, @@ -408,7 +408,7 @@ def participant_step1( idx = hostpubkeys.index(hostpubkey) # ValueError if not found enc_state, enc_pmsg = encpedpop.participant_step1( - hostseckey, t, hostpubkeys, idx, random + hostseckey, hostseckey, hostpubkeys, t, idx, random ) # SecKeyError if len(hostseckey) != 32 state1 = ParticipantState1(params, idx, enc_state) return state1, ParticipantMsg1(enc_pmsg) @@ -639,9 +639,6 @@ def recover( # Decrypt share enc_context = encpedpop.serialize_enc_context(t, hostpubkeys) - session_seed = encpedpop.derive_session_seed( - hostseckey, pubnonces[idx], enc_context - ) secshare = encpedpop.decrypt_sum( hostseckey, hostpubkeys[idx], @@ -651,11 +648,6 @@ def recover( idx, ) - # Derive my_share - vss = VSS.generate(session_seed, t) - my_share = vss.secshare_for(idx) - secshare += my_share - # This is just a sanity check. Our signature is valid, so we have done # this check already during the actual session. assert VSSCommitment.verify_secshare(secshare, pubshares[idx]) diff --git a/python/chilldkg_ref/encpedpop.py b/python/chilldkg_ref/encpedpop.py index e338f15..98578d9 100644 --- a/python/chilldkg_ref/encpedpop.py +++ b/python/chilldkg_ref/encpedpop.py @@ -28,45 +28,58 @@ def ecdh( return Scalar(int_from_bytes(tagged_hash_bip_dkg("encpedpop ecdh", data))) +def self_pad(deckey: bytes, context_: bytes) -> Scalar: + return Scalar( + int_from_bytes( + prf(seed=deckey, tag="encaps_multi self_pad", extra_input=context_) + ) + ) + + def encaps_multi( - secnonce: bytes, pubnonce: bytes, enckeys: List[bytes], context: bytes, idx: int + secnonce: bytes, + pubnonce: bytes, + deckey: bytes, + enckeys: List[bytes], + context: bytes, + idx: int, ) -> List[Scalar]: # This is effectively the "Hashed ElGamal" multi-recipient KEM described in # Section 5 of "Multi-recipient encryption, revisited" by Alexandre Pinto, # Bertram Poettering, Jacob C. N. Schuldt (AsiaCCS 2014). Its crucial # feature is to feed the index of the enckey to the hash function. The only - # differences are that we skip our own index (because we don't need to - # encrypt to ourselves) and that we feed also the pubnonce and context data - # into the hash function. - if idx >= len(enckeys): - raise IndexError - keys = [ - ecdh( - secnonce, - pubnonce, - enckey, - i.to_bytes(4, byteorder="big") + context, - sending=True, - ) - for i, enckey in enumerate(enckeys) - if i != idx # Skip own index - ] - return keys + # difference is that we feed also the pubnonce and context data into the + # hash function. + pads = [] + for i, enckey in enumerate(enckeys): + context_ = i.to_bytes(4, byteorder="big") + context + if i == idx: + # We're encrypting to ourselves, so we use a symmetrically derived + # pad to save the ECDH computation. + pad = self_pad(deckey, context_) + else: + pad = ecdh( + seckey=secnonce, + my_pubkey=pubnonce, + their_pubkey=enckey, + context=context_, + sending=True, + ) + pads.append(pad) + return pads def encrypt_multi( secnonce: bytes, pubnonce: bytes, + deckey: bytes, enckeys: List[bytes], messages: List[Scalar], context: bytes, idx: int, ) -> List[Scalar]: - keys = encaps_multi(secnonce, pubnonce, enckeys, context, idx) - their_messages = messages[:idx] + messages[idx + 1 :] # Skip own index - ciphertexts = [ - message + key for message, key in zip(their_messages, keys, strict=True) - ] + pads = encaps_multi(secnonce, pubnonce, deckey, enckeys, context, idx) + ciphertexts = [message + pad for message, pad in zip(messages, pads, strict=True)] return ciphertexts @@ -80,12 +93,19 @@ def decrypt_sum( ) -> Scalar: if idx >= len(pubnonces): raise IndexError - ecdh_context = idx.to_bytes(4, byteorder="big") + context + context_ = idx.to_bytes(4, byteorder="big") + context secshare = sum_ciphertexts for i, pubnonce in enumerate(pubnonces): if i == idx: - continue # Skip own index - pad = ecdh(deckey, enckey, pubnonce, ecdh_context, sending=False) + pad = self_pad(deckey, context_) + else: + pad = ecdh( + seckey=deckey, + my_pubkey=enckey, + their_pubkey=pubnonce, + context=context_, + sending=False, + ) secshare = secshare - pad return secshare @@ -116,7 +136,6 @@ class ParticipantState(NamedTuple): pubnonce: bytes enckeys: List[bytes] idx: int - self_share: Scalar def serialize_enc_context(t: int, enckeys: List[bytes]) -> bytes: @@ -131,8 +150,9 @@ def derive_session_seed(seed: bytes, pubnonce: bytes, enc_context: bytes) -> byt def participant_step1( seed: bytes, - t: int, + deckey: bytes, enckeys: List[bytes], + t: int, idx: int, random: bytes, ) -> Tuple[ParticipantState, ParticipantMsg]: @@ -156,13 +176,12 @@ def participant_step1( ) assert len(shares) == n - # Encrypt shares, no need to encrypt to ourselves enc_shares = encrypt_multi( - secnonce, pubnonce, enckeys, shares, enc_context, idx=idx + secnonce, pubnonce, deckey, enckeys, shares, enc_context, idx ) pmsg = ParticipantMsg(simpl_pmsg, pubnonce, enc_shares) - state = ParticipantState(simpl_state, pubnonce, enckeys, idx, shares[idx]) + state = ParticipantState(simpl_state, pubnonce, enckeys, idx) return state, pmsg @@ -172,7 +191,7 @@ def participant_step2( cmsg: CoordinatorMsg, enc_secshare: Scalar, ) -> Tuple[simplpedpop.DKGOutput, bytes]: - simpl_state, pubnonce, enckeys, idx, self_share = state + simpl_state, pubnonce, enckeys, idx = state simpl_cmsg, pubnonces = cmsg reported_pubnonce = pubnonces[idx] @@ -183,7 +202,6 @@ def participant_step2( secshare = decrypt_sum( deckey, enckeys[idx], pubnonces, enc_secshare, enc_context, idx ) - secshare += self_share dkg_output, eq_input = simplpedpop.participant_step2( simpl_state, simpl_cmsg, secshare ) @@ -209,13 +227,10 @@ def coordinator_step( ) pubnonces = [pmsg.pubnonce for pmsg in pmsgs] for i in range(n): - # Participant i implicitly uses a pad of 0 to encrypt to themselves. - # Make this pad explicit at the right position. - if len(pmsgs[i].enc_shares) != n - 1: + if len(pmsgs[i].enc_shares) != n: raise InvalidContributionError( i, "Participant sent enc_shares with invalid length" ) - pmsgs[i].enc_shares.insert(i, Scalar(0)) enc_secshares = [ Scalar.sum(*([pmsg.enc_shares[i] for pmsg in pmsgs])) for i in range(n) ] diff --git a/python/tests.py b/python/tests.py index 1bd94a5..415d6dd 100755 --- a/python/tests.py +++ b/python/tests.py @@ -67,8 +67,11 @@ def simulate_encpedpop(seeds, t) -> List[Tuple[simplpedpop.DKGOutput, bytes]]: enckeys = [pret[1] for pret in enc_prets0] for i in range(n): + deckey = enc_prets0[i][0] random = random_bytes(32) - enc_prets1 += [encpedpop.participant_step1(seeds[i], t, enckeys, i, random)] + enc_prets1 += [ + encpedpop.participant_step1(seeds[i], deckey, enckeys, t, i, random) + ] pmsgs = [pmsg for (_, pmsg) in enc_prets1] pstates = [pstate for (pstate, _) in enc_prets1]