Skip to content

Commit

Permalink
feat: rework eval policy and support setting headers to downstream req
Browse files Browse the repository at this point in the history
  • Loading branch information
lsjostro committed Oct 4, 2024
1 parent 96dbad0 commit 6226835
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 42 deletions.
49 changes: 42 additions & 7 deletions authz/authz.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"html/template"
"log/slog"
"maps"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -116,6 +117,7 @@ func (s *Service) Check(ctx context.Context, req *connect.Request[auth.CheckRequ
attribute.String("header_match_name", provider.HeaderMatch.Name),
),
)
var policyHeaders []*core.HeaderValueOption

// if PreAuthPolicy is defined evaluate the request
if provider.preAuthPolicy != nil {
Expand All @@ -126,22 +128,32 @@ func (s *Service) Check(ctx context.Context, req *connect.Request[auth.CheckRequ
return connect.NewResponse(s.authResponse(false, envoy_type.StatusCode_BadGateway, nil, nil, err.Error())), nil
}

allowed, bypassAuth, err := provider.preAuthPolicy.Eval(ctx, resInput)
decision, err := provider.preAuthPolicy.Eval(ctx, resInput)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
return connect.NewResponse(s.authResponse(false, envoy_type.StatusCode_BadGateway, nil, nil, err.Error())), nil
}
slog.Debug("PreAuth policy result", slog.Bool("allowed", allowed))

if !allowed {
if allowed, ok := decision["allow"].(bool); ok && !allowed {
span.SetStatus(codes.Error, "PreAuth policy denied the request")
return connect.NewResponse(s.authResponse(false, envoy_type.StatusCode_Forbidden, nil, nil, "PreAuth policy denied the request")), nil
}

if bypassAuth {
// check for headers to add downstream from decision log
if h, ok := decision["headers"].(map[string]any); ok {
policyHeaders = s.addHeaders(h)

hv := make(map[string]string)
for k, v := range h {
hv[k] = v.(string)
}
maps.Copy(httpReq.Headers, hv)
}

if bypass, ok := decision["bypass_auth"].(bool); ok && bypass {
span.SetStatus(codes.Ok, "bypassAuth policy allowed the request")
return connect.NewResponse(s.authResponse(true, envoy_type.StatusCode_OK, nil, nil, "skipAuth policy allowed the request")), nil
return connect.NewResponse(s.authResponse(true, envoy_type.StatusCode_OK, policyHeaders, nil, "")), nil
}
}

Expand All @@ -162,16 +174,26 @@ func (s *Service) Check(ctx context.Context, req *connect.Request[auth.CheckRequ
return connect.NewResponse(s.authResponse(false, envoy_type.StatusCode_BadGateway, nil, nil, err.Error())), nil
}

allowed, _, err := provider.postAuthPolicy.Eval(ctx, respInput)
decision, err := provider.postAuthPolicy.Eval(ctx, respInput)
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return connect.NewResponse(s.authResponse(false, envoy_type.StatusCode_BadGateway, nil, nil, err.Error())), nil
}
if !allowed {

if allowed, ok := decision["allow"].(bool); ok && !allowed {
span.SetStatus(codes.Error, "PostAuth policy denied the request")
return connect.NewResponse(s.authResponse(false, envoy_type.StatusCode_Forbidden, nil, nil, "PostAuth policy denied the request")), nil
}

if h, ok := decision["headers"].(map[string]any); ok {
policyHeaders = append(policyHeaders, s.addHeaders(h)...)
}
}

// if OkResponse add policy headers if any
if resp.GetStatus().GetCode() == int32(rpc.OK) && len(policyHeaders) > 0 {
resp.GetOkResponse().Headers = append(resp.GetOkResponse().Headers, policyHeaders...)
}

// Return response to envoy
Expand Down Expand Up @@ -536,6 +558,19 @@ func (s *Service) setAuthorizationHeader(token string) *core.HeaderValueOption {
}
}

func (s *Service) addHeaders(h map[string]any) []*core.HeaderValueOption {
var headers []*core.HeaderValueOption
for k, v := range h {
headers = append(headers, &core.HeaderValueOption{
Header: &core.HeaderValue{
Key: k,
Value: v.(string),
},
})
}
return headers
}

func (s *Service) getCodeQueryParam(fullURL string) (string, error) {
u, err := url.Parse(fullURL)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
CGO_ENABLED = "0";

meta = {
desciption = "Envoy OIDC Authserver";
description = "Envoy OIDC Authserver";
homepage = "https://github.com/shelmangroup/envoy-oidc-authserver";
mainProgram = "envoy-oidc-authserver";
};
Expand Down
45 changes: 13 additions & 32 deletions policy/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,12 @@ func (h tracePrintHook) Print(c print.Context, msg string) error {
// NewPolicy creates a new Policy with the given policy.
func NewPolicy(name, policy string) (*Policy, error) {
ctx := context.Background()
query := "allow = data.authz.allow"
// Allow skipAuth for PreAuth policy
if name == "PreAuth" {
query = "allow = data.authz.allow; bypass_auth = data.authz.bypass_auth"
}

ph := &tracePrintHook{
Name: name,
}

r, err := rego.New(
rego.Query(query),
rego.Query("data.authz"),
rego.Module("OpenPolicyAgent", policy),
rego.EnablePrintStatements(true),
rego.PrintHook(ph),
Expand All @@ -75,8 +69,8 @@ func NewPolicy(name, policy string) (*Policy, error) {
}, nil
}

// Eval evaluates the policy with the given input and returns the result.
func (p *Policy) Eval(ctx context.Context, input map[string]any) (bool, bool, error) {
// Eval evaluates the policy with the given input and returns the decision log.
func (p *Policy) Eval(ctx context.Context, input map[string]any) (map[string]any, error) {
ctx, span := tracer.Start(ctx, p.name+"PolicyEval")
defer span.End()

Expand All @@ -91,37 +85,24 @@ func (p *Policy) Eval(ctx context.Context, input map[string]any) (bool, bool, er
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return false, false, err
return nil, err
}

if len(rs) == 0 {
return false, false, errors.New("no results returned! No default value for `allow` and/or `bypass_auth` in policy?")
return nil, errors.New("no results returned! No default value for `allow` and/or `bypass_auth` in policy?")
}

allow, ok := rs[0].Bindings["allow"].(bool)
if !ok {
return false, false, errors.New("no allow result")
}

if !allow {
span.SetStatus(codes.Error, "policy denied")
return false, false, nil
}

bypassAuth, ok := rs[0].Bindings["bypass_auth"].(bool)
if !ok {
bypassAuth = false
var decision map[string]interface{}
switch d := rs[0].Expressions[0].Value.(type) {
case map[string]interface{}:
decision = d
default:
return nil, errors.New("invalid decision type")
}

span.AddEvent("decision_log",
trace.WithAttributes(
attribute.Bool("allowed", allow),
attribute.Bool("bypass_auth", bypassAuth),
),
)
slog.Debug("policy eval", slog.Any("decision_log", decision))

span.SetStatus(codes.Ok, "allowed")
return allow, bypassAuth, nil
return decision, nil
}

func RequestOrResponseToInput(req any) (map[string]any, error) {
Expand Down
7 changes: 5 additions & 2 deletions policy/eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ func TestEvalCheckRequest(t *testing.T) {
_, err = RequestOrResponseToInput(nil)
require.Error(t, err)

allowed, bypass, err := p.Eval(ctx, input)
decision, err := p.Eval(ctx, input)
require.NoError(t, err)
allowed := decision["allow"].(bool)
require.Equal(t, scenario.expected, allowed)
bypass := decision["bypass_auth"].(bool)
require.Equal(t, scenario.expected, bypass)
})
}
Expand Down Expand Up @@ -124,8 +126,9 @@ func TestEvalCheckResponse(t *testing.T) {
input, err := RequestOrResponseToInput(res)
require.NoError(t, err)

allowed, _, err := p.Eval(ctx, input)
decision, err := p.Eval(ctx, input)
require.NoError(t, err)
allowed := decision["allow"].(bool)
require.Equal(t, scenario.expected, allowed)
})
}
Expand Down
21 changes: 21 additions & 0 deletions run/config/providers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ providers:
}
bypass_auth if {
print("bypass", pre_authed)
pre_authed
}
Expand All @@ -41,6 +42,11 @@ providers:
glob.match("/api/info", ["/"], httpreq.path)
}
action_allowed if {
httpreq.method == "GET"
glob.match("/headers", ["/"], httpreq.path)
}
action_allowed if {
httpreq.method == "GET"
print("request path:", httpreq.path)
Expand All @@ -54,6 +60,10 @@ providers:
token.payload.email == "[email protected]"
}
headers["x-ext-authz-allow"] := "true" if { allow == true }
headers["x-ext-authz-pre-policy"] := "true"
headers["foo"] := "bar"
jwks_request(url) := http.send({
"url": url, "method": "GET", "force_cache": true, "force_cache_duration_seconds": 3600
})
Expand All @@ -65,6 +75,13 @@ providers:
[_, payload, _] := io.jwt.decode(parsed_jwt)
}
valid_tenents := { "securityteam", "devteam", "foo", "bar" }
token_payload_groups := [ "OpsTeam", "SecurityTeam" ]
groups_set := { lower(x) | x := token_payload_groups[_] }
groups := groups_set & valid_tenents
headers["X-Scope-OrgID"] := concat("|", array.concat(["fake"], [ x | x := groups[_] ]))
postAuthPolicy: |
package authz
import rego.v1
Expand All @@ -75,6 +92,10 @@ providers:
token.payload.email == "[email protected]"
}
headers["x-ext-authz-allow"] := "true" if { allow == true }
headers["x-ext-authz-post-policy"] := "true"
headers["foo"] := "baz"
token := { "payload": payload } if {
[_, payload, _] := io.jwt.decode(input.parsed_jwt)
}
Expand Down

0 comments on commit 6226835

Please sign in to comment.