diff --git a/config/base_config.go b/config/base_config.go index 63bdc9f..8164970 100644 --- a/config/base_config.go +++ b/config/base_config.go @@ -2,7 +2,6 @@ package config import ( "errors" - "github.com/open-policy-agent/opa/rego" "log" "net/http" "time" @@ -18,8 +17,6 @@ type Config struct { // You must set either URL or Policy. Policy string `json:"policy,omitempty"` - instantiatedPolicy *rego.PreparedEvalQuery - // Query is the name of the policy to query. Query string `json:"query,omitempty"` @@ -31,13 +28,15 @@ type Config struct { ExceptedResult bool `json:"excepted_result,omitempty"` // DeniedStatusCode is the status code that should be returned if the request is denied. + // Default to http.StatusForbidden. DeniedStatusCode int `json:"denied_status,omitempty"` // DeniedMessage is the message that should be returned if the request is denied. + // Default to "Forbidden". DeniedMessage string `json:"denied_message,omitempty"` - // Headers is a list of headers to send to the OPA server. - // All headers are sent to the OPA server except those in the IgnoredHeaders list. + // Headers is a list of headers to send to the OPA server in addition. + // All headers in the request are sent to the OPA server except those in the IgnoredHeaders list. Headers map[string][]string `json:"headers,omitempty"` // IgnoredHeaders is a list of headers to ignore when sending to the OPA server. @@ -61,6 +60,12 @@ func (c *Config) Validate() error { c.Logger = log.Default() } } + if c.DeniedStatusCode == 0 { + c.DeniedStatusCode = http.StatusForbidden + } + if c.DeniedMessage == "" { + c.DeniedMessage = "Forbidden" + } if c.Timeout == 0 { c.Timeout = 10 * time.Second } diff --git a/config/base_config_test.go b/config/base_config_test.go index d4c2025..435f99f 100644 --- a/config/base_config_test.go +++ b/config/base_config_test.go @@ -1,7 +1,6 @@ package config import ( - "github.com/open-policy-agent/opa/rego" "log" "net/http" "testing" @@ -12,7 +11,6 @@ func TestConfig_Validate(t *testing.T) { type fields struct { URL string Policy string - instantiatedPolicy *rego.PreparedEvalQuery Query string InputCreationMethod func(r *http.Request) (map[string]interface{}, error) ExceptedResult bool @@ -110,7 +108,6 @@ func TestConfig_Validate(t *testing.T) { c := &Config{ URL: tt.fields.URL, Policy: tt.fields.Policy, - instantiatedPolicy: tt.fields.instantiatedPolicy, Query: tt.fields.Query, InputCreationMethod: tt.fields.InputCreationMethod, ExceptedResult: tt.fields.ExceptedResult, diff --git a/fiber_middleware.go b/fiber_middleware.go index 852c591..2b58f63 100644 --- a/fiber_middleware.go +++ b/fiber_middleware.go @@ -9,7 +9,7 @@ import ( "net/http" ) -// InputCreationMethod is the method that is used to create the input for the policy. +// FiberInputCreationMethod is the method that is used to create the input for the policy. type FiberInputCreationMethod func(c *fiber.Ctx) (map[string]interface{}, error) type FiberMiddleware struct { @@ -40,7 +40,7 @@ func NewFiberMiddleware(cfg *config.Config, input FiberInputCreationMethod) (*Fi func (g *FiberMiddleware) Use() func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { if g.Config.Debug { - g.Config.Logger.Printf("[opa-middleware-fiber] Request: %s", c.Request().URI()) + g.Config.Logger.Printf("[opa-middleware-fiber] Request received") } result, err := g.query(c) if err != nil { diff --git a/gin_middleware.go b/gin_middleware.go index fa66de7..9a2d32b 100644 --- a/gin_middleware.go +++ b/gin_middleware.go @@ -12,7 +12,7 @@ type GinInputCreationMethod func(c *gin.Context) (map[string]interface{}, error) type GinMiddleware struct { Config *config.Config - // BindingMethod is a function that returns the value to be sent to the OPA server. + // InputCreationMethod is a function that returns the value to be sent to the OPA server. InputCreationMethod GinInputCreationMethod `json:"binding_method,omitempty"` } @@ -44,7 +44,7 @@ func NewGinMiddleware(cfg *config.Config, input GinInputCreationMethod) (*GinMid func (g *GinMiddleware) Use() func(c *gin.Context) { return func(c *gin.Context) { if g.Config.Debug { - g.Config.Logger.Printf("[opa-middleware-gin] Request: %s", c.Request.URL.String()) + g.Config.Logger.Printf("[opa-middleware-gin] Request received") } result, err := g.query(c) if err != nil { diff --git a/http_middleware.go b/http_middleware.go index 83d8722..e2cc143 100644 --- a/http_middleware.go +++ b/http_middleware.go @@ -31,7 +31,7 @@ func NewHTTPMiddleware(cfg *config.Config, next http.Handler) (*HTTPMiddleware, // ServeHTTP serves the http request. Act as Use acts in other frameworks. func (h *HTTPMiddleware) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if h.Config.Debug { - h.Config.Logger.Printf("[opa-middleware-http] Request: %s", req.URL.String()) + h.Config.Logger.Printf("[opa-middleware-http] Request received") } result, err := h.query(req) if err != nil {