From fe953e5237dbfcc141dc8c86700259a083330df8 Mon Sep 17 00:00:00 2001 From: XIAO YU Date: Tue, 15 Oct 2024 18:53:10 +0900 Subject: [PATCH] refactor(appx): optimize JWT token validation with concurrent validators (#49) * refactor(appx): optimize JWT token validation with concurrent validators * Add context cancellation to stop unnecessary validations * Add tests for change * Resolve data race in JWT validator test --- appx/jwt_validator.go | 62 ++++++++++++---- appx/jwt_validator_test.go | 130 +++++++++++++++++++++++++++++++- appx/tracer.go | 6 +- appx/tracer_test.go | 148 +++++++++++++++++++++++++++++++++++++ 4 files changed, 328 insertions(+), 18 deletions(-) create mode 100644 appx/tracer_test.go diff --git a/appx/jwt_validator.go b/appx/jwt_validator.go index 03feecd..a7f5636 100644 --- a/appx/jwt_validator.go +++ b/appx/jwt_validator.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sync" "github.com/auth0/go-jwt-middleware/v2/validator" "github.com/reearth/reearthx/log" @@ -29,7 +30,7 @@ func NewJWTValidatorWithError( audience []string, opts ...validator.Option, ) (*JWTValidatorWithError, error) { - validator, err := validator.New( + v, err := validator.New( keyFunc, signatureAlgorithm, issuerURL, @@ -40,7 +41,7 @@ func NewJWTValidatorWithError( return nil, err } return &JWTValidatorWithError{ - validator: validator, + validator: v, iss: issuerURL, aud: slices.Clone(audience), }, nil @@ -49,9 +50,9 @@ func NewJWTValidatorWithError( func (v *JWTValidatorWithError) ValidateToken(ctx context.Context, token string) (interface{}, error) { res, err := v.validator.ValidateToken(ctx, token) if err != nil { - err = fmt.Errorf("invalid JWT: iss=%s aud=%v err=%w", v.iss, v.aud, err) + return nil, fmt.Errorf("invalid JWT: iss=%s aud=%v err=%w", v.iss, v.aud, err) } - return res, err + return res, nil } type JWTMultipleValidator []JWTValidator @@ -62,20 +63,53 @@ func NewJWTMultipleValidator(providers []JWTProvider) (JWTMultipleValidator, err }) } -// ValidateToken Trys to validate the token with each validator +// ValidateToken tries to validate the token with each validator concurrently // NOTE: the last validation error only is returned -func (mv JWTMultipleValidator) ValidateToken(ctx context.Context, tokenString string) (res interface{}, err error) { +func (mv JWTMultipleValidator) ValidateToken(ctx context.Context, tokenString string) (interface{}, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + type result struct { + res interface{} + err error + } + + resultChan := make(chan result, len(mv)) + var wg sync.WaitGroup + for _, v := range mv { - var err2 error - res, err2 = v.ValidateToken(ctx, tokenString) - if err2 == nil { - err = nil - return + wg.Add(1) + go func(validator JWTValidator) { + defer wg.Done() + res, err := validator.ValidateToken(ctx, tokenString) + select { + case resultChan <- result{res, err}: + case <-ctx.Done(): + return + } + }(v) + } + + go func() { + wg.Wait() + close(resultChan) + }() + + var lastErr error + for i := 0; i < len(mv); i++ { + select { + case r := <-resultChan: + if r.err == nil { + cancel() + return r.res, nil + } + lastErr = errors.Join(lastErr, r.err) + case <-ctx.Done(): + return nil, ctx.Err() } - err = errors.Join(err, err2) } log.Debugfc(ctx, "auth: invalid JWT token: %s", tokenString) - log.Errorfc(ctx, "auth: invalid JWT token: %v", err) - return + log.Errorfc(ctx, "auth: invalid JWT token: %v", lastErr) + return nil, lastErr } diff --git a/appx/jwt_validator_test.go b/appx/jwt_validator_test.go index 6b1c4b8..bdc4003 100644 --- a/appx/jwt_validator_test.go +++ b/appx/jwt_validator_test.go @@ -6,6 +6,7 @@ import ( "crypto/rsa" "encoding/json" "net/http" + "sync" "testing" "time" @@ -23,7 +24,10 @@ func TestMultiValidator(t *testing.T) { key := lo.Must(rsa.GenerateKey(rand.Reader, 2048)) httpmock.Activate() - defer httpmock.DeactivateAndReset() + t.Cleanup(func() { + httpmock.DeactivateAndReset() + }) + httpmock.RegisterResponder( http.MethodGet, "https://example.com/.well-known/openid-configuration", @@ -121,4 +125,128 @@ func TestMultiValidator(t *testing.T) { res3, err := v.ValidateToken(context.Background(), tokenString3) assert.ErrorIs(t, err, jwt2.ErrInvalidIssuer) assert.Nil(t, res3) + + t.Run("all validators fail", func(t *testing.T) { + invalidTokenString := "invalid.token.string" + + res, err := v.ValidateToken(context.Background(), invalidTokenString) + assert.Error(t, err) + assert.Nil(t, res) + + // Check if the error is a combination of multiple errors + var multiErr interface{ Unwrap() []error } + assert.ErrorAs(t, err, &multiErr) + errs := multiErr.Unwrap() + assert.Len(t, errs, 2) + + // Check if both errors are related to invalid token + for _, e := range errs { + assert.Contains(t, e.Error(), "invalid JWT") + } + }) + + t.Run("first validator succeeds", func(t *testing.T) { + v, err := NewJWTMultipleValidator([]JWTProvider{ + {ISS: "https://example.com/", AUD: []string{"a", "b"}, ALG: &jwt.SigningMethodRS256.Name}, + {ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name}, + }) + assert.NoError(t, err) + + res, err := v.ValidateToken(context.Background(), tokenString) + assert.NoError(t, err) + assert.NotNil(t, res) + claims, ok := res.(*validator.ValidatedClaims) + assert.True(t, ok) + assert.Equal(t, "https://example.com/", claims.RegisteredClaims.Issuer) + }) + + t.Run("second validator succeeds", func(t *testing.T) { + v, err := NewJWTMultipleValidator([]JWTProvider{ + {ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name}, + {ISS: "https://example.com/", AUD: []string{"a", "b"}, ALG: &jwt.SigningMethodRS256.Name}, + }) + assert.NoError(t, err) + + res, err := v.ValidateToken(context.Background(), tokenString) + assert.NoError(t, err) + assert.NotNil(t, res) + claims, ok := res.(*validator.ValidatedClaims) + assert.True(t, ok) + assert.Equal(t, "https://example.com/", claims.RegisteredClaims.Issuer) + }) + + t.Run("all validators fail", func(t *testing.T) { + v, err := NewJWTMultipleValidator([]JWTProvider{ + {ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name}, + {ISS: "https://example3.com/", AUD: []string{"d"}, ALG: &jwt.SigningMethodRS256.Name}, + }) + assert.NoError(t, err) + + res, err := v.ValidateToken(context.Background(), tokenString) + assert.Error(t, err) + assert.Nil(t, res) + + var multiErr interface{ Unwrap() []error } + assert.ErrorAs(t, err, &multiErr) + errs := multiErr.Unwrap() + assert.Len(t, errs, 2) + + for _, e := range errs { + assert.Contains(t, e.Error(), "invalid JWT") + } + }) + + t.Run("context cancellation", func(t *testing.T) { + v, err := NewJWTMultipleValidator([]JWTProvider{ + {ISS: "https://example.com/", AUD: []string{"a", "b"}, ALG: &jwt.SigningMethodRS256.Name}, + {ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name}, + }) + assert.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + res, err := v.ValidateToken(ctx, tokenString) + assert.Error(t, err) + assert.Nil(t, res) + assert.ErrorIs(t, err, context.Canceled) + }) + + t.Run("mixed valid and invalid tokens", func(t *testing.T) { + v, err := NewJWTMultipleValidator([]JWTProvider{ + {ISS: "https://example.com/", AUD: []string{"a", "b"}, ALG: &jwt.SigningMethodRS256.Name}, + {ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name}, + }) + assert.NoError(t, err) + + // Test with valid token + res, err := v.ValidateToken(context.Background(), tokenString) + assert.NoError(t, err) + assert.NotNil(t, res) + + // Test with invalid token + res, err = v.ValidateToken(context.Background(), "invalid.token") + assert.Error(t, err) + assert.Nil(t, res) + }) + + t.Run("concurrent validations", func(t *testing.T) { + v, err := NewJWTMultipleValidator([]JWTProvider{ + {ISS: "https://example.com/", AUD: []string{"a", "b"}, ALG: &jwt.SigningMethodRS256.Name}, + {ISS: "https://example2.com/", AUD: []string{"c"}, ALG: &jwt.SigningMethodRS256.Name}, + }) + assert.NoError(t, err) + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + res, err := v.ValidateToken(context.Background(), tokenString) + assert.NoError(t, err) + assert.NotNil(t, res) + }() + } + wg.Wait() + }) } diff --git a/appx/tracer.go b/appx/tracer.go index b3fed80..ceb2975 100644 --- a/appx/tracer.go +++ b/appx/tracer.go @@ -25,7 +25,7 @@ type TracerConfig struct { TracerSample float64 } -func InitTracer(ctx context.Context, conf TracerConfig) io.Closer { +func InitTracer(ctx context.Context, conf *TracerConfig) io.Closer { if conf.Tracer == TRACER_GCP { initGCPTracer(ctx, conf) } else if conf.Tracer == TRACER_JAEGER { @@ -34,7 +34,7 @@ func InitTracer(ctx context.Context, conf TracerConfig) io.Closer { return nil } -func initGCPTracer(ctx context.Context, conf TracerConfig) { +func initGCPTracer(ctx context.Context, conf *TracerConfig) { exporter, err := texporter.New() if err != nil { log.Fatalc(ctx, err) @@ -50,7 +50,7 @@ func initGCPTracer(ctx context.Context, conf TracerConfig) { log.Infofc(ctx, "tracer: initialized cloud trace with sample fraction: %g", conf.TracerSample) } -func initJaegerTracer(conf TracerConfig) io.Closer { +func initJaegerTracer(conf *TracerConfig) io.Closer { cfg := jaegercfg.Configuration{ Sampler: &jaegercfg.SamplerConfig{ Type: jaeger.SamplerTypeConst, diff --git a/appx/tracer_test.go b/appx/tracer_test.go new file mode 100644 index 0000000..b11f637 --- /dev/null +++ b/appx/tracer_test.go @@ -0,0 +1,148 @@ +package appx + +import ( + "context" + "io" + "strings" + "testing" + + "github.com/reearth/reearthx/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "go.opentelemetry.io/otel" + sdktrace "go.opentelemetry.io/otel/sdk/trace" +) + +// Mock for GCP exporter +type mockGCPExporter struct { + mock.Mock +} + +func (m *mockGCPExporter) ExportSpans(ctx context.Context, spans []sdktrace.ReadOnlySpan) error { + args := m.Called(ctx, spans) + return args.Error(0) +} + +func (m *mockGCPExporter) Shutdown(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +// Mock for Jaeger closer +type mockCloser struct { + mock.Mock +} + +func (m *mockCloser) Close() error { + args := m.Called() + return args.Error(0) +} + +type testLogWriter struct { + strings.Builder +} + +func (w *testLogWriter) Write(p []byte) (int, error) { + return w.Builder.Write(p) +} + +func TestInitTracer(t *testing.T) { + // Create function variables + var testInitGCPTracer func(ctx context.Context, conf *TracerConfig) + var testInitJaegerTracer func(conf *TracerConfig) io.Closer + + // Create a test wrapper for InitTracer that uses the function variables + testInitTracer := func(ctx context.Context, conf *TracerConfig) io.Closer { + if conf.Tracer == TRACER_GCP { + testInitGCPTracer(ctx, conf) + return nil + } else if conf.Tracer == TRACER_JAEGER { + return testInitJaegerTracer(conf) + } + return nil + } + + tests := []struct { + name string + config *TracerConfig + setup func() + expected io.Closer + }{ + { + name: "GCP Tracer", + config: &TracerConfig{ + Name: "test-gcp", + Tracer: TRACER_GCP, + TracerSample: 0.5, + }, + setup: func() { + testInitGCPTracer = func(ctx context.Context, conf *TracerConfig) { + // Mock the GCP tracer initialization + mockExporter := &mockGCPExporter{} + tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(mockExporter)) + otel.SetTracerProvider(tp) + log.Infofc(ctx, "tracer: initialized cloud trace with sample fraction: %g", conf.TracerSample) + } + }, + expected: nil, + }, + { + name: "Jaeger Tracer", + config: &TracerConfig{ + Name: "test-jaeger", + Tracer: TRACER_JAEGER, + TracerSample: 0.5, + }, + setup: func() { + testInitJaegerTracer = func(conf *TracerConfig) io.Closer { + // Mock the Jaeger tracer initialization + mockCloser := &mockCloser{} + mockCloser.On("Close").Return(nil) + log.Infof("tracer: initialized jaeger tracer with sample fraction: %g", conf.TracerSample) + return mockCloser + } + }, + expected: &mockCloser{}, + }, + { + name: "Unknown Tracer", + config: &TracerConfig{ + Name: "test-unknown", + Tracer: "unknown", + TracerSample: 0.5, + }, + setup: func() {}, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setup() + + // Capture log output + logWriter := &testLogWriter{} + log.SetOutput(logWriter) + defer log.SetOutput(nil) + + ctx := context.Background() + closer := testInitTracer(ctx, tt.config) + + if tt.expected == nil { + assert.Nil(t, closer) + } else { + assert.NotNil(t, closer) + assert.IsType(t, tt.expected, closer) + } + + // Check if the log output contains the expected message + logOutput := logWriter.String() + expectedLogMessage := "tracer: initialized" + if tt.config.Tracer != "unknown" { + assert.Contains(t, logOutput, expectedLogMessage) + } else { + assert.Empty(t, logOutput) + } + }) + } +}