Skip to content

Commit

Permalink
Merge pull request #51 from BlockstreamResearch/202410-encrypt-own-cheap
Browse files Browse the repository at this point in the history
  • Loading branch information
real-or-random authored Oct 8, 2024
2 parents b76b40b + f86957b commit 8e052d9
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 47 deletions.
11 changes: 2 additions & 9 deletions python/chilldkg_ref/chilldkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from secp256k1proto.keys import pubkey_gen_plain
from secp256k1proto.util import int_from_bytes, bytes_from_int

from .vss import VSS, VSSCommitment
from .vss import VSSCommitment
from . import encpedpop
from .util import (
BIP_TAG,
Expand Down Expand Up @@ -412,6 +412,7 @@ def participant_step1(
# function. Thus, it is sufficient that the seed has a high entropy,
# and so we can simply pass the hostseckey as seed.
seed=hostseckey,
deckey=hostseckey,
t=t,
# This requires the joint security of Schnorr signatures and ECDH.
enckeys=hostpubkeys,
Expand Down Expand Up @@ -650,9 +651,6 @@ def recover(

# Decrypt share
enc_context = encpedpop.serialize_enc_context(t, hostpubkeys)
simpl_seed = encpedpop.derive_simpl_seed(
hostseckey, pubnonces[idx], enc_context
)
secshare = encpedpop.decrypt_sum(
hostseckey,
hostpubkeys[idx],
Expand All @@ -662,11 +660,6 @@ def recover(
idx,
)

# Derive my_share
vss = VSS.generate(simpl_seed, t)
my_share = vss.secshare_for(idx)
secshare += my_share

# This is just a sanity check. Our signature is valid, so we have done
# this check already during the actual session.
assert VSSCommitment.verify_secshare(secshare, pubshares[idx])
Expand Down
89 changes: 52 additions & 37 deletions python/chilldkg_ref/encpedpop.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,45 +28,58 @@ def ecdh(
return Scalar(int_from_bytes(tagged_hash_bip_dkg("encpedpop ecdh", data)))


def self_pad(deckey: bytes, context_: bytes) -> Scalar:
return Scalar(
int_from_bytes(
prf(seed=deckey, tag="encaps_multi self_pad", extra_input=context_)
)
)


def encaps_multi(
secnonce: bytes, pubnonce: bytes, enckeys: List[bytes], context: bytes, idx: int
secnonce: bytes,
pubnonce: bytes,
deckey: bytes,
enckeys: List[bytes],
context: bytes,
idx: int,
) -> List[Scalar]:
# This is effectively the "Hashed ElGamal" multi-recipient KEM described in
# Section 5 of "Multi-recipient encryption, revisited" by Alexandre Pinto,
# Bertram Poettering, Jacob C. N. Schuldt (AsiaCCS 2014). Its crucial
# feature is to feed the index of the enckey to the hash function. The only
# differences are that we skip our own index (because we don't need to
# encrypt to ourselves) and that we feed also the pubnonce and context data
# into the hash function.
if idx >= len(enckeys):
raise IndexError
keys = [
ecdh(
secnonce,
pubnonce,
enckey,
i.to_bytes(4, byteorder="big") + context,
sending=True,
)
for i, enckey in enumerate(enckeys)
if i != idx # Skip own index
]
return keys
# difference is that we feed also the pubnonce and context data into the
# hash function.
pads = []
for i, enckey in enumerate(enckeys):
context_ = i.to_bytes(4, byteorder="big") + context
if i == idx:
# We're encrypting to ourselves, so we use a symmetrically derived
# pad to save the ECDH computation.
pad = self_pad(deckey, context_)
else:
pad = ecdh(
seckey=secnonce,
my_pubkey=pubnonce,
their_pubkey=enckey,
context=context_,
sending=True,
)
pads.append(pad)
return pads


def encrypt_multi(
secnonce: bytes,
pubnonce: bytes,
deckey: bytes,
enckeys: List[bytes],
messages: List[Scalar],
context: bytes,
idx: int,
) -> List[Scalar]:
keys = encaps_multi(secnonce, pubnonce, enckeys, context, idx)
their_messages = messages[:idx] + messages[idx + 1 :] # Skip own index
ciphertexts = [
message + key for message, key in zip(their_messages, keys, strict=True)
]
pads = encaps_multi(secnonce, pubnonce, deckey, enckeys, context, idx)
ciphertexts = [message + pad for message, pad in zip(messages, pads, strict=True)]
return ciphertexts


Expand All @@ -80,12 +93,19 @@ def decrypt_sum(
) -> Scalar:
if idx >= len(pubnonces):
raise IndexError
ecdh_context = idx.to_bytes(4, byteorder="big") + context
context_ = idx.to_bytes(4, byteorder="big") + context
secshare = sum_ciphertexts
for i, pubnonce in enumerate(pubnonces):
if i == idx:
continue # Skip own index
pad = ecdh(deckey, enckey, pubnonce, ecdh_context, sending=False)
pad = self_pad(deckey, context_)
else:
pad = ecdh(
seckey=deckey,
my_pubkey=enckey,
their_pubkey=pubnonce,
context=context_,
sending=False,
)
secshare = secshare - pad
return secshare

Expand Down Expand Up @@ -116,7 +136,6 @@ class ParticipantState(NamedTuple):
pubnonce: bytes
enckeys: List[bytes]
idx: int
self_share: Scalar


def serialize_enc_context(t: int, enckeys: List[bytes]) -> bytes:
Expand All @@ -131,8 +150,9 @@ def derive_simpl_seed(seed: bytes, pubnonce: bytes, enc_context: bytes) -> bytes

def participant_step1(
seed: bytes,
t: int,
deckey: bytes,
enckeys: List[bytes],
t: int,
idx: int,
random: bytes,
) -> Tuple[ParticipantState, ParticipantMsg]:
Expand All @@ -156,13 +176,12 @@ def participant_step1(
)
assert len(shares) == n

# Encrypt shares, no need to encrypt to ourselves
enc_shares = encrypt_multi(
secnonce, pubnonce, enckeys, shares, enc_context, idx=idx
secnonce, pubnonce, deckey, enckeys, shares, enc_context, idx
)

pmsg = ParticipantMsg(simpl_pmsg, pubnonce, enc_shares)
state = ParticipantState(simpl_state, pubnonce, enckeys, idx, shares[idx])
state = ParticipantState(simpl_state, pubnonce, enckeys, idx)
return state, pmsg


Expand All @@ -172,7 +191,7 @@ def participant_step2(
cmsg: CoordinatorMsg,
enc_secshare: Scalar,
) -> Tuple[simplpedpop.DKGOutput, bytes]:
simpl_state, pubnonce, enckeys, idx, self_share = state
simpl_state, pubnonce, enckeys, idx = state
simpl_cmsg, pubnonces = cmsg

reported_pubnonce = pubnonces[idx]
Expand All @@ -183,7 +202,6 @@ def participant_step2(
secshare = decrypt_sum(
deckey, enckeys[idx], pubnonces, enc_secshare, enc_context, idx
)
secshare += self_share
dkg_output, eq_input = simplpedpop.participant_step2(
simpl_state, simpl_cmsg, secshare
)
Expand All @@ -209,13 +227,10 @@ def coordinator_step(
)
pubnonces = [pmsg.pubnonce for pmsg in pmsgs]
for i in range(n):
# Participant i implicitly uses a pad of 0 to encrypt to themselves.
# Make this pad explicit at the right position.
if len(pmsgs[i].enc_shares) != n - 1:
if len(pmsgs[i].enc_shares) != n:
raise InvalidContributionError(
i, "Participant sent enc_shares with invalid length"
)
pmsgs[i].enc_shares.insert(i, Scalar(0))
enc_secshares = [
Scalar.sum(*([pmsg.enc_shares[i] for pmsg in pmsgs])) for i in range(n)
]
Expand Down
5 changes: 4 additions & 1 deletion python/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ def simulate_encpedpop(seeds, t) -> List[Tuple[simplpedpop.DKGOutput, bytes]]:

enckeys = [pret[1] for pret in enc_prets0]
for i in range(n):
deckey = enc_prets0[i][0]
random = random_bytes(32)
enc_prets1 += [encpedpop.participant_step1(seeds[i], t, enckeys, i, random)]
enc_prets1 += [
encpedpop.participant_step1(seeds[i], deckey, enckeys, t, i, random)
]

pmsgs = [pmsg for (_, pmsg) in enc_prets1]
pstates = [pstate for (pstate, _) in enc_prets1]
Expand Down

0 comments on commit 8e052d9

Please sign in to comment.