Skip to content

Commit

Permalink
tests: use implemented server for libp2p transport tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tinyzimmer committed Oct 30, 2023
1 parent 14f7431 commit 8f4b696
Showing 1 changed file with 77 additions and 76 deletions.
153 changes: 77 additions & 76 deletions pkg/meshnet/transport/libp2p/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"

"github.com/webmeshproj/webmesh/pkg/context"
"github.com/webmeshproj/webmesh/pkg/crypto"
Expand All @@ -45,6 +46,67 @@ import (
"github.com/webmeshproj/webmesh/pkg/plugins/clients"
)

var testNode = &v1.MeshNode{
Id: "test-node",
PublicKey: must(crypto.MustGenerateKey().PublicKey().Encode),
}

func must(fn func() (string, error)) string {
s, err := fn()
if err != nil {
panic(err)
}
return s
}

type TestMeshAPI struct {
v1.UnimplementedMeshServer
}

func (*TestMeshAPI) GetMeshGraph(context.Context, *emptypb.Empty) (*v1.MeshGraph, error) {
// Leave unimplemented.
return nil, status.Errorf(codes.Unimplemented, "unimplemented")
}

func (*TestMeshAPI) GetNode(context.Context, *v1.GetNodeRequest) (*v1.MeshNode, error) {
// Return a dummy node
return testNode, nil
}

func (*TestMeshAPI) ListNodes(context.Context, *emptypb.Empty) (*v1.NodeList, error) {
// Use for a custom error message
return nil, status.Errorf(codes.Internal, "something went wrong")
}

func RunClientConnTests(ctx context.Context, t *testing.T, c *grpc.ClientConn) {
t.Helper()
cli := v1.NewMeshClient(c)
_, err := cli.GetMeshGraph(ctx, &emptypb.Empty{})
if err == nil {
t.Fatal("Expected error, got nil")
}
if status.Code(err) != codes.Unimplemented {
t.Fatal("Expected unimplemented error, got", err)
}
_, err = cli.ListNodes(ctx, &emptypb.Empty{})
if err == nil {
t.Fatal("Expected error, got nil")
}
if status.Code(err) != codes.Internal {
t.Fatal("Expected internal error, got", err)
}
node, err := cli.GetNode(ctx, &v1.GetNodeRequest{})
if err != nil {
t.Fatal("GetNode:", err)
}
if node.Id != testNode.Id {
t.Fatal("Expected node ID", testNode.Id, "got", node.Id)
}
if node.PublicKey != testNode.PublicKey {
t.Fatal("Expected node public key", testNode.PublicKey, "got", node.PublicKey)
}
}

func TestRPCTransport(t *testing.T) {
ctx := context.Background()

Expand Down Expand Up @@ -89,7 +151,7 @@ func TestRPCTransport(t *testing.T) {
}
// Create a dummy gRPC server and register an unimplemented service.
srv := grpc.NewServer()
v1.RegisterMeshServer(srv, v1.UnimplementedMeshServer{})
v1.RegisterMeshServer(srv, &TestMeshAPI{})
go func() {
err := srv.Serve(server.RPCListener())
if err != nil {
Expand All @@ -107,15 +169,7 @@ func TestRPCTransport(t *testing.T) {
t.Fatal("Dial server address:", err)
}
defer c.Close()
cli := v1.NewMeshClient(c)
_, err = cli.GetNode(ctx, &v1.GetNodeRequest{})
// We should actually get an unimplemented error here.
if err == nil {
t.Fatal("Expected error, got nil")
}
if status.Code(err) != codes.Unimplemented {
t.Fatal("Expected unimplemented error, got", err)
}
RunClientConnTests(ctx, t, c)
})

t.Run("DialByMultiaddr", func(t *testing.T) {
Expand All @@ -125,15 +179,7 @@ func TestRPCTransport(t *testing.T) {
t.Fatal("Dial server address:", err)
}
defer c.Close()
cli := v1.NewMeshClient(c)
_, err = cli.GetNode(ctx, &v1.GetNodeRequest{})
// We should actually get an unimplemented error here.
if err == nil {
t.Fatal("Expected error, got nil")
}
if status.Code(err) != codes.Unimplemented {
t.Fatal("Expected unimplemented error, got", err)
}
RunClientConnTests(ctx, t, c)
}
})
})
Expand All @@ -159,7 +205,7 @@ func TestRPCTransport(t *testing.T) {
// Create a dummy gRPC server and register an unimplemented service.
srv := grpc.NewServer()
t.Cleanup(srv.Stop)
v1.RegisterMeshServer(srv, v1.UnimplementedMeshServer{})
v1.RegisterMeshServer(srv, &TestMeshAPI{})
go func() {
err := srv.Serve(server.RPCListener())
if err != nil {
Expand All @@ -176,15 +222,7 @@ func TestRPCTransport(t *testing.T) {
t.Fatal("Dial server address:", err)
}
defer c.Close()
cli := v1.NewMeshClient(c)
_, err = cli.GetNode(ctx, &v1.GetNodeRequest{})
// We should actually get an unimplemented error here.
if err == nil {
t.Fatal("Expected error, got nil")
}
if status.Code(err) != codes.Unimplemented {
t.Fatal("Expected unimplemented error, got", err)
}
RunClientConnTests(ctx, t, c)
}
})

Expand Down Expand Up @@ -227,7 +265,7 @@ func TestRPCTransport(t *testing.T) {
idauthcli := clients.NewInProcessClient(idauthsrv)
srv := grpc.NewServer(grpc.ChainUnaryInterceptor(plugins.NewAuthUnaryInterceptor(idauthcli.Auth())))
t.Cleanup(srv.Stop)
v1.RegisterMeshServer(srv, v1.UnimplementedMeshServer{})
v1.RegisterMeshServer(srv, &TestMeshAPI{})
go func() {
err := srv.Serve(server.RPCListener())
if err != nil {
Expand All @@ -244,15 +282,7 @@ func TestRPCTransport(t *testing.T) {
t.Fatal("Dial server address:", err)
}
defer c.Close()
cli := v1.NewMeshClient(c)
_, err = cli.GetNode(ctx, &v1.GetNodeRequest{})
// We should actually get an unimplemented error here.
if err == nil {
t.Fatal("Expected error, got nil")
}
if status.Code(err) != codes.Unimplemented {
t.Fatal("Expected unimplemented error, got", err)
}
RunClientConnTests(ctx, t, c)
}
})
// Test that an unallowed ID can use the server, but will be rejected.
Expand Down Expand Up @@ -307,7 +337,7 @@ func TestRPCTransport(t *testing.T) {
}
srv := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsconf)))
t.Cleanup(srv.Stop)
v1.RegisterMeshServer(srv, v1.UnimplementedMeshServer{})
v1.RegisterMeshServer(srv, &TestMeshAPI{})
go func() {
err := srv.Serve(server.RPCListener())
if err != nil {
Expand All @@ -324,15 +354,7 @@ func TestRPCTransport(t *testing.T) {
t.Fatal("Dial server address:", err)
}
defer c.Close()
cli := v1.NewMeshClient(c)
_, err = cli.GetNode(ctx, &v1.GetNodeRequest{})
// We should actually get an unimplemented error here.
if err == nil {
t.Fatal("Expected error, got nil")
}
if status.Code(err) != codes.Unimplemented {
t.Fatal("Expected unimplemented error, got", err)
}
RunClientConnTests(ctx, t, c)
}
})

Expand Down Expand Up @@ -392,7 +414,7 @@ func TestRPCTransport(t *testing.T) {
t.Fatal(err)
}
srv := grpc.NewServer(servercreds)
v1.RegisterMeshServer(srv, v1.UnimplementedMeshServer{})
v1.RegisterMeshServer(srv, &TestMeshAPI{})
go func() {
err := srv.Serve(server.RPCListener())
if err != nil {
Expand All @@ -419,14 +441,7 @@ func TestRPCTransport(t *testing.T) {
t.Fatal("Dial server address:", err)
}
defer c.Close()
cli := v1.NewMeshClient(c)
_, err = cli.GetNode(ctx, &v1.GetNodeRequest{})
if err == nil {
t.Fatal("Expected error, got nil")
}
if status.Code(err) != codes.Unimplemented {
t.Fatal("Expected unimplemented error, got", err)
}
RunClientConnTests(ctx, t, c)
}
})

Expand Down Expand Up @@ -501,7 +516,7 @@ func TestDiscoveryRPCTransport(t *testing.T) {
server.Announce(ctx, rendezvous, time.Minute)
t.Cleanup(func() { _ = server.Close() })
srv := grpc.NewServer(grpc.Creds(insecure.NewCredentials()))
v1.RegisterMeshServer(srv, v1.UnimplementedMeshServer{})
v1.RegisterMeshServer(srv, &TestMeshAPI{})
go func() {
err := srv.Serve(server.RPCListener())
if err != nil {
Expand Down Expand Up @@ -532,14 +547,7 @@ func TestDiscoveryRPCTransport(t *testing.T) {
t.Fatal("Dial server address:", err)
}
defer c.Close()
cli := v1.NewMeshClient(c)
_, err = cli.GetNode(ctx, &v1.GetNodeRequest{})
if err == nil {
t.Fatal("Expected error, got nil")
}
if status.Code(err) != codes.Unimplemented {
t.Fatal("Expected unimplemented error, got", err)
}
RunClientConnTests(ctx, t, c)
})

t.Run("PrestartedHosts", func(t *testing.T) {
Expand All @@ -565,7 +573,7 @@ func TestDiscoveryRPCTransport(t *testing.T) {
server.Announce(ctx, rendezvous, time.Minute)
t.Cleanup(func() { _ = server.Close() })
srv := grpc.NewServer(grpc.Creds(insecure.NewCredentials()))
v1.RegisterMeshServer(srv, v1.UnimplementedMeshServer{})
v1.RegisterMeshServer(srv, &TestMeshAPI{})
go func() {
err := srv.Serve(server.RPCListener())
if err != nil {
Expand Down Expand Up @@ -602,13 +610,6 @@ func TestDiscoveryRPCTransport(t *testing.T) {
t.Fatal("Dial server address:", err)
}
defer c.Close()
cli := v1.NewMeshClient(c)
_, err = cli.GetNode(ctx, &v1.GetNodeRequest{})
if err == nil {
t.Fatal("Expected error, got nil")
}
if status.Code(err) != codes.Unimplemented {
t.Fatal("Expected unimplemented error, got", err)
}
RunClientConnTests(ctx, t, c)
})
}

0 comments on commit 8f4b696

Please sign in to comment.