diff --git a/pkg/client-sdk/client.go b/pkg/client-sdk/client.go index fcd36e9a..acd2068e 100644 --- a/pkg/client-sdk/client.go +++ b/pkg/client-sdk/client.go @@ -18,6 +18,7 @@ import ( filestore "github.com/ark-network/ark/pkg/client-sdk/wallet/singlekey/store/file" inmemorystore "github.com/ark-network/ark/pkg/client-sdk/wallet/singlekey/store/inmemory" "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/sirupsen/logrus" ) const ( @@ -256,11 +257,13 @@ func (a *arkClient) ping( ticker := time.NewTicker(5 * time.Second) go func(t *time.Ticker) { - // nolint - a.client.Ping(ctx, paymentID) + if _, err := a.client.Ping(ctx, paymentID); err != nil { + logrus.Warnf("failed to ping asp: %s", err) + } for range t.C { - // nolint - a.client.Ping(ctx, paymentID) + if _, err := a.client.Ping(ctx, paymentID); err != nil { + logrus.Warnf("failed to ping asp: %s", err) + } } }(ticker) diff --git a/pkg/client-sdk/client/client.go b/pkg/client-sdk/client/client.go index d24b1fa4..45436d92 100644 --- a/pkg/client-sdk/client/client.go +++ b/pkg/client-sdk/client/client.go @@ -31,7 +31,7 @@ type ASPClient interface { ) error GetEventStream( ctx context.Context, paymentID string, - ) (<-chan RoundEventChannel, error) + ) (<-chan RoundEventChannel, func(), error) Ping(ctx context.Context, paymentID string) (RoundEvent, error) FinalizePayment( ctx context.Context, signedForfeitTxs []string, signedRoundTx string, diff --git a/pkg/client-sdk/client/grpc/client.go b/pkg/client-sdk/client/grpc/client.go index 5d283e5f..3a1b00b5 100644 --- a/pkg/client-sdk/client/grpc/client.go +++ b/pkg/client-sdk/client/grpc/client.go @@ -14,6 +14,7 @@ import ( "github.com/ark-network/ark/pkg/client-sdk/client" "github.com/ark-network/ark/pkg/client-sdk/internal/utils" "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" @@ -22,7 +23,6 @@ import ( type grpcClient struct { conn *grpc.ClientConn svc arkv1.ArkServiceClient - eventsCh chan client.RoundEventChannel treeCache *utils.Cache[tree.CongestionTree] } @@ -47,10 +47,9 @@ func NewClient(aspUrl string) (client.ASPClient, error) { } svc := arkv1.NewArkServiceClient(conn) - eventsCh := make(chan client.RoundEventChannel) treeCache := utils.NewCache[tree.CongestionTree]() - return &grpcClient{conn, svc, eventsCh, treeCache}, nil + return &grpcClient{conn, svc, treeCache}, nil } func (c *grpcClient) Close() { @@ -60,34 +59,47 @@ func (c *grpcClient) Close() { func (a *grpcClient) GetEventStream( ctx context.Context, paymentID string, -) (<-chan client.RoundEventChannel, error) { +) (<-chan client.RoundEventChannel, func(), error) { req := &arkv1.GetEventStreamRequest{} stream, err := a.svc.GetEventStream(ctx, req) if err != nil { - return nil, err + return nil, nil, err } + eventsCh := make(chan client.RoundEventChannel) + go func() { - defer close(a.eventsCh) + defer close(eventsCh) for { - resp, err := stream.Recv() - if err != nil { - a.eventsCh <- client.RoundEventChannel{Err: err} + select { + case <-stream.Context().Done(): return - } + default: + resp, err := stream.Recv() + if err != nil { + eventsCh <- client.RoundEventChannel{Err: err} + return + } - ev, err := event{resp}.toRoundEvent() - if err != nil { - a.eventsCh <- client.RoundEventChannel{Err: err} - return - } + ev, err := event{resp}.toRoundEvent() + if err != nil { + eventsCh <- client.RoundEventChannel{Err: err} + return + } - a.eventsCh <- client.RoundEventChannel{Event: ev} + eventsCh <- client.RoundEventChannel{Event: ev} + } } }() - return a.eventsCh, nil + closeFn := func() { + if err := stream.CloseSend(); err != nil { + logrus.Warnf("failed to close stream: %v", err) + } + } + + return eventsCh, closeFn, nil } func (a *grpcClient) GetInfo(ctx context.Context) (*client.Info, error) { @@ -183,6 +195,10 @@ func (a *grpcClient) Ping( return nil, err } + if resp.GetEvent() == nil { + return nil, nil + } + return event{resp}.toRoundEvent() } diff --git a/pkg/client-sdk/client/rest/client.go b/pkg/client-sdk/client/rest/client.go index 67d1bbfe..3512048a 100644 --- a/pkg/client-sdk/client/rest/client.go +++ b/pkg/client-sdk/client/rest/client.go @@ -25,7 +25,6 @@ import ( type restClient struct { svc ark_service.ClientService - eventsCh chan client.RoundEventChannel requestTimeout time.Duration treeCache *utils.Cache[tree.CongestionTree] } @@ -38,41 +37,46 @@ func NewClient(aspUrl string) (client.ASPClient, error) { if err != nil { return nil, err } - eventsCh := make(chan client.RoundEventChannel) reqTimeout := 15 * time.Second treeCache := utils.NewCache[tree.CongestionTree]() - return &restClient{svc, eventsCh, reqTimeout, treeCache}, nil + return &restClient{svc, reqTimeout, treeCache}, nil } func (c *restClient) Close() {} func (a *restClient) GetEventStream( ctx context.Context, paymentID string, -) (<-chan client.RoundEventChannel, error) { +) (<-chan client.RoundEventChannel, func(), error) { + eventsCh := make(chan client.RoundEventChannel) + stopCh := make(chan struct{}) + go func(payID string) { - defer close(a.eventsCh) + defer close(eventsCh) + defer close(stopCh) timeout := time.After(a.requestTimeout) for { select { + case <-stopCh: + return case <-timeout: - a.eventsCh <- client.RoundEventChannel{ + eventsCh <- client.RoundEventChannel{ Err: fmt.Errorf("timeout reached"), } return default: event, err := a.Ping(ctx, payID) if err != nil { - a.eventsCh <- client.RoundEventChannel{ + eventsCh <- client.RoundEventChannel{ Err: err, } return } if event != nil { - a.eventsCh <- client.RoundEventChannel{ + eventsCh <- client.RoundEventChannel{ Event: event, } } @@ -82,7 +86,11 @@ func (a *restClient) GetEventStream( } }(paymentID) - return a.eventsCh, nil + close := func() { + stopCh <- struct{}{} + } + + return eventsCh, close, nil } func (a *restClient) GetInfo( diff --git a/pkg/client-sdk/covenant_client.go b/pkg/client-sdk/covenant_client.go index f792a130..93e28500 100644 --- a/pkg/client-sdk/covenant_client.go +++ b/pkg/client-sdk/covenant_client.go @@ -924,7 +924,7 @@ func (a *covenantArkClient) handleRoundStream( mustSignRoundTx bool, receivers []client.Output, ) (string, error) { - eventsCh, err := a.client.GetEventStream(ctx, paymentID) + eventsCh, close, err := a.client.GetEventStream(ctx, paymentID) if err != nil { return "", err } @@ -934,7 +934,10 @@ func (a *covenantArkClient) handleRoundStream( pingStop = a.ping(ctx, paymentID) } - defer pingStop() + defer func() { + pingStop() + close() + }() for { select { diff --git a/pkg/client-sdk/covenantless_client.go b/pkg/client-sdk/covenantless_client.go index 87524e7c..6cb32cb1 100644 --- a/pkg/client-sdk/covenantless_client.go +++ b/pkg/client-sdk/covenantless_client.go @@ -961,7 +961,7 @@ func (a *covenantlessArkClient) handleRoundStream( receivers []client.Output, roundEphemeralKey *secp256k1.PrivateKey, ) (string, error) { - eventsCh, err := a.client.GetEventStream(ctx, paymentID) + eventsCh, close, err := a.client.GetEventStream(ctx, paymentID) if err != nil { return "", err } @@ -971,7 +971,10 @@ func (a *covenantlessArkClient) handleRoundStream( pingStop = a.ping(ctx, paymentID) } - defer pingStop() + defer func() { + pingStop() + close() + }() var signerSession bitcointree.SignerSession diff --git a/server/internal/interface/grpc/handlers/arkservice.go b/server/internal/interface/grpc/handlers/arkservice.go index 20e77d94..2dfe1961 100644 --- a/server/internal/interface/grpc/handlers/arkservice.go +++ b/server/internal/interface/grpc/handlers/arkservice.go @@ -18,8 +18,9 @@ import ( ) type listener struct { - id string - ch chan *arkv1.GetEventStreamResponse + id string + done chan struct{} + ch chan *arkv1.GetEventStreamResponse } type handler struct { @@ -297,21 +298,25 @@ func (h *handler) GetRoundById( } func (h *handler) GetEventStream(_ *arkv1.GetEventStreamRequest, stream arkv1.ArkService_GetEventStreamServer) error { + doneCh := make(chan struct{}) + listener := &listener{ - id: uuid.NewString(), - ch: make(chan *arkv1.GetEventStreamResponse), + id: uuid.NewString(), + done: doneCh, + ch: make(chan *arkv1.GetEventStreamResponse), } + h.pushListener(listener) defer h.removeListener(listener.id) defer close(listener.ch) - - h.pushListener(listener) + defer close(doneCh) for { select { case <-stream.Context().Done(): return nil - + case <-doneCh: + return nil case ev := <-listener.ch: if err := stream.Send(ev); err != nil { return err @@ -472,6 +477,7 @@ func (h *handler) listenToEvents() { channel := h.svc.GetEventsChannel(context.Background()) for event := range channel { var ev *arkv1.GetEventStreamResponse + shouldClose := false switch e := event.(type) { case domain.RoundFinalizationStarted: @@ -487,6 +493,7 @@ func (h *handler) listenToEvents() { }, } case domain.RoundFinalized: + shouldClose = true ev = &arkv1.GetEventStreamResponse{ Event: &arkv1.GetEventStreamResponse_RoundFinalized{ RoundFinalized: &arkv1.RoundFinalizedEvent{ @@ -496,6 +503,7 @@ func (h *handler) listenToEvents() { }, } case domain.RoundFailed: + shouldClose = true ev = &arkv1.GetEventStreamResponse{ Event: &arkv1.GetEventStreamResponse_RoundFailed{ RoundFailed: &arkv1.RoundFailed{ @@ -538,8 +546,14 @@ func (h *handler) listenToEvents() { } if ev != nil { - for _, listener := range h.listeners { - listener.ch <- ev + logrus.Debugf("forwarding event to %d listeners", len(h.listeners)) + for _, l := range h.listeners { + go func(l *listener) { + l.ch <- ev + if shouldClose { + l.done <- struct{}{} + } + }(l) } } } diff --git a/server/test/e2e/covenantless/e2e_test.go b/server/test/e2e/covenantless/e2e_test.go index d12ff8f5..479b412b 100644 --- a/server/test/e2e/covenantless/e2e_test.go +++ b/server/test/e2e/covenantless/e2e_test.go @@ -308,6 +308,8 @@ func TestAliceSeveralPaymentsToBob(t *testing.T) { _, err = alice.SendOffChain(ctx, false, []arksdk.Receiver{arksdk.NewBitcoinReceiver(bobAddress, 1000)}) require.NoError(t, err) + time.Sleep(2 * time.Second) + bobVtxos, _, err := bob.ListVtxos(ctx) require.NoError(t, err) require.Len(t, bobVtxos, 1) @@ -321,6 +323,8 @@ func TestAliceSeveralPaymentsToBob(t *testing.T) { _, err = alice.SendOffChain(ctx, false, []arksdk.Receiver{arksdk.NewBitcoinReceiver(bobAddress, 10000)}) require.NoError(t, err) + time.Sleep(2 * time.Second) + bobVtxos, _, err = bob.ListVtxos(ctx) require.NoError(t, err) require.Len(t, bobVtxos, 2) @@ -328,15 +332,43 @@ func TestAliceSeveralPaymentsToBob(t *testing.T) { _, err = alice.SendOffChain(ctx, false, []arksdk.Receiver{arksdk.NewBitcoinReceiver(bobAddress, 10000)}) require.NoError(t, err) + time.Sleep(2 * time.Second) + bobVtxos, _, err = bob.ListVtxos(ctx) require.NoError(t, err) require.Len(t, bobVtxos, 3) - _, err = bob.Claim(ctx) + _, err = alice.SendAsync(ctx, false, []arksdk.Receiver{arksdk.NewBitcoinReceiver(bobAddress, 10000)}) require.NoError(t, err) + time.Sleep(2 * time.Second) + + bobVtxos, _, err = bob.ListVtxos(ctx) + require.NoError(t, err) + require.Len(t, bobVtxos, 4) + _, err = alice.Claim(ctx) require.NoError(t, err) + + _, err = alice.SendAsync(ctx, false, []arksdk.Receiver{arksdk.NewBitcoinReceiver(bobAddress, 10000)}) + require.NoError(t, err) + + time.Sleep(2 * time.Second) + + bobVtxos, _, err = bob.ListVtxos(ctx) + require.NoError(t, err) + require.Len(t, bobVtxos, 5) + + // bobVtxos should be unique + uniqueVtxos := make(map[string]struct{}) + for _, v := range bobVtxos { + uniqueVtxos[fmt.Sprintf("%s:%d", v.Txid, v.VOut)] = struct{}{} + } + require.Len(t, uniqueVtxos, 5) + + _, err = bob.Claim(ctx) + require.NoError(t, err) + } func runClarkCommand(arg ...string) (string, error) {