diff --git a/filtron.go b/filtron.go index aa377b1..b3ac744 100644 --- a/filtron.go +++ b/filtron.go @@ -17,6 +17,7 @@ func main() { listen := flag.String("listen", "127.0.0.1:4004", "Proxy listen address") apiAddr := flag.String("api", "127.0.0.1:4005", "API listen address") ruleFile := flag.String("rules", "rules.json", "JSON rule list") + debug := flag.Bool("debug", false, "Log debug information") readBufferSize := flag.Int("read-buffer-size", 16*1024, "Read buffer size") printVersionInfo := flag.Bool("version", false, "Version information") flag.Parse() @@ -32,6 +33,6 @@ func main() { return } log.Println(rule.RulesLength(rules), "rules loaded from", *ruleFile) - p := proxy.Listen(*listen, *target, *readBufferSize, &rules) + p := proxy.Listen(*listen, *target, *readBufferSize, &rules, *debug) api.Listen(*apiAddr, *ruleFile, p) } diff --git a/proxy/proxy.go b/proxy/proxy.go index 615db76..b9c7e01 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -19,16 +19,19 @@ type Proxy struct { NumberOfRequests uint Rules *[]*rule.Rule client *fasthttp.HostClient + debug bool } -func Listen(address, target string, readBufferSize int, rules *[]*rule.Rule) *Proxy { +func Listen(address, target string, readBufferSize int, rules *[]*rule.Rule, debug bool) *Proxy { p := &Proxy{ NumberOfRequests: 0, Rules: rules, client: &fasthttp.HostClient{Addr: target, ReadBufferSize: readBufferSize}, + debug: debug, } go func(address string, p *Proxy) { log.Println("Proxy listens on", address) + log.Println("Target on", target) fasthttp.ListenAndServe(address, p.Handler) }(address, p) return p @@ -54,7 +57,7 @@ func (p *Proxy) Handler(ctx *fasthttp.RequestCtx) { err := p.client.Do(appRequest, resp) if err != nil { log.Println("Response error:", err, resp) - ctx.SetStatusCode(429) + ctx.SetStatusCode(500) return } diff --git a/rule/rule.go b/rule/rule.go index 02319b7..3b73c57 100644 --- a/rule/rule.go +++ b/rule/rule.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io/ioutil" + "log" "sync" "sync/atomic" "time" @@ -17,27 +18,29 @@ import ( ) type Rule struct { - Interval uint64 `json:"interval"` - Limit uint64 `json:"limit"` - Name string `json:"name"` - lastTick uint64 `json:"-"` - MatchCount uint64 `json:"match_count"` - filterMatchCount uint64 `json:"-"` - Filters []*selector.Selector `json:"-"` - RawFilters []string `json:"filters"` - Aggregations []*Aggregation `json:"-"` - RawAggregations []string `json:"aggregations"` - Actions []action.Action `json:"-"` - RawActions []action.ActionJSON `json:"actions"` - SubRules []*Rule `json:"subrules"` - Disabled bool `json:"disabled"` - Stop bool `json:"stop"` -} - -type Aggregation struct { sync.RWMutex - Values map[string]uint64 - Selector *selector.Selector + + Name string `json:"name"` + Interval uint64 `json:"interval"` + Limit uint64 `json:"limit"` + RequestCount uint64 `json:"requestCount"` + MatchCount uint64 `json:"matchCount"` + Stop bool `json:"stop"` + Disabled bool `json:"disabled"` + Filters []*selector.Selector `json:"-"` + RawFilters []string `json:"filters"` + AggregationSelectors []*selector.Selector `json:"-"` + AggregationValues map[string]*AggregationValue `json:"values"` + DefaultAggregationValue *AggregationValue `json:"-"` + RawAggregations []string `json:"aggregations"` + Actions []action.Action `json:"-"` + RawActions []action.ActionJSON `json:"actions"` + SubRules []*Rule `json:"subrules"` +} + +type AggregationValue struct { + LastTick uint64 `json:"lastTick"` + Count uint64 `json:"count"` } func Evaluate(rules *[]*Rule, ctx *fasthttp.RequestCtx) types.ResponseState { @@ -55,7 +58,7 @@ func validateRuleList(rules *[]*Rule, state *types.ResponseState, ctx *fasthttp. prevMatchCount := rule.MatchCount s := rule.Validate(ctx, *state) - + log.Println("rule ", rule.Name, s) if s > *state { *state = s } @@ -112,8 +115,6 @@ func ParseJSON(jsonData []byte) ([]*Rule, error) { } func (r *Rule) Init() error { - r.filterMatchCount = 0 - r.lastTick = uint64(time.Now().Unix()) if len(r.RawActions) == 0 && len(r.SubRules) == 0 { return errors.New(fmt.Sprintf("At least one subrule or action required in rule: %q", r.Name)) } @@ -141,17 +142,19 @@ func (r *Rule) Init() error { } func (r *Rule) ParseAggregations(aggregations []string) error { - r.Aggregations = make([]*Aggregation, 0, len(aggregations)) + selectors := make([]*selector.Selector, 0, len(aggregations)) for _, t := range aggregations { s, err := selector.Parse(t) if err != nil { return errors.New(fmt.Sprintf("Cannot parse selector '%v': %v", t, err)) } - a := &Aggregation{ - Values: make(map[string]uint64), - Selector: s, - } - r.Aggregations = append(r.Aggregations, a) + selectors = append(selectors, s) + } + r.AggregationSelectors = selectors + r.AggregationValues = make(map[string]*AggregationValue) + if len(aggregations) == 0 { + r.DefaultAggregationValue = NewAggreationValue() + r.AggregationValues["*"] = r.DefaultAggregationValue } return nil } @@ -169,38 +172,24 @@ func (r *Rule) ParseFilters(filters []string) error { } func (r *Rule) Validate(ctx *fasthttp.RequestCtx, rs types.ResponseState) types.ResponseState { - curTime := uint64(time.Now().Unix()) - if r.Limit != 0 && curTime-r.lastTick >= r.Interval { - r.filterMatchCount = 0 - atomic.StoreUint64(&r.filterMatchCount, 0) - atomic.StoreUint64(&r.lastTick, curTime) - for _, a := range r.Aggregations { - a.Lock() - a.Values = make(map[string]uint64) - a.Unlock() - } - } + // Does it pass all the filters ? for _, t := range r.Filters { if _, found := t.Match(ctx); !found { return types.UNTOUCHED } } - matched := false - state := rs - if len(r.Aggregations) == 0 { - atomic.AddUint64(&r.filterMatchCount, 1) - if r.filterMatchCount > r.Limit { - matched = true - } - } else { - for _, a := range r.Aggregations { - if a.Get(ctx) > r.Limit { - matched = true - } - } + + // + requestCount := atomic.AddUint64(&r.RequestCount, 1) + if requestCount%10 == 0 { + r.EraseOldAggregationValues() } - if matched { + + // Does it hit the limit ? + state := rs + if r.Match(ctx) { atomic.AddUint64(&r.MatchCount, 1) + for _, a := range r.Actions { // Skip serving actions if we already had one s := a.GetResponseState() @@ -223,13 +212,82 @@ func (r *Rule) Validate(ctx *fasthttp.RequestCtx, rs types.ResponseState) types. return state } -func (a *Aggregation) Get(ctx *fasthttp.RequestCtx) uint64 { - if val, found := a.Selector.Match(ctx); found { - a.Lock() - a.Values[val] += 1 - v := a.Values[val] - a.Unlock() - return v +func (r *Rule) Match(ctx *fasthttp.RequestCtx) bool { + // Match the aggregation: increment & check if it is above the limit + + // Get the AggregationValue for the context + var av *AggregationValue + if len(r.AggregationSelectors) == 0 { + // No aggregations: default value + av = r.DefaultAggregationValue + } else { + // Aggregation: get the key + key := "" + for _, s := range r.AggregationSelectors { + // Check + value, _ := s.Match(ctx) + // Concat + key = key + "|" + value + } + log.Println("Aggregation key", key) + + // Check if value exists : no --> call NewAggreationValue + var ok bool + var newAv *AggregationValue = nil + if av, ok = r.AggregationValues[key]; !ok { + // memory allocation outside the Lock/Unlock block + newAv = NewAggreationValue() + } + + // + r.Lock() + av, ok = r.AggregationValues[key] + if !ok { + if newAv == nil { + // Should not happen + newAv = NewAggreationValue() + } + av = newAv + r.AggregationValues[key] = av + } + r.Unlock() + } + // Increment, and return true is the limit has been reached + return r.IncAndMatch(av) +} + +func (r *Rule) IncAndMatch(av *AggregationValue) bool { + if r.Limit > 0 { + curTime := uint64(time.Now().Unix()) + + if curTime-atomic.LoadUint64(&av.LastTick) >= r.Interval { + atomic.StoreUint64(&av.Count, 0) + atomic.StoreUint64(&av.LastTick, curTime) + } + return atomic.AddUint64(&av.Count, 1) > r.Limit + } else { + atomic.AddUint64(&av.Count, 1) + return true + } +} + +func (r *Rule) EraseOldAggregationValues() { + if len(r.AggregationValues) > 1 { + curTime := uint64(time.Now().Unix()) + + r.Lock() + for k, av := range r.AggregationValues { + if curTime-av.LastTick >= r.Interval { + delete(r.AggregationValues, k) + } + } + r.Unlock() + } +} + +func NewAggreationValue() *AggregationValue { + return &AggregationValue{ + Count: 0, + LastTick: uint64(time.Now().Unix()), } - return 0 } diff --git a/selector/selector.go b/selector/selector.go index edae033..6e042d5 100644 --- a/selector/selector.go +++ b/selector/selector.go @@ -81,5 +81,6 @@ func (s *Selector) Match(ctx *fasthttp.RequestCtx) (string, bool) { if s.Negate { found = !found } + log.Println(" * ", s.RequestAttr, "[", s.SubAttr, "]=", string(matchSlice), ";", s.Regexp, "negate=", s.Negate, "found=", found) return string(matchSlice), found }