Skip to content

Commit

Permalink
fix: update enum-try-as-inner (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
sinui0 authored Oct 26, 2023
1 parent 832c434 commit 1ac6779
Show file tree
Hide file tree
Showing 17 changed files with 160 additions and 167 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,4 @@ once_cell = "1"
# DO NOT BUMP, SEE https://github.com/privacy-scaling-explorations/mpz/issues/61
generic-array = "0.14"
itybity = "0.2"
enum-try-as-inner = "0.1.0"
2 changes: 1 addition & 1 deletion ot/mpz-ot-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ derive_builder.workspace = true
itybity.workspace = true
opaque-debug.workspace = true
cfg-if.workspace = true
enum-try-as-inner = { tag = "0.1.0", git = "https://github.com/sinui0/enum-try-as-inner" }
bytemuck = { workspace = true, features = ["derive"] }
enum-try-as-inner.workspace = true

[dev-dependencies]
rstest.workspace = true
Expand Down
7 changes: 7 additions & 0 deletions ot/mpz-ot-core/src/chou_orlandi/msgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize};

/// A CO15 protocol message.
#[derive(Debug, Clone, EnumTryAsInner, Serialize, Deserialize)]
#[derive_err(Debug)]
#[allow(missing_docs)]
pub enum Message {
SenderSetup(SenderSetup),
Expand All @@ -18,6 +19,12 @@ pub enum Message {
CointossReceiverPayload(cointoss::msgs::ReceiverPayload),
}

impl From<MessageError> for std::io::Error {
fn from(err: MessageError) -> Self {
std::io::Error::new(std::io::ErrorKind::InvalidData, err.to_string())
}
}

/// Sender setup message.
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct SenderSetup {
Expand Down
7 changes: 7 additions & 0 deletions ot/mpz-ot-core/src/kos/msgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::msgs::Derandomize;

/// A KOS15 protocol message.
#[derive(Debug, Clone, EnumTryAsInner, Serialize, Deserialize)]
#[derive_err(Debug)]
#[allow(missing_docs)]
pub enum Message<BaseMsg> {
BaseMsg(BaseMsg),
Expand All @@ -26,6 +27,12 @@ pub enum Message<BaseMsg> {
CointossSenderPayload(CointossSenderPayload),
}

impl<BaseMsg> From<MessageError<BaseMsg>> for std::io::Error {
fn from(err: MessageError<BaseMsg>) -> Self {
std::io::Error::new(std::io::ErrorKind::InvalidData, err.to_string())
}
}

/// Extension message sent by the receiver.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Extend {
Expand Down
2 changes: 1 addition & 1 deletion ot/mpz-ot/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ p256 = { workspace = true, optional = true }
thiserror.workspace = true
rayon = { workspace = true }
itybity.workspace = true
enum-try-as-inner = { tag = "0.1.0", git = "https://github.com/sinui0/enum-try-as-inner" }
enum-try-as-inner.workspace = true
opaque-debug.workspace = true
serde = { workspace = true, optional = true }

Expand Down
26 changes: 13 additions & 13 deletions ot/mpz-ot/src/actor/kos/error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
actor::kos::msgs::Message,
actor::kos::msgs::MessageError,
kos::{ReceiverError, SenderError},
};

Expand Down Expand Up @@ -33,9 +33,9 @@ impl From<crate::OTError> for SenderActorError {
}
}

impl From<enum_try_as_inner::Error<crate::kos::SenderState>> for SenderActorError {
fn from(value: enum_try_as_inner::Error<crate::kos::SenderState>) -> Self {
SenderError::StateError(value.to_string()).into()
impl From<crate::kos::SenderStateError> for SenderActorError {
fn from(err: crate::kos::SenderStateError) -> Self {
SenderError::from(err).into()
}
}

Expand All @@ -51,11 +51,11 @@ impl<T> From<futures::channel::mpsc::TrySendError<T>> for SenderActorError {
}
}

impl<T> From<enum_try_as_inner::Error<Message<T>>> for SenderActorError {
fn from(value: enum_try_as_inner::Error<Message<T>>) -> Self {
impl<T> From<MessageError<T>> for SenderActorError {
fn from(err: MessageError<T>) -> Self {
SenderActorError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
value.to_string(),
err.to_string(),
))
}
}
Expand Down Expand Up @@ -104,9 +104,9 @@ impl From<crate::OTError> for ReceiverActorError {
}
}

impl From<enum_try_as_inner::Error<crate::kos::ReceiverState>> for ReceiverActorError {
fn from(value: enum_try_as_inner::Error<crate::kos::ReceiverState>) -> Self {
ReceiverError::StateError(value.to_string()).into()
impl From<crate::kos::ReceiverStateError> for ReceiverActorError {
fn from(err: crate::kos::ReceiverStateError) -> Self {
ReceiverError::from(err).into()
}
}

Expand All @@ -122,11 +122,11 @@ impl<T> From<futures::channel::mpsc::TrySendError<T>> for ReceiverActorError {
}
}

impl<T> From<enum_try_as_inner::Error<Message<T>>> for ReceiverActorError {
fn from(value: enum_try_as_inner::Error<Message<T>>) -> Self {
impl<T> From<MessageError<T>> for ReceiverActorError {
fn from(err: MessageError<T>) -> Self {
ReceiverActorError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
value.to_string(),
err.to_string(),
))
}
}
Expand Down
28 changes: 2 additions & 26 deletions ot/mpz-ot/src/actor/kos/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod sender;
use futures::{SinkExt, StreamExt};
use utils_aio::{sink::IoSink, stream::IoStream};

use crate::kos::{msgs::Message as KosMessage, ReceiverError, SenderError};
use crate::kos::msgs::Message as KosMessage;

pub use error::{ReceiverActorError, SenderActorError};
pub use receiver::{ReceiverActor, SharedReceiver};
Expand All @@ -26,35 +26,11 @@ pub(crate) fn into_kos_stream<'a, St: IoStream<msgs::Message<T>> + Send + Unpin,
stream: &'a mut St,
) -> impl IoStream<KosMessage<T>> + Send + Unpin + 'a {
StreamExt::map(stream, |msg| match msg {
Ok(msg) => match msg.into_protocol() {
Ok(msg) => Ok(msg),
Err(err) => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
err.to_string(),
)),
},
Ok(msg) => msg.try_into_protocol().map_err(From::from),
Err(err) => Err(err),
})
}

impl<T> From<enum_try_as_inner::Error<msgs::Message<T>>> for SenderError {
fn from(value: enum_try_as_inner::Error<msgs::Message<T>>) -> Self {
SenderError::from(std::io::Error::new(
std::io::ErrorKind::InvalidData,
value.to_string(),
))
}
}

impl<T> From<enum_try_as_inner::Error<msgs::Message<T>>> for ReceiverError {
fn from(value: enum_try_as_inner::Error<msgs::Message<T>>) -> Self {
ReceiverError::from(std::io::Error::new(
std::io::ErrorKind::InvalidData,
value.to_string(),
))
}
}

#[cfg(test)]
mod tests {
use crate::{
Expand Down
7 changes: 7 additions & 0 deletions ot/mpz-ot/src/actor/kos/msgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,19 @@ use mpz_ot_core::{

/// KOS actor message
#[derive(Debug, Clone, EnumTryAsInner, Serialize, Deserialize)]
#[derive_err(Debug)]
#[allow(missing_docs)]
pub enum Message<BaseOT> {
ActorMessage(ActorMessage),
Protocol(KosMessage<BaseOT>),
}

impl<BaseOT> From<MessageError<BaseOT>> for std::io::Error {
fn from(err: MessageError<BaseOT>) -> Self {
std::io::Error::new(std::io::ErrorKind::InvalidData, err.to_string())
}
}

impl<T> From<ActorMessage> for Message<T> {
fn from(value: ActorMessage) -> Self {
Message::ActorMessage(value)
Expand Down
6 changes: 3 additions & 3 deletions ot/mpz-ot/src/actor/kos/receiver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ where
let mut keys = self
.receiver
.state_mut()
.as_extension_mut()?
.try_as_extension_mut()?
.keys(choices.len())?;

let derandomize = keys.derandomize(choices)?;
Expand Down Expand Up @@ -205,7 +205,7 @@ where
_ = caller_response.send(
self.receiver
.state_mut()
.as_verify_mut()
.try_as_verify_mut()
.map_err(ReceiverError::from)
.and_then(|receiver| {
receiver.remove_record(*id).map_err(ReceiverError::from)
Expand Down Expand Up @@ -262,7 +262,7 @@ where

/// Handles a message from the KOS sender actor.
async fn handle_msg(&mut self, msg: Message<BaseOT::Msg>) -> Result<(), ReceiverActorError> {
let msg = msg.into_actor_message()?;
let msg = msg.try_into_actor_message()?;

match msg {
ActorMessage::TransferPayload(TransferPayload { id, payload }) => {
Expand Down
4 changes: 2 additions & 2 deletions ot/mpz-ot/src/actor/kos/sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ where
futures::select! {
// Processes a message received from the Receiver.
msg = self.stream.select_next_some() => {
self.handle_msg(msg?.into_actor_message()?)?;
self.handle_msg(msg?.try_into_actor_message()?)?;
}
// Processes a command from a controller.
cmd = self.commands.select_next_some() => {
Expand Down Expand Up @@ -185,7 +185,7 @@ where
let keys = self
.sender
.state_mut()
.as_extension_mut()
.try_as_extension_mut()
.map_err(SenderError::from)
.and_then(|sender| {
sender
Expand Down
30 changes: 21 additions & 9 deletions ot/mpz-ot/src/chou_orlandi/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use mpz_ot_core::chou_orlandi::msgs::Message;
use mpz_ot_core::chou_orlandi::msgs::MessageError;

use crate::OTError;

Expand All @@ -10,7 +10,7 @@ pub enum SenderError {
IOError(#[from] std::io::Error),
#[error(transparent)]
CoreError(#[from] mpz_ot_core::chou_orlandi::SenderError),
#[error("invalid state: expected {0}")]
#[error("{0}")]
StateError(String),
#[error(transparent)]
CointossError(#[from] mpz_core::cointoss::CointossError),
Expand All @@ -27,11 +27,17 @@ impl From<SenderError> for OTError {
}
}

impl From<enum_try_as_inner::Error<Message>> for SenderError {
fn from(value: enum_try_as_inner::Error<Message>) -> Self {
impl From<crate::chou_orlandi::sender::StateError> for SenderError {
fn from(err: crate::chou_orlandi::sender::StateError) -> Self {
SenderError::StateError(err.to_string())
}
}

impl From<MessageError> for SenderError {
fn from(err: MessageError) -> Self {
SenderError::from(std::io::Error::new(
std::io::ErrorKind::InvalidData,
value.to_string(),
err.to_string(),
))
}
}
Expand All @@ -44,7 +50,7 @@ pub enum ReceiverError {
IOError(#[from] std::io::Error),
#[error(transparent)]
CoreError(#[from] mpz_ot_core::chou_orlandi::ReceiverError),
#[error("invalid state: expected {0}")]
#[error("{0}")]
StateError(String),
#[error(transparent)]
CointossError(#[from] mpz_core::cointoss::CointossError),
Expand All @@ -61,11 +67,17 @@ impl From<ReceiverError> for OTError {
}
}

impl From<enum_try_as_inner::Error<Message>> for ReceiverError {
fn from(value: enum_try_as_inner::Error<Message>) -> Self {
impl From<crate::chou_orlandi::receiver::StateError> for ReceiverError {
fn from(err: crate::chou_orlandi::receiver::StateError) -> Self {
ReceiverError::StateError(err.to_string())
}
}

impl From<MessageError> for ReceiverError {
fn from(err: MessageError) -> Self {
ReceiverError::from(std::io::Error::new(
std::io::ErrorKind::InvalidData,
value.to_string(),
err.to_string(),
))
}
}
33 changes: 11 additions & 22 deletions ot/mpz-ot/src/chou_orlandi/receiver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ use crate::{CommittedOTReceiver, OTError, OTReceiver, OTSetup};
use super::ReceiverError;

#[derive(Debug, EnumTryAsInner)]
enum State {
#[derive_err(Debug)]
pub(crate) enum State {
Initialized {
config: ReceiverConfig,
seed: Option<[u8; 32]>,
Expand All @@ -30,12 +31,6 @@ enum State {
Error,
}

impl From<enum_try_as_inner::Error<State>> for ReceiverError {
fn from(value: enum_try_as_inner::Error<State>) -> Self {
ReceiverError::StateError(value.to_string())
}
}

/// Chou-Orlandi receiver.
#[derive(Debug)]
pub struct Receiver {
Expand Down Expand Up @@ -84,10 +79,8 @@ impl OTSetup for Receiver {
return Ok(());
}

let (config, seed) = self
.state
.replace(State::Error)
.into_initialized()
let (config, seed) = std::mem::replace(&mut self.state, State::Error)
.try_into_initialized()
.map_err(ReceiverError::from)?;

// If the receiver is committed, we generate the seed using a cointoss.
Expand All @@ -110,7 +103,7 @@ impl OTSetup for Receiver {
let sender_setup = stream
.expect_next()
.await?
.into_sender_setup()
.try_into_sender_setup()
.map_err(ReceiverError::from)?;

let receiver = Backend::spawn(move || receiver.setup(sender_setup)).await;
Expand Down Expand Up @@ -138,7 +131,7 @@ async fn execute_cointoss<
let payload = stream
.expect_next()
.await?
.into_cointoss_receiver_payload()?;
.try_into_cointoss_receiver_payload()?;

let (seeds, payload) = sender.finalize(payload)?;

Expand Down Expand Up @@ -167,10 +160,8 @@ where
stream: &mut St,
choices: &[T],
) -> Result<Vec<Block>, OTError> {
let mut receiver = self
.state
.replace(State::Error)
.into_setup()
let mut receiver = std::mem::replace(&mut self.state, State::Error)
.try_into_setup()
.map_err(ReceiverError::from)?;

let choices = choices.to_vec();
Expand All @@ -186,7 +177,7 @@ where
let sender_payload = stream
.expect_next()
.await?
.into_sender_payload()
.try_into_sender_payload()
.map_err(ReceiverError::from)?;

let (receiver, data) = Backend::spawn(move || {
Expand All @@ -213,10 +204,8 @@ impl CommittedOTReceiver<bool, Block> for Receiver {
sink: &mut Si,
_stream: &mut St,
) -> Result<(), OTError> {
let receiver = self
.state
.replace(State::Error)
.into_setup()
let receiver = std::mem::replace(&mut self.state, State::Error)
.try_into_setup()
.map_err(ReceiverError::from)?;

let Some(cointoss_payload) = self.cointoss_payload.take() else {
Expand Down
Loading

0 comments on commit 1ac6779

Please sign in to comment.