Skip to content

Commit

Permalink
Merge pull request #214 from DasSkelett/feature/use-id-token-claims
Browse files Browse the repository at this point in the history
Add option to read claims from OIDC ID Token
  • Loading branch information
DasSkelett authored Aug 3, 2022
2 parents 1e56ae8 + dfe7161 commit c6f8447
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 56 deletions.
30 changes: 22 additions & 8 deletions docs/4-auth.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,26 +59,34 @@ auth:
# Your OIDC client credentials which would be provided by your OIDC provider
clientID: "<client-id>"
clientSecret: "<client-secret>"
# List of scopes to request defaults to ["openid"]
scopes:
- openid
# The full redirect URL
# The path can be almost anything as long as it doesn't
# conflict with a path that the web UI uses.
# /callback is recommended.
redirectURL: "https://wg-access-server.example.com/callback"
# List of scopes to request claims for. Must include 'openid'.
# Must include 'email' if 'emailDomains' is used. Can include 'profile' to show the user's name in the UI.
# Add custom ones if required for 'claimMapping'.
# Defaults to ["openid"]
scopes:
- openid
- profile
- email
# You can optionally restrict access to users with an email address
# that matches an allowed domain.
# If empty or omitted then all email domains will be allowed.
emailDomains:
- example.com
# This is an advanced feature that allows you to define
# OIDC claim mapping expressions.
# This feature is used to define wg-access-server admins
# based off a claim in your OIDC token
# See https://github.com/Knetic/govaluate/blob/9aa49832a739dcd78a5542ff189fb82c3e423116/MANUAL.md for how to write rules
# This is an advanced feature that allows you to define OIDC claim mapping expressions.
# This feature is used to define wg-access-server admins based off a claim in your OIDC token.
# A JSON-like object of claimKey: claimValue pairs as returned by the issuer is passed to the evaluation function.
# See https://github.com/Knetic/govaluate/blob/9aa49832a739dcd78a5542ff189fb82c3e423116/MANUAL.md for the syntax.
claimMapping:
# This example works if you have a custom group_membership claim which is a list of strings
admin: "'WireguardAdmins' in group_membership"
# Let wg-access-server retrieve the claims from the ID Token instead of querying the UserInfo endpoint.
# Some OIDC authorization provider implementations (e.g. ADFS) only publish claims in the ID Token.
claimsFromIDToken: false
gitlab:
name: "My Gitlab Backend"
baseURL: "https://mygitlab.example.com"
Expand All @@ -88,3 +96,9 @@ auth:
emailDomains:
- example.com
```
## OIDC Provider specifics
### Active Directory Federation Services (ADFS)
Please see [this helpful issue comment](https://github.com/freifunkMUC/wg-access-server/issues/213#issuecomment-1172656633) for instructions for ADFS 2016 and above.
148 changes: 103 additions & 45 deletions pkg/authnz/authconfig/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"net/url"
"strconv"
"strings"
"time"

"github.com/freifunkMUC/wg-access-server/pkg/authnz/authruntime"
"github.com/freifunkMUC/wg-access-server/pkg/authnz/authsession"
Expand All @@ -21,24 +20,28 @@ import (
"gopkg.in/yaml.v2"
)

// OIDCConfig implements an OIDC client using the [Authorization Code Flow]
// [Authorization Code Flow]: https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth
type OIDCConfig struct {
Name string `yaml:"name"`
Issuer string `yaml:"issuer"`
ClientID string `yaml:"clientID"`
ClientSecret string `yaml:"clientSecret"`
Scopes []string `yaml:"scopes"`
RedirectURL string `yaml:"redirectURL"`
EmailDomains []string `yaml:"emailDomains"`
ClaimMapping map[string]ruleExpression `yaml:"claimMapping"`
Name string `yaml:"name"`
Issuer string `yaml:"issuer"`
ClientID string `yaml:"clientID"`
ClientSecret string `yaml:"clientSecret"`
Scopes []string `yaml:"scopes"`
RedirectURL string `yaml:"redirectURL"`
EmailDomains []string `yaml:"emailDomains"`
ClaimMapping map[string]ruleExpression `yaml:"claimMapping"`
ClaimsFromIDToken bool `yaml:"claimsFromIDToken"`
}

func (c *OIDCConfig) Provider() *authruntime.Provider {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
// The context for the oidc.Provider must be long-lived for verifying ID tokens later-on
ctx := context.Background()
provider, err := oidc.NewProvider(ctx, c.Issuer)
if err != nil {
logrus.Fatal(errors.Wrap(err, "failed to create oidc provider"))
}
verifier := provider.Verifier(&oidc.Config{ClientID: c.ClientID})

if c.Scopes == nil {
c.Scopes = []string{"openid"}
Expand All @@ -63,14 +66,15 @@ func (c *OIDCConfig) Provider() *authruntime.Provider {
c.loginHandler(runtime, oauthConfig)(w, r)
},
RegisterRoutes: func(router *mux.Router, runtime *authruntime.ProviderRuntime) error {
router.HandleFunc(redirectURL.Path, c.callbackHandler(runtime, oauthConfig, provider))
router.HandleFunc(redirectURL.Path, c.callbackHandler(runtime, oauthConfig, provider, verifier))
return nil
},
}
}

func (c *OIDCConfig) loginHandler(runtime *authruntime.ProviderRuntime, oauthConfig *oauth2.Config) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// 1. Client prepares an Authentication Request containing the desired request parameters.
oauthStateString := authutil.RandomString(32)
err := runtime.SetSession(w, r, &authsession.AuthSession{
Nonce: &oauthStateString,
Expand All @@ -79,75 +83,110 @@ func (c *OIDCConfig) loginHandler(runtime *authruntime.ProviderRuntime, oauthCon
http.Error(w, "no session", http.StatusUnauthorized)
return
}
url := oauthConfig.AuthCodeURL(oauthStateString)
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
// 2. Client sends the request to the Authorization Server.
authCodeURL := oauthConfig.AuthCodeURL(oauthStateString)
http.Redirect(w, r, authCodeURL, http.StatusTemporaryRedirect)
}
}

func (c *OIDCConfig) callbackHandler(runtime *authruntime.ProviderRuntime, oauthConfig *oauth2.Config, provider *oidc.Provider) http.HandlerFunc {
func (c *OIDCConfig) callbackHandler(runtime *authruntime.ProviderRuntime, oauthConfig *oauth2.Config,
provider *oidc.Provider, verifier *oidc.IDTokenVerifier) http.HandlerFunc {

return func(w http.ResponseWriter, r *http.Request) {
// 3. Authorization Server Authenticates the End-User.
// 4. Authorization Server obtains End-User Consent/Authorization.
// 5. Authorization Server sends the End-User back to the Client with an Authorization Code.

s, err := runtime.GetSession(r)
if err != nil {
http.Error(w, "no session", http.StatusBadRequest)
return
}

// Make sure the returned state matches the one saved in the session cookie to prevent CSRF attacks
state := r.FormValue("state")
if s.Nonce == nil || *s.Nonce != state {
http.Error(w, "bad nonce", http.StatusBadRequest)
if s.Nonce == nil {
http.Error(w, "no state associated with session", http.StatusBadRequest)
return
} else if *s.Nonce != state {
http.Error(w, "bad state value", http.StatusBadRequest)
return
}

code := r.FormValue("code")
authCode := r.FormValue("code")

token, err := oauthConfig.Exchange(r.Context(), code)
// 6. Client requests a response using the Authorization Code at the Token Endpoint.
// 7. Client receives a response that contains an ID Token and Access Token in the response body.
oauth2Token, err := oauthConfig.Exchange(r.Context(), authCode)
if err != nil {
panic(errors.Wrap(err, "Unable to exchange tokens"))
}

info, err := provider.UserInfo(r.Context(), oauthConfig.TokenSource(r.Context(), token))
if err != nil {
panic(errors.Wrap(err, "Unable to get user info"))
// 8. Client validates the ID token and retrieves the End-User's Subject Identifier.
oidcClaims := make(map[string]interface{})
if !c.ClaimsFromIDToken {
// Use the UserInfo endpoint to retrieve the claims
logrus.Debug("retrieving claims from UserInfo endpoint")
info, err := provider.UserInfo(r.Context(), oauthConfig.TokenSource(r.Context(), oauth2Token))
if err != nil {
panic(errors.Wrap(err, "Unable to get UserInfo"))
}

// Dump the claims
err = info.Claims(&oidcClaims)
if err != nil {
panic(errors.Wrap(err, "Unable to unmarshal claims from UserInfo JSON"))
}
} else {
// Extract and parse the ID token to retrieve the claims
logrus.Debug("retrieving claims from ID Token")
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
panic(errors.New("No id_token field in oauth2 token"))
}
// Parse and verify ID Token payload
idToken, err := verifier.Verify(r.Context(), rawIDToken)
if err != nil {
panic(errors.Wrap(err, "Failed to verify ID Token"))
}

// Dump the claims
err = idToken.Claims(&oidcClaims)
if err != nil {
panic(errors.Wrap(err, "Unable to unmarshal claims from ID Token JSON"))
}
}

if msg, valid := verifyEmailDomain(c.EmailDomains, info.Email); !valid {
email, _ := oidcClaims["email"].(string)
if msg, valid := verifyEmailDomain(c.EmailDomains, email); !valid {
http.Error(w, msg, http.StatusForbidden)
return
}

oidcProfileData := make(map[string]interface{})
err = info.Claims(&oidcProfileData)
claims, err := evaluateClaimMapping(c.ClaimMapping, oidcClaims)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

claims := &authsession.Claims{}
for claimName, rule := range c.ClaimMapping {
result, err := rule.Evaluate(oidcProfileData)

if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

// If result is 'false' or an empty string then don't include the Claim
if val, ok := result.(bool); ok && val {
claims.Add(claimName, strconv.FormatBool(val))
} else if val, ok := result.(string); ok && len(val) > 0 {
claims.Add(claimName, val)
}
// Build the authnz Identity for the user, they are now considered logged in
var subject string
if sub, ok := oidcClaims["sub"].(string); ok {
subject = sub
} else {
panic(errors.New("No 'sub' claim returned from authorization provider"))
}

identity := &authsession.Identity{
Provider: c.Name,
Subject: info.Subject,
Email: info.Email,
Subject: subject,
Claims: *claims,
}
if name, ok := oidcProfileData["name"].(string); ok {
if name, ok := oidcClaims["name"].(string); ok {
identity.Name = name
}
if email != "" {
identity.Email = email
}

err = runtime.SetSession(w, r, &authsession.AuthSession{
Identity: identity,
Expand Down Expand Up @@ -183,6 +222,25 @@ func verifyEmailDomain(allowedDomains []string, email string) (string, bool) {
return "email domain not authorized", false
}

// evaluateClaimMapping translates OIDC claims to custom authnz claims.
func evaluateClaimMapping(claimMapping map[string]ruleExpression, oidcClaims map[string]interface{}) (*authsession.Claims, error) {
claims := &authsession.Claims{}
for claimName, rule := range claimMapping {
result, err := rule.Evaluate(oidcClaims)
if err != nil {
return nil, err
}

// If result is 'false' or an empty string then don't include the Claim
if val, ok := result.(bool); ok && val {
claims.Add(claimName, strconv.FormatBool(val))
} else if val, ok := result.(string); ok && len(val) > 0 {
claims.Add(claimName, val)
}
}
return claims, nil
}

type ruleExpression struct {
*govaluate.EvaluableExpression
}
Expand Down
46 changes: 46 additions & 0 deletions pkg/authnz/authconfig/oidc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package authconfig

import (
"reflect"
"testing"

"gopkg.in/Knetic/govaluate.v2"

"github.com/freifunkMUC/wg-access-server/pkg/authnz/authsession"
)

func Test_evaluateClaimMapping(t *testing.T) {
type args struct {
claimMapping map[string]ruleExpression
oidcClaims map[string]interface{}
}

expr, _ := govaluate.NewEvaluableExpression("'WireguardAdmins' in group_membership")

tests := []struct {
name string
args args
want authsession.Claims
wantErr bool
}{
{
args: args{
claimMapping: map[string]ruleExpression{"admin": {expr}},
oidcClaims: map[string]interface{}{"group_membership": []interface{}{"wgas", "WireguardAdmins"}},
},
want: authsession.Claims{{Name: "admin", Value: "true"}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := evaluateClaimMapping(tt.args.claimMapping, tt.args.oidcClaims)
if (err != nil) != tt.wantErr {
t.Errorf("evaluateClaimMapping() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(*got, tt.want) {
t.Errorf("evaluateClaimMapping() got = %v, want %v", got, tt.want)
}
})
}
}
2 changes: 1 addition & 1 deletion pkg/authnz/authruntime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (p *ProviderRuntime) ClearSession(w http.ResponseWriter, r *http.Request) e
}

func (p *ProviderRuntime) Restart(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/signin", http.StatusTemporaryRedirect)
http.Redirect(w, r, "/signin?signout=1", http.StatusTemporaryRedirect)
}

func (p *ProviderRuntime) Done(w http.ResponseWriter, r *http.Request) {
Expand Down
3 changes: 2 additions & 1 deletion pkg/authnz/authutil/random.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ import (
"github.com/sirupsen/logrus"
)

// RandomString returns a base64url-encoded string of random data of size bytes
func RandomString(size int) string {
blk := make([]byte, size)
_, err := rand.Read(blk)
if err != nil {
logrus.Fatal(errors.Wrap(err, "failed to make a random string"))
}
return base64.StdEncoding.EncodeToString(blk)
return base64.URLEncoding.EncodeToString(blk)
}
2 changes: 1 addition & 1 deletion pkg/authnz/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func New(config authconfig.AuthConfig, claimsMiddleware authsession.ClaimsMiddle
}

router.HandleFunc("/signin", func(w http.ResponseWriter, r *http.Request) {
if !config.DesiresSigninPage() && len(providers) == 1 {
if r.FormValue("signout") != "1" && !config.DesiresSigninPage() && len(providers) == 1 {
// we only have one provider, so jump directly to that
providers[0].Invoke(w, r, runtime)
return
Expand Down

0 comments on commit c6f8447

Please sign in to comment.