Skip to content

Commit

Permalink
chore: modernize code
Browse files Browse the repository at this point in the history
Signed-off-by: Jan-Otto Kröpke <[email protected]>
  • Loading branch information
jkroepke committed Dec 21, 2024
1 parent afb7d34 commit 66d29e6
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 39 deletions.
1 change: 1 addition & 0 deletions cmd/daemon/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ func Execute(args []string, logWriter io.Writer, version, commit, date string) i
}

openvpnClient.SetOAuth2Client(oauth2Client)

httpHandler, err := httphandler.New(conf, oauth2Client)
if err != nil {
logger.Error(err.Error())
Expand Down
20 changes: 10 additions & 10 deletions internal/httphandler/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestHandler(t *testing.T) {
},
},
OpenVpn: config.OpenVpn{
Bypass: config.OpenVpnBypass{CommonNames: []string{}},
Bypass: config.OpenVpnBypass{CommonNames: make([]string, 0)},
AuthTokenUser: true,
},
},
Expand Down Expand Up @@ -83,7 +83,7 @@ func TestHandler(t *testing.T) {
},
},
OpenVpn: config.OpenVpn{
Bypass: config.OpenVpnBypass{CommonNames: []string{}},
Bypass: config.OpenVpnBypass{CommonNames: make([]string, 0)},
AuthTokenUser: true,
},
},
Expand Down Expand Up @@ -117,7 +117,7 @@ func TestHandler(t *testing.T) {
PKCE: true,
},
OpenVpn: config.OpenVpn{
Bypass: config.OpenVpnBypass{CommonNames: []string{}},
Bypass: config.OpenVpnBypass{CommonNames: make([]string, 0)},
AuthTokenUser: true,
},
},
Expand Down Expand Up @@ -149,7 +149,7 @@ func TestHandler(t *testing.T) {
},
},
OpenVpn: config.OpenVpn{
Bypass: config.OpenVpnBypass{CommonNames: []string{}},
Bypass: config.OpenVpnBypass{CommonNames: make([]string, 0)},
AuthTokenUser: true,
},
},
Expand Down Expand Up @@ -180,7 +180,7 @@ func TestHandler(t *testing.T) {
},
},
OpenVpn: config.OpenVpn{
Bypass: config.OpenVpnBypass{CommonNames: []string{}},
Bypass: config.OpenVpnBypass{CommonNames: make([]string, 0)},
AuthTokenUser: true,
},
},
Expand Down Expand Up @@ -212,7 +212,7 @@ func TestHandler(t *testing.T) {
},
},
OpenVpn: config.OpenVpn{
Bypass: config.OpenVpnBypass{CommonNames: []string{}},
Bypass: config.OpenVpnBypass{CommonNames: make([]string, 0)},
AuthTokenUser: true,
},
},
Expand Down Expand Up @@ -244,7 +244,7 @@ func TestHandler(t *testing.T) {
},
},
OpenVpn: config.OpenVpn{
Bypass: config.OpenVpnBypass{CommonNames: []string{}},
Bypass: config.OpenVpnBypass{CommonNames: make([]string, 0)},
AuthTokenUser: true,
},
},
Expand Down Expand Up @@ -276,7 +276,7 @@ func TestHandler(t *testing.T) {
},
},
OpenVpn: config.OpenVpn{
Bypass: config.OpenVpnBypass{CommonNames: []string{}},
Bypass: config.OpenVpnBypass{CommonNames: make([]string, 0)},
AuthTokenUser: true,
},
},
Expand Down Expand Up @@ -308,7 +308,7 @@ func TestHandler(t *testing.T) {
},
},
OpenVpn: config.OpenVpn{
Bypass: config.OpenVpnBypass{CommonNames: []string{}},
Bypass: config.OpenVpnBypass{CommonNames: make([]string, 0)},
AuthTokenUser: true,
},
},
Expand Down Expand Up @@ -340,7 +340,7 @@ func TestHandler(t *testing.T) {
},
},
OpenVpn: config.OpenVpn{
Bypass: config.OpenVpnBypass{CommonNames: []string{}},
Bypass: config.OpenVpnBypass{CommonNames: make([]string, 0)},
AuthTokenUser: true,
},
},
Expand Down
7 changes: 5 additions & 2 deletions internal/oauth2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,9 @@ func (c Client) postCodeExchangeHandler(logger *slog.Logger, session state.State
}
}

func (c Client) postCodeExchangeHandlerStoreRefreshToken(ctx context.Context, logger *slog.Logger, session state.State, clientID string, tokens *oidc.Tokens[*idtoken.Claims]) {
func (c Client) postCodeExchangeHandlerStoreRefreshToken(
ctx context.Context, logger *slog.Logger, session state.State, clientID string, tokens *oidc.Tokens[*idtoken.Claims],
) {
if !c.conf.OAuth2.Refresh.Enabled {
return
}
Expand All @@ -199,6 +201,7 @@ func (c Client) postCodeExchangeHandlerStoreRefreshToken(ctx context.Context, lo
refreshToken, err := c.provider.GetRefreshToken(tokens)
if err != nil {
logLevel := slog.LevelWarn

if errors.Is(err, ErrNoRefreshToken) {
if session.SessionState == "AuthenticatedEmptyUser" || session.SessionState == "Authenticated" {
logLevel = slog.LevelDebug
Expand All @@ -217,7 +220,7 @@ func (c Client) postCodeExchangeHandlerStoreRefreshToken(ctx context.Context, lo
}
}

func (c Client) httpErrorHandler(w http.ResponseWriter, httpStatus int, errorType string, errorDesc string, encryptedSession string) {
func (c Client) httpErrorHandler(w http.ResponseWriter, httpStatus int, errorType, errorDesc, encryptedSession string) {
logger := c.logger

session, err := state.NewWithEncodedToken(encryptedSession, c.conf.HTTP.Secret.String())
Expand Down
51 changes: 33 additions & 18 deletions internal/oauth2/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func New(ctx context.Context, logger *slog.Logger, conf config.Config, httpClien
return nil, fmt.Errorf("error fetch configuration for provider %s: %w", provider.GetName(), err)
}

c := &Client{
client := &Client{
storage: tokenStorage,
openvpn: openvpn,
conf: conf,
Expand All @@ -41,54 +41,55 @@ func New(ctx context.Context, logger *slog.Logger, conf config.Config, httpClien
return nil, fmt.Errorf("error parsing authorize params: %w", err)
}

c.authorizeParams = append(c.authorizeParams, authorizeParams...)
client.authorizeParams = append(client.authorizeParams, authorizeParams...)

if providerConfig.AuthCodeOptions != nil {
c.authorizeParams = append(c.authorizeParams, func() []oauth2.AuthCodeOption {
client.authorizeParams = append(client.authorizeParams, func() []oauth2.AuthCodeOption {
return providerConfig.AuthCodeOptions
})
}

options := c.getRelyingPartyOptions(httpClient)
options := client.getRelyingPartyOptions(httpClient)

scopes := conf.OAuth2.Scopes
if len(scopes) == 0 {
scopes = providerConfig.Scopes
}

if providerConfig.Endpoint == (oauth2.Endpoint{}) {
c.relyingParty, err = newOIDCRelyingParty(ctx, logger, conf, provider, scopes, options)
client.relyingParty, err = newOIDCRelyingParty(ctx, logger, conf, provider, scopes, options)
if err != nil {
return nil, fmt.Errorf("error oidc provider: %w", err)
return nil, err
}
} else {
c.relyingParty, err = newOAuthRelyingParty(ctx, logger, conf, provider, scopes, options, providerConfig)

client.relyingParty, err = newOAuthRelyingParty(ctx, logger, conf, provider, scopes, options, providerConfig)
if err != nil {
return nil, fmt.Errorf("error oauth2 provider: %w", err)
return nil, err
}
}

return c, nil
return client, nil
}

// newOIDCRelyingParty creates a new [rp.NewRelyingPartyOIDC]. This is used for providers that support OIDC.
func newOIDCRelyingParty(ctx context.Context, logger *slog.Logger, conf config.Config, provider Provider, scopes []string, options []rp.Option) (rp.RelyingParty, error) {
func newOIDCRelyingParty(
ctx context.Context, logger *slog.Logger, conf config.Config, provider Provider, scopes []string, options []rp.Option,
) (rp.RelyingParty, error) {
if !config.IsURLEmpty(conf.OAuth2.Endpoints.Discovery) {
logger.Log(ctx, slog.LevelInfo, fmt.Sprintf(
logger.LogAttrs(ctx, slog.LevelInfo, fmt.Sprintf(
"discover oidc auto configuration with provider %s for issuer %s with custom discovery url %s",
provider.GetName(), conf.OAuth2.Issuer.String(), conf.OAuth2.Endpoints.Discovery.String(),
))

options = append(options, rp.WithCustomDiscoveryUrl(conf.OAuth2.Endpoints.Discovery.String()))
} else {
logger.Log(ctx, slog.LevelInfo, fmt.Sprintf(
logger.LogAttrs(ctx, slog.LevelInfo, fmt.Sprintf(
"discover oidc auto configuration with provider %s for issuer %s",
provider.GetName(), conf.OAuth2.Issuer.String(),
))
}

return rp.NewRelyingPartyOIDC(
replyingParty, err := rp.NewRelyingPartyOIDC(
logging.ToContext(ctx, logger),
conf.OAuth2.Issuer.String(),
conf.OAuth2.Client.ID,
Expand All @@ -97,26 +98,40 @@ func newOIDCRelyingParty(ctx context.Context, logger *slog.Logger, conf config.C
scopes,
options...,
)

if err != nil {
return nil, fmt.Errorf("error oidc provider: %w", err)
}

return replyingParty, nil
}

// newOAuthRelyingParty creates a new [rp.NewRelyingPartyOAuth]. This is used for providers that do not support OIDC.
func newOAuthRelyingParty(ctx context.Context, logger *slog.Logger, conf config.Config, provider Provider, scopes []string, options []rp.Option, providerConfig types.ProviderConfig) (rp.RelyingParty, error) {
logger.Log(ctx, slog.LevelInfo, fmt.Sprintf(
func newOAuthRelyingParty(

Check failure on line 110 in internal/oauth2/provider.go

View workflow job for this annotation

GitHub Actions / lint

argument-limit: maximum number of arguments per function exceeded; max 6 but got 7 (revive)
ctx context.Context, logger *slog.Logger, conf config.Config, provider Provider, scopes []string, options []rp.Option, providerConfig types.ProviderConfig,
) (rp.RelyingParty, error) {
logger.LogAttrs(ctx, slog.LevelInfo, fmt.Sprintf(
"manually configure oauth2 provider with provider %s and providerConfig %s and %s",
provider.GetName(), providerConfig.AuthURL, providerConfig.TokenURL,
))

if provider.GetName() == "generic" {
logger.Log(ctx, slog.LevelWarn, "generic provider with manual configuration is used. Validation of user data is not possible.")
logger.LogAttrs(ctx, slog.LevelWarn, "generic provider with manual configuration is used. Validation of user data is not possible.")
}

return rp.NewRelyingPartyOAuth(&oauth2.Config{
replyingParty, err := rp.NewRelyingPartyOAuth(&oauth2.Config{
ClientID: conf.OAuth2.Client.ID,
ClientSecret: conf.OAuth2.Client.Secret.String(),
RedirectURL: conf.HTTP.BaseURL.JoinPath("/oauth2/callback").String(),
Scopes: scopes,
Endpoint: providerConfig.Endpoint,
}, options...)

if err != nil {
return nil, fmt.Errorf("error oauth2 provider: %w", err)
}

return replyingParty, nil
}

func (c Client) getRelyingPartyOptions(httpClient *http.Client) []rp.Option {
Expand Down
1 change: 1 addition & 0 deletions internal/oauth2/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ func (c Client) RefreshClientAuth(ctx context.Context, logger *slog.Logger, clie
refreshToken, err = c.provider.GetRefreshToken(tokens)
if err != nil {
logLevel := slog.LevelWarn

if errors.Is(err, ErrNoRefreshToken) {
if session.SessionState == "AuthenticatedEmptyUser" || session.SessionState == "Authenticated" {
logLevel = slog.LevelDebug
Expand Down
7 changes: 1 addition & 6 deletions internal/openvpn/passthrough_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"time"

"github.com/jkroepke/openvpn-auth-oauth2/internal/config"
"github.com/jkroepke/openvpn-auth-oauth2/internal/oauth2"
"github.com/jkroepke/openvpn-auth-oauth2/internal/openvpn"
"github.com/jkroepke/openvpn-auth-oauth2/internal/tokenstorage"
"github.com/jkroepke/openvpn-auth-oauth2/internal/utils/testutils"
Expand Down Expand Up @@ -107,10 +106,7 @@ func TestPassthroughFull(t *testing.T) {
defer cancel()

tokenStorage := tokenstorage.NewInMemory(ctx, testutils.Secret, time.Hour)
provider := oauth2.New(context.Background(), logger.Logger, tt.conf, http.DefaultClient, tokenStorage, nil)
openVPNClient := openvpn.New(ctx, logger.Logger, tt.conf, provider)

defer openVPNClient.Shutdown()
_, openVPNClient := testutils.SetupOpenVPNOAuth2Clients(t, ctx, conf, logger.Logger, http.DefaultClient, tokenStorage)

wg := sync.WaitGroup{}

Expand Down Expand Up @@ -345,7 +341,6 @@ func TestPassthroughFull(t *testing.T) {

<-ctx.Done()

openVPNClient.Shutdown()
wg.Wait()
})
}
Expand Down
1 change: 1 addition & 0 deletions internal/utils/testutils/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ func SetupResourceServer(tb testing.TB, clientListener net.Listener) (*httptest.
}))

resourceServer := httptest.NewServer(mux)

tb.Cleanup(func() {
resourceServer.Close()
})
Expand Down
4 changes: 2 additions & 2 deletions internal/utils/testutils/openvpn.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ func NewFakeOpenVPNClient() FakeOpenVPNClient {
}

func (FakeOpenVPNClient) AcceptClient(_ *slog.Logger, _ state.ClientIdentifier, _ string) {
return

}

func (FakeOpenVPNClient) DenyClient(_ *slog.Logger, _ state.ClientIdentifier, _ string) {
return

}
2 changes: 1 addition & 1 deletion internal/utils/testutils/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ func (FakeStorage) Set(_, _ string) error {
}

func (FakeStorage) Delete(_ string) {
return

}

0 comments on commit 66d29e6

Please sign in to comment.