1
+ use std:: {
2
+ convert:: TryFrom ,
3
+ marker:: PhantomData ,
4
+ } ;
5
+
6
+ use crate :: { balances_map:: UnifiedMap , Address , channel_v5:: Channel , UnifiedNum } ;
7
+ use chrono:: { DateTime , Utc } ;
8
+ use serde:: { Deserialize , Deserializer , Serialize } ;
9
+ use thiserror:: Error ;
10
+
11
+ #[ derive( Serialize , Debug , Clone , PartialEq , Eq ) ]
12
+ #[ serde( rename_all = "camelCase" ) ]
13
+ pub struct Accounting < S : BalancesState > {
14
+ pub channel : Channel ,
15
+ #[ serde( flatten) ]
16
+ pub balances : Balances < S > ,
17
+ pub updated : Option < DateTime < Utc > > ,
18
+ pub created : DateTime < Utc > ,
19
+ }
20
+
21
+ #[ derive( Serialize , Debug , Clone , PartialEq , Eq , Default ) ]
22
+ #[ serde( rename_all = "camelCase" ) ]
23
+ pub struct Balances < S > {
24
+ pub earners : UnifiedMap ,
25
+ pub spenders : UnifiedMap ,
26
+ #[ serde( skip_serializing, skip_deserializing) ]
27
+ state : PhantomData < S > ,
28
+ }
29
+
30
+ impl Balances < UncheckedState > {
31
+ pub fn check ( self ) -> Result < Balances < CheckedState > , Error > {
32
+ let earned = self
33
+ . earners
34
+ . values ( )
35
+ . sum :: < Option < UnifiedNum > > ( )
36
+ . ok_or_else ( || Error :: Overflow ( "earners overflow" . to_string ( ) ) ) ?;
37
+ let spent = self
38
+ . spenders
39
+ . values ( )
40
+ . sum :: < Option < UnifiedNum > > ( )
41
+ . ok_or_else ( || Error :: Overflow ( "spenders overflow" . to_string ( ) ) ) ?;
42
+
43
+ if earned != spent {
44
+ Err ( Error :: PayoutMismatch { spent, earned } )
45
+ } else {
46
+ Ok ( Balances {
47
+ earners : self . earners ,
48
+ spenders : self . spenders ,
49
+ state : PhantomData :: < CheckedState > :: default ( ) ,
50
+ } )
51
+ }
52
+ }
53
+ }
54
+
55
+ impl < S : BalancesState > Balances < S > {
56
+ pub fn spend (
57
+ & mut self ,
58
+ spender : Address ,
59
+ earner : Address ,
60
+ amount : UnifiedNum ,
61
+ ) -> Result < ( ) , OverflowError > {
62
+ let spent = self . spenders . entry ( spender) . or_default ( ) ;
63
+ * spent = spent
64
+ . checked_add ( & amount)
65
+ . ok_or_else ( || OverflowError :: Spender ( spender) ) ?;
66
+
67
+ let earned = self . earners . entry ( earner) . or_default ( ) ;
68
+ * earned = earned
69
+ . checked_add ( & amount)
70
+ . ok_or_else ( || OverflowError :: Earner ( earner) ) ?;
71
+
72
+ Ok ( ( ) )
73
+ }
74
+ }
75
+
76
+ #[ derive( Debug ) ]
77
+ pub enum OverflowError {
78
+ Spender ( Address ) ,
79
+ Earner ( Address ) ,
80
+ }
81
+
82
+ #[ derive( Debug , Error ) ]
83
+ pub enum Error {
84
+ #[ error( "Overflow of computation {0}" ) ]
85
+ Overflow ( String ) ,
86
+ #[ error( "Payout mismatch between spent ({spent}) and earned ({earned})" ) ]
87
+ PayoutMismatch {
88
+ spent : UnifiedNum ,
89
+ earned : UnifiedNum ,
90
+ } ,
91
+ }
92
+
93
+ pub trait BalancesState { }
94
+
95
+ #[ derive( Debug , Clone , PartialEq , Eq , Default ) ]
96
+ pub struct CheckedState ;
97
+ impl BalancesState for CheckedState { }
98
+
99
+ #[ derive( Debug , Clone , PartialEq , Eq , Default ) ]
100
+ pub struct UncheckedState ;
101
+ impl BalancesState for UncheckedState { }
102
+
103
+ impl TryFrom < Balances < UncheckedState > > for Balances < CheckedState > {
104
+ type Error = Error ;
105
+
106
+ fn try_from ( value : Balances < UncheckedState > ) -> Result < Self , Self :: Error > {
107
+ value. check ( )
108
+ }
109
+ }
110
+
111
+ /// This modules implements the needed non-generic structs that help with Deserialization of the `Balances<S>`
112
+ mod de {
113
+ use super :: * ;
114
+
115
+ #[ derive( Deserialize ) ]
116
+ struct DeserializeAccounting {
117
+ pub channel : Channel ,
118
+ #[ serde( flatten) ]
119
+ pub balances : DeserializeBalances ,
120
+ pub created : DateTime < Utc > ,
121
+ pub updated : Option < DateTime < Utc > > ,
122
+ }
123
+
124
+ impl < ' de > Deserialize < ' de > for Accounting < UncheckedState > {
125
+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
126
+ where
127
+ D : Deserializer < ' de > ,
128
+ {
129
+ let de_acc = DeserializeAccounting :: deserialize ( deserializer) ?;
130
+
131
+ Ok ( Self {
132
+ channel : de_acc. channel ,
133
+ balances : Balances :: < UncheckedState > :: try_from ( de_acc. balances ) . map_err ( serde:: de:: Error :: custom) ?,
134
+ created : de_acc. created ,
135
+ updated : de_acc. updated ,
136
+ } )
137
+ }
138
+ }
139
+
140
+ impl < ' de > Deserialize < ' de > for Accounting < CheckedState > {
141
+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
142
+ where
143
+ D : Deserializer < ' de > ,
144
+ {
145
+ let unchecked_acc = Accounting :: < UncheckedState > :: deserialize ( deserializer) ?;
146
+
147
+ Ok ( Self {
148
+ channel : unchecked_acc. channel ,
149
+ balances : unchecked_acc. balances . check ( ) . map_err ( serde:: de:: Error :: custom) ?,
150
+ created : unchecked_acc. created ,
151
+ updated : unchecked_acc. updated ,
152
+ } )
153
+ }
154
+ }
155
+
156
+ #[ derive( Deserialize , Debug , Clone , PartialEq , Eq ) ]
157
+ struct DeserializeBalances {
158
+ pub earners : UnifiedMap ,
159
+ pub spenders : UnifiedMap ,
160
+ }
161
+
162
+ impl < ' de > Deserialize < ' de > for Balances < CheckedState > {
163
+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
164
+ where
165
+ D : Deserializer < ' de > ,
166
+ {
167
+ let unchecked_balances = Balances :: < UncheckedState > :: deserialize ( deserializer) ?;
168
+
169
+ unchecked_balances. check ( ) . map_err ( serde:: de:: Error :: custom)
170
+ }
171
+ }
172
+
173
+ impl < ' de > Deserialize < ' de > for Balances < UncheckedState > {
174
+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
175
+ where
176
+ D : Deserializer < ' de > ,
177
+ {
178
+ let deser_balances = DeserializeBalances :: deserialize ( deserializer) ?;
179
+
180
+ Ok ( Balances {
181
+ earners : deser_balances. earners ,
182
+ spenders : deser_balances. spenders ,
183
+ state : PhantomData :: < UncheckedState > :: default ( ) ,
184
+ } )
185
+ }
186
+ }
187
+
188
+ impl From < DeserializeBalances > for Balances < UncheckedState > {
189
+ fn from ( value : DeserializeBalances ) -> Self {
190
+ Self {
191
+ earners : value. earners ,
192
+ spenders : value. spenders ,
193
+ state : PhantomData :: < UncheckedState > :: default ( ) ,
194
+ }
195
+ }
196
+ }
197
+ }
198
+
199
+ #[ cfg( feature = "postgres" ) ]
200
+ mod postgres {
201
+ use super :: * ;
202
+ use postgres_types:: Json ;
203
+ use tokio_postgres:: Row ;
204
+
205
+ impl TryFrom < & Row > for Accounting < CheckedState > {
206
+ type Error = Error ;
207
+
208
+ fn try_from ( row : & Row ) -> Result < Self , Self :: Error > {
209
+ let balances = Balances :: < UncheckedState > {
210
+ earners : row. get :: < _ , Json < _ > > ( "earners" ) . 0 ,
211
+ spenders : row. get :: < _ , Json < _ > > ( "spenders" ) . 0 ,
212
+ state : PhantomData :: default ( ) ,
213
+ } . check ( ) ?;
214
+
215
+ Ok ( Self {
216
+ channel : row. get ( "channel" ) ,
217
+ balances,
218
+ updated : row. get ( "updated" ) ,
219
+ created : row. get ( "created" ) ,
220
+ } )
221
+ }
222
+ }
223
+ }
0 commit comments