diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 903037b..f723912 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -39,7 +39,7 @@ jobs: - name: golangci-lint uses: golangci/golangci-lint-action@v3 with: - version: v1.54 + version: v1.55 args: --issues-exit-code=1 --timeout 10m only-new-issues: false # the cache is already managed above, enabling it here diff --git a/.golangci.yml b/.golangci.yml index 902eff5..a58310d 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,6 +1,14 @@ # see https://github.com/golangci/golangci-lint/blob/master/.golangci.example.yml linters-settings: + gci: + sections: + - standard + - default + - prefix(github.com/crowdsecurity) + - prefix(github.com/crowdsecurity/crowdsec) + - prefix(github.com/crowdsecurity/cs-blocklist-mirror) + gocyclo: min-complexity: 30 @@ -16,16 +24,28 @@ linters-settings: govet: check-shadowing: true + lll: line-length: 140 + misspell: locale: US + + nlreturn: + block-size: 4 + nolintlint: - allow-leading-space: true # don't require machine-readable nolint directives (i.e. with no leading space) allow-unused: false # report any unused nolint directives require-explanation: false # don't require an explanation for nolint directives require-specific: false # don't require nolint directives to be specific about which linter is being skipped + depguard: + rules: + main: + deny: + - pkg: "github.com/pkg/errors" + desc: "errors.New() is deprecated in favor of fmt.Errorf()" + linters: enable-all: true disable: @@ -33,7 +53,6 @@ linters: # DEPRECATED by golangi-lint # - deadcode # The owner seems to have abandoned the linter. Replaced by unused. - - depguard # Go linter that checks if package imports are in a list of acceptable packages - exhaustivestruct # The owner seems to have abandoned the linter. Replaced by exhaustruct. - golint # Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes - ifshort # Checks that your code uses short syntax for if-statements whenever possible @@ -55,6 +74,7 @@ linters: # - containedctx # containedctx is a linter that detects struct contained context.Context field # - contextcheck # check the function whether use a non-inherited context # - decorder # check declaration order and count of types, constants, variables and functions + # - depguard # Go linter that checks if package imports are in a list of acceptable packages # - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) # - durationcheck # check for two durations multiplied together # - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. @@ -64,6 +84,7 @@ linters: # - exhaustive # check exhaustiveness of enum switch statements # - exportloopref # checks for pointers to enclosing loop variables # - funlen # Tool for detection of long functions + # - ginkgolinter # enforces standards of using ginkgo and gomega # - gochecknoinits # Checks that no init functions are present in Go code # - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification # - goheader # Checks is file header matches to pattern @@ -160,7 +181,7 @@ linters: issues: max-issues-per-linter: 0 - max-same-issues: 10 + max-same-issues: 0 exclude-rules: # `err` is often shadowed, we may continue to do it - linters: diff --git a/cmd/root.go b/cmd/root.go index 435ccfd..3120941 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -13,12 +13,13 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" - "github.com/crowdsecurity/crowdsec/pkg/apiclient" csbouncer "github.com/crowdsecurity/go-cs-bouncer" "github.com/crowdsecurity/go-cs-lib/csdaemon" "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/version" + "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/cs-blocklist-mirror/pkg/cfg" "github.com/crowdsecurity/cs-blocklist-mirror/pkg/registry" "github.com/crowdsecurity/cs-blocklist-mirror/pkg/server" @@ -39,6 +40,7 @@ func HandleSignals(ctx context.Context) error { case <-ctx.Done(): return ctx.Err() } + return nil } @@ -123,10 +125,10 @@ func Execute() error { }) g.Go(func() error { - err := server.RunServer(ctx, g, config) - if err != nil { + if err := server.RunServer(ctx, g, config); err != nil { return fmt.Errorf("blocklist server failed: %w", err) } + return nil }) diff --git a/pkg/cfg/config.go b/pkg/cfg/config.go index 41acc7e..ece5652 100644 --- a/pkg/cfg/config.go +++ b/pkg/cfg/config.go @@ -79,22 +79,26 @@ func (cfg *Config) ValidateAndSetDefaults() error { if cfg.CrowdsecConfig.UpdateFrequency == "" { logrus.Warn("update_frequency is not provided") + cfg.CrowdsecConfig.UpdateFrequency = "10s" } if cfg.ConfigVersion == "" { logrus.Warn("config version is not provided; assuming v1.0") + cfg.ConfigVersion = "v1.0" } if cfg.ListenURI == "" { logrus.Warn("listen_uri is not provided ; assuming 127.0.0.1:41412") + cfg.ListenURI = "127.0.0.1:41412" } validAuthenticationTypes := []string{"basic", "ip_based", "none"} alreadyUsedEndpoint := make(map[string]struct{}) validFormats := []string{} + for format := range formatters.ByName { validFormats = append(validFormats, format) } @@ -103,10 +107,13 @@ func (cfg *Config) ValidateAndSetDefaults() error { if _, ok := alreadyUsedEndpoint[blockList.Endpoint]; ok { return fmt.Errorf("%s endpoint used more than once", blockList.Endpoint) } + alreadyUsedEndpoint[blockList.Endpoint] = struct{}{} + if !slices.Contains(validFormats, blockList.Format) { return fmt.Errorf("%s format is not supported. Supported formats are '%s'", blockList.Format, strings.Join(validFormats, ",")) } + if !slices.Contains(validAuthenticationTypes, strings.ToLower(blockList.Authentication.Type)) && blockList.Authentication.Type != "" { return fmt.Errorf( "%s authentication type is not supported. Supported authentication types are '%s'", @@ -121,10 +128,12 @@ func (cfg *Config) ValidateAndSetDefaults() error { func MergedConfig(configPath string) ([]byte, error) { patcher := yamlpatch.NewPatcher(configPath, ".local") + data, err := patcher.MergedPatchContent() if err != nil { return nil, err } + return data, nil } diff --git a/pkg/cfg/logging.go b/pkg/cfg/logging.go index 9d3874f..098ea2f 100644 --- a/pkg/cfg/logging.go +++ b/pkg/cfg/logging.go @@ -76,14 +76,17 @@ func (c *LoggingConfig) validate() error { if c.LogMedia != "stdout" && c.LogMedia != "file" { return fmt.Errorf("log_media should be either 'stdout' or 'file'") } + return nil } func (c *LoggingConfig) setup(fileName string) error { c.setDefaults() + if err := c.validate(); err != nil { return err } + log.SetLevel(*c.LogLevel) if c.LogMedia == "stdout" { diff --git a/pkg/formatters/formatters.go b/pkg/formatters/formatters.go index 630f771..442d895 100644 --- a/pkg/formatters/formatters.go +++ b/pkg/formatters/formatters.go @@ -25,21 +25,26 @@ func PlainText(w http.ResponseWriter, r *http.Request) { func MikroTik(w http.ResponseWriter, r *http.Request) { decisions := r.Context().Value(registry.GlobalDecisionRegistry.Key).([]*models.Decision) + listName := r.URL.Query().Get("listname") if listName == "" { listName = "CrowdSec" } + if !r.URL.Query().Has("ipv6only") { fmt.Fprintf(w, "/ip firewall address-list remove [find list=%s]\n", listName) } + if !r.URL.Query().Has("ipv4only") { fmt.Fprintf(w, "/ipv6 firewall address-list remove [find list=%s]\n", listName) } + for _, decision := range decisions { var ipType = "/ip" if strings.Contains(*decision.Value, ":") { ipType = "/ipv6" } + fmt.Fprintf(w, "%s firewall address-list add list=%s address=%s comment=\"%s for %s\"\n", ipType, @@ -54,16 +59,18 @@ func MikroTik(w http.ResponseWriter, r *http.Request) { func F5(w http.ResponseWriter, r *http.Request) { decisions := r.Context().Value(registry.GlobalDecisionRegistry.Key).([]*models.Decision) for _, decision := range decisions { - var category = *decision.Scenario + category := *decision.Scenario if strings.Contains(*decision.Scenario, "/") { category = strings.Split(*decision.Scenario, "/")[1] } + switch strings.ToLower(*decision.Scope) { case "ip": mask := 32 if strings.Contains(*decision.Value, ":") { mask = 64 } + fmt.Fprintf(w, "%s,%d,bl,%s\n", *decision.Value, diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go index abcf67f..90473b9 100644 --- a/pkg/registry/registry.go +++ b/pkg/registry/registry.go @@ -28,29 +28,36 @@ func (dr *DecisionRegistry) AddDecisions(decisions []*models.Decision) { if _, ok := dr.ActiveDecisionsByValue[*decision.Value]; !ok { activeDecisionCount.Inc() } + dr.ActiveDecisionsByValue[*decision.Value] = decision } } func (dr *DecisionRegistry) GetActiveDecisions(filter url.Values) []*models.Decision { ret := make([]*models.Decision, 0, len(dr.ActiveDecisionsByValue)) + for _, v := range dr.ActiveDecisionsByValue { if filter.Has("ipv6only") && strings.Contains(*v.Value, ".") { continue } + if filter.Has("ipv4only") && strings.Contains(*v.Value, ":") { continue } + if filter.Has("origin") && !strings.EqualFold(*v.Origin, filter.Get("origin")) { continue } + ret = append(ret, v) } + if !filter.Has("nosort") { sort.SliceStable(ret, func(i, j int) bool { return *ret[i].Value < *ret[j].Value }) } + return ret } diff --git a/pkg/server/logging.go b/pkg/server/logging.go index 3f60078..c16d641 100644 --- a/pkg/server/logging.go +++ b/pkg/server/logging.go @@ -28,6 +28,7 @@ type responseLogger struct { func (l *responseLogger) Write(b []byte) (int, error) { size, err := l.w.Write(b) l.size += size + return size, err } @@ -51,6 +52,7 @@ func (l *responseLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) { // WriteHeader has not been called yet l.status = http.StatusSwitchingProtocols } + return conn, rw, err } @@ -83,6 +85,7 @@ func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { url := *req.URL h.handler.ServeHTTP(w, req) + if req.MultipartForm != nil { req.MultipartForm.RemoveAll() } @@ -100,6 +103,7 @@ func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { func makeLogger(w http.ResponseWriter) (*responseLogger, http.ResponseWriter) { logger := &responseLogger{w: w, status: http.StatusOK} + return logger, httpsnoop.Wrap(w, httpsnoop.Hooks{ Write: func(httpsnoop.WriteFunc) httpsnoop.WriteFunc { return logger.Write @@ -118,25 +122,31 @@ func appendQuoted(buf []byte, s string) []byte { for width := 0; len(s) > 0; s = s[width:] { r := rune(s[0]) width = 1 + if r >= utf8.RuneSelf { r, width = utf8.DecodeRuneInString(s) } + if width == 1 && r == utf8.RuneError { buf = append(buf, `\x`...) buf = append(buf, lowerhex[s[0]>>4]) buf = append(buf, lowerhex[s[0]&0xF]) + continue } + if r == rune('"') || r == '\\' { // always backslashed buf = append(buf, '\\') buf = append(buf, byte(r)) continue } + if strconv.IsPrint(r) { n := utf8.EncodeRune(runeTmp[:], r) buf = append(buf, runeTmp[:n]...) continue } + switch r { case '\a': buf = append(buf, `\a`...) @@ -174,6 +184,7 @@ func appendQuoted(buf []byte, s string) []byte { } } } + return buf } @@ -182,6 +193,7 @@ func appendQuoted(buf []byte, s string) []byte { // status and size are used to provide the response HTTP status and size. func buildCommonLogLine(req *http.Request, url url.URL, ts time.Time, status int, size int) []byte { username := "-" + if url.User != nil { if name := url.User.Username(); name != "" { username = name @@ -201,6 +213,7 @@ func buildCommonLogLine(req *http.Request, url url.URL, ts time.Time, status int if req.ProtoMajor == 2 && req.Method == "CONNECT" { uri = req.Host } + if uri == "" { uri = url.RequestURI() } @@ -221,6 +234,7 @@ func buildCommonLogLine(req *http.Request, url url.URL, ts time.Time, status int buf = append(buf, strconv.Itoa(status)...) buf = append(buf, " "...) buf = append(buf, strconv.Itoa(size)...) + return buf } diff --git a/pkg/server/server.go b/pkg/server/server.go index aaa3a11..2cf1f2c 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -28,6 +28,7 @@ func RunServer(ctx context.Context, g *errgroup.Group, config cfg.Config) error if err != nil { return err } + http.HandleFunc(blockListCFG.Endpoint, f) log.Infof("serving blocklist in format %s at endpoint %s", blockListCFG.Format, blockListCFG.Endpoint) } @@ -39,11 +40,13 @@ func RunServer(ctx context.Context, g *errgroup.Group, config cfg.Config) error } var logHandler http.Handler + if config.EnableAccessLogs { logger, err := config.Logging.LoggerForFile(BlocklistMirrorAccessLogFilePath) if err != nil { return err } + logHandler = CombinedLoggingHandler(logger, http.DefaultServeMux) } @@ -57,6 +60,7 @@ func RunServer(ctx context.Context, g *errgroup.Group, config cfg.Config) error if err != nil && !errors.Is(err, http.ErrServerClosed) { return err } + return nil }) @@ -74,7 +78,9 @@ func listenAndServe(server *http.Server, config cfg.Config) error { log.Infof("Starting server with TLS at %s", config.ListenURI) return server.ListenAndServeTLS(config.TLS.CertFile, config.TLS.KeyFile) } + log.Infof("Starting server at %s", config.ListenURI) + return server.ListenAndServe() } @@ -102,12 +108,14 @@ func satisfiesBasicAuth(r *http.Request, user, password string) bool { if _, ok := r.Header[http.CanonicalHeaderKey("Authorization")]; !ok { return false } + expectedVal := fmt.Sprintf("Basic %s", basicAuth(user, password)) foundVal := r.Header[http.CanonicalHeaderKey("Authorization")][0] log.WithFields(log.Fields{ "expected": expectedVal, "found": foundVal, }).Debug("checking basic auth") + return expectedVal == foundVal } @@ -119,19 +127,24 @@ func toValidCIDR(ip string) string { if strings.Contains(ip, ":") { return ip + "/128" } + return ip + "/32" } func getTrustedIPs(ips []string) ([]net.IPNet, error) { trustedIPs := make([]net.IPNet, 0) + for _, ip := range ips { cidr := toValidCIDR(ip) + _, ipNet, err := net.ParseCIDR(cidr) if err != nil { return nil, err } + trustedIPs = append(trustedIPs, *ipNet) } + return trustedIPs, nil } @@ -142,6 +155,7 @@ func networksContainIP(networks []net.IPNet, ip string) bool { return true } } + return false } @@ -161,6 +175,7 @@ func decisionMiddleware(next http.HandlerFunc) func(w http.ResponseWriter, r *ht http.Error(w, "no decisions available", http.StatusNotFound) return } + ctx := context.WithValue(r.Context(), registry.GlobalDecisionRegistry.Key, decisions) next.ServeHTTP(w, r.WithContext(ctx)) } @@ -172,14 +187,18 @@ func authMiddleware(blockListCfg *cfg.BlockListConfig, next http.HandlerFunc) fu if err != nil { log.Errorf("error while spliting hostport for %s: %v", r.RemoteAddr, err) http.Error(w, "internal error", http.StatusInternalServerError) + return } + trustedIPs, err := getTrustedIPs(blockListCfg.Authentication.TrustedIPs) if err != nil { log.Errorf("error while parsing trusted IPs: %v", err) http.Error(w, "internal error", http.StatusInternalServerError) + return } + switch strings.ToLower(blockListCfg.Authentication.Type) { case "ip_based": if !networksContainIP(trustedIPs, ip) { @@ -193,6 +212,7 @@ func authMiddleware(blockListCfg *cfg.BlockListConfig, next http.HandlerFunc) fu } case "", "none": } + next.ServeHTTP(w, r) } } @@ -201,5 +221,6 @@ func getHandlerForBlockList(blockListCfg *cfg.BlockListConfig) (func(w http.Resp if _, ok := formatters.ByName[blockListCfg.Format]; !ok { return nil, fmt.Errorf("unknown format %s", blockListCfg.Format) } + return authMiddleware(blockListCfg, metricsMiddleware(blockListCfg, decisionMiddleware(formatters.ByName[blockListCfg.Format]))), nil }