Skip to content

Commit

Permalink
fixes GetEventStream
Browse files Browse the repository at this point in the history
  • Loading branch information
louisinger committed Sep 19, 2024
1 parent 5626e47 commit 1bb1e28
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 45 deletions.
11 changes: 7 additions & 4 deletions pkg/client-sdk/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
filestore "github.com/ark-network/ark/pkg/client-sdk/wallet/singlekey/store/file"
inmemorystore "github.com/ark-network/ark/pkg/client-sdk/wallet/singlekey/store/inmemory"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/sirupsen/logrus"
)

const (
Expand Down Expand Up @@ -256,11 +257,13 @@ func (a *arkClient) ping(
ticker := time.NewTicker(5 * time.Second)

go func(t *time.Ticker) {
// nolint
a.client.Ping(ctx, paymentID)
if _, err := a.client.Ping(ctx, paymentID); err != nil {
logrus.Warnf("failed to ping asp: %s", err)
}
for range t.C {
// nolint
a.client.Ping(ctx, paymentID)
if _, err := a.client.Ping(ctx, paymentID); err != nil {
logrus.Warnf("failed to ping asp: %s", err)
}
}
}(ticker)

Expand Down
2 changes: 1 addition & 1 deletion pkg/client-sdk/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type ASPClient interface {
) error
GetEventStream(
ctx context.Context, paymentID string,
) (<-chan RoundEventChannel, error)
) (<-chan RoundEventChannel, func(), error)
Ping(ctx context.Context, paymentID string) (RoundEvent, error)
FinalizePayment(
ctx context.Context, signedForfeitTxs []string, signedRoundTx string,
Expand Down
50 changes: 33 additions & 17 deletions pkg/client-sdk/client/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/ark-network/ark/pkg/client-sdk/client"
"github.com/ark-network/ark/pkg/client-sdk/internal/utils"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
Expand All @@ -22,7 +23,6 @@ import (
type grpcClient struct {
conn *grpc.ClientConn
svc arkv1.ArkServiceClient
eventsCh chan client.RoundEventChannel
treeCache *utils.Cache[tree.CongestionTree]
}

Expand All @@ -47,10 +47,9 @@ func NewClient(aspUrl string) (client.ASPClient, error) {
}

svc := arkv1.NewArkServiceClient(conn)
eventsCh := make(chan client.RoundEventChannel)
treeCache := utils.NewCache[tree.CongestionTree]()

return &grpcClient{conn, svc, eventsCh, treeCache}, nil
return &grpcClient{conn, svc, treeCache}, nil
}

func (c *grpcClient) Close() {
Expand All @@ -60,34 +59,47 @@ func (c *grpcClient) Close() {

func (a *grpcClient) GetEventStream(
ctx context.Context, paymentID string,
) (<-chan client.RoundEventChannel, error) {
) (<-chan client.RoundEventChannel, func(), error) {
req := &arkv1.GetEventStreamRequest{}
stream, err := a.svc.GetEventStream(ctx, req)
if err != nil {
return nil, err
return nil, nil, err
}

eventsCh := make(chan client.RoundEventChannel)

go func() {
defer close(a.eventsCh)
defer close(eventsCh)

for {
resp, err := stream.Recv()
if err != nil {
a.eventsCh <- client.RoundEventChannel{Err: err}
select {
case <-stream.Context().Done():
return
}
default:
resp, err := stream.Recv()
if err != nil {
eventsCh <- client.RoundEventChannel{Err: err}
return
}

ev, err := event{resp}.toRoundEvent()
if err != nil {
a.eventsCh <- client.RoundEventChannel{Err: err}
return
}
ev, err := event{resp}.toRoundEvent()
if err != nil {
eventsCh <- client.RoundEventChannel{Err: err}
return
}

a.eventsCh <- client.RoundEventChannel{Event: ev}
eventsCh <- client.RoundEventChannel{Event: ev}
}
}
}()

return a.eventsCh, nil
closeFn := func() {
if err := stream.CloseSend(); err != nil {
logrus.Warnf("failed to close stream: %v", err)
}
}

return eventsCh, closeFn, nil
}

func (a *grpcClient) GetInfo(ctx context.Context) (*client.Info, error) {
Expand Down Expand Up @@ -183,6 +195,10 @@ func (a *grpcClient) Ping(
return nil, err
}

if resp.GetEvent() == nil {
return nil, nil
}

return event{resp}.toRoundEvent()
}

Expand Down
26 changes: 17 additions & 9 deletions pkg/client-sdk/client/rest/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (

type restClient struct {
svc ark_service.ClientService
eventsCh chan client.RoundEventChannel
requestTimeout time.Duration
treeCache *utils.Cache[tree.CongestionTree]
}
Expand All @@ -38,41 +37,46 @@ func NewClient(aspUrl string) (client.ASPClient, error) {
if err != nil {
return nil, err
}
eventsCh := make(chan client.RoundEventChannel)
reqTimeout := 15 * time.Second
treeCache := utils.NewCache[tree.CongestionTree]()

return &restClient{svc, eventsCh, reqTimeout, treeCache}, nil
return &restClient{svc, reqTimeout, treeCache}, nil
}

func (c *restClient) Close() {}

func (a *restClient) GetEventStream(
ctx context.Context, paymentID string,
) (<-chan client.RoundEventChannel, error) {
) (<-chan client.RoundEventChannel, func(), error) {
eventsCh := make(chan client.RoundEventChannel)
stopCh := make(chan struct{})

go func(payID string) {
defer close(a.eventsCh)
defer close(eventsCh)
defer close(stopCh)

timeout := time.After(a.requestTimeout)

for {
select {
case <-stopCh:
return
case <-timeout:
a.eventsCh <- client.RoundEventChannel{
eventsCh <- client.RoundEventChannel{
Err: fmt.Errorf("timeout reached"),
}
return
default:
event, err := a.Ping(ctx, payID)
if err != nil {
a.eventsCh <- client.RoundEventChannel{
eventsCh <- client.RoundEventChannel{
Err: err,
}
return
}

if event != nil {
a.eventsCh <- client.RoundEventChannel{
eventsCh <- client.RoundEventChannel{
Event: event,
}
}
Expand All @@ -82,7 +86,11 @@ func (a *restClient) GetEventStream(
}
}(paymentID)

return a.eventsCh, nil
close := func() {
stopCh <- struct{}{}
}

return eventsCh, close, nil
}

func (a *restClient) GetInfo(
Expand Down
7 changes: 5 additions & 2 deletions pkg/client-sdk/covenant_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ func (a *covenantArkClient) handleRoundStream(
mustSignRoundTx bool,
receivers []client.Output,
) (string, error) {
eventsCh, err := a.client.GetEventStream(ctx, paymentID)
eventsCh, close, err := a.client.GetEventStream(ctx, paymentID)
if err != nil {
return "", err
}
Expand All @@ -934,7 +934,10 @@ func (a *covenantArkClient) handleRoundStream(
pingStop = a.ping(ctx, paymentID)
}

defer pingStop()
defer func() {
pingStop()
close()
}()

for {
select {
Expand Down
7 changes: 5 additions & 2 deletions pkg/client-sdk/covenantless_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ func (a *covenantlessArkClient) handleRoundStream(
receivers []client.Output,
roundEphemeralKey *secp256k1.PrivateKey,
) (string, error) {
eventsCh, err := a.client.GetEventStream(ctx, paymentID)
eventsCh, close, err := a.client.GetEventStream(ctx, paymentID)
if err != nil {
return "", err
}
Expand All @@ -971,7 +971,10 @@ func (a *covenantlessArkClient) handleRoundStream(
pingStop = a.ping(ctx, paymentID)
}

defer pingStop()
defer func() {
pingStop()
close()
}()

var signerSession bitcointree.SignerSession

Expand Down
32 changes: 23 additions & 9 deletions server/internal/interface/grpc/handlers/arkservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ import (
)

type listener struct {
id string
ch chan *arkv1.GetEventStreamResponse
id string
done chan struct{}
ch chan *arkv1.GetEventStreamResponse
}

type handler struct {
Expand Down Expand Up @@ -297,21 +298,25 @@ func (h *handler) GetRoundById(
}

func (h *handler) GetEventStream(_ *arkv1.GetEventStreamRequest, stream arkv1.ArkService_GetEventStreamServer) error {
doneCh := make(chan struct{})

listener := &listener{
id: uuid.NewString(),
ch: make(chan *arkv1.GetEventStreamResponse),
id: uuid.NewString(),
done: doneCh,
ch: make(chan *arkv1.GetEventStreamResponse),
}

h.pushListener(listener)
defer h.removeListener(listener.id)
defer close(listener.ch)

h.pushListener(listener)
defer close(doneCh)

for {
select {
case <-stream.Context().Done():
return nil

case <-doneCh:
return nil
case ev := <-listener.ch:
if err := stream.Send(ev); err != nil {
return err
Expand Down Expand Up @@ -472,6 +477,7 @@ func (h *handler) listenToEvents() {
channel := h.svc.GetEventsChannel(context.Background())
for event := range channel {
var ev *arkv1.GetEventStreamResponse
shouldClose := false

switch e := event.(type) {
case domain.RoundFinalizationStarted:
Expand All @@ -487,6 +493,7 @@ func (h *handler) listenToEvents() {
},
}
case domain.RoundFinalized:
shouldClose = true
ev = &arkv1.GetEventStreamResponse{
Event: &arkv1.GetEventStreamResponse_RoundFinalized{
RoundFinalized: &arkv1.RoundFinalizedEvent{
Expand All @@ -496,6 +503,7 @@ func (h *handler) listenToEvents() {
},
}
case domain.RoundFailed:
shouldClose = true
ev = &arkv1.GetEventStreamResponse{
Event: &arkv1.GetEventStreamResponse_RoundFailed{
RoundFailed: &arkv1.RoundFailed{
Expand Down Expand Up @@ -538,8 +546,14 @@ func (h *handler) listenToEvents() {
}

if ev != nil {
for _, listener := range h.listeners {
listener.ch <- ev
logrus.Debugf("forwarding event to %d listeners", len(h.listeners))
for _, l := range h.listeners {
go func(l *listener) {
l.ch <- ev
if shouldClose {
l.done <- struct{}{}
}
}(l)
}
}
}
Expand Down
34 changes: 33 additions & 1 deletion server/test/e2e/covenantless/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,8 @@ func TestAliceSeveralPaymentsToBob(t *testing.T) {
_, err = alice.SendOffChain(ctx, false, []arksdk.Receiver{arksdk.NewBitcoinReceiver(bobAddress, 1000)})
require.NoError(t, err)

time.Sleep(2 * time.Second)

bobVtxos, _, err := bob.ListVtxos(ctx)
require.NoError(t, err)
require.Len(t, bobVtxos, 1)
Expand All @@ -321,22 +323,52 @@ func TestAliceSeveralPaymentsToBob(t *testing.T) {
_, err = alice.SendOffChain(ctx, false, []arksdk.Receiver{arksdk.NewBitcoinReceiver(bobAddress, 10000)})
require.NoError(t, err)

time.Sleep(2 * time.Second)

bobVtxos, _, err = bob.ListVtxos(ctx)
require.NoError(t, err)
require.Len(t, bobVtxos, 2)

_, err = alice.SendOffChain(ctx, false, []arksdk.Receiver{arksdk.NewBitcoinReceiver(bobAddress, 10000)})
require.NoError(t, err)

time.Sleep(2 * time.Second)

bobVtxos, _, err = bob.ListVtxos(ctx)
require.NoError(t, err)
require.Len(t, bobVtxos, 3)

_, err = bob.Claim(ctx)
_, err = alice.SendAsync(ctx, false, []arksdk.Receiver{arksdk.NewBitcoinReceiver(bobAddress, 10000)})
require.NoError(t, err)

time.Sleep(2 * time.Second)

bobVtxos, _, err = bob.ListVtxos(ctx)
require.NoError(t, err)
require.Len(t, bobVtxos, 4)

_, err = alice.Claim(ctx)
require.NoError(t, err)

_, err = alice.SendAsync(ctx, false, []arksdk.Receiver{arksdk.NewBitcoinReceiver(bobAddress, 10000)})
require.NoError(t, err)

time.Sleep(2 * time.Second)

bobVtxos, _, err = bob.ListVtxos(ctx)
require.NoError(t, err)
require.Len(t, bobVtxos, 5)

// bobVtxos should be unique
uniqueVtxos := make(map[string]struct{})
for _, v := range bobVtxos {
uniqueVtxos[fmt.Sprintf("%s:%d", v.Txid, v.VOut)] = struct{}{}
}
require.Len(t, uniqueVtxos, 5)

_, err = bob.Claim(ctx)
require.NoError(t, err)

}

func runClarkCommand(arg ...string) (string, error) {
Expand Down

0 comments on commit 1bb1e28

Please sign in to comment.