diff --git a/cmd/daemon/root.go b/cmd/daemon/root.go index ab4a3a1a..9bc45153 100644 --- a/cmd/daemon/root.go +++ b/cmd/daemon/root.go @@ -84,6 +84,7 @@ func Execute(args []string, logWriter io.Writer, version, commit, date string) i } openvpnClient.SetOAuth2Client(oauth2Client) + httpHandler, err := httphandler.New(conf, oauth2Client) if err != nil { logger.Error(err.Error()) diff --git a/internal/httphandler/handler_test.go b/internal/httphandler/handler_test.go index 72cc4e1d..001b6197 100644 --- a/internal/httphandler/handler_test.go +++ b/internal/httphandler/handler_test.go @@ -52,7 +52,7 @@ func TestHandler(t *testing.T) { }, }, OpenVpn: config.OpenVpn{ - Bypass: config.OpenVpnBypass{CommonNames: []string{}}, + Bypass: config.OpenVpnBypass{CommonNames: make([]string, 0)}, AuthTokenUser: true, }, }, @@ -83,7 +83,7 @@ func TestHandler(t *testing.T) { }, }, OpenVpn: config.OpenVpn{ - Bypass: config.OpenVpnBypass{CommonNames: []string{}}, + Bypass: config.OpenVpnBypass{CommonNames: make([]string, 0)}, AuthTokenUser: true, }, }, @@ -117,7 +117,7 @@ func TestHandler(t *testing.T) { PKCE: true, }, OpenVpn: config.OpenVpn{ - Bypass: config.OpenVpnBypass{CommonNames: []string{}}, + Bypass: config.OpenVpnBypass{CommonNames: make([]string, 0)}, AuthTokenUser: true, }, }, @@ -149,7 +149,7 @@ func TestHandler(t *testing.T) { }, }, OpenVpn: config.OpenVpn{ - Bypass: config.OpenVpnBypass{CommonNames: []string{}}, + Bypass: config.OpenVpnBypass{CommonNames: make([]string, 0)}, AuthTokenUser: true, }, }, @@ -180,7 +180,7 @@ func TestHandler(t *testing.T) { }, }, OpenVpn: config.OpenVpn{ - Bypass: config.OpenVpnBypass{CommonNames: []string{}}, + Bypass: config.OpenVpnBypass{CommonNames: make([]string, 0)}, AuthTokenUser: true, }, }, @@ -212,7 +212,7 @@ func TestHandler(t *testing.T) { }, }, OpenVpn: config.OpenVpn{ - Bypass: config.OpenVpnBypass{CommonNames: []string{}}, + Bypass: config.OpenVpnBypass{CommonNames: make([]string, 0)}, AuthTokenUser: true, }, }, @@ -244,7 +244,7 @@ func TestHandler(t *testing.T) { }, }, OpenVpn: config.OpenVpn{ - Bypass: config.OpenVpnBypass{CommonNames: []string{}}, + Bypass: config.OpenVpnBypass{CommonNames: make([]string, 0)}, AuthTokenUser: true, }, }, @@ -276,7 +276,7 @@ func TestHandler(t *testing.T) { }, }, OpenVpn: config.OpenVpn{ - Bypass: config.OpenVpnBypass{CommonNames: []string{}}, + Bypass: config.OpenVpnBypass{CommonNames: make([]string, 0)}, AuthTokenUser: true, }, }, @@ -308,7 +308,7 @@ func TestHandler(t *testing.T) { }, }, OpenVpn: config.OpenVpn{ - Bypass: config.OpenVpnBypass{CommonNames: []string{}}, + Bypass: config.OpenVpnBypass{CommonNames: make([]string, 0)}, AuthTokenUser: true, }, }, @@ -340,7 +340,7 @@ func TestHandler(t *testing.T) { }, }, OpenVpn: config.OpenVpn{ - Bypass: config.OpenVpnBypass{CommonNames: []string{}}, + Bypass: config.OpenVpnBypass{CommonNames: make([]string, 0)}, AuthTokenUser: true, }, }, diff --git a/internal/oauth2/handler.go b/internal/oauth2/handler.go index 6e886853..0b5e44d2 100644 --- a/internal/oauth2/handler.go +++ b/internal/oauth2/handler.go @@ -183,7 +183,9 @@ func (c Client) postCodeExchangeHandler(logger *slog.Logger, session state.State } } -func (c Client) postCodeExchangeHandlerStoreRefreshToken(ctx context.Context, logger *slog.Logger, session state.State, clientID string, tokens *oidc.Tokens[*idtoken.Claims]) { +func (c Client) postCodeExchangeHandlerStoreRefreshToken( + ctx context.Context, logger *slog.Logger, session state.State, clientID string, tokens *oidc.Tokens[*idtoken.Claims], +) { if !c.conf.OAuth2.Refresh.Enabled { return } @@ -199,6 +201,7 @@ func (c Client) postCodeExchangeHandlerStoreRefreshToken(ctx context.Context, lo refreshToken, err := c.provider.GetRefreshToken(tokens) if err != nil { logLevel := slog.LevelWarn + if errors.Is(err, ErrNoRefreshToken) { if session.SessionState == "AuthenticatedEmptyUser" || session.SessionState == "Authenticated" { logLevel = slog.LevelDebug @@ -217,7 +220,7 @@ func (c Client) postCodeExchangeHandlerStoreRefreshToken(ctx context.Context, lo } } -func (c Client) httpErrorHandler(w http.ResponseWriter, httpStatus int, errorType string, errorDesc string, encryptedSession string) { +func (c Client) httpErrorHandler(w http.ResponseWriter, httpStatus int, errorType, errorDesc, encryptedSession string) { logger := c.logger session, err := state.NewWithEncodedToken(encryptedSession, c.conf.HTTP.Secret.String()) diff --git a/internal/oauth2/provider.go b/internal/oauth2/provider.go index b06547b3..f2799d3a 100644 --- a/internal/oauth2/provider.go +++ b/internal/oauth2/provider.go @@ -27,7 +27,7 @@ func New(ctx context.Context, logger *slog.Logger, conf config.Config, httpClien return nil, fmt.Errorf("error fetch configuration for provider %s: %w", provider.GetName(), err) } - c := &Client{ + client := &Client{ storage: tokenStorage, openvpn: openvpn, conf: conf, @@ -41,15 +41,15 @@ func New(ctx context.Context, logger *slog.Logger, conf config.Config, httpClien return nil, fmt.Errorf("error parsing authorize params: %w", err) } - c.authorizeParams = append(c.authorizeParams, authorizeParams...) + client.authorizeParams = append(client.authorizeParams, authorizeParams...) if providerConfig.AuthCodeOptions != nil { - c.authorizeParams = append(c.authorizeParams, func() []oauth2.AuthCodeOption { + client.authorizeParams = append(client.authorizeParams, func() []oauth2.AuthCodeOption { return providerConfig.AuthCodeOptions }) } - options := c.getRelyingPartyOptions(httpClient) + options := client.getRelyingPartyOptions(httpClient) scopes := conf.OAuth2.Scopes if len(scopes) == 0 { @@ -57,38 +57,39 @@ func New(ctx context.Context, logger *slog.Logger, conf config.Config, httpClien } if providerConfig.Endpoint == (oauth2.Endpoint{}) { - c.relyingParty, err = newOIDCRelyingParty(ctx, logger, conf, provider, scopes, options) + client.relyingParty, err = newOIDCRelyingParty(ctx, logger, conf, provider, scopes, options) if err != nil { - return nil, fmt.Errorf("error oidc provider: %w", err) + return nil, err } } else { - c.relyingParty, err = newOAuthRelyingParty(ctx, logger, conf, provider, scopes, options, providerConfig) - + client.relyingParty, err = newOAuthRelyingParty(ctx, logger, conf, provider, scopes, options, providerConfig) if err != nil { - return nil, fmt.Errorf("error oauth2 provider: %w", err) + return nil, err } } - return c, nil + return client, nil } // newOIDCRelyingParty creates a new [rp.NewRelyingPartyOIDC]. This is used for providers that support OIDC. -func newOIDCRelyingParty(ctx context.Context, logger *slog.Logger, conf config.Config, provider Provider, scopes []string, options []rp.Option) (rp.RelyingParty, error) { +func newOIDCRelyingParty( + ctx context.Context, logger *slog.Logger, conf config.Config, provider Provider, scopes []string, options []rp.Option, +) (rp.RelyingParty, error) { if !config.IsURLEmpty(conf.OAuth2.Endpoints.Discovery) { - logger.Log(ctx, slog.LevelInfo, fmt.Sprintf( + logger.LogAttrs(ctx, slog.LevelInfo, fmt.Sprintf( "discover oidc auto configuration with provider %s for issuer %s with custom discovery url %s", provider.GetName(), conf.OAuth2.Issuer.String(), conf.OAuth2.Endpoints.Discovery.String(), )) options = append(options, rp.WithCustomDiscoveryUrl(conf.OAuth2.Endpoints.Discovery.String())) } else { - logger.Log(ctx, slog.LevelInfo, fmt.Sprintf( + logger.LogAttrs(ctx, slog.LevelInfo, fmt.Sprintf( "discover oidc auto configuration with provider %s for issuer %s", provider.GetName(), conf.OAuth2.Issuer.String(), )) } - return rp.NewRelyingPartyOIDC( + replyingParty, err := rp.NewRelyingPartyOIDC( logging.ToContext(ctx, logger), conf.OAuth2.Issuer.String(), conf.OAuth2.Client.ID, @@ -97,26 +98,40 @@ func newOIDCRelyingParty(ctx context.Context, logger *slog.Logger, conf config.C scopes, options..., ) + + if err != nil { + return nil, fmt.Errorf("error oidc provider: %w", err) + } + + return replyingParty, nil } // newOAuthRelyingParty creates a new [rp.NewRelyingPartyOAuth]. This is used for providers that do not support OIDC. -func newOAuthRelyingParty(ctx context.Context, logger *slog.Logger, conf config.Config, provider Provider, scopes []string, options []rp.Option, providerConfig types.ProviderConfig) (rp.RelyingParty, error) { - logger.Log(ctx, slog.LevelInfo, fmt.Sprintf( +func newOAuthRelyingParty( + ctx context.Context, logger *slog.Logger, conf config.Config, provider Provider, scopes []string, options []rp.Option, providerConfig types.ProviderConfig, +) (rp.RelyingParty, error) { + logger.LogAttrs(ctx, slog.LevelInfo, fmt.Sprintf( "manually configure oauth2 provider with provider %s and providerConfig %s and %s", provider.GetName(), providerConfig.AuthURL, providerConfig.TokenURL, )) if provider.GetName() == "generic" { - logger.Log(ctx, slog.LevelWarn, "generic provider with manual configuration is used. Validation of user data is not possible.") + logger.LogAttrs(ctx, slog.LevelWarn, "generic provider with manual configuration is used. Validation of user data is not possible.") } - return rp.NewRelyingPartyOAuth(&oauth2.Config{ + replyingParty, err := rp.NewRelyingPartyOAuth(&oauth2.Config{ ClientID: conf.OAuth2.Client.ID, ClientSecret: conf.OAuth2.Client.Secret.String(), RedirectURL: conf.HTTP.BaseURL.JoinPath("/oauth2/callback").String(), Scopes: scopes, Endpoint: providerConfig.Endpoint, }, options...) + + if err != nil { + return nil, fmt.Errorf("error oauth2 provider: %w", err) + } + + return replyingParty, nil } func (c Client) getRelyingPartyOptions(httpClient *http.Client) []rp.Option { diff --git a/internal/oauth2/refresh.go b/internal/oauth2/refresh.go index 44949edd..3b93d85b 100644 --- a/internal/oauth2/refresh.go +++ b/internal/oauth2/refresh.go @@ -69,6 +69,7 @@ func (c Client) RefreshClientAuth(ctx context.Context, logger *slog.Logger, clie refreshToken, err = c.provider.GetRefreshToken(tokens) if err != nil { logLevel := slog.LevelWarn + if errors.Is(err, ErrNoRefreshToken) { if session.SessionState == "AuthenticatedEmptyUser" || session.SessionState == "Authenticated" { logLevel = slog.LevelDebug diff --git a/internal/openvpn/passthrough_unix_test.go b/internal/openvpn/passthrough_unix_test.go index edfcb90a..9172c973 100644 --- a/internal/openvpn/passthrough_unix_test.go +++ b/internal/openvpn/passthrough_unix_test.go @@ -20,7 +20,6 @@ import ( "time" "github.com/jkroepke/openvpn-auth-oauth2/internal/config" - "github.com/jkroepke/openvpn-auth-oauth2/internal/oauth2" "github.com/jkroepke/openvpn-auth-oauth2/internal/openvpn" "github.com/jkroepke/openvpn-auth-oauth2/internal/tokenstorage" "github.com/jkroepke/openvpn-auth-oauth2/internal/utils/testutils" @@ -107,10 +106,7 @@ func TestPassthroughFull(t *testing.T) { defer cancel() tokenStorage := tokenstorage.NewInMemory(ctx, testutils.Secret, time.Hour) - provider := oauth2.New(context.Background(), logger.Logger, tt.conf, http.DefaultClient, tokenStorage, nil) - openVPNClient := openvpn.New(ctx, logger.Logger, tt.conf, provider) - - defer openVPNClient.Shutdown() + _, openVPNClient := testutils.SetupOpenVPNOAuth2Clients(t, ctx, conf, logger.Logger, http.DefaultClient, tokenStorage) wg := sync.WaitGroup{} @@ -345,7 +341,6 @@ func TestPassthroughFull(t *testing.T) { <-ctx.Done() - openVPNClient.Shutdown() wg.Wait() }) } diff --git a/internal/utils/testutils/main.go b/internal/utils/testutils/main.go index a3d5be12..eb376308 100644 --- a/internal/utils/testutils/main.go +++ b/internal/utils/testutils/main.go @@ -218,6 +218,7 @@ func SetupResourceServer(tb testing.TB, clientListener net.Listener) (*httptest. })) resourceServer := httptest.NewServer(mux) + tb.Cleanup(func() { resourceServer.Close() }) diff --git a/internal/utils/testutils/openvpn.go b/internal/utils/testutils/openvpn.go index ecf35316..4999c3b3 100644 --- a/internal/utils/testutils/openvpn.go +++ b/internal/utils/testutils/openvpn.go @@ -13,9 +13,9 @@ func NewFakeOpenVPNClient() FakeOpenVPNClient { } func (FakeOpenVPNClient) AcceptClient(_ *slog.Logger, _ state.ClientIdentifier, _ string) { - return + } func (FakeOpenVPNClient) DenyClient(_ *slog.Logger, _ state.ClientIdentifier, _ string) { - return + } diff --git a/internal/utils/testutils/storage.go b/internal/utils/testutils/storage.go index e9cbcd64..30be90d9 100644 --- a/internal/utils/testutils/storage.go +++ b/internal/utils/testutils/storage.go @@ -15,5 +15,5 @@ func (FakeStorage) Set(_, _ string) error { } func (FakeStorage) Delete(_ string) { - return + }