Skip to content

Commit

Permalink
Adds unary interceptor tests (#467)
Browse files Browse the repository at this point in the history
  • Loading branch information
nickzelei authored Oct 31, 2023
1 parent 7afc40f commit febcb60
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 2 deletions.
67 changes: 67 additions & 0 deletions backend/internal/connect/interceptors/auth/interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package auth_interceptor

import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"

"connectrpc.com/connect"
mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
"github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect"
"github.com/stretchr/testify/assert"
)

func Test_Interceptor_WrapUnary_Disallow_All(t *testing.T) {
interceptor := NewInterceptor(func(ctx context.Context, header http.Header) (context.Context, error) {
return nil, errors.New("no dice")
})

mux := http.NewServeMux()
mux.Handle(mgmtv1alpha1connect.UserAccountServiceGetUserProcedure, connect.NewUnaryHandler(
mgmtv1alpha1connect.UserAccountServiceGetUserProcedure,
func(ctx context.Context, r *connect.Request[mgmtv1alpha1.GetUserRequest]) (*connect.Response[mgmtv1alpha1.GetUserRequest], error) {
return nil, connect.NewError(connect.CodeInternal, errors.New("oh no"))
},
connect.WithInterceptors(interceptor),
))
srv := startHTTPServer(t, mux)

client := mgmtv1alpha1connect.NewUserAccountServiceClient(srv.Client(), srv.URL)
resp, err := client.GetUser(context.Background(), connect.NewRequest(&mgmtv1alpha1.GetUserRequest{}))
assert.Error(t, err)
assert.ErrorContains(t, err, "no dice")
assert.Nil(t, resp)
}

func Test_Interceptor_WrapUnary_Allow_All(t *testing.T) {
interceptor := NewInterceptor(func(ctx context.Context, header http.Header) (context.Context, error) {
return ctx, nil
})

mux := http.NewServeMux()
mux.Handle(mgmtv1alpha1connect.UserAccountServiceGetUserProcedure, connect.NewUnaryHandler(
mgmtv1alpha1connect.UserAccountServiceGetUserProcedure,
func(ctx context.Context, r *connect.Request[mgmtv1alpha1.GetUserRequest]) (*connect.Response[mgmtv1alpha1.GetUserRequest], error) {
return nil, connect.NewError(connect.CodeInternal, errors.New("oh no"))
},
connect.WithInterceptors(interceptor),
))
srv := startHTTPServer(t, mux)

client := mgmtv1alpha1connect.NewUserAccountServiceClient(srv.Client(), srv.URL)
resp, err := client.GetUser(context.Background(), connect.NewRequest(&mgmtv1alpha1.GetUserRequest{}))
assert.Error(t, err)
assert.ErrorContains(t, err, "oh no")
assert.Nil(t, resp)
}

func startHTTPServer(tb testing.TB, h http.Handler) *httptest.Server {
tb.Helper()
srv := httptest.NewUnstartedServer(h)
srv.EnableHTTP2 = true
srv.Start()
tb.Cleanup(srv.Close)
return srv
}
9 changes: 7 additions & 2 deletions backend/internal/connect/interceptors/logger/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func NewInterceptor(logger *slog.Logger) connect.Interceptor {

func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) {
newCtx := context.WithValue(ctx, loggerContextKey{}, &loggerContextData{logger: i.logger.With()})
newCtx := setLoggerContext(ctx, clonelogger(i.logger))
return next(newCtx, request)
}
}
Expand All @@ -32,7 +32,12 @@ func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) conn

func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
newCtx := context.WithValue(ctx, loggerContextKey{}, &loggerContextData{logger: i.logger.With()})
newCtx := setLoggerContext(ctx, clonelogger(i.logger))
return next(newCtx, conn)
}
}

func clonelogger(logger *slog.Logger) *slog.Logger {
c := *logger
return &c
}
48 changes: 48 additions & 0 deletions backend/internal/connect/interceptors/logger/interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package logger_interceptor

import (
"context"
"log/slog"
"net/http"
"net/http/httptest"
"os"
"testing"

"connectrpc.com/connect"
mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
"github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect"
"github.com/stretchr/testify/assert"
)

func Test_Interceptor_WrapUnary_InjectLogger(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
interceptor := NewInterceptor(logger)

var ctxlogger *slog.Logger

mux := http.NewServeMux()
mux.Handle(mgmtv1alpha1connect.UserAccountServiceGetUserProcedure, connect.NewUnaryHandler(
mgmtv1alpha1connect.UserAccountServiceGetUserProcedure,
func(ctx context.Context, r *connect.Request[mgmtv1alpha1.GetUserRequest]) (*connect.Response[mgmtv1alpha1.GetUserResponse], error) {
ctxlogger = GetLoggerFromContextOrDefault(ctx)
return connect.NewResponse(&mgmtv1alpha1.GetUserResponse{UserId: "123"}), nil
},
connect.WithInterceptors(interceptor),
))
srv := startHTTPServer(t, mux)

assert.Nil(t, ctxlogger, "ctxlogger has not been set yet")
client := mgmtv1alpha1connect.NewUserAccountServiceClient(srv.Client(), srv.URL)
_, err := client.GetUser(context.Background(), connect.NewRequest(&mgmtv1alpha1.GetUserRequest{}))
assert.Nil(t, err)
assert.NotNil(t, ctxlogger)
}

func startHTTPServer(tb testing.TB, h http.Handler) *httptest.Server {
tb.Helper()
srv := httptest.NewUnstartedServer(h)
srv.EnableHTTP2 = true
srv.Start()
tb.Cleanup(srv.Close)
return srv
}
4 changes: 4 additions & 0 deletions backend/internal/connect/interceptors/logger/logger-ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ func GetLoggerFromContextOrDefault(ctx context.Context) *slog.Logger {
}
return data.GetLogger()
}

func setLoggerContext(ctx context.Context, logger *slog.Logger) context.Context {
return context.WithValue(ctx, loggerContextKey{}, &loggerContextData{logger: logger})
}
21 changes: 21 additions & 0 deletions backend/internal/connect/interceptors/logger/logger-ctx_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package logger_interceptor

import (
"context"
"log/slog"
"os"
"testing"

"github.com/stretchr/testify/assert"
)

func Test_GetLoggerFromContextOrDefault(t *testing.T) {
assert.NotNil(t, GetLoggerFromContextOrDefault(context.Background()))
}

func Test_GetLoggerFromContextOrDefault_NonDefault(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := setLoggerContext(context.Background(), logger)
ctxlogger := GetLoggerFromContextOrDefault(ctx)
assert.Equal(t, logger, ctxlogger)
}

0 comments on commit febcb60

Please sign in to comment.