Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement Beaver triple generation #335

Merged
merged 1 commit into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions integration-tests/src/multichain/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ impl Nodes<'_> {
}
}

pub fn url(&self, id: usize) -> &str {
match self {
Nodes::Local { nodes, .. } => &nodes[id].address,
Nodes::Docker { nodes, .. } => &nodes[id].address,
}
}

pub async fn add_node(
&mut self,
node_id: u32,
Expand Down
38 changes: 37 additions & 1 deletion integration-tests/tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ where
pub struct MultichainTestContext<'a> {
nodes: mpc_recovery_integration_tests::multichain::Nodes<'a>,
rpc_client: near_fetch::Client,
http_client: reqwest::Client,
}

async fn with_multichain_nodes<F>(nodes: usize, f: F) -> anyhow::Result<()>
Expand All @@ -72,7 +73,12 @@ where
let nodes = mpc_recovery_integration_tests::multichain::run(nodes, &docker_client).await?;

let rpc_client = near_fetch::Client::new(&nodes.ctx().sandbox.local_address);
f(MultichainTestContext { nodes, rpc_client }).await?;
f(MultichainTestContext {
nodes,
rpc_client,
http_client: reqwest::Client::default(),
})
.await?;

Ok(())
}
Expand Down Expand Up @@ -184,6 +190,7 @@ mod wait_for {
use backon::Retryable;
use mpc_contract::ProtocolContractState;
use mpc_contract::RunningContractState;
use mpc_recovery_node::web::StateView;

pub async fn running_mpc<'a>(
ctx: &MultichainTestContext<'a>,
Expand All @@ -207,6 +214,35 @@ mod wait_for {
.retry(&ExponentialBuilder::default().with_max_times(6))
.await
}

pub async fn has_at_least_triples<'a>(
ctx: &MultichainTestContext<'a>,
id: usize,
expected_triple_count: usize,
) -> anyhow::Result<StateView> {
let is_enough_triples = || async {
let state_view: StateView = ctx
.http_client
.get(format!("{}/state", ctx.nodes.url(id)))
.send()
.await?
.json()
.await?;

match state_view {
StateView::Running { triple_count, .. }
if triple_count >= expected_triple_count =>
{
Ok(state_view)
}
StateView::Running { .. } => anyhow::bail!("node does not have enough triples yet"),
StateView::NotRunning => anyhow::bail!("node is not running"),
}
};
is_enough_triples
.retry(&ExponentialBuilder::default().with_max_times(6))
.await
}
}

trait MpcCheck {
Expand Down
12 changes: 12 additions & 0 deletions integration-tests/tests/multichain/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,15 @@ async fn test_multichain_reshare() -> anyhow::Result<()> {
})
.await
}

#[test(tokio::test)]
async fn test_triples() -> anyhow::Result<()> {
with_multichain_nodes(3, |ctx| {
Box::pin(async move {
wait_for::has_at_least_triples(&ctx, 0, 2).await?;

Ok(())
})
})
.await
}
1 change: 1 addition & 0 deletions node/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ clap = { version = "4.2", features = ["derive", "env"] }
hex = "0.4.3"
k256 = { version = "0.13.1", features = ["sha256", "ecdsa", "serde"] }
local-ip-address = "0.5.4"
rand = "0.8"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, why do we have both rand7 and rand8 in regular MPC?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need both versions anymore, but in the past there was some issue with one of the dependencies relying on rand being exactly 0.7

reqwest = { version = "0.11.16", features = ["json"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
Expand Down
16 changes: 16 additions & 0 deletions node/src/protocol/consensus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use super::state::{
WaitingForConsensusState,
};
use crate::protocol::state::{GeneratingState, ResharingState};
use crate::protocol::triple::TripleManager;
use crate::types::PrivateKeyShare;
use crate::util::AffinePointExt;
use crate::{http_client, rpc_client};
Expand Down Expand Up @@ -85,12 +86,20 @@ impl ConsensusProtocol for StartedState {
tracing::info!(
"contract state is running and we are already a participant"
);
let participants_vec =
contract_state.participants.keys().cloned().collect();
Ok(NodeState::Running(RunningState {
epoch,
participants: contract_state.participants,
threshold: contract_state.threshold,
private_share,
public_key,
triple_manager: TripleManager::new(
participants_vec,
ctx.me(),
contract_state.threshold,
epoch,
),
}))
} else {
Ok(NodeState::Joining(JoiningState { public_key }))
Expand Down Expand Up @@ -252,12 +261,19 @@ impl ConsensusProtocol for WaitingForConsensusState {
if contract_state.public_key != self.public_key {
return Err(ConsensusError::MismatchedPublicKey);
}
let participants_vec = self.participants.keys().cloned().collect();
Ok(NodeState::Running(RunningState {
epoch: self.epoch,
participants: self.participants,
threshold: self.threshold,
private_share: self.private_share,
public_key: self.public_key,
triple_manager: TripleManager::new(
participants_vec,
ctx.me(),
self.threshold,
self.epoch,
),
}))
}
},
Expand Down
20 changes: 19 additions & 1 deletion node/src/protocol/cryptography.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::state::{GeneratingState, NodeState, ResharingState};
use super::state::{GeneratingState, NodeState, ResharingState, RunningState};
use crate::http_client::{self, SendError};
use crate::protocol::message::{GeneratingMessage, ResharingMessage};
use crate::protocol::state::WaitingForConsensusState;
Expand Down Expand Up @@ -163,6 +163,23 @@ impl CryptographicProtocol for ResharingState {
}
}

#[async_trait]
impl CryptographicProtocol for RunningState {
async fn progress<C: CryptographicCtx + Send + Sync>(
mut self,
ctx: C,
) -> Result<NodeState, CryptographicError> {
if self.triple_manager.potential_len() < 2 {
self.triple_manager.generate();
}
for (p, msg) in self.triple_manager.poke() {
Comment on lines +172 to +174
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so, we're generating more triples on the fly when we run low? Won't this be very computationally expensive or since we're just generating one, it will be fine? But wouldn't that still impose a good amount of latency with all the messaging the triple generation protocol requires?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is just pretty much a placeholder. We need to implement TripleStockpile that operates on top of TripleManager and actively tries to initiate generation when it can (the ticket is in the epic, but hasn't been scoped out yet).

let url = self.participants.get(&p).unwrap();
http_client::message(ctx.http_client(), url.clone(), MpcMessage::Triple(msg)).await?;
}
Ok(NodeState::Running(self))
Comment on lines +175 to +178
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this correct that we should be doing all the messaging after the protocol completes? What if all nodes are waiting on messages and we'll be stuck in a deadlock?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the protocol is interactive and occasionally you need to wait for other messages to arrive before you can progress. So ideally there should be a timeout and a restart mechanism but that hasn't been implemented yet.

}
}

#[async_trait]
impl CryptographicProtocol for NodeState {
async fn progress<C: CryptographicCtx + Send + Sync>(
Expand All @@ -172,6 +189,7 @@ impl CryptographicProtocol for NodeState {
match self {
NodeState::Generating(state) => state.progress(ctx).await,
NodeState::Resharing(state) => state.progress(ctx).await,
NodeState::Running(state) => state.progress(ctx).await,
_ => Ok(self),
}
}
Expand Down
78 changes: 49 additions & 29 deletions node/src/protocol/message.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use std::collections::{HashMap, VecDeque};

use super::state::{GeneratingState, NodeState, ResharingState};
use super::state::{GeneratingState, NodeState, ResharingState, RunningState};
use cait_sith::protocol::{MessageData, Participant};
use serde::{Deserialize, Serialize};

pub trait MessageCtx {
fn me(&self) -> Participant;
}

#[derive(Serialize, Deserialize, Debug)]
pub struct GeneratingMessage {
pub from: Participant,
Expand All @@ -17,16 +21,26 @@ pub struct ResharingMessage {
pub data: MessageData,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct TripleMessage {
pub id: u64,
pub epoch: u64,
pub from: Participant,
pub data: MessageData,
}

#[derive(Serialize, Deserialize, Debug)]
pub enum MpcMessage {
Generating(GeneratingMessage),
Resharing(ResharingMessage),
Triple(TripleMessage),
}

#[derive(Default)]
pub struct MpcMessageQueue {
generating: VecDeque<GeneratingMessage>,
resharing_bins: HashMap<u64, VecDeque<ResharingMessage>>,
triple_bins: HashMap<u64, HashMap<u64, VecDeque<TripleMessage>>>,
}

impl MpcMessageQueue {
Expand All @@ -38,52 +52,58 @@ impl MpcMessageQueue {
.entry(message.epoch)
.or_default()
.push_back(message),
MpcMessage::Triple(message) => self
.triple_bins
.entry(message.epoch)
.or_default()
.entry(message.id)
.or_default()
.push_back(message),
}
}
}

pub trait MessageHandler {
fn handle(&mut self, queue: &mut MpcMessageQueue);
fn handle<C: MessageCtx + Send + Sync>(&mut self, ctx: C, queue: &mut MpcMessageQueue);
}

impl MessageHandler for GeneratingState {
fn handle(&mut self, queue: &mut MpcMessageQueue) {
match queue.generating.pop_front() {
Some(msg) => {
tracing::debug!("handling new generating message");
self.protocol.message(msg.from, msg.data);
}
None => {
tracing::debug!("no generating messages to handle")
}
};
fn handle<C: MessageCtx + Send + Sync>(&mut self, _ctx: C, queue: &mut MpcMessageQueue) {
while let Some(msg) = queue.generating.pop_front() {
tracing::debug!("handling new generating message");
self.protocol.message(msg.from, msg.data);
}
}
}

impl MessageHandler for ResharingState {
fn handle(&mut self, queue: &mut MpcMessageQueue) {
match queue
.resharing_bins
.entry(self.old_epoch)
.or_default()
.pop_front()
{
Some(msg) => {
tracing::debug!("handling new resharing message");
self.protocol.message(msg.from, msg.data);
}
None => {
tracing::debug!("no resharing messages to handle")
fn handle<C: MessageCtx + Send + Sync>(&mut self, _ctx: C, queue: &mut MpcMessageQueue) {
let q = queue.resharing_bins.entry(self.old_epoch).or_default();
while let Some(msg) = q.pop_front() {
tracing::debug!("handling new resharing message");
self.protocol.message(msg.from, msg.data);
}
}
}

impl MessageHandler for RunningState {
fn handle<C: MessageCtx + Send + Sync>(&mut self, _ctx: C, queue: &mut MpcMessageQueue) {
for (id, queue) in queue.triple_bins.entry(self.epoch).or_default() {
if let Some(protocol) = self.triple_manager.get_or_generate(*id) {
while let Some(message) = queue.pop_front() {
protocol.message(message.from, message.data);
}
}
};
}
}
}

impl MessageHandler for NodeState {
fn handle(&mut self, queue: &mut MpcMessageQueue) {
fn handle<C: MessageCtx + Send + Sync>(&mut self, ctx: C, queue: &mut MpcMessageQueue) {
match self {
NodeState::Generating(state) => state.handle(queue),
NodeState::Resharing(state) => state.handle(queue),
NodeState::Generating(state) => state.handle(ctx, queue),
NodeState::Resharing(state) => state.handle(ctx, queue),
NodeState::Running(state) => state.handle(ctx, queue),
_ => {
tracing::debug!("skipping message processing")
}
Expand Down
12 changes: 10 additions & 2 deletions node/src/protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ mod contract;
mod cryptography;
mod message;
mod state;
mod triple;

pub use contract::ProtocolState;
pub use message::MpcMessage;
pub use state::NodeState;

use self::consensus::ConsensusCtx;
use self::cryptography::CryptographicCtx;
use self::message::MessageCtx;
use crate::protocol::consensus::ConsensusProtocol;
use crate::protocol::cryptography::CryptographicProtocol;
use crate::protocol::message::{MessageHandler, MpcMessageQueue};
Expand Down Expand Up @@ -68,6 +70,12 @@ impl CryptographicCtx for &Ctx {
}
}

impl MessageCtx for &Ctx {
fn me(&self) -> Participant {
self.me
}
}

pub struct MpcSignProtocol {
ctx: Ctx,
receiver: mpsc::Receiver<MpcMessage>,
Expand Down Expand Up @@ -101,7 +109,7 @@ impl MpcSignProtocol {
}

pub async fn run(mut self) -> anyhow::Result<()> {
tracing::info!("running mpc recovery protocol");
let _span = tracing::info_span!("running", me = u32::from(self.ctx.me));
let mut queue = MpcMessageQueue::default();
loop {
tracing::debug!("trying to advance mpc recovery protocol");
Expand Down Expand Up @@ -140,7 +148,7 @@ impl MpcSignProtocol {
let mut state = std::mem::take(&mut *state_guard);
state = state.progress(&self.ctx).await?;
state = state.advance(&self.ctx, contract_state).await?;
state.handle(&mut queue);
state.handle(&self.ctx, &mut queue);
*state_guard = state;
drop(state_guard);
tokio::time::sleep(Duration::from_millis(1000)).await;
Expand Down
2 changes: 2 additions & 0 deletions node/src/protocol/state.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::triple::TripleManager;
use crate::types::{KeygenProtocol, PrivateKeyShare, PublicKey, ReshareProtocol};
use cait_sith::protocol::Participant;
use std::collections::HashMap;
Expand Down Expand Up @@ -31,6 +32,7 @@ pub struct RunningState {
pub threshold: usize,
pub private_share: PrivateKeyShare,
pub public_key: PublicKey,
pub triple_manager: TripleManager,
}

pub struct ResharingState {
Expand Down
Loading
Loading