Skip to content

Commit

Permalink
improved unit tests to tss/comm
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanMweiss committed Feb 10, 2025
1 parent 6720675 commit 46167eb
Showing 1 changed file with 101 additions and 1 deletion.
102 changes: 101 additions & 1 deletion node/pkg/tss/comm/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"net"
"sync/atomic"
Expand Down Expand Up @@ -45,7 +46,9 @@ func (m *mockTssMessageHandler) FetchPartyId(*x509.Certificate) (*tsscommv1.Part
func (m *mockTssMessageHandler) ProducedOutputMessages() <-chan tss.Sendable {
return m.chn
}
func (m *mockTssMessageHandler) HandleIncomingTssMessage(msg tss.Incoming) {}
func (m *mockTssMessageHandler) HandleIncomingTssMessage(msg tss.Incoming) {
fmt.Println("received message from", msg.GetSource())
}

// wraps regular server and changes its Send function.
type testServer struct {
Expand Down Expand Up @@ -671,3 +674,100 @@ func TestDialWithDefaultPort(t *testing.T) {

t.FailNow()
}

type mockJustHandleIncomingMessage struct {
tss.ReliableMessenger
receivedData chan tss.Incoming
}

func (m *mockJustHandleIncomingMessage) HandleIncomingTssMessage(msg tss.Incoming) {
m.receivedData <- msg
close(m.receivedData)
}

func TestDialWithDefaultPortDeliverCorrectSrc(t *testing.T) {
a := require.New(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*40)
defer cancel()
ctx = testutils.MakeSupervisorContext(ctx)

en, err := _loadGuardians(3)
a.NoError(err)

streamReceiverEngine := en[0]
senderEngine := en[1]

streamReceiverPath := "localhost:5930"

// ensuring when a message arrives, the server idetntifies the source according to
// the tls key, then returns the tss.PartyID according to what is
// stored in the guardian storage.
expectedText := "This text is what i expect to see in the incoming message."
for _, v := range streamReceiverEngine.Guardians {
if v.Id == senderEngine.Self.Id {
v.Id = expectedText
continue
}
v.Id = ""
}

// Setting the ID as they will be sent and used to connect to the other party.
for _, v := range senderEngine.Guardians {
if v.Id == streamReceiverEngine.Self.Id {
v.Id = streamReceiverPath // ensuring the server connects
continue
}
v.Id = ""
}

incomingDataChan := make(chan tss.Incoming)
listenerServer, err := NewServer(streamReceiverPath, supervisor.Logger(ctx),
&mockJustHandleIncomingMessage{
ReliableMessenger: streamReceiverEngine,
receivedData: incomingDataChan,
},
)
a.NoError(err)

ListenerWrapper := listenerServer.(*server)
ListenerWrapper.ctx = ctx

l, err := net.Listen("tcp", streamReceiverPath)
a.NoError(err)
defer l.Close()

gserver := grpc.NewServer(ListenerWrapper.makeServerCredentials())
defer gserver.Stop()

tsscommv1.RegisterDirectLinkServer(gserver, ListenerWrapper)
go gserver.Serve(l)

msgChan := make(chan tss.Sendable)
sender, err := NewServer("nonsensePort", supervisor.Logger(ctx), &tssMockJustForMessageGeneration{
ReliableMessenger: senderEngine,
chn: msgChan,
})
a.NoError(err)

tmp := sender.(*server)
tmp.ctx = ctx
tmp.run() // demanding this server run.

time.Sleep(time.Second * 1)

//should set up connection with the stream r

msgChan <- &tss.Echo{
Echo: &tsscommv1.Echo{},
Recipients: []*tsscommv1.PartyId{&tsscommv1.PartyId{Id: streamReceiverPath}},
}

select {
case <-ctx.Done():
t.FailNow()
case incoming := <-incomingDataChan:
// ensuring the incoming message has the correct ID without any port.
a.Equal(expectedText, incoming.GetSource().Id)
return
}
}

0 comments on commit 46167eb

Please sign in to comment.