diff --git a/backend/ethereum/channel/backend.go b/backend/ethereum/channel/backend.go index e5e70d68..c8e34e7c 100644 --- a/backend/ethereum/channel/backend.go +++ b/backend/ethereum/channel/backend.go @@ -49,6 +49,7 @@ var ( abiParams abi.Type abiState abi.Type abiProgress abi.Method + abiRegister abi.Method ) func init() { @@ -74,6 +75,10 @@ func init() { if abiProgress, ok = adj.Methods["progress"]; !ok { panic("Could not find method progress in adjudicator contract.") } + + if abiRegister, ok = adj.Methods["register"]; !ok { + panic("Could not find method register in adjudicator contract.") + } } // Backend implements the interface defined in channel/Backend.go. diff --git a/backend/ethereum/channel/subscription.go b/backend/ethereum/channel/subscription.go index d61753e6..36232a8b 100644 --- a/backend/ethereum/channel/subscription.go +++ b/backend/ethereum/channel/subscription.go @@ -21,6 +21,7 @@ import ( "sync" "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" @@ -104,6 +105,7 @@ evloop: func (r *RegisteredSub) processNext(ctx context.Context, a *Adjudicator, _next *subscription.Event) (err error) { next, ok := _next.Data.(*adjudicator.AdjudicatorChannelUpdate) + next.Raw = _next.Log if !ok { log.Panicf("unexpected event type: %T", _next.Data) } @@ -115,7 +117,7 @@ func (r *RegisteredSub) processNext(ctx context.Context, a *Adjudicator, _next * // if newer version or same version and newer timeout, replace if current.Version() < next.Version || current.Version() == next.Version && currentTimeout.Time < next.Timeout { var e channel.AdjudicatorEvent - e, err = a.convertEvent(ctx, next, _next.Log.TxHash) + e, err = a.convertEvent(ctx, next) if err != nil { return } @@ -126,7 +128,7 @@ func (r *RegisteredSub) processNext(ctx context.Context, a *Adjudicator, _next * } default: // next-channel is empty var e channel.AdjudicatorEvent - e, err = a.convertEvent(ctx, next, _next.Log.TxHash) + e, err = a.convertEvent(ctx, next) if err != nil { return } @@ -161,14 +163,40 @@ func (r *RegisteredSub) Err() error { return <-r.err } -func (a *Adjudicator) convertEvent(ctx context.Context, e *adjudicator.AdjudicatorChannelUpdate, txHash common.Hash) (channel.AdjudicatorEvent, error) { +func (a *Adjudicator) convertEvent(ctx context.Context, e *adjudicator.AdjudicatorChannelUpdate) (channel.AdjudicatorEvent, error) { base := channel.NewAdjudicatorEventBase(e.ChannelID, NewBlockTimeout(a.ContractInterface, e.Timeout), e.Version) switch e.Phase { case phaseDispute: - return &channel.RegisteredEvent{AdjudicatorEventBase: *base}, nil + args, err := a.fetchRegisterCallData(ctx, e.Raw.TxHash) + if err != nil { + return nil, errors.WithMessage(err, "fetching call data") + } + + ch, ok := args.get(e.ChannelID) + if !ok { + return nil, errors.Errorf("channel not found in calldata: %v", e.ChannelID) + } + + var app channel.App + var zeroAddress common.Address + if ch.Params.App == zeroAddress { + app = channel.NoApp() + } else { + app, err = channel.Resolve(wallet.AsWalletAddr(ch.Params.App)) + if err != nil { + return nil, err + } + } + state := FromEthState(app, &ch.State) + + return &channel.RegisteredEvent{ + AdjudicatorEventBase: *base, + State: &state, + Sigs: ch.Sigs, + }, nil case phaseForceExec: - args, err := a.fetchProgressCallData(ctx, txHash) + args, err := a.fetchProgressCallData(ctx, e.Raw.TxHash) if err != nil { return nil, errors.WithMessage(err, "fetching call data") } @@ -200,24 +228,53 @@ type progressCallData struct { } func (a *Adjudicator) fetchProgressCallData(ctx context.Context, txHash common.Hash) (*progressCallData, error) { + var args progressCallData + err := a.fetchCallData(ctx, txHash, abiProgress, &args) + return &args, errors.WithMessage(err, "fetching call data") +} + +type registerCallData struct { + Channel adjudicator.AdjudicatorSignedState + SubChannels []adjudicator.AdjudicatorSignedState +} + +func (args *registerCallData) get(id channel.ID) (*adjudicator.AdjudicatorSignedState, bool) { + ch := &args.Channel + if ch.State.ChannelID == id { + return ch, true + } + for _, ch := range args.SubChannels { + if ch.State.ChannelID == id { + return &ch, true + } + } + return nil, false +} + +func (a *Adjudicator) fetchRegisterCallData(ctx context.Context, txHash common.Hash) (*registerCallData, error) { + var args registerCallData + err := a.fetchCallData(ctx, txHash, abiRegister, &args) + return &args, errors.WithMessage(err, "fetching call data") +} + +func (a *Adjudicator) fetchCallData(ctx context.Context, txHash common.Hash, method abi.Method, args interface{}) error { tx, _, err := a.ContractBackend.TransactionByHash(ctx, txHash) if err != nil { err = cherrors.CheckIsChainNotReachableError(err) - return nil, errors.WithMessage(err, "getting transaction") + return errors.WithMessage(err, "getting transaction") } - argsData := tx.Data()[len(abiProgress.ID):] + argsData := tx.Data()[len(method.ID):] - argsI, err := abiProgress.Inputs.UnpackValues(argsData) + argsI, err := method.Inputs.UnpackValues(argsData) if err != nil { - return nil, errors.WithMessage(err, "unpacking") + return errors.WithMessage(err, "unpacking") } - var args progressCallData - err = abiProgress.Inputs.Copy(&args, argsI) + err = method.Inputs.Copy(args, argsI) if err != nil { - return nil, errors.WithMessage(err, "copying into struct") + return errors.WithMessage(err, "copying into struct") } - return &args, nil + return nil } diff --git a/channel/adjudicator.go b/channel/adjudicator.go index 807a485b..19926cf9 100644 --- a/channel/adjudicator.go +++ b/channel/adjudicator.go @@ -144,6 +144,8 @@ type ( // registration on the blockchain. RegisteredEvent struct { AdjudicatorEventBase // Channel ID and Refutation phase timeout + State *State + Sigs []wallet.Sig } // ConcludedEvent signals channel conclusion. @@ -191,13 +193,15 @@ func (b AdjudicatorEventBase) Timeout() Timeout { return b.TimeoutV } func (b AdjudicatorEventBase) Version() uint64 { return b.VersionV } // NewRegisteredEvent creates a new RegisteredEvent. -func NewRegisteredEvent(id ID, timeout Timeout, version uint64) *RegisteredEvent { +func NewRegisteredEvent(id ID, timeout Timeout, version uint64, state *State, sigs []wallet.Sig) *RegisteredEvent { return &RegisteredEvent{ AdjudicatorEventBase: AdjudicatorEventBase{ IDV: id, TimeoutV: timeout, VersionV: version, }, + State: state, + Sigs: sigs, } } @@ -214,6 +218,17 @@ func NewProgressedEvent(id ID, timeout Timeout, state *State, idx Index) *Progre } } +// NewConcludedEvent creates a new ConcludedEvent. +func NewConcludedEvent(id ID, timeout Timeout, version uint64) *ConcludedEvent { + return &ConcludedEvent{ + AdjudicatorEventBase: AdjudicatorEventBase{ + IDV: id, + TimeoutV: timeout, + VersionV: version, + }, + } +} + // ElapsedTimeout is a Timeout that is always elapsed. type ElapsedTimeout struct{} diff --git a/client/adjudicate.go b/client/adjudicate.go index a060bda2..7a9d3c56 100644 --- a/client/adjudicate.go +++ b/client/adjudicate.go @@ -21,6 +21,7 @@ import ( "perun.network/go-perun/channel" "perun.network/go-perun/pkg/sync" + "perun.network/go-perun/wire" ) // AdjudicatorEventHandler represents an interface for handling adjudicator events. @@ -167,38 +168,80 @@ func (c *Channel) ProgressBy(ctx context.Context, update func(*channel.State)) e return errors.WithMessage(c.adjudicator.Progress(ctx, *pr), "progressing") } -// Settle concludes a channel and withdraws the funds. +// Settle concludes the channel and withdraws the funds. // -// Returns TxTimedoutError when the program times out waiting for a transaction -// to be mined. -// Returns ChainNotReachableError if the connection to the blockchain network -// fails when sending a transaction to / reading from the blockchain. -func (c *Channel) Settle(ctx context.Context, secondary bool) error { - return c.SettleWithSubchannels(ctx, nil, secondary) -} - -// SettleWithSubchannels concludes a channel and withdraws the funds. -// -// If the channel is a ledger channel with locked funds, additionally subStates -// can be supplied to also conclude the corresponding sub-channels. +// This only works if the channel is concludable. +// - A ledger channel is concludable, if it is final or if it has been disputed +// before and the dispute timeout has passed. +// - Sub-channels and virtual channels are only concludable if they are final +// and do not have any sub-channels. Otherwise, this means a dispute has +// occurred and the corresponding ledger channel must be disputed. // // Returns TxTimedoutError when the program times out waiting for a transaction // to be mined. // Returns ChainNotReachableError if the connection to the blockchain network // fails when sending a transaction to / reading from the blockchain. -func (c *Channel) SettleWithSubchannels(ctx context.Context, subStates channel.StateMap, secondary bool) error { - // Lock channel machine. - if !c.machMtx.TryLockCtx(ctx) { - return errors.WithMessage(ctx.Err(), "locking machine") +func (c *Channel) Settle(ctx context.Context, secondary bool) (err error) { + // Lock machines of channel and all subchannels recursively. + l, err := c.tryLockRecursive(ctx) + defer l.Unlock() + if err != nil { + return errors.WithMessage(err, "locking recursive") + } + + // Set phase `Withdrawing`. + if err = c.applyRecursive(func(c *Channel) error { + if c.machine.Phase() == channel.Withdrawn { + return nil + } + return c.machine.SetWithdrawing(ctx) + }); err != nil { + return errors.WithMessage(err, "setting phase `Withdrawing` recursive") + } + + // Settle. + err = c.settle(ctx, secondary) + if err != nil { + return + } + + // Set phase `Withdrawn`. + if err = c.applyRecursive(func(c *Channel) error { + // Skip if already withdrawn. + if c.machine.Phase() == channel.Withdrawn { + return nil + } + return c.machine.SetWithdrawn(ctx) + }); err != nil { + return errors.WithMessage(err, "setting phase `Withdrawn` recursive") } - defer c.machMtx.Unlock() - if err := c.machine.SetWithdrawing(ctx); err != nil { - return errors.WithMessage(err, "setting machine to withdrawing phase") + // Decrement account usage. + if err = c.applyRecursive(func(c *Channel) (err error) { + // Skip if we are not a participant, e.g., if this is a virtual channel and we are the hub. + if c.IsVirtualChannel() { + ourID := c.parent.Peers()[c.parent.Idx()] + if !c.hasParticipant(ourID) { + return + } + } + c.wallet.DecrementUsage(c.machine.Account().Address()) + return + }); err != nil { + return errors.WithMessage(err, "decrementing account usage") } + c.Log().Info("Withdrawal successful.") + return nil +} + +func (c *Channel) settle(ctx context.Context, secondary bool) error { switch { case c.IsLedgerChannel(): + subStates, err := c.subChannelStateMap() + if err != nil { + return errors.WithMessage(err, "creating sub-channel state map") + } req := c.machine.AdjudicatorReq() req.Secondary = secondary if err := c.adjudicator.Withdraw(ctx, req, subStates); err != nil { @@ -224,14 +267,17 @@ func (c *Channel) SettleWithSubchannels(ctx context.Context, subStates channel.S default: panic("invalid channel type") } + return nil +} - if err := c.machine.SetWithdrawn(ctx); err != nil { - return errors.WithMessage(err, "setting machine phase") +// hasParticipant returns we are participating in the channel. +func (c *Channel) hasParticipant(id wire.Address) bool { + for _, p := range c.Peers() { + if id.Equals(p) { + return true + } } - - c.Log().Info("Withdrawal successful.") - c.wallet.DecrementUsage(c.machine.Account().Address()) - return nil + return false } func (c *Channel) setMachinePhase(ctx context.Context, e channel.AdjudicatorEvent) (err error) { @@ -266,21 +312,13 @@ func (a mutexList) Unlock() { // tryLockRecursive tries to lock the channel and all of its sub-channels. // It returns a list of all the mutexes that have been locked. func (c *Channel) tryLockRecursive(ctx context.Context) (l mutexList, err error) { - l = mutexList{} - f := func(c *Channel) error { + err = c.applyRecursive(func(c *Channel) error { if !c.machMtx.TryLockCtx(ctx) { return errors.Errorf("locking machine mutex in time: %v", ctx.Err()) } l = append(l, &c.machMtx) return nil - } - - err = f(c) - if err != nil { - return - } - - err = c.applyToSubChannelsRecursive(f) + }) return } @@ -306,13 +344,8 @@ func (c *Channel) applyToSubChannelsRecursive(f func(*Channel) error) (err error return } -// setRegisteringRecursive sets the machine phase of the channel and all of its sub-channels to `Registering`. -// Assumes that the channel machine has been locked. -func (c *Channel) setRegisteringRecursive(ctx context.Context) (err error) { - f := func(c *Channel) error { - return c.machine.SetRegistering(ctx) - } - +// applyRecursive applies the function to the channel and its sub-channels recursively. +func (c *Channel) applyRecursive(f func(*Channel) error) (err error) { err = f(c) if err != nil { return err @@ -322,20 +355,20 @@ func (c *Channel) setRegisteringRecursive(ctx context.Context) (err error) { return } +// setRegisteringRecursive sets the machine phase of the channel and all of its sub-channels to `Registering`. +// Assumes that the channel machine has been locked. +func (c *Channel) setRegisteringRecursive(ctx context.Context) (err error) { + return c.applyRecursive(func(c *Channel) error { + return c.machine.SetRegistering(ctx) + }) +} + // setRegisteredRecursive sets the machine phase of the channel and all of its sub-channels to `Registered`. // Assumes that the channel machine has been locked. func (c *Channel) setRegisteredRecursive(ctx context.Context) (err error) { - f := func(c *Channel) error { + return c.applyRecursive(func(c *Channel) error { return c.machine.SetRegistered(ctx) - } - - err = f(c) - if err != nil { - return err - } - - err = c.applyToSubChannelsRecursive(f) - return + }) } // gatherSubChannelStates gathers the state of all sub-channels recursively. @@ -352,3 +385,14 @@ func (c *Channel) gatherSubChannelStates() (states []channel.SignedState, err er }) return } + +// gatherSubChannelStates gathers the state of all sub-channels recursively. +// Assumes sub-channels are locked. +func (c *Channel) subChannelStateMap() (states channel.StateMap, err error) { + states = channel.MakeStateMap() + err = c.applyToSubChannelsRecursive(func(c *Channel) error { + states[c.ID()] = c.state() + return nil + }) + return +} diff --git a/client/client_role_test.go b/client/client_role_test.go index e6d8c5e7..f5c0ecca 100644 --- a/client/client_role_test.go +++ b/client/client_role_test.go @@ -15,17 +15,13 @@ package client_test import ( - "context" "math/rand" - "sync" "testing" "time" "github.com/stretchr/testify/assert" - "perun.network/go-perun/channel" "perun.network/go-perun/client" ctest "perun.network/go-perun/client/test" - "perun.network/go-perun/log" wtest "perun.network/go-perun/wallet/test" wiretest "perun.network/go-perun/wire/test" ) @@ -34,31 +30,22 @@ const roleOperationTimeout = 1 * time.Second func NewSetups(rng *rand.Rand, names []string) []ctest.RoleSetup { var ( - bus = wiretest.NewSerializingLocalBus() - n = len(names) - setup = make([]ctest.RoleSetup, n) + bus = wiretest.NewSerializingLocalBus() + n = len(names) + setup = make([]ctest.RoleSetup, n) + backend = ctest.NewMockBackend(rng) ) for i := 0; i < n; i++ { - acc := wtest.NewRandomAccount(rng) - - // The use of a delayed funder simulates that channel participants may - // receive their funding confirmation at different times. - var funder channel.Funder - if i == 0 { - funder = &logFunderWithDelay{log.WithField("role", names[i])} - } else { - funder = &logFunder{log.WithField("role", names[i])} - } - setup[i] = ctest.RoleSetup{ Name: names[i], - Identity: acc, + Identity: wtest.NewRandomAccount(rng), Bus: bus, - Funder: funder, - Adjudicator: &logAdjudicator{log.WithField("role", names[i]), sync.RWMutex{}, nil}, + Funder: backend, + Adjudicator: backend, Wallet: wtest.NewWallet(), Timeout: roleOperationTimeout, + Backend: backend, } } @@ -84,86 +71,3 @@ func NewClients(rng *rand.Rand, names []string, t *testing.T) []*Client { } return clients } - -type ( - logFunder struct { - log log.Logger - } - - logFunderWithDelay struct { - log log.Logger - } - - logAdjudicator struct { - log log.Logger - mu sync.RWMutex - latestEvent channel.AdjudicatorEvent - } -) - -func (f *logFunder) Fund(_ context.Context, req channel.FundingReq) error { - f.log.Infof("Funding: %v", req) - return nil -} - -func (f *logFunderWithDelay) Fund(_ context.Context, req channel.FundingReq) error { - time.Sleep(100 * time.Millisecond) - f.log.Infof("Funding: %v", req) - return nil -} - -func (a *logAdjudicator) Register(_ context.Context, req channel.AdjudicatorReq, subChannels []channel.SignedState) error { - a.log.Infof("Register: %v", req) - e := channel.NewRegisteredEvent( - req.Params.ID(), - &channel.ElapsedTimeout{}, - req.Tx.Version, - ) - a.setEvent(e) - return nil -} - -func (a *logAdjudicator) Progress(_ context.Context, req channel.ProgressReq) error { - a.log.Infof("Progress: %v", req) - a.setEvent(channel.NewProgressedEvent( - req.Params.ID(), - &channel.ElapsedTimeout{}, - req.NewState.Clone(), - req.Idx, - )) - return nil -} - -func (a *logAdjudicator) Withdraw(_ context.Context, req channel.AdjudicatorReq, subStates channel.StateMap) error { - a.log.Infof("Withdraw: %v, %v", req, subStates) - return nil -} - -func (a *logAdjudicator) Subscribe(_ context.Context, params *channel.Params) (channel.AdjudicatorSubscription, error) { - a.log.Infof("SubscribeRegistered: %v", params) - return &simSubscription{a}, nil -} - -func (a *logAdjudicator) setEvent(e channel.AdjudicatorEvent) { - a.mu.Lock() - defer a.mu.Unlock() - a.latestEvent = e -} - -type simSubscription struct { - a *logAdjudicator -} - -func (s *simSubscription) Next() channel.AdjudicatorEvent { - s.a.mu.RLock() - defer s.a.mu.RUnlock() - return s.a.latestEvent -} - -func (s *simSubscription) Close() error { - return nil -} - -func (s *simSubscription) Err() error { - return nil -} diff --git a/client/client_test.go b/client/client_test.go index b7185eb1..f7116004 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package client +package client_test import ( "context" @@ -23,6 +23,8 @@ import ( "github.com/stretchr/testify/require" "perun.network/go-perun/channel" + "perun.network/go-perun/client" + ctest "perun.network/go-perun/client/test" "perun.network/go-perun/pkg/test" wtest "perun.network/go-perun/wallet/test" "perun.network/go-perun/wire" @@ -41,64 +43,34 @@ func (d DummyBus) SubscribeClient(wire.Consumer, wire.Address) error { return nil } -type DummyFunder struct { - t *testing.T -} - -func (d *DummyFunder) Fund(context.Context, channel.FundingReq) error { - d.t.Error("DummyFunder.Fund called") - return errors.New("DummyFunder.Fund called") -} - -type DummyAdjudicator struct { - t *testing.T -} - -func (d *DummyAdjudicator) Register(context.Context, channel.AdjudicatorReq, []channel.SignedState) error { - d.t.Error("DummyAdjudicator.Register called") - return errors.New("DummyAdjudicator.Register called") -} - -func (d *DummyAdjudicator) Progress(context.Context, channel.ProgressReq) error { - d.t.Error("DummyAdjudicator.Register called") - return errors.New("DummyAdjudicator.Progress called") -} - -func (d *DummyAdjudicator) Withdraw(context.Context, channel.AdjudicatorReq, channel.StateMap) error { - d.t.Error("DummyAdjudicator.Withdraw called") - return errors.New("DummyAdjudicator.Withdraw called") -} - -func (d *DummyAdjudicator) Subscribe(context.Context, *channel.Params) (channel.AdjudicatorSubscription, error) { - d.t.Error("DummyAdjudicator.SubscribeRegistered called") - return nil, errors.New("DummyAdjudicator.SubscribeRegistered called") -} - func TestClient_New_NilArgs(t *testing.T) { rng := test.Prng(t) id := wtest.NewRandomAddress(rng) - b, f, a, w := &DummyBus{t}, &DummyFunder{t}, &DummyAdjudicator{t}, wtest.RandomWallet() - assert.Panics(t, func() { New(nil, b, f, a, w) }) - assert.Panics(t, func() { New(id, nil, f, a, w) }) - assert.Panics(t, func() { New(id, b, nil, a, w) }) - assert.Panics(t, func() { New(id, b, f, nil, w) }) - assert.Panics(t, func() { New(id, b, f, a, nil) }) + backend := &ctest.MockBackend{} + b, f, a, w := &DummyBus{t}, backend, backend, wtest.RandomWallet() + assert.Panics(t, func() { client.New(nil, b, f, a, w) }) + assert.Panics(t, func() { client.New(id, nil, f, a, w) }) + assert.Panics(t, func() { client.New(id, b, nil, a, w) }) + assert.Panics(t, func() { client.New(id, b, f, nil, w) }) + assert.Panics(t, func() { client.New(id, b, f, a, nil) }) } func TestClient_Handle_NilArgs(t *testing.T) { rng := test.Prng(t) - c, err := New(wtest.NewRandomAddress(rng), &DummyBus{t}, &DummyFunder{t}, &DummyAdjudicator{t}, wtest.RandomWallet()) + backend := &ctest.MockBackend{} + c, err := client.New(wtest.NewRandomAddress(rng), &DummyBus{t}, backend, backend, wtest.RandomWallet()) require.NoError(t, err) - dummyUH := UpdateHandlerFunc(func(*channel.State, ChannelUpdate, *UpdateResponder) {}) + dummyUH := client.UpdateHandlerFunc(func(*channel.State, client.ChannelUpdate, *client.UpdateResponder) {}) assert.Panics(t, func() { c.Handle(nil, dummyUH) }) - dummyPH := ProposalHandlerFunc(func(ChannelProposal, *ProposalResponder) {}) + dummyPH := client.ProposalHandlerFunc(func(client.ChannelProposal, *client.ProposalResponder) {}) assert.Panics(t, func() { c.Handle(dummyPH, nil) }) } func TestClient_New(t *testing.T) { rng := test.Prng(t) - c, err := New(wtest.NewRandomAddress(rng), &DummyBus{t}, &DummyFunder{t}, &DummyAdjudicator{t}, wtest.RandomWallet()) + backend := &ctest.MockBackend{} + c, err := client.New(wtest.NewRandomAddress(rng), &DummyBus{t}, backend, backend, wtest.RandomWallet()) assert.NoError(t, err) require.NotNil(t, c) } diff --git a/client/test/backend.go b/client/test/backend.go new file mode 100644 index 00000000..8cde7a28 --- /dev/null +++ b/client/test/backend.go @@ -0,0 +1,314 @@ +// Copyright 2021 - See NOTICE file for copyright holders. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "bytes" + "context" + "math/big" + "math/rand" + "sync" + "time" + + "perun.network/go-perun/channel" + "perun.network/go-perun/log" + "perun.network/go-perun/pkg/io" + "perun.network/go-perun/wallet" +) + +type ( + // MockBackend is a mocked backend useful for testing. + MockBackend struct { + log log.Logger + rng rng + mu sync.Mutex + latestEvents map[channel.ID]channel.AdjudicatorEvent + eventSubs map[channel.ID][]chan channel.AdjudicatorEvent + balances map[addressMapKey]map[assetMapKey]*big.Int + } + + rng interface { + Intn(n int) int + } + + threadSafeRng struct { + mu sync.Mutex + r *rand.Rand + } +) + +// NewMockBackend creates a new backend object. +func NewMockBackend(rng *rand.Rand) *MockBackend { + return &MockBackend{ + log: log.Get(), + rng: newThreadSafePrng(rng), + latestEvents: make(map[channel.ID]channel.AdjudicatorEvent), + eventSubs: make(map[channel.ID][]chan channel.AdjudicatorEvent), + balances: make(map[string]map[string]*big.Int), + } +} + +func newThreadSafePrng(r *rand.Rand) *threadSafeRng { + return &threadSafeRng{ + mu: sync.Mutex{}, + r: r, + } +} + +func (g *threadSafeRng) Intn(n int) int { + g.mu.Lock() + defer g.mu.Unlock() + + return g.r.Intn(n) +} + +// Fund funds the channel. +func (b *MockBackend) Fund(_ context.Context, req channel.FundingReq) error { + time.Sleep(time.Duration(b.rng.Intn(100)) * time.Millisecond) + b.log.Infof("Funding: %+v", req) + return nil +} + +// Register registers the channel. +func (b *MockBackend) Register(_ context.Context, req channel.AdjudicatorReq, subChannels []channel.SignedState) error { + b.log.Infof("Register: %+v", req) + + b.mu.Lock() + defer b.mu.Unlock() + + // Check concluded. + ch := req.Params.ID() + if b.isConcluded(ch) { + log.Debug("register: already concluded:", ch) + return nil + } + + channels := append([]channel.SignedState{ + { + Params: req.Params, + State: req.Tx.State, + Sigs: req.Tx.Sigs, + }, + }, subChannels...) + + for _, ch := range channels { + b.setLatestEvent( + ch.Params.ID(), + channel.NewRegisteredEvent( + ch.Params.ID(), + &channel.ElapsedTimeout{}, + ch.State.Version, + ch.State, + ch.Sigs, + ), + ) + } + return nil +} + +func (b *MockBackend) setLatestEvent(ch channel.ID, e channel.AdjudicatorEvent) { + b.latestEvents[ch] = e + // Update subscriptions. + if channelSubs, ok := b.eventSubs[ch]; ok { + for _, events := range channelSubs { + // Remove previous latest event. + select { + case <-events: + default: + } + // Add latest event. + events <- e + } + } +} + +// Progress progresses the channel state. +func (b *MockBackend) Progress(_ context.Context, req channel.ProgressReq) error { + b.log.Infof("Progress: %+v", req) + + b.mu.Lock() + defer b.mu.Unlock() + + b.setLatestEvent( + req.Params.ID(), + channel.NewProgressedEvent( + req.Params.ID(), + &channel.ElapsedTimeout{}, + req.NewState.Clone(), + req.Idx, + ), + ) + return nil +} + +// outcomeRecursive returns the accumulated outcome of the channel and its sub-channels. +func outcomeRecursive(state *channel.State, subStates channel.StateMap) (outcome channel.Balances) { + outcome = state.Balances.Clone() + for _, subAlloc := range state.Locked { + subOutcome := outcomeRecursive(subStates[subAlloc.ID], subStates) + for a, bals := range subOutcome { + for p, bal := range bals { + _p := p + if len(subAlloc.IndexMap) > 0 { + _p = int(subAlloc.IndexMap[p]) + } + outcome[a][_p].Add(outcome[a][_p], bal) + } + } + } + return +} + +// Withdraw withdraws the channel funds. +func (b *MockBackend) Withdraw(_ context.Context, req channel.AdjudicatorReq, subStates channel.StateMap) error { + b.mu.Lock() + defer b.mu.Unlock() + + // Check concluded. + ch := req.Params.ID() + if b.isConcluded(ch) { + log.Debug("withdraw: already concluded:", ch) + return nil + } + + outcome := outcomeRecursive(req.Tx.State, subStates) + b.log.Infof("Withdraw: %+v, %+v, %+v", req, subStates, outcome) + for a, assetOutcome := range outcome { + asset := req.Tx.Allocation.Assets[a] + for p, amount := range assetOutcome { + participant := req.Params.Parts[p] + b.addBalance(participant, asset, amount) + } + } + + b.setLatestEvent(ch, channel.NewConcludedEvent(ch, &channel.ElapsedTimeout{}, req.Tx.Version)) + return nil +} + +func (b *MockBackend) isConcluded(ch channel.ID) bool { + e, ok := b.latestEvents[ch] + if !ok { + return false + } + if _, ok := e.(*channel.ConcludedEvent); !ok { + return false + } + return true +} + +func (b *MockBackend) addBalance(p wallet.Address, a channel.Asset, v *big.Int) { + bal := b.getBalance(p, a) + bal = new(big.Int).Add(bal, v) + b.setBalance(p, a, bal) +} + +func (b *MockBackend) getBalance(p wallet.Address, a channel.Asset) *big.Int { + partBals, ok := b.balances[newAddressMapKey(p)] + if !ok { + return big.NewInt(0) + } + bal, ok := partBals[newAssetMapKey(a)] + if !ok { + return big.NewInt(0) + } + return new(big.Int).Set(bal) +} + +type ( + addressMapKey = string + assetMapKey = string +) + +func newAddressMapKey(a wallet.Address) addressMapKey { + return encodableAsString(a) +} + +func newAssetMapKey(a channel.Asset) assetMapKey { + return encodableAsString(a) +} + +func encodableAsString(e io.Encoder) string { + var buf bytes.Buffer + if err := e.Encode(&buf); err != nil { + panic(err) + } + return buf.String() +} + +// GetBalance returns the balance for the participant and asset. +func (b *MockBackend) GetBalance(p wallet.Address, a channel.Asset) *big.Int { + b.mu.Lock() + defer b.mu.Unlock() + return b.getBalance(p, a) +} + +func (b *MockBackend) setBalance(p wallet.Address, a channel.Asset, v *big.Int) { + partKey := newAddressMapKey(p) + partBals, ok := b.balances[partKey] + if !ok { + log.Debug("part not found", p) + partBals = make(map[string]*big.Int) + b.balances[partKey] = partBals + } + log.Debug("set balance:", p, v) + partBals[newAssetMapKey(a)] = new(big.Int).Set(v) +} + +// Subscribe creates an event subscription. +func (b *MockBackend) Subscribe(ctx context.Context, params *channel.Params) (channel.AdjudicatorSubscription, error) { + b.log.Infof("SubscribeRegistered: %+v", params) + + b.mu.Lock() + defer b.mu.Unlock() + + sub := &mockSubscription{ + ctx: ctx, + events: make(chan channel.AdjudicatorEvent, 1), + err: make(chan error, 1), + } + b.eventSubs[params.ID()] = append(b.eventSubs[params.ID()], sub.events) + + // Feed latest event if any. + if e, ok := b.latestEvents[params.ID()]; ok { + sub.events <- e + } + + return sub, nil +} + +type mockSubscription struct { + ctx context.Context + events chan channel.AdjudicatorEvent + err chan error +} + +func (s *mockSubscription) Next() channel.AdjudicatorEvent { + select { + case e := <-s.events: + return e + case <-s.ctx.Done(): + s.err <- s.ctx.Err() + return nil + } +} + +func (s *mockSubscription) Close() error { + close(s.events) + return nil +} + +func (s *mockSubscription) Err() error { + return <-s.err +} diff --git a/client/test/role.go b/client/test/role.go index aaa7aff0..cd00aff1 100644 --- a/client/test/role.go +++ b/client/test/role.go @@ -63,6 +63,7 @@ type ( Wallet wallettest.Wallet PR persistence.PersistRestorer // Optional PersistRestorer Timeout time.Duration // Timeout waiting for other role, not challenge duration + Backend *MockBackend } // ExecConfig contains additional config parameters for the tests. diff --git a/client/virtual_channel.go b/client/virtual_channel.go index 1e6cd643..af98b521 100644 --- a/client/virtual_channel.go +++ b/client/virtual_channel.go @@ -16,6 +16,7 @@ package client import ( "context" + "fmt" "time" "github.com/pkg/errors" @@ -97,6 +98,50 @@ func (c *Client) handleVirtualChannelFundingProposal( c.acceptProposal(responder) } +func (c *Channel) watchVirtual() error { + log := c.Log().WithField("proc", fmt.Sprintf("virtual channel watcher %v", c.ID())) + defer log.Info("Watcher returned.") + + // Subscribe to state changes + ctx := c.Ctx() + sub, err := c.adjudicator.Subscribe(ctx, c.Params()) + if err != nil { + return errors.WithMessage(err, "subscribing to adjudicator state changes") + } + defer func() { + if err := sub.Close(); err != nil { + log.Warn(err) + } + }() + + // Wait for state changed event + for e := sub.Next(); e != nil; e = sub.Next() { + // Update channel + switch e := e.(type) { + case *channel.RegisteredEvent: + if e.Version() > c.State().Version { + err := c.pushVirtualUpdate(ctx, e.State, e.Sigs) + if err != nil { + log.Warnf("error updating virtual channel: %v", err) + } + } + + case *channel.ProgressedEvent: + log.Errorf("Virtual channel progressed: %v", e.ID()) + + case *channel.ConcludedEvent: + log.Infof("Virtual channel concluded: %v", e.ID()) + + default: + log.Errorf("unsupported type: %T", e) + } + } + + err = sub.Err() + log.Debugf("Subscription closed: %v", err) + return errors.WithMessage(err, "subscription closed") +} + // dummyAcount represents an address but cannot be used for signing. type dummyAccount struct { address wallet.Address @@ -156,6 +201,33 @@ func (c *Client) persistVirtualChannel(ctx context.Context, parent *Channel, pee return ch, nil } +func (c *Channel) pushVirtualUpdate(ctx context.Context, state *channel.State, sigs []wallet.Sig) error { + if !c.machMtx.TryLockCtx(ctx) { + return errors.Errorf("locking machine mutex in time: %v", ctx.Err()) + } + defer c.machMtx.Unlock() + + m := c.machine + if err := m.ForceUpdate(ctx, state, hubIndex); err != nil { + return err + } + + for i, sig := range sigs { + idx := channel.Index(i) + if err := m.AddSig(ctx, idx, sig); err != nil { + return err + } + } + + var err error + if state.IsFinal { + err = m.EnableFinal(ctx) + } else { + err = m.EnableUpdate(ctx) + } + return err +} + func (c *Client) validateVirtualChannelFundingProposal( ch *Channel, prop *virtualChannelFundingProposal, @@ -265,8 +337,16 @@ func (c *Client) matchFundingProposal(ctx context.Context, a, b interface{}) boo // Store state for withdrawal after dispute. parent := channels[0] peers := c.gatherPeers(channels...) - _, err = c.persistVirtualChannel(ctx, parent, peers, *prop0.Initial.Params, *prop0.Initial.State, prop0.Initial.Sigs) - return err == nil + virtual, err := c.persistVirtualChannel(ctx, parent, peers, *prop0.Initial.Params, *prop0.Initial.State, prop0.Initial.Sigs) + if err != nil { + return false + } + + go func() { + err := virtual.watchVirtual() + c.log.Debugf("channel %v: watcher stopped: %v", virtual.ID(), err) + }() + return true } func castToFundingProposals(inputs ...interface{}) ([]*virtualChannelFundingProposal, error) { diff --git a/client/virtual_channel_test.go b/client/virtual_channel_test.go index be7c6c88..0867ceda 100644 --- a/client/virtual_channel_test.go +++ b/client/virtual_channel_test.go @@ -17,6 +17,7 @@ package client_test import ( "context" "math/big" + "math/rand" "testing" "time" @@ -26,6 +27,7 @@ import ( "perun.network/go-perun/channel" chtest "perun.network/go-perun/channel/test" "perun.network/go-perun/client" + ctest "perun.network/go-perun/client/test" "perun.network/go-perun/pkg/sync" "perun.network/go-perun/pkg/test" "perun.network/go-perun/wire" @@ -68,7 +70,47 @@ func TestVirtualChannelsOptimistic(t *testing.T) { assert.NoError(t, err, "Bob: invalid final balances") } +func TestVirtualChannelsDispute(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testDuration) + defer cancel() + + vct := setupVirtualChannelTest(t, ctx) + assert := assert.New(t) + + chs := []*client.Channel{vct.chAliceIngrid, vct.chIngridAlice, vct.chBobIngrid, vct.chIngridBob} + // Register the channels in a random order. + for _, i := range rand.Perm(len(chs)) { + err := chs[i].Register(ctx) + assert.NoErrorf(err, "register channel: %d", i) + } + + time.Sleep(100 * time.Millisecond) // Sleep to ensure that registered events have been processed. + + // Settle the channels in a random order. + for _, i := range rand.Perm(len(chs)) { + err := chs[i].Settle(ctx, false) + assert.NoErrorf(err, "settle channel: %d", i) + } + + // Test final balances. + vct.testFinalBalancesDispute(t) +} + +func (vct *virtualChannelTest) testFinalBalancesDispute(t *testing.T) { + assert := assert.New(t) + backend, asset := vct.backend, vct.asset + got, expected := backend.GetBalance(vct.alice.Identity.Address(), asset), vct.finalBalsAlice[0] + assert.Truef(got.Cmp(expected) == 0, "alice: wrong final balance: got %v, expected %v", got, expected) + got, expected = backend.GetBalance(vct.bob.Identity.Address(), asset), vct.finalBalsBob[0] + assert.Truef(got.Cmp(expected) == 0, "bob: wrong final balance: got %v, expected %v", got, expected) + got, expected = backend.GetBalance(vct.ingrid.Identity.Address(), asset), vct.finalBalIngrid + assert.Truef(got.Cmp(expected) == 0, "ingrid: wrong final balance: got %v, expected %v", got, expected) +} + type virtualChannelTest struct { + alice *Client + bob *Client + ingrid *Client chAliceIngrid *client.Channel chIngridAlice *client.Channel chBobIngrid *client.Channel @@ -78,7 +120,10 @@ type virtualChannelTest struct { virtualBalsUpdated []*big.Int finalBalsAlice []*big.Int finalBalsBob []*big.Int + finalBalIngrid *big.Int errs chan error + backend *ctest.MockBackend + asset channel.Asset } func setupVirtualChannelTest(t *testing.T, ctx context.Context) (vct virtualChannelTest) { @@ -87,12 +132,14 @@ func setupVirtualChannelTest(t *testing.T, ctx context.Context) (vct virtualChan // Set test values. asset := chtest.NewRandomAsset(rng) + vct.asset = asset initBalsAlice := []*big.Int{big.NewInt(10), big.NewInt(10)} // with Ingrid initBalsBob := []*big.Int{big.NewInt(10), big.NewInt(10)} // with Ingrid initBalsVirtual := []*big.Int{big.NewInt(5), big.NewInt(5)} // Alice proposes vct.virtualBalsUpdated = []*big.Int{big.NewInt(2), big.NewInt(8)} // Send 3. vct.finalBalsAlice = []*big.Int{big.NewInt(7), big.NewInt(13)} vct.finalBalsBob = []*big.Int{big.NewInt(13), big.NewInt(7)} + vct.finalBalIngrid = new(big.Int).Add(vct.finalBalsAlice[1], vct.finalBalsBob[1]) vct.errs = make(chan error, 10) // Setup clients. @@ -102,6 +149,8 @@ func setupVirtualChannelTest(t *testing.T, ctx context.Context) (vct virtualChan t, ) alice, bob, ingrid := clients[0], clients[1], clients[2] + vct.alice, vct.bob, vct.ingrid = alice, bob, ingrid + vct.backend = alice.Backend // Assumes all clients have same backend. _channelsIngrid := make(chan *client.Channel, 1) var openingProposalHandlerIngrid client.ProposalHandlerFunc = func(cp client.ChannelProposal, pr *client.ProposalResponder) {