Skip to content

Commit

Permalink
Prevent malicious taproot commitment
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasnick committed Dec 4, 2024
1 parent 76ba82f commit 7f8ecc3
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 47 deletions.
46 changes: 26 additions & 20 deletions python/chilldkg_ref/chilldkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from .vss import VSSCommitment
from . import encpedpop
from . import simplpedpop
from .util import (
BIP_TAG,
tagged_hash_bip_dkg,
Expand Down Expand Up @@ -480,7 +481,7 @@ def participant_step2(
enc_cmsg, enc_secshares = cmsg1

try:
enc_dkg_output, eq_input = encpedpop.participant_step2(
enc_dkg_pre_output, eq_input = encpedpop.participant_step2(
state=enc_state,
deckey=hostseckey,
cmsg=enc_cmsg,
Expand All @@ -496,8 +497,7 @@ def participant_step2(
# Include the enc_shares in eq_input to ensure that participants agree on
# all shares, which in turn ensures that they have the right recovery data.
eq_input += b"".join([bytes_from_int(int(share)) for share in enc_secshares])
dkg_output = DKGOutput._make(enc_dkg_output)
state2 = ParticipantState2(params, eq_input, dkg_output)
state2 = ParticipantState2(params, eq_input, enc_dkg_pre_output)
sig = certeq_participant_step(hostseckey, idx, eq_input)
pmsg2 = ParticipantMsg2(sig)
return state2, pmsg2
Expand Down Expand Up @@ -538,9 +538,9 @@ def participant_finalize(
SessionNotFinalizedError: If finalizing the DKG session was not
successful from this participant's perspective (see above).
"""
params, eq_input, dkg_output = state2
params, eq_input, dkg_pre_output = state2
certeq_verify(params.hostpubkeys, eq_input, cmsg2.cert) # SessionNotFinalizedError
return dkg_output, RecoveryData(eq_input + cmsg2.cert)
return simplpedpop.dkg_output(dkg_pre_output), RecoveryData(eq_input + cmsg2.cert)


def participant_blame(
Expand Down Expand Up @@ -590,14 +590,13 @@ def coordinator_step1(
params_validate(params)
hostpubkeys, t = params

enc_cmsg, enc_dkg_output, eq_input, enc_secshares = encpedpop.coordinator_step(
enc_cmsg, enc_dkg_pre_output, eq_input, enc_secshares = encpedpop.coordinator_step(
pmsgs=[pmsg1.enc_pmsg for pmsg1 in pmsgs1],
t=t,
enckeys=hostpubkeys,
)
eq_input += b"".join([bytes_from_int(int(share)) for share in enc_secshares])
dkg_output = DKGOutput._make(enc_dkg_output) # Convert to chilldkg.DKGOutput type
state = CoordinatorState(params, eq_input, dkg_output)
state = CoordinatorState(params, eq_input, enc_dkg_pre_output)
cmsg1 = CoordinatorMsg1(enc_cmsg, enc_secshares)
return state, cmsg1

Expand Down Expand Up @@ -626,10 +625,10 @@ def coordinator_finalize(
received messages from other participants via a communication
channel beside the coordinator (or be malicious).
"""
params, eq_input, dkg_output = state
params, eq_input, dkg_pre_output = state
cert = certeq_coordinator_step([pmsg2.sig for pmsg2 in pmsgs2])
certeq_verify(params.hostpubkeys, eq_input, cert) # SessionNotFinalizedError
return CoordinatorMsg2(cert), dkg_output, RecoveryData(eq_input + cert)
return CoordinatorMsg2(cert), simplpedpop.coordinator_dkg_output(dkg_pre_output), RecoveryData(eq_input + cert)


def coordinator_blame(pmsgs: List[ParticipantMsg1]) -> List[CoordinatorBlameMsg]:
Expand Down Expand Up @@ -683,9 +682,6 @@ def recover(
eq_input = recovery_data[: -len(cert)]
certeq_verify(hostpubkeys, eq_input, cert)

# Compute threshold pubkey and individual pubshares
threshold_pubkey = sum_coms.commitment_to_secret()
pubshares = [sum_coms.pubshare(i) for i in range(n)]

if hostseckey:
hostpubkey = hostpubkey_gen(hostseckey)
Expand All @@ -707,15 +703,25 @@ def recover(
enc_secshares[idx],
)

pubshare = sum_coms.pubshare(idx)
# This is just a sanity check. Our signature is valid, so we have done
# this check already during the actual session.
assert VSSCommitment.verify_secshare(secshare, pubshares[idx])
assert VSSCommitment.verify_secshare(secshare, pubshare)

# Compute threshold pubkey and individual pubshares
dkg_pre_output = simplpedpop.DKGPreOutput(
n,
idx,
secshare,
sum_coms,
pubshare,
)
dkg_output = simplpedpop.dkg_output(dkg_pre_output)
else:
secshare = None
dkg_pre_output = simplpedpop.CoordinatorDKGPreOutput(
n,
sum_coms,
)
dkg_output = simplpedpop.coordinator_dkg_output(dkg_pre_output)

dkg_output = DKGOutput(
None if secshare is None else secshare.to_bytes(),
threshold_pubkey.to_bytes_compressed(),
[pubshare.to_bytes_compressed() for pubshare in pubshares],
)
return dkg_output, params
12 changes: 6 additions & 6 deletions python/chilldkg_ref/encpedpop.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def participant_step2(
deckey: bytes,
cmsg: CoordinatorMsg,
enc_secshare: Scalar,
) -> Tuple[simplpedpop.DKGOutput, bytes]:
) -> Tuple[simplpedpop.DKGPreOutput, bytes]:
simpl_state, pubnonce, enckeys, idx = state
simpl_cmsg, pubnonces = cmsg

Expand All @@ -231,7 +231,7 @@ def participant_step2(
secshare = enc_secshare - Scalar.sum(*pads)

try:
dkg_output, eq_input = simplpedpop.participant_step2(
dkg_pre_output, eq_input = simplpedpop.participant_step2(
simpl_state, simpl_cmsg, secshare
)
except UnknownFaultyParticipantOrCoordinatorError as e:
Expand All @@ -242,7 +242,7 @@ def participant_step2(
raise UnknownFaultyParticipantOrCoordinatorError(blame_state, e.args) from e

eq_input += b"".join(enckeys) + b"".join(pubnonces)
return dkg_output, eq_input
return dkg_pre_output, eq_input


def participant_blame(
Expand Down Expand Up @@ -282,13 +282,13 @@ def coordinator_step(
pmsgs: List[ParticipantMsg],
t: int,
enckeys: List[bytes],
) -> Tuple[CoordinatorMsg, simplpedpop.DKGOutput, bytes, List[Scalar]]:
) -> Tuple[CoordinatorMsg, simplpedpop.CoordinatorDKGPreOutput, bytes, List[Scalar]]:
n = len(enckeys)
if n != len(pmsgs):
raise ValueError

simpl_pmsgs = [pmsg.simpl_pmsg for pmsg in pmsgs]
simpl_cmsg, dkg_output, eq_input = simplpedpop.coordinator_step(simpl_pmsgs, t, n)
simpl_cmsg, vss_com, eq_input = simplpedpop.coordinator_step(simpl_pmsgs, t, n)
pubnonces = [pmsg.pubnonce for pmsg in pmsgs]
for i in range(n):
if len(pmsgs[i].enc_shares) != n:
Expand All @@ -311,7 +311,7 @@ def coordinator_step(
# participant i for participant_step2(); we leave this unspecified.
return (
CoordinatorMsg(simpl_cmsg, pubnonces),
dkg_output,
vss_com,
eq_input,
enc_secshares,
)
Expand Down
78 changes: 61 additions & 17 deletions python/chilldkg_ref/simplpedpop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from typing import List, NamedTuple, NewType, Tuple, Optional, NoReturn

from secp256k1proto.bip340 import schnorr_sign, schnorr_verify
from secp256k1proto.secp256k1 import GE, Scalar
from secp256k1proto.secp256k1 import G, GE, Scalar
from .util import (
tagged_hash_bip_dkg,
BIP_TAG,
SecretKeyError,
ThresholdError,
Expand Down Expand Up @@ -80,6 +81,18 @@ class CoordinatorBlameMsg(NamedTuple):
### Other common definitions
###

class DKGPreOutput(NamedTuple):
n: int
idx: int
secshare: Scalar
com: VSSCommitment
pubshare: GE


class CoordinatorDKGPreOutput(NamedTuple):
n: int
com: VSSCommitment


class DKGOutput(NamedTuple):
secshare: Optional[bytes] # None for coordinator
Expand Down Expand Up @@ -204,7 +217,6 @@ def participant_step2(
i, "Participant sent invalid proof-of-knowledge"
)
sum_coms = assemble_sum_coms(coms_to_secrets, sum_coms_to_nonconst_terms)
threshold_pubkey = sum_coms.commitment_to_secret()
pubshare = sum_coms.pubshare(idx)

if not VSSCommitment.verify_secshare(secshare, pubshare):
Expand All @@ -213,14 +225,40 @@ def participant_step2(
"Received invalid secshare, consider blaming to determine faulty party",
)

pubshares = [sum_coms.pubshare(i) if i != idx else pubshare for i in range(n)]
dkg_output = DKGOutput(
secshare.to_bytes(),
dkg_pre_output = DKGPreOutput(
n,
idx,
secshare,
sum_coms,
pubshare,
)
eq_input = t.to_bytes(4, byteorder="big") + sum_coms.to_bytes()
return dkg_pre_output, eq_input


# TODO: add comment
def vss_invalid_taproot_commit(vss: VSSCommitment, pubshare: Optional[GE]) -> (VSSCommitment, Scalar):
pk = vss.commitment_to_secret()
secshare_tweak = Scalar.from_bytes(
tagged_hash_bip_dkg("invalid Taproot commitment", pk.to_bytes_compressed())
)
pk_tweak = secshare_tweak * G
if pubshare:
pubshare += pk_tweak
vss_offset = VSSCommitment([pk_tweak] + [GE()] * (vss.t() - 1))
return (vss + vss_offset, pubshare, secshare_tweak)

def dkg_output(dkg_pre_output: DKGPreOutput) -> DKGOutput:
(n, idx, secshare, com, pubshare) = dkg_pre_output

com, pubshare, secshare_offset = vss_invalid_taproot_commit(com, pubshare)
threshold_pubkey = com.commitment_to_secret()
pubshares = [com.pubshare(i) if i != idx else pubshare for i in range(n)]
return DKGOutput(
(secshare + secshare_offset).to_bytes(),
threshold_pubkey.to_bytes_compressed(),
[pubshare.to_bytes_compressed() for pubshare in pubshares],
)
eq_input = t.to_bytes(4, byteorder="big") + sum_coms.to_bytes()
return dkg_output, eq_input


def participant_blame(
Expand Down Expand Up @@ -267,9 +305,22 @@ def participant_blame(
###



def coordinator_dkg_output(dkg_pre_output: CoordinatorDKGPreOutput) -> DKGOutput:
n, com = dkg_pre_output
com, _, _ = vss_invalid_taproot_commit(com, None)
threshold_pubkey = com.commitment_to_secret()
pubshares = [com.pubshare(i) for i in range(n)]
return DKGOutput(
None,
threshold_pubkey.to_bytes_compressed(),
[pubshare.to_bytes_compressed() for pubshare in pubshares],
)


def coordinator_step(
pmsgs: List[ParticipantMsg], t: int, n: int
) -> Tuple[CoordinatorMsg, DKGOutput, bytes]:
) -> Tuple[CoordinatorMsg, CoordinatorDKGPreOutput, bytes]:
# Sum the commitments to the i-th coefficients for i > 0
#
# This procedure corresponds to the one described by Pedersen in Section 5.1
Expand All @@ -286,16 +337,9 @@ def coordinator_step(
cmsg = CoordinatorMsg(coms_to_secrets, sum_coms_to_nonconst_terms, pops)

sum_coms = assemble_sum_coms(coms_to_secrets, sum_coms_to_nonconst_terms)
threshold_pubkey = sum_coms.commitment_to_secret()
pubshares = [sum_coms.pubshare(i) for i in range(n)]

dkg_output = DKGOutput(
None,
threshold_pubkey.to_bytes_compressed(),
[pubshare.to_bytes_compressed() for pubshare in pubshares],
)
dkg_pre_output = CoordinatorDKGPreOutput(n, sum_coms)
eq_input = t.to_bytes(4, byteorder="big") + sum_coms.to_bytes()
return cmsg, dkg_output, eq_input
return cmsg, dkg_pre_output, eq_input


def coordinator_blame(pmsgs: List[ParticipantMsg]) -> List[CoordinatorBlameMsg]:
Expand Down
10 changes: 6 additions & 4 deletions python/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def simulate_simplpedpop(
pmsgs = [pmsg for (_, pmsg, _) in prets]

cmsg, cout, ceq = simplpedpop.coordinator_step(pmsgs, t, n)
pre_finalize_rets = [(cout, ceq)]
pre_finalize_rets = [(simplpedpop.coordinator_dkg_output(cout), ceq)]
for i in range(n):
partial_secshares = [
partial_secshares_for[i] for (_, _, partial_secshares_for) in prets
Expand All @@ -64,8 +64,9 @@ def simulate_simplpedpop(

secshare = simplpedpop.participant_step2_prepare_secshare(partial_secshares)
try:
pre_out, eq = simplpedpop.participant_step2(pstates[i], cmsg, secshare)
pre_finalize_rets += [
simplpedpop.participant_step2(pstates[i], cmsg, secshare)
(simplpedpop.dkg_output(pre_out), eq)
]
except UnknownFaultyParticipantOrCoordinatorError as e:
if not blame:
Expand Down Expand Up @@ -120,12 +121,13 @@ def simulate_encpedpop(
pmsgs[faulty_idx[i]].enc_shares[i] += Scalar(17)

cmsg, cout, ceq, enc_secshares = encpedpop.coordinator_step(pmsgs, t, enckeys)
pre_finalize_rets = [(cout, ceq)]
pre_finalize_rets = [(simplpedpop.coordinator_dkg_output(cout), ceq)]
for i in range(n):
deckey = enc_prets0[i][0]
try:
pre_out, eq = encpedpop.participant_step2(pstates[i], deckey, cmsg, enc_secshares[i])
pre_finalize_rets += [
encpedpop.participant_step2(pstates[i], deckey, cmsg, enc_secshares[i])
(simplpedpop.dkg_output(pre_out), eq)
]
except UnknownFaultyParticipantOrCoordinatorError as e:
if not blame:
Expand Down

0 comments on commit 7f8ecc3

Please sign in to comment.