From 2e57b7e964f07181d32c7923f26d9205c2df6955 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Mon, 17 Jun 2019 17:44:09 +0200 Subject: [PATCH 01/47] agent/backend: add the api to get rules Add the first rule fields allowing to instantiate a callback at a given hookpoint. The data entry being dynamic according to the rule, an explicit json parser is required in order to unmarshal the json string to the correct type according to the json key `type`. Rules can be retrieved either at login in the app-login request's response, or on explicit request as a heartbeat command sent by the backend. --- agent/internal/backend/api/api.go | 26 +++++++++++++ agent/internal/backend/api/jsonpb.go | 45 +++++++++++++++++++++++ agent/internal/backend/api/jsonpb_test.go | 20 ++++++++++ agent/internal/backend/client.go | 13 +++++++ agent/internal/backend/client_test.go | 44 +++++++++++++++++++++- agent/internal/config/config.go | 5 ++- 6 files changed, 149 insertions(+), 4 deletions(-) diff --git a/agent/internal/backend/api/api.go b/agent/internal/backend/api/api.go index f60fcb03..eb3a5c6a 100644 --- a/agent/internal/backend/api/api.go +++ b/agent/internal/backend/api/api.go @@ -164,6 +164,27 @@ type BatchRequest_Event struct { } type Rule struct { + Name string `json:"name"` + Hookpoint Hookpoint `json:"hookpoint"` + Data RuleData `json:"data"` +} + +type Hookpoint struct { + Class string `json:"klass"` + Method string `json:"method"` + Callback string `json:"callback_class"` +} + +type RuleData struct { + Values []RuleDataEntry `json:"values"` +} + +type RuleDataEntry Struct + +const CustomErrorPageType = "custom_error_page" + +type CustomErrorPageRuleDataEntry struct { + StatusCode int `json:"status_code"` } type Dependency struct { @@ -800,3 +821,8 @@ func NewBlockedUserEventProperties_OutputFromFace(that BlockedUserEventPropertie this.User = that.GetUser() return this } + +type RulesPackResponse struct { + PackID string `json:"pack_id"` + Rules []Rule `json:"rules"` +} diff --git a/agent/internal/backend/api/jsonpb.go b/agent/internal/backend/api/jsonpb.go index acd7578b..c8129e9e 100644 --- a/agent/internal/backend/api/jsonpb.go +++ b/agent/internal/backend/api/jsonpb.go @@ -7,6 +7,9 @@ package api import ( "encoding/json" "fmt" + + "github.com/pkg/errors" + "github.com/sqreen/go-agent/agent/sqlib/sqerrors" ) var RequestRecordVersion = "20171208" @@ -104,3 +107,45 @@ func (identify *RequestRecord_Observed_SDKEvent_Args_Identify) MarshalJSON() ([] } return args.MarshalJSON() } + +// UnmarshalJSON parses rules data to their actual type. The actual type is +// given by the json structure key `type`. +func (v *RuleDataEntry) UnmarshalJSON(data []byte) error { + var discriminant struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &discriminant); err != nil { + return sqerrors.Wrap(err, "json unmarshal") + } + + var value interface{} + switch t := discriminant.Type; t { + case CustomErrorPageType: + value = &CustomErrorPageRuleDataEntry{} + default: + return sqerrors.Wrap(errors.Errorf("unexpected type of rule data value `%s`", t), "json unmarshal") + } + + if err := json.Unmarshal(data, value); err != nil { + return sqerrors.Wrap(err, "json unmarshal") + } + + v.Value = value + return nil +} + +// MarshalJSON serializes the type to the json representation whose type is +// provided by the key `type`. +func (v *RuleDataEntry) MarshalJSON() ([]byte, error) { + var discriminant interface{} + switch actual := v.Value.(type) { + case *CustomErrorPageRuleDataEntry: + discriminant = struct { + Type string `json:"type"` + *CustomErrorPageRuleDataEntry // Inlined + }{ + Type: CustomErrorPageType, CustomErrorPageRuleDataEntry: actual, + } + } + return json.Marshal(discriminant) +} diff --git a/agent/internal/backend/api/jsonpb_test.go b/agent/internal/backend/api/jsonpb_test.go index 717fad12..93aa4d80 100644 --- a/agent/internal/backend/api/jsonpb_test.go +++ b/agent/internal/backend/api/jsonpb_test.go @@ -213,6 +213,26 @@ func TestStruct(t *testing.T) { require.Equal(t, parsedPB, pb) } +func TestRuleDataValue(t *testing.T) { + t.Run("CustomErrorPage", func(t *testing.T) { + msg := &api.RuleDataEntry{ + Value: &api.CustomErrorPageRuleDataEntry{StatusCode: 33}, + } + + // Check it can be marshaled to the expected JSON struct. + buf, err := json.Marshal(msg) + require.NoError(t, err) + + // Check it can be unmarshaled back to json. + parsed := new(api.RuleDataEntry) + err = json.Unmarshal(buf, parsed) + require.NoError(t, err) + + // Check both are equal + require.Equal(t, parsed, msg) + }) +} + func FuzzStruct(e *api.Struct, c fuzz.Continue) { nbFields := c.Uint32() % 10 if nbFields == 0 { diff --git a/agent/internal/backend/client.go b/agent/internal/backend/client.go index 27e4d6b9..27a231fc 100644 --- a/agent/internal/backend/client.go +++ b/agent/internal/backend/client.go @@ -139,6 +139,19 @@ func (c *Client) ActionsPack() (*api.ActionsPackResponse, error) { return res, nil } +func (c *Client) RulesPack() (*api.RulesPackResponse, error) { + httpReq, err := c.newRequest(&config.BackendHTTPAPIEndpoint.RulesPack) + if err != nil { + return nil, err + } + httpReq.Header.Set(config.BackendHTTPAPIHeaderSession, c.session) + res := new(api.RulesPackResponse) + if err := c.Do(httpReq, nil, res); err != nil { + return nil, err + } + return res, nil +} + // Do performs the request whose body is pbs[0] pointer, while the expected // response is pbs[1] pointer. They are optional, and must be used according to // the cases request case. diff --git a/agent/internal/backend/client_test.go b/agent/internal/backend/client_test.go index cbe3fa2c..67ac43cf 100644 --- a/agent/internal/backend/client_test.go +++ b/agent/internal/backend/client_test.go @@ -23,14 +23,15 @@ import ( var ( logger = plog.NewLogger(plog.Debug, os.Stderr, 0) cfg = config.New(logger) - fuzzer = fuzz.New().Funcs(FuzzStruct, FuzzCommandRequest) + fuzzer = fuzz.New().Funcs(FuzzStruct, FuzzCommandRequest, FuzzRuleDataValue) ) func TestClient(t *testing.T) { RegisterTestingT(t) - g := NewGomegaWithT(t) t.Run("AppLogin", func(t *testing.T) { + g := NewGomegaWithT(t) + token := testlib.RandString(2, 50) appName := testlib.RandString(2, 50) @@ -59,6 +60,8 @@ func TestClient(t *testing.T) { }) t.Run("AppBeat", func(t *testing.T) { + g := NewGomegaWithT(t) + statusCode := http.StatusOK endpointCfg := &config.BackendHTTPAPIEndpoint.AppBeat @@ -77,6 +80,8 @@ func TestClient(t *testing.T) { }) t.Run("Batch", func(t *testing.T) { + g := NewGomegaWithT(t) + statusCode := http.StatusOK endpointCfg := &config.BackendHTTPAPIEndpoint.Batch @@ -94,6 +99,8 @@ func TestClient(t *testing.T) { }) t.Run("ActionsPack", func(t *testing.T) { + g := NewGomegaWithT(t) + statusCode := http.StatusOK endpointCfg := &config.BackendHTTPAPIEndpoint.ActionsPack @@ -110,7 +117,28 @@ func TestClient(t *testing.T) { g.Expect(res).Should(Equal(response)) }) + t.Run("RulesPack", func(t *testing.T) { + g := NewGomegaWithT(t) + + statusCode := http.StatusOK + + endpointCfg := &config.BackendHTTPAPIEndpoint.RulesPack + + response := NewRandomRulesPackResponse() + + client, server := initFakeServerSession(endpointCfg, nil, response, statusCode, nil) + defer server.Close() + + res, err := client.RulesPack() + g.Expect(err).NotTo(HaveOccurred()) + // A request has been received + g.Expect(len(server.ReceivedRequests())).ToNot(Equal(0)) + g.Expect(res).Should(Equal(response)) + }) + t.Run("AppLogout", func(t *testing.T) { + g := NewGomegaWithT(t) + statusCode := http.StatusOK endpointCfg := &config.BackendHTTPAPIEndpoint.AppLogout @@ -282,6 +310,12 @@ func NewRandomActionsPackResponse() *api.ActionsPackResponse { return pb } +func NewRandomRulesPackResponse() *api.RulesPackResponse { + pb := new(api.RulesPackResponse) + fuzzer.Fuzz(pb) + return pb +} + func FuzzStruct(e *api.Struct, c fuzz.Continue) { v := struct { A string @@ -305,3 +339,9 @@ func FuzzCommandRequest(e *api.CommandRequest, c fuzz.Continue) { c.Fuzz(&e.Name) c.Fuzz(&e.Uuid) } + +func FuzzRuleDataValue(e *api.RuleDataEntry, c fuzz.Continue) { + v := &api.CustomErrorPageRuleDataEntry{} + c.Fuzz(&v.StatusCode) + e.Value = v +} diff --git a/agent/internal/config/config.go b/agent/internal/config/config.go index d6b1e522..87e34c3e 100644 --- a/agent/internal/config/config.go +++ b/agent/internal/config/config.go @@ -49,7 +49,7 @@ var ( // List of endpoint addresses, relative to the base URL. BackendHTTPAPIEndpoint = struct { - AppLogin, AppLogout, AppBeat, AppException, Batch, ActionsPack HTTPAPIEndpoint + AppLogin, AppLogout, AppBeat, AppException, Batch, ActionsPack, RulesPack HTTPAPIEndpoint }{ AppLogin: HTTPAPIEndpoint{http.MethodPost, "/sqreen/v1/app-login"}, AppLogout: HTTPAPIEndpoint{http.MethodGet, "/sqreen/v0/app-logout"}, @@ -57,8 +57,9 @@ var ( AppException: HTTPAPIEndpoint{http.MethodPost, "/sqreen/v0/app_sqreen_exception"}, Batch: HTTPAPIEndpoint{http.MethodPost, "/sqreen/v0/batch"}, ActionsPack: HTTPAPIEndpoint{http.MethodGet, "/sqreen/v0/actionspack"}, + RulesPack: HTTPAPIEndpoint{http.MethodGet, "/sqreen/v0/rulespack"}, } - + // Header name of the API token. BackendHTTPAPIHeaderToken = "X-Api-Key" From e255cd66c4404d1df8d61dda85804ea2c91d054a Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Tue, 18 Jun 2019 14:46:31 +0200 Subject: [PATCH 02/47] agent/command: add the rules reloading command Use the /rulesreload endpoint in order to get the rules when the `rules_reload` command is received. --- agent/internal/command.go | 6 ++++++ agent/internal/command_test.go | 13 +++++++++++++ 2 files changed, 19 insertions(+) diff --git a/agent/internal/command.go b/agent/internal/command.go index 66fd2ef2..b91f6389 100644 --- a/agent/internal/command.go +++ b/agent/internal/command.go @@ -27,6 +27,7 @@ type CommandManagerAgent interface { InstrumentationDisable() error ActionsReload() error SetCIDRWhitelist([]string) error + RulesReload() error } func NewCommandManager(agent CommandManagerAgent, logger *plog.Logger) *CommandManager { @@ -41,6 +42,7 @@ func NewCommandManager(agent CommandManagerAgent, logger *plog.Logger) *CommandM "instrumentation_remove": mng.InstrumentationRemove, "actions_reload": mng.ActionsReload, "ips_whitelist": mng.IPSWhitelist, + "rules_reload": mng.RulesReload, } return mng @@ -105,6 +107,10 @@ func (m *CommandManager) IPSWhitelist(args []json.RawMessage) error { return m.agent.SetCIDRWhitelist(cidrs) } +func (m *CommandManager) RulesReload([]json.RawMessage) error { + return m.agent.RulesReload() +} + // commandResult converts an error to a command result API object. func commandResult(logger *plog.Logger, err error) api.CommandResult { if err != nil { diff --git a/agent/internal/command_test.go b/agent/internal/command_test.go index 04e7a27a..b6946123 100644 --- a/agent/internal/command_test.go +++ b/agent/internal/command_test.go @@ -78,6 +78,10 @@ func TestCommandManager(t *testing.T) { {json.RawMessage(`["a", "b", "c"]`), json.RawMessage(`["a", "b", "c"]`)}, }, }, + { + Command: "rules_reload", + AgentExpectedCall: agent.ExpectRulesReload, + }, } for _, tc := range testCases { @@ -285,6 +289,11 @@ func (a *agentMockup) SetCIDRWhitelist(cidrs []string) error { return ret.Error(0) } +func (a *agentMockup) RulesReload() error { + ret := a.Called() + return ret.Error(0) +} + func (a *agentMockup) ExpectInstrumentationEnable(...interface{}) *mock.Call { return a.On("InstrumentationEnable") } @@ -300,3 +309,7 @@ func (a *agentMockup) ExpectActionsReload(...interface{}) *mock.Call { func (a *agentMockup) ExpectSetCIDRWhitelist(args ...interface{}) *mock.Call { return a.On("SetCIDRWhitelist", args...) } + +func (a *agentMockup) ExpectRulesReload(...interface{}) *mock.Call { + return a.On("RulesReload") +} From b67257a8f502f936f838764fce513318cdabe341 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Tue, 18 Jun 2019 14:50:00 +0200 Subject: [PATCH 03/47] agent: add the top-level rule management The agent gets the list of rules when it receives the `rules_reload` command and enables them when it receives the `instrumentation_enable` command. Rules are disabled by the `instrumentation_disable` command. The rule engine provides this required interface. --- agent/internal/agent.go | 41 +++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/agent/internal/agent.go b/agent/internal/agent.go index dacc90f2..53423453 100644 --- a/agent/internal/agent.go +++ b/agent/internal/agent.go @@ -120,17 +120,17 @@ func Start() { } type Agent struct { - logger *plog.Logger - eventMng *eventManager - metricsMng *metricsManager - ctx context.Context - cancel context.CancelFunc - isDone chan struct{} - config *config.Config - appInfo *app.Info - client *backend.Client - actors *actor.Store - rulespackId string + logger *plog.Logger + eventMng *eventManager + metricsMng *metricsManager + ctx context.Context + cancel context.CancelFunc + isDone chan struct{} + config *config.Config + appInfo *app.Info + client *backend.Client + actors *actor.Store + rules *rule.Engine } // Error channel buffer length. @@ -156,6 +156,7 @@ func New(cfg *config.Config) *Agent { appInfo: app.NewInfo(logger), client: backend.NewClient(cfg.BackendHTTPAPIBaseURL(), cfg, logger), actors: actor.NewStore(logger), + rules: rules.NewEngine(logger), } } @@ -202,7 +203,7 @@ func (a *Agent) Serve() error { // Create the command manager to process backend commands commandMng := NewCommandManager(a, a.logger) - // Process commands that may have been received on login. + // Process commands that may have been received at login. commandResults := commandMng.Do(appLoginRes.Commands) heartbeat := time.Duration(appLoginRes.Features.HeartbeatDelay) * time.Second @@ -213,7 +214,6 @@ func (a *Agent) Serve() error { a.logger.Info("up and running - heartbeat set to ", heartbeat) ticker := time.Tick(heartbeat) - a.rulespackId = appLoginRes.PackId batchSize := int(appLoginRes.Features.BatchSize) if batchSize == 0 { batchSize = config.MaxEventsPerHeatbeat @@ -275,6 +275,8 @@ func (a *Agent) Serve() error { } func (a *Agent) InstrumentationEnable() error { + a.ReloadRules() + a.rules.Enable() sdk.SetAgent(a) a.logger.Info("instrumentation enabled") return nil @@ -284,8 +286,9 @@ func (a *Agent) InstrumentationEnable() error { // now the SDK. func (a *Agent) InstrumentationDisable() error { sdk.SetAgent(nil) - a.logger.Info("instrumentation disabled") + a.rules.Disable() err := a.actors.SetActions(nil) + a.logger.Info("instrumentation disabled") return err } @@ -303,6 +306,16 @@ func (a *Agent) SetCIDRWhitelist(cidrs []string) error { return a.actors.SetCIDRWhitelist(cidrs) } +func (a *Agent) RulesReload() error { + rulespack, err := a.client.RulesPack() + if err != nil { + a.logger.Error(err) + return err + } + a.rules.SetRules(rulespack.PackID, rulespack.Rules) + return nil +} + func (a *Agent) GracefulStop() { if a.config.Disable() { return From 7f4007952bcd2f9867321cafae7fa8432631a8f2 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Tue, 18 Jun 2019 23:09:24 +0200 Subject: [PATCH 04/47] agent/rule: implement a first simple rule engine Implement the first simple rule engine going through the received rules to: 1. Find the hook using the symbol. 2. Instantiate the callback. 3. Attach the callback to the found hook. Enabling and disabling is separated in order to provide the require interface for commands `instrumentation_enable`, `instrumentation_disable` and `rules_reload`. --- agent/internal/agent.go | 9 +- agent/internal/rule/doc.go | 7 ++ agent/internal/rule/rule.go | 169 +++++++++++++++++++++++++++++++ agent/internal/rule/rule_test.go | 145 ++++++++++++++++++++++++++ 4 files changed, 327 insertions(+), 3 deletions(-) create mode 100644 agent/internal/rule/doc.go create mode 100644 agent/internal/rule/rule.go create mode 100644 agent/internal/rule/rule_test.go diff --git a/agent/internal/agent.go b/agent/internal/agent.go index 53423453..5ffb98f0 100644 --- a/agent/internal/agent.go +++ b/agent/internal/agent.go @@ -17,6 +17,7 @@ import ( "github.com/sqreen/go-agent/agent/internal/backend/api" "github.com/sqreen/go-agent/agent/internal/config" "github.com/sqreen/go-agent/agent/internal/plog" + "github.com/sqreen/go-agent/agent/internal/rule" "github.com/sqreen/go-agent/agent/sqlib/sqerrors" "github.com/sqreen/go-agent/agent/sqlib/sqsafe" "github.com/sqreen/go-agent/agent/sqlib/sqtime" @@ -156,7 +157,7 @@ func New(cfg *config.Config) *Agent { appInfo: app.NewInfo(logger), client: backend.NewClient(cfg.BackendHTTPAPIBaseURL(), cfg, logger), actors: actor.NewStore(logger), - rules: rules.NewEngine(logger), + rules: rule.NewEngine(logger), } } @@ -275,7 +276,9 @@ func (a *Agent) Serve() error { } func (a *Agent) InstrumentationEnable() error { - a.ReloadRules() + if err := a.RulesReload(); err != nil { + return err + } a.rules.Enable() sdk.SetAgent(a) a.logger.Info("instrumentation enabled") @@ -439,5 +442,5 @@ func (a *Agent) AddExceptionEvent(e *ExceptionEvent) { } func (a *Agent) RulespackID() string { - return a.rulespackId + return a.rules.PackID() } diff --git a/agent/internal/rule/doc.go b/agent/internal/rule/doc.go new file mode 100644 index 00000000..e8608108 --- /dev/null +++ b/agent/internal/rule/doc.go @@ -0,0 +1,7 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +// This package manages the rules by instantiating the callbacks and attaching +// them to their corresponding hooks. +package rule diff --git a/agent/internal/rule/rule.go b/agent/internal/rule/rule.go new file mode 100644 index 00000000..72ce6c47 --- /dev/null +++ b/agent/internal/rule/rule.go @@ -0,0 +1,169 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +// Package rule implements the engine to manage rules. +// +// Main requirements: +// - Rules can be globally enabled or disabled, independently from setting +// the list of rules. +// - Rule hookpoints can be undefined, ie. the backend sent more rules than +// actually required. +// - Errors regarding hookpoint or callbacks should be handled. +// - Setting new rules when already enabled and having active rules should be +// atomic at the hook level. For example, having a new SQLi rule should not +// introduce a time when it is disabled, but should instead be replaced with +// the new one atomically. +package rule + +import ( + "fmt" + + "github.com/sqreen/go-agent/agent/internal/backend/api" + "github.com/sqreen/go-agent/agent/internal/config" + "github.com/sqreen/go-agent/agent/internal/plog" + "github.com/sqreen/go-agent/agent/sqlib/sqerrors" + "github.com/sqreen/go-agent/agent/sqlib/sqhook" +) + +type Engine struct { + logger Logger + // Map rules to their corresponding symbol in order to be able to modify them + // at run time by atomically replacing a running rule. + rules ruleDescriptors + packID string + cfg *config.Config + enabled bool +} + +// Logger interface required by this package. +type Logger interface { + plog.DebugLogger + plog.ErrorLogger +} + +// NewEngine returns a new rule engine. +func NewEngine(logger Logger) *Engine { + return &Engine{ + logger: logger, + } +} + +// PackID returns the ID of the current pack of rules. +func (e *Engine) PackID() string { + return e.packID +} + +// SetRules set the currents rules. If rules were already set, it will replace +// them by atomically modifying the hooks, and removing what is left. +func (e *Engine) SetRules(packID string, rules []api.Rule) { + // Create the net rule descriptors and replace the existing ones + ruleDescriptors := newRuleDescriptors(e.logger, rules) + e.setRules(packID, ruleDescriptors) +} + +func (e *Engine) setRules(packID string, descriptors ruleDescriptors) { + for symbol, rule := range descriptors { + if e.enabled { + // TODO: chain multiple callbacks per hookpoint using a callback of callbacks + // Attach the callback to the hook + err := rule.hook.Attach(rule.prolog, rule.epilog) + if err != nil { + e.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule `%s`: could not attach the callbacks", rule.name))) + continue + } + } + // Remove from the previous rules pack the entries that were redefined in + // this one. + delete(e.rules, symbol) + } + // Disable previously enabled rules that were not replaced by new ones. + for _, rule := range e.rules { + err := rule.hook.Attach(nil, nil) + if err != nil { + e.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule `%s`: could not attach the callbacks", rule.name))) + continue + } + } + // Save the rules pack ID and the list of enabled hooks + e.packID = packID + e.rules = descriptors +} + +// newRuleDescriptors walks the list of received rules and creates the map of +// rule descriptors indexed by their symbol. A rule descriptor contains all it +// needs to enable and disable rules at run time. +func newRuleDescriptors(logger Logger, rules []api.Rule) ruleDescriptors { + // Create and configure the list of callbacks according to the given rules + ruleDescriptors := make(ruleDescriptors) + for _, r := range rules { + hookpoint := r.Hookpoint + // Find the symbol + symbol := fmt.Sprintf("%s.%s", hookpoint.Class, hookpoint.Method) + hook := sqhook.Find(symbol) + if hook == nil { + logger.Debugf("rule `%s` ignored: symbol `%s` cannot be hooked", r.Name, symbol) + continue + } + // Get the callback data from the API message + var data []interface{} + if nbData := len(r.Data.Values); nbData > 0 { + data = make([]interface{}, 0, nbData) + for _, e := range r.Data.Values { + data = append(data, e.Value) + } + } + // Instantiate the callback + prolog, epilog, err := NewCallbacks(hookpoint.Callback, data) + if err != nil { + logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule `%s`: could not instantiate the callbacks", r.Name))) + continue + } + // Create the rule descriptor with everything required to be able to enable + // or disable it afterwards. + ruleDescriptors.Add(symbol, ruleDescriptor{ + name: r.Name, + hook: hook, + prolog: prolog, + epilog: epilog, + }) + } + if len(ruleDescriptors) == 0 { + return nil + } + return ruleDescriptors +} + +// Enable the hooks of the ongoing configured rules. +func (e *Engine) Enable() { + for _, r := range e.rules { + err := r.hook.Attach(r.prolog, r.epilog) + if err != nil { + e.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule `%s`: could not attach the callbacks", r.name))) + } + } + e.enabled = true +} + +// Disable the hooks currently attached to callbacks. +func (e *Engine) Disable() { + for _, r := range e.rules { + err := r.hook.Attach(nil, nil) + if err != nil { + e.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule `%s`: could not disable the callbacks", r.name))) + } + } + e.enabled = false +} + +type ruleDescriptors map[string]ruleDescriptor + +type ruleDescriptor struct { + name string + hook *sqhook.Hook + epilog, prolog sqhook.Callback +} + +func (m ruleDescriptors) Add(symbol string, descriptor ruleDescriptor) { + m[symbol] = descriptor +} diff --git a/agent/internal/rule/rule_test.go b/agent/internal/rule/rule_test.go new file mode 100644 index 00000000..b91583e5 --- /dev/null +++ b/agent/internal/rule/rule_test.go @@ -0,0 +1,145 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package rule_test + +import ( + "net/http" + "os" + "reflect" + "testing" + + "github.com/sqreen/go-agent/agent/internal/backend/api" + "github.com/sqreen/go-agent/agent/internal/plog" + "github.com/sqreen/go-agent/agent/internal/rule" + "github.com/sqreen/go-agent/agent/sqlib/sqhook" + "github.com/stretchr/testify/require" +) + +func func1(_ http.ResponseWriter, _ *http.Request, _ http.Header, _ int, _ []byte) {} +func func2(_ http.ResponseWriter, _ *http.Request, _ http.Header, _ int, _ []byte) {} + +type empty struct{} + +func TestEngineUsage(t *testing.T) { + logger := plog.NewLogger(plog.Debug, os.Stderr, 0) + engine := rule.NewEngine(logger) + hookFunc1 := sqhook.New(func1) + require.NotNil(t, hookFunc1) + hookFunc2 := sqhook.New(func2) + require.NotNil(t, hookFunc2) + + t.Run("empty state", func(t *testing.T) { + require.Empty(t, engine.PackID()) + engine.SetRules("my pack id", nil) + require.Equal(t, engine.PackID(), "my pack id") + // No problem enabling/disabling the engine + engine.Enable() + engine.Disable() + engine.Enable() + engine.SetRules("my other pack id", []api.Rule{}) + require.Equal(t, engine.PackID(), "my other pack id") + }) + + t.Run("multiple rules", func(t *testing.T) { + engine.Disable() + engine.SetRules("yet another pack id", []api.Rule{ + { + Name: "a valid rule", + Hookpoint: api.Hookpoint{ + Class: reflect.TypeOf(empty{}).PkgPath(), + Method: "func1", + Callback: "WriteCustomErrorPage", + }, + }, + { + Name: "another valid rule", + Hookpoint: api.Hookpoint{ + Class: reflect.TypeOf(empty{}).PkgPath(), + Method: "func2", + Callback: "WriteCustomErrorPage", + }, + }, + }) + + t.Run("callbacks are not attached when disabled", func(t *testing.T) { + // Check the callbacks were not attached because rules are disabled + prologFunc1, epilogFunc1 := hookFunc1.Callbacks() + require.Nil(t, prologFunc1) + require.Nil(t, epilogFunc1) + prologFunc2, epilogFunc2 := hookFunc2.Callbacks() + require.Nil(t, prologFunc2) + require.Nil(t, epilogFunc2) + }) + + t.Run("enabling the rules attaches the callbacks", func(t *testing.T) { + // Enable the rules + engine.Enable() + // Check the callbacks were now attached + prologFunc1, epilogFunc1 := hookFunc1.Callbacks() + require.NotNil(t, prologFunc1) + require.Nil(t, epilogFunc1) + prologFunc2, epilogFunc2 := hookFunc2.Callbacks() + require.NotNil(t, prologFunc2) + require.Nil(t, epilogFunc2) + }) + + t.Run("disabling the rules removes the callbacks", func(t *testing.T) { + // Disable the rules + engine.Disable() + // Check the callbacks were all removed for func1 and not func2 + prologFunc1, epilogFunc1 := hookFunc1.Callbacks() + require.Nil(t, prologFunc1) + require.Nil(t, epilogFunc1) + prologFunc2, epilogFunc2 := hookFunc2.Callbacks() + require.Nil(t, prologFunc2) + require.Nil(t, epilogFunc2) + }) + + t.Run("enabling the rules again sets back the callbacks", func(t *testing.T) { + // Enable again the rules + engine.Enable() + // Check the callbacks are attached again + prologFunc1, epilogFunc1 := hookFunc1.Callbacks() + require.NotNil(t, prologFunc1) + require.Nil(t, epilogFunc1) + prologFunc2, epilogFunc2 := hookFunc2.Callbacks() + require.NotNil(t, prologFunc2) + require.Nil(t, epilogFunc2) + }) + }) + + t.Run("modify enabled rules", func(t *testing.T) { + // Modify the rules while enabled + engine.SetRules("a pack id", []api.Rule{ + { + Name: "another valid rule", + Hookpoint: api.Hookpoint{ + Class: reflect.TypeOf(empty{}).PkgPath(), + Method: "func2", + Callback: "WriteCustomErrorPage", + }, + }, + }) + // Check the callbacks were removed for func1 and not func2 + prologFunc1, epilogFunc1 := hookFunc1.Callbacks() + require.Nil(t, prologFunc1) + require.Nil(t, epilogFunc1) + prologFunc2, epilogFunc2 := hookFunc2.Callbacks() + require.NotNil(t, prologFunc2) + require.Nil(t, epilogFunc2) + }) + + t.Run("replace the enabled rules with an empty array of rules", func(t *testing.T) { + // Set the rules with an empty array while enabled + engine.SetRules("yet another pack id", []api.Rule{}) + // Check the callbacks were all removed for func1 and not func2 + prologFunc1, epilogFunc1 := hookFunc1.Callbacks() + require.Nil(t, prologFunc1) + require.Nil(t, epilogFunc1) + prologFunc2, epilogFunc2 := hookFunc2.Callbacks() + require.Nil(t, prologFunc2) + require.Nil(t, epilogFunc2) + }) +} From f2e0a88b3abb3d8db969e4e65387064013050940 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Wed, 26 Jun 2019 11:40:26 +0200 Subject: [PATCH 05/47] agent/sqlib/sqhook: add a pure Go hook library Package sqhook provides a pure Go implementation of hooks to be inserted into function definitions in order to be able to attach prolog and epilog callbacks giving read/write access to the arguments and returned values of the function call at run time. --- agent/sqlib/sqhook/hook.go | 285 +++++++++++++++++++++++++ agent/sqlib/sqhook/hook_test.go | 368 ++++++++++++++++++++++++++++++++ 2 files changed, 653 insertions(+) create mode 100644 agent/sqlib/sqhook/hook.go create mode 100644 agent/sqlib/sqhook/hook_test.go diff --git a/agent/sqlib/sqhook/hook.go b/agent/sqlib/sqhook/hook.go new file mode 100644 index 00000000..5c318916 --- /dev/null +++ b/agent/sqlib/sqhook/hook.go @@ -0,0 +1,285 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +// Package sqhook provides a pure Go implementation of hooks to be inserted +// into function definitions in order to be able to attach prolog and epilog +// callbacks giving read/write access to the arguments and returned values of +// the function call at run time. +// +// A hook needs to be globally created and associated to a function symbol at +// package initialization time. Prolog and epilog callbacks can then be +// accessed in the function call to pass the call arguments and return values. +// +// On the other side, callbacks can be attached to a hook at run time. Prolog +// callbacks get read/write access to the arguments of the function call before +// it gets executed, while epilog callbacks get read/write access to the return +// values before returning from the call. Therefore, the callbacks' signature +// need to match the function signature. +// +// Given a function F: +// func F(A, B, C) (R, S, T) +// The expected prolog signature is: +// type prolog = func(*sqhook.Context, *A, *B, *C) error +// The expected epilog signature is: +// type epilog = func(*sqhook.Context, *R, *S, *T) +// +// Example: +// // Define the hook globally +// var exampleHook *sqhook.Hook +// +// // Initialization needs to be done in the init() function because of some +// // Go initialization limitations. +// func init() { +// exampleHook = sqhook.New(Example) +// } +// +// func Example(arg1 int, arg2 string) (ret1 []byte, ret2 error) { +// // Use the hook first and call its callbacks +// { +// type Prolog = func(*sqhook.Context, *int, *string) error +// type Epilog = func(*sqhook.Context, *[]byte, *error) +// // Create a call context +// ctx := sqhook.Context{} +// prolog, epilog := exampleHook.Callbacks() +// // If an epilog is set, defer the call to the epilog +// if epilog, ok := epilog.(Epilog); ok { +// // Pass pointers to the return values +// defer epilog(&ctx, &ret1, &ret2) +// } +// // If a prolog is set, call it +// if prolog, ok := prolog.(Prolog); ok { +// // Pass pointers to the arguments +// err := prolog(&ctx, &w, &r, &headers, &statusCode, &body) +// // If an error is returned, the function execution is aborted. +// // The deferred epilog call will still be executed before returning. +// if err != nil { +// return +// } +// } +// } +// // .. function code ... +// } +// +// +// Main requirements: +// - Concurrent access and modification of callbacks. +// - Reentrant implementation of callbacks with a call context when data needs +// to be shared between the prolog and epilog. +// +// - Fast call dispatch for callbacks that don't need to be generic, ie. +// callbacks that are designed to be attached to specific functions. +// Type-assertion instead of `reflect.Call` is therefore used while generic +// callbacks that are not tied a specific function will be attached using +// `reflect.MakeFunc` in order to match the function signature. The usage +// of dynamic calls using `reflect` is indeed much slower and consumes +// memory. +// +// Design constraints: +// +// - There are no compilation-time functions or macros that would have allowed +// to provide helpers setting up the hooks in the function definitions. +// - Access and modification of callbacks need to be atomic. +// - There are no way to add custom sections to the binary file, which would +// have made possible defining the index of hooks at compilation-time ( +// things that can be easily done with GCC). +// +package sqhook + +import ( + "fmt" + "reflect" + "runtime" + "sync/atomic" + "unsafe" + + "github.com/sqreen/go-agent/agent/sqlib/sqerrors" +) + +var index = make(map[string]*Hook) + +type Hook struct { + // The function type where the hook is used. + fnType reflect.Type + // Pointer to a structure containing the callbacks in order to be able to + // atomically modify the pointer. + attached *callbacks +} + +type callbacks struct { + prolog, epilog Callback +} + +// Callback is a function expecting a Context pointer as first argument, +// followed by the pointers to the arguments of the hooked function for a +// prolog, and followed by the pointers to the returned values for an epilog. +type Callback interface{} + +// Context is a call context for the hook. It is shared between the prolog and +// epilog and is unique for each function call. It allows callbacks to provide +// reentrant implementations when memory needs to be shared for a given call. +type Context []interface{} + +// MethodReceiver is store in the context when hooking a method. +type MethodReceiver interface{} + +type Error int + +// Errors that hooks can return in order to modify the control flow of the +// function. +const ( + _ Error = iota + // Abort the execution of the function by returning from it. + AbortError +) + +func (e Error) Error() string { + return fmt.Sprintf("Error(%d)", e) +} + +// Static assertion that `Error` implements interface `error` +var _ error = Error(0) + +// New returns a hook for function `fn` to be used in the function definition +// in order to be able to attach callbacks to it. It returns nil if the fn is +// not a non-nil function or if the symbol name of `fn` cannot be retrieved. +func New(fn interface{}) *Hook { + // Check fn is a non-nil function value. + if fn == nil { + return nil + } + v := reflect.ValueOf(fn) + fnType := v.Type() + if fnType.Kind() != reflect.Func { + return nil + } + // If the symbol name cannot be retrieved + symbol := runtime.FuncForPC(v.Pointer()).Name() + if symbol == "" { + return nil + } + // Create the hook, store it in the map and return it. + hook := &Hook{ + fnType: fnType, + } + index[symbol] = hook + return hook +} + +// Find returns the hook associated to the given symbol string when it was +// created using `New()`, nil otherwise. +func Find(symbol string) *Hook { + return index[symbol] +} + +// Attach atomically attaches prolog and epilog callbacks to the hook. It is +// possible to pass nil values when only one type of callback is required. If +// both arguments are nil, the callbacks are removed. +func (h *Hook) Attach(prolog, epilog Callback) error { + if h == nil { + return sqerrors.New("cannot attach callbacks to a nil hook") + } + var cbs *callbacks + if prolog != nil || epilog != nil { + cbs = &callbacks{} + if prolog != nil { + // Create the list of argument types + argTypes := make([]reflect.Type, 0, h.fnType.NumIn()) + for i := 0; i < h.fnType.NumIn(); i++ { + argTypes = append(argTypes, h.fnType.In(i)) + } + if err := validateProlog(prolog, argTypes); err != nil { + return err + } + cbs.prolog = prolog + } + if epilog != nil { + // Create the list of return types + retTypes := make([]reflect.Type, 0, h.fnType.NumOut()) + for i := 0; i < h.fnType.NumOut(); i++ { + retTypes = append(retTypes, h.fnType.Out(i)) + } + if err := validateEpilog(epilog, retTypes); err != nil { + return err + } + cbs.epilog = epilog + } + } + atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&h.attached)), unsafe.Pointer(cbs)) + return nil +} + +// Callbacks atomically accesses the attached prolog and epilog callbacks. +func (h *Hook) Callbacks() (prolog, epilog Callback) { + if h == nil { + return nil, nil + } + attached := (*callbacks)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&h.attached)))) + if attached == nil { + return nil, nil + } + return attached.prolog, attached.epilog +} + +// validateProlog validates that the prolog has the expected signature. +func validateProlog(prolog Callback, argTypes []reflect.Type) error { + if err := validateCallback(prolog, argTypes); err != nil { + return err + } + t := reflect.TypeOf(prolog) + if t.NumOut() != 1 || !t.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) { + return sqerrors.New("validation: the prolog callback should return an error value") + } + return nil +} + +// validateEpilog validates that the epilog has the expected signature. +func validateEpilog(epilog Callback, argTypes []reflect.Type) error { + if err := validateCallback(epilog, argTypes); err != nil { + return err + } + t := reflect.TypeOf(epilog) + if t.NumOut() != 0 { + return sqerrors.New("validation: the epilog callback should not return values") + } + return nil +} + +// validateCallback validates the fact that the callback is a function whose +// first argument is the hook context and the rest of its arguments can be +// assigned the hook argument values. +func validateCallback(callback Callback, argTypes []reflect.Type) (err error) { + defer func() { + if err != nil { + err = sqerrors.Wrap(err, "validation error") + } + }() + callbackType := reflect.TypeOf(callback) + // Check the callback is a function + if callbackType.Kind() != reflect.Func { + return sqerrors.New("the callback argument is not a function") + } + callbackArgc := callbackType.NumIn() + // Check the callback accepts a hook context as first argument + if callbackArgc < 1 { + return sqerrors.New("the callback should expect a hook context as first argument") + } + if !reflect.TypeOf((*Context)(nil)).AssignableTo(callbackType.In(0)) { + return sqerrors.New("the callback should expect a hook context as first argument") + } + // Check the argument count + fnArgc := len(argTypes) + if callbackArgc-1 != fnArgc && callbackArgc != fnArgc { + return sqerrors.Errorf("the callback arguments count `%d` is not compatible to the hook arguments count `%d`", callbackArgc, fnArgc) + } + // Check arguments are assignable + var i int + for i = 1; i < callbackArgc; i++ { + argPtrType := reflect.PtrTo(argTypes[i-1]) + callbackArgType := callbackType.In(i) + if !argPtrType.AssignableTo(callbackArgType) { + return sqerrors.Errorf("hook argument `%d` of type `%s` cannot be assigned to the callback argument `%d` of type `%s`", i-1, argPtrType, i, callbackArgType) + } + } + return nil +} diff --git a/agent/sqlib/sqhook/hook_test.go b/agent/sqlib/sqhook/hook_test.go new file mode 100644 index 00000000..18dd8890 --- /dev/null +++ b/agent/sqlib/sqhook/hook_test.go @@ -0,0 +1,368 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package sqhook_test + +import ( + "errors" + "fmt" + "reflect" + "testing" + + fuzz "github.com/google/gofuzz" + "github.com/sqreen/go-agent/agent/sqlib/sqhook" + "github.com/stretchr/testify/require" +) + +type example struct{} + +func (example) method() {} +func (example) ExportedMethod() {} +func (*example) methodPointerReceiver() {} + +func function(_ int, _ string, _ bool) error { return nil } +func ExportedFunction(_ int, _ string, _ bool) error { return nil } + +func TestNew(t *testing.T) { + for _, tc := range []struct { + value interface{} + shouldSucceed bool + }{ + {func() {}, true}, + {nil, false}, + {(func())(nil), false}, + {example.method, true}, + {example.ExportedMethod, true}, + {(*example).methodPointerReceiver, true}, + {function, true}, + {ExportedFunction, true}, + {33, false}, + } { + tc := tc + t.Run(fmt.Sprintf("%T", tc.value), func(t *testing.T) { + hook := sqhook.New(tc.value) + if tc.shouldSucceed { + require.NotNil(t, hook) + } else { + require.Nil(t, hook) + } + }) + } +} + +func TestFind(t *testing.T) { + pkgName := reflect.TypeOf(example{}).PkgPath() + for _, tc := range []struct { + value interface{} + symbol string + }{ + {example.method, "example.method"}, + {example.ExportedMethod, "example.ExportedMethod"}, + {(*example).methodPointerReceiver, "(*example).methodPointerReceiver"}, + {function, "function"}, + {ExportedFunction, "ExportedFunction"}, + } { + tc := tc + t.Run(fmt.Sprintf("%T", tc.value), func(t *testing.T) { + hook := sqhook.New(tc.value) + require.NotNil(t, hook) + got := sqhook.Find(pkgName + "." + tc.symbol) + require.NotNil(t, got) + require.Equal(t, hook, got) + }) + } +} + +func TestAttach(t *testing.T) { + for _, tc := range []struct { + function, expectedProlog, expectedEpilog interface{} + notExpectedPrologs, notExpectedEpilogs []interface{} + }{ + { + function: func() {}, + expectedProlog: func(*sqhook.Context) error { return nil }, + expectedEpilog: func(*sqhook.Context) {}, + notExpectedPrologs: []interface{}{ + func() error { return nil }, + func(*sqhook.Context) {}, + func() {}, + }, + notExpectedEpilogs: []interface{}{ + func() error { return nil }, + func(*sqhook.Context) error { return nil }, + func() {}, + }, + }, + { + function: example.method, + expectedProlog: func(*sqhook.Context) error { return nil }, + expectedEpilog: func(*sqhook.Context) {}, + notExpectedPrologs: []interface{}{ + func() error { return nil }, + func(*sqhook.Context) {}, + func() {}, + }, + notExpectedEpilogs: []interface{}{ + func() error { return nil }, + func(*sqhook.Context) error { return nil }, + func() {}, + }, + }, + { + function: example.ExportedMethod, + expectedProlog: func(*sqhook.Context) error { return nil }, + expectedEpilog: func(*sqhook.Context) {}, + notExpectedPrologs: []interface{}{ + func() error { return nil }, + func(*sqhook.Context) {}, + func() {}, + }, + notExpectedEpilogs: []interface{}{ + func() error { return nil }, + func(*sqhook.Context) error { return nil }, + func() {}, + }, + }, + { + function: (*example).methodPointerReceiver, + expectedProlog: func(*sqhook.Context) error { return nil }, + expectedEpilog: func(*sqhook.Context) {}, + notExpectedPrologs: []interface{}{ + func() error { return nil }, + func(*sqhook.Context) {}, + func() {}, + }, + notExpectedEpilogs: []interface{}{ + func() error { return nil }, + func(*sqhook.Context) error { return nil }, + func() {}, + }, + }, + { + function: function, + expectedProlog: func(*sqhook.Context, *int, *string, *bool) error { return nil }, + expectedEpilog: func(*sqhook.Context, *error) {}, + notExpectedPrologs: []interface{}{ + func(*sqhook.Context, *int, *bool, *bool) error { return nil }, + func(*int, *string, *bool) error { return nil }, + func(*sqhook.Context, *int, *string, *bool) {}, + func(*sqhook.Context, int, string, bool) error { return nil }, + func(*sqhook.Context, *int, *bool) error { return nil }, + func(*sqhook.Context, *int) {}, + }, + notExpectedEpilogs: []interface{}{ + func(*error) {}, + func(*sqhook.Context, error) {}, + func() {}, + }, + }, + { + function: ExportedFunction, + expectedProlog: func(*sqhook.Context, *int, *string, *bool) error { return nil }, + expectedEpilog: func(*sqhook.Context, *error) {}, + notExpectedPrologs: []interface{}{ + func(*sqhook.Context, *int, *string, *string) error { return nil }, + func(*int, *string, *bool) error { return nil }, + func(*sqhook.Context, *int, *string, *bool) {}, + func(*sqhook.Context, int, string, bool) error { return nil }, + func(*sqhook.Context, *int, *bool) error { return nil }, + }, + notExpectedEpilogs: []interface{}{ + func(*error) {}, + func(*sqhook.Context, error) {}, + func(*sqhook.Context, *int) {}, + func() {}, + }, + }, + } { + tc := tc + t.Run(fmt.Sprintf("%T", tc.function), func(t *testing.T) { + t.Run("expected callbacks", func(t *testing.T) { + t.Run("non-nil prolog and epilog", func(t *testing.T) { + hook := sqhook.New(tc.function) + require.NotNil(t, hook) + err := hook.Attach(tc.expectedProlog, tc.expectedEpilog) + require.NoError(t, err) + prolog, epilog := hook.Callbacks() + require.Equal(t, reflect.ValueOf(prolog).Pointer(), reflect.ValueOf(tc.expectedProlog).Pointer()) + require.Equal(t, reflect.ValueOf(epilog).Pointer(), reflect.ValueOf(tc.expectedEpilog).Pointer()) + }) + t.Run("nil prolog", func(t *testing.T) { + hook := sqhook.New(tc.function) + require.NotNil(t, hook) + err := hook.Attach(nil, tc.expectedEpilog) + require.NoError(t, err) + prolog, epilog := hook.Callbacks() + require.Nil(t, prolog) + require.Equal(t, reflect.ValueOf(epilog).Pointer(), reflect.ValueOf(tc.expectedEpilog).Pointer()) + }) + t.Run("nil epilog", func(t *testing.T) { + hook := sqhook.New(tc.function) + require.NotNil(t, hook) + err := hook.Attach(tc.expectedProlog, nil) + require.NoError(t, err) + prolog, epilog := hook.Callbacks() + require.Equal(t, reflect.ValueOf(prolog).Pointer(), reflect.ValueOf(tc.expectedProlog).Pointer()) + require.Nil(t, epilog) + }) + t.Run("nil prolog and epilog", func(t *testing.T) { + hook := sqhook.New(tc.function) + require.NotNil(t, hook) + err := hook.Attach(nil, nil) + require.NoError(t, err) + prolog, epilog := hook.Callbacks() + require.Nil(t, prolog) + require.Nil(t, epilog) + }) + }) + t.Run("not expected callbacks", func(t *testing.T) { + for _, notExpectedProlog := range tc.notExpectedPrologs { + notExpectedProlog := notExpectedProlog + t.Run(fmt.Sprintf("%T", notExpectedProlog), func(t *testing.T) { + hook := sqhook.New(tc.function) + require.NotNil(t, hook) + err := hook.Attach(notExpectedProlog, tc.expectedEpilog) + require.Error(t, err) + prolog, epilog := hook.Callbacks() + require.Nil(t, prolog) + require.Nil(t, epilog) + }) + } + for _, notExpectedEpilog := range tc.notExpectedEpilogs { + notExpectedEpilog := notExpectedEpilog + t.Run(fmt.Sprintf("%T", notExpectedEpilog), func(t *testing.T) { + hook := sqhook.New(tc.function) + require.NotNil(t, hook) + err := hook.Attach(tc.expectedProlog, notExpectedEpilog) + require.Error(t, err) + prolog, epilog := hook.Callbacks() + require.Nil(t, prolog) + require.Nil(t, epilog) + }) + } + }) + }) + } +} + +func TestEnableDisable(t *testing.T) { + hook := sqhook.New(example.ExportedMethod) + require.NotNil(t, hook) + err := hook.Attach(func(*sqhook.Context) error { return nil }, func(*sqhook.Context) {}) + require.NoError(t, err) + prolog, epilog := hook.Callbacks() + require.NotNil(t, prolog) + require.NotNil(t, epilog) + hook.Attach(nil, nil) + prolog, epilog = hook.Callbacks() + require.Nil(t, prolog) + require.Nil(t, epilog) +} + +func TestUsage(t *testing.T) { + t.Run("nil hook", func(t *testing.T) { + hook := sqhook.New(33) + require.Nil(t, hook) + err := hook.Attach("oops", "no") + require.Error(t, err) + prolog, epilog := hook.Callbacks() + require.Nil(t, prolog) + require.Nil(t, epilog) + }) + + t.Run("attaching nil", func(t *testing.T) { + hook := sqhook.New(example.method) + require.NotNil(t, hook) + err := hook.Attach(func(*sqhook.Context) error { return nil }, nil) + require.NoError(t, err) + prolog, epilog := hook.Callbacks() + require.NotNil(t, prolog) + require.Nil(t, epilog) + err = hook.Attach(nil, nil) + require.NoError(t, err) + prolog, epilog = hook.Callbacks() + require.Nil(t, prolog) + require.Nil(t, epilog) + }) + + t.Run("hooking a function and reading and writing the arguments and return values", func(t *testing.T) { + var hook *sqhook.Hook + + // Fuzz the initial call arguments, and the arguments and return values the + // callback will use to modify them. + var ( + callA, expectedA int + callB, expectedB string + callC, expectedC bool + callD, expectedD []byte + expectedE float32 + expectedF error + ) + fuzz := fuzz.New() + fuzz.Fuzz(&callA) + fuzz.Fuzz(&callB) + fuzz.Fuzz(&callC) + fuzz.Fuzz(&callD) + fuzz.Fuzz(&expectedA) + fuzz.Fuzz(&expectedB) + fuzz.Fuzz(&expectedC) + fuzz.Fuzz(&expectedD) + fuzz.Fuzz(&expectedE) + expectedF = errors.New("the error") + + example := func(a int, b string, c bool, d []byte) (e float32, f error) { + { + type Prolog = func(*sqhook.Context, *int, *string, *bool, *[]byte) error + type Epilog = func(*sqhook.Context, *float32, *error) + ctx := sqhook.Context{} + prolog, epilog := hook.Callbacks() + if epilog, ok := epilog.(Epilog); ok { + defer epilog(&ctx, &e, &f) + } + if prolog, ok := prolog.(Prolog); ok { + err := prolog(&ctx, &a, &b, &c, &d) + if err != nil { + return + } + } + } + // Check the arguments were modified + require.Equal(t, expectedA, a) + require.Equal(t, expectedB, b) + require.Equal(t, expectedC, c) + require.Equal(t, expectedD, d) + // Return some values that should get modified by the epilog callback + return 33, nil + } + + // Define a hook a attach prolog and epilog callbacks that will modify the + // arguments and return values + hook = sqhook.New(example) + require.NotNil(t, hook) + err := hook.Attach( + func(ctx *sqhook.Context, a *int, b *string, c *bool, d *[]byte) error { + require.Equal(t, callA, *a) + require.Equal(t, callB, *b) + require.Equal(t, callC, *c) + require.Equal(t, callD, *d) + // Modify the arguments + *a = expectedA + *b = expectedB + *c = expectedC + *d = expectedD + return nil + }, + func(ctx *sqhook.Context, e *float32, f *error) { + // Modify the return values + *e = expectedE + *f = expectedF + }) + require.NoError(t, err) + + e, f := example(callA, callB, callC, callD) + // Check the returned values were also modified + require.Equal(t, expectedE, e) + require.Equal(t, expectedF, f) + }) +} From 3fdfdbb0dd5dff2ecbd697b5792866f2b0121a42 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Wed, 26 Jun 2019 11:32:25 +0200 Subject: [PATCH 06/47] agent/actor: make security responses hookable - Factorize the security response HTTP handler out in a `WriteResponse()` function. - Add a hook to `WriteResponse` so that the rule's callback modifying its arguments can be attached. - The security response still works without rules. --- agent/internal/actor/http.go | 6 +-- agent/internal/httphandler/write-response.go | 48 ++++++++++++++++++++ 2 files changed, 50 insertions(+), 4 deletions(-) create mode 100644 agent/internal/httphandler/write-response.go diff --git a/agent/internal/actor/http.go b/agent/internal/actor/http.go index 7adc45af..a780e3c3 100644 --- a/agent/internal/actor/http.go +++ b/agent/internal/actor/http.go @@ -10,6 +10,7 @@ import ( "net/http" "github.com/sqreen/go-agent/agent/internal/backend/api" + "github.com/sqreen/go-agent/agent/internal/httphandler" "github.com/sqreen/go-agent/agent/types" "github.com/sqreen/go-agent/sdk" ) @@ -20,8 +21,6 @@ const ( blockUserEventName = "app.sqreen.action.block_user" ) -const sqreenBlockPage = ` Sqreen has detected an attack.

Uh Oh! Sqreen has detected an attack.

If you are the application owner, check the Sqreen dashboard for more information.

` - // NewIPActionHTTPHandler returns a HTTP handler that should be applied at the // request handler level to perform the security response. func NewIPActionHTTPHandler(action Action, ip net.IP) http.Handler { @@ -56,8 +55,7 @@ func newBlockHTTPHandler(eventName string, properties types.EventProperties) *bl func (a *blockHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { record := sdk.FromContext(r.Context()) record.TrackEvent(a.eventName).WithProperties(a.eventProperties) - w.WriteHeader(500) - _, _ = w.Write([]byte(sqreenBlockPage)) + httphandler.WriteResponse(w, r, nil, 500, nil) } // blockedIPEventProperties implements `types.EventProperties` to be marshaled diff --git a/agent/internal/httphandler/write-response.go b/agent/internal/httphandler/write-response.go new file mode 100644 index 00000000..6ca09d24 --- /dev/null +++ b/agent/internal/httphandler/write-response.go @@ -0,0 +1,48 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package httphandler + +import ( + "net/http" + + "github.com/sqreen/go-agent/agent/sqlib/sqhook" +) + +var writeResponseHook *sqhook.Hook + +func init() { + writeResponseHook = sqhook.New(WriteResponse) +} + +// WriteResponse writes an HTTP response according to the given arguments. +// The statusCode is the only mandatory argument. Headers and body can be nil. +func WriteResponse(w http.ResponseWriter, r *http.Request, headers http.Header, statusCode int, body []byte) { + { + type Prolog = func(*sqhook.Context, *http.ResponseWriter, **http.Request, *http.Header, *int, *[]uint8) error + type Epilog = func(*sqhook.Context) + ctx := sqhook.Context{} + prolog, epilog := writeResponseHook.Callbacks() + if epilog, ok := epilog.(Epilog); ok { + defer epilog(&ctx) + } + if prolog, ok := prolog.(Prolog); ok { + err := prolog(&ctx, &w, &r, &headers, &statusCode, &body) + if err != nil { + return + } + } + } + + if len(headers) != 0 { + responseHeaders := w.Header() + for k, v := range headers { + responseHeaders[k] = v + } + } + w.WriteHeader(statusCode) + if len(body) != 0 { + _, _ = w.Write(body) + } +} From 3deffd7037e1027ec8d0a12ad5f8e2e647fd0f32 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Wed, 26 Jun 2019 11:25:12 +0200 Subject: [PATCH 07/47] agent/rule/callback: implement the callbacks factory Callbacks are identified by unique strings that are used to find the corresponding callback constructor. This constructor is passed the rule data so that it can bind it to the returned callbacks. This implementation is first used with the `WriteCustomErrorPage` callback used by the `custom-error-page` rule. --- agent/internal/rule/callback.go | 29 +++++++ .../rule/callback/write-custom-error-page.go | 44 +++++++++++ .../callback/write-custom-error-page_test.go | 76 +++++++++++++++++++ agent/internal/rule/callback_test.go | 53 +++++++++++++ agent/sqlib/sqerrors/errors.go | 6 ++ 5 files changed, 208 insertions(+) create mode 100644 agent/internal/rule/callback.go create mode 100644 agent/internal/rule/callback/write-custom-error-page.go create mode 100644 agent/internal/rule/callback/write-custom-error-page_test.go create mode 100644 agent/internal/rule/callback_test.go diff --git a/agent/internal/rule/callback.go b/agent/internal/rule/callback.go new file mode 100644 index 00000000..d49375b8 --- /dev/null +++ b/agent/internal/rule/callback.go @@ -0,0 +1,29 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package rule + +import ( + "github.com/sqreen/go-agent/agent/internal/rule/callback" + "github.com/sqreen/go-agent/agent/sqlib/sqerrors" + "github.com/sqreen/go-agent/agent/sqlib/sqhook" +) + +// CallbackConstructorFunc is a function returning a callback function +// configured with the given data. The data types are known by the constructor +// that can type-assert them. +type CallbacksConstructorFunc func(data []interface{}) (prolog, epilog sqhook.Callback, err error) + +// NewCallbacks returns the prolog and epilog callbacks of the given callback +// name. And error is returned if the callback name is unknown. +func NewCallbacks(name string, data []interface{}) (prolog, epilog sqhook.Callback, err error) { + var callbacksCtor CallbacksConstructorFunc + switch name { + default: + return nil, nil, sqerrors.Errorf("undefined callback name `%s`", name) + case "WriteCustomErrorPage": + callbacksCtor = callback.NewWriteCustomErrorPageCallbacks + } + return callbacksCtor(data) +} diff --git a/agent/internal/rule/callback/write-custom-error-page.go b/agent/internal/rule/callback/write-custom-error-page.go new file mode 100644 index 00000000..3edaf3d0 --- /dev/null +++ b/agent/internal/rule/callback/write-custom-error-page.go @@ -0,0 +1,44 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package callback + +import ( + "net/http" + + "github.com/sqreen/go-agent/agent/internal/backend/api" + "github.com/sqreen/go-agent/agent/sqlib/sqerrors" + "github.com/sqreen/go-agent/agent/sqlib/sqhook" +) + +// NewWriteCustomErrorPageCallbacks returns the native prolog and epilog +// callbacks modifying the arguments of `httphandler.WriteResponse` in order to +// modify the http status code and error page that are provided by the rule's +// data. +func NewWriteCustomErrorPageCallbacks(data []interface{}) (prolog, epilog sqhook.Callback, err error) { + var statusCode = 500 + if len(data) > 0 { + d0 := data[0] + cfg, ok := d0.(*api.CustomErrorPageRuleDataEntry) + if !ok { + return nil, nil, sqerrors.Errorf("unexpected callback data type: got `%T` instead of `*api.CustomErrorPageRuleDataEntry`", d0) + } + statusCode = cfg.StatusCode + } + return newWriteCustomErrorPagePrologCallback(statusCode, []byte(blockedBySqreenPage)), nil, nil +} + +type WriteCustomErrorPagePrologCallbackType = func(*sqhook.Context, *http.ResponseWriter, **http.Request, *http.Header, *int, *[]byte) error + +// The prolog callback modifies the function arguments in order to replace the +// written status code and body. +func newWriteCustomErrorPagePrologCallback(statusCode int, body []byte) WriteCustomErrorPagePrologCallbackType { + return func(_ *sqhook.Context, _ *http.ResponseWriter, _ **http.Request, _ *http.Header, callerStatusCode *int, callerBody *[]byte) error { + *callerStatusCode = statusCode + *callerBody = body + return nil + } +} + +const blockedBySqreenPage = ` Sqreen has detected an attack.

Uh Oh! Sqreen has detected an attack.

If you are the application owner, check the Sqreen dashboard for more information.

` diff --git a/agent/internal/rule/callback/write-custom-error-page_test.go b/agent/internal/rule/callback/write-custom-error-page_test.go new file mode 100644 index 00000000..8f70917d --- /dev/null +++ b/agent/internal/rule/callback/write-custom-error-page_test.go @@ -0,0 +1,76 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package callback_test + +import ( + "testing" + + "github.com/sqreen/go-agent/agent/internal/backend/api" + "github.com/sqreen/go-agent/agent/internal/rule/callback" + "github.com/stretchr/testify/require" +) + +func TestNewWriteCustomErrorPageCallbacks(t *testing.T) { + t.Run("with incorrect data", func(t *testing.T) { + for _, data := range [][]interface{}{ + {33}, + {"yet another wrong type"}, + } { + prolog, epilog, err := callback.NewWriteCustomErrorPageCallbacks(data) + require.Error(t, err) + require.Nil(t, prolog) + require.Nil(t, epilog) + } + }) + + t.Run("with correct data", func(t *testing.T) { + for _, tc := range []struct { + testName string + data []interface{} + expectedStatusCode int + }{ + { + testName: "default behaviour with nil data", + data: nil, + expectedStatusCode: 500, + }, + { + testName: "default behaviour with empty array", + data: nil, + expectedStatusCode: 500, + }, + { + testName: "actual rule data", + data: []interface{}{ + &api.CustomErrorPageRuleDataEntry{ + StatusCode: 33, + }, + }, + expectedStatusCode: 33, + }, + } { + tc := tc + t.Run(tc.testName, func(t *testing.T) { + // Instantiate the callback with the given correct rule data + prolog, epilog, err := callback.NewWriteCustomErrorPageCallbacks(tc.data) + require.NoError(t, err) + require.NotNil(t, prolog) + require.Nil(t, epilog) + // Call it and check the behaviour follows the rule's data + actualProlog, ok := prolog.(callback.WriteCustomErrorPagePrologCallbackType) + require.True(t, ok) + var ( + statusCode int + body []byte + ) + err = actualProlog(nil, nil, nil, nil, &statusCode, &body) + // Check it behaves as expected + require.NoError(t, err) + require.Equal(t, tc.expectedStatusCode, statusCode) + require.NotNil(t, body) + }) + } + }) +} diff --git a/agent/internal/rule/callback_test.go b/agent/internal/rule/callback_test.go new file mode 100644 index 00000000..a0d1dbbb --- /dev/null +++ b/agent/internal/rule/callback_test.go @@ -0,0 +1,53 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package rule_test + +import ( + "testing" + + "github.com/sqreen/go-agent/agent/internal/backend/api" + "github.com/sqreen/go-agent/agent/internal/rule" + "github.com/stretchr/testify/require" +) + +func TestNewCallbacks(t *testing.T) { + for _, tc := range []struct { + testName string + name string + data []interface{} + shouldSucceed bool + }{ + { + testName: "not existing", + name: "iDontExist", + data: nil, + shouldSucceed: false, + }, + { + testName: "empty string", + name: "", + data: nil, + shouldSucceed: false, + }, + { + testName: "WriteCustomErrorPage", + name: "WriteCustomErrorPage", + data: []interface{}{ + &api.CustomErrorPageRuleDataEntry{}, + }, + shouldSucceed: true, + }, + } { + tc := tc + t.Run(tc.testName, func(t *testing.T) { + _, _, err := rule.NewCallbacks(tc.name, tc.data) + if tc.shouldSucceed { + require.NoError(t, err) + } else { + require.Error(t, err) + } + }) + } +} diff --git a/agent/sqlib/sqerrors/errors.go b/agent/sqlib/sqerrors/errors.go index b4b803ba..aef072ae 100644 --- a/agent/sqlib/sqerrors/errors.go +++ b/agent/sqlib/sqerrors/errors.go @@ -64,6 +64,12 @@ func New(message string) error { return WithTimestamp(errors.New(message)) } +// Errorf returns a new errors whose message is formatted by `fmt.Sprintf`. The +// returned error is annotated with a timestamp, a message and a stack trace. +func Errorf(format string, args ...interface{}) error { + return New(fmt.Sprintf(format, args...)) +} + // Wrap annotates the given error `err` with a timestamp, a message and a stack // trace. func Wrap(err error, message string) error { From f9362d0a669d85b6160baf806422984db9e44982 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Wed, 26 Jun 2019 11:42:16 +0200 Subject: [PATCH 08/47] agent/config: optional local rules json file Add a new configuration option in order to pass a JSON file containing an array of rules that will be appended to the rules received from the backend. Mostly helpful when developping new rules. --- agent/internal/agent.go | 17 +++++++++++++++++ agent/internal/config/config.go | 8 ++++++++ 2 files changed, 25 insertions(+) diff --git a/agent/internal/agent.go b/agent/internal/agent.go index 5ffb98f0..6230facd 100644 --- a/agent/internal/agent.go +++ b/agent/internal/agent.go @@ -6,6 +6,8 @@ package internal import ( "context" + "encoding/json" + "io/ioutil" "net/http" "os" "time" @@ -315,6 +317,21 @@ func (a *Agent) RulesReload() error { a.logger.Error(err) return err } + + // Insert local rules if any + localRulesJSON := a.config.LocalRulesFile() + buf, err := ioutil.ReadFile(localRulesJSON) + if err == nil { + var localRules []api.Rule + err = json.Unmarshal(buf, &localRules) + if err == nil { + rulespack.Rules = append(rulespack.Rules, localRules...) + } + } + if err != nil { + a.logger.Error(sqerrors.Wrap(err, "config: could not read the local rules file")) + } + a.rules.SetRules(rulespack.PackID, rulespack.Rules) return nil } diff --git a/agent/internal/config/config.go b/agent/internal/config/config.go index 87e34c3e..a43540eb 100644 --- a/agent/internal/config/config.go +++ b/agent/internal/config/config.go @@ -204,6 +204,7 @@ const ( configKeyBackendHTTPAPIProxy = `proxy` configKeyDisable = `disable` configKeyStripHTTPReferer = `strip_http_referer` + configKeyRules = `rules` ) // User configuration's default values. @@ -246,6 +247,7 @@ func New(logger *plog.Logger) *Config { manager.SetDefault(configKeyBackendHTTPAPIProxy, "") manager.SetDefault(configKeyDisable, "") manager.SetDefault(configKeyStripHTTPReferer, "") + manager.SetDefault(configKeyRules, "") err := manager.ReadInConfig() if err != nil { @@ -303,6 +305,12 @@ func (c *Config) StripHTTPReferer() bool { return strip != "" } +// LocalRulesFile returns a JSON file containing custom rules in an array. They +// are added to the rules received from server. +func (c *Config) LocalRulesFile() string { + return sanitizeString(c.GetString(configKeyRules)) +} + func sanitizeString(s string) string { return strings.TrimSpace(s) } From 87b7e70ff0d4fb95aa872a5d50986031774a7427 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Thu, 4 Jul 2019 16:40:31 +0200 Subject: [PATCH 09/47] sqlib/sqhook: hook creation cannot fail Creating a hook cannot return nil anymore in case of an error but rather panic because: - it would be caught at Go initialization time and should therefore be impossible to avoid, unless the program is deployed into production straight away. - it should avoid expecting hooks to be working while they don't, and therefore avoid not being able to attach a rule relying on a failed hook. Note that failed hooks are not supposed to happen. It is otherwise a bug. For these two reasons, panic-ing makes perfect sense. --- agent/sqlib/sqhook/hook.go | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/agent/sqlib/sqhook/hook.go b/agent/sqlib/sqhook/hook.go index 5c318916..db24f2f2 100644 --- a/agent/sqlib/sqhook/hook.go +++ b/agent/sqlib/sqhook/hook.go @@ -146,17 +146,17 @@ var _ error = Error(0) func New(fn interface{}) *Hook { // Check fn is a non-nil function value. if fn == nil { - return nil + panic("nil argument") } v := reflect.ValueOf(fn) fnType := v.Type() if fnType.Kind() != reflect.Func { - return nil + panic("the argument is not a function type") } // If the symbol name cannot be retrieved symbol := runtime.FuncForPC(v.Pointer()).Name() if symbol == "" { - return nil + panic("could not read the symbol name of the function") } // Create the hook, store it in the map and return it. hook := &Hook{ @@ -176,9 +176,6 @@ func Find(symbol string) *Hook { // possible to pass nil values when only one type of callback is required. If // both arguments are nil, the callbacks are removed. func (h *Hook) Attach(prolog, epilog Callback) error { - if h == nil { - return sqerrors.New("cannot attach callbacks to a nil hook") - } var cbs *callbacks if prolog != nil || epilog != nil { cbs = &callbacks{} @@ -211,9 +208,6 @@ func (h *Hook) Attach(prolog, epilog Callback) error { // Callbacks atomically accesses the attached prolog and epilog callbacks. func (h *Hook) Callbacks() (prolog, epilog Callback) { - if h == nil { - return nil, nil - } attached := (*callbacks)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&h.attached)))) if attached == nil { return nil, nil From e14af09773ee5859e62ccd1f2779d65775a3d8a0 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Thu, 4 Jul 2019 16:57:19 +0200 Subject: [PATCH 10/47] sqlib/sqhook: adapt tests to new non-nil hook requirement --- agent/sqlib/sqhook/hook_test.go | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/agent/sqlib/sqhook/hook_test.go b/agent/sqlib/sqhook/hook_test.go index 18dd8890..6a187bf7 100644 --- a/agent/sqlib/sqhook/hook_test.go +++ b/agent/sqlib/sqhook/hook_test.go @@ -41,11 +41,11 @@ func TestNew(t *testing.T) { } { tc := tc t.Run(fmt.Sprintf("%T", tc.value), func(t *testing.T) { - hook := sqhook.New(tc.value) if tc.shouldSucceed { + hook := sqhook.New(tc.value) require.NotNil(t, hook) } else { - require.Nil(t, hook) + require.Panics(t, func() { sqhook.New(tc.value) }) } }) } @@ -261,16 +261,6 @@ func TestEnableDisable(t *testing.T) { } func TestUsage(t *testing.T) { - t.Run("nil hook", func(t *testing.T) { - hook := sqhook.New(33) - require.Nil(t, hook) - err := hook.Attach("oops", "no") - require.Error(t, err) - prolog, epilog := hook.Callbacks() - require.Nil(t, prolog) - require.Nil(t, epilog) - }) - t.Run("attaching nil", func(t *testing.T) { hook := sqhook.New(example.method) require.NotNil(t, hook) From 113461b4a10c5962b1f85ccf792f32054bc0d205 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Thu, 27 Jun 2019 11:19:27 +0200 Subject: [PATCH 11/47] agent/rule/callback: add a redirection callback for blocking responses The blocking behaviour is also configured by a second rule named `custom_error_redirection`, which changes the blocking behaviour to a user-defined HTTP redirection. To do so, a new callback function `WriteHTTPRedirection` was added in order to change the `WriteResponse()` arguments to perform the configured redirection. --- agent/internal/backend/api/api.go | 9 ++- agent/internal/backend/api/jsonpb.go | 5 +- agent/internal/rule/callback.go | 2 + .../rule/callback/write-http-redirection.go | 52 +++++++++++++++++ .../callback/write-http-redirection_test.go | 57 +++++++++++++++++++ 5 files changed, 122 insertions(+), 3 deletions(-) create mode 100644 agent/internal/rule/callback/write-http-redirection.go create mode 100644 agent/internal/rule/callback/write-http-redirection_test.go diff --git a/agent/internal/backend/api/api.go b/agent/internal/backend/api/api.go index eb3a5c6a..52364d56 100644 --- a/agent/internal/backend/api/api.go +++ b/agent/internal/backend/api/api.go @@ -181,12 +181,19 @@ type RuleData struct { type RuleDataEntry Struct -const CustomErrorPageType = "custom_error_page" +const ( + CustomErrorPageType = "custom_error_page" + RedirectionType = "redirection" +) type CustomErrorPageRuleDataEntry struct { StatusCode int `json:"status_code"` } +type RedirectionRuleDataEntry struct { + RedirectionURL string `json:"redirection_url"` +} + type Dependency struct { Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name"` Version string `protobuf:"bytes,2,opt,name=version,proto3" json:"version"` diff --git a/agent/internal/backend/api/jsonpb.go b/agent/internal/backend/api/jsonpb.go index c8129e9e..56be6ea9 100644 --- a/agent/internal/backend/api/jsonpb.go +++ b/agent/internal/backend/api/jsonpb.go @@ -8,7 +8,6 @@ import ( "encoding/json" "fmt" - "github.com/pkg/errors" "github.com/sqreen/go-agent/agent/sqlib/sqerrors" ) @@ -122,8 +121,10 @@ func (v *RuleDataEntry) UnmarshalJSON(data []byte) error { switch t := discriminant.Type; t { case CustomErrorPageType: value = &CustomErrorPageRuleDataEntry{} + case RedirectionType: + value = &RedirectionRuleDataEntry{} default: - return sqerrors.Wrap(errors.Errorf("unexpected type of rule data value `%s`", t), "json unmarshal") + return sqerrors.Errorf("unexpected type of rule data value `%s`", t) } if err := json.Unmarshal(data, value); err != nil { diff --git a/agent/internal/rule/callback.go b/agent/internal/rule/callback.go index d49375b8..be21336a 100644 --- a/agent/internal/rule/callback.go +++ b/agent/internal/rule/callback.go @@ -24,6 +24,8 @@ func NewCallbacks(name string, data []interface{}) (prolog, epilog sqhook.Callba return nil, nil, sqerrors.Errorf("undefined callback name `%s`", name) case "WriteCustomErrorPage": callbacksCtor = callback.NewWriteCustomErrorPageCallbacks + case "WriteHTTPRedirection": + callbacksCtor = callback.NewWriteHTTPRedirectionCallbacks } return callbacksCtor(data) } diff --git a/agent/internal/rule/callback/write-http-redirection.go b/agent/internal/rule/callback/write-http-redirection.go new file mode 100644 index 00000000..924542a7 --- /dev/null +++ b/agent/internal/rule/callback/write-http-redirection.go @@ -0,0 +1,52 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package callback + +import ( + "net/http" + "net/url" + + "github.com/sqreen/go-agent/agent/internal/backend/api" + "github.com/sqreen/go-agent/agent/sqlib/sqerrors" + "github.com/sqreen/go-agent/agent/sqlib/sqhook" +) + +// NewWriteHTTPRedirectionCallbacks returns the native prolog and epilog +// callbacks modifying the arguments of `httphandler.WriteResponse` in order to +// modify the http status code and headers in order to perform an HTTP +// redirection to the URL provided by the rule's data. +func NewWriteHTTPRedirectionCallbacks(data []interface{}) (prolog, epilog sqhook.Callback, err error) { + var redirectionURL string + if len(data) > 0 { + d0 := data[0] + cfg, ok := d0.(*api.RedirectionRuleDataEntry) + if !ok { + return nil, nil, sqerrors.Errorf("unexpected callback data type: got `%T` instead of `*api.CustomErrorPageRuleDataEntry`", d0) + } + redirectionURL = cfg.RedirectionURL + } + if redirectionURL == "" { + return nil, nil, sqerrors.New("unexpected empty redirection url") + } + if _, err := url.ParseRequestURI(redirectionURL); err != nil { + return nil, nil, sqerrors.Wrap(err, "validation error of the redirection url") + } + return newWriteHTTPRedirectionPrologCallback(redirectionURL), nil, nil +} + +type WriteHTTPRedirectionPrologCallbackType = func(*sqhook.Context, *http.ResponseWriter, **http.Request, *http.Header, *int, *[]byte) error + +// The prolog callback modifies the function arguments in order to perform an +// HTTP redirection. +func newWriteHTTPRedirectionPrologCallback(url string) WriteHTTPRedirectionPrologCallbackType { + return func(_ *sqhook.Context, _ *http.ResponseWriter, _ **http.Request, callerHeaders *http.Header, callerStatusCode *int, _ *[]byte) error { + *callerStatusCode = http.StatusSeeOther + if *callerHeaders == nil { + *callerHeaders = make(http.Header) + } + callerHeaders.Set("Location", url) + return nil + } +} diff --git a/agent/internal/rule/callback/write-http-redirection_test.go b/agent/internal/rule/callback/write-http-redirection_test.go new file mode 100644 index 00000000..21d89dd9 --- /dev/null +++ b/agent/internal/rule/callback/write-http-redirection_test.go @@ -0,0 +1,57 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package callback_test + +import ( + "net/http" + "testing" + + "github.com/sqreen/go-agent/agent/internal/backend/api" + "github.com/sqreen/go-agent/agent/internal/rule/callback" + "github.com/stretchr/testify/require" +) + +func TestNewWriteHTTPRedirectionCallbacks(t *testing.T) { + t.Run("with incorrect data", func(t *testing.T) { + for _, data := range [][]interface{}{ + nil, + {}, + {33}, + {"yet another wrong type"}, + {&api.CustomErrorPageRuleDataEntry{}}, + {&api.RedirectionRuleDataEntry{}}, + {&api.RedirectionRuleDataEntry{"http//sqreen.com"}}, + } { + prolog, epilog, err := callback.NewWriteHTTPRedirectionCallbacks(data) + require.Error(t, err) + require.Nil(t, prolog) + require.Nil(t, epilog) + } + }) + + t.Run("with correct data", func(t *testing.T) { + // Instantiate the callback with the given correct rule data + expectedURL := "http://sqreen.com" + prolog, epilog, err := callback.NewWriteHTTPRedirectionCallbacks([]interface{}{ + &api.RedirectionRuleDataEntry{RedirectionURL: expectedURL}, + }) + require.NoError(t, err) + require.NotNil(t, prolog) + require.Nil(t, epilog) + // Call it and check the behaviour follows the rule's data + actualProlog, ok := prolog.(callback.WriteHTTPRedirectionPrologCallbackType) + require.True(t, ok) + var ( + statusCode int + headers http.Header + ) + err = actualProlog(nil, nil, nil, &headers, &statusCode, nil) + // Check it behaves as expected + require.NoError(t, err) + require.Equal(t, http.StatusSeeOther, statusCode) + require.NotNil(t, headers) + require.Equal(t, expectedURL, headers.Get("Location")) + }) +} From d75911c903b883b3ae68624224792b852dc2e28d Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Fri, 28 Jun 2019 10:21:57 +0200 Subject: [PATCH 12/47] agent/actor/action: add redirection security response for ip addresses Add support for redirection of IP addresses by adding a new type of action whose HTTP handler responds to the HTTP request with a HTTP redirection. The redirection URL is sent by the backend in the action parameters. --- agent/internal/actor/action.go | 26 +++++++- agent/internal/actor/action_test.go | 99 +++++++++++++++++++++-------- agent/internal/actor/actor.go | 23 ++++++- agent/internal/actor/http.go | 77 ++++++++++++++++++++-- agent/internal/backend/api/api.go | 34 ++++++++++ agent/internal/request.go | 6 +- 6 files changed, 229 insertions(+), 36 deletions(-) diff --git a/agent/internal/actor/action.go b/agent/internal/actor/action.go index 27543a17..2ddff127 100644 --- a/agent/internal/actor/action.go +++ b/agent/internal/actor/action.go @@ -5,13 +5,17 @@ package actor import ( + "net/url" "time" + + "github.com/sqreen/go-agent/agent/sqlib/sqerrors" ) // Action kinds. const ( - actionKind_BlockIP = "block_ip" - actionKind_BlockUser = "block_user" + actionKindBlockIP = "block_ip" + actionKindBlockUser = "block_user" + actionKindRedirectIP = "redirect_ip" ) // Action is an interface common to each concrete action type stored in the data @@ -40,6 +44,24 @@ func (a blockAction) ActionID() string { return a.ID } +type redirectAction struct { + ID, URL string +} + +func newRedirectAction(id, location string) (*redirectAction, error) { + if _, err := url.ParseRequestURI(location); err != nil { + return nil, sqerrors.Wrap(err, "validation of the redirection location url") + } + return &redirectAction{ + ID: id, + URL: location, + }, nil +} + +func (a *redirectAction) ActionID() string { + return a.ID +} + // timedAction is an Action with a time deadline after which it is considered // expired. type timedAction struct { diff --git a/agent/internal/actor/action_test.go b/agent/internal/actor/action_test.go index 9f0e3a77..150f8a21 100644 --- a/agent/internal/actor/action_test.go +++ b/agent/internal/actor/action_test.go @@ -16,40 +16,85 @@ import ( ) func TestAction(t *testing.T) { - action := newBlockAction(testlib.RandString(1, 20)) + t.Run("Blocking action", func(t *testing.T) { + action := newBlockAction(testlib.RandString(1, 20)) - t.Run("with duration", func(t *testing.T) { - t.Run("not expired", func(t *testing.T) { - action := withDuration(action, 10*time.Hour) - require.False(t, action.Expired()) + t.Run("with duration", func(t *testing.T) { + t.Run("not expired", func(t *testing.T) { + action := withDuration(action, 10*time.Hour) + require.False(t, action.Expired()) + }) + t.Run("expired", func(t *testing.T) { + action := withDuration(action, 0) + require.True(t, action.Expired()) + }) }) - t.Run("expired", func(t *testing.T) { - action := withDuration(action, 0) - require.True(t, action.Expired()) + + t.Run("HTTP Handler", func(t *testing.T) { + t.Run("Block IP", func(t *testing.T) { + handler, err := NewIPActionHTTPHandler(action, net.IPv4(1, 2, 3, 4)) + require.NotNil(t, handler) + require.Nil(t, err) + // Use the handler + req := httptest.NewRequest(http.MethodPost, "/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + require.Equal(t, rec.Code, 500) + // TODO: check the sdk event + }) + + t.Run("Block User", func(t *testing.T) { + handler := NewUserActionHTTPHandler(action, map[string]string{"uid": testlib.RandString(1, 250)}) + require.NotNil(t, handler) + // Use the handler + req := httptest.NewRequest(http.MethodPost, "/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + require.Equal(t, rec.Code, 500) + // TODO: check the sdk event + }) }) }) - t.Run("HTTP Handler", func(t *testing.T) { - t.Run("Block IP", func(t *testing.T) { - handler := NewIPActionHTTPHandler(action, net.IPv4(1, 2, 3, 4)) - require.NotNil(t, handler) - // Use the handler - req := httptest.NewRequest(http.MethodPost, "/", nil) - rec := httptest.NewRecorder() - handler.ServeHTTP(rec, req) - require.Equal(t, rec.Code, 500) - // TODO: check the sdk event + t.Run("Redirection action", func(t *testing.T) { + t.Run("invalid location url", func(t *testing.T) { + action, err := newRedirectAction(testlib.RandString(1, 20), "http//toto") + require.Nil(t, action) + require.Error(t, err) }) - t.Run("Block User", func(t *testing.T) { - handler := NewUserActionHTTPHandler(action, map[string]string{"uid": testlib.RandString(1, 250)}) - require.NotNil(t, handler) - // Use the handler - req := httptest.NewRequest(http.MethodPost, "/", nil) - rec := httptest.NewRecorder() - handler.ServeHTTP(rec, req) - require.Equal(t, rec.Code, 500) - // TODO: check the sdk event + t.Run("valid location url", func(t *testing.T) { + action, err := newRedirectAction(testlib.RandString(1, 20), "http://sqreen.com") + require.NotNil(t, action) + require.NoError(t, err) + + t.Run("with duration", func(t *testing.T) { + t.Run("not expired", func(t *testing.T) { + action := withDuration(action, 10*time.Hour) + require.False(t, action.Expired()) + }) + t.Run("expired", func(t *testing.T) { + action := withDuration(action, 0) + require.True(t, action.Expired()) + }) + }) + + t.Run("HTTP Handler", func(t *testing.T) { + t.Run("Redirect IP", func(t *testing.T) { + handler, err := NewIPActionHTTPHandler(action, net.IPv4(1, 2, 3, 4)) + require.NotNil(t, handler) + require.Nil(t, err) + // Use the handler + req := httptest.NewRequest(http.MethodPost, "/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + require.Equal(t, rec.Code, http.StatusSeeOther) + require.Equal(t, rec.Header().Get("Location"), action.URL) + // TODO: check the sdk event + }) + + }) }) }) + } diff --git a/agent/internal/actor/actor.go b/agent/internal/actor/actor.go index 8bb0dde5..1934f5d8 100644 --- a/agent/internal/actor/actor.go +++ b/agent/internal/actor/actor.go @@ -207,10 +207,12 @@ func newActionStore(actions []api.ActionsPackResponse_Action) (*actionStore, err func (s *actionStore) addAction(action api.ActionsPackResponse_Action) (err error) { switch action.Action { - case actionKind_BlockIP: + case actionKindBlockIP: err = s.addBlockIPAction(action) - case actionKind_BlockUser: + case actionKindBlockUser: err = s.addBlockUserAction(action) + case actionKindRedirectIP: + err = s.addRedirectIPAction(action) } return err } @@ -231,6 +233,23 @@ func (s *actionStore) addBlockIPAction(action api.ActionsPackResponse_Action) er return s.addCIDRList(cidrs, blockIP) } +func (s *actionStore) addRedirectIPAction(action api.ActionsPackResponse_Action) error { + duration, err := float64ToDuration(action.Duration) + if err != nil { + return err + } + var redirectIP Action + redirectIP, err = newRedirectAction(action.ActionId, action.Parameters.Url) + if duration > 0 { + redirectIP = withDuration(redirectIP, duration) + } + cidrs := action.Parameters.IpCidr + if len(cidrs) == 0 { + return errors.Errorf("could not add action `%s`: empty list of CIDRs", action.ActionId) + } + return s.addCIDRList(cidrs, redirectIP) +} + // Convert a float64 to a `time.Duration` by making sure it doesn't overflow. func float64ToDuration(duration float64) (time.Duration, error) { if duration <= math.MinInt64 || duration >= math.MaxInt64 { diff --git a/agent/internal/actor/http.go b/agent/internal/actor/http.go index a780e3c3..a189e2d3 100644 --- a/agent/internal/actor/http.go +++ b/agent/internal/actor/http.go @@ -11,20 +11,28 @@ import ( "github.com/sqreen/go-agent/agent/internal/backend/api" "github.com/sqreen/go-agent/agent/internal/httphandler" + "github.com/sqreen/go-agent/agent/sqlib/sqerrors" "github.com/sqreen/go-agent/agent/types" "github.com/sqreen/go-agent/sdk" ) // Event names. const ( - blockIPEventName = "app.sqreen.action.block_ip" - blockUserEventName = "app.sqreen.action.block_user" + blockIPEventName = "app.sqreen.action.block_ip" + blockUserEventName = "app.sqreen.action.block_user" + redirectIPEventName = "app.sqreen.action.redirect_ip" ) // NewIPActionHTTPHandler returns a HTTP handler that should be applied at the // request handler level to perform the security response. -func NewIPActionHTTPHandler(action Action, ip net.IP) http.Handler { - return newBlockHTTPHandler(blockIPEventName, newBlockedIPEventProperties(action, ip)) +func NewIPActionHTTPHandler(action Action, ip net.IP) (http.Handler, error) { + switch actual := action.(type) { + case blockAction: + return newBlockHTTPHandler(blockIPEventName, newBlockedIPEventProperties(actual, ip)), nil + case *redirectAction: + return newRedirectHTTPHandler(redirectIPEventName, newRedirectedIPEventProperties(actual, ip), actual.URL), nil + } + return nil, sqerrors.Errorf("unexpected IP action type `%T`", action) } // NewUserActionHTTPHandler returns a HTTP handler that should be applied at @@ -118,3 +126,64 @@ func (p *blockedUserEventProperties) GetOutput() api.BlockedUserEventProperties_ func (p *blockedUserEventProperties) GetUser() map[string]string { return p.userID } + +// redirectHTTPHandler implements the http.Handler interface and holds the event +// data corresponding to the action. +type redirectHTTPHandler struct { + eventName string + eventProperties types.EventProperties + location string +} + +// Static assertion that http.Handler is implemented. +var _ http.Handler = &redirectHTTPHandler{} + +func newRedirectHTTPHandler(eventName string, properties types.EventProperties, location string) *redirectHTTPHandler { + return &redirectHTTPHandler{ + eventName: eventName, + eventProperties: properties, + location: location, + } +} + +// ServeHTTP writes the HTTP status code 500 into the HTTP response writer `w`. +// The caller needs to abort the request. +func (a *redirectHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + record := sdk.FromContext(r.Context()) + record.TrackEvent(a.eventName).WithProperties(a.eventProperties) + w.Header().Set("Location", a.location) + w.WriteHeader(http.StatusSeeOther) +} + +// redirectedIPEventProperties implements `types.EventProperties` to be marshaled +// to an SDK event property structure. +type redirectedIPEventProperties struct { + action *redirectAction + ip net.IP +} + +// Static assert that blockedIPEventProperties implements types.EventProperties. +var _ types.EventProperties = &redirectedIPEventProperties{} + +func newRedirectedIPEventProperties(action *redirectAction, ip net.IP) *blockedIPEventProperties { + return &blockedIPEventProperties{ + action: action, + ip: ip, + } +} +func (p *redirectedIPEventProperties) MarshalJSON() ([]byte, error) { + pb := api.NewRedirectedIPEventPropertiesFromFace(p) + return json.Marshal(pb) +} +func (p *redirectedIPEventProperties) GetActionId() string { + return p.action.ActionID() +} +func (p *redirectedIPEventProperties) GetOutput() api.RedirectedIPEventPropertiesOutput { + return *api.NewRedirectedIPEventPropertiesOutputFromFace(p) +} +func (p *redirectedIPEventProperties) GetIpAddress() string { + return p.ip.String() +} +func (p *redirectedIPEventProperties) GetURL() string { + return p.action.URL +} diff --git a/agent/internal/backend/api/api.go b/agent/internal/backend/api/api.go index 52364d56..b122f129 100644 --- a/agent/internal/backend/api/api.go +++ b/agent/internal/backend/api/api.go @@ -829,6 +829,40 @@ func NewBlockedUserEventProperties_OutputFromFace(that BlockedUserEventPropertie return this } +type RedirectedIPEventProperties struct { + ActionId string `json:"action_id,omitempty"` + Output RedirectedIPEventPropertiesOutput `json:"output"` +} + +type RedirectedIPEventPropertiesOutput struct { + IpAddress string `json:"ip_address"` + URL string `json:"url"` +} + +func NewRedirectedIPEventPropertiesFromFace(that RedirectedIPEventPropertiesFace) *RedirectedIPEventProperties { + return &RedirectedIPEventProperties{ + ActionId: that.GetActionId(), + Output: that.GetOutput(), + } +} + +type RedirectedIPEventPropertiesFace interface { + GetActionId() string + GetOutput() RedirectedIPEventPropertiesOutput +} + +type RedirectedIPEventPropertiesOutputFace interface { + GetIpAddress() string + GetURL() string +} + +func NewRedirectedIPEventPropertiesOutputFromFace(that RedirectedIPEventPropertiesOutputFace) *RedirectedIPEventPropertiesOutput { + return &RedirectedIPEventPropertiesOutput{ + IpAddress: that.GetIpAddress(), + URL: that.GetURL(), + } +} + type RulesPackResponse struct { PackID string `json:"pack_id"` Rules []Rule `json:"rules"` diff --git a/agent/internal/request.go b/agent/internal/request.go index d60e0eeb..ad27fd15 100644 --- a/agent/internal/request.go +++ b/agent/internal/request.go @@ -17,6 +17,7 @@ import ( "github.com/sqreen/go-agent/agent/internal/actor" "github.com/sqreen/go-agent/agent/internal/backend/api" "github.com/sqreen/go-agent/agent/internal/config" + "github.com/sqreen/go-agent/agent/sqlib/sqerrors" "github.com/sqreen/go-agent/agent/types" ) @@ -158,7 +159,10 @@ func (ctx *HTTPRequestRecord) SecurityResponse() http.Handler { if !exists { return nil } - ctx.lastSecurityResponseHandler = actor.NewIPActionHTTPHandler(action, ip) + ctx.lastSecurityResponseHandler, err = actor.NewIPActionHTTPHandler(action, ip) + if err != nil { + agent.logger.Error(sqerrors.Wrap(err, "security response")) + } return ctx.lastSecurityResponseHandler } From bf5029e7faf9301cd94323f43473dcfd265b66cf Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Fri, 28 Jun 2019 13:55:53 +0200 Subject: [PATCH 13/47] agent/actor/action: add redirection security response for user identifiers Add support for the redirection of user identifiers by adding a new type of action whose HTTP handler responds to the HTTP request with a HTTP redirection. The redirection URL is sent by the backend in the action parameters. --- agent/internal/actor/action.go | 7 ++-- agent/internal/actor/action_test.go | 16 +++++++-- agent/internal/actor/actor.go | 19 +++++++++++ agent/internal/actor/http.go | 53 ++++++++++++++++++++++++----- agent/internal/backend/api/api.go | 45 ++++++++++++++++++++---- agent/internal/request.go | 8 +++-- 6 files changed, 126 insertions(+), 22 deletions(-) diff --git a/agent/internal/actor/action.go b/agent/internal/actor/action.go index 2ddff127..a20db830 100644 --- a/agent/internal/actor/action.go +++ b/agent/internal/actor/action.go @@ -13,9 +13,10 @@ import ( // Action kinds. const ( - actionKindBlockIP = "block_ip" - actionKindBlockUser = "block_user" - actionKindRedirectIP = "redirect_ip" + actionKindBlockIP = "block_ip" + actionKindBlockUser = "block_user" + actionKindRedirectIP = "redirect_ip" + actionKindRedirectUser = "redirect_user" ) // Action is an interface common to each concrete action type stored in the data diff --git a/agent/internal/actor/action_test.go b/agent/internal/actor/action_test.go index 150f8a21..64888fab 100644 --- a/agent/internal/actor/action_test.go +++ b/agent/internal/actor/action_test.go @@ -44,7 +44,8 @@ func TestAction(t *testing.T) { }) t.Run("Block User", func(t *testing.T) { - handler := NewUserActionHTTPHandler(action, map[string]string{"uid": testlib.RandString(1, 250)}) + handler, err := NewUserActionHTTPHandler(action, map[string]string{"uid": testlib.RandString(1, 250)}) + require.NoError(t, err) require.NotNil(t, handler) // Use the handler req := httptest.NewRequest(http.MethodPost, "/", nil) @@ -93,8 +94,19 @@ func TestAction(t *testing.T) { // TODO: check the sdk event }) + t.Run("Redirect User", func(t *testing.T) { + handler, err := NewUserActionHTTPHandler(action, map[string]string{"uid": testlib.RandString(1, 250)}) + require.NoError(t, err) + require.NotNil(t, handler) + // Use the handler + req := httptest.NewRequest(http.MethodPost, "/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + require.Equal(t, rec.Code, http.StatusSeeOther) + require.Equal(t, rec.Header().Get("Location"), action.URL) + // TODO: check the sdk event + }) }) }) }) - } diff --git a/agent/internal/actor/actor.go b/agent/internal/actor/actor.go index 1934f5d8..4dae3122 100644 --- a/agent/internal/actor/actor.go +++ b/agent/internal/actor/actor.go @@ -213,6 +213,8 @@ func (s *actionStore) addAction(action api.ActionsPackResponse_Action) (err erro err = s.addBlockUserAction(action) case actionKindRedirectIP: err = s.addRedirectIPAction(action) + case actionKindRedirectUser: + err = s.addRedirectUserAction(action) } return err } @@ -250,6 +252,23 @@ func (s *actionStore) addRedirectIPAction(action api.ActionsPackResponse_Action) return s.addCIDRList(cidrs, redirectIP) } +func (s *actionStore) addRedirectUserAction(action api.ActionsPackResponse_Action) error { + duration, err := float64ToDuration(action.Duration) + if err != nil { + return err + } + var redirectUser Action + redirectUser, err = newRedirectAction(action.ActionId, action.Parameters.Url) + if duration > 0 { + redirectUser = withDuration(redirectUser, duration) + } + users := action.Parameters.Users + if len(users) == 0 { + return errors.Errorf("could not add action `%s`: empty list of users", action.ActionId) + } + return s.addUserList(users, redirectUser) +} + // Convert a float64 to a `time.Duration` by making sure it doesn't overflow. func float64ToDuration(duration float64) (time.Duration, error) { if duration <= math.MinInt64 || duration >= math.MaxInt64 { diff --git a/agent/internal/actor/http.go b/agent/internal/actor/http.go index a189e2d3..b99e7f41 100644 --- a/agent/internal/actor/http.go +++ b/agent/internal/actor/http.go @@ -18,9 +18,10 @@ import ( // Event names. const ( - blockIPEventName = "app.sqreen.action.block_ip" - blockUserEventName = "app.sqreen.action.block_user" - redirectIPEventName = "app.sqreen.action.redirect_ip" + blockIPEventName = "app.sqreen.action.block_ip" + blockUserEventName = "app.sqreen.action.block_user" + redirectIPEventName = "app.sqreen.action.redirect_ip" + redirectUserEventName = "app.sqreen.action.redirect_user" ) // NewIPActionHTTPHandler returns a HTTP handler that should be applied at the @@ -37,8 +38,14 @@ func NewIPActionHTTPHandler(action Action, ip net.IP) (http.Handler, error) { // NewUserActionHTTPHandler returns a HTTP handler that should be applied at // the request handler level to perform the security response. -func NewUserActionHTTPHandler(action Action, userID map[string]string) http.Handler { - return newBlockHTTPHandler(blockUserEventName, newBlockedUserEventProperties(action, userID)) +func NewUserActionHTTPHandler(action Action, userID map[string]string) (http.Handler, error) { + switch actual := action.(type) { + case blockAction: + return newBlockHTTPHandler(blockUserEventName, newBlockedUserEventProperties(actual, userID)), nil + case *redirectAction: + return newRedirectHTTPHandler(redirectUserEventName, newRedirectedUserEventProperties(actual, userID), actual.URL), nil + } + return nil, sqerrors.Errorf("unexpected user action type `%T`", action) } // blockHTTPHandler implements the http.Handler interface and holds the event @@ -96,7 +103,7 @@ func (p *blockedIPEventProperties) GetIpAddress() string { return p.ip.String() } -// blockedIPEventProperties implements `types.EventProperties` to be marshaled +// blockedUserEventProperties implements `types.EventProperties` to be marshaled // to an SDK event property structure. type blockedUserEventProperties struct { action Action @@ -120,13 +127,43 @@ func (p *blockedUserEventProperties) MarshalJSON() ([]byte, error) { func (p *blockedUserEventProperties) GetActionId() string { return p.action.ActionID() } -func (p *blockedUserEventProperties) GetOutput() api.BlockedUserEventProperties_Output { - return *api.NewBlockedUserEventProperties_OutputFromFace(p) +func (p *blockedUserEventProperties) GetOutput() api.BlockedUserEventPropertiesOutput { + return *api.NewBlockedUserEventPropertiesOutputFromFace(p) } func (p *blockedUserEventProperties) GetUser() map[string]string { return p.userID } +// redirectedUserEventProperties implements `types.EventProperties` to be marshaled +// to an SDK event property structure. +type redirectedUserEventProperties struct { + action Action + userID map[string]string +} + +// Static assert that redirectedUserEventProperties implements `types.EventProperties`. +var _ types.EventProperties = &redirectedUserEventProperties{} + +func newRedirectedUserEventProperties(action Action, userID map[string]string) *redirectedUserEventProperties { + return &redirectedUserEventProperties{ + action: action, + userID: userID, + } +} +func (p *redirectedUserEventProperties) MarshalJSON() ([]byte, error) { + pb := api.NewRedirectedUserEventPropertiesFromFace(p) + return json.Marshal(pb) +} +func (p *redirectedUserEventProperties) GetActionId() string { + return p.action.ActionID() +} +func (p *redirectedUserEventProperties) GetOutput() api.RedirectedUserEventPropertiesOutput { + return *api.NewRedirectedUserEventPropertiesOutputFromFace(p) +} +func (p *redirectedUserEventProperties) GetUser() map[string]string { + return p.userID +} + // redirectHTTPHandler implements the http.Handler interface and holds the event // data corresponding to the action. type redirectHTTPHandler struct { diff --git a/agent/internal/backend/api/api.go b/agent/internal/backend/api/api.go index b122f129..984954b5 100644 --- a/agent/internal/backend/api/api.go +++ b/agent/internal/backend/api/api.go @@ -799,11 +799,11 @@ func NewBlockedIPEventProperties_OutputFromFace(that BlockedIPEventProperties_Ou } type BlockedUserEventProperties struct { - ActionId string `json:"action_id"` - Output BlockedUserEventProperties_Output `json:"output"` + ActionId string `json:"action_id"` + Output BlockedUserEventPropertiesOutput `json:"output"` } -type BlockedUserEventProperties_Output struct { +type BlockedUserEventPropertiesOutput struct { User map[string]string `json:"user"` } @@ -816,15 +816,15 @@ func NewBlockedUserEventPropertiesFromFace(that BlockedUserEventPropertiesFace) type BlockedUserEventPropertiesFace interface { GetActionId() string - GetOutput() BlockedUserEventProperties_Output + GetOutput() BlockedUserEventPropertiesOutput } -type BlockedUserEventProperties_OutputFace interface { +type BlockedUserEventPropertiesOutputFace interface { GetUser() map[string]string } -func NewBlockedUserEventProperties_OutputFromFace(that BlockedUserEventProperties_OutputFace) *BlockedUserEventProperties_Output { - this := &BlockedUserEventProperties_Output{} +func NewBlockedUserEventPropertiesOutputFromFace(that BlockedUserEventPropertiesOutputFace) *BlockedUserEventPropertiesOutput { + this := &BlockedUserEventPropertiesOutput{} this.User = that.GetUser() return this } @@ -863,6 +863,37 @@ func NewRedirectedIPEventPropertiesOutputFromFace(that RedirectedIPEventProperti } } +type RedirectedUserEventProperties struct { + ActionId string `json:"action_id"` + Output RedirectedUserEventPropertiesOutput `json:"output"` +} + +type RedirectedUserEventPropertiesOutput struct { + User map[string]string `json:"user"` +} + +func NewRedirectedUserEventPropertiesFromFace(that RedirectedUserEventPropertiesFace) *RedirectedUserEventProperties { + return &RedirectedUserEventProperties{ + ActionId: that.GetActionId(), + Output: that.GetOutput(), + } +} + +type RedirectedUserEventPropertiesFace interface { + GetActionId() string + GetOutput() RedirectedUserEventPropertiesOutput +} + +type RedirectedUserEventPropertiesOutputFace interface { + GetUser() map[string]string +} + +func NewRedirectedUserEventPropertiesOutputFromFace(that RedirectedUserEventPropertiesOutputFace) *RedirectedUserEventPropertiesOutput { + this := &RedirectedUserEventPropertiesOutput{} + this.User = that.GetUser() + return this +} + type RulesPackResponse struct { PackID string `json:"pack_id"` Rules []Rule `json:"rules"` diff --git a/agent/internal/request.go b/agent/internal/request.go index ad27fd15..a9042bbe 100644 --- a/agent/internal/request.go +++ b/agent/internal/request.go @@ -161,7 +161,7 @@ func (ctx *HTTPRequestRecord) SecurityResponse() http.Handler { } ctx.lastSecurityResponseHandler, err = actor.NewIPActionHTTPHandler(action, ip) if err != nil { - agent.logger.Error(sqerrors.Wrap(err, "security response")) + agent.logger.Error(sqerrors.Wrap(err, "ip security response")) } return ctx.lastSecurityResponseHandler } @@ -179,7 +179,11 @@ func (ctx *HTTPRequestRecord) UserSecurityResponse() http.Handler { if !exists { return nil } - ctx.lastUserSecurityResponseHandler = actor.NewUserActionHTTPHandler(action, userID) + var err error + ctx.lastUserSecurityResponseHandler, err = actor.NewUserActionHTTPHandler(action, userID) + if err != nil { + agent.logger.Error(sqerrors.Wrap(err, "user security response")) + } return ctx.lastUserSecurityResponseHandler } From 479f67e586339bba14dfdd5915e2609b742c9dad Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Mon, 8 Jul 2019 14:20:58 +0200 Subject: [PATCH 14/47] agent/request/actions: add more context into the error logs --- agent/internal/request.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/agent/internal/request.go b/agent/internal/request.go index a9042bbe..7a0971e7 100644 --- a/agent/internal/request.go +++ b/agent/internal/request.go @@ -161,7 +161,7 @@ func (ctx *HTTPRequestRecord) SecurityResponse() http.Handler { } ctx.lastSecurityResponseHandler, err = actor.NewIPActionHTTPHandler(action, ip) if err != nil { - agent.logger.Error(sqerrors.Wrap(err, "ip security response")) + agent.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("could not create the http handler for an ip security response: action `%v` - ip `%s`:", action.ActionID(), ip))) } return ctx.lastSecurityResponseHandler } @@ -182,7 +182,7 @@ func (ctx *HTTPRequestRecord) UserSecurityResponse() http.Handler { var err error ctx.lastUserSecurityResponseHandler, err = actor.NewUserActionHTTPHandler(action, userID) if err != nil { - agent.logger.Error(sqerrors.Wrap(err, "user security response")) + agent.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("could not create the http handler for a user security response: action `%v` - user `%v`:", action.ActionID(), userID))) } return ctx.lastUserSecurityResponseHandler } From 7cdb0466b36af7acb7c8d0fe4a13a66bc58b70bf Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Tue, 2 Jul 2019 11:32:40 +0200 Subject: [PATCH 15/47] sdk/middleware/sqhttp: common http middleware implementation - Factorize from `sqhttp.Middleware` an HTTP middleware that will be used by every other middleware implementations we have so far. - Add a specific hookable security headers insertion function. - Use this new middleware with `sqhttp.Middleware`. --- sdk/middleware/sqhttp/http.go | 99 ++++++++++++++++++++++++++++++++--- 1 file changed, 91 insertions(+), 8 deletions(-) diff --git a/sdk/middleware/sqhttp/http.go b/sdk/middleware/sqhttp/http.go index ad0831e2..4a823e0a 100644 --- a/sdk/middleware/sqhttp/http.go +++ b/sdk/middleware/sqhttp/http.go @@ -7,7 +7,9 @@ package sqhttp import ( "net/http" + "github.com/sqreen/go-agent/agent/sqlib/sqhook" "github.com/sqreen/go-agent/sdk" + "golang.org/x/xerrors" ) // Middleware is Sqreen's middleware function for `net/http` to monitor and @@ -44,28 +46,109 @@ import ( // http.Handle("/foo", sqhttp.Middleware(http.HandlerFunc(fn))) // func Middleware(next http.Handler) http.Handler { + // Simply adapt http.Handler to Handler in order to call MiddlewareWithError + // to get the middleware function. + m := MiddlewareWithError(HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + next.ServeHTTP(w, r) + return nil + })) + // And now return a function adapting Handler to http.Handler return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = m.ServeHTTP(w, r) + }) +} + +// MiddlewareWithError is a helper middleware to define other middlewares for +// other frameworks thanks to the error returned by the handlers in order +// to know if a request is being aborted. +func MiddlewareWithError(next Handler) Handler { + // TODO: move this middleware function into the agent internal package (which + // needs restructuring the SDK) + return HandlerFunc(func(w http.ResponseWriter, r *http.Request) (err error) { + if err := addSecurityHeaders(w); err != nil { + return err + } // Create a new sqreen request wrapper. req := sdk.NewHTTPRequest(r) defer req.Close() // Use the newly created request compliant with `sdk.FromContext()`. r = req.Request() - // Check if an early security action is already required such as based on // the request IP address. if handler := req.SecurityResponse(); handler != nil { handler.ServeHTTP(w, r) - return + return AbortRequestError{} } - // Call next handler. - next.ServeHTTP(w, r) - - // Check if a security response should be applied now after having used - // `Identify()` and `MatchSecurityResponse()`. + err = next.ServeHTTP(w, r) + // If the returned error is not nil nor a security response, return it now. + var secResponse sdk.SecurityResponseMatch + if err != nil && !xerrors.As(err, &secResponse) { + return err + } + // Otherwise check if a security response should be applied now, after + // having used `Identify()` and `MatchSecurityResponse()`. if handler := req.UserSecurityResponse(); handler != nil { handler.ServeHTTP(w, r) - return + return AbortRequestError{} } + return nil }) } + +// Handler is equivalent to http.Handler but returns an error when the request +// should no longer be handled. +type Handler interface { + ServeHTTP(w http.ResponseWriter, r *http.Request) error +} + +// HandlerFunc is equivalent to http.HandlerFunc but returns an error when the +// request should no longer be handled. +type HandlerFunc func(http.ResponseWriter, *http.Request) error + +// ServeHTTP calls f(w, r). +func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) error { + return f(w, r) +} + +// AbortRequestError is returned by handlers when some security response was +// triggered and handled the response. The request handling should therefore +// stop. +type AbortRequestError struct { + Message string +} + +func (AbortRequestError) Error() string { + return "request aborted" +} + +// addSecurityHeaders is a mean to add a hook to the function closure returned +// by MiddlewareWithError() since it is not possible to get the symbol of +// function closures at compilation-time, so it is not possible to create a hook +// with the address of the function closure. The solution for this precise case +// where only a prolog is enough is therefore to simply define a function having +// a hook and called by the closure. +func addSecurityHeaders(w http.ResponseWriter) (err error) { + { + type Prolog = func(*sqhook.Context, *http.ResponseWriter) error + type Epilog = func(*sqhook.Context, *error) + ctx := sqhook.Context{} + prolog, epilog := addSecurityHeaderHook.Callbacks() + if epilog, ok := epilog.(Epilog); ok { + defer epilog(&ctx, &err) + } + if prolog, ok := prolog.(Prolog); ok { + if err := prolog(&ctx, &w); err != nil { + return err + } + } + } + + return nil +} + +var addSecurityHeaderHook *sqhook.Hook + +func init() { + addSecurityHeaderHook = sqhook.New(addSecurityHeaders) +} From 6055441333e6af214014bdc3ce5341172c4db7e1 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Tue, 2 Jul 2019 15:57:51 +0200 Subject: [PATCH 16/47] agent/rule/callback: http security headers callback Add a new callback to attach to `sqhttp.addSecurityHeaders` in order to add security headers defined in the rule's data. Unfortunately with this rule, the data is not structured the same way as the blocking behaviour rules that were implemented and a kind of hack had to be inserted into the json unmarshaler in order to support this case. Instead of receiving an array of structures with a `type` key, this rule is sent with an array of array of 2 string values, the first one for the header key, the second one for the header value. So the idea for now is to fallback to this type when the attempt to unmarshal a structure fails. --- agent/internal/backend/api/jsonpb.go | 11 +++- agent/internal/rule/callback.go | 2 + .../rule/callback/add-security-headers.go | 50 ++++++++++++++++ .../callback/add-security-headers_test.go | 59 +++++++++++++++++++ 4 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 agent/internal/rule/callback/add-security-headers.go create mode 100644 agent/internal/rule/callback/add-security-headers_test.go diff --git a/agent/internal/backend/api/jsonpb.go b/agent/internal/backend/api/jsonpb.go index 56be6ea9..555fa168 100644 --- a/agent/internal/backend/api/jsonpb.go +++ b/agent/internal/backend/api/jsonpb.go @@ -114,7 +114,16 @@ func (v *RuleDataEntry) UnmarshalJSON(data []byte) error { Type string `json:"type"` } if err := json.Unmarshal(data, &discriminant); err != nil { - return sqerrors.Wrap(err, "json unmarshal") + // Some rules come with values not discriminated by a `type` key + // So we try other types + // TODO: fix this in the API + var strArray []string + err = json.Unmarshal(data, &strArray) + if err != nil { + return err + } + v.Value = strArray + return nil } var value interface{} diff --git a/agent/internal/rule/callback.go b/agent/internal/rule/callback.go index be21336a..67febe30 100644 --- a/agent/internal/rule/callback.go +++ b/agent/internal/rule/callback.go @@ -26,6 +26,8 @@ func NewCallbacks(name string, data []interface{}) (prolog, epilog sqhook.Callba callbacksCtor = callback.NewWriteCustomErrorPageCallbacks case "WriteHTTPRedirection": callbacksCtor = callback.NewWriteHTTPRedirectionCallbacks + case "AddSecurityHeaders": + callbacksCtor = callback.NewAddSecurityHeadersCallbacks } return callbacksCtor(data) } diff --git a/agent/internal/rule/callback/add-security-headers.go b/agent/internal/rule/callback/add-security-headers.go new file mode 100644 index 00000000..acccbb64 --- /dev/null +++ b/agent/internal/rule/callback/add-security-headers.go @@ -0,0 +1,50 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package callback + +import ( + "net/http" + + "github.com/sqreen/go-agent/agent/sqlib/sqerrors" + "github.com/sqreen/go-agent/agent/sqlib/sqhook" +) + +// NewAddSecurityHeadersCallbacks returns the native prolog and epilog callbacks +// to be hooked to `sqhttp.MiddlewareWithError` in order to add HTTP headers +// provided by the rule's data. +func NewAddSecurityHeadersCallbacks(data []interface{}) (prolog, epilog sqhook.Callback, err error) { + var headers = make(http.Header, len(data)) + for _, headersKV := range data { + // TODO: move to a structured list of headers to avoid dynamic type checking + kv, ok := headersKV.([]string) + if !ok { + err = sqerrors.Errorf("unexpected number of values: header key and values are expected but got `%d` values instead", len(kv)) + return + } + if len(kv) != 2 { + err = sqerrors.Errorf("unexpected number of values: header key and values are expected but got `%d` values instead", len(kv)) + return + } + headers.Set(kv[0], kv[1]) + } + if len(headers) == 0 { + return nil, nil, sqerrors.New("there are no headers to add") + } + return newAddHeadersPrologCallback(headers), nil, nil +} + +type AddSecurityHeadersPrologCallbackType = func(*sqhook.Context, *http.ResponseWriter) error + +// The prolog callback modifies the function arguments in order to replace the +// written status code and body. +func newAddHeadersPrologCallback(headers http.Header) AddSecurityHeadersPrologCallbackType { + return func(_ *sqhook.Context, w *http.ResponseWriter) error { + responseHeaders := (*w).Header() + for k, v := range headers { + responseHeaders[k] = v + } + return nil + } +} diff --git a/agent/internal/rule/callback/add-security-headers_test.go b/agent/internal/rule/callback/add-security-headers_test.go new file mode 100644 index 00000000..ec5f8b6f --- /dev/null +++ b/agent/internal/rule/callback/add-security-headers_test.go @@ -0,0 +1,59 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package callback_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/sqreen/go-agent/agent/internal/rule/callback" + "github.com/stretchr/testify/require" +) + +func TestNewAddSecurityHeadersCallbacks(t *testing.T) { + t.Run("with incorrect data", func(t *testing.T) { + for _, data := range [][]interface{}{ + nil, + {}, + {33}, + {"yet another wrong type"}, + {[]string{}}, + {nil}, + {[]string{"one"}}, + {[]string{"one", "two", "three"}}, + } { + prolog, epilog, err := callback.NewAddSecurityHeadersCallbacks(data) + require.Error(t, err) + require.Nil(t, prolog) + require.Nil(t, epilog) + } + }) + + t.Run("with correct data", func(t *testing.T) { + // Instantiate the callback with the given correct rule data + prolog, epilog, err := callback.NewAddSecurityHeadersCallbacks([]interface{}{ + []string{"k", "v"}, + []string{"one", "two"}, + []string{"canonical-header", "the value"}, + }) + require.NoError(t, err) + require.NotNil(t, prolog) + require.Nil(t, epilog) + // Call it and check the behaviour follows the rule's data + actualProlog, ok := prolog.(callback.AddSecurityHeadersPrologCallbackType) + require.True(t, ok) + var rec http.ResponseWriter = httptest.NewRecorder() + err = actualProlog(nil, &rec) + // Check it behaves as expected + require.NoError(t, err) + expectedHeaders := http.Header{ + "K": []string{"v"}, + "One": []string{"two"}, + "Canonical-Header": []string{"the value"}, + } + require.Equal(t, expectedHeaders, rec.Header()) + }) +} From 55b2484889bfd5704a755604315352607753a08b Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Tue, 2 Jul 2019 16:20:04 +0200 Subject: [PATCH 17/47] sdk/middleware/sqgin: use the common http middleware --- sdk/middleware/sqgin/gin.go | 50 +++++++++++++------------------------ 1 file changed, 17 insertions(+), 33 deletions(-) diff --git a/sdk/middleware/sqgin/gin.go b/sdk/middleware/sqgin/gin.go index 74867439..e6570ff9 100644 --- a/sdk/middleware/sqgin/gin.go +++ b/sdk/middleware/sqgin/gin.go @@ -5,8 +5,11 @@ package sqgin import ( + "net/http" + gingonic "github.com/gin-gonic/gin" "github.com/sqreen/go-agent/sdk" + "github.com/sqreen/go-agent/sdk/middleware/sqhttp" ) // Middleware is Sqreen's middleware function for Gin to monitor and protect the @@ -53,39 +56,20 @@ import ( // func Middleware() gingonic.HandlerFunc { return func(c *gingonic.Context) { - // Get current request. - r := c.Request - // Create a new sqreen request wrapper. - req := sdk.NewHTTPRequest(r) - defer req.Close() - // Use the newly created request compliant with `sdk.FromContext()`. - r = req.Request() - // Also replace Gin's request pointer with it. - c.Request = r - - // Check if an early security action is already required such as based on - // the request IP address. - if handler := req.SecurityResponse(); handler != nil { - handler.ServeHTTP(c.Writer, r) - c.Abort() - return - } - - // Gin implements the `context.Context` interface but with string keys, so - // we need to also store the request record in Gin's context using a string - // key (previous call to `sdk.NewHTTPRequest()` stored it with a non-string - // key, as documented by `context.WithValue()` - // (https://godoc.org/context#WithValue)). - contextKey := sdk.HTTPRequestRecordContextKey.String - c.Set(contextKey, req.Record()) - - // Call next handler. - c.Next() - - // Check if a security response should be applied now after having used - // `Identify()` and `MatchSecurityResponse()`. - if handler := req.UserSecurityResponse(); handler != nil { - handler.ServeHTTP(c.Writer, r) + // Adapt sqhttp middleware to Gin's + err := sqhttp.MiddlewareWithError(sqhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + c.Request = r + // Gin implements the `context.Context` interface but with string keys, so + // we need to also store the request record in Gin's context using a string + // key (previous call to `sdk.NewHTTPRequest()` stored it with a non-string + // key, as documented by `context.WithValue()` + // (https://godoc.org/context#WithValue)). + contextKey := sdk.HTTPRequestRecordContextKey.String + c.Set(contextKey, sdk.FromContext(r.Context())) + c.Next() + return nil + })).ServeHTTP(c.Writer, c.Request) + if err != nil { c.Abort() } } From 437b8edbb240a740cd02e9cf71c6440bd8badf25 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Tue, 2 Jul 2019 16:22:55 +0200 Subject: [PATCH 18/47] sdk/middleware/sqecho: use the common http middleware --- sdk/middleware/sqecho/echo.go | 49 ++++++++---------------------- sdk/middleware/sqecho/echo_test.go | 11 ++++--- 2 files changed, 19 insertions(+), 41 deletions(-) diff --git a/sdk/middleware/sqecho/echo.go b/sdk/middleware/sqecho/echo.go index ad10672b..3c5847be 100644 --- a/sdk/middleware/sqecho/echo.go +++ b/sdk/middleware/sqecho/echo.go @@ -5,9 +5,11 @@ package sqecho import ( + "net/http" + "github.com/labstack/echo" "github.com/sqreen/go-agent/sdk" - "golang.org/x/xerrors" + "github.com/sqreen/go-agent/sdk/middleware/sqhttp" ) // Middleware is Sqreen's middleware function for Echo to monitor and protect @@ -53,44 +55,17 @@ import ( // } // func Middleware() echo.MiddlewareFunc { + // Create a middleware function by adapting to sqhttp's return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - // Get current request. - r := c.Request() - // Create a new sqreen request wrapper. - req := sdk.NewHTTPRequest(r) - defer req.Close() - // Use the newly created request compliant with `sdk.FromContext()`. - r = req.Request() - // Also replace Echo's request pointer with it. - c.SetRequest(r) - - // Check if an early security action is already required such as based on - // the request IP address. - if handler := req.SecurityResponse(); handler != nil { - handler.ServeHTTP(c.Response(), req.Request()) - return nil - } - - // Echo defines its own context interface, so we need to store it in - // Echo's context. Echo expects string keys. - contextKey := sdk.HTTPRequestRecordContextKey.String - c.Set(contextKey, req.Record()) - - // Call next handler. - err := next(c) - if err != nil && !xerrors.As(err, &sdk.SecurityResponseMatch{}) { - // The error is not a security response match - return err - } - - // Check if a security response should be applied now after having used - // `Identify()` and `MatchSecurityResponse()`. - if handler := req.UserSecurityResponse(); handler != nil { - handler.ServeHTTP(c.Response(), req.Request()) - } - - return nil + return sqhttp.MiddlewareWithError(sqhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + c.SetRequest(r) + // Echo defines its own context interface, so we need to store it in + // Echo's context. Echo expects string keys. + contextKey := sdk.HTTPRequestRecordContextKey.String + c.Set(contextKey, sdk.FromContext(r.Context())) + return next(c) + })).ServeHTTP(c.Response(), c.Request()) } } } diff --git a/sdk/middleware/sqecho/echo_test.go b/sdk/middleware/sqecho/echo_test.go index c08d3e4d..123e9e55 100644 --- a/sdk/middleware/sqecho/echo_test.go +++ b/sdk/middleware/sqecho/echo_test.go @@ -15,8 +15,10 @@ import ( "github.com/labstack/echo" "github.com/sqreen/go-agent/sdk" "github.com/sqreen/go-agent/sdk/middleware/sqecho" + "github.com/sqreen/go-agent/sdk/middleware/sqhttp" "github.com/sqreen/go-agent/tools/testlib" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" ) func TestMiddleware(t *testing.T) { @@ -164,7 +166,8 @@ func TestMiddleware(t *testing.T) { mw := sqecho.Middleware() err := mw(h)(c) // Check the request was performed as expected - require.NoError(t, err) + require.Error(t, err) + require.True(t, xerrors.Is(err, sqhttp.AbortRequestError{})) require.Equal(t, rec.Code, status) require.Equal(t, rec.Body.String(), "") }) @@ -201,9 +204,9 @@ func TestMiddleware(t *testing.T) { mw := sqecho.Middleware() err := mw(h)(c) // Check the request was performed as expected - require.NoError(t, err) - require.Equal(t, rec.Code, status) - require.Equal(t, rec.Body.String(), "") + require.Error(t, err) + require.Equal(t, status, rec.Code) + require.Equal(t, "", rec.Body.String()) }) }) } From 9fbf408c2a2f75f12729c4f9dd7576229845e407 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Tue, 2 Jul 2019 16:25:30 +0200 Subject: [PATCH 19/47] sdk/middleware/sqgrpc: use the common http middleware --- sdk/middleware/sqgrpc/grpc.go | 84 ++++++++--------------------------- 1 file changed, 19 insertions(+), 65 deletions(-) diff --git a/sdk/middleware/sqgrpc/grpc.go b/sdk/middleware/sqgrpc/grpc.go index 3998f27d..1f922fc9 100644 --- a/sdk/middleware/sqgrpc/grpc.go +++ b/sdk/middleware/sqgrpc/grpc.go @@ -49,7 +49,7 @@ import ( grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils" - "github.com/sqreen/go-agent/sdk" + "github.com/sqreen/go-agent/sdk/middleware/sqhttp" "golang.org/x/xerrors" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -59,74 +59,36 @@ import ( func UnaryServerInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - // Create a new sqreen request wrapper. - ctx, sqreened := newRequestFromMD(ctx) - defer sqreened.Close() - - // Check if an early security action is already required such as based on - // the request IP address. - if handler := sqreened.SecurityResponse(); handler != nil { - // TODO: better interface for non-standard HTTP packages to avoid this - // noopHTTPResponseWriter hack just to send the block event. - handler.ServeHTTP(noopHTTPResponseWriter{}, sqreened.Request()) + sqreened := newRequestFromMD(ctx) + var res interface{} + err := sqhttp.MiddlewareWithError(sqhttp.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) (err error) { + res, err = handler(r.Context(), req) + return err + })).ServeHTTP(noopHTTPResponseWriter{}, sqreened) + if xerrors.Is(err, sqhttp.AbortRequestError{}) { return nil, status.Error(codes.Aborted, "aborted by sqreen security action") } - - res, err := handler(ctx, req) - if err != nil && !xerrors.As(err, &sdk.SecurityResponseMatch{}) { - // The error is not a security response match - return res, err - } - - // Check if a security response should be applied now after having used - // `Identify()` and `MatchSecurityResponse()`. - if handler := sqreened.UserSecurityResponse(); handler != nil { - // TODO: same as before - handler.ServeHTTP(noopHTTPResponseWriter{}, sqreened.Request()) - return nil, status.Error(codes.Aborted, "aborted by a sqreen user action") - } - - return res, nil + return res, err } } func StreamServerInterceptor() grpc.StreamServerInterceptor { return func(srv interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - ctx, sqreened := newRequestFromMD(stream.Context()) - defer sqreened.Close() - - // Check if an early security action is already required such as based on - // the request IP address. - if handler := sqreened.SecurityResponse(); handler != nil { - // TODO: better interface for non-standard HTTP packages to avoid this - // noopHTTPResponseWriter hack just to send the block event. - handler.ServeHTTP(noopHTTPResponseWriter{}, sqreened.Request()) - return status.Error(codes.Aborted, "aborted by a sqreen security action") + sqreened := newRequestFromMD(stream.Context()) + err := sqhttp.MiddlewareWithError(sqhttp.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) (err error) { + stream := grpc_middleware.WrapServerStream(stream) + stream.WrappedContext = r.Context() + return handler(srv, stream) + })).ServeHTTP(noopHTTPResponseWriter{}, sqreened) + if xerrors.Is(err, sqhttp.AbortRequestError{}) { + return status.Error(codes.Aborted, "aborted by sqreen security action") } - - wrapped := grpc_middleware.WrapServerStream(stream) - wrapped.WrappedContext = ctx - err := handler(srv, wrapped) - if err != nil && !xerrors.As(err, &sdk.SecurityResponseMatch{}) { - // The error is not a security response match - return err - } - - // Check if a security response should be applied now after having used - // `Identify()` and `MatchSecurityResponse()`. - if handler := sqreened.UserSecurityResponse(); handler != nil { - // TODO: same as before - handler.ServeHTTP(noopHTTPResponseWriter{}, sqreened.Request()) - return status.Error(codes.Aborted, "aborted by a sqreen user action") - } - // Note that we do not control the result's payload here. So users need // to use the SDK in order to check the user security response and avoid // sending messages. A slower solution would be wrapping the stream's // Recv() and Send() methods in order to check for the security response // every time a message is received/sent, so that the connection can // be aborted. - return nil } } @@ -178,7 +140,7 @@ func (r http2Request) getHeader() http.Header { // HTTP request in order to be compatible with the current API. In the future, a // better abstraction should allow to not rely only on the standard Go HTTP // package only. -func newRequestFromMD(ctx context.Context) (context.Context, *sdk.HTTPRequest) { +func newRequestFromMD(ctx context.Context) *http.Request { // gRPC stores headers into the metadata object. r := http2Request(metautils.ExtractIncoming(ctx)) p, ok := peer.FromContext(ctx) @@ -186,7 +148,7 @@ func newRequestFromMD(ctx context.Context) (context.Context, *sdk.HTTPRequest) { if ok { remoteAddr = p.Addr.String() } - req := &http.Request{ + return &http.Request{ Method: r.getMethod(), URL: r.getURL(), Proto: "HTTP/2", @@ -197,14 +159,6 @@ func newRequestFromMD(ctx context.Context) (context.Context, *sdk.HTTPRequest) { RemoteAddr: remoteAddr, RequestURI: r.getRequestURI(), } - req = req.WithContext(ctx) - - // Create a new sqreened request. - sqreened := sdk.NewHTTPRequest(req) - // Get the new request context which includes the request record pointer. - ctx = sqreened.Request().Context() - - return ctx, sqreened } // TODO: agent interfaces should not require this hack. From c0c11e5bb0e4cb1f562a3ae2c89cf76b7bfeedbf Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Thu, 4 Jul 2019 15:04:06 +0200 Subject: [PATCH 20/47] agent/rule: add support for multiple callbacks per hook The support for multiple callbacks per hook is implemented by chaining callbacks together. It allows to keep the fast dispatch we have thanks to type-assertion - instead of using reflect - since a callback has the same function signature (same type) as the next callback. So it makes them perfect to chain them. The huge benefit also is that it keeps `sqhook` simple and efficient. The complexity for this is added at the callback level. Callbacks are chained through their closure: a pointer to the next one is passed when they are instantiated. To do so, the rule engine was modified in order to be able to find callbacks already associated to a hook, so that the previous callback pointer is passed to the currently constructed callback. All this happens in a new dataset before replacing the current one. --- agent/internal/rule/callback.go | 6 +- .../rule/callback/add-security-headers.go | 27 ++- .../callback/add-security-headers_test.go | 176 +++++++++++++++--- .../rule/callback/write-custom-error-page.go | 30 ++- .../callback/write-custom-error-page_test.go | 91 ++++----- .../rule/callback/write-http-redirection.go | 33 +++- .../callback/write-http-redirection_test.go | 4 +- agent/internal/rule/callback_test.go | 2 +- agent/internal/rule/rule.go | 80 ++++---- 9 files changed, 307 insertions(+), 142 deletions(-) diff --git a/agent/internal/rule/callback.go b/agent/internal/rule/callback.go index 67febe30..183a455f 100644 --- a/agent/internal/rule/callback.go +++ b/agent/internal/rule/callback.go @@ -13,11 +13,11 @@ import ( // CallbackConstructorFunc is a function returning a callback function // configured with the given data. The data types are known by the constructor // that can type-assert them. -type CallbacksConstructorFunc func(data []interface{}) (prolog, epilog sqhook.Callback, err error) +type CallbacksConstructorFunc func(data []interface{}, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) // NewCallbacks returns the prolog and epilog callbacks of the given callback // name. And error is returned if the callback name is unknown. -func NewCallbacks(name string, data []interface{}) (prolog, epilog sqhook.Callback, err error) { +func NewCallbacks(name string, data []interface{}, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) { var callbacksCtor CallbacksConstructorFunc switch name { default: @@ -29,5 +29,5 @@ func NewCallbacks(name string, data []interface{}) (prolog, epilog sqhook.Callba case "AddSecurityHeaders": callbacksCtor = callback.NewAddSecurityHeadersCallbacks } - return callbacksCtor(data) + return callbacksCtor(data, nextProlog, nextEpilog) } diff --git a/agent/internal/rule/callback/add-security-headers.go b/agent/internal/rule/callback/add-security-headers.go index acccbb64..d5849afd 100644 --- a/agent/internal/rule/callback/add-security-headers.go +++ b/agent/internal/rule/callback/add-security-headers.go @@ -14,7 +14,7 @@ import ( // NewAddSecurityHeadersCallbacks returns the native prolog and epilog callbacks // to be hooked to `sqhttp.MiddlewareWithError` in order to add HTTP headers // provided by the rule's data. -func NewAddSecurityHeadersCallbacks(data []interface{}) (prolog, epilog sqhook.Callback, err error) { +func NewAddSecurityHeadersCallbacks(data []interface{}, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) { var headers = make(http.Header, len(data)) for _, headersKV := range data { // TODO: move to a structured list of headers to avoid dynamic type checking @@ -32,19 +32,36 @@ func NewAddSecurityHeadersCallbacks(data []interface{}) (prolog, epilog sqhook.C if len(headers) == 0 { return nil, nil, sqerrors.New("there are no headers to add") } - return newAddHeadersPrologCallback(headers), nil, nil + + // Next callbacks to call + actualNextProlog, ok := nextProlog.(AddSecurityHeadersPrologCallbackType) + if nextProlog != nil && !ok { + err = sqerrors.Errorf("unexpected next prolog type `%T` instead of `%T`", nextProlog, AddSecurityHeadersPrologCallbackType(nil)) + return + } + // No epilog in this callback, so simply check and pass the given one + if _, ok := nextEpilog.(AddSecurityHeadersEpilogCallbackType); nextEpilog != nil && !ok { + err = sqerrors.Errorf("unexpected next epilog type `%T` instead of `%T`", nextEpilog, AddSecurityHeadersEpilogCallbackType(nil)) + return + } + return newAddHeadersPrologCallback(headers, actualNextProlog), nextEpilog, nil } type AddSecurityHeadersPrologCallbackType = func(*sqhook.Context, *http.ResponseWriter) error +type AddSecurityHeadersEpilogCallbackType = func(*sqhook.Context) // The prolog callback modifies the function arguments in order to replace the // written status code and body. -func newAddHeadersPrologCallback(headers http.Header) AddSecurityHeadersPrologCallbackType { - return func(_ *sqhook.Context, w *http.ResponseWriter) error { +func newAddHeadersPrologCallback(headers http.Header, next AddSecurityHeadersPrologCallbackType) AddSecurityHeadersPrologCallbackType { + return func(ctx *sqhook.Context, w *http.ResponseWriter) error { responseHeaders := (*w).Header() for k, v := range headers { responseHeaders[k] = v } - return nil + + if next == nil { + return nil + } + return next(ctx, w) } } diff --git a/agent/internal/rule/callback/add-security-headers_test.go b/agent/internal/rule/callback/add-security-headers_test.go index ec5f8b6f..4ec8dc98 100644 --- a/agent/internal/rule/callback/add-security-headers_test.go +++ b/agent/internal/rule/callback/add-security-headers_test.go @@ -7,15 +7,22 @@ package callback_test import ( "net/http" "net/http/httptest" + "reflect" "testing" + "github.com/sqreen/go-agent/agent/internal/rule" "github.com/sqreen/go-agent/agent/internal/rule/callback" + "github.com/sqreen/go-agent/agent/sqlib/sqhook" "github.com/stretchr/testify/require" ) func TestNewAddSecurityHeadersCallbacks(t *testing.T) { - t.Run("with incorrect data", func(t *testing.T) { - for _, data := range [][]interface{}{ + RunCallbackTest(t, TestConfig{ + CallbacksCtor: callback.NewAddSecurityHeadersCallbacks, + ExpectProlog: true, + PrologType: reflect.TypeOf(callback.AddSecurityHeadersPrologCallbackType(nil)), + EpilogType: reflect.TypeOf(callback.AddSecurityHeadersEpilogCallbackType(nil)), + InvalidTestCases: [][]interface{}{ nil, {}, {33}, @@ -24,36 +31,149 @@ func TestNewAddSecurityHeadersCallbacks(t *testing.T) { {nil}, {[]string{"one"}}, {[]string{"one", "two", "three"}}, - } { - prolog, epilog, err := callback.NewAddSecurityHeadersCallbacks(data) + }, + ValidTestCases: []ValidTestCase{ + { + ValidData: []interface{}{ + []string{"k", "v"}, + []string{"one", "two"}, + []string{"canonical-header", "the value"}, + }, + TestCallbacks: func(t *testing.T, prolog, epilog sqhook.Callback) { + expectedHeaders := http.Header{ + "K": []string{"v"}, + "One": []string{"two"}, + "Canonical-Header": []string{"the value"}, + } + actualProlog, ok := prolog.(callback.AddSecurityHeadersPrologCallbackType) + require.True(t, ok) + var rec http.ResponseWriter = httptest.NewRecorder() + err := actualProlog(nil, &rec) + // Check it behaves as expected + require.NoError(t, err) + require.Equal(t, expectedHeaders, rec.Header()) + + // Test the epilog if any + if epilog != nil { + actualEpilog, ok := epilog.(callback.AddSecurityHeadersEpilogCallbackType) + require.True(t, ok) + actualEpilog(&sqhook.Context{}) + } + }, + }, + }, + }) +} + +type TestConfig struct { + CallbacksCtor rule.CallbacksConstructorFunc + ExpectEpilog, ExpectProlog bool + PrologType, EpilogType reflect.Type + InvalidTestCases [][]interface{} + ValidTestCases []ValidTestCase +} + +type ValidTestCase struct { + ValidData []interface{} + TestCallbacks func(t *testing.T, prolog, epilog sqhook.Callback) +} + +func RunCallbackTest(t *testing.T, config TestConfig) { + for _, data := range config.InvalidTestCases { + data := data + t.Run("with incorrect data", func(t *testing.T) { + prolog, epilog, err := config.CallbacksCtor(data, nil, nil) require.Error(t, err) require.Nil(t, prolog) require.Nil(t, epilog) - } - }) + }) + } + + for _, tc := range config.ValidTestCases { + tc := tc + t.Run("with correct data", func(t *testing.T) { + t.Run("without next callbacks", func(t *testing.T) { + // Instantiate the callback with the given correct rule data + prolog, epilog, err := config.CallbacksCtor(tc.ValidData, nil, nil) + require.NoError(t, err) + checkCallbacksValues(t, config, prolog, epilog) + tc.TestCallbacks(t, prolog, epilog) + }) - t.Run("with correct data", func(t *testing.T) { - // Instantiate the callback with the given correct rule data - prolog, epilog, err := callback.NewAddSecurityHeadersCallbacks([]interface{}{ - []string{"k", "v"}, - []string{"one", "two"}, - []string{"canonical-header", "the value"}, + t.Run("with next callbacks", func(t *testing.T) { + t.Run("wrong next prolog type", func(t *testing.T) { + prolog, epilog, err := config.CallbacksCtor(tc.ValidData, 33, nil) + require.Error(t, err) + require.Nil(t, prolog) + require.Nil(t, epilog) + }) + + t.Run("wrong next epilog type", func(t *testing.T) { + prolog, epilog, err := config.CallbacksCtor(tc.ValidData, nil, func() {}) + require.Error(t, err) + require.Nil(t, prolog) + require.Nil(t, epilog) + }) + + t.Run("with correct next prolog", func(t *testing.T) { + var called bool + nextProlog := reflect.MakeFunc(config.PrologType, func(args []reflect.Value) (results []reflect.Value) { + called = true + return []reflect.Value{reflect.Zero(reflect.TypeOf((*error)(nil)).Elem())} + }).Interface() + + prolog, epilog, err := config.CallbacksCtor(tc.ValidData, nextProlog, nil) + require.NoError(t, err) + checkCallbacksValues(t, config, prolog, epilog) + require.NotNil(t, prolog) + tc.TestCallbacks(t, prolog, epilog) + require.True(t, called) + }) + + t.Run("with correct next epilog", func(t *testing.T) { + var called bool + nextEpilog := reflect.MakeFunc(config.EpilogType, func(args []reflect.Value) (results []reflect.Value) { + called = true + return + }).Interface() + + prolog, epilog, err := config.CallbacksCtor(tc.ValidData, nil, nextEpilog) + require.NoError(t, err) + checkCallbacksValues(t, config, prolog, epilog) + require.NotNil(t, epilog) + tc.TestCallbacks(t, prolog, epilog) + require.True(t, called) + }) + + t.Run("with both correct next callbacks", func(t *testing.T) { + var calledProlog, calledEpilog bool + nextProlog := reflect.MakeFunc(config.PrologType, func(args []reflect.Value) (results []reflect.Value) { + calledProlog = true + return []reflect.Value{reflect.Zero(reflect.TypeOf((*error)(nil)).Elem())} + }).Interface() + nextEpilog := reflect.MakeFunc(config.EpilogType, func(args []reflect.Value) (results []reflect.Value) { + calledEpilog = true + return + }).Interface() + + prolog, epilog, err := config.CallbacksCtor(tc.ValidData, nextProlog, nextEpilog) + require.NoError(t, err) + require.NotNil(t, prolog) + require.NotNil(t, epilog) + tc.TestCallbacks(t, prolog, epilog) + require.True(t, calledProlog) + require.True(t, calledEpilog) + }) + }) }) - require.NoError(t, err) + } +} + +func checkCallbacksValues(t *testing.T, config TestConfig, prolog, epilog sqhook.Callback) { + if config.ExpectProlog { require.NotNil(t, prolog) - require.Nil(t, epilog) - // Call it and check the behaviour follows the rule's data - actualProlog, ok := prolog.(callback.AddSecurityHeadersPrologCallbackType) - require.True(t, ok) - var rec http.ResponseWriter = httptest.NewRecorder() - err = actualProlog(nil, &rec) - // Check it behaves as expected - require.NoError(t, err) - expectedHeaders := http.Header{ - "K": []string{"v"}, - "One": []string{"two"}, - "Canonical-Header": []string{"the value"}, - } - require.Equal(t, expectedHeaders, rec.Header()) - }) + } + if config.ExpectEpilog { + require.NotNil(t, prolog) + } } diff --git a/agent/internal/rule/callback/write-custom-error-page.go b/agent/internal/rule/callback/write-custom-error-page.go index 3edaf3d0..8ce41400 100644 --- a/agent/internal/rule/callback/write-custom-error-page.go +++ b/agent/internal/rule/callback/write-custom-error-page.go @@ -16,28 +16,46 @@ import ( // callbacks modifying the arguments of `httphandler.WriteResponse` in order to // modify the http status code and error page that are provided by the rule's // data. -func NewWriteCustomErrorPageCallbacks(data []interface{}) (prolog, epilog sqhook.Callback, err error) { +func NewWriteCustomErrorPageCallbacks(data []interface{}, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) { var statusCode = 500 if len(data) > 0 { d0 := data[0] cfg, ok := d0.(*api.CustomErrorPageRuleDataEntry) if !ok { - return nil, nil, sqerrors.Errorf("unexpected callback data type: got `%T` instead of `*api.CustomErrorPageRuleDataEntry`", d0) + err = sqerrors.Errorf("unexpected callback data type: got `%T` instead of `*api.CustomErrorPageRuleDataEntry`", d0) + return } statusCode = cfg.StatusCode } - return newWriteCustomErrorPagePrologCallback(statusCode, []byte(blockedBySqreenPage)), nil, nil + + // Next callbacks to call + actualNextProlog, ok := nextProlog.(WriteCustomErrorPagePrologCallbackType) + if nextProlog != nil && !ok { + err = sqerrors.Errorf("unexpected next prolog type `%T`", nextProlog) + return + } + // No epilog in this callback, so simply check and pass the given one + if _, ok := nextEpilog.(WriteCustomErrorPageEpilogCallbackType); nextEpilog != nil && !ok { + err = sqerrors.Errorf("unexpected next epilog type `%T` instead of `%T`", nextEpilog, WriteCustomErrorPageEpilogCallbackType(nil)) + return + } + return newWriteCustomErrorPagePrologCallback(statusCode, []byte(blockedBySqreenPage), actualNextProlog), nextEpilog, nil } type WriteCustomErrorPagePrologCallbackType = func(*sqhook.Context, *http.ResponseWriter, **http.Request, *http.Header, *int, *[]byte) error +type WriteCustomErrorPageEpilogCallbackType = func(*sqhook.Context) // The prolog callback modifies the function arguments in order to replace the // written status code and body. -func newWriteCustomErrorPagePrologCallback(statusCode int, body []byte) WriteCustomErrorPagePrologCallbackType { - return func(_ *sqhook.Context, _ *http.ResponseWriter, _ **http.Request, _ *http.Header, callerStatusCode *int, callerBody *[]byte) error { +func newWriteCustomErrorPagePrologCallback(statusCode int, body []byte, next WriteCustomErrorPagePrologCallbackType) WriteCustomErrorPagePrologCallbackType { + return func(ctx *sqhook.Context, callerWriter *http.ResponseWriter, callerRequest **http.Request, callerHeaders *http.Header, callerStatusCode *int, callerBody *[]byte) error { *callerStatusCode = statusCode *callerBody = body - return nil + + if next == nil { + return nil + } + return next(ctx, callerWriter, callerRequest, callerHeaders, callerStatusCode, callerBody) } } diff --git a/agent/internal/rule/callback/write-custom-error-page_test.go b/agent/internal/rule/callback/write-custom-error-page_test.go index 8f70917d..8bfb13c4 100644 --- a/agent/internal/rule/callback/write-custom-error-page_test.go +++ b/agent/internal/rule/callback/write-custom-error-page_test.go @@ -5,72 +5,63 @@ package callback_test import ( + "reflect" "testing" "github.com/sqreen/go-agent/agent/internal/backend/api" "github.com/sqreen/go-agent/agent/internal/rule/callback" + "github.com/sqreen/go-agent/agent/sqlib/sqhook" "github.com/stretchr/testify/require" ) func TestNewWriteCustomErrorPageCallbacks(t *testing.T) { - t.Run("with incorrect data", func(t *testing.T) { - for _, data := range [][]interface{}{ + RunCallbackTest(t, TestConfig{ + CallbacksCtor: callback.NewWriteCustomErrorPageCallbacks, + ExpectProlog: true, + PrologType: reflect.TypeOf(callback.WriteCustomErrorPagePrologCallbackType(nil)), + EpilogType: reflect.TypeOf(callback.WriteCustomErrorPageEpilogCallbackType(nil)), + InvalidTestCases: [][]interface{}{ {33}, {"yet another wrong type"}, - } { - prolog, epilog, err := callback.NewWriteCustomErrorPageCallbacks(data) - require.Error(t, err) - require.Nil(t, prolog) - require.Nil(t, epilog) - } - }) - - t.Run("with correct data", func(t *testing.T) { - for _, tc := range []struct { - testName string - data []interface{} - expectedStatusCode int - }{ + }, + ValidTestCases: []ValidTestCase{ { - testName: "default behaviour with nil data", - data: nil, - expectedStatusCode: 500, + ValidData: nil, + TestCallbacks: testWriteCustomErrorPageCallbacks(500), }, { - testName: "default behaviour with empty array", - data: nil, - expectedStatusCode: 500, + ValidData: []interface{}{}, + TestCallbacks: testWriteCustomErrorPageCallbacks(500), }, { - testName: "actual rule data", - data: []interface{}{ - &api.CustomErrorPageRuleDataEntry{ - StatusCode: 33, - }, + ValidData: []interface{}{ + &api.CustomErrorPageRuleDataEntry{StatusCode: 33}, }, - expectedStatusCode: 33, + TestCallbacks: testWriteCustomErrorPageCallbacks(33), }, - } { - tc := tc - t.Run(tc.testName, func(t *testing.T) { - // Instantiate the callback with the given correct rule data - prolog, epilog, err := callback.NewWriteCustomErrorPageCallbacks(tc.data) - require.NoError(t, err) - require.NotNil(t, prolog) - require.Nil(t, epilog) - // Call it and check the behaviour follows the rule's data - actualProlog, ok := prolog.(callback.WriteCustomErrorPagePrologCallbackType) - require.True(t, ok) - var ( - statusCode int - body []byte - ) - err = actualProlog(nil, nil, nil, nil, &statusCode, &body) - // Check it behaves as expected - require.NoError(t, err) - require.Equal(t, tc.expectedStatusCode, statusCode) - require.NotNil(t, body) - }) - } + }, }) } + +func testWriteCustomErrorPageCallbacks(expectedStatusCode int) func(t *testing.T, prolog sqhook.Callback, epilog sqhook.Callback) { + return func(t *testing.T, prolog, epilog sqhook.Callback) { + actualProlog, ok := prolog.(callback.WriteCustomErrorPagePrologCallbackType) + require.True(t, ok) + var ( + statusCode int + body []byte + ) + err := actualProlog(nil, nil, nil, nil, &statusCode, &body) + // Check it behaves as expected + require.NoError(t, err) + require.Equal(t, expectedStatusCode, statusCode) + require.NotNil(t, body) + + // Test the epilog if any + if epilog != nil { + actualEpilog, ok := epilog.(callback.AddSecurityHeadersEpilogCallbackType) + require.True(t, ok) + actualEpilog(&sqhook.Context{}) + } + } +} diff --git a/agent/internal/rule/callback/write-http-redirection.go b/agent/internal/rule/callback/write-http-redirection.go index 924542a7..069682e5 100644 --- a/agent/internal/rule/callback/write-http-redirection.go +++ b/agent/internal/rule/callback/write-http-redirection.go @@ -17,36 +17,51 @@ import ( // callbacks modifying the arguments of `httphandler.WriteResponse` in order to // modify the http status code and headers in order to perform an HTTP // redirection to the URL provided by the rule's data. -func NewWriteHTTPRedirectionCallbacks(data []interface{}) (prolog, epilog sqhook.Callback, err error) { +func NewWriteHTTPRedirectionCallbacks(data []interface{}, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) { var redirectionURL string if len(data) > 0 { d0 := data[0] cfg, ok := d0.(*api.RedirectionRuleDataEntry) if !ok { - return nil, nil, sqerrors.Errorf("unexpected callback data type: got `%T` instead of `*api.CustomErrorPageRuleDataEntry`", d0) + err = sqerrors.Errorf("unexpected callback data type: got `%T` instead of `*api.CustomErrorPageRuleDataEntry`", d0) + return } redirectionURL = cfg.RedirectionURL } if redirectionURL == "" { - return nil, nil, sqerrors.New("unexpected empty redirection url") + err = sqerrors.New("unexpected empty redirection url") + return } - if _, err := url.ParseRequestURI(redirectionURL); err != nil { - return nil, nil, sqerrors.Wrap(err, "validation error of the redirection url") + if _, err = url.ParseRequestURI(redirectionURL); err != nil { + err = sqerrors.Wrap(err, "validation error of the redirection url") + return } - return newWriteHTTPRedirectionPrologCallback(redirectionURL), nil, nil + + // Next callbacks to call + actualNextProlog, ok := nextProlog.(WriteHTTPRedirectionPrologCallbackType) + if nextProlog != nil && !ok { + err = sqerrors.Errorf("unexpected next prolog type `%T`", nextProlog) + return + } + // No epilog in this callback, so simply pass the given one + return newWriteHTTPRedirectionPrologCallback(redirectionURL, actualNextProlog), nextEpilog, nil } type WriteHTTPRedirectionPrologCallbackType = func(*sqhook.Context, *http.ResponseWriter, **http.Request, *http.Header, *int, *[]byte) error // The prolog callback modifies the function arguments in order to perform an // HTTP redirection. -func newWriteHTTPRedirectionPrologCallback(url string) WriteHTTPRedirectionPrologCallbackType { - return func(_ *sqhook.Context, _ *http.ResponseWriter, _ **http.Request, callerHeaders *http.Header, callerStatusCode *int, _ *[]byte) error { +func newWriteHTTPRedirectionPrologCallback(url string, next WriteHTTPRedirectionPrologCallbackType) WriteHTTPRedirectionPrologCallbackType { + return func(ctx *sqhook.Context, callerWriter *http.ResponseWriter, callerRequest **http.Request, callerHeaders *http.Header, callerStatusCode *int, callerBody *[]byte) error { *callerStatusCode = http.StatusSeeOther if *callerHeaders == nil { *callerHeaders = make(http.Header) } callerHeaders.Set("Location", url) - return nil + + if next == nil { + return nil + } + return next(ctx, callerWriter, callerRequest, callerHeaders, callerStatusCode, callerBody) } } diff --git a/agent/internal/rule/callback/write-http-redirection_test.go b/agent/internal/rule/callback/write-http-redirection_test.go index 21d89dd9..18a06e2c 100644 --- a/agent/internal/rule/callback/write-http-redirection_test.go +++ b/agent/internal/rule/callback/write-http-redirection_test.go @@ -24,7 +24,7 @@ func TestNewWriteHTTPRedirectionCallbacks(t *testing.T) { {&api.RedirectionRuleDataEntry{}}, {&api.RedirectionRuleDataEntry{"http//sqreen.com"}}, } { - prolog, epilog, err := callback.NewWriteHTTPRedirectionCallbacks(data) + prolog, epilog, err := callback.NewWriteHTTPRedirectionCallbacks(data, nil, nil) require.Error(t, err) require.Nil(t, prolog) require.Nil(t, epilog) @@ -36,7 +36,7 @@ func TestNewWriteHTTPRedirectionCallbacks(t *testing.T) { expectedURL := "http://sqreen.com" prolog, epilog, err := callback.NewWriteHTTPRedirectionCallbacks([]interface{}{ &api.RedirectionRuleDataEntry{RedirectionURL: expectedURL}, - }) + }, nil, nil) require.NoError(t, err) require.NotNil(t, prolog) require.Nil(t, epilog) diff --git a/agent/internal/rule/callback_test.go b/agent/internal/rule/callback_test.go index a0d1dbbb..3042278a 100644 --- a/agent/internal/rule/callback_test.go +++ b/agent/internal/rule/callback_test.go @@ -42,7 +42,7 @@ func TestNewCallbacks(t *testing.T) { } { tc := tc t.Run(tc.testName, func(t *testing.T) { - _, _, err := rule.NewCallbacks(tc.name, tc.data) + _, _, err := rule.NewCallbacks(tc.name, tc.data, nil, nil) if tc.shouldSucceed { require.NoError(t, err) } else { diff --git a/agent/internal/rule/rule.go b/agent/internal/rule/rule.go index 72ce6c47..f3878096 100644 --- a/agent/internal/rule/rule.go +++ b/agent/internal/rule/rule.go @@ -30,7 +30,7 @@ type Engine struct { logger Logger // Map rules to their corresponding symbol in order to be able to modify them // at run time by atomically replacing a running rule. - rules ruleDescriptors + hooks hookDescriptors packID string cfg *config.Config enabled bool @@ -58,45 +58,47 @@ func (e *Engine) PackID() string { // them by atomically modifying the hooks, and removing what is left. func (e *Engine) SetRules(packID string, rules []api.Rule) { // Create the net rule descriptors and replace the existing ones - ruleDescriptors := newRuleDescriptors(e.logger, rules) + ruleDescriptors := newHookDescriptors(e.logger, rules) e.setRules(packID, ruleDescriptors) } -func (e *Engine) setRules(packID string, descriptors ruleDescriptors) { - for symbol, rule := range descriptors { +func (e *Engine) setRules(packID string, descriptors hookDescriptors) { + for hook, callback := range descriptors { if e.enabled { // TODO: chain multiple callbacks per hookpoint using a callback of callbacks // Attach the callback to the hook - err := rule.hook.Attach(rule.prolog, rule.epilog) + err := hook.Attach(callback.prolog, callback.epilog) if err != nil { - e.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule `%s`: could not attach the callbacks", rule.name))) + e.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule: could not attach the callbacks"))) continue } } // Remove from the previous rules pack the entries that were redefined in // this one. - delete(e.rules, symbol) + delete(e.hooks, hook) } + // Disable previously enabled rules that were not replaced by new ones. - for _, rule := range e.rules { - err := rule.hook.Attach(nil, nil) + for hook := range e.hooks { + err := hook.Attach(nil, nil) if err != nil { - e.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule `%s`: could not attach the callbacks", rule.name))) + e.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule: could not attach the callbacks"))) continue } } // Save the rules pack ID and the list of enabled hooks e.packID = packID - e.rules = descriptors + e.hooks = descriptors } -// newRuleDescriptors walks the list of received rules and creates the map of -// rule descriptors indexed by their symbol. A rule descriptor contains all it -// needs to enable and disable rules at run time. -func newRuleDescriptors(logger Logger, rules []api.Rule) ruleDescriptors { +// newHookDescriptors walks the list of received rules and creates the map of +// hook descriptors indexed by their hook pointer. A hook descriptor contains +// all it takes to enable and disable rules at run time. +func newHookDescriptors(logger Logger, rules []api.Rule) hookDescriptors { // Create and configure the list of callbacks according to the given rules - ruleDescriptors := make(ruleDescriptors) - for _, r := range rules { + var hookDescriptors = make(hookDescriptors) + for i := len(rules) - 1; i >= 0; i-- { + r := rules[i] hookpoint := r.Hookpoint // Find the symbol symbol := fmt.Sprintf("%s.%s", hookpoint.Class, hookpoint.Method) @@ -114,32 +116,32 @@ func newRuleDescriptors(logger Logger, rules []api.Rule) ruleDescriptors { } } // Instantiate the callback - prolog, epilog, err := NewCallbacks(hookpoint.Callback, data) + next := hookDescriptors.Get(hook) + prolog, epilog, err := NewCallbacks(hookpoint.Callback, data, next.prolog, next.epilog) if err != nil { logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule `%s`: could not instantiate the callbacks", r.Name))) continue } - // Create the rule descriptor with everything required to be able to enable - // or disable it afterwards. - ruleDescriptors.Add(symbol, ruleDescriptor{ - name: r.Name, - hook: hook, + // Create the descriptor with everything required to be able to enable or + // disable it afterwards. + hookDescriptors.Set(hook, callbacksDescriptor{ prolog: prolog, epilog: epilog, }) } - if len(ruleDescriptors) == 0 { + // Nothing in the end + if len(hookDescriptors) == 0 { return nil } - return ruleDescriptors + return hookDescriptors } // Enable the hooks of the ongoing configured rules. func (e *Engine) Enable() { - for _, r := range e.rules { - err := r.hook.Attach(r.prolog, r.epilog) + for hook, callback := range e.hooks { + err := hook.Attach(callback.prolog, callback.epilog) if err != nil { - e.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule `%s`: could not attach the callbacks", r.name))) + e.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule: could not attach the callbacks `%v` and `%v` to hook `%v`", callback.prolog, callback.epilog, hook))) } } e.enabled = true @@ -147,23 +149,25 @@ func (e *Engine) Enable() { // Disable the hooks currently attached to callbacks. func (e *Engine) Disable() { - for _, r := range e.rules { - err := r.hook.Attach(nil, nil) + for hook := range e.hooks { + err := hook.Attach(nil, nil) if err != nil { - e.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule `%s`: could not disable the callbacks", r.name))) + e.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule: could not disable hook `%v`", hook))) } } e.enabled = false } -type ruleDescriptors map[string]ruleDescriptor +type hookDescriptors map[*sqhook.Hook]callbacksDescriptor + +type callbacksDescriptor struct { + prolog, epilog sqhook.Callback +} -type ruleDescriptor struct { - name string - hook *sqhook.Hook - epilog, prolog sqhook.Callback +func (m hookDescriptors) Set(hook *sqhook.Hook, descriptor callbacksDescriptor) { + m[hook] = descriptor } -func (m ruleDescriptors) Add(symbol string, descriptor ruleDescriptor) { - m[symbol] = descriptor +func (m hookDescriptors) Get(hook *sqhook.Hook) callbacksDescriptor { + return m[hook] } From bd58c43422fbf0de9b40c3c56327f2b1b5b64d09 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Thu, 4 Jul 2019 15:21:59 +0200 Subject: [PATCH 21/47] sqlib/sqhook: add a stringer method to help logging error messages --- agent/sqlib/sqhook/hook.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/agent/sqlib/sqhook/hook.go b/agent/sqlib/sqhook/hook.go index db24f2f2..5b859adf 100644 --- a/agent/sqlib/sqhook/hook.go +++ b/agent/sqlib/sqhook/hook.go @@ -104,6 +104,8 @@ type Hook struct { // Pointer to a structure containing the callbacks in order to be able to // atomically modify the pointer. attached *callbacks + // Symbol name where the hook is used. Required for the stringer. + symbol string } type callbacks struct { @@ -160,6 +162,7 @@ func New(fn interface{}) *Hook { } // Create the hook, store it in the map and return it. hook := &Hook{ + symbol: symbol, fnType: fnType, } index[symbol] = hook @@ -277,3 +280,7 @@ func validateCallback(callback Callback, argTypes []reflect.Type) (err error) { } return nil } + +func (h *Hook) String() string { + return fmt.Sprintf("%s (%s)", h.symbol, h.fnType) +} From a74ce5876dd44936ef348d00bf179bb69ae8f42e Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Mon, 15 Jul 2019 15:47:19 +0200 Subject: [PATCH 22/47] agent/rule: add a http status code monitoring callback - A rule now also defines the list of metrics stores to create, along with their time period. A new interface is therefore passed to callback constructors which can be passed to the callbacks if required. It allows to provide the metrics method required for the new HTTP code monitoring callback. - The HTTP code monitoring callback updates a metrics store with the response status code. This store is defined by the rule and a callback doesn't need to know about it when there is only one metrics store - the first one defined in the rule is used. --- agent/internal/rule/callback.go | 70 ++++++++- agent/internal/rule/callback/callback_test.go | 141 ++++++++++++++++++ .../rule/callback/monitor-http-status-code.go | 41 +++++ .../callback/monitor-http-status-code_test.go | 45 ++++++ agent/internal/rule/callback/types.go | 13 ++ agent/internal/rule/callback_test.go | 18 ++- agent/internal/rule/rule.go | 30 ++-- agent/internal/rule/rule_test.go | 18 ++- 8 files changed, 348 insertions(+), 28 deletions(-) create mode 100644 agent/internal/rule/callback/callback_test.go create mode 100644 agent/internal/rule/callback/monitor-http-status-code.go create mode 100644 agent/internal/rule/callback/monitor-http-status-code_test.go create mode 100644 agent/internal/rule/callback/types.go diff --git a/agent/internal/rule/callback.go b/agent/internal/rule/callback.go index 183a455f..1e25cfc6 100644 --- a/agent/internal/rule/callback.go +++ b/agent/internal/rule/callback.go @@ -5,6 +5,12 @@ package rule import ( + "fmt" + "time" + + "github.com/sqreen/go-agent/agent/internal/backend/api" + "github.com/sqreen/go-agent/agent/internal/metrics" + "github.com/sqreen/go-agent/agent/internal/plog" "github.com/sqreen/go-agent/agent/internal/rule/callback" "github.com/sqreen/go-agent/agent/sqlib/sqerrors" "github.com/sqreen/go-agent/agent/sqlib/sqhook" @@ -13,11 +19,11 @@ import ( // CallbackConstructorFunc is a function returning a callback function // configured with the given data. The data types are known by the constructor // that can type-assert them. -type CallbacksConstructorFunc func(data []interface{}, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) +type CallbacksConstructorFunc func(rule callback.Context, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) // NewCallbacks returns the prolog and epilog callbacks of the given callback // name. And error is returned if the callback name is unknown. -func NewCallbacks(name string, data []interface{}, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) { +func NewCallbacks(name string, rule *CallbackContext, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) { var callbacksCtor CallbacksConstructorFunc switch name { default: @@ -28,6 +34,64 @@ func NewCallbacks(name string, data []interface{}, nextProlog, nextEpilog sqhook callbacksCtor = callback.NewWriteHTTPRedirectionCallbacks case "AddSecurityHeaders": callbacksCtor = callback.NewAddSecurityHeadersCallbacks + case "MonitorHTTPStatusCode": + callbacksCtor = callback.NewMonitorHTTPStatusCodeCallbacks + } + return callbacksCtor(rule, nextProlog, nextEpilog) +} + +type CallbackContext struct { + config interface{} + metricsStores map[string]*metrics.Store + defaultMetricsStore *metrics.Store + logger plog.ErrorLogger + name string +} + +func NewCallbackContext(r *api.Rule, logger plog.ErrorLogger, metricsEngine *metrics.Engine) *CallbackContext { + config := newCallbackConfig(&r.Data) + + var ( + metricsStores map[string]*metrics.Store + defaultMetricsStore *metrics.Store + ) + if len(r.Metrics) > 0 { + metricsStores = make(map[string]*metrics.Store) + for _, m := range r.Metrics { + metricsStores[m.Name] = metricsEngine.NewStore(m.Name, time.Second*time.Duration(m.Period)) + } + defaultMetricsStore = metricsStores[r.Metrics[0].Name] + } + + return &CallbackContext{ + config: config, + metricsStores: metricsStores, + defaultMetricsStore: defaultMetricsStore, + name: r.Name, + logger: logger, + } +} + +func newCallbackConfig(data *api.RuleData) (config interface{}) { + if nbData := len(data.Values); nbData > 1 { + configArray := make([]interface{}, 0, nbData) + for _, e := range data.Values { + configArray = append(configArray, e.Value) + } + config = configArray + } else if nbData == 1 { + config = data.Values[0].Value + } + return config +} + +func (d *CallbackContext) Config() interface{} { + return d.config +} + +func (d *CallbackContext) AddMetricsValue(key interface{}, value uint64) { + err := d.defaultMetricsStore.Add(key, value) + if err != nil { + d.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule `%s`: could not add a value to the default metrics store", d.name))) } - return callbacksCtor(data, nextProlog, nextEpilog) } diff --git a/agent/internal/rule/callback/callback_test.go b/agent/internal/rule/callback/callback_test.go new file mode 100644 index 00000000..5da905eb --- /dev/null +++ b/agent/internal/rule/callback/callback_test.go @@ -0,0 +1,141 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package callback_test + +import ( + "reflect" + "testing" + + "github.com/sqreen/go-agent/agent/internal/rule" + "github.com/sqreen/go-agent/agent/sqlib/sqhook" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type TestConfig struct { + CallbacksCtor rule.CallbacksConstructorFunc + ExpectEpilog, ExpectProlog bool + PrologType, EpilogType reflect.Type + InvalidTestCases []interface{} + ValidTestCases []ValidTestCase +} + +type ValidTestCase struct { + Rule *FakeRule + TestCallbacks func(t *testing.T, rule *FakeRule, prolog, epilog sqhook.Callback) +} + +func RunCallbackTest(t *testing.T, config TestConfig) { + for _, data := range config.InvalidTestCases { + data := data + t.Run("with incorrect data", func(t *testing.T) { + prolog, epilog, err := config.CallbacksCtor(&FakeRule{config: data}, nil, nil) + require.Error(t, err) + require.Nil(t, prolog) + require.Nil(t, epilog) + }) + } + + for _, tc := range config.ValidTestCases { + tc := tc + t.Run("with correct data", func(t *testing.T) { + t.Run("without next callbacks", func(t *testing.T) { + // Instantiate the callback with the given correct rule data + prolog, epilog, err := config.CallbacksCtor(tc.Rule, nil, nil) + require.NoError(t, err) + checkCallbacksValues(t, config, prolog, epilog) + tc.TestCallbacks(t, tc.Rule, prolog, epilog) + }) + + t.Run("with next callbacks", func(t *testing.T) { + t.Run("wrong next prolog type", func(t *testing.T) { + prolog, epilog, err := config.CallbacksCtor(tc.Rule, 33, nil) + require.Error(t, err) + require.Nil(t, prolog) + require.Nil(t, epilog) + }) + + t.Run("wrong next epilog type", func(t *testing.T) { + prolog, epilog, err := config.CallbacksCtor(tc.Rule, nil, func() {}) + require.Error(t, err) + require.Nil(t, prolog) + require.Nil(t, epilog) + }) + + t.Run("with correct next prolog", func(t *testing.T) { + var called bool + nextProlog := reflect.MakeFunc(config.PrologType, func(args []reflect.Value) (results []reflect.Value) { + called = true + return []reflect.Value{reflect.Zero(reflect.TypeOf((*error)(nil)).Elem())} + }).Interface() + + prolog, epilog, err := config.CallbacksCtor(tc.Rule, nextProlog, nil) + require.NoError(t, err) + checkCallbacksValues(t, config, prolog, epilog) + require.NotNil(t, prolog) + tc.TestCallbacks(t, tc.Rule, prolog, epilog) + require.True(t, called) + }) + + t.Run("with correct next epilog", func(t *testing.T) { + var called bool + nextEpilog := reflect.MakeFunc(config.EpilogType, func(args []reflect.Value) (results []reflect.Value) { + called = true + return + }).Interface() + + prolog, epilog, err := config.CallbacksCtor(tc.Rule, nil, nextEpilog) + require.NoError(t, err) + checkCallbacksValues(t, config, prolog, epilog) + require.NotNil(t, epilog) + tc.TestCallbacks(t, tc.Rule, prolog, epilog) + require.True(t, called) + }) + + t.Run("with both correct next callbacks", func(t *testing.T) { + var calledProlog, calledEpilog bool + nextProlog := reflect.MakeFunc(config.PrologType, func(args []reflect.Value) (results []reflect.Value) { + calledProlog = true + return []reflect.Value{reflect.Zero(reflect.TypeOf((*error)(nil)).Elem())} + }).Interface() + nextEpilog := reflect.MakeFunc(config.EpilogType, func(args []reflect.Value) (results []reflect.Value) { + calledEpilog = true + return + }).Interface() + + prolog, epilog, err := config.CallbacksCtor(tc.Rule, nextProlog, nextEpilog) + require.NoError(t, err) + require.NotNil(t, prolog) + require.NotNil(t, epilog) + tc.TestCallbacks(t, tc.Rule, prolog, epilog) + require.True(t, calledProlog) + require.True(t, calledEpilog) + }) + }) + }) + } +} + +func checkCallbacksValues(t *testing.T, config TestConfig, prolog, epilog sqhook.Callback) { + if config.ExpectProlog { + require.NotNil(t, prolog) + } + if config.ExpectEpilog { + require.NotNil(t, epilog) + } +} + +type FakeRule struct { + config interface{} + mock.Mock +} + +func (r *FakeRule) AddMetricsValue(key interface{}, value uint64) { + r.Called(key, value) +} + +func (r *FakeRule) Config() interface{} { + return r.config +} diff --git a/agent/internal/rule/callback/monitor-http-status-code.go b/agent/internal/rule/callback/monitor-http-status-code.go new file mode 100644 index 00000000..5754d6f7 --- /dev/null +++ b/agent/internal/rule/callback/monitor-http-status-code.go @@ -0,0 +1,41 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package callback + +import ( + "github.com/sqreen/go-agent/agent/sqlib/sqerrors" + "github.com/sqreen/go-agent/agent/sqlib/sqhook" +) + +func NewMonitorHTTPStatusCodeCallbacks(rule Context, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) { + // Next callbacks to call + actualNextProlog, ok := nextProlog.(MonitorHTTPStatusCodePrologCallbackType) + if nextProlog != nil && !ok { + err = sqerrors.Errorf("unexpected next prolog type `%T` instead of `%T`", nextProlog, MonitorHTTPStatusCodePrologCallbackType(nil)) + return + } + // No epilog in this callback, so simply check and pass the given one + if _, ok := nextEpilog.(MonitorHTTPStatusCodeEpilogCallbackType); nextEpilog != nil && !ok { + err = sqerrors.Errorf("unexpected next epilog type `%T` instead of `%T`", nextEpilog, MonitorHTTPStatusCodeEpilogCallbackType(nil)) + return + } + return newMonitorHTTPStatusCodePrologCallback(rule, actualNextProlog), nextEpilog, nil +} + +func newMonitorHTTPStatusCodePrologCallback(rule Context, next MonitorHTTPStatusCodePrologCallbackType) MonitorHTTPStatusCodePrologCallbackType { + return func(ctx *sqhook.Context, code *int) error { + //if status := *code; status >= 400 && status <= 500 { + rule.AddMetricsValue(*code, 1) + //} + + if next == nil { + return nil + } + return next(ctx, code) + } +} + +type MonitorHTTPStatusCodePrologCallbackType = func(*sqhook.Context, *int) error +type MonitorHTTPStatusCodeEpilogCallbackType = func(*sqhook.Context) diff --git a/agent/internal/rule/callback/monitor-http-status-code_test.go b/agent/internal/rule/callback/monitor-http-status-code_test.go new file mode 100644 index 00000000..30d5b4ad --- /dev/null +++ b/agent/internal/rule/callback/monitor-http-status-code_test.go @@ -0,0 +1,45 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package callback_test + +import ( + "math/rand" + "reflect" + "testing" + + "github.com/sqreen/go-agent/agent/internal/rule/callback" + "github.com/sqreen/go-agent/agent/sqlib/sqhook" + "github.com/stretchr/testify/require" +) + +func TestNewMonitorHTTPStatusCodeCallbacks(t *testing.T) { + RunCallbackTest(t, TestConfig{ + CallbacksCtor: callback.NewMonitorHTTPStatusCodeCallbacks, + ExpectProlog: true, + PrologType: reflect.TypeOf(callback.MonitorHTTPStatusCodePrologCallbackType(nil)), + EpilogType: reflect.TypeOf(callback.MonitorHTTPStatusCodeEpilogCallbackType(nil)), + ValidTestCases: []ValidTestCase{ + { + Rule: &FakeRule{}, + TestCallbacks: func(t *testing.T, rule *FakeRule, prolog, epilog sqhook.Callback) { + actualProlog, ok := prolog.(callback.MonitorHTTPStatusCodePrologCallbackType) + require.True(t, ok) + code := rand.Int() + rule.On("AddMetricsValue", code, uint64(1)).Return().Once() + err := actualProlog(nil, &code) + // Check it behaves as expected + require.NoError(t, err) + + // Test the epilog if any + if epilog != nil { + actualEpilog, ok := epilog.(callback.MonitorHTTPStatusCodeEpilogCallbackType) + require.True(t, ok) + actualEpilog(&sqhook.Context{}) + } + }, + }, + }, + }) +} diff --git a/agent/internal/rule/callback/types.go b/agent/internal/rule/callback/types.go new file mode 100644 index 00000000..ee4ba772 --- /dev/null +++ b/agent/internal/rule/callback/types.go @@ -0,0 +1,13 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package callback + +type Context interface { + // Get the rule configuration. + Config() interface{} + // Add a new metrics value for the given key to the default metrics store + // given by the rule. + AddMetricsValue(key interface{}, value uint64) +} diff --git a/agent/internal/rule/callback_test.go b/agent/internal/rule/callback_test.go index 3042278a..b234e432 100644 --- a/agent/internal/rule/callback_test.go +++ b/agent/internal/rule/callback_test.go @@ -16,33 +16,37 @@ func TestNewCallbacks(t *testing.T) { for _, tc := range []struct { testName string name string - data []interface{} + rule *rule.CallbackContext shouldSucceed bool }{ { testName: "not existing", name: "iDontExist", - data: nil, + rule: nil, shouldSucceed: false, }, { testName: "empty string", name: "", - data: nil, + rule: nil, shouldSucceed: false, }, { testName: "WriteCustomErrorPage", name: "WriteCustomErrorPage", - data: []interface{}{ - &api.CustomErrorPageRuleDataEntry{}, - }, + rule: rule.NewCallbackContext(&api.Rule{ + Data: api.RuleData{ + Values: []api.RuleDataEntry{ + {&api.CustomErrorPageRuleDataEntry{}}, + }, + }, + }, nil, nil), shouldSucceed: true, }, } { tc := tc t.Run(tc.testName, func(t *testing.T) { - _, _, err := rule.NewCallbacks(tc.name, tc.data, nil, nil) + _, _, err := rule.NewCallbacks(tc.name, tc.rule, nil, nil) if tc.shouldSucceed { require.NoError(t, err) } else { diff --git a/agent/internal/rule/rule.go b/agent/internal/rule/rule.go index f3878096..7087f9bc 100644 --- a/agent/internal/rule/rule.go +++ b/agent/internal/rule/rule.go @@ -21,6 +21,7 @@ import ( "github.com/sqreen/go-agent/agent/internal/backend/api" "github.com/sqreen/go-agent/agent/internal/config" + "github.com/sqreen/go-agent/agent/internal/metrics" "github.com/sqreen/go-agent/agent/internal/plog" "github.com/sqreen/go-agent/agent/sqlib/sqerrors" "github.com/sqreen/go-agent/agent/sqlib/sqhook" @@ -30,10 +31,11 @@ type Engine struct { logger Logger // Map rules to their corresponding symbol in order to be able to modify them // at run time by atomically replacing a running rule. - hooks hookDescriptors - packID string - cfg *config.Config - enabled bool + hooks hookDescriptors + packID string + cfg *config.Config + enabled bool + metricsEngine *metrics.Engine } // Logger interface required by this package. @@ -43,9 +45,10 @@ type Logger interface { } // NewEngine returns a new rule engine. -func NewEngine(logger Logger) *Engine { +func NewEngine(logger Logger, metricsEngine *metrics.Engine) *Engine { return &Engine{ - logger: logger, + logger: logger, + metricsEngine: metricsEngine, } } @@ -58,7 +61,7 @@ func (e *Engine) PackID() string { // them by atomically modifying the hooks, and removing what is left. func (e *Engine) SetRules(packID string, rules []api.Rule) { // Create the net rule descriptors and replace the existing ones - ruleDescriptors := newHookDescriptors(e.logger, rules) + ruleDescriptors := newHookDescriptors(e.logger, rules, e.metricsEngine) e.setRules(packID, ruleDescriptors) } @@ -94,7 +97,7 @@ func (e *Engine) setRules(packID string, descriptors hookDescriptors) { // newHookDescriptors walks the list of received rules and creates the map of // hook descriptors indexed by their hook pointer. A hook descriptor contains // all it takes to enable and disable rules at run time. -func newHookDescriptors(logger Logger, rules []api.Rule) hookDescriptors { +func newHookDescriptors(logger Logger, rules []api.Rule, metricsEngine *metrics.Engine) hookDescriptors { // Create and configure the list of callbacks according to the given rules var hookDescriptors = make(hookDescriptors) for i := len(rules) - 1; i >= 0; i-- { @@ -107,17 +110,10 @@ func newHookDescriptors(logger Logger, rules []api.Rule) hookDescriptors { logger.Debugf("rule `%s` ignored: symbol `%s` cannot be hooked", r.Name, symbol) continue } - // Get the callback data from the API message - var data []interface{} - if nbData := len(r.Data.Values); nbData > 0 { - data = make([]interface{}, 0, nbData) - for _, e := range r.Data.Values { - data = append(data, e.Value) - } - } // Instantiate the callback next := hookDescriptors.Get(hook) - prolog, epilog, err := NewCallbacks(hookpoint.Callback, data, next.prolog, next.epilog) + ruleDescriptor := NewCallbackContext(&r, logger, metricsEngine) + prolog, epilog, err := NewCallbacks(hookpoint.Callback, ruleDescriptor, next.prolog, next.epilog) if err != nil { logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule `%s`: could not instantiate the callbacks", r.Name))) continue diff --git a/agent/internal/rule/rule_test.go b/agent/internal/rule/rule_test.go index b91583e5..fa06fac1 100644 --- a/agent/internal/rule/rule_test.go +++ b/agent/internal/rule/rule_test.go @@ -11,6 +11,7 @@ import ( "testing" "github.com/sqreen/go-agent/agent/internal/backend/api" + "github.com/sqreen/go-agent/agent/internal/metrics" "github.com/sqreen/go-agent/agent/internal/plog" "github.com/sqreen/go-agent/agent/internal/rule" "github.com/sqreen/go-agent/agent/sqlib/sqhook" @@ -24,7 +25,7 @@ type empty struct{} func TestEngineUsage(t *testing.T) { logger := plog.NewLogger(plog.Debug, os.Stderr, 0) - engine := rule.NewEngine(logger) + engine := rule.NewEngine(logger, metrics.NewEngine(plog.NewLogger(plog.Debug, os.Stderr, 0))) hookFunc1 := sqhook.New(func1) require.NotNil(t, hookFunc1) hookFunc2 := sqhook.New(func2) @@ -52,6 +53,11 @@ func TestEngineUsage(t *testing.T) { Method: "func1", Callback: "WriteCustomErrorPage", }, + Data: api.RuleData{ + Values: []api.RuleDataEntry{ + {&api.CustomErrorPageRuleDataEntry{}}, + }, + }, }, { Name: "another valid rule", @@ -60,6 +66,11 @@ func TestEngineUsage(t *testing.T) { Method: "func2", Callback: "WriteCustomErrorPage", }, + Data: api.RuleData{ + Values: []api.RuleDataEntry{ + {&api.CustomErrorPageRuleDataEntry{}}, + }, + }, }, }) @@ -120,6 +131,11 @@ func TestEngineUsage(t *testing.T) { Method: "func2", Callback: "WriteCustomErrorPage", }, + Data: api.RuleData{ + Values: []api.RuleDataEntry{ + {&api.CustomErrorPageRuleDataEntry{}}, + }, + }, }, }) // Check the callbacks were removed for func1 and not func2 From f21de01f188385eea0da9b1805d5956309ce6a37 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Mon, 15 Jul 2019 16:05:34 +0200 Subject: [PATCH 23/47] sdk/middleware/sqhttp: add http status code monitoring hook point The status code is written using method `WriteHeader()` of the `ResponseWriter`. We therefore wrap it to monitor the value being written through a hook point that will be enabled at run time according to the received rules. --- sdk/middleware/sqhttp/http.go | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/sdk/middleware/sqhttp/http.go b/sdk/middleware/sqhttp/http.go index 4a823e0a..a6983298 100644 --- a/sdk/middleware/sqhttp/http.go +++ b/sdk/middleware/sqhttp/http.go @@ -73,6 +73,8 @@ func MiddlewareWithError(next Handler) Handler { defer req.Close() // Use the newly created request compliant with `sdk.FromContext()`. r = req.Request() + // Wrap the response writer to monitor the http status codes. + w = ResponseWriter{w} // Check if an early security action is already required such as based on // the request IP address. if handler := req.SecurityResponse(); handler != nil { @@ -147,8 +149,39 @@ func addSecurityHeaders(w http.ResponseWriter) (err error) { return nil } -var addSecurityHeaderHook *sqhook.Hook +var ( + addSecurityHeaderHook *sqhook.Hook + responseWriterWriteHeader *sqhook.Hook +) func init() { addSecurityHeaderHook = sqhook.New(addSecurityHeaders) + responseWriterWriteHeader = sqhook.New(responseWriter.WriteHeader) +} + +type ResponseWriter = responseWriter + +type responseWriter struct { + http.ResponseWriter +} + +func (w responseWriter) WriteHeader(statusCode int) { + { + type Prolog = func(*sqhook.Context, *int) error + type Epilog = func(*sqhook.Context) + ctx := sqhook.Context{sqhook.MethodReceiver(&w)} + prolog, epilog := responseWriterWriteHeader.Callbacks() + if epilog, ok := epilog.(Epilog); ok { + defer epilog(&ctx) + } + if prolog, ok := prolog.(Prolog); ok { + if err := prolog(&ctx, &statusCode); err != nil { + return + } + } + } + + if w.ResponseWriter != nil { + w.ResponseWriter.WriteHeader(statusCode) + } } From f3c03f80f3b998e49edbd7d891ab439dde2006fc Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Mon, 15 Jul 2019 16:47:17 +0200 Subject: [PATCH 24/47] sdk/middleware: add status code monitoring to gin and echo There is unfortunately no way to properly replace the response writer for Gin and Echo. This adds a quick workaround in order to pass the status code to the monitoring callback by calling the sqhttp's response writer manually. --- sdk/middleware/sqecho/echo.go | 7 ++++++- sdk/middleware/sqgin/gin.go | 7 +++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sdk/middleware/sqecho/echo.go b/sdk/middleware/sqecho/echo.go index 3c5847be..49b6902b 100644 --- a/sdk/middleware/sqecho/echo.go +++ b/sdk/middleware/sqecho/echo.go @@ -58,12 +58,17 @@ func Middleware() echo.MiddlewareFunc { // Create a middleware function by adapting to sqhttp's return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - return sqhttp.MiddlewareWithError(sqhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + return sqhttp.MiddlewareWithError(sqhttp.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) error { c.SetRequest(r) // Echo defines its own context interface, so we need to store it in // Echo's context. Echo expects string keys. contextKey := sdk.HTTPRequestRecordContextKey.String c.Set(contextKey, sdk.FromContext(r.Context())) + c.Response().After(func() { + // Hack for now to monitor the status code because Gin doesn't use the + // HTTP ResponseWriter when overwriting it through c.Writer = ... + sqhttp.ResponseWriter{}.WriteHeader(c.Response().Status) + }) return next(c) })).ServeHTTP(c.Response(), c.Request()) } diff --git a/sdk/middleware/sqgin/gin.go b/sdk/middleware/sqgin/gin.go index e6570ff9..5b53fe4a 100644 --- a/sdk/middleware/sqgin/gin.go +++ b/sdk/middleware/sqgin/gin.go @@ -67,6 +67,7 @@ func Middleware() gingonic.HandlerFunc { contextKey := sdk.HTTPRequestRecordContextKey.String c.Set(contextKey, sdk.FromContext(r.Context())) c.Next() + monitorHTTPStatusCode(c.Writer.Status()) return nil })).ServeHTTP(c.Writer, c.Request) if err != nil { @@ -74,3 +75,9 @@ func Middleware() gingonic.HandlerFunc { } } } + +func monitorHTTPStatusCode(statusCode int) { + // Hack for now to monitor the status code because Gin doesn't use the + // HTTP ResponseWriter when overwriting it through c.Writer = ... + sqhttp.ResponseWriter{}.WriteHeader(statusCode) +} From 50a76e55867ac4eac3ba64e1c8c867f1b6b612c7 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Mon, 15 Jul 2019 16:51:41 +0200 Subject: [PATCH 25/47] agent/metrics: add a new metrics engine - The metrics engine allows to manage the metrics stores. - Metrics stores are optimized for writes of existing keys (N goroutines writing M values, where N > M). - No extra goroutines nor channels are required thanks to a new polling interface to be used on every heartbeat. --- agent/internal/metrics/metrics.go | 240 +++++++++++ agent/internal/metrics/metrics_test.go | 533 +++++++++++++++++++++++++ 2 files changed, 773 insertions(+) create mode 100644 agent/internal/metrics/metrics.go create mode 100644 agent/internal/metrics/metrics_test.go diff --git a/agent/internal/metrics/metrics.go b/agent/internal/metrics/metrics.go new file mode 100644 index 00000000..013637f3 --- /dev/null +++ b/agent/internal/metrics/metrics.go @@ -0,0 +1,240 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +// Package metrics provides shared metrics stores. A metrics store is a +// key/value store with a given time period after which the data is considered +// ready. This package provides an implementation optimized for writes updating +// already existing keys: lots of goroutines updating a smaller set of keys. +// The metrics engine allows to create and register new metrics stores that a +// single reader (Sqreen's agent) can concurrently read. Read and write +// operations and mutually exclusive - slow polling is better for aggregating +// more data while not blocking the writers too often. +// +// Main requirements: +// +// - Loss-less kv-stores. +// - Near zero time impact on the hot path (updates): no need to switch to +// another goroutines, no blocking locks. +// +// Design decisions: +// +// The former first implementation was using channels and dedicated goroutines +// sleeping until the period was passed. The major issue was the case when +// the channels were full, with the choice of either blocking the sending +// goroutine, or dropping the data to avoid blocking it. +// This design is now considered not suitable for metrics as they happen at a +// too frequently to go through a channel. A channel indeed needs at least one +// extra reader goroutine that would require too much CPU time to aggregate +// all the metrics values. +// +// Metrics store operations, insertions and updates of integer values, are +// therefore considered shorter than any "pure-Go" approach with channels and +// so on. The main challenge here comes from the map whose index cannot be +// modified concurrently. So the idea is to use a RWLock it in order to +// mutually exclude the insertions of new values, updates of existing values and +// retrieval of expired values. +// The hot path being updates of existing values, the Add() method first tries +// to only RLock the store in order to avoid locking every other +// updating-goroutine. The value being a uint64, it can be atomically updated +// without using an lock for the value. +// +// The metrics stores and engine provide a polling interface to retrieve stores +// whose period are passed. No goroutine is started to automatically swap the +// stores. This is due to the fact that metrics are sent by the Sqreen agent +// only during the heartbeat; it can therefore check for expired stores. +// Metrics stores can therefore be longer than their period and will actually +// last until they are flushed by the reader goroutine. +package metrics + +import ( + "reflect" + "sync" + "sync/atomic" + "time" + + "github.com/sqreen/go-agent/agent/internal/plog" + "github.com/sqreen/go-agent/agent/sqlib/sqerrors" +) + +// Engine manages the metrics stores in oder to create new one, and to poll +// the existing ones. Engine's methods are not thread-safe and designed to be +// used by a single goroutine. +type Engine struct { + logger plog.DebugLogger + stores map[string]*Store +} + +func NewEngine(logger plog.DebugLogger) *Engine { + return &Engine{ + logger: logger, + stores: make(map[string]*Store), + } +} + +// NewStore creates and registers a new metrics store. +func (e *Engine) NewStore(id string, period time.Duration) *Store { + store := newStore(period) + e.stores[id] = store + return store +} + +// ReadyMetrics returns the set of ready stores (ie. having data and a passed +// period). This operation blocks metrics stores operations and should be +// wisely used. +func (e *Engine) ReadyMetrics() (expiredMetrics map[string]*ReadyStore) { + expiredMetrics = make(map[string]*ReadyStore) + for id, s := range e.stores { + if s.Ready() { + ready := s.Flush() + expiredMetrics[id] = ready + e.logger.Debugf("metrics: store `%s` ready with `%d` entries", id, len(ready.Metrics())) + } + } + if len(expiredMetrics) == 0 { + return nil + } + return expiredMetrics +} + +// Store is a metrics store optimized for write accesses to already existing +// keys (cf. Add). It has a period of time after which the data is considered +// ready to be retrieved. An empty store is never considered ready and the +// deadline is computed when the first value is inserted. +type Store struct { + // Map of comparable types to uint64 pointers. + set StoreMap + lock sync.RWMutex + // Next deadline, computed when the first value is inserted. + deadline time.Time + // Minimum time duration the data should be kept. + period time.Duration +} + +type StoreMap map[interface{}]*uint64 +type ReadyStoreMap map[interface{}]uint64 + +func newStore(period time.Duration) *Store { + return &Store{ + set: make(StoreMap), + period: period, + } +} + +// Add delta to the given key, inserting it if it doesn't exist. This method +// is thread-safe and optimized for updating existing key which is lock-free +// when not concurrently retrieving (method `Flush()`) or inserting a new key. +func (s *Store) Add(key interface{}, delta uint64) error { + // Avoid panic-ing by checking the key type is not nil and comparable. + if key == nil { + return sqerrors.New("unexpected key value `nil`") + } else if !reflect.TypeOf(key).Comparable() { + return sqerrors.Errorf("unexpected non-comparable type `%T`", key) + } + + // Fast hot path: concurrently updating the value of an existing key. + // Lock the store for reading only. + s.lock.RLock() + // Lookup the value + value, exists := s.set[key] + if exists { + // The key already exists. + // Atomically update the value. + // This update operation can be therefore done concurrently. + atomic.AddUint64(value, delta) + // It is important to do it in this write-safe section that is mutually + // exclusive with Flush() which replaces the store's map using Lock(). + } + // Unlock the store + s.lock.RUnlock() + + // Slow path: the key does not exist + if !exists { + // Exclusively lock the store + s.lock.Lock() + defer s.lock.Unlock() + // Check again in case the value has been inserted while getting here. + value, exists = s.set[key] + if exists { + // The value was inserted by another concurrent goroutine. + // We can update the value without atomic operation as we exclusively + // have the lock. + *value += delta + // Note that this is not possible to unlock and perform the atomic + // operation because of possible concurrent `Flush()`. + } else { + // The value still doesn't exist and we need to insert it into the store's + // map. + value := delta + s.set[key] = &value + // Set the deadline when the first value inserted into the metrics store + if s.deadline.IsZero() { + s.deadline = time.Now().Add(s.period) + } + } + } + + return nil +} + +// Flush returns the stored data and the corresponding time window the data was +// held. It should be used when the store is `Ready()`. This method is +// thead-safe. +func (s *Store) Flush() (flushed *ReadyStore) { + // Read current time before swapping the stores in order to avoid making it in + // the critical-section. Reading it before is important in order to get + // old.finish <= new.start. + now := time.Now() + + // Exclusively lock the store in order to get the values and replace it. + s.lock.Lock() + oldMap := s.set + startedAt := s.deadline.Add(-s.period) + // Create a new map with the same capacity as the old one to avoid allocation + // time when still used the same way after the flush. + s.set = make(StoreMap, len(oldMap)) + s.deadline = time.Time{} // time.Time zero value + // Unlock the store which is ready to be used again by concurrent goroutines. + s.lock.Unlock() + + // Compute the map of values getting rid of the pointers (less GC-pressure). + readyMap := make(ReadyStoreMap, len(oldMap)) + for k, v := range oldMap { + readyMap[k] = *v + } + return &ReadyStore{ + set: readyMap, + start: startedAt, + finish: now, + } +} + +// Ready returns true when the store has values and the period passed. +// This method is thread-safe. Note that the atomic operation +// "Ready() + Flush()" doesn't exist, they should therefore be used by a single +// "flusher" goroutine. The locking of `Ready()` is indeed weaker than `Flush()` +// as it only lock the store for reading in order to avoid blocking other +// concurrent updates. +func (s *Store) Ready() bool { + s.lock.RLock() + defer s.lock.RUnlock() + return !s.deadline.IsZero() && time.Now().After(s.deadline) +} + +// ReadyStore provides methods to get the values and the time window. +type ReadyStore struct { + set ReadyStoreMap + start, finish time.Time +} + +func (s *ReadyStore) Start() time.Time { + return s.start +} + +func (s *ReadyStore) Finish() time.Time { + return s.finish +} + +func (s *ReadyStore) Metrics() ReadyStoreMap { + return s.set +} diff --git a/agent/internal/metrics/metrics_test.go b/agent/internal/metrics/metrics_test.go new file mode 100644 index 00000000..ebc800d6 --- /dev/null +++ b/agent/internal/metrics/metrics_test.go @@ -0,0 +1,533 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package metrics_test + +import ( + "fmt" + "math" + "os" + "sync" + "testing" + "time" + + "github.com/sqreen/go-agent/agent/internal/metrics" + "github.com/sqreen/go-agent/agent/internal/plog" + "github.com/sqreen/go-agent/tools/testlib" + "github.com/stretchr/testify/require" +) + +var logger = plog.NewLogger(plog.Debug, os.Stderr, 0) + +func TestUsage(t *testing.T) { + engine := metrics.NewEngine(logger) + + t.Run("store usage", func(t *testing.T) { + t.Run("empty stores are never ready", func(t *testing.T) { + store := engine.NewStore("id 1", time.Microsecond) + require.False(t, store.Ready()) + time.Sleep(time.Microsecond) + require.False(t, store.Ready()) + }) + + t.Run("non-empty stores get ready starting as soon as a value was added", func(t *testing.T) { + store := engine.NewStore("id 1", time.Millisecond) + require.False(t, store.Ready()) + time.Sleep(time.Millisecond) + // Should be still not ready because no values were added + require.False(t, store.Ready()) + // Now add a value + store.Add("key 1", 1) + // Should be started but still not expired + require.False(t, store.Ready()) + time.Sleep(time.Microsecond) + // Should be still not expired + require.False(t, store.Ready()) + time.Sleep(time.Millisecond) + // Now should be expired + require.True(t, store.Ready()) + // Flushing the store should give the map and "restart" the store + old := store.Flush() + require.False(t, store.Ready()) + // Should not be expired while empty + time.Sleep(time.Millisecond) + require.False(t, store.Ready()) + // The old store should have the stored values + require.Equal(t, metrics.ReadyStoreMap{"key 1": 1}, old.Metrics()) + // Adding a new value to the store and then waiting for it to become ready + // should return the net value + store.Add("key 2", 3) + time.Sleep(time.Millisecond) + require.True(t, store.Ready()) + old = store.Flush() + require.Equal(t, metrics.ReadyStoreMap{"key 2": 3}, old.Metrics()) + }) + + t.Run("adding values to a store that is ready is possible", func(t *testing.T) { + store := engine.NewStore("id 1", time.Millisecond) + require.False(t, store.Ready()) + store.Add("key 1", 1) + time.Sleep(time.Millisecond) + require.True(t, store.Ready()) + store.Add("key 1", 1) + store.Add("key 2", 33) + store.Add("key 3", 33) + store.Add("key 3", 1) + + require.True(t, store.Ready()) + old := store.Flush() + require.Equal(t, metrics.ReadyStoreMap{ + "key 1": 2, + "key 2": 33, + "key 3": 34, + }, old.Metrics()) + }) + + t.Run("key types", func(t *testing.T) { + store := engine.NewStore("id 1", time.Millisecond) + + t.Run("non comparable key types are not allowed and do not panic", func(t *testing.T) { + type Struct2 struct { + a int + b string + c float32 + d []byte + } + + require.NotPanics(t, func() { + require.Error(t, store.Add([]byte("no slices"), 1)) + require.Error(t, store.Add(Struct2{ + a: 33, + b: "string", + c: 4.815162342, + d: []byte("no slice"), + }, 1)) + }) + }) + + t.Run("comparable key types are allowed and do not panic", func(t *testing.T) { + type Struct struct { + a int + b string + c float32 + d [33]byte + } + + ptr := &Struct{} + + require.NotPanics(t, func() { + require.NoError(t, store.Add("string", 1)) + require.NoError(t, store.Add(33, 1)) + require.NoError(t, store.Add(Struct{ + a: 33, + b: "string", + c: 4.815162342, + d: [33]byte{}, + }, 1)) + require.NoError(t, store.Add(ptr, 1)) + // Nil is comparable but not allowed + require.Error(t, store.Add(nil, 1)) + }) + + time.Sleep(time.Millisecond) + require.True(t, store.Ready()) + old := store.Flush() + require.Equal(t, metrics.ReadyStoreMap{ + "string": 1, + 33: 1, + Struct{ + a: 33, + b: "string", + c: 4.815162342, + d: [33]byte{}, + }: 1, + ptr: 1, + }, old.Metrics()) + }) + }) + }) + + t.Run("one reader - 8000 writers", func(t *testing.T) { + // Create a store that will be checked more often than actually required by + // its period. So that we cover the case where the store is not always + // ready. + engine := metrics.NewEngine(logger) + // The reader will be awaken 4 times per store period, so only it will see + // a ready store only once out of four. + readerPeriod := time.Microsecond + metricsStorePeriod := 4 * readerPeriod + tick := time.Tick(readerPeriod) + store := engine.NewStore("id", metricsStorePeriod) + + // Signal channel between this test and the reader to tear down the test + done := make(chan struct{}) + + // Array of metrics flushed by the reader + var metricsArray []*metrics.ReadyStore + // Time the test finished - it will be compared to the last metrics store + // finish time + var finished time.Time + + // One reader + go func() { + for { + select { + case <-tick: + if store.Ready() { + ready := store.Flush() + metricsArray = append(metricsArray, ready) + } + + case <-done: + // All goroutines are done, so read get the last data left + if ready := store.Flush(); len(ready.Metrics()) > 0 { + metricsArray = append(metricsArray, ready) + } + finished = time.Now() + // Notify we are done and so the data is ready to be checked + close(done) + return + } + } + }() + + // Start 8000 writers that will write 1000 times + nbWriters := 8000 + nbWrites := 1000 + + var startBarrier, stopBarrier sync.WaitGroup + // Create a start barrier to synchronize every goroutine's launch + startBarrier.Add(nbWriters) + // Create a stopBarrier to signal when all goroutines are done writing + // their values + stopBarrier.Add(nbWriters) + + for n := 0; n < nbWriters; n++ { + go func() { + startBarrier.Wait() // Sync the starts of the goroutines + defer stopBarrier.Done() // Signal we are done when returning + for c := 0; c < nbWrites; c++ { + _ = store.Add(c, 1) + } + }() + } + + // Save the test start time to compare it to the first metrics store's + // that should be latter. + started := time.Now() + + startBarrier.Add(-nbWriters) // Unblock the writer goroutines + stopBarrier.Wait() // Wait for the writer goroutines to be done + done <- struct{}{} // Signal the reader they are done + <-done // Wait for the reader to be done + + // Make sure there is no data left by sleeping more than needed and checking + // the store. + time.Sleep(2 * metricsStorePeriod) + require.False(t, store.Ready()) + + // Aggregate the ready metrics the reader retrieved and check the previous + // store finish time is before the current store start time. + results := make(metrics.ReadyStoreMap) + lastStoreFinish := started + for _, store := range metricsArray { + for k, v := range store.Metrics() { + results[k] += v + } + if !lastStoreFinish.IsZero() { + require.True(t, lastStoreFinish.Before(store.Start()), fmt.Sprint(lastStoreFinish, store)) + } + lastStoreFinish = store.Finish() + } + require.True(t, lastStoreFinish.Before(finished)) + + // Check each writer wrote the expected number of times. + for n := 0; n < nbWrites; n++ { + v, exists := results[n] + require.True(t, exists) + require.Equal(t, uint64(nbWriters), v) + } + }) +} + +func BenchmarkStore(b *testing.B) { + engine := metrics.NewEngine(logger) + + type structKeyType struct { + n int + s string + } + + b.Run("non-concurrent insertion", func(b *testing.B) { + b.Run("integer key type", func(b *testing.B) { + b.Run("non existing keys", func(b *testing.B) { + b.Run("using MetricsStore", func(b *testing.B) { + store := engine.NewStore("id", time.Minute) + b.ResetTimer() + for n := 0; n < b.N; n++ { + _ = store.Add(n, 1) + } + }) + + b.Run("using sync.Map", func(b *testing.B) { + var store sync.Map + b.ResetTimer() + for n := 0; n < b.N; n++ { + store.Store(n, 1) + } + }) + }) + + b.Run("already existing key", func(b *testing.B) { + b.Run("using MetricsStore", func(b *testing.B) { + store := engine.NewStore("id", time.Minute) + b.ResetTimer() + for n := 0; n < b.N; n++ { + _ = store.Add(42, 1) + } + }) + + b.Run("using sync.Map", func(b *testing.B) { + var store sync.Map + b.ResetTimer() + for n := 0; n < b.N; n++ { + store.Store(42, 1) + } + }) + }) + }) + + b.Run("structure key type", func(b *testing.B) { + b.Run("non existing keys", func(b *testing.B) { + key := structKeyType{ + s: testlib.RandString(50), + } + + b.Run("using MetricsStore", func(b *testing.B) { + store := engine.NewStore("id", time.Minute) + b.ResetTimer() + for n := 0; n < b.N; n++ { + key.n = n + _ = store.Add(key, 1) + } + }) + + b.Run("using sync.Map", func(b *testing.B) { + var store sync.Map + b.ResetTimer() + for n := 0; n < b.N; n++ { + key.n = n + store.Store(key, 1) + } + }) + }) + + b.Run("already existing key", func(b *testing.B) { + key := structKeyType{ + n: 42, + s: testlib.RandString(50), + } + b.Run("using MetricsStore", func(b *testing.B) { + store := engine.NewStore("id", time.Minute) + b.ResetTimer() + for n := 0; n < b.N; n++ { + _ = store.Add(key, 1) + } + }) + + b.Run("using sync.Map", func(b *testing.B) { + var store sync.Map + b.ResetTimer() + for n := 0; n < b.N; n++ { + store.Store(key, 1) + } + }) + }) + }) + }) + + b.Run("concurrent insertion", func(b *testing.B) { + for p := 1; p <= 1000; p *= 10 { + p := p + b.Run(fmt.Sprintf("%d", p), func(b *testing.B) { + b.Run("integer key type", func(b *testing.B) { + b.Run("same non existing keys", func(b *testing.B) { + b.Run("using MetricsStore", func(b *testing.B) { + store := engine.NewStore("id", time.Minute) + b.ResetTimer() + b.SetParallelism(p) + b.RunParallel(func(pb *testing.PB) { + n := 0 + for pb.Next() { + _ = store.Add(n, 1) + n++ + } + }) + }) + + b.Run("using sync.Map", func(b *testing.B) { + var store sync.Map + b.ResetTimer() + b.SetParallelism(p) + b.RunParallel(func(pb *testing.PB) { + n := 0 + for pb.Next() { + store.Store(n, 1) + n++ + } + }) + }) + }) + + b.Run("same key", func(b *testing.B) { + b.Run("using MetricsStore", func(b *testing.B) { + store := engine.NewStore("id", time.Minute) + b.ResetTimer() + b.SetParallelism(p) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = store.Add(42, 1) + } + }) + }) + + b.Run("using sync.Map", func(b *testing.B) { + var store sync.Map + b.ResetTimer() + b.SetParallelism(p) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + store.Store(42, 1) + } + }) + }) + }) + }) + b.Run("structure key type", func(b *testing.B) { + b.Run("same non existing keys", func(b *testing.B) { + b.Run("using MetricsStore", func(b *testing.B) { + store := engine.NewStore("id", time.Minute) + b.ResetTimer() + b.SetParallelism(p) + b.RunParallel(func(pb *testing.PB) { + key := structKeyType{ + s: testlib.RandString(50), + } + for pb.Next() { + _ = store.Add(key, 1) + key.n++ + } + }) + }) + + b.Run("using sync.Map", func(b *testing.B) { + var store sync.Map + b.ResetTimer() + b.SetParallelism(p) + b.RunParallel(func(pb *testing.PB) { + key := structKeyType{ + s: testlib.RandString(50), + } + for pb.Next() { + store.Store(key, 1) + key.n++ + } + }) + }) + }) + + b.Run("same key", func(b *testing.B) { + key := structKeyType{ + s: testlib.RandString(50), + n: 42, + } + + b.Run("using MetricsStore", func(b *testing.B) { + store := engine.NewStore("id", time.Minute) + b.ResetTimer() + b.SetParallelism(p) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = store.Add(key, 1) + } + }) + }) + + b.Run("using sync.Map", func(b *testing.B) { + var store sync.Map + b.ResetTimer() + b.SetParallelism(p) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + store.Store(key, 1) + } + }) + }) + }) + }) + }) + } + }) +} + +func BenchmarkUsage(b *testing.B) { + engine := metrics.NewEngine(logger) + + for p := 1; p <= 1000; p *= 10 { + p := p + b.Run(fmt.Sprintf("parallelism/%d", p), func(b *testing.B) { + b.Run("constant cpu time", func(b *testing.B) { + b.Run("reference without metrics", func(b *testing.B) { + b.SetParallelism(p) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + doConstantCPUProcessing(1) + } + }) + }) + + b.Run("integer key type", func(b *testing.B) { + b.Run("concurrent writes to the same key", func(b *testing.B) { + b.SetParallelism(p) + store := engine.NewStore("id", time.Minute) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = store.Add(418, 1) + _ = doConstantCPUProcessing(1) + } + }) + }) + + b.Run("concurrent writes to multiple keys", func(b *testing.B) { + b.SetParallelism(p) + store := engine.NewStore("id", time.Minute) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + n := 0 + for pb.Next() { + _ = store.Add(n, 1) + _ = doConstantCPUProcessing(1) + n++ + } + }) + }) + }) + }) + }) + } +} + +// go:noinline +func doConstantCPUProcessing(n int) (r int) { + for i := 0; i < int(math.Pow(1000, float64(n))); i++ { + r += useCPU(i) + } + return r +} + +// go:noinline +func useCPU(i int) int { + return i + 10 - 2*3 +} From 8f8d72f7a048f8d3fb7ea44847e0c0fc15a61264 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Mon, 15 Jul 2019 16:56:34 +0200 Subject: [PATCH 26/47] agent: adapt the agent to the new metrics engine --- agent/internal/agent.go | 87 +++++++++++---- agent/internal/backend/api/api.go | 13 ++- agent/internal/metrics.go | 169 ++---------------------------- agent/internal/request.go | 13 ++- 4 files changed, 91 insertions(+), 191 deletions(-) diff --git a/agent/internal/agent.go b/agent/internal/agent.go index 6230facd..d4360120 100644 --- a/agent/internal/agent.go +++ b/agent/internal/agent.go @@ -7,6 +7,7 @@ package internal import ( "context" "encoding/json" + "fmt" "io/ioutil" "net/http" "os" @@ -18,6 +19,7 @@ import ( "github.com/sqreen/go-agent/agent/internal/backend" "github.com/sqreen/go-agent/agent/internal/backend/api" "github.com/sqreen/go-agent/agent/internal/config" + "github.com/sqreen/go-agent/agent/internal/metrics" "github.com/sqreen/go-agent/agent/internal/plog" "github.com/sqreen/go-agent/agent/internal/rule" "github.com/sqreen/go-agent/agent/sqlib/sqerrors" @@ -123,17 +125,22 @@ func Start() { } type Agent struct { - logger *plog.Logger - eventMng *eventManager - metricsMng *metricsManager - ctx context.Context - cancel context.CancelFunc - isDone chan struct{} - config *config.Config - appInfo *app.Info - client *backend.Client - actors *actor.Store - rules *rule.Engine + logger *plog.Logger + eventMng *eventManager + metrics *metrics.Engine + staticMetrics staticMetrics + ctx context.Context + cancel context.CancelFunc + isDone chan struct{} + config *config.Config + appInfo *app.Info + client *backend.Client + actors *actor.Store + rules *rule.Engine +} + +type staticMetrics struct { + sdkUserLoginSuccess, sdkUserLoginFailure, sdkUserSignup, whitelistedIP *metrics.Store } // Error channel buffer length. @@ -147,19 +154,27 @@ func New(cfg *config.Config) *Agent { return nil } + metrics := metrics.NewEngine(logger) + // Agent graceful stopping using context cancellation. ctx, cancel := context.WithCancel(context.Background()) return &Agent{ - logger: logger, - isDone: make(chan struct{}), - metricsMng: newMetricsManager(ctx, logger), - ctx: ctx, - cancel: cancel, - config: cfg, - appInfo: app.NewInfo(logger), - client: backend.NewClient(cfg.BackendHTTPAPIBaseURL(), cfg, logger), - actors: actor.NewStore(logger), - rules: rule.NewEngine(logger), + logger: logger, + isDone: make(chan struct{}), + metrics: metrics, + staticMetrics: staticMetrics{ + sdkUserLoginSuccess: metrics.NewStore("sdk-login-success", 60*time.Second), + sdkUserLoginFailure: metrics.NewStore("sdk-login-fail", 60*time.Second), + sdkUserSignup: metrics.NewStore("sdk-signup", 60*time.Second), + whitelistedIP: metrics.NewStore("whitelisted", 60*time.Second), + }, + ctx: ctx, + cancel: cancel, + config: cfg, + appInfo: app.NewInfo(logger), + client: backend.NewClient(cfg.BackendHTTPAPIBaseURL(), cfg, logger), + actors: actor.NewStore(logger), + rules: rule.NewEngine(logger, metrics), } } @@ -239,9 +254,8 @@ func (a *Agent) Serve() error { case <-ticker: a.logger.Debug("heartbeat") - metrics := a.metricsMng.getObservations() appBeatReq := api.AppBeatRequest{ - Metrics: metrics, + Metrics: makeAPIMetrics(a.logger, a.metrics.ReadyMetrics()), CommandResults: commandResults, } @@ -277,6 +291,33 @@ func (a *Agent) Serve() error { } } +func makeAPIMetrics(logger plog.ErrorLogger, expiredMetrics map[string]*metrics.ReadyStore) []api.MetricResponse { + var metricsArray []api.MetricResponse + if readyMetrics := expiredMetrics; len(readyMetrics) > 0 { + metricsArray = make([]api.MetricResponse, len(readyMetrics)) + for name, values := range readyMetrics { + observations := make(map[string]uint64, len(values.Metrics())) + for k, v := range values.Metrics() { + jsonKey, err := json.Marshal(k) + if err != nil { + logger.Error(sqerrors.Wrap(err, fmt.Sprintf("could not marshal to json key the value `%v` of type `%T`", k, k))) + continue + } + observations[string(jsonKey)] = v + } + if len(observations) > 0 { + metricsArray = append(metricsArray, api.MetricResponse{ + Name: name, + Start: values.Start(), + Finish: values.Finish(), + Observation: api.Struct{Value: observations}, + }) + } + } + } + return metricsArray +} + func (a *Agent) InstrumentationEnable() error { if err := a.RulesReload(); err != nil { return err diff --git a/agent/internal/backend/api/api.go b/agent/internal/backend/api/api.go index 984954b5..c3d1998d 100644 --- a/agent/internal/backend/api/api.go +++ b/agent/internal/backend/api/api.go @@ -164,9 +164,16 @@ type BatchRequest_Event struct { } type Rule struct { - Name string `json:"name"` - Hookpoint Hookpoint `json:"hookpoint"` - Data RuleData `json:"data"` + Name string `json:"name"` + Hookpoint Hookpoint `json:"hookpoint"` + Data RuleData `json:"data"` + Metrics []MetricDefinition `json:"metrics"` +} + +type MetricDefinition struct { + Kind string `json:"kind"` + Name string `json:"name"` + Period int64 `json:"period"` } type Hookpoint struct { diff --git a/agent/internal/metrics.go b/agent/internal/metrics.go index 456ad6e6..beec4470 100644 --- a/agent/internal/metrics.go +++ b/agent/internal/metrics.go @@ -4,175 +4,28 @@ package internal -import ( - "context" - "sync" - "sync/atomic" - "time" - - "github.com/pkg/errors" - "github.com/sqreen/go-agent/agent/internal/backend/api" - "github.com/sqreen/go-agent/agent/internal/plog" - "github.com/sqreen/go-agent/agent/sqlib/sqsafe" -) - -type metricsManager struct { - ctx context.Context - logger *plog.Logger - metrics sync.Map - readyLock sync.Mutex - ready []api.MetricResponse -} - -func newMetricsManager(ctx context.Context, logger *plog.Logger) *metricsManager { - return &metricsManager{ - ctx: ctx, - logger: logger, - } -} - -type metricsStore struct { - done func(start, finish time.Time, observations sync.Map) - period time.Duration - entries sync.Map - once sync.Once - swapLock sync.RWMutex - expired bool - logger *plog.Logger -} - -type metricEntry interface { - // Deterministic marshaling if possible... - bucketID() (string, error) -} - -func (m *metricsManager) get(name string) *metricsStore { - store := &metricsStore{ - logger: m.logger, - period: time.Minute, - done: func(start, finish time.Time, observations sync.Map) { - m.metrics.Delete(name) - m.logger.Debug("metrics `", name, "` ready") - m.addObservations(name, start, finish, observations) - }, - } - - actual, _ := m.metrics.LoadOrStore(name, store) - store = actual.(*metricsStore) - store.once.Do(func() { - _ = sqsafe.Go(func() error { - m.logger.Debug("bookkeeping metrics `", name, "` with period `", store.period, "`") - store.monitor(m.ctx, time.Now()) - return nil - }) - }) - - return store -} - -func (m *metricsManager) addObservations(name string, start, finish time.Time, observations sync.Map) { - observation := make(map[string]uint64) - observations.Range(func(k, v interface{}) bool { - key, ok := k.(string) - if !ok { - m.logger.Error(errors.New("unexpected metric key type")) - return true - } - - value, ok := v.(*uint64) - if !ok { - m.logger.Error(errors.New("unexpected metric value type")) - return true - } - - observation[key] = *value - return true - }) - - metric := api.MetricResponse{ - Name: name, - Start: start, - Finish: finish, - Observation: api.Struct{observation}, - } - - m.readyLock.Lock() - defer m.readyLock.Unlock() - m.ready = append(m.ready, metric) -} - -func (m *metricsManager) getObservations() []api.MetricResponse { - m.readyLock.Lock() - defer m.readyLock.Unlock() - ready := m.ready - m.ready = m.ready[0:0] - return ready -} - -func (s *metricsStore) add(e metricEntry) { - s.swapLock.RLock() - defer s.swapLock.RUnlock() - - if s.expired { - // FIXME: better design preventing this case - // For now, a few events may be dropped. - return - } - - var n uint64 = 1 - key, err := e.bucketID() - if err != nil { - // Log the error and continue. - s.logger.Error(errors.Wrap(err, "could not compute the bucket id of the metric key")) - return - } - actual, loaded := s.entries.LoadOrStore(key, &n) - if loaded { - newVal := atomic.AddUint64(actual.(*uint64), 1) - s.logger.Debug("metric store value of `", key, "` set to ", newVal) - } else { - s.logger.Debug("metric store value of `", key, "` set to ", n) - } -} - -func (s *metricsStore) monitor(ctx context.Context, start time.Time) { - var finish time.Time - select { - case <-ctx.Done(): - finish = time.Now() - case finish = <-time.After(s.period): - } - - s.swapLock.Lock() - entries := s.entries - s.entries = sync.Map{} - s.expired = true - s.swapLock.Unlock() - - s.done(start, finish, entries) -} +import "github.com/sqreen/go-agent/agent/internal/metrics" func (a *Agent) addUserEvent(event userEventFace) { - if a.config.Disable() || a.metricsMng == nil { + if a.config.Disable() || a.metrics == nil { // Disabled or not yet initialized agent return } - - var store *metricsStore + var store *metrics.Store switch actual := event.(type) { case *authUserEvent: if actual.loginSuccess { - store = a.metricsMng.get("sdk-login-success") + store = a.staticMetrics.sdkUserLoginSuccess } else { - store = a.metricsMng.get("sdk-login-fail") + store = a.staticMetrics.sdkUserLoginFailure } case *signupUserEvent: - store = a.metricsMng.get("sdk-signup") + store = a.staticMetrics.sdkUserSignup default: + // TODO: log error return } - - store.add(event) + store.Add(event, 1) } type WhitelistedIP struct { @@ -184,11 +37,11 @@ func (m WhitelistedIP) bucketID() (string, error) { } func (a *Agent) addWhitelistEvent(matchedWhitelistEntry string) { - if a.config.Disable() || a.metricsMng == nil { + if a.config.Disable() || a.metrics == nil { // Agent is disabled or not yet initialized return } - a.metricsMng.get("whitelisted").add(WhitelistedIP{ + a.staticMetrics.whitelistedIP.Add(WhitelistedIP{ MatchedWhitelistEntry: matchedWhitelistEntry, - }) + }, 1) } diff --git a/agent/internal/request.go b/agent/internal/request.go index 7a0971e7..b978573c 100644 --- a/agent/internal/request.go +++ b/agent/internal/request.go @@ -54,7 +54,6 @@ type HTTPRequestEvent struct { type userEventFace interface { isUserEvent() - metricEntry } type userEvent struct { @@ -70,12 +69,12 @@ type authUserEvent struct { func (_ *authUserEvent) isUserEvent() {} -func (e *authUserEvent) bucketID() (string, error) { +func (e *authUserEvent) MarshalJSON() ([]byte, error) { k := &userMetricKey{ id: e.userEvent.userIdentifiers, ip: e.userEvent.ip, } - return k.bucketID() + return k.MarshalJSON() } type userMetricKey struct { @@ -83,7 +82,7 @@ type userMetricKey struct { ip net.IP } -func (k *userMetricKey) bucketID() (string, error) { +func (k *userMetricKey) MarshalJSON() ([]byte, error) { var keys [][]interface{} for prop, val := range k.id { keys = append(keys, []interface{}{prop, val}) @@ -96,19 +95,19 @@ func (k *userMetricKey) bucketID() (string, error) { IP: k.ip.String(), } buf, err := json.Marshal(&v) - return string(buf), err + return buf, err } type signupUserEvent struct { *userEvent } -func (e *signupUserEvent) bucketID() (string, error) { +func (e *signupUserEvent) MarshalJSON() ([]byte, error) { k := &userMetricKey{ id: e.userEvent.userIdentifiers, ip: e.userEvent.ip, } - return k.bucketID() + return k.MarshalJSON() } func (_ *signupUserEvent) isUserEvent() {} From 9ed28e3c8280a2af4bb84753918ae2bf46ce015a Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Mon, 15 Jul 2019 17:02:55 +0200 Subject: [PATCH 27/47] agent/rule/callbacks: adapt existing callbacks to the new callback api --- .../rule/callback/add-security-headers.go | 28 ++-- .../callback/add-security-headers_test.go | 140 ++---------------- .../rule/callback/monitor-http-status-code.go | 2 - .../rule/callback/write-custom-error-page.go | 9 +- .../callback/write-custom-error-page_test.go | 20 +-- .../rule/callback/write-http-redirection.go | 16 +- .../callback/write-http-redirection_test.go | 79 +++++----- 7 files changed, 95 insertions(+), 199 deletions(-) diff --git a/agent/internal/rule/callback/add-security-headers.go b/agent/internal/rule/callback/add-security-headers.go index d5849afd..4278d662 100644 --- a/agent/internal/rule/callback/add-security-headers.go +++ b/agent/internal/rule/callback/add-security-headers.go @@ -14,20 +14,28 @@ import ( // NewAddSecurityHeadersCallbacks returns the native prolog and epilog callbacks // to be hooked to `sqhttp.MiddlewareWithError` in order to add HTTP headers // provided by the rule's data. -func NewAddSecurityHeadersCallbacks(data []interface{}, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) { - var headers = make(http.Header, len(data)) - for _, headersKV := range data { - // TODO: move to a structured list of headers to avoid dynamic type checking - kv, ok := headersKV.([]string) +func NewAddSecurityHeadersCallbacks(rule Context, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) { + var headers http.Header + if cfg := rule.Config(); cfg != nil { + cfg, ok := rule.Config().([]interface{}) if !ok { - err = sqerrors.Errorf("unexpected number of values: header key and values are expected but got `%d` values instead", len(kv)) + err = sqerrors.Errorf("unexpected callback data type: got `%T` instead of `[][]string`", cfg) return } - if len(kv) != 2 { - err = sqerrors.Errorf("unexpected number of values: header key and values are expected but got `%d` values instead", len(kv)) - return + headers = make(http.Header, len(cfg)) + for _, headersKV := range cfg { + // TODO: move to a structured list of headers to avoid these dynamic type checking + kv, ok := headersKV.([]string) + if !ok { + err = sqerrors.Errorf("unexpected number of values: header key and values are expected but got `%d` values instead", len(kv)) + return + } + if len(kv) != 2 { + err = sqerrors.Errorf("unexpected number of values: header key and values are expected but got `%d` values instead", len(kv)) + return + } + headers.Set(kv[0], kv[1]) } - headers.Set(kv[0], kv[1]) } if len(headers) == 0 { return nil, nil, sqerrors.New("there are no headers to add") diff --git a/agent/internal/rule/callback/add-security-headers_test.go b/agent/internal/rule/callback/add-security-headers_test.go index 4ec8dc98..25c61d6f 100644 --- a/agent/internal/rule/callback/add-security-headers_test.go +++ b/agent/internal/rule/callback/add-security-headers_test.go @@ -10,7 +10,6 @@ import ( "reflect" "testing" - "github.com/sqreen/go-agent/agent/internal/rule" "github.com/sqreen/go-agent/agent/internal/rule/callback" "github.com/sqreen/go-agent/agent/sqlib/sqhook" "github.com/stretchr/testify/require" @@ -22,24 +21,24 @@ func TestNewAddSecurityHeadersCallbacks(t *testing.T) { ExpectProlog: true, PrologType: reflect.TypeOf(callback.AddSecurityHeadersPrologCallbackType(nil)), EpilogType: reflect.TypeOf(callback.AddSecurityHeadersEpilogCallbackType(nil)), - InvalidTestCases: [][]interface{}{ + InvalidTestCases: []interface{}{ nil, - {}, - {33}, - {"yet another wrong type"}, - {[]string{}}, - {nil}, - {[]string{"one"}}, - {[]string{"one", "two", "three"}}, + 33, + "yet another wrong type", + []string{}, + []string{"one"}, + []string{"one", "two", "three"}, }, ValidTestCases: []ValidTestCase{ { - ValidData: []interface{}{ - []string{"k", "v"}, - []string{"one", "two"}, - []string{"canonical-header", "the value"}, + Rule: &FakeRule{ + config: []interface{}{ + []string{"k", "v"}, + []string{"one", "two"}, + []string{"canonical-header", "the value"}, + }, }, - TestCallbacks: func(t *testing.T, prolog, epilog sqhook.Callback) { + TestCallbacks: func(t *testing.T, _ *FakeRule, prolog, epilog sqhook.Callback) { expectedHeaders := http.Header{ "K": []string{"v"}, "One": []string{"two"}, @@ -64,116 +63,3 @@ func TestNewAddSecurityHeadersCallbacks(t *testing.T) { }, }) } - -type TestConfig struct { - CallbacksCtor rule.CallbacksConstructorFunc - ExpectEpilog, ExpectProlog bool - PrologType, EpilogType reflect.Type - InvalidTestCases [][]interface{} - ValidTestCases []ValidTestCase -} - -type ValidTestCase struct { - ValidData []interface{} - TestCallbacks func(t *testing.T, prolog, epilog sqhook.Callback) -} - -func RunCallbackTest(t *testing.T, config TestConfig) { - for _, data := range config.InvalidTestCases { - data := data - t.Run("with incorrect data", func(t *testing.T) { - prolog, epilog, err := config.CallbacksCtor(data, nil, nil) - require.Error(t, err) - require.Nil(t, prolog) - require.Nil(t, epilog) - }) - } - - for _, tc := range config.ValidTestCases { - tc := tc - t.Run("with correct data", func(t *testing.T) { - t.Run("without next callbacks", func(t *testing.T) { - // Instantiate the callback with the given correct rule data - prolog, epilog, err := config.CallbacksCtor(tc.ValidData, nil, nil) - require.NoError(t, err) - checkCallbacksValues(t, config, prolog, epilog) - tc.TestCallbacks(t, prolog, epilog) - }) - - t.Run("with next callbacks", func(t *testing.T) { - t.Run("wrong next prolog type", func(t *testing.T) { - prolog, epilog, err := config.CallbacksCtor(tc.ValidData, 33, nil) - require.Error(t, err) - require.Nil(t, prolog) - require.Nil(t, epilog) - }) - - t.Run("wrong next epilog type", func(t *testing.T) { - prolog, epilog, err := config.CallbacksCtor(tc.ValidData, nil, func() {}) - require.Error(t, err) - require.Nil(t, prolog) - require.Nil(t, epilog) - }) - - t.Run("with correct next prolog", func(t *testing.T) { - var called bool - nextProlog := reflect.MakeFunc(config.PrologType, func(args []reflect.Value) (results []reflect.Value) { - called = true - return []reflect.Value{reflect.Zero(reflect.TypeOf((*error)(nil)).Elem())} - }).Interface() - - prolog, epilog, err := config.CallbacksCtor(tc.ValidData, nextProlog, nil) - require.NoError(t, err) - checkCallbacksValues(t, config, prolog, epilog) - require.NotNil(t, prolog) - tc.TestCallbacks(t, prolog, epilog) - require.True(t, called) - }) - - t.Run("with correct next epilog", func(t *testing.T) { - var called bool - nextEpilog := reflect.MakeFunc(config.EpilogType, func(args []reflect.Value) (results []reflect.Value) { - called = true - return - }).Interface() - - prolog, epilog, err := config.CallbacksCtor(tc.ValidData, nil, nextEpilog) - require.NoError(t, err) - checkCallbacksValues(t, config, prolog, epilog) - require.NotNil(t, epilog) - tc.TestCallbacks(t, prolog, epilog) - require.True(t, called) - }) - - t.Run("with both correct next callbacks", func(t *testing.T) { - var calledProlog, calledEpilog bool - nextProlog := reflect.MakeFunc(config.PrologType, func(args []reflect.Value) (results []reflect.Value) { - calledProlog = true - return []reflect.Value{reflect.Zero(reflect.TypeOf((*error)(nil)).Elem())} - }).Interface() - nextEpilog := reflect.MakeFunc(config.EpilogType, func(args []reflect.Value) (results []reflect.Value) { - calledEpilog = true - return - }).Interface() - - prolog, epilog, err := config.CallbacksCtor(tc.ValidData, nextProlog, nextEpilog) - require.NoError(t, err) - require.NotNil(t, prolog) - require.NotNil(t, epilog) - tc.TestCallbacks(t, prolog, epilog) - require.True(t, calledProlog) - require.True(t, calledEpilog) - }) - }) - }) - } -} - -func checkCallbacksValues(t *testing.T, config TestConfig, prolog, epilog sqhook.Callback) { - if config.ExpectProlog { - require.NotNil(t, prolog) - } - if config.ExpectEpilog { - require.NotNil(t, prolog) - } -} diff --git a/agent/internal/rule/callback/monitor-http-status-code.go b/agent/internal/rule/callback/monitor-http-status-code.go index 5754d6f7..ccff3143 100644 --- a/agent/internal/rule/callback/monitor-http-status-code.go +++ b/agent/internal/rule/callback/monitor-http-status-code.go @@ -26,9 +26,7 @@ func NewMonitorHTTPStatusCodeCallbacks(rule Context, nextProlog, nextEpilog sqho func newMonitorHTTPStatusCodePrologCallback(rule Context, next MonitorHTTPStatusCodePrologCallbackType) MonitorHTTPStatusCodePrologCallbackType { return func(ctx *sqhook.Context, code *int) error { - //if status := *code; status >= 400 && status <= 500 { rule.AddMetricsValue(*code, 1) - //} if next == nil { return nil diff --git a/agent/internal/rule/callback/write-custom-error-page.go b/agent/internal/rule/callback/write-custom-error-page.go index 8ce41400..cd3e6a07 100644 --- a/agent/internal/rule/callback/write-custom-error-page.go +++ b/agent/internal/rule/callback/write-custom-error-page.go @@ -16,13 +16,12 @@ import ( // callbacks modifying the arguments of `httphandler.WriteResponse` in order to // modify the http status code and error page that are provided by the rule's // data. -func NewWriteCustomErrorPageCallbacks(data []interface{}, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) { +func NewWriteCustomErrorPageCallbacks(rule Context, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) { var statusCode = 500 - if len(data) > 0 { - d0 := data[0] - cfg, ok := d0.(*api.CustomErrorPageRuleDataEntry) + if cfg := rule.Config(); cfg != nil { + cfg, ok := cfg.(*api.CustomErrorPageRuleDataEntry) if !ok { - err = sqerrors.Errorf("unexpected callback data type: got `%T` instead of `*api.CustomErrorPageRuleDataEntry`", d0) + err = sqerrors.Errorf("unexpected callback data type: got `%T` instead of `*api.CustomErrorPageRuleDataEntry`", cfg) return } statusCode = cfg.StatusCode diff --git a/agent/internal/rule/callback/write-custom-error-page_test.go b/agent/internal/rule/callback/write-custom-error-page_test.go index 8bfb13c4..c1968b18 100644 --- a/agent/internal/rule/callback/write-custom-error-page_test.go +++ b/agent/internal/rule/callback/write-custom-error-page_test.go @@ -20,22 +20,18 @@ func TestNewWriteCustomErrorPageCallbacks(t *testing.T) { ExpectProlog: true, PrologType: reflect.TypeOf(callback.WriteCustomErrorPagePrologCallbackType(nil)), EpilogType: reflect.TypeOf(callback.WriteCustomErrorPageEpilogCallbackType(nil)), - InvalidTestCases: [][]interface{}{ - {33}, - {"yet another wrong type"}, + InvalidTestCases: []interface{}{ + 33, + "yet another wrong type", }, ValidTestCases: []ValidTestCase{ { - ValidData: nil, + Rule: &FakeRule{}, TestCallbacks: testWriteCustomErrorPageCallbacks(500), }, { - ValidData: []interface{}{}, - TestCallbacks: testWriteCustomErrorPageCallbacks(500), - }, - { - ValidData: []interface{}{ - &api.CustomErrorPageRuleDataEntry{StatusCode: 33}, + Rule: &FakeRule{ + config: &api.CustomErrorPageRuleDataEntry{StatusCode: 33}, }, TestCallbacks: testWriteCustomErrorPageCallbacks(33), }, @@ -43,8 +39,8 @@ func TestNewWriteCustomErrorPageCallbacks(t *testing.T) { }) } -func testWriteCustomErrorPageCallbacks(expectedStatusCode int) func(t *testing.T, prolog sqhook.Callback, epilog sqhook.Callback) { - return func(t *testing.T, prolog, epilog sqhook.Callback) { +func testWriteCustomErrorPageCallbacks(expectedStatusCode int) func(t *testing.T, rule *FakeRule, prolog sqhook.Callback, epilog sqhook.Callback) { + return func(t *testing.T, _ *FakeRule, prolog, epilog sqhook.Callback) { actualProlog, ok := prolog.(callback.WriteCustomErrorPagePrologCallbackType) require.True(t, ok) var ( diff --git a/agent/internal/rule/callback/write-http-redirection.go b/agent/internal/rule/callback/write-http-redirection.go index 069682e5..ba7f6b89 100644 --- a/agent/internal/rule/callback/write-http-redirection.go +++ b/agent/internal/rule/callback/write-http-redirection.go @@ -17,13 +17,12 @@ import ( // callbacks modifying the arguments of `httphandler.WriteResponse` in order to // modify the http status code and headers in order to perform an HTTP // redirection to the URL provided by the rule's data. -func NewWriteHTTPRedirectionCallbacks(data []interface{}, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) { +func NewWriteHTTPRedirectionCallbacks(rule Context, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) { var redirectionURL string - if len(data) > 0 { - d0 := data[0] - cfg, ok := d0.(*api.RedirectionRuleDataEntry) + if cfg := rule.Config(); cfg != nil { + cfg, ok := cfg.(*api.RedirectionRuleDataEntry) if !ok { - err = sqerrors.Errorf("unexpected callback data type: got `%T` instead of `*api.CustomErrorPageRuleDataEntry`", d0) + err = sqerrors.Errorf("unexpected callback data type: got `%T` instead of `*api.CustomErrorPageRuleDataEntry`", cfg) return } redirectionURL = cfg.RedirectionURL @@ -43,11 +42,16 @@ func NewWriteHTTPRedirectionCallbacks(data []interface{}, nextProlog, nextEpilog err = sqerrors.Errorf("unexpected next prolog type `%T`", nextProlog) return } - // No epilog in this callback, so simply pass the given one + // No epilog in this callback, so simply check and pass the given one + if _, ok := nextEpilog.(WriteHTTPRedirectionEpilogCallbackType); nextEpilog != nil && !ok { + err = sqerrors.Errorf("unexpected next epilog type `%T` instead of `%T`", nextEpilog, WriteHTTPRedirectionEpilogCallbackType(nil)) + return + } return newWriteHTTPRedirectionPrologCallback(redirectionURL, actualNextProlog), nextEpilog, nil } type WriteHTTPRedirectionPrologCallbackType = func(*sqhook.Context, *http.ResponseWriter, **http.Request, *http.Header, *int, *[]byte) error +type WriteHTTPRedirectionEpilogCallbackType = func(*sqhook.Context) // The prolog callback modifies the function arguments in order to perform an // HTTP redirection. diff --git a/agent/internal/rule/callback/write-http-redirection_test.go b/agent/internal/rule/callback/write-http-redirection_test.go index 18a06e2c..9edd432e 100644 --- a/agent/internal/rule/callback/write-http-redirection_test.go +++ b/agent/internal/rule/callback/write-http-redirection_test.go @@ -6,52 +6,57 @@ package callback_test import ( "net/http" + "reflect" "testing" "github.com/sqreen/go-agent/agent/internal/backend/api" "github.com/sqreen/go-agent/agent/internal/rule/callback" + "github.com/sqreen/go-agent/agent/sqlib/sqhook" "github.com/stretchr/testify/require" ) func TestNewWriteHTTPRedirectionCallbacks(t *testing.T) { - t.Run("with incorrect data", func(t *testing.T) { - for _, data := range [][]interface{}{ + RunCallbackTest(t, TestConfig{ + CallbacksCtor: callback.NewWriteHTTPRedirectionCallbacks, + ExpectProlog: true, + PrologType: reflect.TypeOf(callback.WriteHTTPRedirectionPrologCallbackType(nil)), + EpilogType: reflect.TypeOf(callback.WriteHTTPRedirectionEpilogCallbackType(nil)), + InvalidTestCases: []interface{}{ nil, - {}, - {33}, - {"yet another wrong type"}, - {&api.CustomErrorPageRuleDataEntry{}}, - {&api.RedirectionRuleDataEntry{}}, - {&api.RedirectionRuleDataEntry{"http//sqreen.com"}}, - } { - prolog, epilog, err := callback.NewWriteHTTPRedirectionCallbacks(data, nil, nil) - require.Error(t, err) - require.Nil(t, prolog) - require.Nil(t, epilog) - } - }) + 33, + "yet another wrong type", + &api.CustomErrorPageRuleDataEntry{}, + &api.RedirectionRuleDataEntry{}, + &api.RedirectionRuleDataEntry{"http//sqreen.com"}, + }, + ValidTestCases: []ValidTestCase{ + { + Rule: &FakeRule{ + config: &api.RedirectionRuleDataEntry{"http://sqreen.com"}, + }, + TestCallbacks: func(t *testing.T, rule *FakeRule, prolog, epilog sqhook.Callback) { + // Call it and check the behaviour follows the rule's data + actualProlog, ok := prolog.(callback.WriteHTTPRedirectionPrologCallbackType) + require.True(t, ok) + var ( + statusCode int + headers http.Header + ) + err := actualProlog(nil, nil, nil, &headers, &statusCode, nil) + // Check it behaves as expected + require.NoError(t, err) + require.Equal(t, http.StatusSeeOther, statusCode) + require.NotNil(t, headers) + require.Equal(t, "http://sqreen.com", headers.Get("Location")) - t.Run("with correct data", func(t *testing.T) { - // Instantiate the callback with the given correct rule data - expectedURL := "http://sqreen.com" - prolog, epilog, err := callback.NewWriteHTTPRedirectionCallbacks([]interface{}{ - &api.RedirectionRuleDataEntry{RedirectionURL: expectedURL}, - }, nil, nil) - require.NoError(t, err) - require.NotNil(t, prolog) - require.Nil(t, epilog) - // Call it and check the behaviour follows the rule's data - actualProlog, ok := prolog.(callback.WriteHTTPRedirectionPrologCallbackType) - require.True(t, ok) - var ( - statusCode int - headers http.Header - ) - err = actualProlog(nil, nil, nil, &headers, &statusCode, nil) - // Check it behaves as expected - require.NoError(t, err) - require.Equal(t, http.StatusSeeOther, statusCode) - require.NotNil(t, headers) - require.Equal(t, expectedURL, headers.Get("Location")) + // Test the epilog if any + if epilog != nil { + actualEpilog, ok := epilog.(callback.WriteHTTPRedirectionEpilogCallbackType) + require.True(t, ok) + actualEpilog(&sqhook.Context{}) + } + }, + }, + }, }) } From a000df0ba07dcb77b82d4074e97c36efc244a3fc Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Fri, 19 Jul 2019 08:52:17 +0200 Subject: [PATCH 28/47] agent/rule: rename the AddMetricsValue into a more generic PushMetricsValue name This method is meant to abstract the metrics store type, so we decided to rename it with a more generic name and thought that "Push" has a weaker meaning than "Add". --- agent/internal/rule/callback.go | 2 +- agent/internal/rule/callback/callback_test.go | 2 +- agent/internal/rule/callback/monitor-http-status-code.go | 2 +- agent/internal/rule/callback/monitor-http-status-code_test.go | 2 +- agent/internal/rule/callback/types.go | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/agent/internal/rule/callback.go b/agent/internal/rule/callback.go index 1e25cfc6..8520d3be 100644 --- a/agent/internal/rule/callback.go +++ b/agent/internal/rule/callback.go @@ -89,7 +89,7 @@ func (d *CallbackContext) Config() interface{} { return d.config } -func (d *CallbackContext) AddMetricsValue(key interface{}, value uint64) { +func (d *CallbackContext) PushMetricsValue(key interface{}, value uint64) { err := d.defaultMetricsStore.Add(key, value) if err != nil { d.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule `%s`: could not add a value to the default metrics store", d.name))) diff --git a/agent/internal/rule/callback/callback_test.go b/agent/internal/rule/callback/callback_test.go index 5da905eb..ce4be5d6 100644 --- a/agent/internal/rule/callback/callback_test.go +++ b/agent/internal/rule/callback/callback_test.go @@ -132,7 +132,7 @@ type FakeRule struct { mock.Mock } -func (r *FakeRule) AddMetricsValue(key interface{}, value uint64) { +func (r *FakeRule) PushMetricsValue(key interface{}, value uint64) { r.Called(key, value) } diff --git a/agent/internal/rule/callback/monitor-http-status-code.go b/agent/internal/rule/callback/monitor-http-status-code.go index ccff3143..d7c867f7 100644 --- a/agent/internal/rule/callback/monitor-http-status-code.go +++ b/agent/internal/rule/callback/monitor-http-status-code.go @@ -26,7 +26,7 @@ func NewMonitorHTTPStatusCodeCallbacks(rule Context, nextProlog, nextEpilog sqho func newMonitorHTTPStatusCodePrologCallback(rule Context, next MonitorHTTPStatusCodePrologCallbackType) MonitorHTTPStatusCodePrologCallbackType { return func(ctx *sqhook.Context, code *int) error { - rule.AddMetricsValue(*code, 1) + rule.PushMetricsValue(*code, 1) if next == nil { return nil diff --git a/agent/internal/rule/callback/monitor-http-status-code_test.go b/agent/internal/rule/callback/monitor-http-status-code_test.go index 30d5b4ad..925dfa93 100644 --- a/agent/internal/rule/callback/monitor-http-status-code_test.go +++ b/agent/internal/rule/callback/monitor-http-status-code_test.go @@ -27,7 +27,7 @@ func TestNewMonitorHTTPStatusCodeCallbacks(t *testing.T) { actualProlog, ok := prolog.(callback.MonitorHTTPStatusCodePrologCallbackType) require.True(t, ok) code := rand.Int() - rule.On("AddMetricsValue", code, uint64(1)).Return().Once() + rule.On("PushMetricsValue", code, uint64(1)).Return().Once() err := actualProlog(nil, &code) // Check it behaves as expected require.NoError(t, err) diff --git a/agent/internal/rule/callback/types.go b/agent/internal/rule/callback/types.go index ee4ba772..0e9b28fc 100644 --- a/agent/internal/rule/callback/types.go +++ b/agent/internal/rule/callback/types.go @@ -7,7 +7,7 @@ package callback type Context interface { // Get the rule configuration. Config() interface{} - // Add a new metrics value for the given key to the default metrics store + // Push a new metrics value for the given key into the default metrics store // given by the rule. - AddMetricsValue(key interface{}, value uint64) + PushMetricsValue(key interface{}, value uint64) } From 9631765cbd6bafe7cbf94e507b954b7789f8999f Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Fri, 19 Jul 2019 09:03:40 +0200 Subject: [PATCH 29/47] agent/rule: fix configuration descriptor --- agent/internal/rule/callback.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/agent/internal/rule/callback.go b/agent/internal/rule/callback.go index 8520d3be..2a79a64f 100644 --- a/agent/internal/rule/callback.go +++ b/agent/internal/rule/callback.go @@ -6,6 +6,7 @@ package rule import ( "fmt" + "reflect" "time" "github.com/sqreen/go-agent/agent/internal/backend/api" @@ -73,14 +74,14 @@ func NewCallbackContext(r *api.Rule, logger plog.ErrorLogger, metricsEngine *met } func newCallbackConfig(data *api.RuleData) (config interface{}) { - if nbData := len(data.Values); nbData > 1 { + if nbData := len(data.Values); nbData == 1 && reflect.TypeOf(data.Values[0].Value).Kind() != reflect.Slice { + config = data.Values[0].Value + } else { configArray := make([]interface{}, 0, nbData) for _, e := range data.Values { configArray = append(configArray, e.Value) } config = configArray - } else if nbData == 1 { - config = data.Values[0].Value } return config } From b752a2fd0fc31d1b59b428766a3109b260b24537 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Fri, 19 Jul 2019 09:06:48 +0200 Subject: [PATCH 30/47] sdk/middleware/echo: fix comment typo --- sdk/middleware/sqecho/echo.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/middleware/sqecho/echo.go b/sdk/middleware/sqecho/echo.go index 49b6902b..fae95618 100644 --- a/sdk/middleware/sqecho/echo.go +++ b/sdk/middleware/sqecho/echo.go @@ -65,7 +65,7 @@ func Middleware() echo.MiddlewareFunc { contextKey := sdk.HTTPRequestRecordContextKey.String c.Set(contextKey, sdk.FromContext(r.Context())) c.Response().After(func() { - // Hack for now to monitor the status code because Gin doesn't use the + // Hack for now to monitor the status code because Echo doesn't use the // HTTP ResponseWriter when overwriting it through c.Writer = ... sqhttp.ResponseWriter{}.WriteHeader(c.Response().Status) }) From b81c1ef7d7a3f45ac44f1a9d518221bb87f549de Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Fri, 19 Jul 2019 09:15:54 +0200 Subject: [PATCH 31/47] agent/actor: fix a dev regression on security actions --- agent/internal/actor/http.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/agent/internal/actor/http.go b/agent/internal/actor/http.go index b99e7f41..ad8a13c2 100644 --- a/agent/internal/actor/http.go +++ b/agent/internal/actor/http.go @@ -28,6 +28,8 @@ const ( // request handler level to perform the security response. func NewIPActionHTTPHandler(action Action, ip net.IP) (http.Handler, error) { switch actual := action.(type) { + case *timedAction: + return NewIPActionHTTPHandler(actual.Action, ip) case blockAction: return newBlockHTTPHandler(blockIPEventName, newBlockedIPEventProperties(actual, ip)), nil case *redirectAction: @@ -40,6 +42,8 @@ func NewIPActionHTTPHandler(action Action, ip net.IP) (http.Handler, error) { // the request handler level to perform the security response. func NewUserActionHTTPHandler(action Action, userID map[string]string) (http.Handler, error) { switch actual := action.(type) { + case *timedAction: + return NewUserActionHTTPHandler(actual.Action, userID) case blockAction: return newBlockHTTPHandler(blockUserEventName, newBlockedUserEventProperties(actual, userID)), nil case *redirectAction: From a7bbd3351dc5efc698b9c9b80469d7910b629f0d Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Fri, 19 Jul 2019 09:19:30 +0200 Subject: [PATCH 32/47] agent/config: add an option to modify the sdk metrics stores periods during integration tests The signup/login rules are js rules that are not yet supported by the go agent. This configuration option allows to modify the period of the signup and login stores so that it is reduced during the integration tests to get them faster. --- agent/internal/agent.go | 13 +++++++++---- agent/internal/config/config.go | 15 +++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/agent/internal/agent.go b/agent/internal/agent.go index d4360120..3d34d77a 100644 --- a/agent/internal/agent.go +++ b/agent/internal/agent.go @@ -156,6 +156,11 @@ func New(cfg *config.Config) *Agent { metrics := metrics.NewEngine(logger) + // TODO: remove this SDK metrics period config when the corresponding js rule + // is supported + sdkMetricsPeriod := time.Duration(cfg.SDKMetricsPeriod()) * time.Second + logger.Debugf("using sdk metrics store period of %s", sdkMetricsPeriod) + // Agent graceful stopping using context cancellation. ctx, cancel := context.WithCancel(context.Background()) return &Agent{ @@ -163,10 +168,10 @@ func New(cfg *config.Config) *Agent { isDone: make(chan struct{}), metrics: metrics, staticMetrics: staticMetrics{ - sdkUserLoginSuccess: metrics.NewStore("sdk-login-success", 60*time.Second), - sdkUserLoginFailure: metrics.NewStore("sdk-login-fail", 60*time.Second), - sdkUserSignup: metrics.NewStore("sdk-signup", 60*time.Second), - whitelistedIP: metrics.NewStore("whitelisted", 60*time.Second), + sdkUserLoginSuccess: metrics.NewStore("sdk-login-success", sdkMetricsPeriod), + sdkUserLoginFailure: metrics.NewStore("sdk-login-fail", sdkMetricsPeriod), + sdkUserSignup: metrics.NewStore("sdk-signup", sdkMetricsPeriod), + whitelistedIP: metrics.NewStore("whitelisted", sdkMetricsPeriod), }, ctx: ctx, cancel: cancel, diff --git a/agent/internal/config/config.go b/agent/internal/config/config.go index a43540eb..d28ff6d8 100644 --- a/agent/internal/config/config.go +++ b/agent/internal/config/config.go @@ -14,6 +14,7 @@ import ( "net/http" "os" "path/filepath" + "strconv" "strings" "time" @@ -205,12 +206,14 @@ const ( configKeyDisable = `disable` configKeyStripHTTPReferer = `strip_http_referer` configKeyRules = `rules` + configKeySDKMetricsPeriod = `sdk_metrics_period` ) // User configuration's default values. const ( configDefaultBackendHTTPAPIBaseURL = `https://back.sqreen.com` configDefaultLogLevel = `info` + configDefaultSDKMetricsPeriod = 60 ) func New(logger *plog.Logger) *Config { @@ -248,6 +251,7 @@ func New(logger *plog.Logger) *Config { manager.SetDefault(configKeyDisable, "") manager.SetDefault(configKeyStripHTTPReferer, "") manager.SetDefault(configKeyRules, "") + manager.SetDefault(configKeySDKMetricsPeriod, configDefaultSDKMetricsPeriod) err := manager.ReadInConfig() if err != nil { @@ -311,6 +315,17 @@ func (c *Config) LocalRulesFile() string { return sanitizeString(c.GetString(configKeyRules)) } +// SDKMetricsPeriod returns the period to use for the SDK metric stores. +// This is temporary until the SDK rules are implemented and required for +// integration tests which require a shorter time. +func (c *Config) SDKMetricsPeriod() int { + p, err := strconv.Atoi(sanitizeString(c.GetString(configKeySDKMetricsPeriod))) + if err != nil { + return configDefaultSDKMetricsPeriod + } + return p +} + func sanitizeString(s string) string { return strings.TrimSpace(s) } From 4b2d1a6dd9b7c5739139c500ce1e69ec117bd478 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Fri, 19 Jul 2019 09:25:26 +0200 Subject: [PATCH 33/47] agent/metrics: fix a dev regression on flushing metrics stores --- agent/internal/agent.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agent/internal/agent.go b/agent/internal/agent.go index 3d34d77a..2a6c17ec 100644 --- a/agent/internal/agent.go +++ b/agent/internal/agent.go @@ -299,7 +299,7 @@ func (a *Agent) Serve() error { func makeAPIMetrics(logger plog.ErrorLogger, expiredMetrics map[string]*metrics.ReadyStore) []api.MetricResponse { var metricsArray []api.MetricResponse if readyMetrics := expiredMetrics; len(readyMetrics) > 0 { - metricsArray = make([]api.MetricResponse, len(readyMetrics)) + metricsArray = make([]api.MetricResponse, 0, len(readyMetrics)) for name, values := range readyMetrics { observations := make(map[string]uint64, len(values.Metrics())) for k, v := range values.Metrics() { From 838d04891a3ee3473752d643ac93d8e390d6868f Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Fri, 19 Jul 2019 09:28:16 +0200 Subject: [PATCH 34/47] agent/rules: fix a dev regression on local rules files --- agent/internal/agent.go | 18 ++++++++++-------- agent/internal/client.go | 4 ++-- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/agent/internal/agent.go b/agent/internal/agent.go index 2a6c17ec..d724f371 100644 --- a/agent/internal/agent.go +++ b/agent/internal/agent.go @@ -366,16 +366,18 @@ func (a *Agent) RulesReload() error { // Insert local rules if any localRulesJSON := a.config.LocalRulesFile() - buf, err := ioutil.ReadFile(localRulesJSON) - if err == nil { - var localRules []api.Rule - err = json.Unmarshal(buf, &localRules) + if localRulesJSON != "" { + buf, err := ioutil.ReadFile(localRulesJSON) if err == nil { - rulespack.Rules = append(rulespack.Rules, localRules...) + var localRules []api.Rule + err = json.Unmarshal(buf, &localRules) + if err == nil { + rulespack.Rules = append(rulespack.Rules, localRules...) + } + } + if err != nil { + a.logger.Error(sqerrors.Wrap(err, "config: could not read the local rules file")) } - } - if err != nil { - a.logger.Error(sqerrors.Wrap(err, "config: could not read the local rules file")) } a.rules.SetRules(rulespack.PackID, rulespack.Rules) diff --git a/agent/internal/client.go b/agent/internal/client.go index 8edcc1e2..cdcb8d4f 100644 --- a/agent/internal/client.go +++ b/agent/internal/client.go @@ -148,10 +148,10 @@ func TrySendAppException(logger plog.DebugLogger, cfg *config.Config, exception req.Header.Add(config.BackendHTTPAPIHeaderAppName, cfg.AppName()) req.Header.Add("Content-Type", "application/json") - logger.Debugf("sending app exception:\n%s\n", req, (*backend.HTTPRequestStringer)(req)) + logger.Debugf("sending app exception:\n%s\n", (*backend.HTTPRequestStringer)(req)) res, err := http.DefaultClient.Do(req) if err != nil { return } - logger.Debugf("received app exception response:\n%s\n", res, (*backend.HTTPResponseStringer)(res)) + logger.Debugf("received app exception response:\n%s\n", (*backend.HTTPResponseStringer)(res)) } From 5a3617d85c683ad6a1504d0f174b064396ea95a8 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Mon, 22 Jul 2019 15:39:18 +0200 Subject: [PATCH 35/47] agent/rule: signature verification Add the ECDSA signature verification of rules. The received rule is not sent the way it was signed, so the original message that was signed needs to be re-created. This message is the string representation of the json structure whose fields are sorted in lexicographical order. Only the field names provided in the signature structure should be used. Since Go structures cannot be addresses by field name strings, the hack is to parse the rule as a `map[string]json.RawMessage` to only get the required keys and then unmarshal it to `interface{}` in order to get native Go types and marshal it back to json in lexicographical order. The nice thing here is that only `map[string]interface{}` (for structures) needs to be specifically considered, while other types already have a stable json serialization. --- agent/internal/agent.go | 9 +- agent/internal/backend/api/api.go | 29 +++-- agent/internal/backend/api/json_test.go | 4 - agent/internal/backend/api/jsonpb.go | 70 +++++++++++ agent/internal/config/config.go | 9 +- agent/internal/rule/rule.go | 14 ++- agent/internal/rule/rule_test.go | 147 ++++++++++++++++++++++-- agent/internal/rule/signature.go | 73 ++++++++++++ agent/internal/rule/signature_test.go | 89 ++++++++++++++ 9 files changed, 419 insertions(+), 25 deletions(-) create mode 100644 agent/internal/rule/signature.go create mode 100644 agent/internal/rule/signature_test.go diff --git a/agent/internal/agent.go b/agent/internal/agent.go index d724f371..472d2706 100644 --- a/agent/internal/agent.go +++ b/agent/internal/agent.go @@ -156,6 +156,13 @@ func New(cfg *config.Config) *Agent { metrics := metrics.NewEngine(logger) + publicKey, err := rule.NewECDSAPublicKey(config.PublicKey) + if err != nil { + logger.Error(sqerrors.Wrap(err, "ecdsa public key")) + return nil + } + rulesEngine := rule.NewEngine(logger, metrics, publicKey) + // TODO: remove this SDK metrics period config when the corresponding js rule // is supported sdkMetricsPeriod := time.Duration(cfg.SDKMetricsPeriod()) * time.Second @@ -179,7 +186,7 @@ func New(cfg *config.Config) *Agent { appInfo: app.NewInfo(logger), client: backend.NewClient(cfg.BackendHTTPAPIBaseURL(), cfg, logger), actors: actor.NewStore(logger), - rules: rule.NewEngine(logger, metrics), + rules: rulesEngine, } } diff --git a/agent/internal/backend/api/api.go b/agent/internal/backend/api/api.go index c3d1998d..7538b9a5 100644 --- a/agent/internal/backend/api/api.go +++ b/agent/internal/backend/api/api.go @@ -41,7 +41,6 @@ type AppLoginResponse struct { Commands []CommandRequest `protobuf:"bytes,3,rep,name=commands,proto3" json:"commands"` Features AppLoginResponse_Feature `protobuf:"bytes,4,opt,name=features,proto3" json:"features"` PackId string `protobuf:"bytes,5,opt,name=pack_id,json=packId,proto3" json:"pack_id"` - Rules []Rule `protobuf:"bytes,6,rep,name=rules,proto3" json:"rules"` } type AppLoginResponse_Feature struct { @@ -164,10 +163,28 @@ type BatchRequest_Event struct { } type Rule struct { - Name string `json:"name"` - Hookpoint Hookpoint `json:"hookpoint"` - Data RuleData `json:"data"` - Metrics []MetricDefinition `json:"metrics"` + Name string `json:"name"` + Hookpoint Hookpoint `json:"hookpoint"` + Data RuleData `json:"data"` + Metrics []MetricDefinition `json:"metrics"` + Signature RuleSignature `json:"signature"` + Conditions RuleConditions `json:"conditions"` + Callbacks RuleCallbacks `json:"callbacks"` +} + +type RuleConditions struct{} +type RuleCallbacks struct{} + +type ECDSASignature struct { + Keys []string `json:"keys"` + Value string `json:"value"` + // Custom field where the signed message is reconstructed out of the list of + // keys + Message []byte `json:"-"` +} + +type RuleSignature struct { + ECDSASignature ECDSASignature `json:"v0_9"` } type MetricDefinition struct { @@ -406,7 +423,6 @@ type AppLoginResponseFace interface { GetCommands() []CommandRequest GetFeatures() AppLoginResponse_Feature GetPackId() string - GetRules() []Rule } func NewAppLoginResponseFromFace(that AppLoginResponseFace) *AppLoginResponse { @@ -416,7 +432,6 @@ func NewAppLoginResponseFromFace(that AppLoginResponseFace) *AppLoginResponse { this.Commands = that.GetCommands() this.Features = that.GetFeatures() this.PackId = that.GetPackId() - this.Rules = that.GetRules() return this } diff --git a/agent/internal/backend/api/json_test.go b/agent/internal/backend/api/json_test.go index 1367423a..444f9ef3 100644 --- a/agent/internal/backend/api/json_test.go +++ b/agent/internal/backend/api/json_test.go @@ -232,10 +232,6 @@ func (this *AppLoginResponse) GetPackId() string { return this.PackId } -func (this *AppLoginResponse) GetRules() []api.Rule { - return this.Rules -} - type AppLoginResponse_Feature api.AppLoginResponse_Feature func (this *AppLoginResponse_Feature) GetBatchSize() uint32 { diff --git a/agent/internal/backend/api/jsonpb.go b/agent/internal/backend/api/jsonpb.go index 555fa168..60d3bb42 100644 --- a/agent/internal/backend/api/jsonpb.go +++ b/agent/internal/backend/api/jsonpb.go @@ -7,6 +7,8 @@ package api import ( "encoding/json" "fmt" + "sort" + "strings" "github.com/sqreen/go-agent/agent/sqlib/sqerrors" ) @@ -159,3 +161,71 @@ func (v *RuleDataEntry) MarshalJSON() ([]byte, error) { } return json.Marshal(discriminant) } + +func (r *Rule) UnmarshalJSON(data []byte) error { + type rule Rule + if err := json.Unmarshal(data, (*rule)(r)); err != nil { + return err + } + + var keys map[string]json.RawMessage + if err := json.Unmarshal(data, &keys); err != nil { + return err + } + + signature := &r.Signature.ECDSASignature + kv := make(map[string]interface{}, len(signature.Keys)) + for _, k := range signature.Keys { + rawValue, exists := keys[k] + if !exists { + continue + } + var v interface{} + if err := json.Unmarshal(rawValue, &v); err != nil { + return err + } + kv[k] = v + } + message, err := LexicographicalOrderJSONMarshalMap(kv) + if err != nil { + return err + } + signature.Message = message + + return nil +} + +func LexicographicalOrderJSONMarshal(o interface{}) ([]byte, error) { + switch actual := o.(type) { + case map[string]interface{}: + return LexicographicalOrderJSONMarshalMap(actual) + default: + return json.Marshal(o) + } +} + +func LexicographicalOrderJSONMarshalMap(o map[string]interface{}) ([]byte, error) { + if len(o) == 0 { + return []byte(`{}`), nil + } + // Get the list of keys + keys := make([]string, 0, len(o)) + for k := range o { + keys = append(keys, k) + } + // Sort the list of keys + sort.Strings(keys) + for i, k := range keys { + v, err := LexicographicalOrderJSONMarshal(o[k]) + if err != nil { + return nil, err + } + jsonKey, err := json.Marshal(k) + if err != nil { + return nil, sqerrors.Wrap(err, "map string key marshaling") + } + k = string(jsonKey) + keys[i] = fmt.Sprintf("%s:%s", k, v) + } + return []byte(fmt.Sprintf("{%s}", strings.Join(keys, ","))), nil +} diff --git a/agent/internal/config/config.go b/agent/internal/config/config.go index d28ff6d8..102543f1 100644 --- a/agent/internal/config/config.go +++ b/agent/internal/config/config.go @@ -34,6 +34,13 @@ const ( ErrorMessage_UnsupportedCommand = "command is not supported" ) +const PublicKey string = `-----BEGIN PUBLIC KEY----- +MIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQA39oWMHR8sxb9LRaM5evZ7mw03iwJ +WNHuDeGqgPo1HmvuMfLnAyVLwaMXpGPuvbqhC1U65PG90bTJLpvNokQf0VMA5Tpi +m+NXwl7bjqa03vO/HErLbq3zBRysrZnC4OhJOF1jazkAg0psQOea2r5HcMcPHgMK +fnWXiKWnZX+uOWPuerE= +-----END PUBLIC KEY-----` + type HTTPAPIEndpoint struct { Method, URL string } @@ -60,7 +67,7 @@ var ( ActionsPack: HTTPAPIEndpoint{http.MethodGet, "/sqreen/v0/actionspack"}, RulesPack: HTTPAPIEndpoint{http.MethodGet, "/sqreen/v0/rulespack"}, } - + // Header name of the API token. BackendHTTPAPIHeaderToken = "X-Api-Key" diff --git a/agent/internal/rule/rule.go b/agent/internal/rule/rule.go index 7087f9bc..fe1c88f0 100644 --- a/agent/internal/rule/rule.go +++ b/agent/internal/rule/rule.go @@ -17,6 +17,7 @@ package rule import ( + "crypto/ecdsa" "fmt" "github.com/sqreen/go-agent/agent/internal/backend/api" @@ -36,6 +37,7 @@ type Engine struct { cfg *config.Config enabled bool metricsEngine *metrics.Engine + publicKey *ecdsa.PublicKey } // Logger interface required by this package. @@ -45,10 +47,11 @@ type Logger interface { } // NewEngine returns a new rule engine. -func NewEngine(logger Logger, metricsEngine *metrics.Engine) *Engine { +func NewEngine(logger Logger, metricsEngine *metrics.Engine, publicKey *ecdsa.PublicKey) *Engine { return &Engine{ logger: logger, metricsEngine: metricsEngine, + publicKey: publicKey, } } @@ -97,13 +100,18 @@ func (e *Engine) setRules(packID string, descriptors hookDescriptors) { // newHookDescriptors walks the list of received rules and creates the map of // hook descriptors indexed by their hook pointer. A hook descriptor contains // all it takes to enable and disable rules at run time. -func newHookDescriptors(logger Logger, rules []api.Rule, metricsEngine *metrics.Engine) hookDescriptors { +func newHookDescriptors(logger Logger, rules []api.Rule, publicKey *ecdsa.PublicKey, metricsEngine *metrics.Engine) hookDescriptors { // Create and configure the list of callbacks according to the given rules var hookDescriptors = make(hookDescriptors) for i := len(rules) - 1; i >= 0; i-- { r := rules[i] - hookpoint := r.Hookpoint + // Verify the signature + if err := VerifyRuleSignature(&r, publicKey); err != nil { + logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule `%s`: signature verification", r.Name))) + continue + } // Find the symbol + hookpoint := r.Hookpoint symbol := fmt.Sprintf("%s.%s", hookpoint.Class, hookpoint.Method) hook := sqhook.Find(symbol) if hook == nil { diff --git a/agent/internal/rule/rule_test.go b/agent/internal/rule/rule_test.go index fa06fac1..10c57451 100644 --- a/agent/internal/rule/rule_test.go +++ b/agent/internal/rule/rule_test.go @@ -5,6 +5,13 @@ package rule_test import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha512" + "encoding/asn1" + "encoding/base64" + "math/big" "net/http" "os" "reflect" @@ -24,8 +31,12 @@ func func2(_ http.ResponseWriter, _ *http.Request, _ http.Header, _ int, _ []byt type empty struct{} func TestEngineUsage(t *testing.T) { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + publicKey := &privateKey.PublicKey + logger := plog.NewLogger(plog.Debug, os.Stderr, 0) - engine := rule.NewEngine(logger, metrics.NewEngine(plog.NewLogger(plog.Debug, os.Stderr, 0))) + engine := rule.NewEngine(logger, metrics.NewEngine(plog.NewLogger(plog.Debug, os.Stderr, 0)), publicKey) hookFunc1 := sqhook.New(func1) require.NotNil(t, hookFunc1) hookFunc2 := sqhook.New(func2) @@ -58,6 +69,7 @@ func TestEngineUsage(t *testing.T) { {&api.CustomErrorPageRuleDataEntry{}}, }, }, + Signature: MakeSignature(privateKey, `{"name":"a valid rule"}`), }, { Name: "another valid rule", @@ -71,6 +83,7 @@ func TestEngineUsage(t *testing.T) { {&api.CustomErrorPageRuleDataEntry{}}, }, }, + Signature: MakeSignature(privateKey, `{"name":"another valid rule"}`), }, }) @@ -136,26 +149,142 @@ func TestEngineUsage(t *testing.T) { {&api.CustomErrorPageRuleDataEntry{}}, }, }, + Signature: MakeSignature(privateKey, `{"name":"another valid rule"}`), }, }) // Check the callbacks were removed for func1 and not func2 - prologFunc1, epilogFunc1 := hookFunc1.Callbacks() + prologFunc1 := hookFunc1.Prolog() require.Nil(t, prologFunc1) - require.Nil(t, epilogFunc1) - prologFunc2, epilogFunc2 := hookFunc2.Callbacks() + prologFunc2 := hookFunc2.Prolog() require.NotNil(t, prologFunc2) - require.Nil(t, epilogFunc2) }) t.Run("replace the enabled rules with an empty array of rules", func(t *testing.T) { // Set the rules with an empty array while enabled engine.SetRules("yet another pack id", []api.Rule{}) // Check the callbacks were all removed for func1 and not func2 - prologFunc1, epilogFunc1 := hookFunc1.Callbacks() + prologFunc1 := hookFunc1.Prolog() require.Nil(t, prologFunc1) - require.Nil(t, epilogFunc1) - prologFunc2, epilogFunc2 := hookFunc2.Callbacks() + prologFunc2 := hookFunc2.Prolog() require.Nil(t, prologFunc2) - require.Nil(t, epilogFunc2) }) + + t.Run("add rules with signature issues", func(t *testing.T) { + validSignature := MakeSignature(privateKey, `{"name":"a valid rule"}`).ECDSASignature + + // Modify the rules while enabled + engine.SetRules("a pack id", []api.Rule{ + { + Name: "a valid rule", + Hookpoint: api.Hookpoint{ + Class: reflect.TypeOf(empty{}).PkgPath(), + Method: "func1", + Callback: "WriteCustomErrorPage", + }, + Data: api.RuleData{ + Values: []api.RuleDataEntry{ + {&api.CustomErrorPageRuleDataEntry{}}, + }, + }, + Signature: api.RuleSignature{ /*zero value*/ }, + }, + { + Name: "a valid rule", + Hookpoint: api.Hookpoint{ + Class: reflect.TypeOf(empty{}).PkgPath(), + Method: "func1", + Callback: "WriteCustomErrorPage", + }, + Data: api.RuleData{ + Values: []api.RuleDataEntry{ + {&api.CustomErrorPageRuleDataEntry{}}, + }, + }, + Signature: api.RuleSignature{ + ECDSASignature: api.ECDSASignature{ + Message: validSignature.Message, + /* zero signature value */ + }, + }, + }, + { + Name: "a valid rule", + Hookpoint: api.Hookpoint{ + Class: reflect.TypeOf(empty{}).PkgPath(), + Method: "func1", + Callback: "WriteCustomErrorPage", + }, + Data: api.RuleData{ + Values: []api.RuleDataEntry{ + {&api.CustomErrorPageRuleDataEntry{}}, + }, + }, + Signature: api.RuleSignature{ + ECDSASignature: api.ECDSASignature{ + Value: validSignature.Value, + /* zero message value */ + }, + }, + }, + { + Name: "a valid rule", + Hookpoint: api.Hookpoint{ + Class: reflect.TypeOf(empty{}).PkgPath(), + Method: "func1", + Callback: "WriteCustomErrorPage", + }, + Data: api.RuleData{ + Values: []api.RuleDataEntry{ + {&api.CustomErrorPageRuleDataEntry{}}, + }, + }, + Signature: api.RuleSignature{ + ECDSASignature: api.ECDSASignature{ + Value: validSignature.Value, + Message: []byte(`wrong message`), + }, + }, + }, + { + Name: "a valid rule", + Hookpoint: api.Hookpoint{ + Class: reflect.TypeOf(empty{}).PkgPath(), + Method: "func1", + Callback: "WriteCustomErrorPage", + }, + Data: api.RuleData{ + Values: []api.RuleDataEntry{ + {&api.CustomErrorPageRuleDataEntry{}}, + }, + }, + Signature: api.RuleSignature{ + ECDSASignature: api.ECDSASignature{ + Value: `wrong value`, + Message: validSignature.Message, + }, + }, + }, + }) + // Check the callbacks were removed for func1 and not func2 + prologFunc1 := hookFunc1.Prolog() + require.Nil(t, prologFunc1) + }) +} + +func MakeSignature(privateKey *ecdsa.PrivateKey, message string) api.RuleSignature { + hash := sha512.Sum512([]byte(message)) + r, s, err := ecdsa.Sign(rand.Reader, privateKey, hash[:]) + if err != nil { + panic(err) + } + signature, err := asn1.Marshal(struct{ R, S *big.Int }{R: r, S: s}) + if err != nil { + panic(err) + } + return api.RuleSignature{ + ECDSASignature: api.ECDSASignature{ + Message: []byte(message), + Value: base64.StdEncoding.EncodeToString(signature), + }, + } } diff --git a/agent/internal/rule/signature.go b/agent/internal/rule/signature.go new file mode 100644 index 00000000..014632f3 --- /dev/null +++ b/agent/internal/rule/signature.go @@ -0,0 +1,73 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package rule + +import ( + "crypto/ecdsa" + "crypto/sha512" + "crypto/x509" + "encoding/asn1" + "encoding/base64" + "encoding/pem" + "math/big" + + "github.com/sqreen/go-agent/agent/internal/backend/api" + "github.com/sqreen/go-agent/agent/sqlib/sqerrors" +) + +// NewECDSAPublicKey creates a ECDSA public key from a PEM public key. +func NewECDSAPublicKey(PEMPublicKey string) (*ecdsa.PublicKey, error) { + // decode the key, assuming it's in PEM format + block, _ := pem.Decode([]byte(PEMPublicKey)) + if block == nil { + return nil, sqerrors.New("failed to decode the PEM public key") + } + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, sqerrors.Wrap(err, "failed to parse ECDSA public key") + } + publicKey, ok := pub.(*ecdsa.PublicKey) + if !ok { + return nil, sqerrors.Errorf("unexpected public key type `%T`", pub) + } + return publicKey, nil +} + +// Verify returns a non-nil error when message verification against the public +// key failed, nil otherwise. +func Verify(publicKey *ecdsa.PublicKey, hash []byte, signature []byte) error { + // unmarshal the R and S components of the ASN.1-encoded signature into our + // signature data structure + var sig struct { + R, S *big.Int + } + if _, err := asn1.Unmarshal(signature, &sig); err != nil { + return err + } + valid := ecdsa.Verify( + publicKey, + hash, + sig.R, + sig.S, + ) + if !valid { + return sqerrors.New("invalid signature") + } + // signature is valid + return nil +} + +// VerifyRuleSignature returns a non-nil error when the rule signature is +// invalid, nil otherwise. +func VerifyRuleSignature(r *api.Rule, publicKey *ecdsa.PublicKey) error { + signature := r.Signature.ECDSASignature + // first decode the signature to extract the DER-encoded byte string + der, err := base64.StdEncoding.DecodeString(signature.Value) + if err != nil { + return sqerrors.Wrap(err, "base64 decoding") + } + hash := sha512.Sum512([]byte(signature.Message)) + return Verify(publicKey, hash[:], der) +} diff --git a/agent/internal/rule/signature_test.go b/agent/internal/rule/signature_test.go new file mode 100644 index 00000000..91b430b4 --- /dev/null +++ b/agent/internal/rule/signature_test.go @@ -0,0 +1,89 @@ +// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package rule_test + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha512" + "encoding/asn1" + "math/big" + "testing" + + "github.com/sqreen/go-agent/agent/internal/config" + "github.com/sqreen/go-agent/agent/internal/rule" + "github.com/sqreen/go-agent/tools/testlib" + "github.com/stretchr/testify/require" +) + +func TestNewECDSAPublicKey(t *testing.T) { + t.Run("invalid format pubkey", func(t *testing.T) { + _, err := rule.NewECDSAPublicKey(testlib.RandString(0, 100)) + require.Error(t, err) + }) + + t.Run("invalid pem pubkey", func(t *testing.T) { + const publicKey string = `-----BEGIN PUBLIC KEY----- +MIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQA39oWMHR8sxb9LRaM5evZ7mw03iwJ +WNHuDeGqgPo1HmvuMfLnAyVLwaMXpGPuvbqhC1U65PG90bTJLpvNokQf0VMA5Tpi +m+NXwl7bjqa03vO/HExLbq3zBRysrZnC4OhJOF1jazkAg0psQOea2r5HcMcPHgMK +fnWXiKWnZX+uOWPuerE= +-----END PUBLIC KEY-----` + _, err := rule.NewECDSAPublicKey(publicKey) + require.Error(t, err) + }) + + t.Run("valid pem pubkey but not ecdsa", func(t *testing.T) { + const publicKey string = `-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAryQICCl6NZ5gDKrnSztO +3Hy8PEUcuyvg/ikC+VcIo2SFFSf18a3IMYldIugqqqZCs4/4uVW3sbdLs/6PfgdX +7O9D22ZiFWHPYA2k2N744MNiCD1UE+tJyllUhSblK48bn+v1oZHCM0nYQ2NqUkvS +j+hwUU3RiWl7x3D2s9wSdNt7XUtW05a/FXehsPSiJfKvHJJnGOX0BgTvkLnkAOTd +OrUZ/wK69Dzu4IvrN4vs9Nes8vbwPa/ddZEzGR0cQMt0JBkhk9kU/qwqUseP1QRJ +5I1jR4g8aYPL/ke9K35PxZWuDp3U0UPAZ3PjFAh+5T+fc7gzCs9dPzSHloruU+gl +FQIDAQAB +-----END PUBLIC KEY-----` + _, err := rule.NewECDSAPublicKey(publicKey) + require.Error(t, err) + }) + + t.Run("valid ecdsa pem pubkey", func(t *testing.T) { + pub, err := rule.NewECDSAPublicKey(config.PublicKey) + require.NoError(t, err) + require.NotNil(t, pub) + }) + +} + +func TestVerify(t *testing.T) { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + msg := []byte("hello, world") + hash := sha512.Sum512(msg) + + r, s, err := ecdsa.Sign(rand.Reader, privateKey, hash[:]) + require.NoError(t, err) + + t.Run("invalid signature", func(t *testing.T) { + signature, err := asn1.Marshal(struct{ R, S *big.Int }{R: big.NewInt(0).Add(r, big.NewInt(33)), S: s}) + require.NoError(t, err) + err = rule.Verify(&privateKey.PublicKey, hash[:], signature) + require.Error(t, err) + }) + + t.Run("invalid asn1", func(t *testing.T) { + err = rule.Verify(&privateKey.PublicKey, hash[:], []byte("oops")) + require.Error(t, err) + }) + + t.Run("valid signature", func(t *testing.T) { + signature, err := asn1.Marshal(struct{ R, S *big.Int }{R: r, S: s}) + require.NoError(t, err) + err = rule.Verify(&privateKey.PublicKey, hash[:], signature) + require.NoError(t, err) + }) +} From 9a680ae12bca9dfce66b8e9d3ea30cd3600686e1 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Mon, 22 Jul 2019 16:33:40 +0200 Subject: [PATCH 36/47] sqlib/sqhook: remove relying on a context thanks to function closures Remove the two callback pointers in favor of just one: the prolog returning a non-nil epilog function closure when required. That way, any context can be shared between the two, removing the need to allocate a call-context that may be unneeded. The other great benefit also is that it makes the epilog typed with its actual type rather than an `interface{}`, so it also removes a type-assertion in favor of just checking it is not nil. --- Makefile | 2 +- agent/internal/backend/client_test.go | 7 +- agent/internal/httphandler/write-response.go | 15 +- agent/internal/rule/callback.go | 8 +- .../rule/callback/add-security-headers.go | 22 +- .../callback/add-security-headers_test.go | 9 +- agent/internal/rule/callback/callback_test.go | 77 ++--- .../rule/callback/monitor-http-status-code.go | 34 +- .../callback/monitor-http-status-code_test.go | 7 +- .../rule/callback/write-custom-error-page.go | 19 +- .../callback/write-custom-error-page_test.go | 9 +- .../rule/callback/write-http-redirection.go | 19 +- .../callback/write-http-redirection_test.go | 7 +- agent/internal/rule/callback_test.go | 2 +- agent/internal/rule/rule.go | 37 +-- agent/internal/rule/rule_test.go | 24 +- agent/sqlib/sqhook/hook.go | 235 +++++++------- agent/sqlib/sqhook/hook_test.go | 304 +++++++++--------- sdk/middleware/sqhttp/http.go | 32 +- 19 files changed, 409 insertions(+), 460 deletions(-) diff --git a/Makefile b/Makefile index e62b2db4..b3520659 100644 --- a/Makefile +++ b/Makefile @@ -19,7 +19,7 @@ protobufs := $(patsubst %.proto,%.pb.go,$(shell find agent -name '*.proto')) protoc/flags := -I. -Ivendor --gogo_out=google/protobuf/any.proto=github.com/gogo/protobuf/types,Mgoogle/protobuf/duration.proto=github.com/gogo/protobuf/types,Mgoogle/protobuf/struct.proto=github.com/gogo/protobuf/types,Mgoogle/protobuf/timestamp.proto=github.com/gogo/protobuf/types,Mgoogle/protobuf/wrappers.proto=github.com/gogo/protobuf/types:. test/packages/everything := ./agent/... ./sdk/... test/packages := $(or $(TEST_PACKAGE), $(test/packages/everything)) -test/options := $(TEST_OPTIONS) +test/options := $(TEST_OPTIONS) -timeout 30m benchmark := $(or $(BENCHMARK), .) benchmark/results := tools/benchmark/results benchmark/result = $(benchmark/results)/$(git/ref/head)/$(shell date '+%Y-%m-%d-%H-%M-%S') diff --git a/agent/internal/backend/client_test.go b/agent/internal/backend/client_test.go index 67ac43cf..dbc266e3 100644 --- a/agent/internal/backend/client_test.go +++ b/agent/internal/backend/client_test.go @@ -23,7 +23,7 @@ import ( var ( logger = plog.NewLogger(plog.Debug, os.Stderr, 0) cfg = config.New(logger) - fuzzer = fuzz.New().Funcs(FuzzStruct, FuzzCommandRequest, FuzzRuleDataValue) + fuzzer = fuzz.New().Funcs(FuzzStruct, FuzzCommandRequest, FuzzRuleDataValue, FuzzRule) ) func TestClient(t *testing.T) { @@ -345,3 +345,8 @@ func FuzzRuleDataValue(e *api.RuleDataEntry, c fuzz.Continue) { c.Fuzz(&v.StatusCode) e.Value = v } + +func FuzzRule(e *api.Rule, c fuzz.Continue) { + c.Fuzz(e) + e.Signature = api.RuleSignature{ECDSASignature: api.ECDSASignature{Message: []byte(`{}`)}} +} diff --git a/agent/internal/httphandler/write-response.go b/agent/internal/httphandler/write-response.go index 6ca09d24..b694dcd0 100644 --- a/agent/internal/httphandler/write-response.go +++ b/agent/internal/httphandler/write-response.go @@ -20,15 +20,14 @@ func init() { // The statusCode is the only mandatory argument. Headers and body can be nil. func WriteResponse(w http.ResponseWriter, r *http.Request, headers http.Header, statusCode int, body []byte) { { - type Prolog = func(*sqhook.Context, *http.ResponseWriter, **http.Request, *http.Header, *int, *[]uint8) error - type Epilog = func(*sqhook.Context) - ctx := sqhook.Context{} - prolog, epilog := writeResponseHook.Callbacks() - if epilog, ok := epilog.(Epilog); ok { - defer epilog(&ctx) - } + type Epilog = func() + type Prolog = func(*http.ResponseWriter, **http.Request, *http.Header, *int, *[]uint8) (Epilog, error) + prolog := writeResponseHook.Prolog() if prolog, ok := prolog.(Prolog); ok { - err := prolog(&ctx, &w, &r, &headers, &statusCode, &body) + epilog, err := prolog(&w, &r, &headers, &statusCode, &body) + if epilog != nil { + defer epilog() + } if err != nil { return } diff --git a/agent/internal/rule/callback.go b/agent/internal/rule/callback.go index 2a79a64f..1998be34 100644 --- a/agent/internal/rule/callback.go +++ b/agent/internal/rule/callback.go @@ -20,15 +20,15 @@ import ( // CallbackConstructorFunc is a function returning a callback function // configured with the given data. The data types are known by the constructor // that can type-assert them. -type CallbacksConstructorFunc func(rule callback.Context, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) +type CallbacksConstructorFunc func(rule callback.Context, nextProlog sqhook.PrologCallback) (prolog sqhook.PrologCallback, err error) // NewCallbacks returns the prolog and epilog callbacks of the given callback // name. And error is returned if the callback name is unknown. -func NewCallbacks(name string, rule *CallbackContext, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) { +func NewCallbacks(name string, rule *CallbackContext, nextProlog sqhook.PrologCallback) (prolog sqhook.PrologCallback, err error) { var callbacksCtor CallbacksConstructorFunc switch name { default: - return nil, nil, sqerrors.Errorf("undefined callback name `%s`", name) + return nil, sqerrors.Errorf("undefined callback name `%s`", name) case "WriteCustomErrorPage": callbacksCtor = callback.NewWriteCustomErrorPageCallbacks case "WriteHTTPRedirection": @@ -38,7 +38,7 @@ func NewCallbacks(name string, rule *CallbackContext, nextProlog, nextEpilog sqh case "MonitorHTTPStatusCode": callbacksCtor = callback.NewMonitorHTTPStatusCodeCallbacks } - return callbacksCtor(rule, nextProlog, nextEpilog) + return callbacksCtor(rule, nextProlog) } type CallbackContext struct { diff --git a/agent/internal/rule/callback/add-security-headers.go b/agent/internal/rule/callback/add-security-headers.go index 4278d662..94152481 100644 --- a/agent/internal/rule/callback/add-security-headers.go +++ b/agent/internal/rule/callback/add-security-headers.go @@ -14,7 +14,7 @@ import ( // NewAddSecurityHeadersCallbacks returns the native prolog and epilog callbacks // to be hooked to `sqhttp.MiddlewareWithError` in order to add HTTP headers // provided by the rule's data. -func NewAddSecurityHeadersCallbacks(rule Context, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) { +func NewAddSecurityHeadersCallbacks(rule Context, nextProlog sqhook.PrologCallback) (prolog sqhook.PrologCallback, err error) { var headers http.Header if cfg := rule.Config(); cfg != nil { cfg, ok := rule.Config().([]interface{}) @@ -38,7 +38,7 @@ func NewAddSecurityHeadersCallbacks(rule Context, nextProlog, nextEpilog sqhook. } } if len(headers) == 0 { - return nil, nil, sqerrors.New("there are no headers to add") + return nil, sqerrors.New("there are no headers to add") } // Next callbacks to call @@ -47,29 +47,23 @@ func NewAddSecurityHeadersCallbacks(rule Context, nextProlog, nextEpilog sqhook. err = sqerrors.Errorf("unexpected next prolog type `%T` instead of `%T`", nextProlog, AddSecurityHeadersPrologCallbackType(nil)) return } - // No epilog in this callback, so simply check and pass the given one - if _, ok := nextEpilog.(AddSecurityHeadersEpilogCallbackType); nextEpilog != nil && !ok { - err = sqerrors.Errorf("unexpected next epilog type `%T` instead of `%T`", nextEpilog, AddSecurityHeadersEpilogCallbackType(nil)) - return - } - return newAddHeadersPrologCallback(headers, actualNextProlog), nextEpilog, nil + return newAddHeadersPrologCallback(headers, actualNextProlog), nil } -type AddSecurityHeadersPrologCallbackType = func(*sqhook.Context, *http.ResponseWriter) error -type AddSecurityHeadersEpilogCallbackType = func(*sqhook.Context) +type AddSecurityHeadersEpilogCallbackType = func(*error) +type AddSecurityHeadersPrologCallbackType = func(*http.ResponseWriter) (AddSecurityHeadersEpilogCallbackType, error) // The prolog callback modifies the function arguments in order to replace the // written status code and body. func newAddHeadersPrologCallback(headers http.Header, next AddSecurityHeadersPrologCallbackType) AddSecurityHeadersPrologCallbackType { - return func(ctx *sqhook.Context, w *http.ResponseWriter) error { + return func(w *http.ResponseWriter) (AddSecurityHeadersEpilogCallbackType, error) { responseHeaders := (*w).Header() for k, v := range headers { responseHeaders[k] = v } - if next == nil { - return nil + return nil, nil } - return next(ctx, w) + return next(w) } } diff --git a/agent/internal/rule/callback/add-security-headers_test.go b/agent/internal/rule/callback/add-security-headers_test.go index 25c61d6f..62c234b2 100644 --- a/agent/internal/rule/callback/add-security-headers_test.go +++ b/agent/internal/rule/callback/add-security-headers_test.go @@ -28,6 +28,8 @@ func TestNewAddSecurityHeadersCallbacks(t *testing.T) { []string{}, []string{"one"}, []string{"one", "two", "three"}, + []interface{}{[]string{"one", "two"}, []string{"three"}}, + []interface{}{[]string{"one", "two"}, []string{"three", "four"}, "nope"}, }, ValidTestCases: []ValidTestCase{ { @@ -38,7 +40,7 @@ func TestNewAddSecurityHeadersCallbacks(t *testing.T) { []string{"canonical-header", "the value"}, }, }, - TestCallbacks: func(t *testing.T, _ *FakeRule, prolog, epilog sqhook.Callback) { + TestCallbacks: func(t *testing.T, _ *FakeRule, prolog sqhook.PrologCallback) { expectedHeaders := http.Header{ "K": []string{"v"}, "One": []string{"two"}, @@ -47,16 +49,15 @@ func TestNewAddSecurityHeadersCallbacks(t *testing.T) { actualProlog, ok := prolog.(callback.AddSecurityHeadersPrologCallbackType) require.True(t, ok) var rec http.ResponseWriter = httptest.NewRecorder() - err := actualProlog(nil, &rec) + epilog, err := actualProlog(&rec) // Check it behaves as expected require.NoError(t, err) require.Equal(t, expectedHeaders, rec.Header()) // Test the epilog if any if epilog != nil { - actualEpilog, ok := epilog.(callback.AddSecurityHeadersEpilogCallbackType) require.True(t, ok) - actualEpilog(&sqhook.Context{}) + epilog(nil) } }, }, diff --git a/agent/internal/rule/callback/callback_test.go b/agent/internal/rule/callback/callback_test.go index ce4be5d6..f66e45d5 100644 --- a/agent/internal/rule/callback/callback_test.go +++ b/agent/internal/rule/callback/callback_test.go @@ -24,92 +24,74 @@ type TestConfig struct { type ValidTestCase struct { Rule *FakeRule - TestCallbacks func(t *testing.T, rule *FakeRule, prolog, epilog sqhook.Callback) + TestCallbacks func(t *testing.T, rule *FakeRule, prolog sqhook.PrologCallback) } func RunCallbackTest(t *testing.T, config TestConfig) { for _, data := range config.InvalidTestCases { data := data t.Run("with incorrect data", func(t *testing.T) { - prolog, epilog, err := config.CallbacksCtor(&FakeRule{config: data}, nil, nil) + prolog, err := config.CallbacksCtor(&FakeRule{config: data}, nil) require.Error(t, err) require.Nil(t, prolog) - require.Nil(t, epilog) }) } for _, tc := range config.ValidTestCases { tc := tc t.Run("with correct data", func(t *testing.T) { - t.Run("without next callbacks", func(t *testing.T) { + t.Run("without next callback", func(t *testing.T) { // Instantiate the callback with the given correct rule data - prolog, epilog, err := config.CallbacksCtor(tc.Rule, nil, nil) + prolog, err := config.CallbacksCtor(tc.Rule, nil) require.NoError(t, err) - checkCallbacksValues(t, config, prolog, epilog) - tc.TestCallbacks(t, tc.Rule, prolog, epilog) + checkCallbacksValues(t, config, prolog) + tc.TestCallbacks(t, tc.Rule, prolog) }) - t.Run("with next callbacks", func(t *testing.T) { - t.Run("wrong next prolog type", func(t *testing.T) { - prolog, epilog, err := config.CallbacksCtor(tc.Rule, 33, nil) + t.Run("with next callback", func(t *testing.T) { + t.Run("with wrong next prolog type", func(t *testing.T) { + prolog, err := config.CallbacksCtor(tc.Rule, 33) require.Error(t, err) require.Nil(t, prolog) - require.Nil(t, epilog) - }) - - t.Run("wrong next epilog type", func(t *testing.T) { - prolog, epilog, err := config.CallbacksCtor(tc.Rule, nil, func() {}) - require.Error(t, err) - require.Nil(t, prolog) - require.Nil(t, epilog) }) t.Run("with correct next prolog", func(t *testing.T) { var called bool nextProlog := reflect.MakeFunc(config.PrologType, func(args []reflect.Value) (results []reflect.Value) { called = true - return []reflect.Value{reflect.Zero(reflect.TypeOf((*error)(nil)).Elem())} + return []reflect.Value{ + reflect.Zero(config.EpilogType), + reflect.Zero(reflect.TypeOf((*error)(nil)).Elem()), + } }).Interface() - prolog, epilog, err := config.CallbacksCtor(tc.Rule, nextProlog, nil) + prolog, err := config.CallbacksCtor(tc.Rule, nextProlog) require.NoError(t, err) - checkCallbacksValues(t, config, prolog, epilog) + checkCallbacksValues(t, config, prolog) require.NotNil(t, prolog) - tc.TestCallbacks(t, tc.Rule, prolog, epilog) + tc.TestCallbacks(t, tc.Rule, prolog) require.True(t, called) }) t.Run("with correct next epilog", func(t *testing.T) { - var called bool + var calledProlog, calledEpilog bool nextEpilog := reflect.MakeFunc(config.EpilogType, func(args []reflect.Value) (results []reflect.Value) { - called = true + calledEpilog = true return - }).Interface() - - prolog, epilog, err := config.CallbacksCtor(tc.Rule, nil, nextEpilog) - require.NoError(t, err) - checkCallbacksValues(t, config, prolog, epilog) - require.NotNil(t, epilog) - tc.TestCallbacks(t, tc.Rule, prolog, epilog) - require.True(t, called) - }) + }) - t.Run("with both correct next callbacks", func(t *testing.T) { - var calledProlog, calledEpilog bool nextProlog := reflect.MakeFunc(config.PrologType, func(args []reflect.Value) (results []reflect.Value) { calledProlog = true - return []reflect.Value{reflect.Zero(reflect.TypeOf((*error)(nil)).Elem())} - }).Interface() - nextEpilog := reflect.MakeFunc(config.EpilogType, func(args []reflect.Value) (results []reflect.Value) { - calledEpilog = true - return + return []reflect.Value{ + nextEpilog, + reflect.Zero(reflect.TypeOf((*error)(nil)).Elem()), + } }).Interface() - prolog, epilog, err := config.CallbacksCtor(tc.Rule, nextProlog, nextEpilog) + prolog, err := config.CallbacksCtor(tc.Rule, nextProlog) require.NoError(t, err) - require.NotNil(t, prolog) - require.NotNil(t, epilog) - tc.TestCallbacks(t, tc.Rule, prolog, epilog) + checkCallbacksValues(t, config, prolog) + tc.TestCallbacks(t, tc.Rule, prolog) require.True(t, calledProlog) require.True(t, calledEpilog) }) @@ -118,13 +100,10 @@ func RunCallbackTest(t *testing.T, config TestConfig) { } } -func checkCallbacksValues(t *testing.T, config TestConfig, prolog, epilog sqhook.Callback) { - if config.ExpectProlog { +func checkCallbacksValues(t *testing.T, config TestConfig, prolog sqhook.PrologCallback) { + if config.ExpectProlog || config.ExpectEpilog { require.NotNil(t, prolog) } - if config.ExpectEpilog { - require.NotNil(t, epilog) - } } type FakeRule struct { diff --git a/agent/internal/rule/callback/monitor-http-status-code.go b/agent/internal/rule/callback/monitor-http-status-code.go index d7c867f7..0efb961e 100644 --- a/agent/internal/rule/callback/monitor-http-status-code.go +++ b/agent/internal/rule/callback/monitor-http-status-code.go @@ -9,31 +9,37 @@ import ( "github.com/sqreen/go-agent/agent/sqlib/sqhook" ) -func NewMonitorHTTPStatusCodeCallbacks(rule Context, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) { +func NewMonitorHTTPStatusCodeCallbacks(rule Context, nextProlog sqhook.PrologCallback) (prolog sqhook.PrologCallback, err error) { // Next callbacks to call actualNextProlog, ok := nextProlog.(MonitorHTTPStatusCodePrologCallbackType) if nextProlog != nil && !ok { err = sqerrors.Errorf("unexpected next prolog type `%T` instead of `%T`", nextProlog, MonitorHTTPStatusCodePrologCallbackType(nil)) return } - // No epilog in this callback, so simply check and pass the given one - if _, ok := nextEpilog.(MonitorHTTPStatusCodeEpilogCallbackType); nextEpilog != nil && !ok { - err = sqerrors.Errorf("unexpected next epilog type `%T` instead of `%T`", nextEpilog, MonitorHTTPStatusCodeEpilogCallbackType(nil)) - return - } - return newMonitorHTTPStatusCodePrologCallback(rule, actualNextProlog), nextEpilog, nil + return newMonitorHTTPStatusCodePrologCallback(rule, actualNextProlog), nil } func newMonitorHTTPStatusCodePrologCallback(rule Context, next MonitorHTTPStatusCodePrologCallbackType) MonitorHTTPStatusCodePrologCallbackType { - return func(ctx *sqhook.Context, code *int) error { - rule.PushMetricsValue(*code, 1) + return func(r sqhook.MethodReceiver, code *int) (MonitorHTTPStatusCodeEpilogCallbackType, error) { + var ( + nextEpilog MonitorHTTPStatusCodeEpilogCallbackType + err error + ) + if next != nil { + nextEpilog, err = next(r, code) + } + return newMonitorHTTPStatusCodeEpilogCallback(rule, code, nextEpilog), err + } +} - if next == nil { - return nil +func newMonitorHTTPStatusCodeEpilogCallback(rule Context, code *int, next MonitorHTTPStatusCodeEpilogCallbackType) MonitorHTTPStatusCodeEpilogCallbackType { + return func() { + if next != nil { + defer next() } - return next(ctx, code) + rule.PushMetricsValue(*code, 1) } } -type MonitorHTTPStatusCodePrologCallbackType = func(*sqhook.Context, *int) error -type MonitorHTTPStatusCodeEpilogCallbackType = func(*sqhook.Context) +type MonitorHTTPStatusCodeEpilogCallbackType = func() +type MonitorHTTPStatusCodePrologCallbackType = func(sqhook.MethodReceiver, *int) (MonitorHTTPStatusCodeEpilogCallbackType, error) diff --git a/agent/internal/rule/callback/monitor-http-status-code_test.go b/agent/internal/rule/callback/monitor-http-status-code_test.go index 925dfa93..a5d8af9e 100644 --- a/agent/internal/rule/callback/monitor-http-status-code_test.go +++ b/agent/internal/rule/callback/monitor-http-status-code_test.go @@ -23,20 +23,19 @@ func TestNewMonitorHTTPStatusCodeCallbacks(t *testing.T) { ValidTestCases: []ValidTestCase{ { Rule: &FakeRule{}, - TestCallbacks: func(t *testing.T, rule *FakeRule, prolog, epilog sqhook.Callback) { + TestCallbacks: func(t *testing.T, rule *FakeRule, prolog sqhook.PrologCallback) { actualProlog, ok := prolog.(callback.MonitorHTTPStatusCodePrologCallbackType) require.True(t, ok) code := rand.Int() rule.On("PushMetricsValue", code, uint64(1)).Return().Once() - err := actualProlog(nil, &code) + epilog, err := actualProlog(sqhook.MethodReceiver{}, &code) // Check it behaves as expected require.NoError(t, err) // Test the epilog if any if epilog != nil { - actualEpilog, ok := epilog.(callback.MonitorHTTPStatusCodeEpilogCallbackType) require.True(t, ok) - actualEpilog(&sqhook.Context{}) + epilog() } }, }, diff --git a/agent/internal/rule/callback/write-custom-error-page.go b/agent/internal/rule/callback/write-custom-error-page.go index cd3e6a07..79b645ef 100644 --- a/agent/internal/rule/callback/write-custom-error-page.go +++ b/agent/internal/rule/callback/write-custom-error-page.go @@ -16,7 +16,7 @@ import ( // callbacks modifying the arguments of `httphandler.WriteResponse` in order to // modify the http status code and error page that are provided by the rule's // data. -func NewWriteCustomErrorPageCallbacks(rule Context, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) { +func NewWriteCustomErrorPageCallbacks(rule Context, nextProlog sqhook.PrologCallback) (prolog sqhook.PrologCallback, err error) { var statusCode = 500 if cfg := rule.Config(); cfg != nil { cfg, ok := cfg.(*api.CustomErrorPageRuleDataEntry) @@ -33,28 +33,23 @@ func NewWriteCustomErrorPageCallbacks(rule Context, nextProlog, nextEpilog sqhoo err = sqerrors.Errorf("unexpected next prolog type `%T`", nextProlog) return } - // No epilog in this callback, so simply check and pass the given one - if _, ok := nextEpilog.(WriteCustomErrorPageEpilogCallbackType); nextEpilog != nil && !ok { - err = sqerrors.Errorf("unexpected next epilog type `%T` instead of `%T`", nextEpilog, WriteCustomErrorPageEpilogCallbackType(nil)) - return - } - return newWriteCustomErrorPagePrologCallback(statusCode, []byte(blockedBySqreenPage), actualNextProlog), nextEpilog, nil + return newWriteCustomErrorPagePrologCallback(statusCode, []byte(blockedBySqreenPage), actualNextProlog), nil } -type WriteCustomErrorPagePrologCallbackType = func(*sqhook.Context, *http.ResponseWriter, **http.Request, *http.Header, *int, *[]byte) error -type WriteCustomErrorPageEpilogCallbackType = func(*sqhook.Context) +type WriteCustomErrorPageEpilogCallbackType = func() +type WriteCustomErrorPagePrologCallbackType = func(*http.ResponseWriter, **http.Request, *http.Header, *int, *[]byte) (WriteCustomErrorPageEpilogCallbackType, error) // The prolog callback modifies the function arguments in order to replace the // written status code and body. func newWriteCustomErrorPagePrologCallback(statusCode int, body []byte, next WriteCustomErrorPagePrologCallbackType) WriteCustomErrorPagePrologCallbackType { - return func(ctx *sqhook.Context, callerWriter *http.ResponseWriter, callerRequest **http.Request, callerHeaders *http.Header, callerStatusCode *int, callerBody *[]byte) error { + return func(callerWriter *http.ResponseWriter, callerRequest **http.Request, callerHeaders *http.Header, callerStatusCode *int, callerBody *[]byte) (WriteCustomErrorPageEpilogCallbackType, error) { *callerStatusCode = statusCode *callerBody = body if next == nil { - return nil + return nil, nil } - return next(ctx, callerWriter, callerRequest, callerHeaders, callerStatusCode, callerBody) + return next(callerWriter, callerRequest, callerHeaders, callerStatusCode, callerBody) } } diff --git a/agent/internal/rule/callback/write-custom-error-page_test.go b/agent/internal/rule/callback/write-custom-error-page_test.go index c1968b18..53fdfe84 100644 --- a/agent/internal/rule/callback/write-custom-error-page_test.go +++ b/agent/internal/rule/callback/write-custom-error-page_test.go @@ -39,15 +39,15 @@ func TestNewWriteCustomErrorPageCallbacks(t *testing.T) { }) } -func testWriteCustomErrorPageCallbacks(expectedStatusCode int) func(t *testing.T, rule *FakeRule, prolog sqhook.Callback, epilog sqhook.Callback) { - return func(t *testing.T, _ *FakeRule, prolog, epilog sqhook.Callback) { +func testWriteCustomErrorPageCallbacks(expectedStatusCode int) func(t *testing.T, rule *FakeRule, prolog sqhook.PrologCallback) { + return func(t *testing.T, _ *FakeRule, prolog sqhook.PrologCallback) { actualProlog, ok := prolog.(callback.WriteCustomErrorPagePrologCallbackType) require.True(t, ok) var ( statusCode int body []byte ) - err := actualProlog(nil, nil, nil, nil, &statusCode, &body) + epilog, err := actualProlog(nil, nil, nil, &statusCode, &body) // Check it behaves as expected require.NoError(t, err) require.Equal(t, expectedStatusCode, statusCode) @@ -55,9 +55,8 @@ func testWriteCustomErrorPageCallbacks(expectedStatusCode int) func(t *testing.T // Test the epilog if any if epilog != nil { - actualEpilog, ok := epilog.(callback.AddSecurityHeadersEpilogCallbackType) require.True(t, ok) - actualEpilog(&sqhook.Context{}) + epilog() } } } diff --git a/agent/internal/rule/callback/write-http-redirection.go b/agent/internal/rule/callback/write-http-redirection.go index ba7f6b89..a487dadf 100644 --- a/agent/internal/rule/callback/write-http-redirection.go +++ b/agent/internal/rule/callback/write-http-redirection.go @@ -17,7 +17,7 @@ import ( // callbacks modifying the arguments of `httphandler.WriteResponse` in order to // modify the http status code and headers in order to perform an HTTP // redirection to the URL provided by the rule's data. -func NewWriteHTTPRedirectionCallbacks(rule Context, nextProlog, nextEpilog sqhook.Callback) (prolog, epilog sqhook.Callback, err error) { +func NewWriteHTTPRedirectionCallbacks(rule Context, nextProlog sqhook.PrologCallback) (prolog sqhook.PrologCallback, err error) { var redirectionURL string if cfg := rule.Config(); cfg != nil { cfg, ok := cfg.(*api.RedirectionRuleDataEntry) @@ -42,21 +42,16 @@ func NewWriteHTTPRedirectionCallbacks(rule Context, nextProlog, nextEpilog sqhoo err = sqerrors.Errorf("unexpected next prolog type `%T`", nextProlog) return } - // No epilog in this callback, so simply check and pass the given one - if _, ok := nextEpilog.(WriteHTTPRedirectionEpilogCallbackType); nextEpilog != nil && !ok { - err = sqerrors.Errorf("unexpected next epilog type `%T` instead of `%T`", nextEpilog, WriteHTTPRedirectionEpilogCallbackType(nil)) - return - } - return newWriteHTTPRedirectionPrologCallback(redirectionURL, actualNextProlog), nextEpilog, nil + return newWriteHTTPRedirectionPrologCallback(redirectionURL, actualNextProlog), nil } -type WriteHTTPRedirectionPrologCallbackType = func(*sqhook.Context, *http.ResponseWriter, **http.Request, *http.Header, *int, *[]byte) error -type WriteHTTPRedirectionEpilogCallbackType = func(*sqhook.Context) +type WriteHTTPRedirectionEpilogCallbackType = func() +type WriteHTTPRedirectionPrologCallbackType = func(*http.ResponseWriter, **http.Request, *http.Header, *int, *[]byte) (WriteHTTPRedirectionEpilogCallbackType, error) // The prolog callback modifies the function arguments in order to perform an // HTTP redirection. func newWriteHTTPRedirectionPrologCallback(url string, next WriteHTTPRedirectionPrologCallbackType) WriteHTTPRedirectionPrologCallbackType { - return func(ctx *sqhook.Context, callerWriter *http.ResponseWriter, callerRequest **http.Request, callerHeaders *http.Header, callerStatusCode *int, callerBody *[]byte) error { + return func(callerWriter *http.ResponseWriter, callerRequest **http.Request, callerHeaders *http.Header, callerStatusCode *int, callerBody *[]byte) (WriteHTTPRedirectionEpilogCallbackType, error) { *callerStatusCode = http.StatusSeeOther if *callerHeaders == nil { *callerHeaders = make(http.Header) @@ -64,8 +59,8 @@ func newWriteHTTPRedirectionPrologCallback(url string, next WriteHTTPRedirection callerHeaders.Set("Location", url) if next == nil { - return nil + return nil, nil } - return next(ctx, callerWriter, callerRequest, callerHeaders, callerStatusCode, callerBody) + return next(callerWriter, callerRequest, callerHeaders, callerStatusCode, callerBody) } } diff --git a/agent/internal/rule/callback/write-http-redirection_test.go b/agent/internal/rule/callback/write-http-redirection_test.go index 9edd432e..6388d126 100644 --- a/agent/internal/rule/callback/write-http-redirection_test.go +++ b/agent/internal/rule/callback/write-http-redirection_test.go @@ -34,7 +34,7 @@ func TestNewWriteHTTPRedirectionCallbacks(t *testing.T) { Rule: &FakeRule{ config: &api.RedirectionRuleDataEntry{"http://sqreen.com"}, }, - TestCallbacks: func(t *testing.T, rule *FakeRule, prolog, epilog sqhook.Callback) { + TestCallbacks: func(t *testing.T, rule *FakeRule, prolog sqhook.PrologCallback) { // Call it and check the behaviour follows the rule's data actualProlog, ok := prolog.(callback.WriteHTTPRedirectionPrologCallbackType) require.True(t, ok) @@ -42,7 +42,7 @@ func TestNewWriteHTTPRedirectionCallbacks(t *testing.T) { statusCode int headers http.Header ) - err := actualProlog(nil, nil, nil, &headers, &statusCode, nil) + epilog, err := actualProlog(nil, nil, &headers, &statusCode, nil) // Check it behaves as expected require.NoError(t, err) require.Equal(t, http.StatusSeeOther, statusCode) @@ -51,9 +51,8 @@ func TestNewWriteHTTPRedirectionCallbacks(t *testing.T) { // Test the epilog if any if epilog != nil { - actualEpilog, ok := epilog.(callback.WriteHTTPRedirectionEpilogCallbackType) require.True(t, ok) - actualEpilog(&sqhook.Context{}) + epilog() } }, }, diff --git a/agent/internal/rule/callback_test.go b/agent/internal/rule/callback_test.go index b234e432..61e051d6 100644 --- a/agent/internal/rule/callback_test.go +++ b/agent/internal/rule/callback_test.go @@ -46,7 +46,7 @@ func TestNewCallbacks(t *testing.T) { } { tc := tc t.Run(tc.testName, func(t *testing.T) { - _, _, err := rule.NewCallbacks(tc.name, tc.rule, nil, nil) + _, err := rule.NewCallbacks(tc.name, tc.rule, nil) if tc.shouldSucceed { require.NoError(t, err) } else { diff --git a/agent/internal/rule/rule.go b/agent/internal/rule/rule.go index fe1c88f0..f4386896 100644 --- a/agent/internal/rule/rule.go +++ b/agent/internal/rule/rule.go @@ -64,16 +64,16 @@ func (e *Engine) PackID() string { // them by atomically modifying the hooks, and removing what is left. func (e *Engine) SetRules(packID string, rules []api.Rule) { // Create the net rule descriptors and replace the existing ones - ruleDescriptors := newHookDescriptors(e.logger, rules, e.metricsEngine) + ruleDescriptors := newHookDescriptors(e.logger, rules, e.publicKey, e.metricsEngine) e.setRules(packID, ruleDescriptors) } func (e *Engine) setRules(packID string, descriptors hookDescriptors) { - for hook, callback := range descriptors { + for hook, prolog := range descriptors { if e.enabled { // TODO: chain multiple callbacks per hookpoint using a callback of callbacks // Attach the callback to the hook - err := hook.Attach(callback.prolog, callback.epilog) + err := hook.Attach(prolog) if err != nil { e.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule: could not attach the callbacks"))) continue @@ -86,7 +86,7 @@ func (e *Engine) setRules(packID string, descriptors hookDescriptors) { // Disable previously enabled rules that were not replaced by new ones. for hook := range e.hooks { - err := hook.Attach(nil, nil) + err := hook.Attach(nil) if err != nil { e.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule: could not attach the callbacks"))) continue @@ -119,19 +119,16 @@ func newHookDescriptors(logger Logger, rules []api.Rule, publicKey *ecdsa.Public continue } // Instantiate the callback - next := hookDescriptors.Get(hook) + nextProlog := hookDescriptors.Get(hook) ruleDescriptor := NewCallbackContext(&r, logger, metricsEngine) - prolog, epilog, err := NewCallbacks(hookpoint.Callback, ruleDescriptor, next.prolog, next.epilog) + prolog, err := NewCallbacks(hookpoint.Callback, ruleDescriptor, nextProlog) if err != nil { logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule `%s`: could not instantiate the callbacks", r.Name))) continue } // Create the descriptor with everything required to be able to enable or // disable it afterwards. - hookDescriptors.Set(hook, callbacksDescriptor{ - prolog: prolog, - epilog: epilog, - }) + hookDescriptors.Set(hook, prolog) } // Nothing in the end if len(hookDescriptors) == 0 { @@ -142,10 +139,10 @@ func newHookDescriptors(logger Logger, rules []api.Rule, publicKey *ecdsa.Public // Enable the hooks of the ongoing configured rules. func (e *Engine) Enable() { - for hook, callback := range e.hooks { - err := hook.Attach(callback.prolog, callback.epilog) + for hook, prolog := range e.hooks { + err := hook.Attach(prolog) if err != nil { - e.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule: could not attach the callbacks `%v` and `%v` to hook `%v`", callback.prolog, callback.epilog, hook))) + e.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule: could not attach the callbacks `%v` to hook `%v`", prolog, hook))) } } e.enabled = true @@ -154,7 +151,7 @@ func (e *Engine) Enable() { // Disable the hooks currently attached to callbacks. func (e *Engine) Disable() { for hook := range e.hooks { - err := hook.Attach(nil, nil) + err := hook.Attach(nil) if err != nil { e.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule: could not disable hook `%v`", hook))) } @@ -162,16 +159,12 @@ func (e *Engine) Disable() { e.enabled = false } -type hookDescriptors map[*sqhook.Hook]callbacksDescriptor +type hookDescriptors map[*sqhook.Hook]sqhook.PrologCallback -type callbacksDescriptor struct { - prolog, epilog sqhook.Callback +func (m hookDescriptors) Set(hook *sqhook.Hook, prolog sqhook.PrologCallback) { + m[hook] = prolog } -func (m hookDescriptors) Set(hook *sqhook.Hook, descriptor callbacksDescriptor) { - m[hook] = descriptor -} - -func (m hookDescriptors) Get(hook *sqhook.Hook) callbacksDescriptor { +func (m hookDescriptors) Get(hook *sqhook.Hook) sqhook.PrologCallback { return m[hook] } diff --git a/agent/internal/rule/rule_test.go b/agent/internal/rule/rule_test.go index 10c57451..4ef0c5e0 100644 --- a/agent/internal/rule/rule_test.go +++ b/agent/internal/rule/rule_test.go @@ -89,48 +89,40 @@ func TestEngineUsage(t *testing.T) { t.Run("callbacks are not attached when disabled", func(t *testing.T) { // Check the callbacks were not attached because rules are disabled - prologFunc1, epilogFunc1 := hookFunc1.Callbacks() + prologFunc1 := hookFunc1.Prolog() require.Nil(t, prologFunc1) - require.Nil(t, epilogFunc1) - prologFunc2, epilogFunc2 := hookFunc2.Callbacks() + prologFunc2 := hookFunc2.Prolog() require.Nil(t, prologFunc2) - require.Nil(t, epilogFunc2) }) t.Run("enabling the rules attaches the callbacks", func(t *testing.T) { // Enable the rules engine.Enable() // Check the callbacks were now attached - prologFunc1, epilogFunc1 := hookFunc1.Callbacks() + prologFunc1 := hookFunc1.Prolog() require.NotNil(t, prologFunc1) - require.Nil(t, epilogFunc1) - prologFunc2, epilogFunc2 := hookFunc2.Callbacks() + prologFunc2 := hookFunc2.Prolog() require.NotNil(t, prologFunc2) - require.Nil(t, epilogFunc2) }) t.Run("disabling the rules removes the callbacks", func(t *testing.T) { // Disable the rules engine.Disable() // Check the callbacks were all removed for func1 and not func2 - prologFunc1, epilogFunc1 := hookFunc1.Callbacks() + prologFunc1 := hookFunc1.Prolog() require.Nil(t, prologFunc1) - require.Nil(t, epilogFunc1) - prologFunc2, epilogFunc2 := hookFunc2.Callbacks() + prologFunc2 := hookFunc2.Prolog() require.Nil(t, prologFunc2) - require.Nil(t, epilogFunc2) }) t.Run("enabling the rules again sets back the callbacks", func(t *testing.T) { // Enable again the rules engine.Enable() // Check the callbacks are attached again - prologFunc1, epilogFunc1 := hookFunc1.Callbacks() + prologFunc1 := hookFunc1.Prolog() require.NotNil(t, prologFunc1) - require.Nil(t, epilogFunc1) - prologFunc2, epilogFunc2 := hookFunc2.Callbacks() + prologFunc2 := hookFunc2.Prolog() require.NotNil(t, prologFunc2) - require.Nil(t, epilogFunc2) }) }) diff --git a/agent/sqlib/sqhook/hook.go b/agent/sqlib/sqhook/hook.go index 5b859adf..5974ef72 100644 --- a/agent/sqlib/sqhook/hook.go +++ b/agent/sqlib/sqhook/hook.go @@ -3,13 +3,13 @@ // https://www.sqreen.io/terms.html // Package sqhook provides a pure Go implementation of hooks to be inserted -// into function definitions in order to be able to attach prolog and epilog -// callbacks giving read/write access to the arguments and returned values of -// the function call at run time. +// into function definitions in order to be able to attach at run time prolog +// and epilog callbacks getting read/write access to the arguments and returned +// values of the function call. // // A hook needs to be globally created and associated to a function symbol at -// package initialization time. Prolog and epilog callbacks can then be -// accessed in the function call to pass the call arguments and return values. +// package initialization time. Callbacks can then be accessed in the function +// call to pass the call arguments and return values. // // On the other side, callbacks can be attached to a hook at run time. Prolog // callbacks get read/write access to the arguments of the function call before @@ -20,9 +20,15 @@ // Given a function F: // func F(A, B, C) (R, S, T) // The expected prolog signature is: -// type prolog = func(*sqhook.Context, *A, *B, *C) error +// type prolog = func(*A, *B, *C) (epilog, error) // The expected epilog signature is: -// type epilog = func(*sqhook.Context, *R, *S, *T) +// type epilog = func(*R, *S, *T) +// +// Note 1: the prolog callback returns the epilog callback - which can be nil +// when not required - so that context can be shared using a closure. +// +// Note 2: a prolog for a method should accept the method receiver pointer as +// first argument wrapped into a `sqhook.MethodReceiver` value. // // Example: // // Define the hook globally @@ -37,36 +43,37 @@ // func Example(arg1 int, arg2 string) (ret1 []byte, ret2 error) { // // Use the hook first and call its callbacks // { -// type Prolog = func(*sqhook.Context, *int, *string) error -// type Epilog = func(*sqhook.Context, *[]byte, *error) -// // Create a call context -// ctx := sqhook.Context{} -// prolog, epilog := exampleHook.Callbacks() -// // If an epilog is set, defer the call to the epilog -// if epilog, ok := epilog.(Epilog); ok { -// // Pass pointers to the return values -// defer epilog(&ctx, &ret1, &ret2) -// } -// // If a prolog is set, call it +// type Epilog = func(*[]byte, *error) +// type Prolog = func(*int, *string) (Epilog, error) +// // Get the prolog callback and call it if it is not nil +// prolog := exampleHook.Prolog() // if prolog, ok := prolog.(Prolog); ok { // // Pass pointers to the arguments -// err := prolog(&ctx, &w, &r, &headers, &statusCode, &body) +// epilog, err := prolog(&w, &r, &headers, &statusCode, &body) // // If an error is returned, the function execution is aborted. -// // The deferred epilog call will still be executed before returning. +// // The epilog still needs to be called if set. A deferred call to it +// // does the job. +// if epilog != nil { +// // Pass pointers to the return values +// defer epilog(&ret1, &ret2) +// } // if err != nil { // return // } // } // } -// // .. function code ... +// /* .. function code ... */ // } // // // Main requirements: -// - Concurrent access and modification of callbacks. -// - Reentrant implementation of callbacks with a call context when data needs -// to be shared between the prolog and epilog. // +// - Concurrent access and modification of callbacks. +// - Ability to read/write arguments and return values. +// - Hook to the prolog and epilog of a function. +// - Epilog callbacks should be able to recover from a panic. +// - Callbacks should be reentrant. If any context needs to be shared, it +// should be done through the closure. // - Fast call dispatch for callbacks that don't need to be generic, ie. // callbacks that are designed to be attached to specific functions. // Type-assertion instead of `reflect.Call` is therefore used while generic @@ -101,34 +108,29 @@ var index = make(map[string]*Hook) type Hook struct { // The function type where the hook is used. fnType reflect.Type - // Pointer to a structure containing the callbacks in order to be able to - // atomically modify the pointer. - attached *callbacks + // Currently attached callback. + attached *PrologCallback // Symbol name where the hook is used. Required for the stringer. symbol string } -type callbacks struct { - prolog, epilog Callback -} - -// Callback is a function expecting a Context pointer as first argument, -// followed by the pointers to the arguments of the hooked function for a -// prolog, and followed by the pointers to the returned values for an epilog. -type Callback interface{} - -// Context is a call context for the hook. It is shared between the prolog and -// epilog and is unique for each function call. It allows callbacks to provide -// reentrant implementations when memory needs to be shared for a given call. -type Context []interface{} - -// MethodReceiver is store in the context when hooking a method. -type MethodReceiver interface{} +// PrologCallback is an interface type to a prolog function. +// Given a function F: +// func F(A, B, C) (R, S, T) +// The expected prolog signature is: +// type prolog = func(*A, *B, *C) (epilog, error) +// The expected epilog signature is: +// type epilog = func(*R, *S, *T) +// The returned epilog value can be nil when there is no need for epilog. +type PrologCallback interface{} -type Error int +// MethodReceiver should be the first argument of the prolog of a method. +type MethodReceiver struct{ Receiver interface{} } // Errors that hooks can return in order to modify the control flow of the // function. +type Error int + const ( _ Error = iota // Abort the execution of the function by returning from it. @@ -175,69 +177,81 @@ func Find(symbol string) *Hook { return index[symbol] } -// Attach atomically attaches prolog and epilog callbacks to the hook. It is -// possible to pass nil values when only one type of callback is required. If -// both arguments are nil, the callbacks are removed. -func (h *Hook) Attach(prolog, epilog Callback) error { - var cbs *callbacks - if prolog != nil || epilog != nil { - cbs = &callbacks{} - if prolog != nil { - // Create the list of argument types - argTypes := make([]reflect.Type, 0, h.fnType.NumIn()) - for i := 0; i < h.fnType.NumIn(); i++ { - argTypes = append(argTypes, h.fnType.In(i)) - } - if err := validateProlog(prolog, argTypes); err != nil { - return err - } - cbs.prolog = prolog - } - if epilog != nil { - // Create the list of return types - retTypes := make([]reflect.Type, 0, h.fnType.NumOut()) - for i := 0; i < h.fnType.NumOut(); i++ { - retTypes = append(retTypes, h.fnType.Out(i)) - } - if err := validateEpilog(epilog, retTypes); err != nil { - return err - } - cbs.epilog = epilog - } +// Attach atomically attaches a prolog callback to the hook. It is +// possible to pass a `nil` value to remove the attached callback. +func (h *Hook) Attach(prolog PrologCallback) error { + addr := (*unsafe.Pointer)(unsafe.Pointer(&h.attached)) + if prolog == nil { + atomic.StorePointer(addr, nil) + return nil } - atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&h.attached)), unsafe.Pointer(cbs)) + if err := validateProlog(prolog, h.fnType); err != nil { + return err + } + atomic.StorePointer(addr, unsafe.Pointer(&prolog)) return nil } -// Callbacks atomically accesses the attached prolog and epilog callbacks. -func (h *Hook) Callbacks() (prolog, epilog Callback) { - attached := (*callbacks)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&h.attached)))) +// Callbacks atomically accesses the attached prolog. +func (h *Hook) Prolog() (prolog PrologCallback) { + addr := (*unsafe.Pointer)(unsafe.Pointer(&h.attached)) + attached := (*PrologCallback)(atomic.LoadPointer(addr)) if attached == nil { - return nil, nil + return nil } - return attached.prolog, attached.epilog + return *attached } // validateProlog validates that the prolog has the expected signature. -func validateProlog(prolog Callback, argTypes []reflect.Type) error { - if err := validateCallback(prolog, argTypes); err != nil { +func validateProlog(prolog PrologCallback, fnType reflect.Type) (err error) { + defer func() { + if err != nil { + err = sqerrors.Wrap(err, "prolog validation error") + } + }() + // Create the list of argument types + callbackArgsTypes := make([]reflect.Type, 0, fnType.NumIn()) + for i := 0; i < fnType.NumIn(); i++ { + callbackArgsTypes = append(callbackArgsTypes, fnType.In(i)) + } + // Check the prolog args are pointers to the callback args + prologType := reflect.TypeOf(prolog) + if err := validateCallback(prologType, callbackArgsTypes); err != nil { return err } - t := reflect.TypeOf(prolog) - if t.NumOut() != 1 || !t.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) { - return sqerrors.New("validation: the prolog callback should return an error value") + // Check the prolog returns two values + if numPrologOut, numCallbackOut := prologType.NumOut(), 2; numPrologOut != numCallbackOut { + return sqerrors.Errorf("wrong number of returned values, expected `%d` but got `%d`", numCallbackOut, numPrologOut) + } + // Check the second returned value is an error + if ret1Type := prologType.Out(1); ret1Type != reflect.TypeOf((*error)(nil)).Elem() { + return sqerrors.Errorf("unexpected second return type `%s` instead of `error`", ret1Type) + } + // Check the first returned value is the expected epilog type + epilogType := prologType.Out(0) + if err := validateEpilog(epilogType, fnType); err != nil { + return err } return nil } // validateEpilog validates that the epilog has the expected signature. -func validateEpilog(epilog Callback, argTypes []reflect.Type) error { - if err := validateCallback(epilog, argTypes); err != nil { +func validateEpilog(epilogType reflect.Type, fnType reflect.Type) (err error) { + defer func() { + if err != nil { + err = sqerrors.Wrap(err, "epilog validation error") + } + }() + // Create the list of argument types + callbackRetTypes := make([]reflect.Type, 0, fnType.NumOut()) + for i := 0; i < fnType.NumOut(); i++ { + callbackRetTypes = append(callbackRetTypes, fnType.Out(i)) + } + if err := validateCallback(epilogType, callbackRetTypes); err != nil { return err } - t := reflect.TypeOf(epilog) - if t.NumOut() != 0 { - return sqerrors.New("validation: the epilog callback should not return values") + if numOut := epilogType.NumOut(); numOut != 0 { + return sqerrors.Errorf("unexpected number of return values `%d` instead of `0`", numOut) } return nil } @@ -245,37 +259,30 @@ func validateEpilog(epilog Callback, argTypes []reflect.Type) error { // validateCallback validates the fact that the callback is a function whose // first argument is the hook context and the rest of its arguments can be // assigned the hook argument values. -func validateCallback(callback Callback, argTypes []reflect.Type) (err error) { - defer func() { - if err != nil { - err = sqerrors.Wrap(err, "validation error") - } - }() - callbackType := reflect.TypeOf(callback) +func validateCallback(callbackType reflect.Type, argTypes []reflect.Type) error { // Check the callback is a function if callbackType.Kind() != reflect.Func { return sqerrors.New("the callback argument is not a function") } callbackArgc := callbackType.NumIn() - // Check the callback accepts a hook context as first argument - if callbackArgc < 1 { - return sqerrors.New("the callback should expect a hook context as first argument") - } - if !reflect.TypeOf((*Context)(nil)).AssignableTo(callbackType.In(0)) { - return sqerrors.New("the callback should expect a hook context as first argument") - } - // Check the argument count - fnArgc := len(argTypes) - if callbackArgc-1 != fnArgc && callbackArgc != fnArgc { - return sqerrors.Errorf("the callback arguments count `%d` is not compatible to the hook arguments count `%d`", callbackArgc, fnArgc) + // Check the callback accepts the same number of arguments than the function + // Note that the method receiver is in the argument list of the type + // definition. + if callbackArgc != len(argTypes) { + return sqerrors.Errorf("the callback should have the same arguments: `%d` callback arguments while expecting `%d`", callbackArgc, len(argTypes)) } - // Check arguments are assignable - var i int - for i = 1; i < callbackArgc; i++ { - argPtrType := reflect.PtrTo(argTypes[i-1]) - callbackArgType := callbackType.In(i) - if !argPtrType.AssignableTo(callbackArgType) { - return sqerrors.Errorf("hook argument `%d` of type `%s` cannot be assigned to the callback argument `%d` of type `%s`", i-1, argPtrType, i, callbackArgType) + // Check arguments are pointers to the same types than the function arguments. + if callbackArgc > 0 { + i := 0 + if callbackType.In(0) == reflect.TypeOf(MethodReceiver{}) { + i++ + } + for ; i < callbackArgc; i++ { + argPtrType := reflect.PtrTo(argTypes[i]) + callbackArgType := callbackType.In(i) + if argPtrType != callbackArgType { + return sqerrors.Errorf("argument `%d` has type `%s` instead of `%s`", i, callbackArgType, argPtrType) + } } } return nil diff --git a/agent/sqlib/sqhook/hook_test.go b/agent/sqlib/sqhook/hook_test.go index 6a187bf7..ae3dd5a8 100644 --- a/agent/sqlib/sqhook/hook_test.go +++ b/agent/sqlib/sqhook/hook_test.go @@ -8,7 +8,9 @@ import ( "errors" "fmt" "reflect" + "sync/atomic" "testing" + "unsafe" fuzz "github.com/google/gofuzz" "github.com/sqreen/go-agent/agent/sqlib/sqhook" @@ -21,8 +23,52 @@ func (example) method() {} func (example) ExportedMethod() {} func (*example) methodPointerReceiver() {} -func function(_ int, _ string, _ bool) error { return nil } -func ExportedFunction(_ int, _ string, _ bool) error { return nil } +func function(_ int, _ string, _ bool) (float32, error) { return 0, nil } +func ExportedFunction(_ int, _ string, _ bool) error { return nil } + +func TestGoAssumptions(t *testing.T) { + t.Run("getting a function pointer using reflect", func(t *testing.T) { + var fn interface{} = function + require.Equal(t, reflect.ValueOf(fn).Pointer(), reflect.ValueOf(function).Pointer()) + }) + + t.Run("atomic store a function pointer", func(t *testing.T) { + var cb *sqhook.PrologCallback + addr := (*unsafe.Pointer)((unsafe.Pointer)(&cb)) + + // Atomically store the function pointer + var v sqhook.PrologCallback = function + atomic.StorePointer(addr, unsafe.Pointer(&v)) + // Atomic load in order to ensure the sequential order of the memory + // accesses to &fn. Non-atomic reads could be otherwise reordered. + atomicLoad := atomic.LoadPointer(addr) // sequential read barrier under the hood + + // Check + require.Equal(t, unsafe.Pointer(&v), atomicLoad) + require.Equal(t, &v, cb) + require.True(t, reflect.TypeOf(*cb) == reflect.TypeOf(function)) + + // Atomically store nil + atomic.StorePointer(addr, nil) + atomicLoad = atomic.LoadPointer(addr) // sequential read barrier under the hood + + // Check + require.Equal(t, unsafe.Pointer(nil), atomicLoad) + require.Equal(t, (*sqhook.PrologCallback)(nil), cb) + }) + + t.Run("the first argument of a method is the method receiver", func(t *testing.T) { + require.Equal(t, reflect.TypeOf(example{}).Name(), reflect.TypeOf(example.method).In(0).Name()) + }) + + t.Run("types can be compared using operator == on reflect.Type", func(t *testing.T) { + val1 := example{} + val2 := example{} + val3 := &example{} + require.True(t, reflect.TypeOf(val1) == reflect.TypeOf(val2)) + require.True(t, reflect.TypeOf(val1) != reflect.TypeOf(val3) && reflect.TypeOf(val3) == reflect.PtrTo(reflect.TypeOf(val1))) + }) +} func TestNew(t *testing.T) { for _, tc := range []struct { @@ -76,169 +122,116 @@ func TestFind(t *testing.T) { func TestAttach(t *testing.T) { for _, tc := range []struct { - function, expectedProlog, expectedEpilog interface{} - notExpectedPrologs, notExpectedEpilogs []interface{} + function, expected interface{} + unexpected []interface{} }{ { - function: func() {}, - expectedProlog: func(*sqhook.Context) error { return nil }, - expectedEpilog: func(*sqhook.Context) {}, - notExpectedPrologs: []interface{}{ + function: func() {}, + expected: func() (func(), error) { return nil, nil }, + unexpected: []interface{}{ + "not even a function", + func() (func(), int) { return nil, 33 }, func() error { return nil }, - func(*sqhook.Context) {}, - func() {}, - }, - notExpectedEpilogs: []interface{}{ - func() error { return nil }, - func(*sqhook.Context) error { return nil }, + func() func() { return nil }, func() {}, + func() (func() error, error) { return nil, nil }, }, }, { - function: example.method, - expectedProlog: func(*sqhook.Context) error { return nil }, - expectedEpilog: func(*sqhook.Context) {}, - notExpectedPrologs: []interface{}{ - func() error { return nil }, - func(*sqhook.Context) {}, - func() {}, - }, - notExpectedEpilogs: []interface{}{ + function: example.method, + expected: func(sqhook.MethodReceiver) (func(), error) { return nil, nil }, + unexpected: []interface{}{ func() error { return nil }, - func(*sqhook.Context) error { return nil }, func() {}, + func(example) (func(), error) { return nil, nil }, + func(*example) (func() error, error) { return nil, nil }, + func(*example) func() { return nil }, }, }, { - function: example.ExportedMethod, - expectedProlog: func(*sqhook.Context) error { return nil }, - expectedEpilog: func(*sqhook.Context) {}, - notExpectedPrologs: []interface{}{ + function: example.method, + expected: func(*example) (func(), error) { return nil, nil }, + unexpected: []interface{}{ func() error { return nil }, - func(*sqhook.Context) {}, - func() {}, - }, - notExpectedEpilogs: []interface{}{ - func() error { return nil }, - func(*sqhook.Context) error { return nil }, func() {}, + func(example) (func(), error) { return nil, nil }, + func(*example) (func() error, error) { return nil, nil }, + func(*example) func() { return nil }, }, }, { - function: (*example).methodPointerReceiver, - expectedProlog: func(*sqhook.Context) error { return nil }, - expectedEpilog: func(*sqhook.Context) {}, - notExpectedPrologs: []interface{}{ + function: example.ExportedMethod, + expected: func(*example) (func(), error) { return nil, nil }, + unexpected: []interface{}{ func() error { return nil }, - func(*sqhook.Context) {}, func() {}, + func(example) (func(), error) { return nil, nil }, + func(*example) (func() error, error) { return nil, nil }, + func(*example) error { return nil }, }, - notExpectedEpilogs: []interface{}{ + }, + { + function: (*example).methodPointerReceiver, + expected: func(**example) (func(), error) { return nil, nil }, + unexpected: []interface{}{ func() error { return nil }, - func(*sqhook.Context) error { return nil }, func() {}, + func(*example) (func(), error) { return nil, nil }, + func(example) (func(), error) { return nil, nil }, + func(**example) func() error { return nil }, + func(**example) error { return nil }, }, }, { - function: function, - expectedProlog: func(*sqhook.Context, *int, *string, *bool) error { return nil }, - expectedEpilog: func(*sqhook.Context, *error) {}, - notExpectedPrologs: []interface{}{ - func(*sqhook.Context, *int, *bool, *bool) error { return nil }, + function: function, + expected: func(*int, *string, *bool) (func(*float32, *error), error) { return nil, nil }, + unexpected: []interface{}{ + func(*int, *bool, *bool) error { return nil }, func(*int, *string, *bool) error { return nil }, - func(*sqhook.Context, *int, *string, *bool) {}, - func(*sqhook.Context, int, string, bool) error { return nil }, - func(*sqhook.Context, *int, *bool) error { return nil }, - func(*sqhook.Context, *int) {}, - }, - notExpectedEpilogs: []interface{}{ - func(*error) {}, - func(*sqhook.Context, error) {}, - func() {}, + func(*int, *string, *bool) {}, + func(int, string, bool) error { return nil }, + func(*int, *bool) error { return nil }, + func(*int) {}, + func(*int, *string, *bool) (func(*error), error) { return nil, nil }, + func(*int, *string, *bool) (func(float32, *error), error) { return nil, nil }, + func(*int, *string, bool) (func(*float32, *error), error) { return nil, nil }, }, }, { - function: ExportedFunction, - expectedProlog: func(*sqhook.Context, *int, *string, *bool) error { return nil }, - expectedEpilog: func(*sqhook.Context, *error) {}, - notExpectedPrologs: []interface{}{ - func(*sqhook.Context, *int, *string, *string) error { return nil }, + function: ExportedFunction, + expected: func(*int, *string, *bool) (func(*error), error) { return nil, nil }, + unexpected: []interface{}{ + func(*int, *bool, *bool) error { return nil }, func(*int, *string, *bool) error { return nil }, - func(*sqhook.Context, *int, *string, *bool) {}, - func(*sqhook.Context, int, string, bool) error { return nil }, - func(*sqhook.Context, *int, *bool) error { return nil }, - }, - notExpectedEpilogs: []interface{}{ - func(*error) {}, - func(*sqhook.Context, error) {}, - func(*sqhook.Context, *int) {}, - func() {}, + func(*int, *string, *bool) {}, + func(int, string, bool) error { return nil }, + func(*int, *bool) error { return nil }, + func(*int) {}, + func(*int, *string, *bool) (func(error), error) { return nil, nil }, + func(*int, string, *bool) (func(*error), error) { return nil, nil }, }, }, } { tc := tc t.Run(fmt.Sprintf("%T", tc.function), func(t *testing.T) { - t.Run("expected callbacks", func(t *testing.T) { - t.Run("non-nil prolog and epilog", func(t *testing.T) { - hook := sqhook.New(tc.function) - require.NotNil(t, hook) - err := hook.Attach(tc.expectedProlog, tc.expectedEpilog) - require.NoError(t, err) - prolog, epilog := hook.Callbacks() - require.Equal(t, reflect.ValueOf(prolog).Pointer(), reflect.ValueOf(tc.expectedProlog).Pointer()) - require.Equal(t, reflect.ValueOf(epilog).Pointer(), reflect.ValueOf(tc.expectedEpilog).Pointer()) - }) - t.Run("nil prolog", func(t *testing.T) { - hook := sqhook.New(tc.function) - require.NotNil(t, hook) - err := hook.Attach(nil, tc.expectedEpilog) - require.NoError(t, err) - prolog, epilog := hook.Callbacks() - require.Nil(t, prolog) - require.Equal(t, reflect.ValueOf(epilog).Pointer(), reflect.ValueOf(tc.expectedEpilog).Pointer()) - }) - t.Run("nil epilog", func(t *testing.T) { - hook := sqhook.New(tc.function) - require.NotNil(t, hook) - err := hook.Attach(tc.expectedProlog, nil) - require.NoError(t, err) - prolog, epilog := hook.Callbacks() - require.Equal(t, reflect.ValueOf(prolog).Pointer(), reflect.ValueOf(tc.expectedProlog).Pointer()) - require.Nil(t, epilog) - }) - t.Run("nil prolog and epilog", func(t *testing.T) { - hook := sqhook.New(tc.function) - require.NotNil(t, hook) - err := hook.Attach(nil, nil) - require.NoError(t, err) - prolog, epilog := hook.Callbacks() - require.Nil(t, prolog) - require.Nil(t, epilog) - }) + t.Run("expected callback type", func(t *testing.T) { + hook := sqhook.New(tc.function) + require.NotNil(t, hook) + err := hook.Attach(tc.expected) + require.NoError(t, err) + prolog := hook.Prolog() + require.Equal(t, reflect.ValueOf(prolog).Pointer(), reflect.ValueOf(tc.expected).Pointer()) }) - t.Run("not expected callbacks", func(t *testing.T) { - for _, notExpectedProlog := range tc.notExpectedPrologs { + t.Run("not expected callback types", func(t *testing.T) { + for _, notExpectedProlog := range tc.unexpected { notExpectedProlog := notExpectedProlog t.Run(fmt.Sprintf("%T", notExpectedProlog), func(t *testing.T) { hook := sqhook.New(tc.function) require.NotNil(t, hook) - err := hook.Attach(notExpectedProlog, tc.expectedEpilog) + err := hook.Attach(notExpectedProlog) require.Error(t, err) - prolog, epilog := hook.Callbacks() + prolog := hook.Prolog() require.Nil(t, prolog) - require.Nil(t, epilog) - }) - } - for _, notExpectedEpilog := range tc.notExpectedEpilogs { - notExpectedEpilog := notExpectedEpilog - t.Run(fmt.Sprintf("%T", notExpectedEpilog), func(t *testing.T) { - hook := sqhook.New(tc.function) - require.NotNil(t, hook) - err := hook.Attach(tc.expectedProlog, notExpectedEpilog) - require.Error(t, err) - prolog, epilog := hook.Callbacks() - require.Nil(t, prolog) - require.Nil(t, epilog) }) } }) @@ -249,33 +242,27 @@ func TestAttach(t *testing.T) { func TestEnableDisable(t *testing.T) { hook := sqhook.New(example.ExportedMethod) require.NotNil(t, hook) - err := hook.Attach(func(*sqhook.Context) error { return nil }, func(*sqhook.Context) {}) + err := hook.Attach(func(*example) (func(), error) { return nil, nil }) require.NoError(t, err) - prolog, epilog := hook.Callbacks() + prolog := hook.Prolog() require.NotNil(t, prolog) - require.NotNil(t, epilog) - hook.Attach(nil, nil) - prolog, epilog = hook.Callbacks() + err = hook.Attach(nil) + require.NoError(t, err) + prolog = hook.Prolog() require.Nil(t, prolog) - require.Nil(t, epilog) } -func TestUsage(t *testing.T) { - t.Run("attaching nil", func(t *testing.T) { - hook := sqhook.New(example.method) - require.NotNil(t, hook) - err := hook.Attach(func(*sqhook.Context) error { return nil }, nil) - require.NoError(t, err) - prolog, epilog := hook.Callbacks() - require.NotNil(t, prolog) - require.Nil(t, epilog) - err = hook.Attach(nil, nil) - require.NoError(t, err) - prolog, epilog = hook.Callbacks() - require.Nil(t, prolog) - require.Nil(t, epilog) - }) +func TestString(t *testing.T) { + hook := sqhook.New(example.ExportedMethod) + require.NotEmpty(t, hook.String()) +} + +func TestError(t *testing.T) { + err := sqhook.AbortError + require.NotEmpty(t, err.Error()) +} +func TestUsage(t *testing.T) { t.Run("hooking a function and reading and writing the arguments and return values", func(t *testing.T) { var hook *sqhook.Hook @@ -303,15 +290,15 @@ func TestUsage(t *testing.T) { example := func(a int, b string, c bool, d []byte) (e float32, f error) { { - type Prolog = func(*sqhook.Context, *int, *string, *bool, *[]byte) error - type Epilog = func(*sqhook.Context, *float32, *error) - ctx := sqhook.Context{} - prolog, epilog := hook.Callbacks() - if epilog, ok := epilog.(Epilog); ok { - defer epilog(&ctx, &e, &f) - } + type Epilog = func(*float32, *error) + type Prolog = func(*int, *string, *bool, *[]byte) (Epilog, error) + prolog := hook.Prolog() + if prolog, ok := prolog.(Prolog); ok { - err := prolog(&ctx, &a, &b, &c, &d) + epilog, err := prolog(&a, &b, &c, &d) + if epilog != nil { + defer epilog(&e, &f) + } if err != nil { return } @@ -331,7 +318,7 @@ func TestUsage(t *testing.T) { hook = sqhook.New(example) require.NotNil(t, hook) err := hook.Attach( - func(ctx *sqhook.Context, a *int, b *string, c *bool, d *[]byte) error { + func(a *int, b *string, c *bool, d *[]byte) (func(*float32, *error), error) { require.Equal(t, callA, *a) require.Equal(t, callB, *b) require.Equal(t, callC, *c) @@ -341,12 +328,11 @@ func TestUsage(t *testing.T) { *b = expectedB *c = expectedC *d = expectedD - return nil - }, - func(ctx *sqhook.Context, e *float32, f *error) { - // Modify the return values - *e = expectedE - *f = expectedF + return func(e *float32, f *error) { + // Modify the return values + *e = expectedE + *f = expectedF + }, nil }) require.NoError(t, err) diff --git a/sdk/middleware/sqhttp/http.go b/sdk/middleware/sqhttp/http.go index a6983298..d1e62cc3 100644 --- a/sdk/middleware/sqhttp/http.go +++ b/sdk/middleware/sqhttp/http.go @@ -132,15 +132,15 @@ func (AbortRequestError) Error() string { // a hook and called by the closure. func addSecurityHeaders(w http.ResponseWriter) (err error) { { - type Prolog = func(*sqhook.Context, *http.ResponseWriter) error - type Epilog = func(*sqhook.Context, *error) - ctx := sqhook.Context{} - prolog, epilog := addSecurityHeaderHook.Callbacks() - if epilog, ok := epilog.(Epilog); ok { - defer epilog(&ctx, &err) - } + type Epilog = func(*error) + type Prolog = func(*http.ResponseWriter) (Epilog, error) + prolog := addSecurityHeaderHook.Prolog() if prolog, ok := prolog.(Prolog); ok { - if err := prolog(&ctx, &w); err != nil { + epilog, err := prolog(&w) + if epilog != nil { + defer epilog(&err) + } + if err != nil { return err } } @@ -167,15 +167,15 @@ type responseWriter struct { func (w responseWriter) WriteHeader(statusCode int) { { - type Prolog = func(*sqhook.Context, *int) error - type Epilog = func(*sqhook.Context) - ctx := sqhook.Context{sqhook.MethodReceiver(&w)} - prolog, epilog := responseWriterWriteHeader.Callbacks() - if epilog, ok := epilog.(Epilog); ok { - defer epilog(&ctx) - } + type Epilog = func() + type Prolog = func(sqhook.MethodReceiver, *int) (Epilog, error) + prolog := responseWriterWriteHeader.Prolog() if prolog, ok := prolog.(Prolog); ok { - if err := prolog(&ctx, &statusCode); err != nil { + epilog, err := prolog(sqhook.MethodReceiver{&w}, &statusCode) + if epilog != nil { + defer epilog() + } + if err != nil { return } } From d2aa5d39a6319970518b206f430d1fe4cc66eac4 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Mon, 22 Jul 2019 19:02:35 +0200 Subject: [PATCH 37/47] agent/client: forbid non-printable characters in app-name --- agent/internal/backend/client_test.go | 75 +++++++++++++++++++++++++++ agent/internal/client.go | 35 +++++++++---- 2 files changed, 99 insertions(+), 11 deletions(-) diff --git a/agent/internal/backend/client_test.go b/agent/internal/backend/client_test.go index dbc266e3..2e6cfa34 100644 --- a/agent/internal/backend/client_test.go +++ b/agent/internal/backend/client_test.go @@ -12,6 +12,7 @@ import ( fuzz "github.com/google/gofuzz" . "github.com/onsi/gomega" "github.com/onsi/gomega/ghttp" + "github.com/sqreen/go-agent/agent/internal" "github.com/sqreen/go-agent/agent/internal/backend" "github.com/sqreen/go-agent/agent/internal/backend/api" "github.com/sqreen/go-agent/agent/internal/config" @@ -350,3 +351,77 @@ func FuzzRule(e *api.Rule, c fuzz.Continue) { c.Fuzz(e) e.Signature = api.RuleSignature{ECDSASignature: api.ECDSASignature{Message: []byte(`{}`)}} } + +func TestValidateCredentialsConfiguration(t *testing.T) { + for _, tc := range []struct { + Name, Token, AppName string + ShouldFail bool + }{ + { + Name: "valid org token strings", + Token: "org_ok", + AppName: "ok", + }, + { + Name: "valid non-org", + Token: "ok", + }, + { + Name: "invalid credentials with empty strings", + Token: "", + AppName: "", + ShouldFail: true, + }, + { + Name: "invalid credentials with empty token and non-empty app-name", + Token: "", + AppName: "ok", + ShouldFail: true, + }, + { + Name: "invalid credentials with valid org token but empty app-name", + Token: "org_ok", + AppName: "", + ShouldFail: true, + }, + { + Name: "invalid credentials with token ok but invalid app-name", + Token: "org_ok", + AppName: "ko\nko", + ShouldFail: true, + }, + { + Name: "invalid credentials with token ok but invalid app-name", + Token: "org_ok", + AppName: "koko\x00\x01\x02ok", + ShouldFail: true, + }, + { + Name: "invalid credentials with token ok but invalid app-name", + Token: "org_ok", + AppName: "koko\tok", + ShouldFail: true, + }, + { + Name: "valid credentials with a space in app-name", + Token: "org_ok", + AppName: "ok ok ok", + }, + { + Name: "invalid credentials with invalid token character", + Token: "org_ok\nko", + AppName: "ok", + ShouldFail: true, + }, + } { + tc := tc + t.Run(tc.Name, func(t *testing.T) { + err := internal.ValidateCredentialsConfiguration(tc.Token, tc.AppName) + if tc.ShouldFail { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/agent/internal/client.go b/agent/internal/client.go index cdcb8d4f..41213c89 100644 --- a/agent/internal/client.go +++ b/agent/internal/client.go @@ -12,6 +12,7 @@ import ( "net/http" "strings" "time" + "unicode" "github.com/pkg/errors" "github.com/sqreen/go-agent/agent/internal/app" @@ -19,6 +20,7 @@ import ( "github.com/sqreen/go-agent/agent/internal/backend/api" "github.com/sqreen/go-agent/agent/internal/config" "github.com/sqreen/go-agent/agent/internal/plog" + "github.com/sqreen/go-agent/agent/sqlib/sqerrors" "github.com/sqreen/go-agent/agent/sqlib/sqtime" ) @@ -99,11 +101,6 @@ func appLogin(ctx context.Context, logger *plog.Logger, client *backend.Client, } } -var ( - ErrMissingAppName = errors.New("missing application name") - ErrMissingToken = errors.New("missing token") -) - type InvalidCredentialsConfiguration struct { error } @@ -117,16 +114,32 @@ func (e InvalidCredentialsConfiguration) Cause() error { } func ValidateCredentialsConfiguration(token, appName string) (err error) { + defer func() { + if err != nil { + err = InvalidCredentialsConfiguration{err} + } + }() + if token == "" { - err = ErrMissingToken - } else if strings.HasPrefix(token, config.BackendHTTPAPIOrganizationTokenPrefix) && appName == "" { - err = ErrMissingAppName + return sqerrors.New("missing application name") + } + if strings.HasPrefix(token, config.BackendHTTPAPIOrganizationTokenPrefix) && appName == "" { + return sqerrors.New("missing token") } - if err == nil { - return err + + for _, r := range appName { + if !unicode.IsPrint(r) { + return sqerrors.Errorf("forbidden non-printable character `%q` in the application name `%q`", r, appName) + } + } + + for _, r := range token { + if !unicode.IsPrint(r) { + return sqerrors.Errorf("forbidden non-printable character `%q` in the token `%q`", r, token) + } } - return InvalidCredentialsConfiguration{err} + return nil } // TrySendAppException is a special client function allowing to send app-level From fbadc4075facb1a0c712c3f7cfbedf36b8f0dfaa Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Mon, 22 Jul 2019 23:06:12 +0200 Subject: [PATCH 38/47] go: update go.mod --- go.mod | 1 - go.sum | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/go.mod b/go.mod index d442c4e1..4a40d31f 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,6 @@ require ( github.com/valyala/fasttemplate v0.0.0-20170224212429-dcecefd839c4 // indirect golang.org/x/net v0.0.0-20190311183353-d8887717615a golang.org/x/xerrors v0.0.0-20190315151331-d61658bd2e18 - google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8 google.golang.org/grpc v1.20.1 gopkg.in/go-playground/assert.v1 v1.2.1 // indirect gopkg.in/go-playground/validator.v8 v8.18.2 // indirect diff --git a/go.sum b/go.sum index 64ff2f0d..19d4f0c0 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,7 @@ github.com/gin-gonic/gin v1.3.0 h1:kCmZyPklC0gVdL728E6Aj20uYBJV93nj/TkwBTKhFbs= github.com/gin-gonic/gin v1.3.0/go.mod h1:7cKuhb5qV2ggCFctp2fJQ+ErvciLZrIeoOSOm6mUr7Y= github.com/gogo/protobuf v1.2.0 h1:xU6/SpYbvkNYiptHJYEDRseDLvYE7wSqhYYNy0QSUzI= github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= @@ -91,8 +92,6 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90Pveol golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181220203305-927f97764cc3 h1:eH6Eip3UpmR+yM/qI9Ijluzb1bNv/cAU/n+6l8tRSis= -golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= From 888458d1ebb9f180e5f468bd394586ed71b56f27 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Tue, 23 Jul 2019 13:19:22 +0200 Subject: [PATCH 39/47] agent/whitelist: fix a dev regression not properly whitelisting non-sdk events The solution is simple for now and should be changed in favor of an agent middleware in the future. --- agent/internal/agent.go | 3 +-- agent/internal/metrics.go | 12 +----------- agent/internal/metrics/metrics.go | 10 ++++++---- agent/internal/request.go | 5 +++++ agent/types/types.go | 3 +++ sdk/middleware/sqhttp/http.go | 10 +++++++--- sdk/record.go | 7 +++++++ sdk/sdk_test.go | 23 ----------------------- tools/testlib/agent.go | 11 +++++++++++ 9 files changed, 41 insertions(+), 43 deletions(-) diff --git a/agent/internal/agent.go b/agent/internal/agent.go index 472d2706..ae762d0a 100644 --- a/agent/internal/agent.go +++ b/agent/internal/agent.go @@ -199,8 +199,7 @@ func (a *Agent) NewRequestRecord(req *http.Request) types.RequestRecord { } if whitelisted { a.addWhitelistEvent(matched) - return &WhitelistedHTTPRequestRecord{ - } + return &WhitelistedHTTPRequestRecord{} } return &HTTPRequestRecord{ request: req, diff --git a/agent/internal/metrics.go b/agent/internal/metrics.go index beec4470..a762f585 100644 --- a/agent/internal/metrics.go +++ b/agent/internal/metrics.go @@ -28,20 +28,10 @@ func (a *Agent) addUserEvent(event userEventFace) { store.Add(event, 1) } -type WhitelistedIP struct { - MatchedWhitelistEntry string -} - -func (m WhitelistedIP) bucketID() (string, error) { - return m.MatchedWhitelistEntry, nil -} - func (a *Agent) addWhitelistEvent(matchedWhitelistEntry string) { if a.config.Disable() || a.metrics == nil { // Agent is disabled or not yet initialized return } - a.staticMetrics.whitelistedIP.Add(WhitelistedIP{ - MatchedWhitelistEntry: matchedWhitelistEntry, - }, 1) + a.staticMetrics.whitelistedIP.Add(matchedWhitelistEntry, 1) } diff --git a/agent/internal/metrics/metrics.go b/agent/internal/metrics/metrics.go index 013637f3..ec00996b 100644 --- a/agent/internal/metrics/metrics.go +++ b/agent/internal/metrics/metrics.go @@ -163,14 +163,16 @@ func (s *Store) Add(key interface{}, delta uint64) error { // Note that this is not possible to unlock and perform the atomic // operation because of possible concurrent `Flush()`. } else { + if l := len(s.set); l == 0 { + // Set the deadline when the first value inserted into the metrics store + s.deadline = time.Now().Add(s.period) + } else { + + } // The value still doesn't exist and we need to insert it into the store's // map. value := delta s.set[key] = &value - // Set the deadline when the first value inserted into the metrics store - if s.deadline.IsZero() { - s.deadline = time.Now().Add(s.period) - } } } diff --git a/agent/internal/request.go b/agent/internal/request.go index b978573c..9fc5ce9d 100644 --- a/agent/internal/request.go +++ b/agent/internal/request.go @@ -227,6 +227,10 @@ func (ctx *HTTPRequestRecord) Close() { ctx.agent.AddHTTPRequestRecordEvent(NewHTTPRequestRecordEvent(ctx, ctx.agent.RulespackID())) } +func (ctx *HTTPRequestRecord) Whitelisted() bool { + return false +} + func (ctx *HTTPRequestRecord) addSilentEvent(event *HTTPRequestEvent) { ctx.addEvent_(event, true) } @@ -321,6 +325,7 @@ func (WhitelistedHTTPRequestRecord) Close() { func (WhitelistedHTTPRequestRecord) WithTimestamp(time.Time) {} func (WhitelistedHTTPRequestRecord) WithProperties(types.EventProperties) {} func (WhitelistedHTTPRequestRecord) WithUserIdentifiers(map[string]string) {} +func (WhitelistedHTTPRequestRecord) Whitelisted() bool { return true } type getClientIPConfigFace interface { HTTPClientIPHeader() string diff --git a/agent/types/types.go b/agent/types/types.go index 8a4874d5..b88a71ae 100644 --- a/agent/types/types.go +++ b/agent/types/types.go @@ -46,6 +46,9 @@ type RequestRecord interface { UserSecurityResponse() http.Handler // Close needs to be called when the request is done. Close() + // Whitelisted returns true when the request is whitelisted, false otherwise. + // TODO: move the sqhttp middleware into the agent to get rid of this method + Whitelisted() bool } type CustomEvent interface { diff --git a/sdk/middleware/sqhttp/http.go b/sdk/middleware/sqhttp/http.go index d1e62cc3..6eee4708 100644 --- a/sdk/middleware/sqhttp/http.go +++ b/sdk/middleware/sqhttp/http.go @@ -65,16 +65,20 @@ func MiddlewareWithError(next Handler) Handler { // TODO: move this middleware function into the agent internal package (which // needs restructuring the SDK) return HandlerFunc(func(w http.ResponseWriter, r *http.Request) (err error) { - if err := addSecurityHeaders(w); err != nil { - return err - } // Create a new sqreen request wrapper. req := sdk.NewHTTPRequest(r) + if req.Record().Whitelisted() { + return next.ServeHTTP(w, r) + } defer req.Close() // Use the newly created request compliant with `sdk.FromContext()`. r = req.Request() // Wrap the response writer to monitor the http status codes. w = ResponseWriter{w} + // Add security headers + if err := addSecurityHeaders(w); err != nil { + return err + } // Check if an early security action is already required such as based on // the request IP address. if handler := req.SecurityResponse(); handler != nil { diff --git a/sdk/record.go b/sdk/record.go index c6ccb3ee..508fbd0a 100644 --- a/sdk/record.go +++ b/sdk/record.go @@ -106,3 +106,10 @@ func (ctx *HTTPRequestRecord) ForUser(id EventUserIdentifiersMap) *UserHTTPReque id: id, } } + +func (ctx *HTTPRequestRecord) Whitelisted() bool { + if ctx == nil { + return true + } + return ctx.record.Whitelisted() +} diff --git a/sdk/sdk_test.go b/sdk/sdk_test.go index 836f8833..0564186c 100644 --- a/sdk/sdk_test.go +++ b/sdk/sdk_test.go @@ -312,29 +312,6 @@ func TestDisabled(t *testing.T) { req.Close() sdk.GracefulStop() }) - - t.Run("with a whitelisted request", func(t *testing.T) { - agent := &testlib.AgentMockup{} - defer agent.AssertExpectations(t) - sdk.SetAgent(agent) - - record := whitelistedRecord{} - agent.ExpectNewRequestRecord(mock.Anything).Return(record).Once() - - req := newTestRequest() - sqReq := sdk.NewHTTPRequest(req) - require.NotNil(t, sqReq) - req = sqReq.Request() - require.NotNil(t, req) - - // When getting the SDK context out of the request wrapper. - require.NotPanics(t, useTheSDK(t, sqReq.Record())) - - // Other methods - sqReq.SecurityResponse() - sqReq.UserSecurityResponse() - sqReq.Close() - }) } func TestEventPropertyMap(t *testing.T) { diff --git a/tools/testlib/agent.go b/tools/testlib/agent.go index 4d580cfd..d02fecc7 100644 --- a/tools/testlib/agent.go +++ b/tools/testlib/agent.go @@ -46,6 +46,7 @@ func NewAgentForMiddlewareTestsWithoutSecurityResponse() (*AgentMockup, *HTTPReq agent := &AgentMockup{} record := &HTTPRequestRecordMockup{} agent.ExpectNewRequestRecord(mock.Anything).Return(record).Once() + record.ExpectWhitelisted().Return(false).Once() record.ExpectSecurityResponse().Return(nil).Once() record.ExpectUserSecurityResponse().Return(nil).Maybe() // Some tests don't call it, such as those returning a handler error record.ExpectClose().Once() @@ -56,6 +57,7 @@ func NewAgentForMiddlewareTestsWithSecurityResponse(actionHandler http.Handler) agent := &AgentMockup{} record := &HTTPRequestRecordMockup{} agent.ExpectNewRequestRecord(mock.Anything).Return(record).Once() + record.ExpectWhitelisted().Return(false).Once() record.ExpectSecurityResponse().Return(actionHandler).Once() record.ExpectClose().Once() return agent, record @@ -65,6 +67,7 @@ func NewAgentForMiddlewareTestsWithUserSecurityResponse(actionHandler http.Handl agent := &AgentMockup{} record := &HTTPRequestRecordMockup{} agent.ExpectNewRequestRecord(mock.Anything).Return(record).Once() + record.ExpectWhitelisted().Return(false).Once() record.ExpectSecurityResponse().Return(nil).Once() record.ExpectUserSecurityResponse().Return(actionHandler) record.ExpectClose().Once() @@ -87,6 +90,14 @@ func (r *HTTPRequestRecordMockup) ExpectTrackEvent(event string) *mock.Call { return r.On("NewCustomEvent", event) } +func (r *HTTPRequestRecordMockup) Whitelisted() bool { + return r.Called().Bool(0) +} + +func (r *HTTPRequestRecordMockup) ExpectWhitelisted() *mock.Call { + return r.On("Whitelisted") +} + func (r *HTTPRequestRecordMockup) Close() { r.Called() } From 101f84510d094573c8c3af75b93c273fde5a8a83 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Tue, 23 Jul 2019 16:01:22 +0200 Subject: [PATCH 40/47] agent/metrics: add max metrics store length error handling When the metrics store map is bigger than the maximum configured length, an error is returned. This error should not be logger throught the usual "agent exception" management as it is supposed to happen under stressed conditions, which means that adding the error to a queue would not help at all, and it is rather now possible to aggregate the error into another metrics store. The returned error can therefore be used as a key that should aggregate with every other error of the same type, so that memory is not an issue and new insertions to the error metrics store is still rare. --- agent/internal/agent.go | 7 +-- agent/internal/config/config.go | 22 ++++++-- agent/internal/metrics.go | 44 ++++++++++++++-- agent/internal/metrics/metrics.go | 36 +++++++++---- agent/internal/metrics/metrics_test.go | 72 +++++++++++++++++++++++--- agent/internal/rule/callback.go | 18 +++++-- agent/internal/rule/callback_test.go | 2 +- agent/internal/rule/rule.go | 8 +-- agent/internal/rule/rule_test.go | 14 ++--- 9 files changed, 179 insertions(+), 44 deletions(-) diff --git a/agent/internal/agent.go b/agent/internal/agent.go index ae762d0a..0728d25e 100644 --- a/agent/internal/agent.go +++ b/agent/internal/agent.go @@ -140,7 +140,7 @@ type Agent struct { } type staticMetrics struct { - sdkUserLoginSuccess, sdkUserLoginFailure, sdkUserSignup, whitelistedIP *metrics.Store + sdkUserLoginSuccess, sdkUserLoginFailure, sdkUserSignup, whitelistedIP, errors *metrics.Store } // Error channel buffer length. @@ -154,7 +154,7 @@ func New(cfg *config.Config) *Agent { return nil } - metrics := metrics.NewEngine(logger) + metrics := metrics.NewEngine(logger, cfg.MaxMetricsStoreLength()) publicKey, err := rule.NewECDSAPublicKey(config.PublicKey) if err != nil { @@ -179,6 +179,7 @@ func New(cfg *config.Config) *Agent { sdkUserLoginFailure: metrics.NewStore("sdk-login-fail", sdkMetricsPeriod), sdkUserSignup: metrics.NewStore("sdk-signup", sdkMetricsPeriod), whitelistedIP: metrics.NewStore("whitelisted", sdkMetricsPeriod), + errors: metrics.NewStore("whitelisted", config.ErrorMetricsPeriod), }, ctx: ctx, cancel: cancel, @@ -386,7 +387,7 @@ func (a *Agent) RulesReload() error { } } - a.rules.SetRules(rulespack.PackID, rulespack.Rules) + a.rules.SetRules(rulespack.PackID, rulespack.Rules, a.staticMetrics.errors) return nil } diff --git a/agent/internal/config/config.go b/agent/internal/config/config.go index 102543f1..8af7c372 100644 --- a/agent/internal/config/config.go +++ b/agent/internal/config/config.go @@ -45,10 +45,11 @@ type HTTPAPIEndpoint struct { Method, URL string } -const ( - // Default value of network timeouts. - DefaultNetworkTimeout = 5 * time.Second -) +// Error metrics store period. +const ErrorMetricsPeriod = time.Minute + +// Default value of network timeouts. +const DefaultNetworkTimeout = 5 * time.Second // Backend client configuration. var ( @@ -214,6 +215,7 @@ const ( configKeyStripHTTPReferer = `strip_http_referer` configKeyRules = `rules` configKeySDKMetricsPeriod = `sdk_metrics_period` + configKeyMaxMetricsStoreLength = `max_metrics_store_length` ) // User configuration's default values. @@ -221,6 +223,7 @@ const ( configDefaultBackendHTTPAPIBaseURL = `https://back.sqreen.com` configDefaultLogLevel = `info` configDefaultSDKMetricsPeriod = 60 + configDefaultMaxMetricsStoreLength = 100 * 1024 * 1024 ) func New(logger *plog.Logger) *Config { @@ -259,6 +262,7 @@ func New(logger *plog.Logger) *Config { manager.SetDefault(configKeyStripHTTPReferer, "") manager.SetDefault(configKeyRules, "") manager.SetDefault(configKeySDKMetricsPeriod, configDefaultSDKMetricsPeriod) + manager.SetDefault(configKeyMaxMetricsStoreLength, configDefaultMaxMetricsStoreLength) err := manager.ReadInConfig() if err != nil { @@ -333,6 +337,16 @@ func (c *Config) SDKMetricsPeriod() int { return p } +// MaxMetricsStoreLength returns the maximum length a metrics store should not +// exceed. After this limit, new metrics values will be dropped. +func (c *Config) MaxMetricsStoreLength() uint { + n := c.GetInt(configKeyMaxMetricsStoreLength) + if n < 0 { + n = 0 + } + return uint(n) +} + func sanitizeString(s string) string { return strings.TrimSpace(s) } diff --git a/agent/internal/metrics.go b/agent/internal/metrics.go index a762f585..914ef3b4 100644 --- a/agent/internal/metrics.go +++ b/agent/internal/metrics.go @@ -4,28 +4,49 @@ package internal -import "github.com/sqreen/go-agent/agent/internal/metrics" +import ( + "github.com/sqreen/go-agent/agent/internal/metrics" + "github.com/sqreen/go-agent/agent/sqlib/sqerrors" +) func (a *Agent) addUserEvent(event userEventFace) { if a.config.Disable() || a.metrics == nil { // Disabled or not yet initialized agent return } - var store *metrics.Store + var ( + store *metrics.Store + logFmt string + ) switch actual := event.(type) { case *authUserEvent: if actual.loginSuccess { store = a.staticMetrics.sdkUserLoginSuccess + logFmt = "user event: user login success `%v`" } else { store = a.staticMetrics.sdkUserLoginFailure + logFmt = "user event: user login failure `%v`" } case *signupUserEvent: store = a.staticMetrics.sdkUserSignup + logFmt = "user event: user signup `%v`" default: - // TODO: log error + a.logger.Error(sqerrors.Errorf("user event: unexpected user event type `%T`", actual)) return } - store.Add(event, 1) + a.logger.Debug(logFmt, event) + if err := store.Add(event, 1); err != nil { + sqErr := sqerrors.Wrap(err, "user event: could not update the user metrics store") + switch actualErr := err.(type) { + case metrics.MaxMetricsStoreLengthError: + a.logger.Debug(sqErr) + if err := a.staticMetrics.errors.Add(actualErr, 1); err != nil { + a.logger.Debugf("could not update the metrics store: %v", err) + } + default: + a.logger.Error(sqErr) + } + } } func (a *Agent) addWhitelistEvent(matchedWhitelistEntry string) { @@ -33,5 +54,18 @@ func (a *Agent) addWhitelistEvent(matchedWhitelistEntry string) { // Agent is disabled or not yet initialized return } - a.staticMetrics.whitelistedIP.Add(matchedWhitelistEntry, 1) + a.logger.Debug("request whitelisted for `%s`", matchedWhitelistEntry) + err := a.staticMetrics.whitelistedIP.Add(matchedWhitelistEntry, 1) + if err != nil { + sqErr := sqerrors.Wrap(err, "whitelist event: could not update the whitelist metrics store") + switch actualErr := err.(type) { + case metrics.MaxMetricsStoreLengthError: + a.logger.Debug(sqErr) + if err := a.staticMetrics.errors.Add(actualErr, 1); err != nil { + a.logger.Debugf("could not update the metrics store: %v", err) + } + default: + a.logger.Error(sqErr) + } + } } diff --git a/agent/internal/metrics/metrics.go b/agent/internal/metrics/metrics.go index ec00996b..95a5e4ed 100644 --- a/agent/internal/metrics/metrics.go +++ b/agent/internal/metrics/metrics.go @@ -48,6 +48,7 @@ package metrics import ( + "fmt" "reflect" "sync" "sync/atomic" @@ -61,20 +62,22 @@ import ( // the existing ones. Engine's methods are not thread-safe and designed to be // used by a single goroutine. type Engine struct { - logger plog.DebugLogger - stores map[string]*Store + logger plog.DebugLogger + stores map[string]*Store + maxMetricsStoreLen uint } -func NewEngine(logger plog.DebugLogger) *Engine { +func NewEngine(logger plog.DebugLogger, maxMetricsStoreLen uint) *Engine { return &Engine{ - logger: logger, - stores: make(map[string]*Store), + logger: logger, + stores: make(map[string]*Store), + maxMetricsStoreLen: maxMetricsStoreLen, } } // NewStore creates and registers a new metrics store. func (e *Engine) NewStore(id string, period time.Duration) *Store { - store := newStore(period) + store := newStore(period, e.maxMetricsStoreLen) e.stores[id] = store return store } @@ -109,18 +112,30 @@ type Store struct { deadline time.Time // Minimum time duration the data should be kept. period time.Duration + // Maximum map length. New keys are dropped when reached. + // The length is unlimited when 0. + maxLen uint } type StoreMap map[interface{}]*uint64 type ReadyStoreMap map[interface{}]uint64 -func newStore(period time.Duration) *Store { +func newStore(period time.Duration, maxLen uint) *Store { return &Store{ set: make(StoreMap), period: period, + maxLen: maxLen, } } +type MaxMetricsStoreLengthError struct { + MaxLen uint +} + +func (e MaxMetricsStoreLengthError) Error() string { + return fmt.Sprintf("new metrics store key dropped as the metrics store has reached its maximum length `%d`", e.MaxLen) +} + // Add delta to the given key, inserting it if it doesn't exist. This method // is thread-safe and optimized for updating existing key which is lock-free // when not concurrently retrieving (method `Flush()`) or inserting a new key. @@ -164,10 +179,11 @@ func (s *Store) Add(key interface{}, delta uint64) error { // operation because of possible concurrent `Flush()`. } else { if l := len(s.set); l == 0 { - // Set the deadline when the first value inserted into the metrics store + // Set the deadline when the first valuMaxMetricsStoreLengthe inserted into the metrics store s.deadline = time.Now().Add(s.period) - } else { - + } else if s.maxLen > 0 && uint(l) >= s.maxLen { + // The maximum length is reached - no more new insertions are allowed + return MaxMetricsStoreLengthError{MaxLen: s.maxLen} } // The value still doesn't exist and we need to insert it into the store's // map. diff --git a/agent/internal/metrics/metrics_test.go b/agent/internal/metrics/metrics_test.go index ebc800d6..b2da966d 100644 --- a/agent/internal/metrics/metrics_test.go +++ b/agent/internal/metrics/metrics_test.go @@ -21,7 +21,7 @@ import ( var logger = plog.NewLogger(plog.Debug, os.Stderr, 0) func TestUsage(t *testing.T) { - engine := metrics.NewEngine(logger) + engine := metrics.NewEngine(logger, 100000000) t.Run("store usage", func(t *testing.T) { t.Run("empty stores are never ready", func(t *testing.T) { @@ -114,9 +114,19 @@ func TestUsage(t *testing.T) { d [33]byte } + type T1 struct{} + type T2 struct{} + var v1, v2 interface{} = T1{}, T2{} + ptr := &Struct{} require.NotPanics(t, func() { + require.NoError(t, store.Add("string", 1)) + require.NoError(t, store.Add(T1{}, 1)) + require.NoError(t, store.Add(v1, 3)) + require.NoError(t, store.Add(T2{}, 3)) + require.NoError(t, store.Add(v2, 5)) + require.NoError(t, store.Add("string", 1)) require.NoError(t, store.Add("string", 1)) require.NoError(t, store.Add(33, 1)) require.NoError(t, store.Add(Struct{ @@ -134,7 +144,7 @@ func TestUsage(t *testing.T) { require.True(t, store.Ready()) old := store.Flush() require.Equal(t, metrics.ReadyStoreMap{ - "string": 1, + "string": 3, 33: 1, Struct{ a: 33, @@ -142,7 +152,9 @@ func TestUsage(t *testing.T) { c: 4.815162342, d: [33]byte{}, }: 1, - ptr: 1, + ptr: 1, + v1: 4, + T2{}: 8, }, old.Metrics()) }) }) @@ -152,7 +164,7 @@ func TestUsage(t *testing.T) { // Create a store that will be checked more often than actually required by // its period. So that we cover the case where the store is not always // ready. - engine := metrics.NewEngine(logger) + engine := metrics.NewEngine(logger, 100000000) // The reader will be awaken 4 times per store period, so only it will see // a ready store only once out of four. readerPeriod := time.Microsecond @@ -249,10 +261,58 @@ func TestUsage(t *testing.T) { require.Equal(t, uint64(nbWriters), v) } }) + + t.Run("metrics store max length and store error aggregation", func(t *testing.T) { + var maxLen uint = 3 + engine := metrics.NewEngine(logger, maxLen) + period := time.Millisecond + s1 := engine.NewStore("s1", period) + errors := engine.NewStore("errors", period) + + require.NoError(t, s1.Add("k1", 1)) + require.NoError(t, s1.Add("k1", 1)) + require.NoError(t, s1.Add("k1", 1)) + require.NoError(t, s1.Add("k1", 1)) + require.NoError(t, s1.Add("k2", 1)) + require.NoError(t, s1.Add("k3", 33)) + + err := s1.Add("k4", 2) + require.Error(t, err) + require.NoError(t, errors.Add(err, 1)) + + err = s1.Add("k4", 55) + require.Error(t, err) + require.NoError(t, errors.Add(err, 1)) + + err = s1.Add("k4", 1) + require.Error(t, err) + require.NoError(t, errors.Add(err, 1)) + + require.NoError(t, s1.Add("k1", 1)) + require.NoError(t, s1.Add("k1", 1)) + require.NoError(t, s1.Add("k1", 1)) + require.NoError(t, s1.Add("k1", 1)) + require.NoError(t, s1.Add("k2", 1)) + require.NoError(t, s1.Add("k3", 33)) + + time.Sleep(period) + require.True(t, s1.Ready()) + _ = s1.Flush() + + require.NoError(t, s1.Add("k4", 2)) + require.NoError(t, s1.Add("k4", 1)) + require.NoError(t, s1.Add("k4", 5)) + + // Errors were properly aggregated + require.True(t, errors.Ready()) + readyErrors := errors.Flush() + + require.Equal(t, metrics.ReadyStoreMap{metrics.MaxMetricsStoreLengthError{MaxLen: maxLen}: 3}, readyErrors.Metrics()) + }) } func BenchmarkStore(b *testing.B) { - engine := metrics.NewEngine(logger) + engine := metrics.NewEngine(logger, 100000000) type structKeyType struct { n int @@ -472,7 +532,7 @@ func BenchmarkStore(b *testing.B) { } func BenchmarkUsage(b *testing.B) { - engine := metrics.NewEngine(logger) + engine := metrics.NewEngine(logger, 100000000) for p := 1; p <= 1000; p *= 10 { p := p diff --git a/agent/internal/rule/callback.go b/agent/internal/rule/callback.go index 1998be34..576bb09a 100644 --- a/agent/internal/rule/callback.go +++ b/agent/internal/rule/callback.go @@ -11,7 +11,6 @@ import ( "github.com/sqreen/go-agent/agent/internal/backend/api" "github.com/sqreen/go-agent/agent/internal/metrics" - "github.com/sqreen/go-agent/agent/internal/plog" "github.com/sqreen/go-agent/agent/internal/rule/callback" "github.com/sqreen/go-agent/agent/sqlib/sqerrors" "github.com/sqreen/go-agent/agent/sqlib/sqhook" @@ -45,11 +44,12 @@ type CallbackContext struct { config interface{} metricsStores map[string]*metrics.Store defaultMetricsStore *metrics.Store - logger plog.ErrorLogger + errorMetricsStore *metrics.Store + logger Logger name string } -func NewCallbackContext(r *api.Rule, logger plog.ErrorLogger, metricsEngine *metrics.Engine) *CallbackContext { +func NewCallbackContext(r *api.Rule, logger Logger, metricsEngine *metrics.Engine, errorMetricsStore *metrics.Store) *CallbackContext { config := newCallbackConfig(&r.Data) var ( @@ -68,6 +68,7 @@ func NewCallbackContext(r *api.Rule, logger plog.ErrorLogger, metricsEngine *met config: config, metricsStores: metricsStores, defaultMetricsStore: defaultMetricsStore, + errorMetricsStore: errorMetricsStore, name: r.Name, logger: logger, } @@ -93,6 +94,15 @@ func (d *CallbackContext) Config() interface{} { func (d *CallbackContext) PushMetricsValue(key interface{}, value uint64) { err := d.defaultMetricsStore.Add(key, value) if err != nil { - d.logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule `%s`: could not add a value to the default metrics store", d.name))) + sqErr := sqerrors.Wrap(err, fmt.Sprintf("rule `%s`: could not add a value to the default metrics store", d.name)) + switch actualErr := err.(type) { + case metrics.MaxMetricsStoreLengthError: + d.logger.Debug(sqErr) + if err := d.errorMetricsStore.Add(actualErr, 1); err != nil { + d.logger.Debugf("could not update the error metrics store: %v", err) + } + default: + d.logger.Error(sqErr) + } } } diff --git a/agent/internal/rule/callback_test.go b/agent/internal/rule/callback_test.go index 61e051d6..781a816f 100644 --- a/agent/internal/rule/callback_test.go +++ b/agent/internal/rule/callback_test.go @@ -40,7 +40,7 @@ func TestNewCallbacks(t *testing.T) { {&api.CustomErrorPageRuleDataEntry{}}, }, }, - }, nil, nil), + }, nil, nil, nil), shouldSucceed: true, }, } { diff --git a/agent/internal/rule/rule.go b/agent/internal/rule/rule.go index f4386896..3e4dae9c 100644 --- a/agent/internal/rule/rule.go +++ b/agent/internal/rule/rule.go @@ -62,9 +62,9 @@ func (e *Engine) PackID() string { // SetRules set the currents rules. If rules were already set, it will replace // them by atomically modifying the hooks, and removing what is left. -func (e *Engine) SetRules(packID string, rules []api.Rule) { +func (e *Engine) SetRules(packID string, rules []api.Rule, errorMetricsStore *metrics.Store) { // Create the net rule descriptors and replace the existing ones - ruleDescriptors := newHookDescriptors(e.logger, rules, e.publicKey, e.metricsEngine) + ruleDescriptors := newHookDescriptors(e.logger, rules, e.publicKey, e.metricsEngine, errorMetricsStore) e.setRules(packID, ruleDescriptors) } @@ -100,7 +100,7 @@ func (e *Engine) setRules(packID string, descriptors hookDescriptors) { // newHookDescriptors walks the list of received rules and creates the map of // hook descriptors indexed by their hook pointer. A hook descriptor contains // all it takes to enable and disable rules at run time. -func newHookDescriptors(logger Logger, rules []api.Rule, publicKey *ecdsa.PublicKey, metricsEngine *metrics.Engine) hookDescriptors { +func newHookDescriptors(logger Logger, rules []api.Rule, publicKey *ecdsa.PublicKey, metricsEngine *metrics.Engine, errorMetricsStore *metrics.Store) hookDescriptors { // Create and configure the list of callbacks according to the given rules var hookDescriptors = make(hookDescriptors) for i := len(rules) - 1; i >= 0; i-- { @@ -120,7 +120,7 @@ func newHookDescriptors(logger Logger, rules []api.Rule, publicKey *ecdsa.Public } // Instantiate the callback nextProlog := hookDescriptors.Get(hook) - ruleDescriptor := NewCallbackContext(&r, logger, metricsEngine) + ruleDescriptor := NewCallbackContext(&r, logger, metricsEngine, errorMetricsStore) prolog, err := NewCallbacks(hookpoint.Callback, ruleDescriptor, nextProlog) if err != nil { logger.Error(sqerrors.Wrap(err, fmt.Sprintf("rule `%s`: could not instantiate the callbacks", r.Name))) diff --git a/agent/internal/rule/rule_test.go b/agent/internal/rule/rule_test.go index 4ef0c5e0..9749a4a0 100644 --- a/agent/internal/rule/rule_test.go +++ b/agent/internal/rule/rule_test.go @@ -36,7 +36,7 @@ func TestEngineUsage(t *testing.T) { publicKey := &privateKey.PublicKey logger := plog.NewLogger(plog.Debug, os.Stderr, 0) - engine := rule.NewEngine(logger, metrics.NewEngine(plog.NewLogger(plog.Debug, os.Stderr, 0)), publicKey) + engine := rule.NewEngine(logger, metrics.NewEngine(plog.NewLogger(plog.Debug, os.Stderr, 0), 100000000), publicKey) hookFunc1 := sqhook.New(func1) require.NotNil(t, hookFunc1) hookFunc2 := sqhook.New(func2) @@ -44,13 +44,13 @@ func TestEngineUsage(t *testing.T) { t.Run("empty state", func(t *testing.T) { require.Empty(t, engine.PackID()) - engine.SetRules("my pack id", nil) + engine.SetRules("my pack id", nil, nil) require.Equal(t, engine.PackID(), "my pack id") // No problem enabling/disabling the engine engine.Enable() engine.Disable() engine.Enable() - engine.SetRules("my other pack id", []api.Rule{}) + engine.SetRules("my other pack id", []api.Rule{}, nil) require.Equal(t, engine.PackID(), "my other pack id") }) @@ -85,7 +85,7 @@ func TestEngineUsage(t *testing.T) { }, Signature: MakeSignature(privateKey, `{"name":"another valid rule"}`), }, - }) + }, nil) t.Run("callbacks are not attached when disabled", func(t *testing.T) { // Check the callbacks were not attached because rules are disabled @@ -143,7 +143,7 @@ func TestEngineUsage(t *testing.T) { }, Signature: MakeSignature(privateKey, `{"name":"another valid rule"}`), }, - }) + }, nil) // Check the callbacks were removed for func1 and not func2 prologFunc1 := hookFunc1.Prolog() require.Nil(t, prologFunc1) @@ -153,7 +153,7 @@ func TestEngineUsage(t *testing.T) { t.Run("replace the enabled rules with an empty array of rules", func(t *testing.T) { // Set the rules with an empty array while enabled - engine.SetRules("yet another pack id", []api.Rule{}) + engine.SetRules("yet another pack id", []api.Rule{}, nil) // Check the callbacks were all removed for func1 and not func2 prologFunc1 := hookFunc1.Prolog() require.Nil(t, prologFunc1) @@ -256,7 +256,7 @@ func TestEngineUsage(t *testing.T) { }, }, }, - }) + }, nil) // Check the callbacks were removed for func1 and not func2 prologFunc1 := hookFunc1.Prolog() require.Nil(t, prologFunc1) From 4c15aa0d98fd707998de676dd4866bc19623efb4 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Tue, 23 Jul 2019 16:09:08 +0200 Subject: [PATCH 41/47] agent/metrics: fix dev regression - wrong metrics store name --- agent/internal/agent.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agent/internal/agent.go b/agent/internal/agent.go index 0728d25e..d017a183 100644 --- a/agent/internal/agent.go +++ b/agent/internal/agent.go @@ -179,7 +179,7 @@ func New(cfg *config.Config) *Agent { sdkUserLoginFailure: metrics.NewStore("sdk-login-fail", sdkMetricsPeriod), sdkUserSignup: metrics.NewStore("sdk-signup", sdkMetricsPeriod), whitelistedIP: metrics.NewStore("whitelisted", sdkMetricsPeriod), - errors: metrics.NewStore("whitelisted", config.ErrorMetricsPeriod), + errors: metrics.NewStore("errors", config.ErrorMetricsPeriod), }, ctx: ctx, cancel: cancel, From d3a7c630924b262a2b37ebcd6fef161963c535e0 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Tue, 23 Jul 2019 21:08:05 +0200 Subject: [PATCH 42/47] agent: better configuration settings logs Configuration settings were impossible to understand using logs and they are now clearly explained. A user can now understand: - what configuration source was used: files or environment variables - what proxy settings were used - the full settings dump in debug log level - why the agent does not started The default early log-level of the agent, before it reads the user configuration is now set to `info` rather than `debug` which was showing error stack traces to users, considered too confusing as it looked like a fatal exception. --- agent/internal/agent.go | 13 +++-- agent/internal/backend/client.go | 7 ++- agent/internal/client.go | 4 +- agent/internal/config/config.go | 87 +++++++++++++++++++--------- agent/internal/config/config_test.go | 7 --- 5 files changed, 75 insertions(+), 43 deletions(-) diff --git a/agent/internal/agent.go b/agent/internal/agent.go index d017a183..994f970c 100644 --- a/agent/internal/agent.go +++ b/agent/internal/agent.go @@ -62,7 +62,7 @@ func Start() { // - the correctness of sub-level error handling (ie. they don't panic). // Any panics from these would stop the execution of this level. backoff := sqtime.NewBackoff(time.Second, time.Hour, 2) - logger := plog.NewLogger(plog.Debug, os.Stderr, 0) + logger := plog.NewLogger(plog.Info, os.Stderr, 0) for { err := sqsafe.Call(func() error { // Level 2 @@ -147,10 +147,10 @@ type staticMetrics struct { const errorChanBufferLength = 256 func New(cfg *config.Config) *Agent { - logger := plog.NewLogger(plog.ParseLogLevel(cfg.LogLevel()), os.Stderr, errorChanBufferLength) + logger := plog.NewLogger(cfg.LogLevel(), os.Stderr, errorChanBufferLength) if cfg.Disable() { - logger.Info("agent disabled by configuration") + logger.Info("config: empty token value or explicitly disabled agent") return nil } @@ -241,7 +241,8 @@ func (a *Agent) Serve() error { heartbeat = config.BackendHTTPAPIDefaultHeartbeatDelay } - a.logger.Info("up and running - heartbeat set to ", heartbeat) + a.logger.Infof("go agent v%s up and running", version) + a.logger.Infof("agent: heartbeat set to %s", heartbeat) ticker := time.Tick(heartbeat) batchSize := int(appLoginRes.Features.BatchSize) @@ -336,7 +337,7 @@ func (a *Agent) InstrumentationEnable() error { } a.rules.Enable() sdk.SetAgent(a) - a.logger.Info("instrumentation enabled") + a.logger.Debug("instrumentation enabled") return nil } @@ -346,7 +347,7 @@ func (a *Agent) InstrumentationDisable() error { sdk.SetAgent(nil) a.rules.Disable() err := a.actors.SetActions(nil) - a.logger.Info("instrumentation disabled") + a.logger.Debug("instrumentation disabled") return err } diff --git a/agent/internal/backend/client.go b/agent/internal/backend/client.go index 27a231fc..d92b2844 100644 --- a/agent/internal/backend/client.go +++ b/agent/internal/backend/client.go @@ -35,11 +35,14 @@ func NewClient(backendURL string, cfg *config.Config, logger *plog.Logger) *Clie if proxySettings := cfg.BackendHTTPAPIProxy(); proxySettings == "" { // No user settings. The default transport uses standard global proxy // settings *_PROXY environment variables. - logger.Info("using proxy settings as indicated by the environment variables HTTP_PROXY, HTTPS_PROXY and NO_PROXY (or the lowercase versions)") + dummyReq, _ := http.NewRequest("GET", backendURL, nil) + if proxyURL, _ := http.ProxyFromEnvironment(dummyReq); proxyURL != nil { + logger.Infof("client: using system http proxy `%s` as indicated by the system environment variables http_proxy, https_proxy and no_proxy (or their uppercase alternatives)", proxyURL) + } transport = (http.DefaultTransport).(*http.Transport) } else { // Use the settings. - logger.Info("using configured https proxy ", proxySettings) + logger.Infof("client: using configured https proxy `%s`", proxySettings) proxyCfg := httpproxy.Config{ HTTPSProxy: proxySettings, } diff --git a/agent/internal/client.go b/agent/internal/client.go index 41213c89..c891b2ce 100644 --- a/agent/internal/client.go +++ b/agent/internal/client.go @@ -93,9 +93,9 @@ func appLogin(ctx context.Context, logger *plog.Logger, client *backend.Client, appLoginRes = nil d, max := backoff.Next() if max { - return nil, NewLoginError(errors.New("maximum number of retries reached")) + return nil, NewLoginError(errors.New("login: maximum number of retries reached")) } - logger.Debugf("retrying the request in %s", d) + logger.Debugf("login: retrying the request in %s", d) time.Sleep(d) } } diff --git a/agent/internal/config/config.go b/agent/internal/config/config.go index 8af7c372..8c0c3a71 100644 --- a/agent/internal/config/config.go +++ b/agent/internal/config/config.go @@ -10,6 +10,7 @@ package config import ( + "fmt" "net" "net/http" "os" @@ -18,7 +19,6 @@ import ( "strings" "time" - "github.com/pkg/errors" "github.com/sqreen/go-agent/agent/internal/plog" "github.com/sqreen/go-agent/agent/sqlib/sqerrors" @@ -232,44 +232,79 @@ func New(logger *plog.Logger) *Config { manager.AutomaticEnv() manager.SetConfigName(configFileBasename) - // Configuration file path options + // Default values of configurable parameters + parameters := []struct { + key string + defaultValue interface{} + secretFromChar int + hidden bool + }{ + {key: configKeyBackendHTTPAPIBaseURL, defaultValue: configDefaultBackendHTTPAPIBaseURL}, + {key: configKeyLogLevel, defaultValue: configDefaultLogLevel}, + {key: configKeyBackendHTTPAPIToken, defaultValue: "", secretFromChar: 6}, + {key: configKeyAppName, defaultValue: ""}, + {key: configKeyHTTPClientIPHeader, defaultValue: ""}, + {key: configKeyHTTPClientIPHeaderFormat, defaultValue: ""}, + {key: configKeyBackendHTTPAPIProxy, defaultValue: ""}, + {key: configKeyDisable, defaultValue: ""}, + {key: configKeyStripHTTPReferer, defaultValue: ""}, + {key: configKeyRules, defaultValue: "", hidden: true}, + {key: configKeySDKMetricsPeriod, defaultValue: configDefaultSDKMetricsPeriod, hidden: true}, + {key: configKeyMaxMetricsStoreLength, defaultValue: configDefaultMaxMetricsStoreLength, hidden: true}, + } + for _, p := range parameters { + manager.SetDefault(p.key, p.defaultValue) + } + + // Configuration file settings configFileEnvVar := strings.ToUpper(configEnvPrefix + "_" + configEnvKeyConfigFile) - if file := os.Getenv(configFileEnvVar); file != "" { + configFile := os.Getenv(configFileEnvVar) + if configFile != "" { // File location enforced by the user - manager.SetConfigFile(file) + manager.SetConfigFile(configFile) + logger.Infof("config: configuration file enforced by the environment variable `%s` to `%s`", configFileEnvVar, configFile) } else { // Not enforced: add possible paths in precedence order - // 1. Current working directory path: manager.AddConfigPath(`.`) - // 2. Executable path exec, err := os.Executable() if err != nil { - logger.Error(errors.Wrap(err, "could not read the executable file path")) + logger.Error(sqerrors.Wrap(err, "config: could not read the executable file path")) } else { manager.AddConfigPath(filepath.Dir(exec)) } } - - manager.SetDefault(configKeyBackendHTTPAPIBaseURL, configDefaultBackendHTTPAPIBaseURL) - manager.SetDefault(configKeyLogLevel, configDefaultLogLevel) - manager.SetDefault(configKeyAppName, "") - manager.SetDefault(configKeyHTTPClientIPHeader, "") - manager.SetDefault(configKeyHTTPClientIPHeaderFormat, "") - manager.SetDefault(configKeyBackendHTTPAPIProxy, "") - manager.SetDefault(configKeyDisable, "") - manager.SetDefault(configKeyStripHTTPReferer, "") - manager.SetDefault(configKeyRules, "") - manager.SetDefault(configKeySDKMetricsPeriod, configDefaultSDKMetricsPeriod) - manager.SetDefault(configKeyMaxMetricsStoreLength, configDefaultMaxMetricsStoreLength) - - err := manager.ReadInConfig() - if err != nil { - logger.Error(sqerrors.Wrap(err, "could not read the configuration file")) + // Try to read a configuration file according to the previous settings + if readErr, fileUsed := manager.ReadInConfig(), manager.ConfigFileUsed(); readErr != nil && fileUsed != "" { + // Could not read despite the fact of having found a file + logger.Error(sqerrors.Wrap(readErr, fmt.Sprintf("config: could not read the configuration file `%s`: falling back to environment variables", fileUsed))) + } else if fileUsed != "" { + // A file was found and no error reading it + logger.Infof("config: reading configuration settings from file `%s`", fileUsed) + } else { + logger.Infof("config: reading configuration settings from environment variables") } - return &Config{manager} + cfg := &Config{manager} + if cfg.LogLevel() == plog.Debug { + logger.Infof("config: setting: %s = %q", configFileEnvVar, configFile) + for _, p := range parameters { + if !p.hidden { + v := cfg.GetString(p.key) + if p.secretFromChar > 0 && len(v) > 0 { + secret := make([]byte, 0, len(v)) + secret = append(secret, v[:p.secretFromChar]...) + for range v[p.secretFromChar:] { + secret = append(secret, '*') + } + v = string(secret) + } + logger.Infof("config: settings: %s = %q", p.key, v) + } + } + } + return cfg } // BackendHTTPAPIBaseURL returns the base URL of the backend HTTP API. @@ -283,8 +318,8 @@ func (c *Config) BackendHTTPAPIToken() string { } // LogLevel returns the log level. -func (c *Config) LogLevel() string { - return sanitizeString(c.GetString(configKeyLogLevel)) +func (c *Config) LogLevel() plog.LogLevel { + return plog.ParseLogLevel(sanitizeString(c.GetString(configKeyLogLevel))) } // AppName returns the app name. diff --git a/agent/internal/config/config_test.go b/agent/internal/config/config_test.go index 62a64ab9..e15535f7 100644 --- a/agent/internal/config/config_test.go +++ b/agent/internal/config/config_test.go @@ -39,13 +39,6 @@ func TestUserConfig(t *testing.T) { ConfigKey: configKeyBackendHTTPAPIToken, SomeValue: testlib.RandString(2, 30), }, - { - Name: "Log Level", - GetCfgValue: cfg.LogLevel, - ConfigKey: configKeyLogLevel, - DefaultValue: configDefaultLogLevel, - SomeValue: testlib.RandString(2, 30), - }, { Name: "App Name", GetCfgValue: cfg.AppName, From 8d390d020fb524fc0ff5401460501124e68c79ef Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Wed, 24 Jul 2019 12:53:07 +0200 Subject: [PATCH 43/47] agent: fix dev regression - using pointer addresses of user events as metrics keys So it still worked but no longer aggregating since every event has distinct address. --- agent/internal/agent.go | 2 +- agent/internal/config/config.go | 7 ++-- agent/internal/metrics.go | 49 +++++++++++++++++++++++--- agent/internal/metrics/metrics_test.go | 1 + agent/internal/request.go | 38 -------------------- 5 files changed, 49 insertions(+), 48 deletions(-) diff --git a/agent/internal/agent.go b/agent/internal/agent.go index 994f970c..e8e2d3a7 100644 --- a/agent/internal/agent.go +++ b/agent/internal/agent.go @@ -242,7 +242,7 @@ func (a *Agent) Serve() error { } a.logger.Infof("go agent v%s up and running", version) - a.logger.Infof("agent: heartbeat set to %s", heartbeat) + a.logger.Debugf("agent: heartbeat set to %s", heartbeat) ticker := time.Tick(heartbeat) batchSize := int(appLoginRes.Features.BatchSize) diff --git a/agent/internal/config/config.go b/agent/internal/config/config.go index 8c0c3a71..3f9fbe87 100644 --- a/agent/internal/config/config.go +++ b/agent/internal/config/config.go @@ -15,7 +15,6 @@ import ( "net/http" "os" "path/filepath" - "strconv" "strings" "time" @@ -365,9 +364,9 @@ func (c *Config) LocalRulesFile() string { // This is temporary until the SDK rules are implemented and required for // integration tests which require a shorter time. func (c *Config) SDKMetricsPeriod() int { - p, err := strconv.Atoi(sanitizeString(c.GetString(configKeySDKMetricsPeriod))) - if err != nil { - return configDefaultSDKMetricsPeriod + p := c.GetInt(configKeySDKMetricsPeriod) + if p < 0 { + return 60 } return p } diff --git a/agent/internal/metrics.go b/agent/internal/metrics.go index 914ef3b4..7bfd2300 100644 --- a/agent/internal/metrics.go +++ b/agent/internal/metrics.go @@ -5,6 +5,8 @@ package internal import ( + "encoding/json" + "github.com/sqreen/go-agent/agent/internal/metrics" "github.com/sqreen/go-agent/agent/sqlib/sqerrors" ) @@ -18,24 +20,33 @@ func (a *Agent) addUserEvent(event userEventFace) { store *metrics.Store logFmt string ) + var uevent *userEvent switch actual := event.(type) { case *authUserEvent: + uevent = actual.userEvent if actual.loginSuccess { store = a.staticMetrics.sdkUserLoginSuccess - logFmt = "user event: user login success `%v`" + logFmt = "user event: user login success `%+v`" } else { store = a.staticMetrics.sdkUserLoginFailure - logFmt = "user event: user login failure `%v`" + logFmt = "user event: user login failure `%+v`" } case *signupUserEvent: + uevent = actual.userEvent store = a.staticMetrics.sdkUserSignup - logFmt = "user event: user signup `%v`" + logFmt = "user event: user signup `%+v`" default: a.logger.Error(sqerrors.Errorf("user event: unexpected user event type `%T`", actual)) return } - a.logger.Debug(logFmt, event) - if err := store.Add(event, 1); err != nil { + key, err := UserEventMetricsStoreKey(uevent) + if err != nil { + a.logger.Error(sqerrors.Wrap(err, "user event: could not create a metrics store key")) + return + } + a.logger.Debugf(logFmt, key) + + if err := store.Add(key, 1); err != nil { sqErr := sqerrors.Wrap(err, "user event: could not update the user metrics store") switch actualErr := err.(type) { case metrics.MaxMetricsStoreLengthError: @@ -69,3 +80,31 @@ func (a *Agent) addWhitelistEvent(matchedWhitelistEntry string) { } } } + +func UserEventMetricsStoreKey(event *userEvent) (json.Marshaler, error) { + var keys [][]interface{} + for prop, val := range event.userIdentifiers { + keys = append(keys, []interface{}{prop, val}) + } + jsonKeys, _ := json.Marshal(keys) + return userMetricsKey{ + Keys: string(jsonKeys), + IP: event.ip.String(), + }, nil +} + +type userMetricsKey struct { + Keys string + IP string +} + +func (e userMetricsKey) MarshalJSON() ([]byte, error) { + v := struct { + Keys json.RawMessage `json:"keys"` + IP string `json:"ip"` + }{ + Keys: json.RawMessage(e.Keys), + IP: e.IP, + } + return json.Marshal(&v) +} diff --git a/agent/internal/metrics/metrics_test.go b/agent/internal/metrics/metrics_test.go index b2da966d..dcb771bc 100644 --- a/agent/internal/metrics/metrics_test.go +++ b/agent/internal/metrics/metrics_test.go @@ -97,6 +97,7 @@ func TestUsage(t *testing.T) { require.NotPanics(t, func() { require.Error(t, store.Add([]byte("no slices"), 1)) + require.Error(t, store.Add(map[string]string{"a": "b", "c": "d"}, 21)) require.Error(t, store.Add(Struct2{ a: 33, b: "string", diff --git a/agent/internal/request.go b/agent/internal/request.go index 9fc5ce9d..aec983c6 100644 --- a/agent/internal/request.go +++ b/agent/internal/request.go @@ -5,7 +5,6 @@ package internal import ( - "encoding/json" "fmt" "net" "net/http" @@ -69,47 +68,10 @@ type authUserEvent struct { func (_ *authUserEvent) isUserEvent() {} -func (e *authUserEvent) MarshalJSON() ([]byte, error) { - k := &userMetricKey{ - id: e.userEvent.userIdentifiers, - ip: e.userEvent.ip, - } - return k.MarshalJSON() -} - -type userMetricKey struct { - id EventUserIdentifiersMap - ip net.IP -} - -func (k *userMetricKey) MarshalJSON() ([]byte, error) { - var keys [][]interface{} - for prop, val := range k.id { - keys = append(keys, []interface{}{prop, val}) - } - v := struct { - Keys [][]interface{} `json:"keys"` - IP string `json:"ip"` - }{ - Keys: keys, - IP: k.ip.String(), - } - buf, err := json.Marshal(&v) - return buf, err -} - type signupUserEvent struct { *userEvent } -func (e *signupUserEvent) MarshalJSON() ([]byte, error) { - k := &userMetricKey{ - id: e.userEvent.userIdentifiers, - ip: e.userEvent.ip, - } - return k.MarshalJSON() -} - func (_ *signupUserEvent) isUserEvent() {} type EventPropertyMap map[string]string From 988a511057e3d2b343b469314039d63501265339 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Wed, 24 Jul 2019 12:57:28 +0200 Subject: [PATCH 44/47] agent: update the version to v0.1.0-beta.6 --- agent/internal/version.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agent/internal/version.go b/agent/internal/version.go index a975968c..412fa216 100644 --- a/agent/internal/version.go +++ b/agent/internal/version.go @@ -4,4 +4,4 @@ package internal -const version = "0.1.0-beta.5" +const version = "0.1.0-beta.6" From 84b8caf8aff5e7fa15307d0c8ff657202da31f8d Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Wed, 24 Jul 2019 18:36:33 +0200 Subject: [PATCH 45/47] agent/rule: detect vendoring to find hooks Symbols can be prefixed by the vendor folder name when the dependency is into the vendor directory tree. Our rules don't include those user-specific vendor prefixes so we have here to detect it by looking for `vendor/` into the agent package name. A prefix can be computed out of it in order to use it when a symbol could not found by retrying with this prefix. Example: Symbols in myapp/vendor/apackage/apackage.go would be prefixed by `myapp/vendor/` while our rules may refer to `apackage/...`. --- agent/internal/app/runtime.go | 12 ++++++++++++ agent/internal/rule/rule.go | 15 ++++++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/agent/internal/app/runtime.go b/agent/internal/app/runtime.go index 81c4eace..58e8b261 100644 --- a/agent/internal/app/runtime.go +++ b/agent/internal/app/runtime.go @@ -9,6 +9,7 @@ import ( "debug/gosym" "fmt" "os" + "reflect" "runtime" "strings" "time" @@ -199,3 +200,14 @@ func executable(logger *plog.Logger) string { } return name } + +func VendorPrefix() string { + type t struct{} + pkg := reflect.TypeOf(t{}).PkgPath() + vendor := "vendor/" + i := strings.Index(pkg, vendor) + if i == -1 { + return "" + } + return pkg[:i+len(vendor)] +} diff --git a/agent/internal/rule/rule.go b/agent/internal/rule/rule.go index 3e4dae9c..329eab07 100644 --- a/agent/internal/rule/rule.go +++ b/agent/internal/rule/rule.go @@ -20,6 +20,7 @@ import ( "crypto/ecdsa" "fmt" + "github.com/sqreen/go-agent/agent/internal/app" "github.com/sqreen/go-agent/agent/internal/backend/api" "github.com/sqreen/go-agent/agent/internal/config" "github.com/sqreen/go-agent/agent/internal/metrics" @@ -38,6 +39,7 @@ type Engine struct { enabled bool metricsEngine *metrics.Engine publicKey *ecdsa.PublicKey + vendorPrefix string } // Logger interface required by this package. @@ -52,6 +54,7 @@ func NewEngine(logger Logger, metricsEngine *metrics.Engine, publicKey *ecdsa.Pu logger: logger, metricsEngine: metricsEngine, publicKey: publicKey, + vendorPrefix: app.VendorPrefix(), } } @@ -64,7 +67,7 @@ func (e *Engine) PackID() string { // them by atomically modifying the hooks, and removing what is left. func (e *Engine) SetRules(packID string, rules []api.Rule, errorMetricsStore *metrics.Store) { // Create the net rule descriptors and replace the existing ones - ruleDescriptors := newHookDescriptors(e.logger, rules, e.publicKey, e.metricsEngine, errorMetricsStore) + ruleDescriptors := newHookDescriptors(e.logger, rules, e.publicKey, e.metricsEngine, errorMetricsStore, e.vendorPrefix) e.setRules(packID, ruleDescriptors) } @@ -100,7 +103,7 @@ func (e *Engine) setRules(packID string, descriptors hookDescriptors) { // newHookDescriptors walks the list of received rules and creates the map of // hook descriptors indexed by their hook pointer. A hook descriptor contains // all it takes to enable and disable rules at run time. -func newHookDescriptors(logger Logger, rules []api.Rule, publicKey *ecdsa.PublicKey, metricsEngine *metrics.Engine, errorMetricsStore *metrics.Store) hookDescriptors { +func newHookDescriptors(logger Logger, rules []api.Rule, publicKey *ecdsa.PublicKey, metricsEngine *metrics.Engine, errorMetricsStore *metrics.Store, vendorPrefix string) hookDescriptors { // Create and configure the list of callbacks according to the given rules var hookDescriptors = make(hookDescriptors) for i := len(rules) - 1; i >= 0; i-- { @@ -114,10 +117,16 @@ func newHookDescriptors(logger Logger, rules []api.Rule, publicKey *ecdsa.Public hookpoint := r.Hookpoint symbol := fmt.Sprintf("%s.%s", hookpoint.Class, hookpoint.Method) hook := sqhook.Find(symbol) + if hook == nil && vendorPrefix != "" { + hook = sqhook.Find(vendorPrefix + symbol) + } if hook == nil { - logger.Debugf("rule `%s` ignored: symbol `%s` cannot be hooked", r.Name, symbol) + logger.Debugf("rule `%s` ignored: symbol `%s` could not be found", r.Name, symbol) continue + } else { + logger.Debugf("rule `%s`: successfully found hook `%s`", r.Name, hook) } + // Instantiate the callback nextProlog := hookDescriptors.Get(hook) ruleDescriptor := NewCallbackContext(&r, logger, metricsEngine, errorMetricsStore) From 6febfde896fbedcc8c07cd777b807182001066ad Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Thu, 25 Jul 2019 14:02:00 +0200 Subject: [PATCH 46/47] repo: update the changelog --- CHANGELOG.md | 40 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f31be7a..68949364 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,38 @@ +# v0.1.0-beta.6 + +## New Features + +- Fully-featured playbooks with the added ability into the agent to redirect the + request to a given URL. (#72) + +- Configurable protection behaviour of the agent when blocking a request by + either customizing the HTTP status code that is used for the blocking HTML + page, or by redirecting to a given URL instead. + Dashboard page: https://my.sqreen.com/application/goto/settings/global#protection-mode + +- HTTP response status code monitoring. (#75) + Dashboard page: https://my.sqreen.com/application/goto/monitoring + +- Support for browser security headers protection modules allowing to enable + various browser security options allowing to restrict modern browsers from + running into some preventable vulnerabilities: + + - [Content Security Policy][csp] protection module allowing to prevent + cross-site scripting attacks. (#74) + Dashboard page: https://my.sqreen.com/application/goto/modules/csp + + - Security headers protection module allowing to protect against client-side + vulnerabilities in the browser. (#73) + Dashboard page: https://my.sqreen.com/application/goto/modules/headers + +## Minor Changes + +- Better agent configuration logs clearly stating where does the configuration + come from (file in search path, enforced file or environment variables), + along with the possibility to display the full settings using the `debug` + log-level. + + # v0.1.0-beta.5 ## New Features @@ -211,5 +246,6 @@ share your impressions with us. - sdk: better documentation with examples. [Security Automation]: https://docs.sqreen.com/security-automation/introduction/ -[playbook]: https://docs.sqreen.com/security-automation/introduction-playbooks -[playbooks]: https://docs.sqreen.com/security-automation/introduction-playbooks +[playbook]: https://docs.sqreen.com/security-automation/introduction-playbooks/ +[playbooks]: https://docs.sqreen.com/security-automation/introduction-playbooks/ +[csp]: https://docs.sqreen.com/using-sqreen/automatically-set-content-security-policy/ \ No newline at end of file From 7b1903d82fc53d0e7b26a4f520987169d9f2d15a Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Thu, 25 Jul 2019 14:44:51 +0200 Subject: [PATCH 47/47] repo: update the readme --- README.md | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 4b750dd7..ee286f97 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -![Sqreen](https://s3-eu-west-1.amazonaws.com/sqreen-assets/npm/20171113/sqreen_horizontal_250.png) +![Sqreen](https://sqreen-assets.s3-eu-west-1.amazonaws.com/logos/sqreen-logo-264-1.svg) # [Sqreen](https://www.sqreen.com/)'s Application Security Management for Go @@ -9,22 +9,25 @@ [![codecov](https://codecov.io/gh/sqreen/go-agent/branch/master/graph/badge.svg)](https://codecov.io/gh/sqreen/go-agent) [![Go Report Card](https://goreportcard.com/badge/github.com/sqreen/go-agent)](https://goreportcard.com/report/github.com/sqreen/go-agent) -Sqreen monitors your application security and helps you easily protect it from -common vulnerabilities or advanced attacks. +After performance monitoring (APM), error and log monitoring it’s time to add a +security component into your app. Sqreen’s microagent automatically monitors +sensitive app’s routines, blocks attacks and reports actionable infos to your +dashboard. -- Gain visibility into your application security. -- One-click protection from common vulnerabilities. -- Easily enforce custom protection rules into your app. -- Identify malicious users before they cause harm. -- Integrate with your workflow. +![Dashboard](https://sqreen-assets.s3-eu-west-1.amazonaws.com/miscellaneous/dashboard.gif) -![Dashboard](https://d33wubrfki0l68.cloudfront.net/0fe441513f505601d03b25249deddd8fd1eb2a49/e2da6/img/new/illustrations/dashboard-mockup.png) +Sqreen provides automatic defense against attacks: -Sqreen also protects applications against common security threats such as -database injections, cross-site scripting attacks, scans, or authentication -activity inside the application to detect and block account takeover attacks. It -monitors functions in the application (I/O, authentication, network, command -execution, etc.) and provides dedicated security logic at run-time. +- Protect with security modules: RASP (Runtime Application Self-Protection), + in-app WAF (Web Application Firewall), Account takeovers and more. + +- Sqreen’s modules adapt to your application stack with no need of configuration. + +- Prevent attacks from the OWASP Top 10 (Injections, XSS and more), 0-days, + data Leaks, and more. + +- Create security automation playbooks that automatically react against + your advanced business-logic threats. For more details, visit [sqreen.com](https://www.sqreen.com/) @@ -33,7 +36,7 @@ For more details, visit [sqreen.com](https://www.sqreen.com/) 1. Download the Go agent and the SDK using `go get`: ```sh - $ go get github.com/sqreen/go-agent/... + $ go get github.com/sqreen/go-agent@v0.1.0-beta.6 ``` 1. Import the package `agent` in your `main` package of your app: @@ -59,9 +62,7 @@ For more details, visit [sqreen.com](https://www.sqreen.com/) - [sqhttp](https://godoc.org/github.com/sqreen/go-agent/sdk/middleware/sqhttp) for the standard net/http package. - [Gin](https://godoc.org/github.com/sqreen/go-agent/sdk/middleware/sqgin) - [Echo](https://godoc.org/github.com/sqreen/go-agent/sdk/middleware/sqecho) - - Coming soon: [Iris](https://github.com/sqreen/go-agent/pull/22), - [gRPC](https://github.com/sqreen/go-agent/pull/23) (please upvote if - interested ;)). + - [gRPC](https://godoc.org/github.com/sqreen/go-agent/sdk/middleware/sqgrpc) If your framework is not in the list, it is usually possible to use the standard `net/http` middleware. If not, please open an issue in this