Skip to content

Commit 1cc7a5f

Browse files
committed
primitives - ValidatorMessage with a Generic
1 parent fc3939d commit 1cc7a5f

File tree

6 files changed

+234
-83
lines changed

6 files changed

+234
-83
lines changed

primitives/src/sentry.rs

Lines changed: 67 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,66 @@
1-
use crate::targeting::Rules;
2-
use crate::validator::MessageTypes;
3-
use crate::{BigNum, Channel, ChannelId, ValidatorId};
1+
use crate::{
2+
targeting::Rules,
3+
validator::Type as MessageType,
4+
validator::{ApproveState, Heartbeat, MessageTypes, NewState},
5+
BigNum, Channel, ChannelId, ValidatorId,
6+
};
47
use chrono::{DateTime, Utc};
58
use serde::{Deserialize, Serialize};
6-
use std::collections::HashMap;
7-
use std::fmt;
8-
use std::hash::Hash;
9+
use std::{collections::HashMap, fmt, hash::Hash};
910

1011
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
1112
#[serde(rename_all = "camelCase")]
1213
pub struct LastApproved {
1314
/// NewState can be None if the channel is brand new
14-
pub new_state: Option<NewStateValidatorMessage>,
15+
pub new_state: Option<MessageResponse<NewState>>,
1516
/// ApproveState can be None if the channel is brand new
16-
pub approve_state: Option<ApproveStateValidatorMessage>,
17+
pub approve_state: Option<MessageResponse<ApproveState>>,
1718
}
1819

1920
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
20-
pub struct NewStateValidatorMessage {
21+
pub struct MessageResponse<T: MessageType> {
2122
pub from: ValidatorId,
2223
pub received: DateTime<Utc>,
23-
pub msg: MessageTypes,
24+
pub msg: message::Message<T>,
2425
}
2526

26-
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
27-
pub struct ApproveStateValidatorMessage {
28-
pub from: ValidatorId,
29-
pub received: DateTime<Utc>,
30-
pub msg: MessageTypes,
31-
}
27+
pub mod message {
28+
use std::{convert::TryFrom, ops::Deref};
3229

33-
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
34-
pub struct HeartbeatValidatorMessage {
35-
pub from: ValidatorId,
36-
pub received: DateTime<Utc>,
37-
pub msg: MessageTypes,
30+
use crate::validator::messages::*;
31+
use serde::{Deserialize, Serialize};
32+
33+
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
34+
#[serde(try_from = "MessageTypes", into = "MessageTypes")]
35+
pub struct Message<T: Type>(T);
36+
37+
impl<T: Type> Message<T> {
38+
pub fn into_inner(self) -> T {
39+
self.0
40+
}
41+
}
42+
43+
impl<T: Type> Deref for Message<T> {
44+
type Target = T;
45+
46+
fn deref(&self) -> &Self::Target {
47+
&self.0
48+
}
49+
}
50+
51+
impl<T: Type> TryFrom<MessageTypes> for Message<T> {
52+
type Error = MessageTypeError<T>;
53+
54+
fn try_from(value: MessageTypes) -> Result<Self, Self::Error> {
55+
<T as TryFrom<MessageTypes>>::try_from(value).map(Self)
56+
}
57+
}
58+
59+
impl<T: Type> Into<MessageTypes> for Message<T> {
60+
fn into(self) -> MessageTypes {
61+
self.0.into()
62+
}
63+
}
3864
}
3965

4066
#[serde(tag = "type", rename_all = "SCREAMING_SNAKE_CASE")]
@@ -119,7 +145,7 @@ pub struct LastApprovedResponse {
119145
/// None -> withHeartbeat=true wasn't passed
120146
/// Some(vec![]) (empty vec) or Some(heartbeats) - withHeartbeat=true was passed
121147
#[serde(default, skip_serializing_if = "Option::is_none")]
122-
pub heartbeats: Option<Vec<HeartbeatValidatorMessage>>,
148+
pub heartbeats: Option<Vec<MessageResponse<Heartbeat>>>,
123149
}
124150

125151
#[derive(Serialize, Deserialize, Debug)]
@@ -232,16 +258,16 @@ pub mod channel_list {
232258

233259
#[cfg(feature = "postgres")]
234260
mod postgres {
235-
use super::{
236-
ApproveStateValidatorMessage, HeartbeatValidatorMessage, NewStateValidatorMessage,
237-
ValidatorMessage,
261+
use super::{MessageResponse, ValidatorMessage};
262+
use crate::{
263+
sentry::EventAggregate,
264+
validator::{messages::Type as MessageType, MessageTypes},
238265
};
239-
use crate::sentry::EventAggregate;
240-
use crate::validator::MessageTypes;
241266
use bytes::BytesMut;
242267
use postgres_types::{accepts, to_sql_checked, IsNull, Json, ToSql, Type};
243-
use std::error::Error;
244-
use tokio_postgres::Row;
268+
use serde::Deserialize;
269+
use std::convert::TryFrom;
270+
use tokio_postgres::{Error, Row};
245271

246272
impl From<&Row> for EventAggregate {
247273
fn from(row: &Row) -> Self {
@@ -263,33 +289,20 @@ mod postgres {
263289
}
264290
}
265291

266-
impl From<&Row> for ApproveStateValidatorMessage {
267-
fn from(row: &Row) -> Self {
268-
Self {
269-
from: row.get("from"),
270-
received: row.get("received"),
271-
msg: row.get::<_, Json<MessageTypes>>("msg").0,
272-
}
273-
}
274-
}
292+
impl<T> TryFrom<&Row> for MessageResponse<T>
293+
where
294+
T: MessageType,
295+
for<'de> T: Deserialize<'de>,
296+
{
297+
type Error = Error;
275298

276-
impl From<&Row> for NewStateValidatorMessage {
277-
fn from(row: &Row) -> Self {
278-
Self {
299+
fn try_from(row: &Row) -> Result<Self, Self::Error> {
300+
Ok(Self {
279301
from: row.get("from"),
280302
received: row.get("received"),
281-
msg: row.get::<_, Json<MessageTypes>>("msg").0,
282-
}
283-
}
284-
}
285-
286-
impl From<&Row> for HeartbeatValidatorMessage {
287-
fn from(row: &Row) -> Self {
288-
Self {
289-
from: row.get("from"),
290-
received: row.get("received"),
291-
msg: row.get::<_, Json<MessageTypes>>("msg").0,
292-
}
303+
// guard against mistakes from wrong Queries
304+
msg: row.try_get::<_, Json<_>>("msg")?.0,
305+
})
293306
}
294307
}
295308

@@ -298,7 +311,7 @@ mod postgres {
298311
&self,
299312
ty: &Type,
300313
w: &mut BytesMut,
301-
) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
314+
) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
302315
Json(self).to_sql(ty, w)
303316
}
304317

primitives/src/validator.rs

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,149 @@ pub struct ValidatorDesc {
100100

101101
// Validator Message Types
102102

103-
mod messages {
103+
pub mod messages {
104+
use std::{any::type_name, convert::TryFrom, fmt, marker::PhantomData};
105+
use thiserror::Error;
106+
104107
use crate::BalancesMap;
105108
use chrono::{DateTime, Utc};
106109
use serde::{Deserialize, Serialize};
107110

111+
#[derive(Error, Debug)]
112+
pub struct MessageTypeError<T: Type> {
113+
expected: PhantomData<T>,
114+
actual: String,
115+
}
116+
117+
impl<T: Type> MessageTypeError<T> {
118+
pub fn for_actual<A: Type>(_actual: &A) -> Self {
119+
Self {
120+
expected: PhantomData::default(),
121+
actual: type_name::<A>().to_string(),
122+
}
123+
}
124+
}
125+
126+
impl<T: Type> fmt::Display for MessageTypeError<T> {
127+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128+
write!(
129+
f,
130+
"Expected {} message type but the actual is {}",
131+
type_name::<T>(),
132+
self.actual
133+
)
134+
}
135+
}
136+
137+
pub trait Type:
138+
fmt::Debug
139+
+ Into<MessageTypes>
140+
+ TryFrom<MessageTypes, Error = MessageTypeError<Self>>
141+
+ Clone
142+
+ PartialEq
143+
+ Eq
144+
{
145+
}
146+
147+
impl Type for Accounting {}
148+
impl TryFrom<MessageTypes> for Accounting {
149+
type Error = MessageTypeError<Self>;
150+
151+
fn try_from(value: MessageTypes) -> Result<Self, Self::Error> {
152+
match value {
153+
MessageTypes::ApproveState(msg) => Err(MessageTypeError::for_actual(&msg)),
154+
MessageTypes::NewState(msg) => Err(MessageTypeError::for_actual(&msg)),
155+
MessageTypes::RejectState(msg) => Err(MessageTypeError::for_actual(&msg)),
156+
MessageTypes::Heartbeat(msg) => Err(MessageTypeError::for_actual(&msg)),
157+
MessageTypes::Accounting(accounting) => Ok(accounting),
158+
}
159+
}
160+
}
161+
impl Into<MessageTypes> for Accounting {
162+
fn into(self) -> MessageTypes {
163+
MessageTypes::Accounting(self)
164+
}
165+
}
166+
167+
impl Type for ApproveState {}
168+
impl TryFrom<MessageTypes> for ApproveState {
169+
type Error = MessageTypeError<Self>;
170+
171+
fn try_from(value: MessageTypes) -> Result<Self, Self::Error> {
172+
match value {
173+
MessageTypes::NewState(msg) => Err(MessageTypeError::for_actual(&msg)),
174+
MessageTypes::RejectState(msg) => Err(MessageTypeError::for_actual(&msg)),
175+
MessageTypes::Heartbeat(msg) => Err(MessageTypeError::for_actual(&msg)),
176+
MessageTypes::Accounting(msg) => Err(MessageTypeError::for_actual(&msg)),
177+
MessageTypes::ApproveState(approve_state) => Ok(approve_state),
178+
}
179+
}
180+
}
181+
impl Into<MessageTypes> for ApproveState {
182+
fn into(self) -> MessageTypes {
183+
MessageTypes::ApproveState(self)
184+
}
185+
}
186+
187+
impl Type for NewState {}
188+
impl TryFrom<MessageTypes> for NewState {
189+
type Error = MessageTypeError<Self>;
190+
191+
fn try_from(value: MessageTypes) -> Result<Self, Self::Error> {
192+
match value {
193+
MessageTypes::ApproveState(msg) => Err(MessageTypeError::for_actual(&msg)),
194+
MessageTypes::RejectState(msg) => Err(MessageTypeError::for_actual(&msg)),
195+
MessageTypes::Heartbeat(msg) => Err(MessageTypeError::for_actual(&msg)),
196+
MessageTypes::Accounting(msg) => Err(MessageTypeError::for_actual(&msg)),
197+
MessageTypes::NewState(new_state) => Ok(new_state),
198+
}
199+
}
200+
}
201+
impl Into<MessageTypes> for NewState {
202+
fn into(self) -> MessageTypes {
203+
MessageTypes::NewState(self)
204+
}
205+
}
206+
207+
impl Type for RejectState {}
208+
impl TryFrom<MessageTypes> for RejectState {
209+
type Error = MessageTypeError<Self>;
210+
211+
fn try_from(value: MessageTypes) -> Result<Self, Self::Error> {
212+
match value {
213+
MessageTypes::ApproveState(msg) => Err(MessageTypeError::for_actual(&msg)),
214+
MessageTypes::NewState(msg) => Err(MessageTypeError::for_actual(&msg)),
215+
MessageTypes::Heartbeat(msg) => Err(MessageTypeError::for_actual(&msg)),
216+
MessageTypes::Accounting(msg) => Err(MessageTypeError::for_actual(&msg)),
217+
MessageTypes::RejectState(reject_state) => Ok(reject_state),
218+
}
219+
}
220+
}
221+
impl Into<MessageTypes> for RejectState {
222+
fn into(self) -> MessageTypes {
223+
MessageTypes::RejectState(self)
224+
}
225+
}
226+
227+
impl Type for Heartbeat {}
228+
impl TryFrom<MessageTypes> for Heartbeat {
229+
type Error = MessageTypeError<Self>;
230+
231+
fn try_from(value: MessageTypes) -> Result<Self, Self::Error> {
232+
match value {
233+
MessageTypes::ApproveState(msg) => Err(MessageTypeError::for_actual(&msg)),
234+
MessageTypes::NewState(msg) => Err(MessageTypeError::for_actual(&msg)),
235+
MessageTypes::RejectState(msg) => Err(MessageTypeError::for_actual(&msg)),
236+
MessageTypes::Accounting(msg) => Err(MessageTypeError::for_actual(&msg)),
237+
MessageTypes::Heartbeat(heartbeat) => Ok(heartbeat),
238+
}
239+
}
240+
}
241+
impl Into<MessageTypes> for Heartbeat {
242+
fn into(self) -> MessageTypes {
243+
MessageTypes::Heartbeat(self)
244+
}
245+
}
108246
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
109247
#[serde(rename_all = "camelCase")]
110248
pub struct Accounting {

0 commit comments

Comments
 (0)