Skip to content

Commit

Permalink
python: Cache blame() inputs, piggyback them on exception
Browse files Browse the repository at this point in the history
  • Loading branch information
real-or-random committed Nov 25, 2024
1 parent 4c2852f commit 0cdefdb
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 112 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ of the success of the DKG session by presenting recovery data to us.
#### participant\_blame

```python
def participant_blame(hostseckey: bytes, state1: ParticipantState1, cmsg1: CoordinatorMsg1, cblame: CoordinatorBlameMsg) -> NoReturn
def participant_blame(blame_state: ParticipantBlameState, cblame: CoordinatorBlameMsg) -> NoReturn
```

Perform a participant's blame step of a ChillDKG session. TODO
Expand Down
36 changes: 21 additions & 15 deletions python/chilldkg_ref/chilldkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,10 @@ class ParticipantState2(NamedTuple):
dkg_output: DKGOutput


class ParticipantBlameState(NamedTuple):
enc_blame_state: encpedpop.ParticipantBlameState


def participant_step1(
hostseckey: bytes, params: SessionParams, random: bytes
) -> Tuple[ParticipantState1, ParticipantMsg1]:
Expand Down Expand Up @@ -472,16 +476,24 @@ def participant_step2(
params, idx, enc_state = state1
enc_cmsg, enc_secshares = cmsg1

enc_dkg_output, eq_input = encpedpop.participant_step2(
state=enc_state,
deckey=hostseckey,
cmsg=enc_cmsg,
enc_secshare=enc_secshares[idx],
)
try:
enc_dkg_output, eq_input = encpedpop.participant_step2(
state=enc_state,
deckey=hostseckey,
cmsg=enc_cmsg,
enc_secshare=enc_secshares[idx],
)
except UnknownFaultyPartyError as e:
assert isinstance(e.blame_state, encpedpop.ParticipantBlameState)
# Translate encpedpop.UnknownFaultyPartyError into our own
# chilldkg.UnknownFaultyPartyError.
blame_state = ParticipantBlameState(e.blame_state)
raise UnknownFaultyPartyError(blame_state, e.args) from e

# Include the enc_shares in eq_input to ensure that participants agree on all
# shares, which in turn ensures that they have the right recovery data.
eq_input += b"".join([bytes_from_int(int(share)) for share in enc_secshares])
dkg_output = DKGOutput._make(enc_dkg_output) # Convert to chilldkg.DKGOutput type
dkg_output = DKGOutput._make(enc_dkg_output)
state2 = ParticipantState2(params, eq_input, dkg_output)
sig = certeq_participant_step(hostseckey, idx, eq_input)
pmsg2 = ParticipantMsg2(sig)
Expand Down Expand Up @@ -529,18 +541,12 @@ def participant_finalize(


def participant_blame(
hostseckey: bytes,
state1: ParticipantState1,
cmsg1: CoordinatorMsg1,
blame_state: ParticipantBlameState,
cblame: CoordinatorBlameMsg,
) -> NoReturn:
"""Perform a participant's blame step of a ChillDKG session. TODO"""
_, idx, enc_state = state1
encpedpop.participant_blame(
state=enc_state,
deckey=hostseckey,
cmsg=cmsg1.enc_cmsg,
enc_secshare=cmsg1.enc_secshares[idx],
blame_state=blame_state.enc_blame_state,
cblame=cblame.enc_cblame,
)

Expand Down
42 changes: 23 additions & 19 deletions python/chilldkg_ref/encpedpop.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from . import simplpedpop
from .util import (
UnknownFaultyPartyError,
tagged_hash_bip_dkg,
prf,
FaultyParticipantOrCoordinatorError,
Expand Down Expand Up @@ -162,6 +163,12 @@ class ParticipantState(NamedTuple):
idx: int


class ParticipantBlameState(NamedTuple):
simpl_bstate: simplpedpop.ParticipantBlameState
enc_secshare: Scalar
pads: List[Scalar]


def serialize_enc_context(t: int, enckeys: List[bytes]) -> bytes:
# TODO Consider hashing the result here because the string can be long, and
# we'll feed it into hashes on multiple occasions
Expand Down Expand Up @@ -223,33 +230,30 @@ def participant_step2(
raise FaultyCoordinatorError("Coordinator replied with wrong pubnonce")

enc_context = serialize_enc_context(simpl_state.t, enckeys)
secshare = decrypt_sum(
deckey, enckeys[idx], pubnonces, enc_context, idx, enc_secshare
)
pads = decaps_multi(deckey, enckeys[idx], pubnonces, enc_context, idx)
secshare = enc_secshare - Scalar.sum(*pads)

try:
dkg_output, eq_input = simplpedpop.participant_step2(
simpl_state, simpl_cmsg, secshare
)
except UnknownFaultyPartyError as e:
assert isinstance(e.blame_state, simplpedpop.ParticipantBlameState)
# Translate simplpedpop.ParticipantBlamestate into our own
# encpedpop.ParticipantBlameState.
blame_state = ParticipantBlameState(e.blame_state, enc_secshare, pads)
raise UnknownFaultyPartyError(blame_state, e.args) from e

dkg_output, eq_input = simplpedpop.participant_step2(
simpl_state, simpl_cmsg, secshare
)
eq_input += b"".join(enckeys) + b"".join(pubnonces)
return dkg_output, eq_input


def participant_blame(
state: ParticipantState,
deckey: bytes,
cmsg: CoordinatorMsg,
enc_secshare: Scalar,
blame_state: ParticipantBlameState,
cblame: CoordinatorBlameMsg,
) -> NoReturn:
simpl_state, _, enckeys, idx = state
_, pubnonces = cmsg
simpl_blame_state, enc_secshare, pads = blame_state
enc_partial_secshares, partial_pubshares = cblame

# Compute the encryption pads once and use them to decrypt both the
# enc_secshare and all enc_partial_secshares
enc_context = serialize_enc_context(simpl_state.t, enckeys)
pads = decaps_multi(deckey, enckeys[idx], pubnonces, enc_context, idx)
secshare = enc_secshare - Scalar.sum(*pads)
partial_secshares = [
enc_partial_secshare - pad
for enc_partial_secshare, pad in zip(enc_partial_secshares, pads, strict=True)
Expand All @@ -258,7 +262,7 @@ def participant_blame(
simpl_cblame = simplpedpop.CoordinatorBlameMsg(partial_pubshares)
try:
simplpedpop.participant_blame(
simpl_state, secshare, partial_secshares, simpl_cblame
simpl_blame_state, simpl_cblame, partial_secshares
)
except simplpedpop.SecshareSumError as e:
# The secshare is not equal to the sum of the partial secshares in the
Expand Down
32 changes: 20 additions & 12 deletions python/chilldkg_ref/simplpedpop.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ class ParticipantState(NamedTuple):
com_to_secret: GE


class ParticipantBlameState(NamedTuple):
n: int
idx: int
secshare: Scalar
pubshare: GE


# To keep the algorithms of SimplPedPop and EncPedPop purely non-interactive
# computations, we omit explicit invocations of an interactive equality check
# protocol. ChillDKG will take care of invoking the equality check protocol.
Expand Down Expand Up @@ -201,6 +208,7 @@ def participant_step2(

if not VSSCommitment.verify_secshare(secshare, pubshare):
raise UnknownFaultyPartyError(
ParticipantBlameState(n, idx, secshare, pubshare),
"Received invalid secshare, consider blaming to determine faulty party",
)

Expand All @@ -215,14 +223,16 @@ def participant_step2(


def participant_blame(
state: ParticipantState,
secshare: Scalar,
partial_secshares: List[Scalar],
blame_state: ParticipantBlameState,
cblame: CoordinatorBlameMsg,
partial_secshares: List[Scalar],
) -> NoReturn:
_, n, idx, _ = state
n, idx, secshare, pubshare = blame_state
partial_pubshares = cblame.partial_pubshares

if GE.sum(*partial_pubshares) != pubshare:
raise FaultyCoordinatorError("Sum of partial pubshares not equal to pubshare")

if Scalar.sum(*partial_secshares) != secshare:
raise SecshareSumError("Sum of partial secshares not equal to secshare")

Expand All @@ -242,15 +252,13 @@ def participant_blame(

# We now know:
# - The sum of the partial secshares is equal to the secshare.
# - The sum of the partial pubshares is equal to the pubshare.
# - Every partial secshare matches its corresponding partial pubshare.
# - The secshare does not match the pubshare (because the caller shouldn't
# have called us otherwise).
# Therefore, the sum of the partial pubshares is not equal to the pubshare,
# and this is the coordinator's fault.
raise FaultyCoordinatorError(
"Sum of partial pubshares not equal to pubshare (or participant_blame() "
"was called even though participant_step2() was successful)"
)
# Hence, the secshare matches the pubshare.
assert VSSCommitment.verify_secshare(secshare, pubshare)

# This should never happen (unless the caller fiddled with the inputs).
raise RuntimeError("participant_blame() was called, but all inputs are consistent.")


###
Expand Down
6 changes: 5 additions & 1 deletion python/chilldkg_ref/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,8 @@ class FaultyCoordinatorError(ProtocolError):


class UnknownFaultyPartyError(ProtocolError):
pass
"""TODO"""

def __init__(self, blame_state: Any, *args: Any):
self.blame_state = blame_state
super().__init__(*args)
62 changes: 43 additions & 19 deletions python/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

"""Example of a full ChillDKG session"""

from typing import Tuple, List
from typing import Tuple, List, Optional
import asyncio
import pprint
from random import randint
from secrets import token_bytes as random_bytes
import sys

Expand All @@ -13,11 +14,14 @@
participant_step1,
participant_step2,
participant_finalize,
participant_blame,
coordinator_step1,
coordinator_finalize,
coordinator_blame,
SessionParams,
DKGOutput,
RecoveryData,
UnknownFaultyPartyError,
)

#
Expand Down Expand Up @@ -72,21 +76,38 @@ async def participant(
chan: ParticipantChannel,
hostseckey: bytes,
params: SessionParams,
blame: bool = True,
blame: bool,
) -> Tuple[DKGOutput, RecoveryData]:
# TODO Top-level error handling
random = random_bytes(32)
state1, pmsg1 = participant_step1(hostseckey, params, random)

# The following code simulate a faulty participant that sends a single
# incorrect share with probability 1/2. Of course, this should not be part
# of a real implementation.
if blame:
# Flip a coin to decide if we send incorrect shares.
faulty = randint(0, 1)
print(faulty)
if faulty:
# Pick a random victim participant.
victim = randint(0, len(pmsg1.enc_pmsg.enc_shares) - 1)
pmsg1.enc_pmsg.enc_shares[victim] += 17
chan.send(pmsg1)
cmsg1 = await chan.receive()

# TODO
# if blame:
# blame_rec = await chan.receive()
# else:
# blame_rec = None

state2, eq_round1 = participant_step2(hostseckey, state1, cmsg1)
try:
state2, eq_round1 = participant_step2(hostseckey, state1, cmsg1)
except UnknownFaultyPartyError as e:
if not blame:
# Not in blame mode, so we give up.
raise
blame_msg = await chan.receive()
participant_blame(e.blame_state, blame_msg)
else:
if blame:
# Ignore the blame message because we don't need it.
_ = await chan.receive()

chan.send(eq_round1)
cmsg2 = await chan.receive()
Expand All @@ -95,7 +116,7 @@ async def participant(


async def coordinator(
chans: CoordinatorChannels, params: SessionParams
chans: CoordinatorChannels, params: SessionParams, blame: bool
) -> Tuple[DKGOutput, RecoveryData]:
(hostpubkeys, t) = params
n = len(hostpubkeys)
Expand All @@ -106,10 +127,10 @@ async def coordinator(
state, cmsg1 = coordinator_step1(pmsgs1, params)
chans.send_all(cmsg1)

# TODO
# if blame:
# for i in range(n):
# chans.send_to(i, blame_recs[i])
if blame:
blame_msgs = coordinator_blame(pmsgs1)
for i in range(n):
chans.send_to(i, blame_msgs[i])

sigs = []
for i in range(n):
Expand All @@ -125,7 +146,9 @@ async def coordinator(
#


def simulate_chilldkg_full(hostseckeys, t) -> List[Tuple[DKGOutput, RecoveryData]]:
def simulate_chilldkg_full(
hostseckeys, t, blame: bool = False
) -> Optional[List[Tuple[DKGOutput, RecoveryData]]]:
# Generate common inputs for all participants and coordinator
n = len(hostseckeys)
hostpubkeys = []
Expand All @@ -136,16 +159,16 @@ def simulate_chilldkg_full(hostseckeys, t) -> List[Tuple[DKGOutput, RecoveryData
params = SessionParams(hostpubkeys, t)

async def session():
# TODO Blame
coord_chans = CoordinatorChannels(n)
participant_chans = [
ParticipantChannel(coord_chans.queues[i]) for i in range(n)
]
coord_chans.set_participant_queues(
[participant_chans[i].queue for i in range(n)]
)
coroutines = [coordinator(coord_chans, params)] + [
participant(participant_chans[i], hostseckeys[i], params) for i in range(n)
coroutines = [coordinator(coord_chans, params, blame)] + [
participant(participant_chans[i], hostseckeys[i], params, blame)
for i in range(n)
]
return await asyncio.gather(*coroutines)

Expand All @@ -168,7 +191,8 @@ def main():
print(f"Participant {i}'s hostseckey:", hostseckeys[i].hex())
print()

rets = simulate_chilldkg_full(hostseckeys, t)
# TODO Add cli arguments to enable blame mode
rets = simulate_chilldkg_full(hostseckeys, t, blame=False)
assert len(rets) == n + 1

print("=== Coordinator's DKGOutput ===")
Expand Down
Loading

0 comments on commit 0cdefdb

Please sign in to comment.