diff --git a/pkg/database/alertfilter.go b/pkg/database/alertfilter.go index d966901401c..ce95d6045c6 100644 --- a/pkg/database/alertfilter.go +++ b/pkg/database/alertfilter.go @@ -253,7 +253,7 @@ func alertPredicatesFromFilter(filter map[string][]string) ([]predicate.Alert, e return predicates, nil } -func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]string) (*ent.AlertQuery, error) { +func applyAlertFilter(alerts *ent.AlertQuery, filter map[string][]string) (*ent.AlertQuery, error) { preds, err := alertPredicatesFromFilter(filter) if err != nil { return nil, err diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index 2fcbb8a5f49..e0674d83cb8 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -750,7 +750,7 @@ func (c *Client) CreateAlert(ctx context.Context, machineID string, alertList [] return alertIDs, nil } -func (c *Client) AlertsCountPerScenario(ctx context.Context, filters map[string][]string) (map[string]int, error) { +func (c *Client) AlertsCountPerScenario(ctx context.Context, filter map[string][]string) (map[string]int, error) { var res []struct { Scenario string Count int @@ -758,7 +758,7 @@ func (c *Client) AlertsCountPerScenario(ctx context.Context, filters map[string] query := c.Ent.Alert.Query() - query, err := BuildAlertRequestFromFilter(query, filters) + query, err := applyAlertFilter(query, filter) if err != nil { return nil, fmt.Errorf("failed to build alert request: %w", err) } @@ -809,7 +809,7 @@ func (c *Client) QueryAlertWithFilter(ctx context.Context, filter map[string][]s for { alerts := c.Ent.Alert.Query() - alerts, err := BuildAlertRequestFromFilter(alerts, filter) + alerts, err := applyAlertFilter(alerts, filter) if err != nil { return nil, err } diff --git a/pkg/database/decisionfilter.go b/pkg/database/decisionfilter.go new file mode 100644 index 00000000000..5fef955d9b7 --- /dev/null +++ b/pkg/database/decisionfilter.go @@ -0,0 +1,208 @@ +package database + +import ( + "fmt" + "strconv" + "strings" + + "github.com/pkg/errors" + + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" + "github.com/crowdsecurity/crowdsec/pkg/types" +) + +func applyDecisionFilter(query *ent.DecisionQuery, filter map[string][]string) (*ent.DecisionQuery, error) { + var ( + err error + start_ip, start_sfx, end_ip, end_sfx int64 + ip_sz int + ) + + contains := true + /*if contains is true, return bans that *contains* the given value (value is the inner) + else, return bans that are *contained* by the given value (value is the outer)*/ + + /*the simulated filter is a bit different : if it's not present *or* set to false, specifically exclude records with simulated to true */ + if v, ok := filter["simulated"]; ok { + if v[0] == "false" { + query = query.Where(decision.SimulatedEQ(false)) + } + + delete(filter, "simulated") + } else { + query = query.Where(decision.SimulatedEQ(false)) + } + + for param, value := range filter { + switch param { + case "contains": + contains, err = strconv.ParseBool(value[0]) + if err != nil { + return nil, errors.Wrapf(InvalidFilter, "invalid contains value : %s", err) + } + case "scopes", "scope": // Swagger mentions both of them, let's just support both to make sure we don't break anything + scopes := strings.Split(value[0], ",") + for i, scope := range scopes { + switch strings.ToLower(scope) { + case "ip": + scopes[i] = types.Ip + case "range": + scopes[i] = types.Range + case "country": + scopes[i] = types.Country + case "as": + scopes[i] = types.AS + } + } + + query = query.Where(decision.ScopeIn(scopes...)) + case "value": + query = query.Where(decision.ValueEQ(value[0])) + case "type": + query = query.Where(decision.TypeEQ(value[0])) + case "origins": + query = query.Where( + decision.OriginIn(strings.Split(value[0], ",")...), + ) + case "scenarios_containing": + predicates := decisionPredicatesFromStr(value[0], decision.ScenarioContainsFold) + query = query.Where(decision.Or(predicates...)) + case "scenarios_not_containing": + predicates := decisionPredicatesFromStr(value[0], decision.ScenarioContainsFold) + query = query.Where(decision.Not( + decision.Or( + predicates..., + ), + )) + case "ip", "range": + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(value[0]) + if err != nil { + return nil, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", value[0], err) + } + case "limit": + limit, err := strconv.Atoi(value[0]) + if err != nil { + return nil, errors.Wrapf(InvalidFilter, "invalid limit value : %s", err) + } + + query = query.Limit(limit) + case "offset": + offset, err := strconv.Atoi(value[0]) + if err != nil { + return nil, errors.Wrapf(InvalidFilter, "invalid offset value : %s", err) + } + + query = query.Offset(offset) + case "id_gt": + id, err := strconv.Atoi(value[0]) + if err != nil { + return nil, errors.Wrapf(InvalidFilter, "invalid id_gt value : %s", err) + } + + query = query.Where(decision.IDGT(id)) + } + } + + query, err = decisionIPFilter(query, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) + if err != nil { + return nil, fmt.Errorf("fail to apply StartIpEndIpFilter: %w", err) + } + + return query, nil +} + +func decisionIPv4Filter(decisions *ent.DecisionQuery, contains bool, ip_sz int, start_ip int64, start_sfx int64, end_ip int64, end_sfx int64) (*ent.DecisionQuery, error) { + if contains { + /*Decision contains {start_ip,end_ip}*/ + return decisions.Where(decision.And( + decision.StartIPLTE(start_ip), + decision.EndIPGTE(end_ip), + decision.IPSizeEQ(int64(ip_sz)))), nil + } + + /*Decision is contained within {start_ip,end_ip}*/ + return decisions.Where(decision.And( + decision.StartIPGTE(start_ip), + decision.EndIPLTE(end_ip), + decision.IPSizeEQ(int64(ip_sz)))), nil +} + +func decisionIPv6Filter(decisions *ent.DecisionQuery, contains bool, ip_sz int, start_ip int64, start_sfx int64, end_ip int64, end_sfx int64) (*ent.DecisionQuery, error) { + /*decision contains {start_ip,end_ip}*/ + if contains { + return decisions.Where(decision.And( + // matching addr size + decision.IPSizeEQ(int64(ip_sz)), + decision.Or( + // decision.start_ip < query.start_ip + decision.StartIPLT(start_ip), + decision.And( + // decision.start_ip == query.start_ip + decision.StartIPEQ(start_ip), + // decision.start_suffix <= query.start_suffix + decision.StartSuffixLTE(start_sfx), + )), + decision.Or( + // decision.end_ip > query.end_ip + decision.EndIPGT(end_ip), + decision.And( + // decision.end_ip == query.end_ip + decision.EndIPEQ(end_ip), + // decision.end_suffix >= query.end_suffix + decision.EndSuffixGTE(end_sfx), + ), + ), + )), nil + } + + /*decision is contained within {start_ip,end_ip}*/ + return decisions.Where(decision.And( + // matching addr size + decision.IPSizeEQ(int64(ip_sz)), + decision.Or( + // decision.start_ip > query.start_ip + decision.StartIPGT(start_ip), + decision.And( + // decision.start_ip == query.start_ip + decision.StartIPEQ(start_ip), + // decision.start_suffix >= query.start_suffix + decision.StartSuffixGTE(start_sfx), + )), + decision.Or( + // decision.end_ip < query.end_ip + decision.EndIPLT(end_ip), + decision.And( + // decision.end_ip == query.end_ip + decision.EndIPEQ(end_ip), + // decision.end_suffix <= query.end_suffix + decision.EndSuffixLTE(end_sfx), + ), + ), + )), nil +} + +func decisionIPFilter(decisions *ent.DecisionQuery, contains bool, ip_sz int, start_ip int64, start_sfx int64, end_ip int64, end_sfx int64) (*ent.DecisionQuery, error) { + switch ip_sz { + case 4: + return decisionIPv4Filter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) + case 16: + return decisionIPv6Filter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) + case 0: + return decisions, nil + default: + return nil, errors.Wrapf(InvalidFilter, "unknown ip size %d", ip_sz) + } +} + +func decisionPredicatesFromStr(s string, predicateFunc func(string) predicate.Decision) []predicate.Decision { + words := strings.Split(s, ",") + predicates := make([]predicate.Decision, len(words)) + + for i, word := range words { + predicates[i] = predicateFunc(word) + } + + return predicates +} diff --git a/pkg/database/decisions.go b/pkg/database/decisions.go index 52d0e341c48..dd5963f3005 100644 --- a/pkg/database/decisions.go +++ b/pkg/database/decisions.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "strconv" - "strings" "time" "entgo.io/ent/dialect/sql" @@ -14,7 +13,6 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" - "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -27,116 +25,16 @@ type DecisionsByScenario struct { Type string } -func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string][]string) (*ent.DecisionQuery, error) { - var ( - err error - start_ip, start_sfx, end_ip, end_sfx int64 - ip_sz int - ) - - contains := true - /*if contains is true, return bans that *contains* the given value (value is the inner) - else, return bans that are *contained* by the given value (value is the outer)*/ - - /*the simulated filter is a bit different : if it's not present *or* set to false, specifically exclude records with simulated to true */ - if v, ok := filter["simulated"]; ok { - if v[0] == "false" { - query = query.Where(decision.SimulatedEQ(false)) - } - - delete(filter, "simulated") - } else { - query = query.Where(decision.SimulatedEQ(false)) - } - - for param, value := range filter { - switch param { - case "contains": - contains, err = strconv.ParseBool(value[0]) - if err != nil { - return nil, errors.Wrapf(InvalidFilter, "invalid contains value : %s", err) - } - case "scopes", "scope": // Swagger mentions both of them, let's just support both to make sure we don't break anything - scopes := strings.Split(value[0], ",") - for i, scope := range scopes { - switch strings.ToLower(scope) { - case "ip": - scopes[i] = types.Ip - case "range": - scopes[i] = types.Range - case "country": - scopes[i] = types.Country - case "as": - scopes[i] = types.AS - } - } - - query = query.Where(decision.ScopeIn(scopes...)) - case "value": - query = query.Where(decision.ValueEQ(value[0])) - case "type": - query = query.Where(decision.TypeEQ(value[0])) - case "origins": - query = query.Where( - decision.OriginIn(strings.Split(value[0], ",")...), - ) - case "scenarios_containing": - predicates := decisionPredicatesFromStr(value[0], decision.ScenarioContainsFold) - query = query.Where(decision.Or(predicates...)) - case "scenarios_not_containing": - predicates := decisionPredicatesFromStr(value[0], decision.ScenarioContainsFold) - query = query.Where(decision.Not( - decision.Or( - predicates..., - ), - )) - case "ip", "range": - ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(value[0]) - if err != nil { - return nil, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", value[0], err) - } - case "limit": - limit, err := strconv.Atoi(value[0]) - if err != nil { - return nil, errors.Wrapf(InvalidFilter, "invalid limit value : %s", err) - } - - query = query.Limit(limit) - case "offset": - offset, err := strconv.Atoi(value[0]) - if err != nil { - return nil, errors.Wrapf(InvalidFilter, "invalid offset value : %s", err) - } - - query = query.Offset(offset) - case "id_gt": - id, err := strconv.Atoi(value[0]) - if err != nil { - return nil, errors.Wrapf(InvalidFilter, "invalid id_gt value : %s", err) - } - - query = query.Where(decision.IDGT(id)) - } - } - - query, err = decisionIPFilter(query, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) - if err != nil { - return nil, fmt.Errorf("fail to apply StartIpEndIpFilter: %w", err) - } - - return query, nil -} - -func (c *Client) QueryAllDecisionsWithFilters(ctx context.Context, filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryAllDecisionsWithFilters(ctx context.Context, filter map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilGT(time.Now().UTC()), ) // Allow a bouncer to ask for non-deduplicated results - if v, ok := filters["dedup"]; !ok || v[0] != "false" { + if v, ok := filter["dedup"]; !ok || v[0] != "false" { query = query.Where(longestDecisionForScopeTypeValue) } - query, err := BuildDecisionRequestWithFilter(query, filters) + query, err := applyDecisionFilter(query, filter) if err != nil { c.Log.Warningf("QueryAllDecisionsWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "get all decisions with filters") @@ -153,16 +51,16 @@ func (c *Client) QueryAllDecisionsWithFilters(ctx context.Context, filters map[s return data, nil } -func (c *Client) QueryExpiredDecisionsWithFilters(ctx context.Context, filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryExpiredDecisionsWithFilters(ctx context.Context, filter map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilLT(time.Now().UTC()), ) // Allow a bouncer to ask for non-deduplicated results - if v, ok := filters["dedup"]; !ok || v[0] != "false" { + if v, ok := filter["dedup"]; !ok || v[0] != "false" { query = query.Where(longestDecisionForScopeTypeValue) } - query, err := BuildDecisionRequestWithFilter(query, filters) + query, err := applyDecisionFilter(query, filter) query = query.Order(ent.Asc(decision.FieldID)) @@ -185,7 +83,7 @@ func (c *Client) QueryDecisionCountByScenario(ctx context.Context) ([]*Decisions decision.UntilGT(time.Now().UTC()), ) - query, err := BuildDecisionRequestWithFilter(query, make(map[string][]string)) + query, err := applyDecisionFilter(query, make(map[string][]string)) if err != nil { c.Log.Warningf("QueryDecisionCountByScenario : %s", err) return nil, errors.Wrap(QueryFail, "count all decisions with filters") @@ -211,7 +109,7 @@ func (c *Client) QueryDecisionWithFilter(ctx context.Context, filter map[string] decisions := c.Ent.Decision.Query(). Where(decision.UntilGTE(time.Now().UTC())) - decisions, err = BuildDecisionRequestWithFilter(decisions, filter) + decisions, err = applyDecisionFilter(decisions, filter) if err != nil { return []*ent.Decision{}, err } @@ -263,7 +161,7 @@ func longestDecisionForScopeTypeValue(s *sql.Selector) { ) } -func (c *Client) QueryExpiredDecisionsSinceWithFilters(ctx context.Context, since *time.Time, filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryExpiredDecisionsSinceWithFilters(ctx context.Context, since *time.Time, filter map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilLT(time.Now().UTC()), ) @@ -273,11 +171,11 @@ func (c *Client) QueryExpiredDecisionsSinceWithFilters(ctx context.Context, sinc } // Allow a bouncer to ask for non-deduplicated results - if v, ok := filters["dedup"]; !ok || v[0] != "false" { + if v, ok := filter["dedup"]; !ok || v[0] != "false" { query = query.Where(longestDecisionForScopeTypeValue) } - query, err := BuildDecisionRequestWithFilter(query, filters) + query, err := applyDecisionFilter(query, filter) if err != nil { c.Log.Warningf("QueryExpiredDecisionsSinceWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions with filters") @@ -294,7 +192,7 @@ func (c *Client) QueryExpiredDecisionsSinceWithFilters(ctx context.Context, sinc return data, nil } -func (c *Client) QueryNewDecisionsSinceWithFilters(ctx context.Context, since *time.Time, filters map[string][]string) ([]*ent.Decision, error) { +func (c *Client) QueryNewDecisionsSinceWithFilters(ctx context.Context, since *time.Time, filter map[string][]string) ([]*ent.Decision, error) { query := c.Ent.Decision.Query().Where( decision.UntilGT(time.Now().UTC()), ) @@ -304,11 +202,11 @@ func (c *Client) QueryNewDecisionsSinceWithFilters(ctx context.Context, since *t } // Allow a bouncer to ask for non-deduplicated results - if v, ok := filters["dedup"]; !ok || v[0] != "false" { + if v, ok := filter["dedup"]; !ok || v[0] != "false" { query = query.Where(longestDecisionForScopeTypeValue) } - query, err := BuildDecisionRequestWithFilter(query, filters) + query, err := applyDecisionFilter(query, filter) if err != nil { c.Log.Warningf("QueryNewDecisionsSinceWithFilters : %s", err) return []*ent.Decision{}, errors.Wrapf(QueryFail, "new decisions since '%s'", since.String()) @@ -537,97 +435,3 @@ func (c *Client) GetActiveDecisionsTimeLeftByValue(ctx context.Context, decision return decision.Until.Sub(time.Now().UTC()), nil } - -func decisionIPv4Filter(decisions *ent.DecisionQuery, contains bool, ip_sz int, start_ip int64, start_sfx int64, end_ip int64, end_sfx int64) (*ent.DecisionQuery, error) { - if contains { - /*Decision contains {start_ip,end_ip}*/ - return decisions.Where(decision.And( - decision.StartIPLTE(start_ip), - decision.EndIPGTE(end_ip), - decision.IPSizeEQ(int64(ip_sz)))), nil - } - - /*Decision is contained within {start_ip,end_ip}*/ - return decisions.Where(decision.And( - decision.StartIPGTE(start_ip), - decision.EndIPLTE(end_ip), - decision.IPSizeEQ(int64(ip_sz)))), nil -} - -func decisionIPv6Filter(decisions *ent.DecisionQuery, contains bool, ip_sz int, start_ip int64, start_sfx int64, end_ip int64, end_sfx int64) (*ent.DecisionQuery, error) { - /*decision contains {start_ip,end_ip}*/ - if contains { - return decisions.Where(decision.And( - // matching addr size - decision.IPSizeEQ(int64(ip_sz)), - decision.Or( - // decision.start_ip < query.start_ip - decision.StartIPLT(start_ip), - decision.And( - // decision.start_ip == query.start_ip - decision.StartIPEQ(start_ip), - // decision.start_suffix <= query.start_suffix - decision.StartSuffixLTE(start_sfx), - )), - decision.Or( - // decision.end_ip > query.end_ip - decision.EndIPGT(end_ip), - decision.And( - // decision.end_ip == query.end_ip - decision.EndIPEQ(end_ip), - // decision.end_suffix >= query.end_suffix - decision.EndSuffixGTE(end_sfx), - ), - ), - )), nil - } - - /*decision is contained within {start_ip,end_ip}*/ - return decisions.Where(decision.And( - // matching addr size - decision.IPSizeEQ(int64(ip_sz)), - decision.Or( - // decision.start_ip > query.start_ip - decision.StartIPGT(start_ip), - decision.And( - // decision.start_ip == query.start_ip - decision.StartIPEQ(start_ip), - // decision.start_suffix >= query.start_suffix - decision.StartSuffixGTE(start_sfx), - )), - decision.Or( - // decision.end_ip < query.end_ip - decision.EndIPLT(end_ip), - decision.And( - // decision.end_ip == query.end_ip - decision.EndIPEQ(end_ip), - // decision.end_suffix <= query.end_suffix - decision.EndSuffixLTE(end_sfx), - ), - ), - )), nil -} - -func decisionIPFilter(decisions *ent.DecisionQuery, contains bool, ip_sz int, start_ip int64, start_sfx int64, end_ip int64, end_sfx int64) (*ent.DecisionQuery, error) { - switch ip_sz { - case 4: - return decisionIPv4Filter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) - case 16: - return decisionIPv6Filter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) - case 0: - return decisions, nil - default: - return nil, errors.Wrapf(InvalidFilter, "unknown ip size %d", ip_sz) - } -} - -func decisionPredicatesFromStr(s string, predicateFunc func(string) predicate.Decision) []predicate.Decision { - words := strings.Split(s, ",") - predicates := make([]predicate.Decision, len(words)) - - for i, word := range words { - predicates[i] = predicateFunc(word) - } - - return predicates -}