Skip to content

Commit

Permalink
Parse role claims (#7713)
Browse files Browse the repository at this point in the history
* extract and test role claim parsing

Signed-off-by: Jörn Friedrich Dreyer <[email protected]>

* add failing test

Signed-off-by: Jörn Friedrich Dreyer <[email protected]>

* read segmented roles claim as array and string

Signed-off-by: Jörn Friedrich Dreyer <[email protected]>

* reuse more code by extracting WalkSegments

Signed-off-by: Jörn Friedrich Dreyer <[email protected]>

* add TestSplitWithEscaping

Signed-off-by: Jörn Friedrich Dreyer <[email protected]>

* docs and error for unhandled case

Signed-off-by: Jörn Friedrich Dreyer <[email protected]>

* add claims test

Signed-off-by: Jörn Friedrich Dreyer <[email protected]>

* add missing ReadStringClaim docs

Signed-off-by: Jörn Friedrich Dreyer <[email protected]>

---------

Signed-off-by: Jörn Friedrich Dreyer <[email protected]>
  • Loading branch information
butonic authored Dec 4, 2023
1 parent 81ace6d commit 23e59b5
Show file tree
Hide file tree
Showing 5 changed files with 409 additions and 32 deletions.
62 changes: 62 additions & 0 deletions ocis-pkg/oidc/claims.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package oidc

import (
"fmt"
"strings"
)

const (
Iss = "iss"
Sub = "sub"
Expand All @@ -12,3 +17,60 @@ const (
OwncloudUUID = "ownclouduuid"
OcisRoutingPolicy = "ocis.routing.policy"
)

// SplitWithEscaping splits s into segments using separator which can be escaped using the escape string
// See https://codereview.stackexchange.com/a/280193
func SplitWithEscaping(s string, separator string, escapeString string) []string {
a := strings.Split(s, separator)

for i := len(a) - 2; i >= 0; i-- {
if strings.HasSuffix(a[i], escapeString) {
a[i] = a[i][:len(a[i])-len(escapeString)] + separator + a[i+1]
a = append(a[:i+1], a[i+2:]...)
}
}
return a
}

// WalkSegments uses the given array of segments to walk the claims and return whatever interface was found
func WalkSegments(segments []string, claims map[string]interface{}) (interface{}, error) {
i := 0
for ; i < len(segments)-1; i++ {
switch castedClaims := claims[segments[i]].(type) {
case map[string]interface{}:
claims = castedClaims
case map[interface{}]interface{}:
claims = make(map[string]interface{}, len(castedClaims))
for k, v := range castedClaims {
if s, ok := k.(string); ok {
claims[s] = v
} else {
return nil, fmt.Errorf("could not walk claims path, key '%v' is not a string", k)
}
}
default:
return nil, fmt.Errorf("unsupported type '%v'", castedClaims)
}
}
return claims[segments[i]], nil
}

// ReadStringClaim returns the string obtained by following the . seperated path in the claims
func ReadStringClaim(path string, claims map[string]interface{}) (string, error) {
// check the simple case first
value, _ := claims[path].(string)
if value != "" {
return value, nil
}

claim, err := WalkSegments(SplitWithEscaping(path, ".", "\\"), claims)
if err != nil {
return "", err
}

if value, _ = claim.(string); value != "" {
return value, nil
}

return value, fmt.Errorf("claim path '%s' not set or empty", path)
}
182 changes: 182 additions & 0 deletions ocis-pkg/oidc/claims_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package oidc_test

import (
"encoding/json"
"reflect"
"testing"

"github.com/owncloud/ocis/v2/ocis-pkg/oidc"
)

type splitWithEscapingTest struct {
// Name of the subtest.
name string

// string to split
s string

// seperator to use
seperator string

// escape character to use for escaping
escape string

expectedParts []string
}

func (swet splitWithEscapingTest) run(t *testing.T) {
parts := oidc.SplitWithEscaping(swet.s, swet.seperator, swet.escape)
if len(swet.expectedParts) != len(parts) {
t.Errorf("mismatching length")
}
for i, v := range swet.expectedParts {
if parts[i] != v {
t.Errorf("expected part %d to be '%s', got '%s'", i, v, parts[i])
}
}
}

func TestSplitWithEscaping(t *testing.T) {
tests := []splitWithEscapingTest{
{
name: "plain claim name",
s: "roles",
seperator: ".",
escape: "\\",
expectedParts: []string{"roles"},
},
{
name: "claim with .",
s: "my.roles",
seperator: ".",
escape: "\\",
expectedParts: []string{"my", "roles"},
},
{
name: "claim with escaped .",
s: "my\\.roles",
seperator: ".",
escape: "\\",
expectedParts: []string{"my.roles"},
},
{
name: "claim with escaped . left",
s: "my\\.other.roles",
seperator: ".",
escape: "\\",
expectedParts: []string{"my.other", "roles"},
},
{
name: "claim with escaped . right",
s: "my.other\\.roles",
seperator: ".",
escape: "\\",
expectedParts: []string{"my", "other.roles"},
},
}
for _, test := range tests {
t.Run(test.name, test.run)
}
}

type walkSegmentsTest struct {
// Name of the subtest.
name string

// path segments to walk
segments []string

// seperator to use
claims map[string]interface{}

expected interface{}

wantErr bool
}

func (wst walkSegmentsTest) run(t *testing.T) {
v, err := oidc.WalkSegments(wst.segments, wst.claims)
if err != nil && !wst.wantErr {
t.Errorf("%v", err)
}
if err == nil && wst.wantErr {
t.Errorf("expected error")
}
if !reflect.DeepEqual(v, wst.expected) {
t.Errorf("expected %v got %v", wst.expected, v)
}
}

func TestWalkSegments(t *testing.T) {
byt := []byte(`{"first":{"second":{"third":["value1","value2"]},"foo":"bar"},"fizz":"buzz"}`)
var dat map[string]interface{}
if err := json.Unmarshal(byt, &dat); err != nil {
t.Errorf("%v", err)
}

tests := []walkSegmentsTest{
{
name: "one segment, single value",
segments: []string{"first"},
claims: map[string]interface{}{
"first": "value",
},
expected: "value",
wantErr: false,
},
{
name: "one segment, array value",
segments: []string{"first"},
claims: map[string]interface{}{
"first": []string{"value1", "value2"},
},
expected: []string{"value1", "value2"},
wantErr: false,
},
{
name: "two segments, single value",
segments: []string{"first", "second"},
claims: map[string]interface{}{
"first": map[string]interface{}{
"second": "value",
},
},
expected: "value",
wantErr: false,
},
{
name: "two segments, array value",
segments: []string{"first", "second"},
claims: map[string]interface{}{
"first": map[string]interface{}{
"second": []string{"value1", "value2"},
},
},
expected: []string{"value1", "value2"},
wantErr: false,
},
{
name: "three segments, array value from json",
segments: []string{"first", "second", "third"},
claims: dat,
expected: []interface{}{"value1", "value2"},
wantErr: false,
},
{
name: "three segments, array value with interface key",
segments: []string{"first", "second", "third"},
claims: map[string]interface{}{
"first": map[interface{}]interface{}{
"second": map[interface{}]interface{}{
"third": []string{"value1", "value2"},
},
},
},
expected: []string{"value1", "value2"},
wantErr: false,
},
}
for _, test := range tests {
t.Run(test.name, test.run)
}
}
16 changes: 1 addition & 15 deletions services/proxy/pkg/middleware/account_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"net/http"
"strings"

"github.com/owncloud/ocis/v2/services/proxy/pkg/user/backend"
"github.com/owncloud/ocis/v2/services/proxy/pkg/userroles"
Expand Down Expand Up @@ -43,19 +42,6 @@ type accountResolver struct {
userCS3Claim string
}

// from https://codereview.stackexchange.com/a/280193
func splitWithEscaping(s string, separator string, escapeString string) []string {
a := strings.Split(s, separator)

for i := len(a) - 2; i >= 0; i-- {
if strings.HasSuffix(a[i], escapeString) {
a[i] = a[i][:len(a[i])-len(escapeString)] + separator + a[i+1]
a = append(a[:i+1], a[i+2:]...)
}
}
return a
}

func readUserIDClaim(path string, claims map[string]interface{}) (string, error) {
// happy path
value, _ := claims[path].(string)
Expand All @@ -64,7 +50,7 @@ func readUserIDClaim(path string, claims map[string]interface{}) (string, error)
}

// try splitting path at .
segments := splitWithEscaping(path, ".", "\\")
segments := oidc.SplitWithEscaping(path, ".", "\\")
subclaims := claims
lastSegment := len(segments) - 1
for i := range segments {
Expand Down
61 changes: 44 additions & 17 deletions services/proxy/pkg/userroles/oidcroles.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
cs3 "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
"github.com/cs3org/reva/v2/pkg/utils"
"github.com/owncloud/ocis/v2/ocis-pkg/middleware"
"github.com/owncloud/ocis/v2/ocis-pkg/oidc"
settingssvc "github.com/owncloud/ocis/v2/protogen/gen/ocis/services/settings/v0"
"go-micro.dev/v4/metadata"
)
Expand All @@ -29,6 +30,45 @@ func NewOIDCRoleAssigner(opts ...Option) UserRoleAssigner {
}
}

func extractRoles(rolesClaim string, claims map[string]interface{}) (map[string]struct{}, error) {

claimRoles := map[string]struct{}{}
// happy path
value, _ := claims[rolesClaim].(string)
if value != "" {
claimRoles[value] = struct{}{}
return claimRoles, nil
}

claim, err := oidc.WalkSegments(oidc.SplitWithEscaping(rolesClaim, ".", "\\"), claims)
if err != nil {
return nil, err
}

switch v := claim.(type) {
case []string:
for _, cr := range v {
claimRoles[cr] = struct{}{}
}
case []interface{}:
for _, cri := range v {
cr, ok := cri.(string)
if !ok {
err := errors.New("invalid role in claims")
return nil, err
}

claimRoles[cr] = struct{}{}
}
case string:
claimRoles[v] = struct{}{}
default:
return nil, errors.New("no roles in user claims")
}

return claimRoles, nil
}

// UpdateUserRoleAssignment assigns the role "User" to the supplied user. Unless the user
// already has a different role assigned.
func (ra oidcRoleAssigner) UpdateUserRoleAssignment(ctx context.Context, user *cs3.User, claims map[string]interface{}) (*cs3.User, error) {
Expand All @@ -39,23 +79,10 @@ func (ra oidcRoleAssigner) UpdateUserRoleAssignment(ctx context.Context, user *c
return nil, err
}

claimRolesRaw, ok := claims[ra.rolesClaim].([]interface{})
if !ok {
logger.Error().Str("rolesClaim", ra.rolesClaim).Msg("No roles in user claims")
return nil, errors.New("no roles in user claims")
}

logger.Debug().Str("rolesClaim", ra.rolesClaim).Interface("rolesInClaim", claims[ra.rolesClaim]).Msg("got roles in claim")
claimRoles := map[string]struct{}{}
for _, cri := range claimRolesRaw {
cr, ok := cri.(string)
if !ok {
err := errors.New("invalid role in claims")
logger.Error().Err(err).Interface("claimValue", cri).Msg("Is not a valid string.")
return nil, err
}

claimRoles[cr] = struct{}{}
claimRoles, err := extractRoles(ra.rolesClaim, claims)
if err != nil {
logger.Error().Err(err).Msg("Error mapping role names to role ids")
return nil, err
}

if len(claimRoles) == 0 {
Expand Down
Loading

0 comments on commit 23e59b5

Please sign in to comment.