Skip to content

Commit ebedc8b

Browse files
committed
refact pkg/database: extract decisionfilter.go
1 parent c231554 commit ebedc8b

File tree

2 files changed

+208
-196
lines changed

2 files changed

+208
-196
lines changed

pkg/database/decisionfilter.go

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
package database
2+
3+
import (
4+
"fmt"
5+
"strconv"
6+
"strings"
7+
8+
"github.com/pkg/errors"
9+
10+
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
11+
"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
12+
"github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate"
13+
"github.com/crowdsecurity/crowdsec/pkg/types"
14+
)
15+
16+
func applyDecisionFilter(query *ent.DecisionQuery, filter map[string][]string) (*ent.DecisionQuery, error) {
17+
var (
18+
err error
19+
start_ip, start_sfx, end_ip, end_sfx int64
20+
ip_sz int
21+
)
22+
23+
contains := true
24+
/*if contains is true, return bans that *contains* the given value (value is the inner)
25+
else, return bans that are *contained* by the given value (value is the outer)*/
26+
27+
/*the simulated filter is a bit different : if it's not present *or* set to false, specifically exclude records with simulated to true */
28+
if v, ok := filter["simulated"]; ok {
29+
if v[0] == "false" {
30+
query = query.Where(decision.SimulatedEQ(false))
31+
}
32+
33+
delete(filter, "simulated")
34+
} else {
35+
query = query.Where(decision.SimulatedEQ(false))
36+
}
37+
38+
for param, value := range filter {
39+
switch param {
40+
case "contains":
41+
contains, err = strconv.ParseBool(value[0])
42+
if err != nil {
43+
return nil, errors.Wrapf(InvalidFilter, "invalid contains value : %s", err)
44+
}
45+
case "scopes", "scope": // Swagger mentions both of them, let's just support both to make sure we don't break anything
46+
scopes := strings.Split(value[0], ",")
47+
for i, scope := range scopes {
48+
switch strings.ToLower(scope) {
49+
case "ip":
50+
scopes[i] = types.Ip
51+
case "range":
52+
scopes[i] = types.Range
53+
case "country":
54+
scopes[i] = types.Country
55+
case "as":
56+
scopes[i] = types.AS
57+
}
58+
}
59+
60+
query = query.Where(decision.ScopeIn(scopes...))
61+
case "value":
62+
query = query.Where(decision.ValueEQ(value[0]))
63+
case "type":
64+
query = query.Where(decision.TypeEQ(value[0]))
65+
case "origins":
66+
query = query.Where(
67+
decision.OriginIn(strings.Split(value[0], ",")...),
68+
)
69+
case "scenarios_containing":
70+
predicates := decisionPredicatesFromStr(value[0], decision.ScenarioContainsFold)
71+
query = query.Where(decision.Or(predicates...))
72+
case "scenarios_not_containing":
73+
predicates := decisionPredicatesFromStr(value[0], decision.ScenarioContainsFold)
74+
query = query.Where(decision.Not(
75+
decision.Or(
76+
predicates...,
77+
),
78+
))
79+
case "ip", "range":
80+
ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(value[0])
81+
if err != nil {
82+
return nil, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", value[0], err)
83+
}
84+
case "limit":
85+
limit, err := strconv.Atoi(value[0])
86+
if err != nil {
87+
return nil, errors.Wrapf(InvalidFilter, "invalid limit value : %s", err)
88+
}
89+
90+
query = query.Limit(limit)
91+
case "offset":
92+
offset, err := strconv.Atoi(value[0])
93+
if err != nil {
94+
return nil, errors.Wrapf(InvalidFilter, "invalid offset value : %s", err)
95+
}
96+
97+
query = query.Offset(offset)
98+
case "id_gt":
99+
id, err := strconv.Atoi(value[0])
100+
if err != nil {
101+
return nil, errors.Wrapf(InvalidFilter, "invalid id_gt value : %s", err)
102+
}
103+
104+
query = query.Where(decision.IDGT(id))
105+
}
106+
}
107+
108+
query, err = decisionIPFilter(query, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx)
109+
if err != nil {
110+
return nil, fmt.Errorf("fail to apply StartIpEndIpFilter: %w", err)
111+
}
112+
113+
return query, nil
114+
}
115+
116+
func decisionIPv4Filter(decisions *ent.DecisionQuery, contains bool, ip_sz int, start_ip int64, start_sfx int64, end_ip int64, end_sfx int64) (*ent.DecisionQuery, error) {
117+
if contains {
118+
/*Decision contains {start_ip,end_ip}*/
119+
return decisions.Where(decision.And(
120+
decision.StartIPLTE(start_ip),
121+
decision.EndIPGTE(end_ip),
122+
decision.IPSizeEQ(int64(ip_sz)))), nil
123+
}
124+
125+
/*Decision is contained within {start_ip,end_ip}*/
126+
return decisions.Where(decision.And(
127+
decision.StartIPGTE(start_ip),
128+
decision.EndIPLTE(end_ip),
129+
decision.IPSizeEQ(int64(ip_sz)))), nil
130+
}
131+
132+
func decisionIPv6Filter(decisions *ent.DecisionQuery, contains bool, ip_sz int, start_ip int64, start_sfx int64, end_ip int64, end_sfx int64) (*ent.DecisionQuery, error) {
133+
/*decision contains {start_ip,end_ip}*/
134+
if contains {
135+
return decisions.Where(decision.And(
136+
// matching addr size
137+
decision.IPSizeEQ(int64(ip_sz)),
138+
decision.Or(
139+
// decision.start_ip < query.start_ip
140+
decision.StartIPLT(start_ip),
141+
decision.And(
142+
// decision.start_ip == query.start_ip
143+
decision.StartIPEQ(start_ip),
144+
// decision.start_suffix <= query.start_suffix
145+
decision.StartSuffixLTE(start_sfx),
146+
)),
147+
decision.Or(
148+
// decision.end_ip > query.end_ip
149+
decision.EndIPGT(end_ip),
150+
decision.And(
151+
// decision.end_ip == query.end_ip
152+
decision.EndIPEQ(end_ip),
153+
// decision.end_suffix >= query.end_suffix
154+
decision.EndSuffixGTE(end_sfx),
155+
),
156+
),
157+
)), nil
158+
}
159+
160+
/*decision is contained within {start_ip,end_ip}*/
161+
return decisions.Where(decision.And(
162+
// matching addr size
163+
decision.IPSizeEQ(int64(ip_sz)),
164+
decision.Or(
165+
// decision.start_ip > query.start_ip
166+
decision.StartIPGT(start_ip),
167+
decision.And(
168+
// decision.start_ip == query.start_ip
169+
decision.StartIPEQ(start_ip),
170+
// decision.start_suffix >= query.start_suffix
171+
decision.StartSuffixGTE(start_sfx),
172+
)),
173+
decision.Or(
174+
// decision.end_ip < query.end_ip
175+
decision.EndIPLT(end_ip),
176+
decision.And(
177+
// decision.end_ip == query.end_ip
178+
decision.EndIPEQ(end_ip),
179+
// decision.end_suffix <= query.end_suffix
180+
decision.EndSuffixLTE(end_sfx),
181+
),
182+
),
183+
)), nil
184+
}
185+
186+
func decisionIPFilter(decisions *ent.DecisionQuery, contains bool, ip_sz int, start_ip int64, start_sfx int64, end_ip int64, end_sfx int64) (*ent.DecisionQuery, error) {
187+
switch ip_sz {
188+
case 4:
189+
return decisionIPv4Filter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx)
190+
case 16:
191+
return decisionIPv6Filter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx)
192+
case 0:
193+
return decisions, nil
194+
default:
195+
return nil, errors.Wrapf(InvalidFilter, "unknown ip size %d", ip_sz)
196+
}
197+
}
198+
199+
func decisionPredicatesFromStr(s string, predicateFunc func(string) predicate.Decision) []predicate.Decision {
200+
words := strings.Split(s, ",")
201+
predicates := make([]predicate.Decision, len(words))
202+
203+
for i, word := range words {
204+
predicates[i] = predicateFunc(word)
205+
}
206+
207+
return predicates
208+
}

0 commit comments

Comments
 (0)